jax.tree_util.register_pytree_node#

jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func, flatten_with_keys_func=None)[原始碼]#

擴充 pytree 中被視為內部節點的類型集合。

請參閱範例用法

參數:
  • nodetype (type[T]) – 要註冊為 pytree 的 Python 類型。

  • flatten_func (Callable[[T], tuple[_Children, _AuxData]]) – 用於展平期間的函式,接受 nodetype 類型的值並傳回一個配對,其中 (1) 包含要遞迴展平的子項的可迭代物件,以及 (2) 要儲存在 treedef 中並傳遞給 unflatten_func 的一些可雜湊輔助資料。

  • unflatten_func (Callable[[_AuxData, _Children], T]) – 接受兩個引數的函式:由 flatten_func 傳回並儲存在 treedef 中的輔助資料,以及未展平的子項。此函式應傳回 nodetype 的執行個體。

  • flatten_with_keys_func (Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None | None)

傳回類型:

None

另請參閱

範例

首先,我們將定義自訂類型

>>> class MyContainer:
...   def __init__(self, size):
...     self.x = jnp.zeros(size)
...     self.y = jnp.ones(size)
...     self.size = size

如果我們嘗試在 JIT 編譯的函式中使用此類型,我們會收到錯誤,因為 JAX 尚不知道如何處理此類型

>>> m = MyContainer(size=5)
>>> def f(m):
...   return m.x + m.y + jnp.arange(m.size)
>>> jax.jit(f)(m)  
Traceback (most recent call last):
  ...
TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute

為了使我們的物件被 JAX 識別,我們必須將其註冊為 pytree

>>> def flatten_func(obj):
...   children = (obj.x, obj.y)  # children must contain arrays & pytrees
...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
...   return (children, aux_data)
...
>>> def unflatten_func(aux_data, children):
...   # Here we avoid `__init__` because it has extra logic we don't require:
...   obj = object.__new__(MyContainer)
...   obj.x, obj.y = children
...   obj.size, = aux_data
...   return obj
...
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)

現在定義完成後,我們可以在 JIT 編譯的函式中使用此類型的執行個體。

>>> jax.jit(f)(m)
Array([1., 2., 3., 4., 5.], dtype=float32)