jax.export 模組#

jax.export 是一個用於匯出和序列化 JAX 函式以進行持久封存的庫。

請參閱匯出與序列化文件。

類別#

class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[source]#

一個降低到 StableHLO 的 JAX 函式。

參數:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

fun_name#

匯出函式的名稱,用於錯誤訊息。

型別:

str

in_tree#

一個 PyTreeDef,描述降低後的 JAX 函式的元組 (args, kwargs)。實際的降低並不依賴 in_tree,但這可以用於使用相同的引數結構調用匯出的函式。

型別:

tree_util.PyTreeDef

in_avals#

輸入抽象值的扁平元組。形狀中可能包含維度表達式。

型別:

tuple[core.ShapedArray, …]

out_tree#

一個 PyTreeDef,描述降低後的 JAX 函式的結果。

型別:

tree_util.PyTreeDef

out_avals#

輸出抽象值的扁平元組。形狀中可能包含維度表達式,維度變數與 in_avals 中的維度變數相同。

型別:

tuple[core.ShapedArray, …]

in_shardings_hlo#

扁平化的輸入分片,一個與 in_avals 長度相同的序列。None 表示未指定分片。請注意,這些不包含網格或網格中使用的實際裝置。請參閱 in_shardings_jax,以了解如何將這些轉換為可用於 JAX API 的分片規範。

型別:

tuple[HloSharding | None, …]

out_shardings_hlo#

扁平化的輸出分片,一個與 out_avals 長度相同的序列。None 表示未指定分片。請注意,這些不包含網格或網格中使用的實際裝置。請參閱 out_shardings_jax,以了解如何將這些轉換為可用於 JAX API 的分片規範。

型別:

tuple[HloSharding | None, …]

nr_devices#

模組已降低的裝置數量。

型別:

int

platforms#

一個元組,包含應匯出函式的平台。JAX 中的平台集合是開放式的;使用者可以新增平台。JAX 內建平台為:'tpu'、'cpu'、'cuda'、'rocm'。請參閱 https://jax.dev.org.tw/en/latest/export/export.html#cross-platform-and-multi-platform-export

型別:

tuple[str, …]

ordered_effects#

序列化模組中存在的有序效應。從序列化版本 9 開始存在。有關存在有序效應時的調用慣例,請參閱 https://jax.dev.org.tw/en/latest/export/export.html#module-calling-convention

型別:

tuple[effects.Effect, …]

unordered_effects#

序列化模組中存在的無序效應。從序列化版本 9 開始存在。

型別:

tuple[effects.Effect, …]

mlir_module_serialized#

序列化降低的 VHLO 模組。

型別:

bytes

calling_convention_version#

匯出模組的調用慣例的版本號碼。有關更多版本控制詳細資訊,請參閱 https://jax.dev.org.tw/en/latest/export/export.html#calling-convention-versions

型別:

int

module_kept_var_idx#

必須傳遞到模組的 in_avals 引數中的已排序索引。其他引數已被捨棄,因為它們未使用。

型別:

tuple[int, …]

uses_global_constants#

mlir_module_serialized 是否使用形狀多型或多平台匯出。這可能是因為 in_avals 包含維度變數,或由於具有維度變數或平台索引引數的 Exported 模組的內部調用。此類模組在 XLA 編譯之前需要形狀精化。

型別:

bool

disabled_safety_checks#

在匯出時已停用的安全檢查描述符列表。請參閱 DisabledSafetyCheck 的 docstring。

型別:

Sequence[DisabledSafetyCheck]

_get_vjp#

一個可選函式,它接受當前匯出的函式並返回匯出的 VJP 函式。VJP 函式接受一個扁平的引數列表,從原始引數開始,然後是每個原始輸出的餘切引數。它返回一個元組,其中包含對應於扁平化原始輸入的餘切。

型別:

Callable[[Exported], Exported] | None

