匯出和序列化暫存運算#

預先降低和編譯 API 產生可用於偵錯或在同一進程中進行編譯和執行的物件。 有時您希望序列化降低的 JAX 函數,以便在單獨的進程中進行編譯和執行,可能是在稍後的時間。 這將允許您

  • 在另一個進程或機器中編譯和執行函數,而無需存取 JAX 程式,也無需重複暫存和降低,例如,在推論系統中。

  • 在無法存取您想要稍後編譯和執行函數的加速器的機器上追蹤和降低函數。

  • 封存 JAX 函數的快照,例如,能夠稍後重現您的結果。 注意:查看此用例的相容性保證

如需更多詳細資訊,請參閱 jax.export API 參考。

以下是一個範例

>>> import re
>>> import numpy as np
>>> import jax
>>> from jax import export

>>> def f(x): return 2 * x * x


>>> exported: export.Exported = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((), np.float32))

>>> # You can inspect the Exported object
>>> exported.fun_name
'f'

>>> exported.in_avals
(ShapedArray(float32[]),)

>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
  func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = ""}) {

>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()

>>> # The serialized function can later be rehydrated and called from
>>> # another JAX computation, possibly in another process.
>>> rehydrated_exp: export.Exported = export.deserialize(serialized)
>>> rehydrated_exp.in_avals
(ShapedArray(float32[]),)

>>> def callee(y):
...  return 3. * rehydrated_exp.call(y * 4.)

>>> callee(1.)
Array(96., dtype=float32)

序列化分為兩個階段

  1. 匯出以產生 jax.export.Exported 物件,其中包含降低函數的 StableHLO 以及從另一個 JAX 函數呼叫它所需的元資料。 我們計劃新增程式碼以從 TensorFlow 產生 Exported 物件,並從 TensorFlow 和 PyTorch 使用 Exported 物件。

  2. 使用 flatbuffers 格式實際序列化為位元組陣列。 有關序列化為 TensorFlow 圖以便與 TensorFlow 互通的替代方案,請參閱 與 TensorFlow 互通

反向模式 AD 的支援#

序列化可以選擇性地支援高階反向模式 AD。 這是透過序列化原始函數的 jax.vjp() 以及原始函數來完成的,直到使用者指定的階數 (預設為 0,表示重新水合的函數無法微分)

>>> import jax
>>> from jax import export
>>> from typing import Callable

>>> def f(x): return 7 * x * x * x

>>> # Serialize 3 levels of VJP along with the primal function
>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3)
>>> rehydrated_f: Callable = export.deserialize(blob).call

>>> rehydrated_f(0.1)  # 7 * 0.1^3
Array(0.007, dtype=float32)

>>> jax.grad(rehydrated_f)(0.1)  # 7*3 * 0.1^2
Array(0.21000001, dtype=float32)

>>> jax.grad(jax.grad(rehydrated_f))(0.1)  # 7*3*2 * 0.1
Array(4.2, dtype=float32)

>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1)  # 7*3*2
Array(42., dtype=float32)

>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1)  
Traceback (most recent call last):
ValueError: No VJP is available

請注意,VJP 函數是在序列化時,當 JAX 程式仍然可用時延遲計算的。 這表示它遵守 JAX VJP 的所有功能,例如,jax.custom_vjp()jax.remat()

請注意,重新水合的函數不支援任何其他轉換,例如,正向模式 AD (jvp) 或 jax.vmap()

相容性保證#

您不應將僅從降低 (jax.jit(f).lower(1.).compiler_ir()) 獲得的原始 StableHLO 用於封存和在另一個進程中進行編譯,原因如下。

首先,編譯可能會使用不同版本的編譯器,支援不同版本的 StableHLO。 jax.export 模組透過使用 StableHLO 的可攜式工件功能來處理 StableHLO 操作集可能發生的演變,從而解決此問題。

自訂呼叫的相容性保證#

其次,原始 StableHLO 可能包含參考 C++ 函數的自訂呼叫。 JAX 使用自訂呼叫來降低少量基本運算,例如,線性代數基本運算、分片註釋或 Pallas 核心。 這些不屬於 StableHLO 的相容性保證範圍。 這些函數的 C++ 實作很少變更,但它們可能會變更。

