Jax 與 Jaxlib 版本控制#

為何 jaxjaxlib 是個別的套件?#

我們將 JAX 發佈為兩個個別的 Python wheel,分別是 jax(純 Python wheel)和 jaxlib(主要為 C++ wheel),其中包含以下程式庫:

  • XLA,

  • XLA 使用的 LLVM 片段,

  • MLIR 基礎架構,例如 StableHLO Python 綁定。

  • JAX 特定的 C++ 程式庫,用於快速 JIT 和 PyTree 操作。

我們發佈個別的 jaxjaxlib 套件,是因為這樣可以輕鬆處理 JAX 的 Python 部分,而無需建置 C++ 程式碼,甚至無需安裝 C++ 工具鏈。jaxlib 是一個大型程式庫,對許多使用者來說不容易建置,但 JAX 的大多數變更僅涉及 Python 程式碼。透過允許 Python 部分獨立於 C++ 部分進行更新,我們提高了 Python 變更的開發速度。

此外,建置 jaxlib 成本不低,但我們希望能夠在 CPU 資源不多的環境中迭代和執行 JAX 測試,例如在 Github Actions 或筆記型電腦上。我們的許多 CI 建置都只是使用預先建置的 jaxlib,而無需在每個 PR 上重建 JAX 的 C++ 部分。

正如我們將看到的,個別發佈 jaxjaxlib 會帶來成本,因為這要求對 jaxlib 的變更必須維持向後相容的 API。但是,我們認為,總體而言,即使以使 C++ 變更稍微困難為代價,使 Python 變更變得容易仍然是更可取的。

jaxjaxlib 如何進行版本控制?#

摘要:jaxjaxlib 在 JAX 原始碼樹中共享相同的版本號,但作為個別的 Python 套件發佈。安裝後,jax 套件版本必須大於或等於 jaxlib 的版本,並且 jaxlib 的版本必須大於或等於 jax 指定的最低 jaxlib 版本。

jaxjaxlib 的發行版本都編號為 x.y.z,其中 x 是主要版本,y 是次要版本,而 z 是可選的修補程式發行版本。版本號必須遵循 PEP 440。版本號比較是對整數元組進行詞彙比較。

每個 jax 發行版本都有一個關聯的最低 jaxlib 版本 mx.my.mzjax 版本 x.y.z 的最低 jaxlib 版本必須不超過 x.y.z

為了使 jax 版本 x.y.zjaxlib 版本 lx.ly.lz 相容,必須滿足以下條件:

  • jaxlib 版本 (lx.ly.lz) 必須大於或等於最低 jaxlib 版本 (mx.my.mz)。

  • jax 版本 (x.y.z) 必須大於或等於 jaxlib 版本 (lx.ly.lz)。

這些限制暗示了以下發行規則:

  • jax 可以隨時單獨發行,而無需更新 jaxlib

  • 如果發行了新的 jaxlib,則必須同時發行 jax 版本。

這些版本限制目前在匯入時由 jax 檢查,而不是表示為 Python 套件版本限制。jax 在執行階段檢查 jaxlib 版本,而不是使用 pip 套件版本限制,因為我們為各種硬體和軟體版本(例如,GPU、TPU 等)提供個別的 jaxlib wheel。由於我們不知道哪個是任何特定使用者的正確選擇,因此我們不希望 pip 自動為我們安裝 jaxlib 套件。

未來,我們希望將 jaxlib 中特定於硬體的部分分離到個別的外掛程式中,屆時最低版本可以表示為 Python 套件依賴項。目前,我們確實提供特定於平台的額外需求,以安裝相容的 jaxlib 版本,例如,jax[cuda]

