jax.tree.map_with_path#

jax.tree.map_with_path(f, tree, *rest, is_leaf=None)[原始碼]#

將多輸入函式對映到 pytree 金鑰路徑和引數,以產生新的 pytree。

這是 tree_map 的更強大替代方案,可以將每個葉節點的金鑰路徑作為輸入引數。

參數:
  • f (Callable[..., Any]) – 函式,接受 2 + len(rest) 個引數,亦即金鑰路徑和 pytree 的每個對應葉節點。

  • tree (Any) – 要對映的 pytree,其中每個葉節點的金鑰路徑作為第一個位置引數,而葉節點本身作為 f 的第二個引數。

  • *rest (Any) – pytree 的元組,每個都具有與 tree 相同的結構,或以 tree 作為前綴。

  • is_leaf (Callable[[Any], bool] | None | None)

傳回值:

一個新的 pytree,其結構與 tree 相同,但每個葉節點的值由 f(kp, x, *xs) 給定,其中 kptree 中對應葉節點的葉節點金鑰路徑,x 是葉節點值,而 xsrest 中對應節點的值元組。

傳回類型:

Any

範例

>>> import jax
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
[1, 3, 5]