jax.export 做出以下匯出相容性保證:JAX 匯出的工件可以由編譯器和 JAX 執行階段系統編譯和執行,這些編譯器和 JAX 執行階段系統

  • 比用於匯出的 JAX 版本新最多 6 個月 (我們說 JAX 匯出提供 6 個月的向後相容性)。 如果我們想要封存匯出的工件以供稍後編譯和執行,這非常有用。

  • 比用於匯出的 JAX 版本舊最多 3 週 (我們說 JAX 匯出提供 3 週的向前相容性)。 如果我們想要使用在匯出之前建置和部署的消費者 (例如,在完成匯出時已部署的推論系統) 來編譯和執行匯出的工件,這非常有用。

(特定的相容性視窗長度與 JAX 針對 jax2tf 承諾的長度相同,並且基於 TensorFlow 相容性。「向後相容性」術語是從消費者 (例如,推論系統) 的角度來看的。)

重要的是匯出和消費元件的建置時間,而不是匯出和編譯發生的時間。 對於外部 JAX 使用者,可以執行不同版本的 JAX 和 jaxlib;重要的是 jaxlib 版本發佈的時間。

為了減少不相容性的可能性,內部 JAX 使用者應

  • 盡可能頻繁地重建和重新部署消費者系統.

而外部使用者應

  • 盡可能使用相同版本的 jaxlib 執行匯出和消費者系統,以及

  • 使用最新發佈的 jaxlib 版本 匯出以進行封存。

如果您繞過 jax.export API 以取得 StableHLO 程式碼,則相容性保證不適用。

為了確保向前相容性,當我們變更 JAX 降低規則以使用新的自訂呼叫目標時,JAX 將在 3 週內避免使用新的目標。 若要使用最新的降低規則,您可以傳遞 --jax_export_ignore_forward_compatibility=1 組態旗標或 JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY=1 環境變數。

只有一部分自訂呼叫保證穩定且具有相容性保證 (請參閱列表)。 我們不斷將更多自訂呼叫目標新增至允許的列表,並進行向後相容性測試。 如果您嘗試序列化調用其他自訂呼叫目標的程式碼,您將在匯出期間收到錯誤。

如果您想要針對特定的自訂呼叫停用此安全檢查,例如,目標為 my_target,您可以將 export.DisabledSafetyCheck.custom_call("my_target") 新增至 export 方法的 disabled_checks 參數中,如下列範例所示

>>> import jax
>>> from jax import export
>>> from jax import lax
>>> from jax._src import core
>>> from jax._src.interpreters import mlir
>>> # Define a new primitive backed by a custom call
>>> new_prim = core.Primitive("new_prim")
>>> _ = new_prim.def_abstract_eval(lambda x: x)
>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
}

>>> # If we try to export, we get an error
>>> export.export(jax.jit(new_prim.bind))(1.)  
Traceback (most recent call last):
ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind

