jax.tree_util.tree_transpose#
- jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#
jax.tree.transpose()
的別名。- 參數:
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
- 傳回型別:
Any
jax.tree.transpose()
的別名。
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
Any