jax.lax.pmean#
- jax.lax.pmean(x, axis_name, *, axis_index_groups=None)[原始碼]#
在 pmapped 軸
axis_name
上計算x
的 all-reduce 平均值。如果
x
是 pytree,則結果等同於將此函式映射到樹狀結構中的每個葉節點。- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件 (請參閱
jax.pmap()
文件以取得更多詳細資訊)。axis_index_groups – 軸索引的選用列表 (例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個副本執行 pmean)。群組必須精確地涵蓋所有軸索引一次,並且在 TPU 上,所有群組的大小必須相同。
- 傳回:
與
x
具有相同形狀的陣列,表示沿軸axis_name
的 all-reduce 平均值的結果。
例如,在 4 個 XLA 裝置可用的情況下
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [1.5 1.5 1.5 1.5] >>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [0. 0.6666667 1.3333334 2. ]