jax.tree.map#
- jax.tree.map(f, tree, *rest, is_leaf=None)[原始碼]#
將多輸入函數映射到 pytree 參數上,以產生新的 pytree。
- 參數:
f (Callable[..., Any]) – 接受
1 + len(rest)
個參數的函數,將應用於 pytree 的對應葉節點。tree (Any) – 要映射的 pytree,每個葉節點都提供 f 的第一個位置參數。
rest (Any) – pytree 的元組,每個 pytree 都具有與
tree
相同的結構,或以tree
作為前綴。is_leaf (Callable[[Any], bool] | None | None) – 可選指定的函數,將在每個展平步驟中呼叫。它應返回一個布林值,指示展平是否應遍歷當前物件,或者是否應立即停止,並將整個子樹視為葉節點。
- 返回值:
一個新的 pytree,其結構與
tree
相同,但每個葉節點的值由f(x, *xs)
給出,其中x
是tree
中對應葉節點的值,而xs
是rest
中對應節點的值元組。- 返回型別:
Any
範例
>>> import jax >>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42}) {'x': 8, 'y': 43}
如果傳遞多個輸入,則樹的結構取自第一個輸入;後續輸入只需要以
tree
作為前綴>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) [[5, 7, 9], [6, 1, 2]]