jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[原始碼]#
跨所有副本收集 x 的值。
如果
x
是 pytree,則結果等同於將此函式映射到樹狀結構中的每個葉節點。這等同於 all_to_all(broadcast(x)),但速度更快。
- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件(詳情請參閱
jax.pmap()
文件)。axis_index_groups – 軸索引的選用列表(例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個副本執行 all gather)。群組必須精確地涵蓋所有軸索引一次,且所有群組的大小必須相同。
axis – 一個位置軸,沿
axis_name
的區塊將串連到其中。tiled – 當
False
時,區塊將堆疊到輸出中索引axis
的全新位置軸中。當True
時,axis
必須參照現有的位置維度,且區塊將串連到該維度中。
- 傳回:
表示沿軸
axis_name
進行 all-gather 的結果的陣列。形狀與x.shape
相同,但當
tiled
為False
時,在位置axis
中有一個等於軸axis_name
大小的新維度,當
tiled
為True
時,位置axis
中維度的大小乘以軸axis_name
的大小。
例如,在 4 個 XLA 裝置可用的情況下
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
使用 axis_index_groups 的範例,群組依偶數和奇數裝置 ID 分割
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]