jax.tree.structure#

jax.tree.structure(tree, is_leaf=None)[原始碼]#

取得 pytree 的 treedef。

參數:
  • tree (Any) – 要取得葉節點的 pytree

  • is_leaf (None | Callable[[Any], bool] | None) – 可選指定的函數,將在每個扁平化步驟中呼叫。它應傳回布林值,指示是否應遍歷目前物件的扁平化,或者是否應立即停止,並將整個子樹視為葉節點。

傳回值:

代表樹狀結構的 PyTreeDef。

傳回型別:

pytreedef

範例

>>> import jax
>>> jax.tree.structure([1, (2, 3), [4, 5]])
PyTreeDef([*, (*, *), [*, *]])