jax.tree_util.treedef_is_leaf#
- jax.tree_util.treedef_is_leaf(treedef)[原始碼]#
如果 treedef 代表葉節點,則傳回 True。
- 參數:
treedef (PyTreeDef) – 要檢查的樹狀結構
- 傳回:
如果 treedef 是葉節點(即具有單一節點),則為 True;否則為 False。
- 傳回類型:
範例
>>> import jax >>> tree1 = jax.tree.structure(1) >>> jax.tree_util.treedef_is_leaf(tree1) True >>> tree2 = jax.tree.structure([1, 2]) >>> jax.tree_util.treedef_is_leaf(tree2) False