jax.lax.all_to_all#

jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False)[原始碼]#

實現映射軸並映射不同的軸。

如果 x 是 pytree,則結果等同於將此函數映射到樹狀結構中的每個葉節點。

在輸出中,輸入映射軸 axis_name 在邏輯軸位置 concat_axis 實現,而位置 split_axis 的輸入未映射軸則以名稱 axis_name 進行映射。

映射軸大小的群組大小必須等於未映射軸的大小;也就是說,我們必須有 lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]。預設情況下,當 axis_index_groups=None 時,這涵蓋所有裝置。

參數:
  • x – 具有名為 axis_name 的映射軸的陣列。

  • axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件(詳情請參閱 jax.pmap() 文件)。

  • split_axis – 整數,指示要以名稱 axis_name 映射的 x 的未映射軸。

  • concat_axis – 整數,指示在輸出中實現具有名稱 axis_name 的輸入映射軸的位置。

  • axis_index_groups – 軸索引的選用列表(例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個複本執行 all_to_all)。群組必須精確地涵蓋所有軸索引一次,且所有群組的大小必須相同。

  • tiled – 當為 True 時,all_to_all 會將 split_axis 分割成區塊,並沿 concat_axis 將它們串連起來。特別是,不會新增或移除維度。預設為 False。

傳回值:

當 tiled 為 False 時,陣列的形狀由以下表達式給出

np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)

其中 axis_size 是輸入 x 中名為 axis_name 的映射軸的大小,即 axis_size = lax.psum(1, axis_name)

否則,陣列的形狀與輸入形狀相似,但 split_axis 除以軸大小,而 concat_axis 乘以軸大小。