jax.experimental.multihost_utils.process_allgather#

jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[原始碼]#

從跨程序收集資料。

參數:
  • in_tree (Any) – 陣列的 pytree - 每個陣列_必須_在主機之間具有相同的形狀。

  • tiled (bool) – 是否堆疊或串聯輸出。預設為 False,即堆疊到索引 0 的新位置軸中。

返回:

numpy 陣列的 Pytrees。
  • 如果輸入是非完全可定址的 jax.Array,則資料會完全複製。

  • 如果輸入是 numpy 陣列或完全可定址的 jax.Array,則輸出形狀取決於 tiled 引數。如果為 False,則輸出將被堆疊,否則將被串聯。

  • 如果輸入是純量,則輸出將被堆疊。

返回類型:

Any