jax.tree.leaves_with_path#

jax.tree.leaves_with_path(tree, is_leaf=None)[原始碼]#

取得 pytree 的分葉,如 tree_leaves,並回傳每個分葉的金鑰路徑。

參數:
  • tree (Any) – 一個 pytree。如果它包含自訂類型,建議使用 register_pytree_with_keys 註冊。

  • is_leaf (Callable[[Any], bool] | None | None)

回傳:

金鑰-分葉配對的列表,每個配對包含一個分葉及其金鑰路徑。

回傳類型:

list[tuple[tree_util.KeyPath, Any]]

範例

>>> import jax
>>> jax.tree.leaves_with_path([1, {'x': 3}])
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]