jax.pmap#

jax.pmap(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[source]#

平行映射,支援集體運算。

pmap() 的目的是表達單一程式多資料 (SPMD) 程式。將 pmap() 應用於函數將會使用 XLA 編譯該函數(類似於 jit()),然後在 XLA 裝置(例如多個 GPU 或多個 TPU 核心)上平行執行。在語義上,它與 vmap() 相當,因為這兩種轉換都會將函數映射到陣列軸上,但 vmap() 通過將映射軸向下推送到原始運算來向量化函數,而 pmap() 則複製函數,並在自己的 XLA 裝置上平行執行每個副本。

映射軸大小必須小於或等於可用的本地 XLA 裝置數量,如 jax.local_device_count() 所返回的數量(除非指定 devices,請參閱下文)。對於巢狀 pmap() 呼叫,映射軸大小的乘積必須小於或等於 XLA 裝置的數量。

注意

pmap() 編譯 fun,因此雖然它可以與 jit() 結合使用,但通常是不必要的。

pmap() 要求所有參與裝置都相同。例如,無法使用 pmap() 在兩種不同型號的 GPU 上平行化計算。目前,同一個裝置在同一個 pmap 中參與兩次是錯誤的。

多進程平台: 在多進程平台(例如 TPU pods)上,pmap() 旨在用於 SPMD Python 程式中,其中每個進程都運行相同的 Python 程式碼,以便所有進程都以相同的順序運行相同的 pmapped 函數。每個進程仍應使用等於本地裝置數量的映射軸大小來呼叫 pmapped 函數(除非指定 devices,請參閱下文),並且通常會返回具有相同前導軸大小的陣列。但是,fun 中的任何集體運算都將通過裝置到裝置的通訊在所有參與裝置(包括其他進程上的裝置)上計算。從概念上講,這可以被認為是在單個陣列上運行 pmap,該陣列在進程之間分片,其中每個進程「僅看到」其輸入和輸出的本地分片。SPMD 模型要求必須在所有裝置上以相同的順序運行相同的多進程 pmaps,但它們可以與在單個進程中運行的任意操作交錯。

參數:
  • fun (Callable) – 要映射到引數軸上的函數。其引數和傳回值應為陣列、純量或(巢狀)標準 Python 容器(tuple/list/dict)。由 static_broadcasted_argnums 指示的位置引數可以是任何東西,只要它們是可雜湊的並且定義了相等運算。

  • axis_name (AxisName | None | None) – 可選,用於識別映射軸的可雜湊 Python 物件,以便可以應用平行集體運算。

  • in_axes – 一個非負整數、None 或其巢狀 Python 容器,用於指定要映射到哪些位置引數軸上。作為關鍵字傳遞的引數始終映射到其前導軸(即軸索引 0)上。請參閱 vmap() 以取得詳細資訊。

  • out_axes – 一個非負整數、None 或其巢狀 Python 容器,指示映射軸應出現在輸出中的位置。所有具有映射軸的輸出都必須具有非 None out_axes 規格(請參閱 vmap())。

  • static_broadcasted_argnums (int | Iterable[int]) –

    一個整數或整數集合,用於指定將哪些位置引數視為靜態(編譯時常數)。僅依賴靜態引數的運算將會被常數折疊。使用這些常數的不同值呼叫 pmapped 函數將會觸發重新編譯。如果使用少於 static_broadcasted_argnums 指示的位置引數呼叫 pmapped 函數,則會引發錯誤。每個靜態引數都將廣播到所有裝置。不是陣列或其容器的引數必須標記為靜態。預設為 ()。

    靜態引數必須是可雜湊的,表示 __hash____eq__ 都已實作,並且應該是不可變的。

  • devices (Sequence[xc.Device] | None | None) – 這是一個實驗性功能,API 可能會變更。可選,要映射到的裝置序列。(可用的裝置可以通過 jax.devices() 檢索)。必須在多進程設定中為每個進程給出相同的裝置(因此將包括跨進程的裝置)。如果指定,則映射軸的大小必須等於給定進程本地序列中裝置的數量。尚不支援在內部或外部 pmap() 中指定 devices 的巢狀 pmap()

  • backend (str | None | None) – 這是一個實驗性功能,API 可能會變更。可選,表示 XLA 後端的字串。'cpu'、'gpu' 或 'tpu'。

  • axis_size (int | None | None) – 可選;映射軸的大小。

  • donate_argnums (int | Iterable[int]) –

    指定哪些位置引數緩衝區「捐贈」給計算。如果您在計算完成後不再需要引數緩衝區,則捐贈它們是安全的。在某些情況下,XLA 可以使用捐贈的緩衝區來減少執行計算所需的記憶體量,例如回收您的輸入緩衝區之一來儲存結果。您不應重複使用您捐贈給計算的緩衝區,如果您嘗試這樣做,JAX 將會引發錯誤。請注意,donate_argnums 僅適用於位置引數,關鍵字引數將不會被捐贈。

    有關緩衝區捐贈的更多詳細資訊,請參閱 FAQ

  • global_arg_shapes (tuple[tuple[int, ...], ...] | None | None)

傳回:

一個平行化的 fun 版本,其引數對應於 fun 的引數,但在 in_axes 指示的位置具有額外的陣列軸,並且輸出具有額外的前導陣列軸(大小相同)。

傳回類型:

Any

例如,假設有 8 個 XLA 裝置可用,pmap() 可以用作沿著前導陣列軸的映射

>>> import jax.numpy as jnp
>>>
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8))  
>>> print(out)  
[0, 1, 4, 9, 16, 25, 36, 49]

