jax.export.export#
- jax.export.export(fun_jit, *, platforms=None, disabled_checks=())[原始碼]#
匯出 JAX 函數以進行持久序列化。
- 參數:
fun_jit (stages.Wrapped) – 要匯出的函數。應為 jax.jit 的結果。
platforms (Sequence[str] | None | None) – 包含 'tpu'、'cpu'、'cuda'、'rocm' 子集的選用序列。如果指定多個平台,則匯出的程式碼會採用指定平台的引數。如果為 None,則使用預設 JAX 後端。多個平台的呼叫慣例在 https://jax.dev.org.tw/en/latest/export/export.html#module-calling-convention 中說明。
disabled_checks (Sequence[DisabledSafetyCheck]) – 要停用的安全檢查。請參閱 jax.export.DisabledSafetyCheck 的文件。
- 傳回:
一個函數,它接受 {class}`jax.ShapeDtypeStruct` 的 args 和 kwargs pytrees,或具有 .shape 和 .dtype 屬性的值,並傳回 Exported。
- 傳回類型:
Callable[…, Exported]
用法
>>> from jax import export >>> exported: export.Exported = export.export(jnp.sin)( ... np.arange(4, dtype=np.float32)) >>> >>> # You can inspect the Exported object >>> exported.in_avals (ShapedArray(float32[4]),) >>> blob: bytearray = exported.serialize() >>> >>> # The serialized bytes are safe to use in a separate process >>> rehydrated: export.Exported = export.deserialize(blob) >>> rehydrated.fun_name 'sin' >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)