jax.lax.pmax#

jax.lax.pmax(x, axis_name, *, axis_index_groups=None)[原始碼]#

計算在 pmapped 軸 axis_name 上的 x 的 all-reduce 最大值。

如果 x 是一個 pytree,則結果等同於將此函數映射到樹中的每個葉節點。

參數:
  • x – 具有名為 axis_name 的映射軸的陣列。

  • axis_name – 用於命名 pmapped 軸的可雜湊 Python 物件(請參閱 jax.pmap() 文件以了解更多詳細資訊)。

  • axis_index_groups – 軸索引的選用列表(例如,對於大小為 4 的軸,[[0, 1], [2, 3]] 將對前兩個和後兩個副本執行 pmax)。群組必須精確地涵蓋所有軸索引一次,並且在 TPU 上,所有群組的大小必須相同。

回傳:

x 具有相同形狀的陣列,表示沿軸 axis_name 的 all-reduce 最大值的結果。