jax.vmap#

jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)[原始碼]#

向量化 map。建立一個函數,將 fun 映射到引數軸上。

參數:
  • fun (F) – 要映射到額外軸上的函數。

  • in_axes (int | None | Sequence[Any]) –

    一個整數、None 或值序列,指定要映射的輸入陣列軸。

    如果 fun 的每個位置引數都是陣列,則 in_axes 可以是一個整數、None 或整數和 None 的元組,長度等於 fun 的位置引數數量。整數或 None 表示要為所有引數映射哪個陣列軸 (其中 None 表示不映射任何軸),而元組表示要為每個對應的位置引數映射哪個軸。軸整數必須在每個陣列的範圍 [-ndim, ndim) 內,其中 ndim 是對應輸入陣列的維度 (軸) 數。

    如果 fun 的位置引數是容器 (pytree) 型別,則 in_axes 必須是一個序列,長度等於 fun 的位置引數數量,並且對於每個引數,in_axes 的對應元素可以是具有匹配 pytree 結構的容器,指定其容器元素的映射。換句話說,in_axes 必須是傳遞給 fun 的位置引數元組的容器樹前綴。有關更多詳細資訊,請參閱此連結:https://jax.dev.org.tw/en/latest/pytrees.html#applying-optional-parameters-to-pytrees

    必須明確提供 axis_size,或者至少一個位置引數必須具有非 None 的 in_axes。所有映射位置引數的映射輸入軸的大小必須全部相等。

    作為關鍵字傳遞的引數始終映射到它們的前導軸 (即軸索引 0)。

    請參閱下面的範例。

  • out_axes (Any) – 一個整數、None 或其 (巢狀) 標準 Python 容器 (tuple/list/dict),指示映射軸應出現在輸出中的位置。所有具有映射軸的輸出都必須具有非 None 的 out_axes 規範。軸整數必須在每個輸出陣列的範圍 [-ndim, ndim) 內,其中 ndimvmap()-ed 函數傳回的陣列的維度 (軸) 數,這比 fun 傳回的對應陣列的維度 (軸) 數多一個。

  • axis_name (AxisName | None | None) – 可選,一個可雜湊的 Python 物件,用於識別映射軸,以便可以應用平行集合運算。

  • axis_size (int | None | None) – 可選,一個整數,指示要映射的軸的大小。如果未提供,則映射軸大小會從引數推斷。

  • spmd_axis_name (AxisName | tuple[AxisName, ...] | None | None)

返回:

fun 的批次/向量化版本,其引數對應於 fun 的引數,但在由 in_axes 指示的位置具有額外的陣列軸,並且傳回值對應於 fun 的傳回值,但在由 out_axes 指示的位置具有額外的陣列軸。

返回型別:

F

例如,我們可以使用向量點積來實作矩陣-矩陣乘積

>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

在這裡,我們使用 [a,b] 來表示形狀為 (a,b) 的陣列。以下是一些變體

>>> mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)

以下範例說明如何在 in_axes 中使用容器型別來指定要映射的容器元素的軸

>>> A, B, C, D = 2, 3, 4, 5
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
...   x, (y, z) = tree_arg
...   return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
 [12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6  # batch size
>>> x = jnp.ones((K, A, B))  # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
(6, 2, 5)

這是另一個在 in_axes 中使用容器型別的範例,這次是字典,用於指定要映射的容器元素

>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
...  return dct['a'] + dct['b'] + x
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
>>> print(out)
[1. 2. 3. 4. 5.]

向量化函數的結果可以映射或取消映射。例如,下面的函數傳回一個配對,其中第一個元素已映射,第二個元素未映射。僅對於未映射的結果,我們可以將 out_axes 指定為 None (使其保持未映射)。

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), 8.0)

如果為未映射的結果指定了 out_axes,則結果將在映射軸上廣播

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))

如果為映射的結果指定了 out_axes,則結果將相應地轉置。

最後,這是一個使用 axis_name 與集合運算子的範例

>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
 [12. 15. 18. 21.]
 [12. 15. 18. 21.]]

有關涉及集合運算子的更多範例,請參閱 jax.pmap() 文件字串。