>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`
>>> exp = export.export(
...    jax.jit(new_prim.bind),
...    disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.)

有關確保相容性的開發人員資訊,請參閱 確保向前和向後相容性

跨平台和多平台匯出#

JAX 降低對於少量 JAX 基本運算而言是平台特定的。 預設情況下,程式碼會針對匯出機器上存在的加速器進行降低和匯出

>>> from jax import export
>>> export.default_export_platform()
'cpu'

當嘗試在不具有程式碼匯出所針對的加速器的機器上編譯 Exported 物件時,將會引發錯誤的安全檢查。

您可以明確指定程式碼應匯出到哪些平台。 這可讓您指定與匯出時可用的加速器不同的加速器,甚至可讓您指定多平台匯出,以取得可在多個平台上編譯和執行的 Exported 物件。

>>> import jax
>>> from jax import export
>>> from jax import lax

>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # even if the current machine does not have that accelerator.
>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.)

>>> # But you will get an error if you try to compile `exp`
>>> # on a machine that does not have TPUs.
>>> exp.call(1.)  
Traceback (most recent call last):
ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.

>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
>>> # parameter to `export`, e.g., because you have reasons to believe
>>> # that the code lowered will run adequately on the current
>>> # compilation platform (which is the case for `cos` in this
>>> # example):
>>> exp_unsafe = export.export(jax.jit(lax.cos),
...    platforms=['tpu'],
...    disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)

>>> exp_unsafe.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)

# and similarly with multi-platform lowering
>>> exp_multi = export.export(jax.jit(lax.cos),
...    platforms=['tpu', 'cpu', 'cuda'])(1.)
>>> exp_multi.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)

對於多平台匯出,StableHLO 將包含多個降低,但僅適用於需要降低的基本運算,因此產生的模組大小應僅比預設匯出的模組大小略大。 作為極端情況,當序列化不含任何具有平台特定降低的基本運算的模組時,您將獲得與單平台匯出相同的 StableHLO。

>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # A largish function
>>> def f(x):
...   for i in range(1000):
...     x = jnp.cos(x)
...   return x

>>> exp_single = export.export(jax.jit(f))(1.)
>>> len(exp_single.mlir_module_serialized)  
9220

>>> exp_multi = export.export(jax.jit(f),
...                           platforms=["cpu", "tpu", "cuda"])(1.)
>>> len(exp_multi.mlir_module_serialized)  
9282

形狀多型匯出#

在 JIT 模式下使用時,JAX 將針對每種輸入形狀組合分別追蹤和降低函數。 匯出時,在某些情況下,可以對某些輸入維度使用維度變數,以取得可用於多種輸入形狀組合的匯出工件。

請參閱 形狀多型 文件。

裝置多型匯出#

匯出的工件可能包含輸入、輸出和某些中介的分片註釋,但這些註釋並非直接參考匯出時存在的實際物理裝置。 相反,分片註釋參考邏輯裝置。 這表示您可以在用於匯出的不同物理裝置上編譯和執行匯出的工件。

實現裝置多型匯出的最乾淨方法是使用使用 jax.sharding.AbstractMesh 建構的分片,其中僅包含網格形狀和軸名稱。 但是,如果您使用針對具有具體裝置的網格建構的分片,則可以獲得相同的結果,因為網格中的實際裝置在追蹤和降低時會被忽略

>>> import jax
>>> from jax import export
>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # Use an AbstractMesh for exporting
>>> export_mesh = AbstractMesh((("a", 4),))

>>> def f(x):
...   return x.T

>>> exp = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((32,), dtype=np.int32,
...                         sharding=NamedSharding(export_mesh, P("a"))))

>>> # `exp` knows for how many devices it was exported.
>>> exp.nr_devices
4

>>> # and it knows the shardings for the inputs. These will be applied
>>> # when the exported is called.
>>> exp.in_shardings_hlo
({devices=[4]<=[4]},)

>>> # You can also use a concrete set of devices for exporting
>>> concrete_devices = jax.local_devices()[:4]
>>> concrete_mesh = Mesh(concrete_devices, ("a",))
>>> exp2 = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((32,), dtype=np.int32,
...                         sharding=NamedSharding(concrete_mesh, P("a"))))

>>> # You can expect the same results
>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo

>>> # When you call an Exported, you must use a concrete set of devices
>>> arg = jnp.arange(8 * 4)
>>> res1 = exp.call(jax.device_put(arg,
...                                NamedSharding(concrete_mesh, P("a"))))

>>> # Check out the first 2 shards of the result
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
['device=TFRT_CPU_0 index=(slice(0, 8, None),)',
 'device=TFRT_CPU_1 index=(slice(8, 16, None),)']

>>> # We can call `exp` with some other 4 devices and another
>>> # mesh with a different shape, as long as the number of devices is
>>> # the same.
>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c"))
>>> res2 = exp.call(jax.device_put(arg,
...                                NamedSharding(other_mesh, P("b"))))

>>> # Check out the first 2 shards of the result. Notice that the output is
>>> # sharded similarly; this means that the input was resharded according to the
>>> # exp.in_shardings.
>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]]
['device=TFRT_CPU_2 index=(slice(0, 8, None),)',
 'device=TFRT_CPU_3 index=(slice(8, 16, None),)']

嘗試使用與匯出時不同的裝置數量來調用匯出的工件是錯誤的

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
...   return x.T

>>> exp = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
...                         sharding=NamedSharding(export_mesh, P("a"))))

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp.call(arg)  
Traceback (most recent call last):
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.

有一些輔助函數可用於分片輸入,以便使用在呼叫站點建構的新網格來呼叫匯出的工件

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
...   return x.T


>>> exp = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
...                         sharding=NamedSharding(export_mesh, P("a"))))

>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))

>>> # Shard the arg according to what `exp` expects.
>>> arg = jnp.arange(4 * len(export_devices))
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
>>> res = exp.call(sharded_arg)

作為特殊功能,如果函數是針對 1 個裝置匯出的,並且不包含任何分片註釋,則可以在相同形狀的引數上調用它,但在多個裝置上分片,編譯器將適當地分片函數

```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> def f(x):
...   return jnp.cos(x)

