jax.distributed.initialize#

jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, coordinator_bind_address=None)[原始碼]#

初始化 JAX 分散式系統。

呼叫 initialize() 準備 JAX 以在多主機 GPU 和 Cloud TPU 上執行。initialize() 必須在執行任何 JAX 計算之前呼叫。

JAX 分散式系統有多項作用:

  • 它允許 JAX 進程互相發現並共享拓撲資訊,

  • 它執行健康檢查,確保在任何進程終止時所有進程都關閉,以及

  • 它用於分散式檢查點。

如果您使用 TPU、Slurm 或 Open MPI,則所有引數都是選填的:如果省略,它們將會自動選擇。

cluster_detection_method 可用於選擇偵測這些分散式引數的特定方法。您可以將任何自動 spec_detect_methods 傳遞給此引數,儘管在 TPU、Slurm 或 Open MPI 的情況下並非必要。對於其他 MPI 安裝,如果您已安裝功能正常的 mpi4py,您可以傳遞 cluster_detection_method="mpi4py" 來引導所需的引數。

否則,您必須提供 coordinator_addressnum_processesprocess_idlocal_device_ids 引數給 initialize()。當提供所有四個引數時,將會跳過叢集環境自動偵測。

請注意:在某些系統上,特別是僅透過代理變數(例如 HTTP_PROXY、HTTPS_PROXY 等)存取外部網路的 HPC 叢集,呼叫 initialize() 可能會逾時。您可能需要在應用程式啟動之前取消設定這些變數。

參數:
  • coordinator_address (str | None | None) – 進程 0 的 IP 位址,以及該進程應在其上啟動協調器服務的埠。埠的選擇並不重要,只要埠在協調器上可用且所有進程都同意該埠即可。僅在支援的環境中可以為 None,在這種情況下它將自動選擇。請注意,諸如 localhost127.0.0.1 之類的特殊位址通常表示程式將繫結到本機介面,並且不適用於在多主機環境中執行。

  • num_processes (int | None | None) – 進程數。僅在支援的環境中可以為 None,在這種情況下它將自動選擇。

  • process_id (int | None | None) – 目前進程的 ID 編號。叢集中的 process_id 值必須是密集的範圍 01、…、num_processes - 1。僅在支援的環境中可以為 None;如果為 None,它將自動選擇。

  • local_device_ids (int | Sequence[int] | None | None) – 將目前進程的可見裝置限制為 local_device_ids。如果為 None,則預設為所有本機裝置對進程可見,除非進程是透過 Slurm 和 Open MPI 在 GPU 上啟動的。在這種情況下,它將預設為每個進程單一裝置。

  • cluster_detection_method (str | None | None) – 一個選填的字串,用於嘗試自動偵測分散式執行的配置。請注意,「mpi4py」方法要求您在環境中安裝可運作的 mpi4py 安裝,並使用與 MPI 相容的作業啟動器(例如 mpiexecmpirun)啟動應用程式。舊版自動偵測選項「ompi」(OMPI) 和「slurm」(Slurm) 仍然啟用。「deactivate」會繞過自動叢集偵測。

  • initialization_timeout (int) – 連線將重試的時間段(以秒為單位)。如果初始化花費的時間超過指定的逾時時間,則初始化將會錯誤。預設為 300 秒,即 5 分鐘。

  • coordinator_bind_address (str | None | None) – 進程 0 上的協調器服務應繫結到的位址和埠。如果未指定,則預設為繫結到與 coordinator_address 相同的埠上的所有可用位址。在每個節點有多個網路介面的系統上,僅讓協調器服務監聽一個位址/介面可能不足。

引發:

RuntimeError – 如果 initialize() 被呼叫多次,或在後端已初始化後呼叫。

範例

假設有兩個 GPU 進程,且進程 0 是指定的協調器,位址為 10.0.0.1:1234。若要初始化 GPU 叢集,請在執行任何其他操作之前執行以下命令。

在進程 0 上

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)  

在進程 1 上

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)