jax.tree_util.register_pytree_with_keys#

jax.tree_util.register_pytree_with_keys(nodetype, flatten_with_keys, unflatten_func, flatten_func=None)[原始碼]#

擴展被視為 pytree 內部節點的類型集合。

這是一個比 register_pytree_node 更強大的替代方案,讓您在展平與樹狀結構對應時,可以存取每個 pytree 葉節點的金鑰路徑。

參數:
  • nodetype (type[T]) – 要視為內部 pytree 節點的 Python 類型。

  • flatten_with_keys (Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]]) – 在展平期間使用的函數,接受 nodetype 類型的值,並傳回一個配對,其中 (1) 是一個可迭代物件,用於包含每個金鑰路徑及其子節點的元組,以及 (2) 一些可雜湊的輔助資料,這些資料將儲存在 treedef 中,並傳遞給 unflatten_func

  • unflatten_func (Callable[[_AuxData, Iterable[Any]], T]) – 接受兩個引數的函數:由 flatten_func 傳回並儲存在 treedef 中的輔助資料,以及未展平的子節點。此函數應傳回 nodetype 的實例。

  • flatten_func (None | Callable[[T], tuple[Iterable[Any], _AuxData]] | None) – 一個與 flatten_with_keys 相似的可選函數,但僅傳回子節點和輔助資料。它必須以與 flatten_with_keys 相同的順序傳回子節點,並傳回相同的輔助資料。此引數為可選,僅在呼叫不帶金鑰的函數(如 tree_maptree_flatten)時,為了更快速的遍歷才需要。

範例

首先,我們將定義一個自訂類型

>>> class MyContainer:
...   def __init__(self, size):
...     self.x = jnp.zeros(size)
...     self.y = jnp.ones(size)
...     self.size = size

現在使用金鑰感知展平函數註冊它

>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
>>> def flatten_with_keys(obj):
...   children = [(GetAttrKey('x'), obj.x),
...               (GetAttrKey('y'), obj.y)]  # children must contain arrays & pytrees
...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
...   return children, aux_data
...
>>> def unflatten(aux_data, children):
...   # Here we avoid `__init__` because it has extra logic we don't require:
...   obj = object.__new__(MyContainer)
...   obj.x, obj.y = children
...   obj.size, = aux_data
...   return obj
...
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)

現在這可以用於像 tree_flatten_with_path() 這樣的函數

>>> m = MyContainer(4)
>>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m)