jax.lax.ppermute#
- jax.lax.ppermute(x, axis_name, perm)[原始碼]#
根據排列
perm
執行集體排列。如果
x
是一個 pytree,則結果等同於將此函式映射到樹狀結構中的每個葉節點。此函式是 CollectivePermute HLO 的類比。
- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件(詳情請參閱
jax.pmap()
文件)。perm – 整數對的列表,表示
(source_index, destination_index)
對,用於編碼應如何洗牌名為axis_name
的映射軸。整數值被視為映射軸axis_name
的索引。任何兩個配對都不應具有相同的來源索引或相同的目的地索引。對於軸axis_name
的每個索引,如果該索引不對應於perm
中的目的地索引,則結果中的對應值將以適當型別的零填充。
- 傳回值:
與
x
具有相同形狀的陣列,其沿軸axis_name
的切片是根據排列perm
從x
收集而來。