JEP 18137:JAX NumPy 與 SciPy Wrappers 的範疇#

Jake VanderPlas

2023 年 10 月

到目前為止,jax.numpyjax.scipy 的預期範疇相對不明確。本文檔提出了這些套件的明確範疇,以更好地引導和評估未來的貢獻,並促使移除一些超出範疇的程式碼。

背景#

從一開始,JAX 的目標就是提供類似 NumPy 的 API,以便在 XLA 中執行程式碼,而專案開發的一大部分是建構 jax.numpyjax.scipy 命名空間,作為 NumPy 和 SciPy API 的 JAX 基礎實作。一直以來都有一個隱含的理解,即 numpyscipy 的某些部分超出 JAX 的範疇,但此範疇尚未明確定義。這可能會導致貢獻者感到困惑和沮喪,因為對於潛在的 jax.numpyjax.scipy 貢獻是否會被 JAX 接受,沒有明確的答案。

為何限制範圍?#

為了避免讓此事懸而未決,我們應該明確說明:事實上,任何包含在像 JAX 這樣的專案中的程式碼,都會為開發人員帶來微小但非零的持續維護負擔。專案隨著時間推移的成功,直接關係到維護人員為專案所有部分繼續進行此維護的能力:記錄功能、回覆問題、修正錯誤等等。為了任何軟體工具的長期成功和永續性,維護人員仔細權衡任何特定貢獻是否會對專案的目標和資源產生淨正面影響,至關重要。

評估標準#

本文檔提出了六個軸線的標準,可以據此判斷任何特定的 numpyscipy API 是否應納入 JAX。在所有軸線上都很強的 API 是納入 JAX 套件的絕佳候選者;在任何六個軸線上存在明顯弱點,都是反對納入 JAX 的有力論據。

軸線 1:XLA 對齊#

我們考量的第一個軸線是擬議的 API 與原生 XLA 運算的對齊程度。例如,jax.numpy.exp() 是一個或多或少直接反映 jax.lax.exp 的函數。numpyscipy.specialnumpy.linalgscipy.linalg 和其他模組中的大量函數符合此標準:在考量將這些函數納入 JAX 時,它們通過了 XLA 對齊檢查。

另一方面,有些函數如 numpy.unique(),它們並不直接對應於任何 XLA 運算,而且在某些情況下,與 JAX 目前的計算模型(需要靜態形狀陣列,例如 unique 返回一個值相關的動態陣列形狀)從根本上不相容。在考量將這些函數納入 JAX 時,它們未通過 XLA 對齊檢查。

我們也將純函數語意的需求視為此軸線的一部分。例如,numpy.random 建構於隱式更新的基於狀態的 RNG 之上,這與基於 XLA 建構的 JAX 計算模型從根本上不相容。

軸線 2:Array API 對齊#

我們考量的第二個軸線著重於 Python Array API 標準:在某種意義上,這是一個社群驅動的概要,說明了哪些陣列運算對於廣泛的使用者社群中的陣列導向程式設計至關重要。如果 numpyscipy 中的 API 列在 Array API 標準中,則強烈表示 JAX 應包含它。以上述範例為例,Array API 標準包含 numpy.unique() 的幾個變體(unique_allunique_countsunique_inverseunique_values),這表示儘管該函數與 XLA 並非完全對齊,但它對於 Python 使用者社群來說足夠重要,JAX 或許應該實作它。

軸線 3:下游實作的存在#

對於不符合軸線 1 或 2 的功能,納入 JAX 的重要考量是是否存在良好支援的下游套件,可以提供所討論的功能。一個很好的例子是 scipy.optimize:雖然 JAX 確實包含 scipy.optimize 功能的最小包裝函式集,但在 JAXopt 套件中存在更完整的處理方式,該套件由 JAX 協作者積極維護。在這種情況下,我們應該傾向於引導使用者和貢獻者使用這些專門的套件,而不是在 JAX 本身中重新實作這些 API。

軸線 4:實作的複雜性與穩健性#

對於不符合 XLA 的功能,一個考量因素是擬議實作的複雜程度。這在某種程度上與軸線 1 一致,但仍然需要明確指出。許多函數已貢獻給 JAX,它們的實作相對複雜,難以驗證,並帶來過大的維護負擔;一個例子是 jax.scipy.special.bessel_jn():在撰寫本 JEP 時,其目前的實作是一個非直接的迭代近似,在某些領域存在 收斂問題,而 擬議的修正 引入了進一步的複雜性。如果我們在接受貢獻時更仔細地權衡實作的複雜性和穩健性,我們可能會選擇不接受此貢獻納入套件。

軸線 5:函數式與物件導向 API#

JAX 最適合函數式 API,而不是物件導向 API。物件導向 API 通常會隱藏不純粹的語意,使其通常難以良好地實作。NumPy 和 SciPy 通常堅持使用函數式 API,但有時會提供物件導向的便利包裝函式。

