jax.tree_util.register_pytree_node#
- jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func, flatten_with_keys_func=None)[原始碼]#
擴充 pytree 中被視為內部節點的類型集合。
請參閱範例用法。
- 參數:
nodetype (type[T]) – 要註冊為 pytree 的 Python 類型。
flatten_func (Callable[[T], tuple[_Children, _AuxData]]) – 用於展平期間的函式,接受
nodetype
類型的值並傳回一個配對,其中 (1) 包含要遞迴展平的子項的可迭代物件,以及 (2) 要儲存在 treedef 中並傳遞給unflatten_func
的一些可雜湊輔助資料。unflatten_func (Callable[[_AuxData, _Children], T]) – 接受兩個引數的函式:由
flatten_func
傳回並儲存在 treedef 中的輔助資料,以及未展平的子項。此函式應傳回nodetype
的執行個體。flatten_with_keys_func (Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None | None)
- 傳回類型:
None
另請參閱
register_static()
:用於註冊靜態 pytree 的更簡單 API。register_dataclass()
:用於註冊 dataclass 的更簡單 API。
範例
首先,我們將定義自訂類型
>>> class MyContainer: ... def __init__(self, size): ... self.x = jnp.zeros(size) ... self.y = jnp.ones(size) ... self.size = size
如果我們嘗試在 JIT 編譯的函式中使用此類型,我們會收到錯誤,因為 JAX 尚不知道如何處理此類型
>>> m = MyContainer(size=5) >>> def f(m): ... return m.x + m.y + jnp.arange(m.size) >>> jax.jit(f)(m) Traceback (most recent call last): ... TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute
為了使我們的物件被 JAX 識別,我們必須將其註冊為 pytree
>>> def flatten_func(obj): ... children = (obj.x, obj.y) # children must contain arrays & pytrees ... aux_data = (obj.size,) # aux_data must contain static, hashable data. ... return (children, aux_data) ... >>> def unflatten_func(aux_data, children): ... # Here we avoid `__init__` because it has extra logic we don't require: ... obj = object.__new__(MyContainer) ... obj.x, obj.y = children ... obj.size, = aux_data ... return obj ... >>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)
現在定義完成後,我們可以在 JIT 編譯的函式中使用此類型的執行個體。
>>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32)