jax.tree.unflatten#
- jax.tree.unflatten(treedef, leaves)[原始碼]#
從 treedef 和 leaves 重建 pytree。
tree_flatten()
的反向。- 參數:
treedef (tree_util.PyTreeDef) – 要重建的 treedef
leaves (Iterable[tree_util.Leaf]) – 用於重建的 leaves 的可迭代物件。可迭代物件必須符合 treedef 的 leaves。
- 傳回:
重建的 pytree,包含由
treedef
描述的結構中放置的leaves
。- 傳回類型:
Any
範例
>>> import jax >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) >>> newvals = [100, 200, 300, 400, 500] >>> jax.tree.unflatten(treedef, newvals) [100, (200, 300), [400, 500]]