jax.lax.pswapaxes#

jax.lax.pswapaxes(x, axis_name, axis, *, axis_index_groups=None)[原始碼]#

將 pmapped 軸 axis_name 與未 mapped 軸 axis 交換。

如果 x 是 pytree,則結果等同於將此函數映射到樹中的每個 leaf。

mapped 軸大小的群組大小必須等於未 mapped 軸的大小;也就是說,我們必須有 lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]。預設情況下,當 axis_index_groups=None 時,這包含所有裝置。

此函數是 all_to_all 的特殊情況,其中輸入的 pmapped 軸放置在輸出的位置 axis。也就是說,它等同於 all_to_all(x, axis_name, axis, axis)

參數:
  • x – 具有名為 axis_name 的 mapped 軸的陣列。

  • axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件 (請參閱 jax.pmap() 文件以取得更多詳細資訊)。

  • axis – 整數,指示要使用名稱 axis_name 映射的 x 的未 mapped 軸。

  • axis_index_groups – 軸索引的選用列表 (例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將在首兩個和最後兩個副本上執行 pswapaxes)。群組必須精確地涵蓋所有軸索引一次,並且所有群組的大小必須相同。

傳回:

x 具有相同形狀的陣列。