jax.lax.map#

jax.lax.map(f, xs, *, batch_size=None)[source]#

將函數映射到前導陣列軸上。

類似 Python 的內建 map,但輸入和輸出採用堆疊陣列的形式。除非您需要逐元素套用函數以減少記憶體使用量,或使用其他控制流程原語進行異質計算,否則請考慮改用 vmap() 轉換。

xs 是陣列類型時,map() 的語義由以下 Python 實作給出

def map(f, xs):
  return np.stack([f(x) for x in xs])

如同 scan()map() 是基於 JAX 基本運算實作的,因此許多相較於 Python 迴圈的優勢也適用:xs 可以是任意巢狀 pytree 類型,且映射的計算只會編譯一次。

如果提供了 batch_size,計算將以該大小的批次執行,並使用 vmap() 進行平行化。這可以用作效能更高的 map 版本,或作為記憶體效率更高的 vmap 版本。如果軸無法被批次大小整除,則剩餘部分將在單獨的 vmap 中處理,並與結果串聯。

>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
...   print('inner shape:', x.shape)
...   return x + 1
>>> y = lax.map(f, x, batch_size=3)
inner shape: (3, 4)
inner shape: (3, 4)
>>> y.shape
(10, 3, 4)

在上面的範例中,「inner shape」會印出兩次,一次是在追蹤批次計算時,另一次是在追蹤剩餘計算時。

參數:
  • f – 一個 Python 函數,用於逐元素地應用於 xs 的第一個軸或多個軸。

  • xs – 要沿著前導軸映射的值。

  • batch_size (int | None) – (可選) 整數,指定每個步驟要平行執行的批次大小。

回傳值:

映射後的值。