jax.tree_util.register_pytree_with_keys_class#

jax.tree_util.register_pytree_with_keys_class(cls)[原始碼]#

擴展了在 pytree 中被視為內部節點的類型集合。

此函式與 register_pytree_node_class 類似,但需要一個類別來定義如何使用金鑰將其展平。

它是 register_pytree_with_keys 的精簡封裝器,並提供面向類別的介面

參數:

cls (Typ) – 要註冊為 pytree 的類型

返回:

輸入類別 cls 在新增到 JAX 的 pytree 登錄後,會保持不變地返回。此傳回值允許將 register_pytree_node_class 用作裝飾器。

返回類型:

Typ

參見

範例

>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
>>> @register_pytree_with_keys_class
... class Special:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten_with_keys(self):
...     return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)