>>> arg = jnp.arange(4)
>>> exp = export.export(jax.jit(f))(arg)
>>> exp.in_avals
(ShapedArray(int32[4]),)

>>> exp.nr_devices
1

>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",))

>>> # Shard the arg according to what `exp` expects.
>>> sharded_arg = jax.device_put(arg,
...                              NamedSharding(calling_mesh, P("b")))
>>> res = exp.call(sharded_arg)

呼叫慣例版本#

JAX 匯出支援隨著時間的推移而不斷發展,例如,支援效果。 為了支援相容性 (請參閱 相容性保證),我們為每個 Exported 維護一個呼叫慣例版本。 截至 2024 年 6 月,所有使用版本 9 (最新版本,請參閱 所有呼叫慣例版本) 匯出的函數

>>> from jax import export
>>> exp: export.Exported = export.export(jnp.cos)(1.)
>>> exp.calling_convention_version
9

在任何給定時間,匯出 API 都可以支援一系列呼叫慣例版本。 您可以使用 --jax_export_calling_convention_version 旗標或 JAX_EXPORT_CALLING_CONVENTION_VERSION 環境變數來控制要使用的呼叫慣例版本

>>> from jax import export
>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version)
(9, 9)

>>> from jax._src import config
>>> with config.jax_export_calling_convention_version(9):
...  exp = export.export(jnp.cos)(1.)
...  exp.calling_convention_version
9

我們保留移除對產生或使用超過 6 個月舊的呼叫慣例版本支援的權利。

模組呼叫慣例#

Exported.mlir_module 具有一個 main 函數,如果模組支援多個平台 (len(platforms) > 1),則該函數會採用選用的第一個平台索引引數,後跟對應於已排序效果的權杖引數,後跟保留的陣列引數 (對應於 module_kept_var_idxin_avals)。 平台索引是一個 i32 或 i64 純量,用於將目前編譯平台的索引編碼到 platforms 序列中。

內部函數使用不同的呼叫慣例:選用的平台索引引數、選用的維度變數引數 (i32 或 i64 類型的純量張量)、後跟選用的權杖引數 (在存在已排序效果的情況下)、後跟常規陣列引數。 維度引數對應於 args_avals 中出現的維度變數,依其名稱的排序順序排列。

考慮降低具有一個類型為 f32[w, 2 * h] 的陣列引數的函數,其中 wh 是兩個維度變數。 假設我們使用多平台降低,並且我們有一個已排序的效果。 main 函數將如下所示

      func public main(
            platform_index: i32 {jax.global_constant="_platform_index"},
            token_in: token,
            arg: f32[?, ?]) {
         arg_w = hlo.get_dimension_size(arg, 0)
         dim1 = hlo.get_dimension_size(arg, 1)
         arg_h = hlo.floordiv(dim1, 2)
         call _check_shape_assertions(arg)  # See below
         token = new_token()
         token_out, res = call _wrapped_jax_export_main(platform_index,
                                                        arg_h,
                                                        arg_w,
                                                        token_in,
                                                        arg)
         return token_out, res
      }

