Jax 與 Jaxlib 版本控制#
為何 jax
和 jaxlib
是個別的套件?#
我們將 JAX 發佈為兩個個別的 Python wheel,分別是 jax
(純 Python wheel)和 jaxlib
(主要為 C++ wheel),其中包含以下程式庫:
XLA,
XLA 使用的 LLVM 片段,
MLIR 基礎架構,例如 StableHLO Python 綁定。
JAX 特定的 C++ 程式庫,用於快速 JIT 和 PyTree 操作。
我們發佈個別的 jax
和 jaxlib
套件,是因為這樣可以輕鬆處理 JAX 的 Python 部分,而無需建置 C++ 程式碼,甚至無需安裝 C++ 工具鏈。jaxlib
是一個大型程式庫,對許多使用者來說不容易建置,但 JAX 的大多數變更僅涉及 Python 程式碼。透過允許 Python 部分獨立於 C++ 部分進行更新,我們提高了 Python 變更的開發速度。
此外,建置 jaxlib
成本不低,但我們希望能夠在 CPU 資源不多的環境中迭代和執行 JAX 測試,例如在 Github Actions 或筆記型電腦上。我們的許多 CI 建置都只是使用預先建置的 jaxlib
,而無需在每個 PR 上重建 JAX 的 C++ 部分。
正如我們將看到的,個別發佈 jax
和 jaxlib
會帶來成本,因為這要求對 jaxlib
的變更必須維持向後相容的 API。但是,我們認為,總體而言,即使以使 C++ 變更稍微困難為代價,使 Python 變更變得容易仍然是更可取的。
jax
和 jaxlib
如何進行版本控制?#
摘要:jax
和 jaxlib
在 JAX 原始碼樹中共享相同的版本號,但作為個別的 Python 套件發佈。安裝後,jax
套件版本必須大於或等於 jaxlib
的版本,並且 jaxlib
的版本必須大於或等於 jax
指定的最低 jaxlib
版本。
jax
和 jaxlib
的發行版本都編號為 x.y.z
,其中 x
是主要版本,y
是次要版本,而 z
是可選的修補程式發行版本。版本號必須遵循 PEP 440。版本號比較是對整數元組進行詞彙比較。
每個 jax
發行版本都有一個關聯的最低 jaxlib
版本 mx.my.mz
。jax
版本 x.y.z
的最低 jaxlib
版本必須不超過 x.y.z
。
為了使 jax
版本 x.y.z
和 jaxlib
版本 lx.ly.lz
相容,必須滿足以下條件:
jaxlib 版本 (
lx.ly.lz
) 必須大於或等於最低 jaxlib 版本 (mx.my.mz
)。jax 版本 (
x.y.z
) 必須大於或等於 jaxlib 版本 (lx.ly.lz
)。
這些限制暗示了以下發行規則:
jax
可以隨時單獨發行,而無需更新jaxlib
。如果發行了新的
jaxlib
,則必須同時發行jax
版本。
這些版本限制目前在匯入時由 jax
檢查,而不是表示為 Python 套件版本限制。jax
在執行階段檢查 jaxlib
版本,而不是使用 pip
套件版本限制,因為我們為各種硬體和軟體版本(例如,GPU、TPU 等)提供個別的 jaxlib
wheel。由於我們不知道哪個是任何特定使用者的正確選擇,因此我們不希望 pip
自動為我們安裝 jaxlib
套件。
未來,我們希望將 jaxlib
中特定於硬體的部分分離到個別的外掛程式中,屆時最低版本可以表示為 Python 套件依賴項。目前,我們確實提供特定於平台的額外需求,以安裝相容的 jaxlib 版本,例如,jax[cuda]
。
如何安全地變更 jaxlib
的 API?#
只要最低
jaxlib
版本提高到相容版本,jax
可以隨時放棄與較舊的jaxlib
發行版本的相容性。但是,請注意,即使對於未發行的jax
版本,最低jaxlib
也必須是已發行的版本!這使我們能夠在 CI 建置中使用已發行的jaxlib
wheel,並允許 Python 開發人員在 HEAD 上處理jax
,而無需建置jaxlib
。例如,若要移除
jax
Python 程式碼中的舊向後相容性路徑,只需提高最低 jaxlib 版本,然後刪除相容性路徑即可。jaxlib
可以放棄與低於其自身發行版本號的較舊jax
發行版本的相容性。jax
強制的版本限制將禁止使用不相容的jaxlib
。例如,為了使
jaxlib
放棄較舊jax
版本使用的 Python 綁定 API,必須遞增jaxlib
的次要或主要版本號。如果可能,對
jaxlib
的變更應以向後相容的方式進行。一般而言,只要遵循關於
jax
與所有至少與最低版本一樣新的jaxlib
相容的規則,jaxlib
可以自由變更其 API。這表示jax
必須始終與至少兩個版本的jaxlib
相容,即上次發行版本和主幹版本,實際上是下一個發行版本。如果保持相容性,則更容易做到這一點,儘管可以使用jax
的版本測試進行不相容的變更;請參閱下文。例如,向
jaxlib
新增函式通常是安全的,但如果目前的jax
仍在使用現有的函式,則移除現有的函式或變更其簽名是不安全的。對於高於最低版本直到 HEAD 的所有jaxlib
發行版本,對jax
的變更必須正常運作或優雅地降級。
請注意,此處的相容性規則僅適用於 jax
和 jaxlib
的已發行版本。它們不適用於未發行的版本;也就是說,如果 API 從未發行,或者沒有已發行的 jax
版本使用該 API,則可以從 jaxlib
引入然後移除 API。
jaxlib
的原始碼如何佈局?#
jaxlib
分散在兩個主要儲存庫中,即主 JAX 儲存庫中的 jaxlib/
子目錄和 XLA 原始碼樹(位於 XLA 儲存庫內)。XLA 內 JAX 特定的部分主要位於 xla/python
子目錄中。
JAX 的 C++ 部分(例如 Python 綁定和執行階段組件)位於 XLA 樹中的原因部分是歷史因素,部分是技術因素。
歷史原因是,最初 xla/python
綁定被設想為通用 Python 綁定,可以與其他框架共享。實際上,這種情況越來越少見,xla/python
包含許多 JAX 特定的部分,並且可能會包含更多。因此,最好簡單地將 xla/python
視為 JAX 的一部分。
技術原因是 XLA C++ API 不穩定。透過將 XLA:Python 綁定保留在 XLA 樹中,它們的 C++ 實作可以與 XLA 的 C++ API 原子性地更新。維護 Python API 的向後和向前相容性比 C++ API 更容易,因此 xla/python
公開 Python API,並負責在 Python 層級維護向後相容性。
jaxlib
是使用 Bazel 從 jax
儲存庫建置的。來自 XLA 儲存庫的 jaxlib
片段作為 Bazel 子模組併入建置中。若要更新建置期間使用的 XLA 版本,必須更新 Bazel WORKSPACE
中的釘選版本。這會根據需要手動完成,但可以在每次建置時覆寫。
我們如何在發行版本之間跨越 jax
和 jaxlib
的邊界進行變更?#
jaxlib 版本是一個粗略的工具:它僅允許我們推斷發行版本。
但是,由於 jax
和 jaxlib
程式碼分散在無法在單一變更中原子性更新的儲存庫中,因此我們需要以比發行週期更精細的粒度來管理相容性。為了管理細粒度的相容性,我們有獨立於 jaxlib
發行版本號的額外版本控制。
我們在 XLA 儲存庫中的 xla_client.py
中維護一個額外的版本號 (_version
)。這個版本號的想法是在 xla/python
中與 JAX 的 C++ 部分一起定義,JAX Python 也可以作為 jax._src.lib.xla_extension_version
存取,並且每次對 XLA/Python 程式碼進行變更時都必須遞增,而這些變更對 jax
具有向後相容性影響。然後,JAX Python 程式碼可以使用此版本號來維護向後相容性,例如:
from jax._src.lib import xla_extension_version
# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
請注意,此版本號是除了已發行版本號的限制之外的,也就是說,此版本號的存在是為了協助管理未發行程式碼在開發期間的相容性。發行版本也必須遵循上述相容性規則。