jax.tree.all#
- jax.tree.all(tree, *, is_leaf=None)[原始碼]#
在樹狀結構的葉節點上呼叫 all()。
- 參數:
tree (Any) – 要評估的 pytree
is_leaf (Callable[[Any], bool] | None | None) – 可選指定的函數,將在每個扁平化步驟中呼叫。它應該返回一個布林值,指示扁平化是否應遍歷當前物件,或者是否應立即停止,並將整個子樹視為葉節點。
- 返回:
布林值 True 或 False
- 返回類型:
result
範例
>>> import jax >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) True >>> jax.tree.all([False, (True, False)]) False