jax.tree.flatten_with_path#

jax.tree.flatten_with_path(tree, is_leaf=None)[原始碼]#

展平 pytree,類似 tree_flatten,但也會傳回每個葉節點的鍵路徑。

參數:
  • tree (Any) – 要展平的 pytree。如果它包含自訂類型,建議使用 register_pytree_with_keys 註冊。

  • is_leaf (Callable[[Any], bool] | None | None)

傳回:

一個配對,其中第一個元素是鍵-葉節點配對的列表,每個配對包含一個葉節點及其鍵路徑。第二個元素是表示展平樹狀結構的 treedef。

傳回類型:

tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]

範例

>>> import jax
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
>>> path_vals
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
>>> treedef
PyTreeDef([*, {'x': *}])