如何安全地變更 jaxlib 的 API?#

  • 只要最低 jaxlib 版本提高到相容版本,jax 可以隨時放棄與較舊的 jaxlib 發行版本的相容性。但是,請注意,即使對於未發行的 jax 版本,最低 jaxlib 也必須是已發行的版本!這使我們能夠在 CI 建置中使用已發行的 jaxlib wheel,並允許 Python 開發人員在 HEAD 上處理 jax,而無需建置 jaxlib

    例如,若要移除 jax Python 程式碼中的舊向後相容性路徑,只需提高最低 jaxlib 版本,然後刪除相容性路徑即可。

  • jaxlib 可以放棄與低於其自身發行版本號的較舊 jax 發行版本的相容性。jax 強制的版本限制將禁止使用不相容的 jaxlib

    例如,為了使 jaxlib 放棄較舊 jax 版本使用的 Python 綁定 API,必須遞增 jaxlib 的次要或主要版本號。

  • 如果可能,對 jaxlib 的變更應以向後相容的方式進行。

    一般而言,只要遵循關於 jax 與所有至少與最低版本一樣新的 jaxlib 相容的規則,jaxlib 可以自由變更其 API。這表示 jax 必須始終與至少兩個版本的 jaxlib 相容,即上次發行版本和主幹版本,實際上是下一個發行版本。如果保持相容性,則更容易做到這一點,儘管可以使用 jax 的版本測試進行不相容的變更;請參閱下文。

    例如,向 jaxlib 新增函式通常是安全的,但如果目前的 jax 仍在使用現有的函式,則移除現有的函式或變更其簽名是不安全的。對於高於最低版本直到 HEAD 的所有 jaxlib 發行版本,對 jax 的變更必須正常運作或優雅地降級。

請注意,此處的相容性規則僅適用於 jaxjaxlib已發行版本。它們不適用於未發行的版本;也就是說,如果 API 從未發行,或者沒有已發行的 jax 版本使用該 API,則可以從 jaxlib 引入然後移除 API。

jaxlib 的原始碼如何佈局?#

jaxlib 分散在兩個主要儲存庫中,即主 JAX 儲存庫中的 jaxlib/ 子目錄XLA 原始碼樹(位於 XLA 儲存庫內)。XLA 內 JAX 特定的部分主要位於 xla/python 子目錄中。

JAX 的 C++ 部分(例如 Python 綁定和執行階段組件)位於 XLA 樹中的原因部分是歷史因素,部分是技術因素。

歷史原因是,最初 xla/python 綁定被設想為通用 Python 綁定,可以與其他框架共享。實際上,這種情況越來越少見,xla/python 包含許多 JAX 特定的部分,並且可能會包含更多。因此,最好簡單地將 xla/python 視為 JAX 的一部分。

技術原因是 XLA C++ API 不穩定。透過將 XLA:Python 綁定保留在 XLA 樹中,它們的 C++ 實作可以與 XLA 的 C++ API 原子性地更新。維護 Python API 的向後和向前相容性比 C++ API 更容易,因此 xla/python 公開 Python API,並負責在 Python 層級維護向後相容性。

jaxlib 是使用 Bazel 從 jax 儲存庫建置的。來自 XLA 儲存庫的 jaxlib 片段作為 Bazel 子模組併入建置中。若要更新建置期間使用的 XLA 版本,必須更新 Bazel WORKSPACE 中的釘選版本。這會根據需要手動完成,但可以在每次建置時覆寫。

我們如何在發行版本之間跨越 jaxjaxlib 的邊界進行變更?#

jaxlib 版本是一個粗略的工具:它僅允許我們推斷發行版本

但是,由於 jaxjaxlib 程式碼分散在無法在單一變更中原子性更新的儲存庫中,因此我們需要以比發行週期更精細的粒度來管理相容性。為了管理細粒度的相容性,我們有獨立於 jaxlib 發行版本號的額外版本控制。

我們在 XLA 儲存庫中的 xla_client.py 中維護一個額外的版本號 (_version)。這個版本號的想法是在 xla/python 中與 JAX 的 C++ 部分一起定義,JAX Python 也可以作為 jax._src.lib.xla_extension_version 存取,並且每次對 XLA/Python 程式碼進行變更時都必須遞增,而這些變更對 jax 具有向後相容性影響。然後,JAX Python 程式碼可以使用此版本號來維護向後相容性,例如:

from jax._src.lib import xla_extension_version

# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
  # Use new code path
  ...
else:
  # Use old code path.

請注意,此版本號是除了已發行版本號的限制之外的,也就是說,此版本號的存在是為了協助管理未發行程式碼在開發期間的相容性。發行版本也必須遵循上述相容性規則。