這方面的例子是 numpy.polynomial.Polynomial,它包裝了較低層級的運算,如 numpy.polyadd()numpy.polydiv() 等。一般來說,當同時提供函數式和物件導向 API 時,JAX 應避免為物件導向 API 提供包裝函式,而應為函數式 API 提供包裝函式。

在僅存在物件導向 API 的情況下,除非在其他軸線上表現強勁,否則 JAX 應避免提供包裝函式。

軸線 6:對 JAX 使用者與利害關係人的整體「重要性」#

在 JAX 中包含 NumPy/SciPy API 的決策也應考慮到演算法對一般使用者社群的重要性。誠然,很難量化誰是「利害關係人」以及應如何衡量這種重要性;但我們包含這一點是為了明確說明,關於在 JAX 的 NumPy 和 SciPy 包裝函式中包含哪些內容的任何決策,都將涉及一些難以量化的自由裁量權。

對於現有的 API,在 github 中搜尋使用情況可能對於確立重要性或缺乏重要性很有用;例如,我們可以回到上面討論的 jax.scipy.special.bessel_jn():搜尋顯示此函數在 github 上只有 少數使用案例,這可能部分與先前提到準確性問題有關。

評估:範圍內項目?#

在本節中,我們將嘗試根據上述標準評估 NumPy 和 SciPy API,包括來自目前 JAX API 的一些範例。這不會是所有現有函數和類別的完整清單,而是一個更通用的按子模組和主題進行的討論,並附帶相關範例。

NumPy API#

numpy 命名空間#

我們認為主要 numpy 命名空間中的函數基本上都在 JAX 的範圍內,因為它與 XLA (軸線 1) 和 Python Array API (軸線 2) 的總體對齊,以及它對 JAX 使用者社群的總體重要性 (軸線 6)。有些函數可能處於邊緣 (像 numpy.intersect1d()np.setdiff1d()np.union1d() 這樣的函數,可以說在某些方面未能通過標準),但為了簡化,我們宣告主要 numpy 命名空間中的所有陣列函數都在 JAX 的範圍內。

numpy.linalg & numpy.fft#

numpy.linalgnumpy.fft 子模組包含許多與 XLA 提供的功能廣泛對齊的函數。其他函數具有複雜的裝置特定降低,但代表利害關係人的重要性 (軸線 6) 超過複雜性的情況。基於這個原因,我們認為這兩個子模組都在 JAX 的範圍內。

numpy.random#

numpy.random 超出 JAX 的範圍,因為基於狀態的 RNG 與 JAX 的計算模型從根本上不相容。我們轉而專注於 jax.random,它使用基於計數器的 PRNG 提供類似的功能。

numpy.ma & numpy.polynomial#

numpy.manumpy.polynomial 子模組主要關注於為可以通過其他函數式方式表達的計算提供物件導向介面 (軸線 5);基於這個原因,我們認為它們超出 JAX 的範圍。

numpy.testing#

NumPy 的測試功能僅對主機端計算有意義,因此我們在 JAX 中不包含任何用於它的包裝函式。也就是說,JAX 陣列與 numpy.testing 相容,並且 JAX 在整個 JAX 測試套件中頻繁使用它。

SciPy API#

SciPy 在頂層命名空間中沒有函數,但包含許多子模組。我們在下面分別考量每個子模組,省略已棄用的模組。

scipy.cluster#

scipy.cluster 模組包含用於階層式叢集、k-means 和相關演算法的工具。這些在多個軸線上都很弱,最好由下游套件提供服務。JAX 中已經存在一個函數 (jax.scipy.cluster.vq.vq()),但在 github 上 沒有明顯的使用案例:這表示叢集對於 JAX 使用者來說並不廣泛重要。

建議:棄用並移除 jax.scipy.cluster.vq()

scipy.constants#

scipy.datasets#

scipy.fft#

scipy.integrate#

JAX 目前確實包含

基於軸 1、2、4 和 6,

建議:移除

scipy.interpolate#

JAX 目前確實有

展望未來,我們應將

scipy.io#

scipy.linalg#

scipy.ndimage#

scipy.odr#

scipy.optimize#

由於這些受到良好支援的外部套件,我們現在認為

建議:棄用

🟡 scipy.signal#

🟡 scipy.sparse#

另一方面,

建議:探索將稀疏求解器移至 Lineax,否則將

scipy.spatial#

建議:考慮棄用和移除

scipy.special#

其他函數需要更複雜的實作;上面提到的一個範例是

有一些現有的函數包裝器我們應該仔細研究一下;例如

建議:重構並提高

scipy.stats#

我們目前沒有任何假設檢定函數的包裝器,可能是因為這些函數對於 JAX 的主要使用者群來說不太有用。

關於分佈,在某些情況下,

建議:展望未來,我們應將統計分佈和摘要統計視為範圍內,並將假設檢定和相關功能通常視為範圍外。