jax.tree_util.register_pytree_node_class#

jax.tree_util.register_pytree_node_class(cls)[原始碼]#

擴展在 pytree 中被視為內部節點的類型集合。

此函式是 register_pytree_node 的精簡包裝器,並提供面向類別的介面。

參數:

cls (Typ) – 要註冊為 pytree 的類型

回傳:

輸入類別 cls 在新增到 JAX 的 pytree 登錄表後會保持不變並回傳。此回傳值允許 register_pytree_node_class 作為裝飾器使用。

回傳類型:

Typ

另請參閱

範例

在此我們將定義一個自訂容器,它將與 jax.jit() 和其他 JAX 轉換相容

>>> import jax
>>> @jax.tree_util.register_pytree_node_class
... class MyContainer:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten(self):
...     return ((self.x, self.y), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
>>> def f(m):
...   return m.x + 2 * m.y
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)