jax.tree_util
模組
用於處理樹狀容器資料結構的工具。
此模組提供一小組實用函式,用於處理樹狀資料結構,例如巢狀元組、列表和字典。我們稱這些結構為 pytrees。它們之所以稱為樹狀結構,是因為它們是遞迴定義的(任何非 pytree 都是 pytree,即葉節點,而任何 pytree 的 pytree 都是 pytree),並且可以遞迴操作(物件識別等價性不會因映射操作而保留,且結構不能包含參考循環)。
被視為 pytree 節點(例如,可以映射而不是視為葉節點)的 Python 類型集合是可擴展的。有一個單一模組級別的類型註冊表,且類別層次結構會被忽略。透過註冊新的 pytree 節點類型,該類型實際上會對此檔案中的實用函式變得透明。
此模組的主要目的是啟用使用者定義的資料結構和 JAX 轉換(例如 jit)之間的互操作性。這並非旨在成為通用的樹狀資料結構處理函式庫。
有關範例,請參閱 JAX pytrees 筆記。
函式列表
Partial (func, *args, **kw)
|
functools.partial 的一個版本,可在 pytrees 中運作。 |
all_leaves (iterable[, is_leaf])
|
測試給定可迭代物件中的所有元素是否都是葉節點。 |
build_tree (treedef, xs)
|
從巢狀可迭代結構建構 treedef |
register_dataclass (nodetype[, data_fields, ...])
|
擴展在 pytrees 中被視為內部節點的類型集合。 |
register_pytree_node (nodetype, flatten_func, ...)
|
擴展在 pytrees 中被視為內部節點的類型集合。 |
register_pytree_node_class (cls)
|
擴展在 pytrees 中被視為內部節點的類型集合。 |
register_pytree_with_keys (nodetype, ...[, ...])
|
擴展在 pytrees 中被視為內部節點的類型集合。 |
register_pytree_with_keys_class (cls)
|
擴展在 pytrees 中被視為內部節點的類型集合。 |
register_static (cls)
|
將 cls 註冊為沒有葉節點的 pytree。 |
tree_flatten_with_path (tree[, is_leaf])
|
jax.tree.flatten_with_path() 的別名。
|
tree_leaves_with_path (tree[, is_leaf])
|
jax.tree.leaves_with_path() 的別名。
|
tree_map_with_path (f, tree, *rest[, is_leaf])
|
jax.tree.map_with_path() 的別名。
|
treedef_children (treedef)
|
傳回直接子節點的 treedef 列表 |
treedef_is_leaf (treedef)
|
如果 treedef 代表葉節點,則傳回 True。 |
treedef_tuple (treedefs)
|
從子 treedef 的可迭代物件建立元組 treedef。 |
KeyEntry
|
類型變數。 |
KeyPath
|
tuple [KeyEntry , ...] 的別名
|
keystr (keys)
|
用於美觀列印金鑰元組的輔助函式。 |
舊版 API
這些 API 現在透過 jax.tree
存取。
tree_all (tree, *[, is_leaf])
|
jax.tree.all() 的別名。
|
tree_flatten (tree[, is_leaf])
|
jax.tree.flatten() 的別名。
|
tree_leaves (tree[, is_leaf])
|
jax.tree.leaves() 的別名。
|
tree_map (f, tree, *rest[, is_leaf])
|
jax.tree.map() 的別名。
|
tree_reduce (function, tree[, initializer, ...])
|
jax.tree.reduce() 的別名。
|
tree_structure (tree[, is_leaf])
|
jax.tree.structure() 的別名。
|
tree_transpose (outer_treedef, inner_treedef, ...)
|
jax.tree.transpose() 的別名。
|
tree_unflatten (treedef, leaves)
|
jax.tree.unflatten() 的別名。
|