當前導維度小於可用裝置的數量時,JAX 將僅在裝置的子集上運行

>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(jnp.dot)(x, y)  
>>> print(out)  
[[[    4.     9.]
  [   12.    29.]]
 [[  244.   345.]
  [  348.   493.]]
 [[ 1412.  1737.]
  [ 1740.  2141.]]]

如果您的前導維度大於可用裝置的數量,您將會收到錯誤

>>> pmap(lambda x: x ** 2)(jnp.arange(9))  
ValueError: ... requires 9 replicas, but only 8 XLA devices are available

vmap() 一樣,在 in_axes 中使用 None 表示引數沒有額外的軸,應該跨副本廣播,而不是映射

>>> x, y = jnp.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y)  
>>> print(out)  
([4., 5.], [8., 8.])

請注意,pmap() 始終傳回映射到其前導軸的值,相當於在 vmap() 中使用 out_axes=0

除了表達純映射之外,pmap() 也可用於表達平行單一程式多資料 (SPMD) 程式,這些程式通過集體運算進行通訊。例如

>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(jnp.arange(4.))  
>>> print(out)  
[ 0.          0.16666667  0.33333334  0.5       ]
>>> print(out.sum())  
1.0

在此範例中,axis_name 是一個字串,但它可以是任何具有 __hash____eq__ 定義的 Python 物件。

傳遞給 pmap() 的引數 axis_name 命名了映射軸,以便集體運算(如 jax.lax.psum())可以引用它。軸名稱尤其在巢狀 pmap() 函數的情況下很重要,其中集體運算可以在不同的軸上運作

>>> from functools import partial
>>> import jax
>>>
>>> @partial(pmap, axis_name='rows')
... @partial(pmap, axis_name='cols')
... def normalize(x):
...   row_normed = x / jax.lax.psum(x, 'rows')
...   col_normed = x / jax.lax.psum(x, 'cols')
...   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
...   return row_normed, col_normed, doubly_normed
>>>
>>> x = jnp.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x)  
>>> print(row_normed.sum(0))  
[ 1.  1.]
>>> print(col_normed.sum(1))  
[ 1.  1.  1.  1.]
>>> print(doubly_normed.sum((0, 1)))  
1.0

在多進程平台上,集體運算在所有裝置(包括其他進程上的裝置)上運作。例如,假設以下程式碼在兩個進程上運行,每個進程有 4 個 XLA 裝置

>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
>>> out = pmap(f, axis_name='i')(data)  
>>> print(out)  
[28 29 30 31] # on process 0
[32 33 34 35] # on process 1

每個進程都傳入一個不同的長度為 4 的陣列,對應於其 4 個本地裝置,並且 psum 在所有 8 個值上運作。從概念上講,這兩個長度為 4 的陣列可以被認為是一個分片的長度為 8 的陣列(在本範例中相當於 jnp.arange(8)),它被映射到,並且長度為 8 的映射軸被命名為 'i'。然後,每個進程上的 pmap 呼叫會傳回相應的長度為 4 的輸出分片。

devices 引數可用於精確指定用於運行平行計算的裝置。例如,再次假設一個具有 8 個裝置的單個進程,以下程式碼定義了兩個平行計算,一個在頭六個裝置上運行,另一個在剩餘的兩個裝置上運行

>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
... def f1(x):
...   return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
... def f2(x):
...   return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(jnp.arange(6.)))  
[0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
>>> print(f2(jnp.array([2., 3.])))  
[ 13.  13.]