jax.lax.all_to_all#
- jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False)[原始碼]#
實現映射軸並映射不同的軸。
如果
x
是 pytree,則結果等同於將此函數映射到樹狀結構中的每個葉節點。在輸出中,輸入映射軸
axis_name
在邏輯軸位置concat_axis
實現,而位置split_axis
的輸入未映射軸則以名稱axis_name
進行映射。映射軸大小的群組大小必須等於未映射軸的大小;也就是說,我們必須有
lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]
。預設情況下,當axis_index_groups=None
時,這涵蓋所有裝置。- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件(詳情請參閱
jax.pmap()
文件)。split_axis – 整數,指示要以名稱
axis_name
映射的x
的未映射軸。concat_axis – 整數,指示在輸出中實現具有名稱
axis_name
的輸入映射軸的位置。axis_index_groups – 軸索引的選用列表(例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個複本執行 all_to_all)。群組必須精確地涵蓋所有軸索引一次,且所有群組的大小必須相同。
tiled – 當為 True 時,all_to_all 會將 split_axis 分割成區塊,並沿 concat_axis 將它們串連起來。特別是,不會新增或移除維度。預設為 False。
- 傳回值:
當 tiled 為 False 時,陣列的形狀由以下表達式給出
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
其中
axis_size
是輸入x
中名為axis_name
的映射軸的大小,即axis_size = lax.psum(1, axis_name)
。否則,陣列的形狀與輸入形狀相似,但 split_axis 除以軸大小,而 concat_axis 乘以軸大小。