jax.tree_util.register_pytree_with_keys_class#
- jax.tree_util.register_pytree_with_keys_class(cls)[原始碼]#
擴展了在 pytree 中被視為內部節點的類型集合。
此函式與
register_pytree_node_class
類似,但需要一個類別來定義如何使用金鑰將其展平。它是
register_pytree_with_keys
的精簡封裝器,並提供面向類別的介面- 參數:
cls (Typ) – 要註冊為 pytree 的類型
- 返回:
輸入類別
cls
在新增到 JAX 的 pytree 登錄後,會保持不變地返回。此傳回值允許將register_pytree_node_class
用作裝飾器。- 返回類型:
Typ
參見
register_static()
:用於註冊靜態 pytree 的更簡單 API。register_dataclass()
:用於註冊 dataclass 的更簡單 API。
範例
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey >>> @register_pytree_with_keys_class ... class Special: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten_with_keys(self): ... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children)