jax.experimental.multihost_utils.broadcast_one_to_all#
- jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[原始碼]#
將資料從來源主機 (預設為主機 0) 廣播到所有其他主機。
- 參數:
in_tree (Any) – 陣列的 pytree - 每個陣列必須在所有主機上具有相同的形狀。
is_source (bool | None | None) – 可選的布林值,表示呼叫者是否為來源。只有「來源主機」會提供用於廣播的資料。如果為 None,則使用主機 0。
- 傳回:
一個與 in_tree 相符的 pytree,其中葉節點現在都包含來自第一個主機的資料。
- 傳回類型:
Any