jax.distributed.initialize#
- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, coordinator_bind_address=None)[原始碼]#
初始化 JAX 分散式系統。
呼叫
initialize()
準備 JAX 以在多主機 GPU 和 Cloud TPU 上執行。initialize()
必須在執行任何 JAX 計算之前呼叫。JAX 分散式系統有多項作用:
它允許 JAX 進程互相發現並共享拓撲資訊,
它執行健康檢查,確保在任何進程終止時所有進程都關閉,以及
它用於分散式檢查點。
如果您使用 TPU、Slurm 或 Open MPI,則所有引數都是選填的:如果省略,它們將會自動選擇。
cluster_detection_method
可用於選擇偵測這些分散式引數的特定方法。您可以將任何自動spec_detect_methods
傳遞給此引數,儘管在 TPU、Slurm 或 Open MPI 的情況下並非必要。對於其他 MPI 安裝,如果您已安裝功能正常的mpi4py
,您可以傳遞cluster_detection_method="mpi4py"
來引導所需的引數。否則,您必須提供
coordinator_address
、num_processes
、process_id
和local_device_ids
引數給initialize()
。當提供所有四個引數時,將會跳過叢集環境自動偵測。請注意:在某些系統上,特別是僅透過代理變數(例如 HTTP_PROXY、HTTPS_PROXY 等)存取外部網路的 HPC 叢集,呼叫
initialize()
可能會逾時。您可能需要在應用程式啟動之前取消設定這些變數。- 參數:
coordinator_address (str | None | None) – 進程 0 的 IP 位址,以及該進程應在其上啟動協調器服務的埠。埠的選擇並不重要,只要埠在協調器上可用且所有進程都同意該埠即可。僅在支援的環境中可以為
None
,在這種情況下它將自動選擇。請注意,諸如localhost
或127.0.0.1
之類的特殊位址通常表示程式將繫結到本機介面,並且不適用於在多主機環境中執行。num_processes (int | None | None) – 進程數。僅在支援的環境中可以為
None
,在這種情況下它將自動選擇。process_id (int | None | None) – 目前進程的 ID 編號。叢集中的
process_id
值必須是密集的範圍0
、1
、…、num_processes - 1
。僅在支援的環境中可以為None
;如果為None
,它將自動選擇。local_device_ids (int | Sequence[int] | None | None) – 將目前進程的可見裝置限制為
local_device_ids
。如果為None
,則預設為所有本機裝置對進程可見,除非進程是透過 Slurm 和 Open MPI 在 GPU 上啟動的。在這種情況下,它將預設為每個進程單一裝置。cluster_detection_method (str | None | None) – 一個選填的字串,用於嘗試自動偵測分散式執行的配置。請注意,「mpi4py」方法要求您在環境中安裝可運作的
mpi4py
安裝,並使用與 MPI 相容的作業啟動器(例如mpiexec
或mpirun
)啟動應用程式。舊版自動偵測選項「ompi」(OMPI) 和「slurm」(Slurm) 仍然啟用。「deactivate」會繞過自動叢集偵測。initialization_timeout (int) – 連線將重試的時間段(以秒為單位)。如果初始化花費的時間超過指定的逾時時間,則初始化將會錯誤。預設為 300 秒,即 5 分鐘。
coordinator_bind_address (str | None | None) – 進程 0 上的協調器服務應繫結到的位址和埠。如果未指定,則預設為繫結到與
coordinator_address
相同的埠上的所有可用位址。在每個節點有多個網路介面的系統上,僅讓協調器服務監聽一個位址/介面可能不足。
- 引發:
RuntimeError – 如果
initialize()
被呼叫多次,或在後端已初始化後呼叫。
範例
假設有兩個 GPU 進程,且進程 0 是指定的協調器,位址為
10.0.0.1:1234
。若要初始化 GPU 叢集,請在執行任何其他操作之前執行以下命令。在進程 0 上
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
在進程 1 上
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)