jax.lax.map#
- jax.lax.map(f, xs, *, batch_size=None)[source]#
將函數映射到前導陣列軸上。
類似 Python 的內建 map,但輸入和輸出採用堆疊陣列的形式。除非您需要逐元素套用函數以減少記憶體使用量,或使用其他控制流程原語進行異質計算,否則請考慮改用
vmap()
轉換。當
xs
是陣列類型時,map()
的語義由以下 Python 實作給出def map(f, xs): return np.stack([f(x) for x in xs])
如同
scan()
,map()
是基於 JAX 基本運算實作的,因此許多相較於 Python 迴圈的優勢也適用:xs
可以是任意巢狀 pytree 類型,且映射的計算只會編譯一次。如果提供了
batch_size
,計算將以該大小的批次執行,並使用vmap()
進行平行化。這可以用作效能更高的map
版本,或作為記憶體效率更高的vmap
版本。如果軸無法被批次大小整除,則剩餘部分將在單獨的vmap
中處理,並與結果串聯。>>> x = jnp.ones((10, 3, 4)) >>> def f(x): ... print('inner shape:', x.shape) ... return x + 1 >>> y = lax.map(f, x, batch_size=3) inner shape: (3, 4) inner shape: (3, 4) >>> y.shape (10, 3, 4)
在上面的範例中,「inner shape」會印出兩次,一次是在追蹤批次計算時,另一次是在追蹤剩餘計算時。
- 參數:
f – 一個 Python 函數,用於逐元素地應用於
xs
的第一個軸或多個軸。xs – 要沿著前導軸映射的值。
batch_size (int | None) – (可選) 整數,指定每個步驟要平行執行的批次大小。
- 回傳值:
映射後的值。