實際的計算在 _wrapped_jax_export_main 中,同時也採用 hw 維度變數的值。

_wrapped_jax_export_main 的簽名是

      func private _wrapped_jax_export_main(
          platform_index: i32 {jax.global_constant="_platform_index"},
          arg_h: i32 {jax.global_constant="h"},
          arg_w: i32 {jax.global_constant="w"},
          arg_token: stablehlo.token {jax.token=True},
          arg: f32[?, ?]) -> (stablehlo.token, ...)

在呼叫慣例版本 9 之前,效果的呼叫慣例是不同的:main 函數不採用或傳回權杖。 相反,函數會建立 i1[0] 類型的虛擬權杖,並將其傳遞給 _wrapped_jax_export_main_wrapped_jax_export_main 採用 i1[0] 類型的虛擬權杖,並在內部建立真實權杖以傳遞給內部函數。 內部函數使用真實權杖 (在呼叫慣例版本 9 之前和之後)

同樣從呼叫慣例版本 9 開始,包含平台索引或維度變數值的函數引數具有 jax.global_constant 字串屬性,其值是全域常數的名稱,即 _platform_index 或維度變數名稱。 如果全域常數名稱未知,則可能為空。 一些全域常數計算使用內部函數,例如,用於 floor_divide。 此類函數的引數對於所有屬性都具有 jax.global_constant 屬性,這表示函數的結果也是全域常數。