請參閱 [mlir_module 的調用慣例的描述](https://jax.dev.org.tw/en/latest/export/export.html#module-calling-convention)。

call(*args, **kwargs)[source]#

從 JAX 程式呼叫匯出的函式。

參數:
  • args – 要傳遞給匯出函式的位置引數。這應該是一個陣列的 pytree,其 pytree 結構與匯出函式的引數相同。

  • kwargs – 要傳遞給匯出函式的關鍵字引數。

傳回值:結果陣列的 pytree,其結構與

匯出函式的結果相同。

調用支援反向模式 AD,以及匯出支援的所有功能:形狀多型、多平台、裝置多型。請參閱 [JAX 匯出文件](https://jax.dev.org.tw/en/latest/export/export.html) 中的範例。

has_vjp()[source]#

如果此 Exported 支援 VJP,則傳回 true。

傳回型別:

bool

in_shardings_jax(mesh)[source]#

建立對應於 self.in_shardings_hlo 的 Sharding。

Exported 物件將 in_shardings_hlo 儲存為 HloSharding,它們獨立於網格或裝置集。此方法建構可用於 JAX API (例如 jax.jitjax.device_put) 的 Sharding。

範例用法

>>> from jax import export
>>> # Prepare the exported object:
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
...                             in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
...     )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
>>> # Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>> # Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
...     exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
 Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
 Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
 Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
 Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
 Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
 Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
 Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
參數:

mesh (sharding.Mesh)

傳回型別:

Sequence[sharding.Sharding | None]

mlir_module()[source]#

mlir_module_serialized 的字串表示形式。

傳回型別:

str

out_shardings_jax(mesh)[source]#

建立對應於 self.out_shardings_hlo 的 Sharding。

請參閱 in_shardings_jax 的文件。

參數:

mesh (sharding.Mesh)

傳回型別:

Sequence[sharding.Sharding | None]

serialize(vjp_order=0)[source]#

序列化 Exported。

參數:

vjp_order (int) – 要包含的最大 vjp 階數。例如,值 2 表示我們序列化原始函式和 vjp 函式的兩個階數。這應該允許反序列化函式的二階反向模式微分。即 jax.grad(jax.grad(f)).

傳回型別:

bytearray

vjp()[source]#

取得匯出的 VJP。

如果不可用則傳回 None,如果 Exported 是從沒有 VJP 的外部格式載入,則可能會發生這種情況。

傳回型別:

已匯出

class jax.export.DisabledSafetyCheck(_impl)[source]#

應在(反)序列化時跳過的安全檢查。

這些檢查大多數在序列化時執行,但有些會延遲到反序列化。停用檢查的列表會附加到序列化,例如,作為 jax.export.Exportedtf.XlaCallModuleOp 的字串屬性序列。

使用 jax2tf 時,您可以通過傳遞 TF_XLA_FLAGS=–tf_xla_call_module_disabled_checks=platform 來停用更多反序列化安全檢查。

參數:

_impl (str)

classmethod custom_call(target_name)[source]#

允許序列化已知不穩定的呼叫目標。

僅在序列化時生效。:param target_name: 要允許的自訂呼叫目標的名稱。

參數:

target_name (str)

傳回型別:

DisabledSafetyCheck

is_custom_call()[source]#

傳回此指令允許的自訂呼叫目標。

傳回型別:

str | None

classmethod platform()[source]#

允許編譯平台與匯出平台不同。

僅在反序列化時生效。

傳回型別:

DisabledSafetyCheck

函數#

export(fun_jit, *[, platforms, disabled_checks])

匯出 JAX 函數以進行持久性序列化。

deserialize(blob)

反序列化一個 Exported 物件。

minimum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

maximum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

default_export_platform()

檢索預設匯出平台。

register_pytree_node_serialization(nodetype, ...)

註冊自訂 PyTree 節點以進行序列化和反序列化。

register_namedtuple_serialization(nodetype, ...)

註冊 namedtuple 以進行序列化和反序列化。

常數#

jax.export.minimum_supported_serialization_version#

最小支援的序列化版本;請參閱調用慣例版本

jax.export.maximum_supported_serialization_version#

最大支援的序列化版本;請參閱調用慣例版本