jax.tree_util.treedef_is_leaf#

jax.tree_util.treedef_is_leaf(treedef)[原始碼]#

如果 treedef 代表葉節點,則傳回 True。

參數:

treedef (PyTreeDef) – 要檢查的樹狀結構

傳回:

如果 treedef 是葉節點(即具有單一節點),則為 True;否則為 False。

傳回類型:

bool

範例

>>> import jax
>>> tree1 = jax.tree.structure(1)
>>> jax.tree_util.treedef_is_leaf(tree1)
True
>>> tree2 = jax.tree.structure([1, 2])
>>> jax.tree_util.treedef_is_leaf(tree2)
False