jax.tree_util.register_pytree_node_class#
- jax.tree_util.register_pytree_node_class(cls)[原始碼]#
擴展在 pytree 中被視為內部節點的類型集合。
此函式是
register_pytree_node
的精簡包裝器,並提供面向類別的介面。- 參數:
cls (Typ) – 要註冊為 pytree 的類型
- 回傳:
輸入類別
cls
在新增到 JAX 的 pytree 登錄表後會保持不變並回傳。此回傳值允許register_pytree_node_class
作為裝飾器使用。- 回傳類型:
Typ
另請參閱
register_static()
:用於註冊靜態 pytree 的更簡單 API。register_dataclass()
:用於註冊 dataclass 的更簡單 API。
範例
在此我們將定義一個自訂容器,它將與
jax.jit()
和其他 JAX 轉換相容>>> import jax >>> @jax.tree_util.register_pytree_node_class ... class MyContainer: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten(self): ... return ((self.x, self.y), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) >>> def f(m): ... return m.x + 2 * m.y >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32)