jax.tree_util.tree_unflatten#

jax.tree_util.tree_unflatten(treedef, leaves)[原始碼]#

jax.tree.unflatten() 的別名。

參數:
  • treedef (PyTreeDef)

  • leaves (Iterable[Leaf])

返回類型:

Any