jax.tree_util 模組#

用於處理樹狀容器資料結構的工具。

此模組提供一小組實用函式,用於處理樹狀資料結構,例如巢狀元組、列表和字典。我們稱這些結構為 pytrees。它們之所以稱為樹狀結構,是因為它們是遞迴定義的(任何非 pytree 都是 pytree,即葉節點,而任何 pytree 的 pytree 都是 pytree),並且可以遞迴操作(物件識別等價性不會因映射操作而保留,且結構不能包含參考循環)。

被視為 pytree 節點(例如,可以映射而不是視為葉節點)的 Python 類型集合是可擴展的。有一個單一模組級別的類型註冊表,且類別層次結構會被忽略。透過註冊新的 pytree 節點類型,該類型實際上會對此檔案中的實用函式變得透明。

此模組的主要目的是啟用使用者定義的資料結構和 JAX 轉換(例如 jit)之間的互操作性。這並非旨在成為通用的樹狀資料結構處理函式庫。

有關範例,請參閱 JAX pytrees 筆記

函式列表#

Partial(func, *args, **kw)

functools.partial 的一個版本,可在 pytrees 中運作。

all_leaves(iterable[, is_leaf])

測試給定可迭代物件中的所有元素是否都是葉節點。

build_tree(treedef, xs)

從巢狀可迭代結構建構 treedef

register_dataclass(nodetype[, data_fields, ...])

擴展在 pytrees 中被視為內部節點的類型集合。

register_pytree_node(nodetype, flatten_func, ...)

擴展在 pytrees 中被視為內部節點的類型集合。

register_pytree_node_class(cls)

擴展在 pytrees 中被視為內部節點的類型集合。

register_pytree_with_keys(nodetype, ...[, ...])

擴展在 pytrees 中被視為內部節點的類型集合。

register_pytree_with_keys_class(cls)

擴展在 pytrees 中被視為內部節點的類型集合。

register_static(cls)

cls 註冊為沒有葉節點的 pytree。

tree_flatten_with_path(tree[, is_leaf])

jax.tree.flatten_with_path() 的別名。

tree_leaves_with_path(tree[, is_leaf])

jax.tree.leaves_with_path() 的別名。

tree_map_with_path(f, tree, *rest[, is_leaf])

jax.tree.map_with_path() 的別名。

treedef_children(treedef)

傳回直接子節點的 treedef 列表

treedef_is_leaf(treedef)

如果 treedef 代表葉節點,則傳回 True。

treedef_tuple(treedefs)

從子 treedef 的可迭代物件建立元組 treedef。

KeyEntry

類型變數。

KeyPath

tuple[KeyEntry, ...] 的別名

keystr(keys)

用於美觀列印金鑰元組的輔助函式。

舊版 API#

這些 API 現在透過 jax.tree 存取。

tree_all(tree, *[, is_leaf])

jax.tree.all() 的別名。

tree_flatten(tree[, is_leaf])

jax.tree.flatten() 的別名。

tree_leaves(tree[, is_leaf])

jax.tree.leaves() 的別名。

tree_map(f, tree, *rest[, is_leaf])

jax.tree.map() 的別名。

tree_reduce(function, tree[, initializer, ...])

jax.tree.reduce() 的別名。

tree_structure(tree[, is_leaf])

jax.tree.structure() 的別名。

tree_transpose(outer_treedef, inner_treedef, ...)

jax.tree.transpose() 的別名。

tree_unflatten(treedef, leaves)

jax.tree.unflatten() 的別名。