jax.lax.pshuffle#

jax.lax.pshuffle(x, axis_name, perm)[原始碼]#

jax.lax.ppermute 的便利包裝函式,具有替代排列編碼

如果 x 是 pytree,則結果等同於將此函數映射到樹狀結構中的每個葉節點。

參數:
  • x – 具有名為 axis_name 的映射軸的陣列。

  • axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件 (如需更多詳細資訊,請參閱 jax.pmap() 文件)。

  • perm – 整數清單,編碼要套用至名為 axis_name 的軸的排列來源,以便軸索引 i 的輸出來自軸索引 perm[i] 的輸入。 [0, N) 中的每個整數都應針對軸大小 N 精確包含一次。

傳回:

x 具有相同形狀的陣列,其切片沿著軸 axis_namex 根據排列 perm 收集而來。