jax.tree_util.treedef_children#
- jax.tree_util.treedef_children(treedef)[source]#
返回立即子項的 treedefs 列表
- 參數:
treedef (PyTreeDef) – 單一 PyTreeDef
- 返回值:
代表 treedef 子項的 PyTreeDefs 列表。
- 返回型別:
list[PyTreeDef]
範例
>>> import jax >>> x = [(1, 2), 3, {'a': 4}] >>> treedef = jax.tree.structure(x) >>> jax.tree_util.treedef_children(treedef) [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})] >>> _ == [jax.tree.structure(vals) for vals in x] True