jax.lax.psum#
- jax.lax.psum(x, axis_name, *, axis_index_groups=None)[原始碼]#
在 pmapped 軸
axis_name
上,計算x
的 all-reduce 總和。如果
x
是一個 pytree,則結果等同於將此函數映射到樹狀結構中的每個葉節點。布林資料類型的輸入在縮減之前會轉換為整數。
- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件 (請參閱
jax.pmap()
文件以瞭解更多詳細資訊)。axis_index_groups – 包含軸索引的可選列表 (例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個副本執行 psum)。群組必須精確地涵蓋所有軸索引一次。
- 傳回:
與
x
具有相同形狀的陣列,表示沿軸axis_name
的 all-reduce 總和的結果。
範例
例如,在 4 個 XLA 裝置可用的情況下
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [0. 0.16666667 0.33333334 0.5 ]
假設我們想要在兩個群組之間執行
psum
,一個群組包含device0
和device1
,另一個群組包含device2
和device3
,>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) [1 1 5 5]
使用 2D 形狀 x 的範例。每列是來自一個裝置的資料。
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]]
跨所有裝置的完整
psum
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [[24 28 32 36] [24 28 32 36] [24 28 32 36] [24 28 32 36]]
在兩個群組之間執行
psum
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) [[ 4 6 8 10] [ 4 6 8 10] [20 22 24 26] [20 22 24 26]]