jax.experimental.multihost_utils 模組#

用於跨多個主機同步和通訊的工具。

多主機工具 API 參考#

broadcast_one_to_all(in_tree[, is_source])

將資料從來源主機 (預設為主機 0) 廣播到所有其他主機。

sync_global_devices(name)

跨所有主機/裝置建立障礙。

process_allgather(in_tree[, tiled])

從跨進程收集資料。

assert_equal(in_tree[, fail_message])

驗證所有主機是否具有相同的數值樹狀結構。

host_local_array_to_global_array(...)

將主機本機值轉換為全域分片 jax.Array。

global_array_to_host_local_array(...)

將全域 jax.Array 轉換為主機本機 jax.Array