jax.tree_util.treedef_children#

jax.tree_util.treedef_children(treedef)[source]#

返回立即子項的 treedefs 列表

參數:

treedef (PyTreeDef) – 單一 PyTreeDef

返回值:

代表 treedef 子項的 PyTreeDefs 列表。

返回型別:

list[PyTreeDef]

範例

>>> import jax
>>> x = [(1, 2), 3, {'a': 4}]
>>> treedef = jax.tree.structure(x)
>>> jax.tree_util.treedef_children(treedef)
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
>>> _ == [jax.tree.structure(vals) for vals in x]
True