jax.tree.map_with_path#
- jax.tree.map_with_path(f, tree, *rest, is_leaf=None)[原始碼]#
將多輸入函式對映到 pytree 金鑰路徑和引數,以產生新的 pytree。
這是
tree_map
的更強大替代方案,可以將每個葉節點的金鑰路徑作為輸入引數。- 參數:
f (Callable[..., Any]) – 函式,接受
2 + len(rest)
個引數,亦即金鑰路徑和 pytree 的每個對應葉節點。tree (Any) – 要對映的 pytree,其中每個葉節點的金鑰路徑作為第一個位置引數,而葉節點本身作為
f
的第二個引數。*rest (Any) – pytree 的元組,每個都具有與
tree
相同的結構,或以tree
作為前綴。is_leaf (Callable[[Any], bool] | None | None)
- 傳回值:
一個新的 pytree,其結構與
tree
相同,但每個葉節點的值由f(kp, x, *xs)
給定,其中kp
是tree
中對應葉節點的葉節點金鑰路徑,x
是葉節點值,而xs
是rest
中對應節點的值元組。- 傳回類型:
Any
範例
>>> import jax >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3]) [1, 3, 5]