jax.tree_util.register_static#

jax.tree_util.register_static(cls)[原始碼]#

cls 註冊為沒有 leaves 的 pytree。

實例會被 jax.jit()jax.pmap() 等視為靜態。這可以替代使用 jitstatic_argnumsstatic_argnames kwargs、pmapstatic_broadcasted_argnums 等將輸入標記為靜態。

參數:

cls (type[H]) – 要註冊為靜態的類型。必須是可雜湊的,如 https://docs.python.org/3/glossary.html#term-hashable 中所定義。

返回:

輸入類別 cls 在新增至 JAX 的 pytree 註冊表後,會保持不變地返回。這允許將 register_static 用作裝飾器。

返回類型:

type[H]

範例

>>> import jax
>>> @jax.tree_util.register_static
... class StaticStr(str):
...   pass

此靜態字串現在可以直接在 jax.jit() 編譯的函式中使用,而無需使用 static_argnums 標記變數為靜態

>>> @jax.jit
... def f(x, y, s):
...   return x + y if s == 'add' else x - y
...
>>> f(1, 2, StaticStr('add'))
Array(3, dtype=int32, weak_type=True)