自動向量化#

在前一節中,我們討論了透過 jax.jit() 函數進行 JIT 編譯。本筆記本討論了 JAX 的另一個轉換:透過 jax.vmap() 進行向量化。

手動向量化#

考慮以下簡單的程式碼,它計算兩個一維向量的卷積

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
Array([11., 20., 29.], dtype=float32)

假設我們想要將此函數應用於一批權重 w 到一批向量 x

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

最簡單的方法是直接在 Python 中迴圈處理批次

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

這會產生正確的結果,但效率不高。

為了有效率地批次處理計算,通常必須手動重寫函數,以確保它以向量化形式完成。這實作起來並不特別困難,但確實涉及更改函數處理索引、軸和輸入其他部分的方式。

例如,我們可以手動重寫 convolve() 以支援跨批次維度的向量化計算,如下所示

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

隨著函數複雜性的增加,這種重新實作可能會變得混亂且容易出錯;幸運的是,JAX 提供了另一種方法。

自動向量化#

在 JAX 中,jax.vmap() 轉換旨在自動產生函數的這種向量化實作

auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

它透過追蹤函數(類似於 jax.jit())並自動在每個輸入的開頭新增批次軸來完成此操作。

如果批次維度不是第一個,您可以使用 in_axesout_axes 引數來指定輸入和輸出中批次維度的位置。如果所有輸入和輸出的批次軸都相同,則這些可以是整數;否則可以是列表。

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)
Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

jax.vmap() 也支援只有一個引數被批次處理的情況:例如,如果您想要將單一組權重 w 與一批向量 x 進行卷積;在這種情況下,in_axes 引數可以設定為 None

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

組合轉換#

與所有 JAX 轉換一樣,jax.jit()jax.vmap() 設計為可組合的,這表示您可以使用 jit 包裝 vmap 函數,或使用 vmap 包裝 jitted 函數,並且一切都會正常運作

jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)