jax.tree.leaves#

jax.tree.leaves(tree, is_leaf=None)[原始碼]#

取得 pytree 的葉節點。

參數:
  • tree (Any) – 要取得葉節點的 pytree

  • is_leaf (Callable[[Any], bool] | None | None) – 可選指定的函式,將在每個扁平化步驟中呼叫。它應傳回布林值,指示是否應遍歷目前的物件,或者是否應立即停止,並將整個子樹視為葉節點。

傳回:

樹葉節點的清單。

傳回類型:

leaves

範例

>>> import jax
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
[1, 2, 3, 4, 5]