jax.tree_util.all_leaves#

jax.tree_util.all_leaves(iterable, is_leaf=None)[原始碼]#

測試給定可迭代物件中的所有元素是否都是葉節點。

此函式在進階情況下很有用,例如,如果一個函式庫允許對葉節點的扁平可迭代物件進行任意的映射操作,它可能想要檢查結果是否仍然是葉節點的扁平可迭代物件。

參數:
  • iterable (Iterable[Any]) – 葉節點的可迭代物件。

  • is_leaf (Callable[[Any], bool] | None | None)

返回:

一個布林值,指示輸入中的所有元素是否為葉節點。

返回類型:

bool

範例

>>> import jax
>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
>>> assert not all_leaves([tree])