jax.lax.psum_scatter#
- jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[原始碼]#
類似
psum(x, axis_name)
,但每個裝置僅保留部分結果。例如,
psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)
計算的值與psum(x, axis_name)[axis_index(axis_name)]
相同,但效率更高。因此,psum
結果會分散在映射軸上。計算
psum(x, axis_name)
的一種有效演算法是執行psum_scatter
,然後執行all_gather
,基本上是評估all_gather(psum_scatter(x, axis_name))
。因此,我們可以將psum_scatter
視為psum
的「前半部分」。- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名映射軸的可雜湊 Python 物件 (請參閱
jax.pmap()
文件以取得更多詳細資訊)。scatter_dimension – 將沿
axis_name
的 all-reduce 結果分散到的位置軸。axis_index_groups – 包含軸索引的選用整數列表列表。例如,對於大小為 4 的軸,
axis_index_groups=[[0, 1], [2, 3]]
將在前兩個和後兩個軸索引上執行 reduce-scatter。群組必須精確地涵蓋所有軸索引一次,並且所有群組的大小必須相同。tiled – 布林值,表示是否使用保留秩的「平鋪」行為。當
False
(預設值) 時,scatter_dimension
中的維度大小必須與軸axis_name
的大小 (或群組大小,如果給定axis_index_groups
) 相符。在沿scatter_dimension
分散 all-reduce 結果後,輸出會透過移除scatter_dimension
來壓縮,因此結果的秩低於輸入。當True
時,scatter_dimension
中的維度大小必須可被軸axis_name
的大小 (或群組大小,如果給定axis_index_groups
) 整除,並且保留scatter_dimension
軸 (因此結果的秩與輸入相同)。
- 傳回:
形狀與
x
相似的陣列,但位置scatter_dimension
中的維度大小除以軸axis_name
的大小 (當tiled=True
時),或位置scatter_dimension
中的維度被消除 (當tiled=False
時)。
例如,使用 4 個 XLA 裝置
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x) >>> print(y) [24 28 32 36]
如果使用平鋪
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x) >>> print(y) [[24] [28] [32] [36]]
使用 axis_index_groups 的範例
>>> def f(x): ... return jax.lax.psum_scatter( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[ 8 10] [20 22] [12 14] [16 18]]