jax.tree.structure#
- jax.tree.structure(tree, is_leaf=None)[原始碼]#
取得 pytree 的 treedef。
- 參數:
tree (Any) – 要取得葉節點的 pytree
is_leaf (None | Callable[[Any], bool] | None) – 可選指定的函數,將在每個扁平化步驟中呼叫。它應傳回布林值,指示是否應遍歷目前物件的扁平化,或者是否應立即停止,並將整個子樹視為葉節點。
- 傳回值:
代表樹狀結構的 PyTreeDef。
- 傳回型別:
pytreedef
範例
>>> import jax >>> jax.tree.structure([1, (2, 3), [4, 5]]) PyTreeDef([*, (*, *), [*, *]])