jax.extend:擴充模組#

@froystig@sharadmv@jakevdp@yashk2810

2023 年 5 月

import jax.extend as jex

許多專案依賴 JAX 的程式碼庫內部結構,通常是為了使用其核心機制 (例如,撰寫 在其 IR 上的轉換) 或擴充它 (例如,定義新的 primitives)。這些依賴關係面臨兩個挑戰:(a) 我們的內部結構並非都設計完善以供外部使用,以及 (b) 規避 JAX 的公共 API 是不受支援的。換句話說,我們的內部結構通常像函式庫一樣被使用,但既沒有像函式庫一樣的結構,也沒有像函式庫一樣的更新。

本提案考慮引入 jax.extend 模組,以定義 JAX 某些內部元件的函式庫視圖。我們會將其視為第二級 API,仍然保證基本上無相容性政策,但希望能讓使用者更容易發現變更發生時的情況。

jax.extend 的受眾包括 JAX 相鄰的 Python 函式庫,例如 Oryxjax-triton 和許多其他函式庫,以及實驗函式轉換、自動微分系統、數值程式設計的編譯器前端等的專案。

本說明概述了 jax.extend 現在和最終可能的樣貌。它沒有詳細說明所有事項,而是建議我們開始迭代開發此模組。

請注意,jax.extendjax.experimental 不同,後者是正在進行中的新功能和想法的暫存區。jax.experimental 中的工作通常最終會進入另一個 JAX 模組,或者完全移除。

無相容性政策#

為了保持較低的開發管理費用,jax.extend 將不會遵循公開的 API 相容性政策。它不會承諾有棄用期限,也不會保證版本之間的向後相容性。每個版本都可能破壞現有的呼叫者,而沒有簡單的補救措施 (例如,沒有重新引入先前行為的標記)。我們會依靠變更日誌來標示此類變更。

需要與 JAX 版本一起定期升級其程式碼的 jax.extend 呼叫者可能會發現將 JAX 版本固定為版本之間的中間步驟很有用。這是當今依賴 JAX 內部結構的專案中的常見習慣。不同之處在於,現在它將在變更日誌公告和關於函式庫設計和命名的更好意圖的幫助下實現。

迭代開發#

沒有相容性政策使得開始實作更容易:在第一天,我們可以從內部套件 (例如 jax._src) 和今天的 jax.corejax.interpreters 中移動一些符號。然後我們可以從那裡迭代改進。

可能的模組概述#

我們可以想像最終 jax.extend 將包含以下模組

  • core – primitives、Jaxpr IR 等。

  • interpreters – 核心轉換 (例如自動微分、批次處理) 和降低。

  • random – 隨機位元產生、金鑰分割和摺疊、金鑰陣列。

  • sharding – 圍繞分散式陣列的額外功能。

我們最初可能在模組中還有其他符號,例如 jex.api_util,因為我們致力於移除或替換它們。其他符號將隨著時間推移而決定。例如,jex.lib 可以提供 jaxlib 的入口點 (並且會在近期內這樣做),但目前尚不清楚我們是否要長期保留它。

以下是一些關於這些模組可能包含內容的初步想法。

jax.extend.core#

這應至少讓呼叫者能夠定義新的 JAX primitives 並處理 Jaxpr IR (jax.make_jaxpr(...) 的輸出)。支援此功能可能涉及提供

  • 存取現有的核心系統 primitives,例如今天的 jax._src.lax.add_p

  • 存取 IR 類型,例如目前的 jax._src.core.ShapedArray

  • 用於檢查和美觀列印 jaxpr 的函式。

  • 用於顯式建構 jaxpr 的函式,而不是透過 jax.make_jaxpr (或不使用!) 暫存 Python 函式。

在初始化時,此模組將包含比定義 primitives 和規則所需更多的符號,包括在設定 「final-style 轉換」中使用的各種名稱,例如目前的 jax._src.core.TraceTracer 類別。我們可以重新審視 jex.core 是否也應支援 final-style 擴充以及 initial style 方法,以及它是否可以透過比完全公開 TraceTracer 更窄的 API 來實現。Oryx 可能有助於引導這些決策。

我們也可以考慮將 make_jaxpr 本身重新定位到 jex.core

jax.extend.interpreters#

此模組將提供一種為 primitives 註冊各種轉換規則的方法—定義它們在 AD、批次處理、降低等方面的行為。

它最初將反映 jax._src.interpreters,提供模組 adbatchingpartial_eval (用於將 Python 暫存到 Jaxpr,以及用於 AD 中的線性化)、mlirpxlaxla。前三個可能會被 jex.core 中的單個 primitive 擴充 API 取代。後三個用於降低的模組可以簡化為一個模組,也許。

今天,為了撰寫轉換規則,例如用於 AD 和批次處理,呼叫者可能需要與 tracers 相關的符號,例如 JVPTracerBatchTracer。這在以後可能是可以避免的,並且允許我們從 jex 中移除 tracer 類型。

此模組加上 jex.core 應該足以複製今天的自訂 primitive 教學 (例如,我們的dfm 的)。例如,定義 primitive 及其在 jax.jit 下的行為在近期內可以透過以下方式實現

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random#

此模組可以公開我們用於定義新 RNG 實作的機制,以及用於處理 PRNG 金鑰內部結構的函式 (請參閱問題 #9263),例如目前的 jax._src.prng.random_wraprandom_unwrap

它還可以公開作為內建 RNG 實作基礎的 keyed hash 函式,例如 jax._src.prng.threefry_2x32

jax.extend.sharding#

此模組可以公開用於分片分散式陣列的低階實用程式。

我們目前只想到一個項目。XLA 編譯器的陣列分片格式比 JAX 提供的格式更具表現力。我們可以將其作為 jex.sharding.XlaOpShardingProto 提供,對應於內部目前的 jax._src.lib.xla_client.OpSharding