多主機和多進程環境#

簡介#

本指南說明如何在多種環境中使用 JAX,例如 GPU 叢集和 Cloud TPU pod,在這些環境中,加速器分散在多個 CPU 主機或 JAX 進程中。我們將這些環境稱為「多進程」環境。

本指南特別著重於如何在多進程設定中使用集體通訊操作 (例如 jax.lax.psum() ),儘管其他通訊方法也可能根據您的使用案例而有用 (例如 RPC、mpi4jax)。如果您尚不熟悉 JAX 的集體操作,我們建議從「平行程式設計簡介」章節開始。JAX 中多進程環境的一個重要要求是加速器之間的直接通訊連結,例如 Cloud TPU 的高速互連或 GPU 的 NCCL。這些連結允許集體操作跨越多個進程的加速器執行,並具有高效能。

多進程程式設計模型#

關鍵概念

  • 您必須在每部主機上執行至少一個 JAX 進程。

  • 您應該使用 jax.distributed.initialize() 初始化叢集。

  • 每個進程都有一組不同的本地裝置,它可以定址。全域裝置是所有進程中所有裝置的集合。

  • 使用標準 JAX 平行化 API,例如 jit() (請參閱「平行程式設計簡介」教學) 和 shard_map()jax.jit 僅接受全域形狀的陣列。shard_map 允許您降至每裝置形狀。

  • 確保所有進程都以相同的順序執行相同的平行計算。

  • 確保所有進程都具有相同數量的本地裝置。

  • 確保所有裝置都相同 (例如,全部都是 V100 或全部都是 H100)。

啟動 JAX 進程#

與其他分散式系統 (其中單一控制器節點管理多個工作節點) 不同,JAX 使用「多控制器」程式設計模型,其中每個 JAX Python 進程獨立執行,有時稱為 Single Program, Multiple Data (SPMD) 模型。一般來說,相同的 JAX Python 程式會在每個進程中執行,每個進程的執行之間只有些微差異 (例如,不同的進程會載入不同的輸入資料)。此外,您必須在每部主機上手動執行您的 JAX 程式! JAX 不會從單一程式調用自動啟動多個進程。

(需要多個進程是本指南未以筆記本形式提供的原因 – 我們目前沒有從單一筆記本管理多個 Python 進程的好方法。)

初始化叢集#

若要初始化叢集,您應該在每個進程開始時呼叫 jax.distributed.initialize()jax.distributed.initialize() 必須在程式的早期呼叫,在執行任何 JAX 計算之前。

API jax.distributed.initialize() 接受幾個引數,即

  • coordinator_address:叢集中進程 0 的 IP 位址,以及該進程上可用的埠。進程 0 將啟動透過該 IP 位址和埠公開的 JAX 服務,叢集中的其他進程將連線到該服務。

  • coordinator_bind_address:叢集中進程 0 上的 JAX 服務將繫結的 IP 位址和埠。預設情況下,它將使用與 coordinator_address 相同的埠繫結到所有可用的介面。

  • num_processes:叢集中的進程數

  • process_id:此進程的 ID 號碼,範圍在 [0 .. num_processes) 內。

  • local_device_ids:將目前進程的可見裝置限制為 local_device_ids

例如在 GPU 上,典型的用法是

import jax

jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
                           num_processes=2,
                           process_id=0)

在 Cloud TPU、Slurm 和 Open MPI 環境中,您可以直接呼叫 jax.distributed.initialize(),不帶任何引數。引數的預設值將自動選擇。當在具有 Slurm 和 Open MPI 的 GPU 上執行時,假設每個 GPU 啟動一個進程,即每個進程將僅被分配一個可見的本地裝置。否則,假設每部主機啟動一個進程,即每個進程將被分配所有本地裝置。Open MPI 自動初始化僅在 JAX 進程透過 mpirun/mpiexec 啟動時使用。

import jax

jax.distributed.initialize()

目前在 TPU 上呼叫 jax.distributed.initialize() 是可選的,但建議使用,因為它可以啟用額外的檢查點和健康檢查功能。

本地裝置 vs. 全域裝置#

在我們開始從您的程式執行多進程計算之前,務必了解本地全域裝置之間的區別。

進程的本地裝置是它可以直接定址和啟動計算的裝置。 例如,在 GPU 叢集上,每部主機只能在其直接連接的 GPU 上啟動計算。在 Cloud TPU pod 上,每部主機只能在其直接連接的 8 個 TPU 核心上啟動計算 (如需更多詳細資訊,請參閱 Cloud TPU 系統架構 文件)。您可以使用 jax.local_devices() 查看進程的本地裝置。

全域裝置是所有進程中的裝置。 只要每個進程在其本地裝置上啟動計算,計算就可以跨進程的裝置執行,並透過裝置之間的直接通訊連結執行集體操作。您可以使用 jax.devices() 查看所有可用的全域裝置。進程的本地裝置始終是全域裝置的子集。

執行多進程計算#

那麼您實際上如何執行涉及跨進程通訊的計算?使用與單一進程中相同的平行評估 API!

例如,shard_map() 可用於跨多個進程執行平行計算。(如果您尚不熟悉如何使用 shard_map 在單一進程中的多個裝置上執行,請查看「平行程式設計簡介」教學。) 從概念上講,這可以視為在跨主機分片的單一陣列上執行 pmap,其中每部主機僅「看到」其輸入和輸出的本地分片。

這是一個多進程 pmap 實際運作的範例

# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

所有進程都以相同的順序執行相同的跨進程計算非常重要。 在每個進程中執行相同的 JAX Python 程式通常就足夠了。以下是一些需要注意的常見陷阱,這些陷阱可能會導致即使執行相同的程式,計算順序也不同

  • 將不同形狀的輸入傳遞給相同平行函數的進程可能會導致掛起或不正確的傳回值。只要不同形狀的輸入在各個進程中產生形狀相同的每裝置資料分片,它們就是安全的;例如,為了在每個進程不同數量的本地裝置上執行,傳遞不同的前導批次大小是可以的,但讓每個進程將其批次填充到不同的最大範例長度則是不行的。

  • 「最後一批次」問題,其中在 (訓練) 迴圈中呼叫平行函數,並且一個或多個進程比其他進程更早退出迴圈。這將導致其餘進程掛起,等待已完成的進程開始計算。

  • 基於集合非決定性排序的條件可能會導致程式碼進程掛起。例如,在目前的 Python 版本上迭代 set 或在 Python 3.7 之前 迭代 dict,即使具有相同的插入順序,也可能導致不同進程上的順序不同。