jax.lax.pshuffle#
- jax.lax.pshuffle(x, axis_name, perm)[原始碼]#
jax.lax.ppermute 的便利包裝函式,具有替代排列編碼
如果
x
是 pytree,則結果等同於將此函數映射到樹狀結構中的每個葉節點。- 參數:
x – 具有名為
axis_name
的映射軸的陣列。axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件 (如需更多詳細資訊,請參閱
jax.pmap()
文件)。perm – 整數清單,編碼要套用至名為
axis_name
的軸的排列來源,以便軸索引 i 的輸出來自軸索引 perm[i] 的輸入。 [0, N) 中的每個整數都應針對軸大小 N 精確包含一次。
- 傳回:
與
x
具有相同形狀的陣列,其切片沿著軸axis_name
從x
根據排列perm
收集而來。