請注意,main 包含對 _check_shape_assertions 的呼叫。 JAX 追蹤假設 arg.shape[1] 是偶數,並且 wh 的值都 >= 1。 當我們調用模組時,我們必須檢查這些約束。 我們使用特殊的自訂呼叫 @shape_assertion,它採用布林值的第一個運算元、字串 error_message 屬性 (可能包含格式規範符 {0}{1}、…),以及對應於格式規範符的可變數量的整數純量運算元。

       func private _check_shape_assertions(arg: f32[?, ?]) {
         # Check that w is >= 1
         arg_w = hlo.get_dimension_size(arg, 0)
         custom_call @shape_assertion(arg_w >= 1, arg_w,
            error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
         # Check that dim1 is even
         dim1 = hlo.get_dimension_size(arg, 1)
         custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2,
            error_message="Division had remainder {0} when computing the value of 'h')
         # Check that h >= 1
         arg_h = hlo.floordiv(dim1, 2)
         custom_call @shape_assertion(arg_h >= 1, arg_h,
            error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")

呼叫慣例版本#

我們在此列出呼叫慣例版本號碼的歷史記錄

  • 版本 1 使用 MHLO 和 CHLO 來序列化程式碼,不再支援。

  • 版本 2 支援 StableHLO 和 CHLO。 自 2022 年 10 月起使用。 不再支援。

  • 版本 3 支援平台檢查和多個平台。 自 2023 年 2 月起使用。 不再支援。

  • 版本 4 支援具有相容性保證的 StableHLO。 這是 JAX 原生序列化啟動時的最早版本。 自 2023 年 3 月 15 日 (cl/516885716) 起在 JAX 中使用。 從 2023 年 3 月 28 日開始,我們停止使用 dim_args_spec (cl/520033493)。 對此版本的支援已於 2023 年 10 月 17 日 (cl/573858283) 移除。

  • 版本 5 新增了對 call_tf_graph 的支援。 目前用於某些特殊用途案例。 自 2023 年 5 月 3 日 (cl/529106145) 起在 JAX 中使用。

  • 版本 6 新增了對 disabled_checks 屬性的支援。 此版本強制執行非空的 platforms 屬性。 自 2023 年 6 月 7 日起受 XlaCallModule 支援,並自 2023 年 6 月 13 日起在 JAX 中提供 (JAX 0.4.13)。

  • 版本 7 新增了對 stablehlo.shape_assertion 操作和在 disabled_checks 中指定的 shape_assertions 的支援。 請參閱 形狀多型存在時的錯誤。 自 2023 年 7 月 12 日起受 XlaCallModule 支援 (cl/547482522),自 2023 年 7 月 20 日起在 JAX 序列化中提供 (JAX 0.4.14),並自 2023 年 8 月 12 日起成為預設值 (JAX 0.4.15)。

  • 版本 8 新增了對 jax.uses_shape_polymorphism 模組屬性的支援,並且僅當屬性存在時才啟用形狀精簡傳遞。 自 2023 年 7 月 21 日起受 XlaCallModule 支援 (cl/549973693),自 2023 年 7 月 26 日起在 JAX 中提供 (JAX 0.4.14),並自 2023 年 10 月 21 日起成為預設值 (JAX 0.4.20)。

  • 版本 9 新增了對效果的支援。 有關精確的呼叫慣例,請參閱 export.Exported 的文件字串。 在此呼叫慣例版本中,我們也使用 jax.global_constant 屬性標記平台索引和維度變數引數。 自 2023 年 10 月 27 日起受 XlaCallModule 支援,自 2023 年 10 月 20 日起在 JAX 中提供 (JAX 0.4.20),並自 2024 年 2 月 1 日起成為預設值 (JAX 0.4.24)。 截至 2024 年 3 月 27 日,這是唯一支援的版本。

開發人員文件#

偵錯#

您可以記錄匯出的模組,在 OSS 和 Google 中的旗標略有不同。 在 OSS 中,您可以執行以下操作

# Log from python
python tests/export_test.py JaxExportTest.test_basic -v=3
# Or, log from pytest to /tmp/mylog.txt
pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt

您將看到如下格式的記錄行

I0619 10:54:18.978733 8299482112 _export.py:606] Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
I0619 10:54:18.978767 8299482112 _export.py:607] Define JAX_DUMP_IR_TO to dump the module.

如果您將環境變數 JAX_DUMP_IR_TO 設定為目錄,則匯出的 (和 JIT 編譯的) HLO 模組將儲存在該處。

JAX_DUMP_IR_TO=/tmp/export.dumps pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt
INFO     absl:_export.py:606 Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
INFO     absl:_export.py:607 The module was dumped to jax_ir0_jit_sin_export.mlir.

您將看到匯出的模組 (名為 ..._export.mlir) 和 JIT 編譯的模組 (名為 ..._compile.mlir)

$ ls -l /tmp/export.dumps/
total 32
-rw-rw-r--@ 1 necula  wheel  2316 Jun 19 11:04 jax_ir0_jit_sin_export.mlir
-rw-rw-r--@ 1 necula  wheel  2279 Jun 19 11:04 jax_ir1_jit_sin_compile.mlir
-rw-rw-r--@ 1 necula  wheel  3377 Jun 19 11:04 jax_ir2_jit_call_exported_compile.mlir
-rw-rw-r--@ 1 necula  wheel  2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir

在 Google 內部,您可以透過使用 --vmodule 引數來指定不同模組的記錄層級來開啟記錄,例如,--vmodule=_export=3

確保向前和向後相容性#

本節討論 JAX 開發人員應使用的程序,以確保 相容性保證

一個複雜之處在於,外部使用者在不同的套件中安裝 JAX 和 jaxlib,並且使用者通常最終使用比 JAX 更舊的 jaxlib。 我們觀察到自訂呼叫存在於 jaxlib 中,並且只有 jaxlib 與匯出工件的消費者相關。 為了簡化此程序,我們正在為外部使用者設定期望,即相容性視窗是根據 jaxlib 版本定義的,即使 JAX 可以使用較舊的版本,他們也有責任確保使用新的 jaxlib 進行匯出。

因此,我們只關心 jaxlib 版本。 當我們發佈 jaxlib 版本時,即使我們不強制將其設為允許的最低版本,我們也可以啟動向後相容性棄用時鐘。

假設我們需要新增、刪除或變更 JAX 降低規則使用的自訂呼叫目標 T 的語意。 以下是可能的時間順序 (用於變更 jaxlib 中存在的自訂呼叫目標)

  1. 「D - 1」天,變更之前。 假設活動的內部 JAX 版本為 0.4.31 (下一個 JAX 和 jaxlib 版本的版本)。 JAX 降低規則使用自訂呼叫 T

  2. 「D」天,我們新增新的自訂呼叫目標 T_NEW。 我們應該建立新的自訂呼叫目標,並在大約 6 個月後清理舊目標,而不是就地更新 T

    • 請參閱 PR #20997,其中實作了以下步驟。

    • 我們新增自訂呼叫目標 T_NEW

    • 我們變更了 JAX 降低規則,先前使用 T,現在改為有條件地使用 T_NEW,條件如下:

    from jax._src import config
    from jax._src.lib import version as jaxlib_version
    
    def my_lowering_rule(ctx: LoweringRuleContext, ...):
      if ctx.is_forward_compat() or jaxlib_version < (0, 4, 31):
        # this is the old lowering, using target T, while we
        # are in forward compatibility mode for T, or we
        # are in OSS and are using an old jaxlib.
        return hlo.custom_call("T", ...)
      else:
        # This is the new lowering, using target T_NEW, for
        # when we use a jaxlib with version `>= (0, 4, 31)`
        # (or when this is internal usage), and also we are
        # in JIT mode.
        return hlo.custom_call("T_NEW", ...)
    
    • 請注意,在 JIT 模式下,或使用者傳遞 --jax_export_ignore_forward_compatibility=true 時,向前相容模式始終為 false。

    • 我們將 T_NEW 新增至 _export.py_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE 的清單中。

  3. 第「D + 21」天(向前相容性視窗結束;甚至可以晚於 21 天):我們在降低程式碼中移除 forward_compat_mode,因此只要我們使用新的 jaxlib,現在匯出將開始使用新的自訂呼叫目標 T_NEW

    • 我們為 T_NEW 新增向後相容性測試。

  4. 第「RELEASE > D」天(D 之後的第一個 JAX 發布日期,當我們發布版本 0.4.31 時):我們開始為期 6 個月的向後相容性計時。請注意,這僅在 T 屬於我們已保證穩定性的自訂呼叫目標時才相關,即列在 _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE 中的目標。

    • 如果 RELEASE 在向前相容性視窗 [D, D + 21] 中,且如果我們將 RELEASE 作為允許的最低 jaxlib 版本,那麼我們可以移除 JIT 分支中 jaxlib_version < (0, 4, 31) 的條件判斷式。

  5. 第「RELEASE + 180」天(向後相容性視窗結束,甚至可以晚於 180 天):到目前為止,我們必須已提升最低 jaxlib 版本,以便移除降低條件判斷式 jaxlib_version < (0, 4, 31),且 JAX 降低規則無法再產生對 T 的自訂呼叫。

    • 我們移除舊的自訂呼叫目標 T 的 C++ 實作。

    • 我們也移除針對 T 的向後相容性測試

從 jax.experimental.export 遷移指南#

在 2024 年 6 月 18 日(JAX 版本 0.4.30),我們棄用了 jax.experimental.export API,改用 jax.export API。其中進行了一些小變更。

  • jax.experimental.export.export:

    • 舊函式過去允許任何 Python callable,或 jax.jit 的結果。現在僅接受後者。您必須在呼叫 export 之前,手動將 jax.jit 應用於要匯出的函式。

    • 舊的 lowering_parameters kwarg 現在命名為 platforms

  • jax.experimental.export.default_lowering_platform() 現在位於 jax.export.default_export_platform()

  • jax.experimental.export.call 現在是 jax.export.Exported 物件的方法。您應該使用 exp.call,而不是 export.call(exp)

  • jax.experimental.export.serialize 現在是 jax.export.Exported 物件的方法。您應該使用 exp.serialize(),而不是 export.serialize(exp)

  • 組態旗標 --jax-serialization-version 已被棄用。請改用 --jax-export-calling-convention-version

  • jax.experimental.export.minimum_supported_serialization_version 現在位於 jax.export.minimum_supported_calling_convention_version

  • jax.export.Exported 的以下欄位已重新命名

    • uses_shape_polymorphism 現在是 uses_global_constants

    • mlir_module_serialization_version 現在是 calling_convention_version

    • lowering_platforms 現在是 platforms