jax.lax.psum#

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[原始碼]#

在 pmapped 軸 axis_name 上,計算 x 的 all-reduce 總和。

如果 x 是一個 pytree,則結果等同於將此函數映射到樹狀結構中的每個葉節點。

布林資料類型的輸入在縮減之前會轉換為整數。

參數:
  • x – 具有名為 axis_name 的映射軸的陣列。

  • axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件 (請參閱 jax.pmap() 文件以瞭解更多詳細資訊)。

  • axis_index_groups – 包含軸索引的可選列表 (例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個副本執行 psum)。群組必須精確地涵蓋所有軸索引一次。

傳回:

x 具有相同形狀的陣列,表示沿軸 axis_name 的 all-reduce 總和的結果。

範例

例如,在 4 個 XLA 裝置可用的情況下

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]

假設我們想要在兩個群組之間執行 psum,一個群組包含 device0device1,另一個群組包含 device2device3

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]

使用 2D 形狀 x 的範例。每列是來自一個裝置的資料。

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

跨所有裝置的完整 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]]

在兩個群組之間執行 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4  6  8 10]
 [ 4  6  8 10]
 [20 22 24 26]
 [20 22 24 26]]