jax.experimental.multihost_utils.broadcast_one_to_all#

jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[原始碼]#

將資料從來源主機 (預設為主機 0) 廣播到所有其他主機。

參數:
  • in_tree (Any) – 陣列的 pytree - 每個陣列必須在所有主機上具有相同的形狀。

  • is_source (bool | None | None) – 可選的布林值,表示呼叫者是否為來源。只有「來源主機」會提供用於廣播的資料。如果為 None,則使用主機 0。

傳回:

一個與 in_tree 相符的 pytree,其中葉節點現在都包含來自第一個主機的資料。

傳回類型:

Any