jax.tree.leaves_with_path#
- jax.tree.leaves_with_path(tree, is_leaf=None)[原始碼]#
取得 pytree 的分葉,如
tree_leaves
,並回傳每個分葉的金鑰路徑。- 參數:
tree (Any) – 一個 pytree。如果它包含自訂類型,建議使用
register_pytree_with_keys
註冊。is_leaf (Callable[[Any], bool] | None | None)
- 回傳:
金鑰-分葉配對的列表,每個配對包含一個分葉及其金鑰路徑。
- 回傳類型:
範例
>>> import jax >>> jax.tree.leaves_with_path([1, {'x': 3}]) [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]