jax.tree.unflatten#

jax.tree.unflatten(treedef, leaves)[原始碼]#

從 treedef 和 leaves 重建 pytree。

tree_flatten() 的反向。

參數:
  • treedef (tree_util.PyTreeDef) – 要重建的 treedef

  • leaves (Iterable[tree_util.Leaf]) – 用於重建的 leaves 的可迭代物件。可迭代物件必須符合 treedef 的 leaves。

傳回:

重建的 pytree,包含由 treedef 描述的結構中放置的 leaves

傳回類型:

Any

範例

>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> newvals = [100, 200, 300, 400, 500]
>>> jax.tree.unflatten(treedef, newvals)
[100, (200, 300), [400, 500]]