jax.lax.ppermute#

jax.lax.ppermute(x, axis_name, perm)[原始碼]#

根據排列 perm 執行集體排列。

如果 x 是一個 pytree,則結果等同於將此函式映射到樹狀結構中的每個葉節點。

此函式是 CollectivePermute HLO 的類比。

參數:
  • x – 具有名為 axis_name 的映射軸的陣列。

  • axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件(詳情請參閱 jax.pmap() 文件)。

  • perm – 整數對的列表,表示 (source_index, destination_index) 對,用於編碼應如何洗牌名為 axis_name 的映射軸。整數值被視為映射軸 axis_name 的索引。任何兩個配對都不應具有相同的來源索引或相同的目的地索引。對於軸 axis_name 的每個索引,如果該索引不對應於 perm 中的目的地索引,則結果中的對應值將以適當型別的零填充。

傳回值:

x 具有相同形狀的陣列,其沿軸 axis_name 的切片是根據排列 permx 收集而來。