jax.tree_util.tree_unflatten# jax.tree_util.tree_unflatten(treedef, leaves)[原始碼]# jax.tree.unflatten() 的別名。 參數: treedef (PyTreeDef) leaves (Iterable[Leaf]) 返回類型: Any