jax.vmap#
- jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)[原始碼]#
向量化 map。建立一個函數,將
fun
映射到引數軸上。- 參數:
fun (F) – 要映射到額外軸上的函數。
in_axes (int | None | Sequence[Any]) –
一個整數、None 或值序列,指定要映射的輸入陣列軸。
如果
fun
的每個位置引數都是陣列,則in_axes
可以是一個整數、None 或整數和 None 的元組,長度等於fun
的位置引數數量。整數或None
表示要為所有引數映射哪個陣列軸 (其中None
表示不映射任何軸),而元組表示要為每個對應的位置引數映射哪個軸。軸整數必須在每個陣列的範圍[-ndim, ndim)
內,其中ndim
是對應輸入陣列的維度 (軸) 數。如果
fun
的位置引數是容器 (pytree) 型別,則in_axes
必須是一個序列,長度等於fun
的位置引數數量,並且對於每個引數,in_axes
的對應元素可以是具有匹配 pytree 結構的容器,指定其容器元素的映射。換句話說,in_axes
必須是傳遞給fun
的位置引數元組的容器樹前綴。有關更多詳細資訊,請參閱此連結:https://jax.dev.org.tw/en/latest/pytrees.html#applying-optional-parameters-to-pytrees必須明確提供
axis_size
,或者至少一個位置引數必須具有非 None 的in_axes
。所有映射位置引數的映射輸入軸的大小必須全部相等。作為關鍵字傳遞的引數始終映射到它們的前導軸 (即軸索引 0)。
請參閱下面的範例。
out_axes (Any) – 一個整數、None 或其 (巢狀) 標準 Python 容器 (tuple/list/dict),指示映射軸應出現在輸出中的位置。所有具有映射軸的輸出都必須具有非 None 的
out_axes
規範。軸整數必須在每個輸出陣列的範圍[-ndim, ndim)
內,其中ndim
是vmap()
-ed 函數傳回的陣列的維度 (軸) 數,這比fun
傳回的對應陣列的維度 (軸) 數多一個。axis_name (AxisName | None | None) – 可選,一個可雜湊的 Python 物件,用於識別映射軸,以便可以應用平行集合運算。
axis_size (int | None | None) – 可選,一個整數,指示要映射的軸的大小。如果未提供,則映射軸大小會從引數推斷。
spmd_axis_name (AxisName | tuple[AxisName, ...] | None | None)
- 返回:
fun
的批次/向量化版本,其引數對應於fun
的引數,但在由in_axes
指示的位置具有額外的陣列軸,並且傳回值對應於fun
的傳回值,但在由out_axes
指示的位置具有額外的陣列軸。- 返回型別:
F
例如,我們可以使用向量點積來實作矩陣-矩陣乘積
>>> import jax.numpy as jnp >>> >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
在這裡,我們使用
[a,b]
來表示形狀為 (a,b) 的陣列。以下是一些變體>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
以下範例說明如何在
in_axes
中使用容器型別來指定要映射的容器元素的軸>>> A, B, C, D = 2, 3, 4, 5 >>> x = jnp.ones((A, B)) >>> y = jnp.ones((B, C)) >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = jnp.ones((K, A, B)) # batch axis in different locations >>> y = jnp.ones((B, K, C)) >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) (6, 2, 5)
這是另一個在
in_axes
中使用容器型別的範例,這次是字典,用於指定要映射的容器元素>>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x >>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x) >>> print(out) [1. 2. 3. 4. 5.]
向量化函數的結果可以映射或取消映射。例如,下面的函數傳回一個配對,其中第一個元素已映射,第二個元素未映射。僅對於未映射的結果,我們可以將
out_axes
指定為None
(使其保持未映射)。>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (Array([4., 5.], dtype=float32), 8.0)
如果為未映射的結果指定了
out_axes
,則結果將在映射軸上廣播>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
如果為映射的結果指定了
out_axes
,則結果將相應地轉置。最後,這是一個使用
axis_name
與集合運算子的範例>>> xs = jnp.arange(3. * 4.).reshape(3, 4) >>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs)) [[12. 15. 18. 21.] [12. 15. 18. 21.] [12. 15. 18. 21.]]
有關涉及集合運算子的更多範例,請參閱
jax.pmap()
文件字串。