jax.tree.map#

jax.tree.map(f, tree, *rest, is_leaf=None)[原始碼]#

將多輸入函數映射到 pytree 參數上,以產生新的 pytree。

參數:
  • f (Callable[..., Any]) – 接受 1 + len(rest) 個參數的函數,將應用於 pytree 的對應葉節點。

  • tree (Any) – 要映射的 pytree,每個葉節點都提供 f 的第一個位置參數。

  • rest (Any) – pytree 的元組,每個 pytree 都具有與 tree 相同的結構,或以 tree 作為前綴。

  • is_leaf (Callable[[Any], bool] | None | None) – 可選指定的函數,將在每個展平步驟中呼叫。它應返回一個布林值,指示展平是否應遍歷當前物件,或者是否應立即停止,並將整個子樹視為葉節點。

返回值:

一個新的 pytree,其結構與 tree 相同,但每個葉節點的值由 f(x, *xs) 給出,其中 xtree 中對應葉節點的值,而 xsrest 中對應節點的值元組。

返回型別:

Any

範例

>>> import jax
>>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42})
{'x': 8, 'y': 43}

如果傳遞多個輸入,則樹的結構取自第一個輸入;後續輸入只需要以 tree 作為前綴

>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]