jax.make_jaxpr#
- jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = False, abstracted_axes: Any | None = None) Callable[..., core.ClosedJaxpr] [原始碼]#
- jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = False, abstracted_axes: Any | None = None) Callable[..., tuple[core.ClosedJaxpr, Any]]
建立一個函式,根據範例引數產生其 jaxpr。
- 參數:
fun – 要計算其
jaxpr
的函式。其位置引數和傳回值應為陣列、純量或它們的標準 Python 容器(tuple/list/dict)。static_argnums – 請參閱
jax.jit()
文件字串。axis_env – 選擇性,一連串的配對,其中第一個元素是軸名稱,第二個元素是代表具有該名稱的對應軸大小的正整數。當降低涉及平行通訊集合的函式時,此參數很有用,並且它指定了軸名稱/大小環境,該環境將由
jax.pmap()
的應用程式設定。return_shape – 選擇性布林值,預設為
False
。如果為True
,則包裝後的函式會傳回一個配對,其中第一個元素是fun
的ClosedJaxpr
表示法,第二個元素是 pytree,其結構與fun
的輸出相同,且葉節點是具有shape
和dtype
屬性的物件,代表輸出葉節點的對應型別。
- 傳回值:
應用於範例引數時,會傳回
fun
在這些引數上的ClosedJaxpr
表示法的fun
包裝版本。如果引數return_shape
為True
,則傳回的函式會改為傳回一個配對,其中第一個元素是fun
的ClosedJaxpr
表示法,第二個元素是 pytree,代表fun
輸出的結構、形狀、dtypes 和具名形狀。
jaxpr
是 JAX 用於程式追蹤的中繼表示法。jaxpr
語言基於具有 let 綁定的簡單型別一階 lambda 演算。make_jaxpr()
改編一個函式以傳回其jaxpr
,我們可以檢查它以了解 JAX 在內部執行的操作。傳回的jaxpr
是抽象化為ShapedArray
層級的fun
追蹤。內部存在其他抽象層級。我們在此不詳細描述
jaxpr
語言的語意,而是提供一些範例。>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c in (g,) }