分散式資料載入#
這份高階指南示範了當您在多主機或多進程環境中執行 JAX,且 JAX 計算所需的資料分散在多個進程時,如何執行分散式資料載入。本文涵蓋了如何思考分散式資料載入的整體方法,以及如何將其應用於資料平行(較簡單)和模型平行(較複雜)工作負載。
分散式資料載入通常比其他替代方案更有效率(資料分散在各個進程中),但也更複雜,例如:1) 在單一進程中載入完整的全域資料,將其分割並透過 RPC 將所需部分傳送給其他進程;以及 2) 在所有進程中載入完整的全域資料,並且僅在每個進程中使用所需部分。載入完整的全域資料通常更簡單但成本更高。例如,在機器學習中,訓練迴圈可能會在等待資料時被封鎖,並且每個進程都會使用額外的網路頻寬。
注意
當使用分散式資料載入時,重要的是每個裝置(例如,每個 GPU 或 TPU)都能存取執行計算所需的輸入資料分片。這通常使得分散式資料載入比上述替代方案更複雜且更具挑戰性,難以正確實作。如果錯誤的資料分片最終出現在錯誤的裝置上,計算仍然可以無錯誤地執行,因為計算無法知道輸入資料「應該」是什麼。但是,最終結果通常會不正確,因為輸入資料與預期的不同。
載入 jax.Array
的一般方法#
考慮從非 JAX 產生的原始資料建立單一 jax.Array
的情況。這些概念適用於載入批次資料記錄之外的情況,例如任何非 JAX 計算直接產生的多進程 jax.Array
。範例包括:1) 從檢查點載入模型權重;或 2) 載入大型空間分片的影像。
每個 jax.Array
都有一個關聯的 Sharding
,用於描述每個全域裝置需要哪個全域資料分片。當您從頭開始建立 jax.Array
時,您也需要建立其 Sharding
。這就是 JAX 如何理解資料如何在裝置之間佈局的方式。您可以建立任何您想要的 Sharding
。實際上,您通常會根據您正在實作的平行處理策略來選擇 Sharding
(您將在本指南稍後更詳細地了解資料和模型平行處理)。您也可以根據原始資料在每個進程中產生方式來選擇 Sharding
。
一旦您定義了 Sharding
,您可以使用 addressable_devices()
來提供載入目前進程內資料所需的裝置列表。(注意:「可定址裝置」一詞是「本機裝置」的更通用版本。目標是確保每個進程的資料載入器為該進程的所有本機裝置提供正確的資料。
範例#
例如,考慮一個 (64, 128)
jax.Array
,您需要跨 4 個進程(每個進程 2 個裝置,總共 8 個裝置)進行分片。這將產生 8 個唯一的資料分片,每個裝置一個。有很多方法可以對這個 jax.Array
進行分片。您可以沿著 jax.Array
的第二個維度執行 1D 分片,為每個裝置提供一個 (64, 16)
分片,如下所示
在上圖中,每個資料分片都有自己的顏色,以指示哪個進程需要載入該分片。例如,您假設進程 0
的 2 個裝置包含分片 A
和 B
,對應於全域資料的第一個 (64, 32)
部分。
您可以選擇不同的分片到裝置的分配方式。例如
這是另一個範例 — 2D 分片
無論 jax.Array
如何分片,您都必須確保每個進程的資料載入器都提供/載入所需的全域資料分片。有幾種高階方法可以實現這一點:1) 在每個進程中載入全域資料;2) 使用每個裝置的資料管線;3) 使用整合的每個進程資料管線;4) 以某種方便的方式載入資料,然後在計算內部重新分片。
選項 1:在每個進程中載入全域資料#
使用此選項,每個進程
載入所需完整值;以及
僅將所需分片傳輸到該進程的本機裝置。
這不是分散式資料載入的有效方法,因為每個進程都會丟棄其本機裝置不需要的資料,並且攝取的總資料量可能會高於必要量。但是此選項有效且相對容易實作,而對於某些工作負載(例如,如果全域資料很小),效能開銷可能是可以接受的。
選項 2:使用每個裝置的資料管線#
在此選項中,每個進程為其每個本機裝置設定一個資料載入器(也就是說,每個裝置都有自己的資料載入器,僅用於其所需資料分片)。
就載入的資料而言,這是有效率的。相較於一次考慮進程的所有本機裝置(請參閱下方的選項 3:使用整合的每個進程資料管線),有時也可以更簡單地獨立考慮每個裝置。但是,擁有多個並行資料載入器有時可能會導致效能問題。
選項 3:使用整合的每個進程資料管線#
如果您選擇此選項,則每個進程
設定單一資料載入器,該載入器載入其所有本機裝置所需的資料;然後
在傳輸到每個本機裝置之前,對本機資料進行分片。
這是執行分散式載入最有效率的方式。但是,這也是最複雜的,因為需要邏輯來找出每個裝置需要哪些資料,並建立單一資料載入器,該載入器僅載入所有這些資料(理想情況下,沒有其他額外資料)。
選項 4:以某種方便的方式載入資料,在計算內部重新分片#
此選項更難以解釋,但通常比上述選項(從 1 到 3)更容易實作。
想像一下,在難以或幾乎不可能設定資料載入器的情況下,這些載入器可以精確載入您需要的資料,無論是針對每個裝置還是每個進程的載入器。但是,仍然有可能為每個進程設定一個資料載入器,該載入器載入 1 / num_processes
的資料,只是分片方式不正確。
然後,繼續使用之前 2D 範例分片,假設每個進程更容易載入單一資料列
然後,您可以建立具有 Sharding
的 jax.Array
,表示每列資料,直接將其傳遞到計算中,並使用 jax.lax.with_sharding_constraint()
立即將列分片輸入重新分片為所需的分片。由於資料在計算內部重新分片,因此它將透過加速器通訊鏈路(例如,TPU ICI 或 NVLink)重新分片。
選項 4 與選項 3(使用整合的每個進程資料管線)具有相似的優點
每個進程仍然具有單一資料載入器;以及
全域資料在所有進程中僅載入一次;以及
全域資料還具有在資料載入方式上提供更大彈性的額外優點。
但是,此方法使用加速器互連頻寬來執行重新分片,這可能會減慢某些工作負載的速度。選項 4 還要求輸入資料除了目標 Sharding
之外,還必須表示為單獨的 Sharding
。
複製#
複製描述了一個進程,其中多個裝置具有相同的資料分片。上述一般選項(選項 1 到 4)仍然適用於複製。唯一的區別是一些進程最終可能會載入相同的資料分片。本節描述完整複製和部分複製。
完整複製#
完整複製是一個進程,其中所有裝置都具有資料的完整副本(也就是說,資料「分片」是整個陣列值)。
在下面的範例中,由於總共有 8 個裝置(每個進程 2 個),您最終將獲得完整資料的 8 個副本。每個資料副本都是未分片的,也就是說,該副本位於單一裝置上
部分複製#
部分複製描述了一個進程,其中資料有多個副本,並且每個副本都跨多個裝置進行分片。對於給定的陣列值,通常有很多種可能的方式來執行部分複製(注意:對於給定的陣列形狀,始終存在單一的完整複製 Sharding
)。
以下是兩個可能的範例。
在下面的第一個範例中,每個副本都跨進程的兩個本機裝置進行分片,總共有 4 個副本。這表示每個進程都需要載入完整的全域資料,因為其本機裝置將具有資料的完整副本。
在下面的第二個範例中,每個副本仍然跨兩個裝置進行分片,但是每對裝置都分佈在兩個不同的進程中。進程 0
(粉紅色)和進程 1
(黃色)都需要僅載入資料的第一列,而進程 2
(綠色)和進程 3
(藍色)都需要僅載入資料的第二列
現在您已經了解了建立 jax.Array
的高階選項,讓我們將它們應用於 ML 應用程式的資料載入。
資料平行處理#
在純資料平行處理(不含模型平行處理)中
您在每個裝置上複製模型;以及
每個模型副本(也就是說,每個裝置)接收不同的每個副本批次的資料。
當將輸入資料表示為單一 jax.Array
時,Array 包含此步驟中所有副本的資料(這稱為全域批次),其中 jax.Array
的每個分片都包含單一的每個副本批次。您可以將其表示為跨所有裝置的 1D 分片(檢查下面的範例)— 換句話說,全域批次由跨批次軸串聯在一起的所有每個副本批次組成。
應用此框架,您可能會得出結論,進程 0
應該取得全域批次的前四分之一(8 個中的 2 個),而進程 1
應該取得第二個,依此類推。
但是您如何知道前四分之一是什麼?以及如何確保進程 0
取得前四分之一?幸運的是,關於資料平行處理有一個非常重要的技巧,這表示您不必回答這些問題,並使整個設定更簡單。
關於資料平行處理的重要技巧#
技巧是您不需要關心哪個每個副本批次落在哪個副本上。因此,哪個進程載入批次並不重要。原因是,由於每個裝置都對應於執行相同操作的模型副本,因此哪個裝置在全域批次中取得哪個每個副本批次並不重要。
這表示您可以自由地重新排列全域批次中的每個副本批次。換句話說,您可以自由地隨機化每個裝置取得哪個資料分片。
例如
通常,重新排列 jax.Array
的資料分片(如上所示)不是一個好主意 – 您實際上是在置換 jax.Array
的值!但是,對於資料平行處理,全域批次順序沒有意義,您可以自由地重新排列全域批次中的每個副本批次,如先前已提及。
這簡化了資料載入,因為這表示每個裝置僅需要每個副本批次的獨立串流,這可以在大多數資料載入器中輕鬆實作,方法是為每個進程建立獨立的管線,並將產生的每個進程批次分塊為每個副本批次。
這是選項 2:每個進程整合資料管線的一個範例。您也可以使用其他選項(例如 0、1 和 3,本文檔稍早涵蓋了這些選項),但這個選項相對簡單且有效率。
以下是如何使用 tf.data 實作此設定的範例
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: setup the Dataset for pure data parallelism (do once)
################################################################################
# Fake example data (replace with your Dataset)
ds = tf.data.Dataset.from_tensor_slices(
[np.ones((16, 3)) * i for i in range(100)])
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step). This can be used with batches
# produced by different data loaders as well!
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
mesh = jax.make_mesh((jax.device_count(),), ('batch',))
sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch'))
global_batch_array = jax.make_array_from_process_local_data(
sharding, per_process_batch)
資料 + 模型平行化#
在模型平行化中,您跨多個裝置對每個模型副本進行分片。如果您使用純模型平行化(不含資料平行化)
只有一個模型副本跨所有裝置進行分片;且
資料(通常)在所有裝置上完全複製。
本指南考慮使用資料和平行模型化的情況
您跨多個裝置對多個模型副本中的每一個進行分片;且
您在每個模型副本上部分複製資料 — 同一模型副本中的每個裝置都獲得相同的每個副本批次,而跨模型副本的裝置獲得不同的每個副本批次。
進程內模型平行化#
為了資料載入的目的,最簡單的方法可以是將每個模型副本在單一進程的本機裝置內進行分片。
對於此範例,讓我們切換到 2 個進程,每個進程有 4 個裝置(而不是 4 個進程,每個進程有 2 個裝置)。考慮一種情境,其中每個模型副本在單一進程的 2 個本機裝置上進行分片。這會導致每個進程有 2 個模型副本,總共有 4 個模型副本,如下所示
在這裡,輸入資料再次表示為單一的 jax.Array
,具有 1D 分片,其中每個分片都是每個副本批次,但有一個例外
與純資料平行化情況不同,您引入了部分複製,並製作了 1D 分片的全域批次的 2 個副本。
這是因為每個模型副本都由 2 個裝置組成,每個裝置都需要每個副本批次的一個副本。
將每個模型副本保留在單一進程內可以簡化事情,因為您可以重複使用上面描述的純資料平行化設定,只是您還需要複製每個副本批次
注意
將每個副本批次複製到正確的裝置也非常重要! 雖然關於資料平行化的非常重要的技巧意味著您不關心哪個批次最終落在哪個副本上,但您確實關心單一副本只獲得單一批次。
例如,這樣是可以的
但是,如果您不小心將每個批次載入到哪個本機裝置上,即使 Sharding
(和平行化策略)表示資料已複製,您也可能會意外地建立未複製的資料
如果您意外地建立了一個 jax.Array
,其中包含應該在單一進程內複製但未複製的資料(對於跨進程的模型平行化而言,情況並非總是如此;請參閱下一節),JAX 將會引發錯誤。
以下是如何使用 tf.data
實作每個進程模型平行化和資料平行化的範例
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: Set up the Dataset with a different data shard per-process (do once)
# (same as for pure data parallelism)
################################################################################
# Fake example data (replace with your Dataset)
per_process_batches = [np.ones((16, 3)) * i for i in range(100)]
ds = tf.data.Dataset.from_tensor_slices(per_process_batches)
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: Create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step)
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
for process_idx in range(jax.process_count())])
mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1)
# Double check that each replica's devices are on a single process.
for replica_devices in mesh_devices:
num_processes = len(set(d.process_index for d in replica_devices))
assert num_processes == 1
mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
# Shard the data across model replicas. You don't shard across the
# data_parallelism mesh axis, meaning each per-replica shard will be replicated
# across that axis.
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("model_replicas"))
global_batch_array = jax.make_array_from_process_local_data(
sharding, per_process_batch)
跨進程模型平行化#
當模型副本分散在多個進程之間時,情況可能會變得更有趣,原因可能是
單一副本無法容納在一個進程內;或
裝置分配並非如此設定。
例如,回到先前每個進程有 2 個裝置的 4 個進程設定,如果您像這樣將裝置分配給副本
這與先前的每個進程模型平行化範例具有相同的平行化策略 – 4 個模型副本,每個副本跨 2 個裝置進行分片。唯一的區別是裝置分配 – 每個副本的兩個裝置分佈在不同的進程中,並且每個進程僅負責每個副本批次的一個副本(但適用於兩個副本)。
像這樣跨進程分割模型副本可能看起來是任意且不必要的(並且在此範例中可以說是這樣),但實際部署最終可能會採用這種裝置分配方式,以最佳地利用裝置之間的通訊連結。
資料載入現在變得更加複雜,因為跨進程需要額外的協調。在純資料平行化和每個進程模型平行化的情況下,重要的是每個進程僅載入唯一的資料流。現在,某些進程必須載入相同的資料,而另一些進程必須載入不同的資料。在上面的範例中,進程 0
和 2
(分別為粉紅色和綠色)必須載入相同的 2 個每個副本批次,而進程 1
和 3
(分別為黃色和藍色)也必須載入相同的 2 個每個副本批次(但與進程 0
和 2
的批次不同)。
此外,重要的是每個進程不要混淆其 2 個每個副本批次。雖然您不關心哪個批次落在哪個副本上(關於資料平行化的非常重要的技巧),但您需要關心副本中的所有裝置都獲得相同的批次。例如,這樣會很糟糕
注意
截至 2023 年 8 月,JAX 無法偵測 jax.Array
跨進程的分片是否應該被複製但卻沒有,並且在執行計算時會產生錯誤的結果。因此,請小心不要這樣做!
為了在每個裝置上獲得正確的每個副本批次,您需要將全域輸入資料表示為以下 jax.Array