jax.make_jaxpr#

jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = False, abstracted_axes: Any | None = None) Callable[..., core.ClosedJaxpr][原始碼]#
jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = False, abstracted_axes: Any | None = None) Callable[..., tuple[core.ClosedJaxpr, Any]]

建立一個函式,根據範例引數產生其 jaxpr。

參數:
  • fun – 要計算其 jaxpr 的函式。其位置引數和傳回值應為陣列、純量或它們的標準 Python 容器(tuple/list/dict)。

  • static_argnums – 請參閱 jax.jit() 文件字串。

  • axis_env – 選擇性,一連串的配對,其中第一個元素是軸名稱,第二個元素是代表具有該名稱的對應軸大小的正整數。當降低涉及平行通訊集合的函式時,此參數很有用,並且它指定了軸名稱/大小環境,該環境將由 jax.pmap() 的應用程式設定。

  • return_shape – 選擇性布林值,預設為 False。如果為 True,則包裝後的函式會傳回一個配對,其中第一個元素是 funClosedJaxpr 表示法,第二個元素是 pytree,其結構與 fun 的輸出相同,且葉節點是具有 shapedtype 屬性的物件,代表輸出葉節點的對應型別。

傳回值:

應用於範例引數時,會傳回 fun 在這些引數上的 ClosedJaxpr 表示法的 fun 包裝版本。如果引數 return_shapeTrue,則傳回的函式會改為傳回一個配對,其中第一個元素是 funClosedJaxpr 表示法,第二個元素是 pytree,代表 fun 輸出的結構、形狀、dtypes 和具名形狀。

jaxpr 是 JAX 用於程式追蹤的中繼表示法。jaxpr 語言基於具有 let 綁定的簡單型別一階 lambda 演算。make_jaxpr() 改編一個函式以傳回其 jaxpr,我們可以檢查它以了解 JAX 在內部執行的操作。傳回的 jaxpr 是抽象化為 ShapedArray 層級的 fun 追蹤。內部存在其他抽象層級。

我們在此不詳細描述 jaxpr 語言的語意,而是提供一些範例。

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
    b:f32[] = cos a
    c:f32[] = sin a
    _:f32[] = sin b
    d:f32[] = cos b
    e:f32[] = mul 1.0 d
    f:f32[] = neg e
    g:f32[] = mul f c
  in (g,) }