jax.lax.psum_scatter#

jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[原始碼]#

類似 psum(x, axis_name),但每個裝置僅保留部分結果。

例如,psum_scatter(x, axis_name, scatter_dimension=0, tiled=False) 計算的值與 psum(x, axis_name)[axis_index(axis_name)] 相同,但效率更高。因此,psum 結果會分散在映射軸上。

計算 psum(x, axis_name) 的一種有效演算法是執行 psum_scatter,然後執行 all_gather,基本上是評估 all_gather(psum_scatter(x, axis_name))。因此,我們可以將 psum_scatter 視為 psum 的「前半部分」。

參數:
  • x – 具有名為 axis_name 的映射軸的陣列。

  • axis_name – 用於命名映射軸的可雜湊 Python 物件 (請參閱 jax.pmap() 文件以取得更多詳細資訊)。

  • scatter_dimension – 將沿 axis_name 的 all-reduce 結果分散到的位置軸。

  • axis_index_groups – 包含軸索引的選用整數列表列表。例如,對於大小為 4 的軸,axis_index_groups=[[0, 1], [2, 3]] 將在前兩個和後兩個軸索引上執行 reduce-scatter。群組必須精確地涵蓋所有軸索引一次,並且所有群組的大小必須相同。

  • tiled – 布林值,表示是否使用保留秩的「平鋪」行為。當 False (預設值) 時,scatter_dimension 中的維度大小必須與軸 axis_name 的大小 (或群組大小,如果給定 axis_index_groups) 相符。在沿 scatter_dimension 分散 all-reduce 結果後,輸出會透過移除 scatter_dimension 來壓縮,因此結果的秩低於輸入。當 True 時,scatter_dimension 中的維度大小必須可被軸 axis_name 的大小 (或群組大小,如果給定 axis_index_groups) 整除,並且保留 scatter_dimension 軸 (因此結果的秩與輸入相同)。

傳回:

形狀與 x 相似的陣列,但位置 scatter_dimension 中的維度大小除以軸 axis_name 的大小 (當 tiled=True 時),或位置 scatter_dimension 中的維度被消除 (當 tiled=False 時)。

例如,使用 4 個 XLA 裝置

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
>>> print(y)
[24 28 32 36]

如果使用平鋪

>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
>>> print(y)
[[24]
 [28]
 [32]
 [36]]

使用 axis_index_groups 的範例

>>> def f(x):
...   return jax.lax.psum_scatter(
...       x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[ 8 10]
 [20 22]
 [12 14]
 [16 18]]