jax.tree.leaves#
- jax.tree.leaves(tree, is_leaf=None)[原始碼]#
取得 pytree 的葉節點。
- 參數:
tree (Any) – 要取得葉節點的 pytree
is_leaf (Callable[[Any], bool] | None | None) – 可選指定的函式,將在每個扁平化步驟中呼叫。它應傳回布林值,指示是否應遍歷目前的物件,或者是否應立即停止,並將整個子樹視為葉節點。
- 傳回:
樹葉節點的清單。
- 傳回類型:
leaves
範例
>>> import jax >>> jax.tree.leaves([1, (2, 3), [4, 5]]) [1, 2, 3, 4, 5]