jax.tree 模組#

用於處理樹狀容器資料結構的工具。

jax.tree 命名空間包含來自 jax.tree_util 的工具別名。

函式列表#

all(tree, *[, is_leaf])

對樹的葉節點呼叫 all()。

flatten(tree[, is_leaf])

展平一個 pytree。

flatten_with_path(tree[, is_leaf])

展平一個 pytree,類似 tree_flatten,但也傳回每個葉節點的索引路徑。

leaves(tree[, is_leaf])

取得 pytree 的葉節點。

leaves_with_path(tree[, is_leaf])

取得 pytree 的葉節點,類似 tree_leaves,並傳回每個葉節點的索引路徑。

map(f, tree, *rest[, is_leaf])

將多輸入函式映射到 pytree 參數上,以產生新的 pytree。

map_with_path(f, tree, *rest[, is_leaf])

將多輸入函式映射到 pytree 索引路徑和參數上,以產生新的 pytree。

reduce()

對樹的葉節點呼叫 reduce()。

structure(tree[, is_leaf])

取得 pytree 的 treedef。

transpose(outer_treedef, inner_treedef, ...)

將具有樹狀結構 (outer, inner) 的樹轉換為具有結構 (inner, outer) 的樹。

unflatten(treedef, leaves)

從 treedef 和葉節點重建 pytree。