適用於 TPU 的 Pallas 分散式運算#
在本教學課程中,我們將涵蓋 TPU 上 Pallas 分散式運算的基本知識。我們將學習 TPU 拓撲、使用遠端 DMA 基本運算進行通訊,以及使用 shard_map
從 JAX 呼叫分散式核心。我們也將涵蓋一些更進階的核心撰寫技術,例如雙重緩衝、雙向頻寬最佳化和巢狀管線化。作為教學範例,我們將學習如何實作 JAX 中的各種集合基本運算,例如 lax.ppermute
、lax.all_gather
、lax.psum
和 lax.psum_scatter
。
事先建議閱讀的一些資料
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
P = jax.sharding.PartitionSpec
num_devices = jax.local_device_count()
assert num_devices > 1, "Please run this notebook with more than one device."
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
Running with 4 TPU v5 lite devices.
TPU 拓撲#
TPU 通常部署在多個裝置的 pod 中,這些裝置透過高頻寬晶片間互連 (ICI) 連接,以便在 pod 內進行通訊,這比典型的網路連線快得多。例如,TPU v5p 的規格表指出每個晶片的 ICI 頻寬為 4.8Tb/s (作為參考,TPU v5p 也具有 21Tb/s 的本機 HBM 頻寬)。ICI 讓我們能夠實作快速且高效能的分散式核心,這些核心需要在 pod 內進行高頻寬通訊,並使用資料中心網路針對頻寬密集度較低的操作進行平行化,例如跨批次維度的資料平行處理。
TPU pod 通常以 ND 環面拓撲排列。下圖提供幾種不同尺寸組態的範例。
環面攤平為圖形後,可以視覺化如下。每個邊緣 (橘色或黑色) 是兩個裝置之間的雙向連線。您通常會聽到關於環的討論,與裝置拓撲一起 — 環面的關鍵特徵是,當沿著 pod 的軸取切片時,例如節點 [(0,1), (1, 1), (2, 1), (3, 1)]
或 [(0, 1), (1, 1)]
,我們有一個裝置環。這是我們可以使用的功能,以簡化 pod 內的通訊模式。
遠端直接記憶體存取 (RDMA) 模型#
TPU 透過僅推送模型進行通訊,稱為遠端直接記憶體存取 (RDMA)。TPU 可以發出複製指令,從本機緩衝區推送到同一 pod 內另一個裝置上的任何緩衝區,該緩衝區與主程式執行緒非同步執行。但是,TPU 只能讀取本機儲存的資料。這與更傳統的多核心程式設計形成對比,在傳統的多核心程式設計中,可以從共享記憶體讀取值並寫入值。
非同步遠端複製運算#
pltpu.make_async_remote_copy
函式用於建立遠端 DMA 描述符物件,該物件參數化「傳送」運算和「接收」運算。以下是其簽名
def make_async_remote_copy(
src_ref: Ref,
dst_ref: Ref,
send_sem: Ref[SemaphoreType],
recv_sem: Ref[SemaphoreType],
device_id: int | tuple[int, ...],
device_id_type: DeviceIdType
) -> AsyncCopyDescriptor:
src_ref
是本機Ref
(在任何記憶體空間中),其中包含您希望傳送到另一個裝置上dst_ref
的資料。dst_ref
是遠端Ref
(在任何記憶體空間中),資料將複製到目標裝置上的該位置。send_sem
是一個 DMA 號誌,用於封鎖直到所有資料都已從src_ref
傳送。recv_sem
是一個 DMA 號誌,用於封鎖直到預期的位元組數已在dst_ref
接收。DMA 的傳送者將寫入接收者的recv_sem
。device_id
是要傳送到的目標裝置的裝置 ID。device_id_type
指定device_id
的格式,可以是 LOGICAL 格式 (整數裝置 ID),也可以是 MESH 格式 (邏輯裝置網格的 ND 元組索引)。預設模式為 MESH。
make_async_remote_copy
傳回描述符物件,您可以使用 .start()
方法在該物件上啟動 DMA,並使用 .wait_send()
封鎖 send_sem
,以及使用 .wait_recv()
封鎖 recv_sem
(或 .wait()
封鎖兩者)。如果裝置僅預期傳送資料,則僅呼叫 .start()
和 .wait_send()
就足夠了,同樣地,如果裝置僅接收資料,則僅呼叫 .wait_recv()
就足夠了。如果使用 SPMD 模式,其中所有裝置都執行 DMA,則每個裝置通常都會呼叫 .start()
和 .wait()
。
dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
dma_descriptor.start() # Initiate the DMA (non-blocking).
# ... do other work
dma_descriptor.wait_send() # Block until all data has been sent.
dma_descriptor.wait_recv() # Block until all data has been received.
作為範例,讓我們視覺化一個 DMA,其中我們考慮 4 個裝置 (索引為 0、1、2、3)。我們考慮一個方案,其中裝置 0 複製到裝置 1,裝置 2 和 3 彼此複製。實際上,我們可以透過使用 @pl.when
在裝置 ID 上分支來建立這種非對稱通訊模式。
(1) 每個裝置建立 DMA 描述符。裝置 0、2 和 3 呼叫 .start()
以從 src_ref
啟動 DMA。裝置 1 跳過 .start()
並且不執行任何操作,例如,透過使用 pl.when
。
(2) 由於 .start()
是非封鎖的,因此每個裝置都可以在 DMA 進行中時自由執行其他計算。裝置 0、2 和 3 呼叫 .wait_send()
以等待 send_sem
,這會封鎖直到所有資料都已傳送。
(3) 最後,裝置 1、2 和 3 將呼叫 .wait_recv()
以等待 recv_sem
,直到所有資料都已到達 dst_ref
。
上述通訊模式可以撰寫如下
def example_kernel(input_ref, output_ref, send_sem, recv_sem):
device_id = lax.axis_index('x')
copy_0_to_1 = pltpu.make_async_remote_copy(
src_ref=input_ref,
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=1,
)
copy_2_to_3 = pltpu.make_async_remote_copy(
src_ref=input_ref,
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=3,
)
copy_3_to_2 = pltpu.make_async_remote_copy(
src_ref=input_ref,
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=2,
)
@pl.when(device_id == 0)
def _():
copy_0_to_1.start()
copy_0_to_1.wait_send()
@pl.when(device_id == 1)
def _():
copy_0_to_1.wait_recv()
@pl.when(device_id == 2)
def _():
copy_2_to_3.start()
copy_2_to_3.wait_send()
copy_3_to_2.wait_recv()
@pl.when(device_id == 3)
def _():
copy_3_to_2.start()
copy_3_to_2.wait_send()
copy_2_to_3.wait_recv()
DMA 號誌#
send_sem
和 recv_sem
是特殊類型號誌的實例,專門保留用於 DMA。在為 pallas_call
指定輸入規格時,它們必須使用 tpu.SemaphoreType.DMA
類型進行配置。
在內部,DMA 號誌可以被視為整數值進度追蹤器。在 DMA 開始時,本機裝置將開始非同步遞增 send_sem
和接收者的 recv_sem
的值。等待號誌將封鎖,直到號誌的值達到傳送/接收的資料總位元組數;當達到該值時,等待執行緒會被釋放,並且號誌的值會減少相同的量。這表示所有資料都已傳送 (對於 send_sem
) 或所有資料都已接收 (對於 dst_sem
)。可以使用 pl.semaphore_read
讀取號誌的值,但請注意,值的基礎語意可能會在硬體世代之間發生變化 (例如,該值可能不完全代表傳送的位元組數,儘管這是在推斷號誌行為時有用的心智模型)。
路由#
傳送者可以將資料傳送到同一 pod 內的任何接收者,即使它們沒有共享直接連線 (此規則的例外情況是 TPU v5e,其中裝置只能路由到與自身偏移 2 的冪次的位置)。TPU 具有內部路由機制,可以將資料傳遞到路徑上前往目的地的下一個裝置。但是,不建議以這種方式進行通訊,因為作為核心撰寫者,您無法控制網路爭用。我們將在本教學課程中涵蓋的範例透過僅將資料傳輸到相鄰裝置來最大程度地減少低效率的通訊。
失敗模式#
如果錯誤地使用遠端 DMA,您可能會遇到幾種難以除錯的失敗模式。錯誤 DMA 使用的一般症狀是崩潰、掛起或靜默資料損壞
如果號誌以無效的非零值退出程式,Pallas 將崩潰並退出程式。
如果等待號誌,但接收到的位元組數不足 (即沒有傳送者,或者如果傳送的資料小於接收裝置上
dst_ref
的大小),則程式可能會無限期地掛起,等待永遠不會傳送的位元組。在這種情況下,需要重新啟動程式。如果遇到競爭條件,如果發生兩個同時寫入或同時讀取和寫入,則可能會發生靜默資料損壞。
上述情況的一些常見原因包括
如果裝置呼叫
.wait_recv()
但沒有其他裝置傳送給它,則核心可能會掛起。如果傳送到裝置的位元組數多於它預期接收的位元組數,則也可能會由於非零號誌狀態而崩潰。如果傳送的位元組數較少,則可能會無限期地掛起。
如果 DMA 已啟動但未等待號誌,則程式可能會由於非零號誌狀態而崩潰。
如果兩個裝置複製到相同的目的地,則可能會由於競爭條件而遇到非決定性結果,或由於非零號誌狀態而崩潰。
範例:向右置換 (lax.ppermute
)#
讓我們深入研究一個非常基本的範例。我們將實作一個執行向右置換的核心,其中每個裝置將其資料切片傳送到其右鄰居。
假設我們有一個包含 512 個元素的陣列,我們將其分片成 128 個大小的切片,分佈在 4 個裝置上。每個裝置都會將其切片傳遞到下一個裝置,並且輸出將包含相同的資料,但切片旋轉了 1。這與 lax.ppermute
運算相同,其中置換設定為 (n, (n+1) % 4)
。
為了在分散式模式下呼叫核心,我們將 pallas_call
包裝在 shard_map
轉換中。從那裡,我們可以像撰寫一般單裝置 Pallas 核心一樣撰寫核心,除了我們現在可以存取遠端 DMA 指令。JAX 集合基本運算 (例如 lax.axis_index
) 可用於取得 device_id
,該 device_id
可用於計算要複製到的目標裝置,方法是參考傳遞到 shard_map
的相同命名軸名稱。
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the last dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)
def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
my_id = lax.axis_index('x')
right_neighbor = lax.rem(my_id + 1, num_devices)
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=input_ref,
dst_ref=output_ref,
send_sem=send_sem,
recv_sem=recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
remote_copy_op.wait()
out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=(
# We allocate DMA semaphores in scratch memory.
[pltpu.SemaphoreType.DMA] * 2
),
)
right_permute = pl.pallas_call(
right_permute_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
shard_map.shard_map(
right_permute,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False,
)
)(input_arr)
# Compare Pallas result to XLA shard_map result.
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))
xla_result = jax.jit(
shard_map.shard_map(
lambda x: lax.ppermute(x, 'x', perm),
mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)
print('Input = ', input_arr[0, ::128])
print('Pallas Result = ', pallas_result[0, ::128])
print('lax.ppermute Result = ', xla_result[0, ::128])
print(
'Difference |Pallas - lax.ppermute| = ',
jnp.mean(jnp.abs(pallas_result - xla_result)),
)
Input = [0.9858954 0.11763906 0.9955574 0.775211 ]
Pallas Result = [0.775211 0.9858954 0.11763906 0.9955574 ]
lax.ppermute Result = [0.775211 0.9858954 0.11763906 0.9955574 ]
Difference |Pallas - lax.ppermute| = 0.0
範例:全收集 (lax.all_gather
)#
在下一個範例中,我們將實作全收集集合運算,它在 lax.all_gather
中具有 JAX 等效項。與上述僅涉及一對來源和目的地鄰居的向右置換範例相比,全收集運算需要所有裝置之間的通訊,因此我們必須考慮如何在它們之間路由資料。我們如何實作此操作的具體細節由裝置拓撲決定,我們假設裝置拓撲是一個環。
環狀通訊模式#
我們將撰寫核心,假設為環狀拓撲。環狀拓撲非常適合 TPU,因為沿著環面的任何維度切片都會產生一個環。在撰寫集合運算時,我們通常只需要一次考慮環面的 1D 切片,因為環面的不同維度保留用於不同類型的平行處理 (例如,資料與模型)。
我們將使用的策略是撰寫一個迴圈核心,其中在每次迭代中,裝置從其左鄰居接收分片陣列的一個切片,並將先前接收的切片複製到其右鄰居。在 num_devices
次迭代之後,每個裝置都將在其本機 HBM 中擁有整個陣列的副本。
我們可以重新利用 Pallas 的 grid
引數來實作迴圈。與先前教學課程中執行的迭代陣列圖塊不同,我們將網格設定為 (num_devices,)
,以指示我們想要迭代裝置的數量,並使用 pl.program_id
來取得 Pallas 核心內的迴圈迭代。以下程式碼片段示範如何實作此操作
partition = P('x', None)
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# Create an input array that shards the first dimension across
# all devices.
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)
def all_gather_kernel(input_ref,
output_ref,
local_copy_sem,
send_sem,
recv_sems):
outer_step = pl.program_id(0)
my_id = lax.axis_index('x')
right_neighbor = lax.rem(my_id + 1, num_devices)
copy_slot = my_id - outer_step
copy_slot = lax.rem(copy_slot + num_devices, num_devices)
@pl.when(outer_step == 0)
def _():
local_copy_op = pltpu.make_async_copy(
src_ref=input_ref,
dst_ref=output_ref.at[my_id],
sem=local_copy_sem,
)
local_copy_op.start()
local_copy_op.wait()
# Copy to our right neighbor.
# Note that we will also be receiving data from our left neighbor,
# but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
# that the indices do not need to be symmetric between remote DMAs.
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=output_ref.at[copy_slot],
dst_ref=output_ref.at[copy_slot],
send_sem=send_sem,
recv_sem=recv_sems.at[outer_step],
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
remote_copy_op.wait()
out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=(
# DMA semaphores are allocated in scratch memory.
# We allocated one semaphore for a local HBM-VMEM copy,
# and one for the remote send semaphore.
[pltpu.SemaphoreType.DMA] * 2
# We additionally allocate one receive semaphore per device.
# This is to avoid situations where we have multiple
# DMAs in flight, as we do not want to share a receive
# semaphore between the DMAs.
+ [pltpu.SemaphoreType.DMA((num_devices-1,))]
),
grid=(num_devices-1,)
)
all_gather = pl.pallas_call(
all_gather_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
)
# Wrap the kernel within a shard_map to call.
pallas_result = jax.jit(
shard_map.shard_map(
all_gather,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False
)
)(input_arr)
# Compare Pallas result to XLA shard_map result.
xla_result = jax.jit(
shard_map.shard_map(
lambda x: lax.all_gather(x, 'x'),
mesh=mesh, in_specs=partition, out_specs=partition
)
)(input_arr)
print('Input: ', input_arr.shape, input_arr[::8, 0])
print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
print('Difference |Pallas - lax.all_gather| = ',
jnp.mean(jnp.abs(pallas_result - xla_result)))
Input: (32, 128) [0.9858954 0.54248166 0.9547038 0.954962 ]
Pallas Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166
0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962
0.9858954 0.54248166 0.9547038 0.954962 ]
lax.all_gather Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166
0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962
0.9858954 0.54248166 0.9547038 0.954962 ]
Difference |Pallas - lax.all_gather| = 0.0
這裡值得一提的一個細節是使用多個接收號誌。由於我們僅封鎖接收裝置,因此傳送者仍然有可能在接收者完成處理第一個 DMA 之前傳送多個正在進行中的 DMA (請參閱下一節和歸約總和範例,其中更詳細地討論了競爭條件)。在這種情況下,我們可能會遇到同一號誌同時用於多個 DMA 的情況。為了避免這種情況,我們配置了 num_devices-1
個號誌,因此沒有重複使用的風險。雖然這種競爭條件不太可能在如此小的核心上發生,但在較大的核心上,裝置更有可能失去同步,並可能導致靜默失敗。
進階技術#
現在我們已經了解如何使用遠端 DMA 運算撰寫幾個基本核心,我們將介紹用於同步處理和撰寫高效核心的更進階技術。
同步處理:常規和屏障號誌#
我們在基本教學課程中實作的範例不需要特殊處理同步處理,因為所有必要的通訊都會寫入不相交的緩衝區。但是,其他運算可能需要更複雜的通訊模式,這些模式需要額外的同步處理基本運算,以避免競爭條件。Pallas 提供了兩個額外的基本運算來協助完成此操作:常規號誌和屏障號誌。
常規號誌#
常規號誌是用於跨多個裝置同步處理的標準工具。號誌從根本上來說是計數器 — 它們可以由任何裝置遞增,之後裝置可以封鎖,直到號誌的值達到特定值 (然後遞減該值)。
可以在常規號誌上使用的三個主要運算是訊號、等待和讀取
def semaphore_signal(
sem: Ref[SemaphoreType],
inc: int,
device_id: int | tuple[int, ...],
device_id_type: DeviceIdType
) -> None:
... # Increments the semaphore `sem` on the target device `device_id` by `inc`.
def semaphore_wait(
semaphore: Ref[SemaphoreType],
value: int,
) -> None:
... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.
def semaphore_read(
sem: Ref[SemaphoreType],
) -> jax.Array:
... # Returns the current value of `sem` as an `int32[]`.
為了使用常規號誌,可以像配置 DMA 號誌一樣配置它們,但要指定 pltpu.SemaphoreType.REGULAR
而不是 pltpu.SemaphoreType.DMA
。
號誌在 Pallas 程式結束時必須為零才能成功完成。有兩種錯誤情況可能會發生這種情況
如果號誌被過度發出訊號,程式將以非零 (>0) 號誌結束。在這種情況下,程式將在完成時崩潰。這對於除錯很有用,因為非零號誌通常表示程式內部某處存在錯誤。
如果號誌被過度等待,程式將在封鎖的
semaphore_wait
呼叫上掛起,同時等待號誌遞增。在這種情況下,需要重新啟動裝置或程式。
屏障號誌#
障礙號誌 (Barrier semaphores) 是全域分配 (globally-allocated) 的號誌,用於同步整個程式中的裝置 (devices),並確保所有裝置都已進入 Pallas 核心 (kernel)。
如果 Pallas 核心在較大的 XLA 程式的上下文中執行,我們需要確保所有進行通訊的裝置都已進入核心。然而,DMA 和一般號誌 (regular semaphores) 都是局部範圍 (locally scoped) 的 - 它們僅能被已進入核心的其他裝置所理解。障礙號誌作為一種全域理解 (globally understood) 的號誌,可用於同步,無論裝置目前在 XLA 程式中的哪個位置執行。
預設情況下,如果您未指定障礙號誌,Pallas 會在程式的開頭自動插入一個障礙號誌。然而,編寫您自己的障礙號誌可能更有效率。障礙號誌與一般號誌類似,它們都是計數器 (counters),可以透過 semaphore_signal
遞增,並透過 semaphore_wait
遞減。它們透過在核心中呼叫 get_barrier_semaphore()
來建立。通常,我們在核心的開頭使用障礙號誌一次,以與所有正在通訊的裝置同步。
from jax.experimental.pallas import tpu as pltpu
def example_kernel(...):
# Use barrier semaphores at the beginning of a kernel.
# is_start_of_kernel = ...
# right_neighbor = ...
# ...
@pl.when(is_start_of_kernel)
def _():
barrier_sem = pltpu.get_barrier_semaphore()
# Increment the semaphore of your right neighbor.
pltpu.semaphore_signal(
barrier_sem,
device_id=right_neighbor,
device_id_type=pltpu.DeviceIdType.LOGICAL,
)
# Wait until your left neighbor has incremented your semaphore
pltpu.semaphore_wait(barrier_sem, 1)
# ...
當使用障礙號誌時,必須將 collective_id
編譯器參數傳遞給 pallas_call
,以指定正在使用的障礙號誌。TPU 具有少量且固定數量的可用障礙號誌(通常約為 20-30 個),因此應謹慎使用。為了確保正確性,只有共享相同通訊模式 (communication pattern) 的核心才應使用相同的 collective_id
。例如,如果兩個核心僅與同一網格軸 (mesh axis) 上的鄰居同步,則它們可以共享相同的 collective_id
。但是,如果兩個核心沿不同的軸同步,它們必須具有不同的 collective_id
。否則可能會導致難以偵錯的競爭條件 (race conditions)。
kernel = pl.pallas_call(
example_kernel,
...,
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
雙緩衝 (Double-buffering)#
為了避免從本地 Ref
讀取,而該 Ref
也正在被另一個裝置寫入,並產生競爭條件,一種有用的技術是「雙緩衝」策略,我們為每個目標值 (destination value) 分配兩個 Ref
。在每次迭代 (iteration) 中,一個 Ref
將被指定為「工作 (working)」槽 (slot),另一個將被指定為「接收 (receiving)」槽。裝置可以自由使用工作槽進行計算,但只會將資料複製到其鄰居的接收槽。工作槽和接收槽在每次迭代時交替,以便在複製完成後,舊的接收槽變成新的工作槽,反之亦然。正確使用此方案,永遠不會從同一個緩衝區讀取和寫入資料。
以下程式碼骨架 (code skeleton) 示範了如何使用雙緩衝。我們在變數 iteration
中維護一個運行的迭代計數器,並且 working_slot
和 receiving_slot
在每次迭代時在 0 和 1 之間交替。dst_ref
被分配為雙緩衝區,大小為 [2, ...]
。在每次迭代中,我們使用 dst_ref.at[working_slot, ...]
從工作槽讀取,並使用該值執行計算。同時,我們複製到鄰居的 dst_ref.at[receiving_slot]
,以避免覆寫其 working_slot
值。透過以這種方式建構通訊,可以在最小化競爭條件風險的同時,將遠端 DMA 的通訊延遲與本地計算重疊。
def kernel(...):
# ...
iteration = pl.program_id(0)
working_slot = lax.rem(iteration, 2)
receiving_slot = 1 - working_slot
# ...
local_copy_op = pltpu.make_async_copy(
src_ref=dst_ref.at[working_slot, ...],
dst_ref=local_scratch_ref,
sem=local_copy_sem,
)
local_copy_op.start()
remote_copy_op = pltpu.make_async_remote_copy(
src_ref=src_ref,
dst_ref=dst_ref.at[receiving_slot, ...],
send_sem=send_sem,
recv_sem=recv_sem,
device_id=target_device,
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy_op.start()
local_copy_op.wait()
# ... do work on local_scratch while waiting for async_copy_op to finish.
remote_copy_op.wait()
在同步方面,如果所有裝置都在相同的迭代上執行,則雙緩衝結構 (double-buffered construction) 可以正常運作。如果發送者 (sender) 設法比其接收者 (receiver) 快一個迭代,則其 working_slot
和 receiving_slot
索引將與接收者相比翻轉,這意味著它可能在接收者從 working_slot
讀取的同時寫入到該槽中。為了避免這種情況,可能需要使用號誌來同步發送者和接收者,或添加額外的緩衝槽(「三倍 (triple)」、「四倍 (quadruple)」或 N 倍緩衝),以允許更多的超前執行 (run-ahead),但會以更多記憶體為代價。在我們之前的 all_gather
範例中,請注意核心包含一個具有 N 個槽的接收緩衝區,這完全避免了競爭條件。在我們的下一個核心中,我們將改為介紹一個使用雙緩衝和顯式同步的範例。
範例:All-Reduce Sum (lax.psum
)#
我們現在將實作一個使用雙緩衝和號誌進行同步的 all-reduce sum 核心。對於那些熟悉 JAX 中的集合運算 (collective operations) 的人來說,等效的操作是 lax.psum
。All-reduce 是一種標準的集合運算,其目標是沿陣列的一個軸 (axis) 進行歸約 (reduce),但該陣列被分片 (sharded) 到多個裝置上。
在上面的範例中,我們有一個陣列 [5, 2, 1, 3] 分片到 4 個裝置上。all-reduce sum 操作將對所有值求和,並在每個裝置上複製結果,從而產生結果 [11, 11, 11, 11] 分片到所有 4 個裝置上。
all-reduce 的 naive 實作 (naive implementation) 是將所有必要的值收集到每個裝置上,然後進行歸約。然而,我們可以透過將通訊與計算交錯 (interleaving) 來提高此實作的效能。交錯的單向 (single-direction) all-reduce 可以視覺化如下。在每次迭代中,我們從左鄰居接收一個輸入值,並同時將輸入傳遞給下一個鄰居,同時用我們的本地累加器 (local accumulator) 遞增它。經過 N-1 次迭代後,每個裝置的記憶體中都將有一個完整總和的副本。
整合在一起 (Putting it all together)#
以下核心示範了如何將這些原則組合到一個功能性核心 (functional kernel) 中。
序言 (prologue)(當 outer_step==0
時執行)首先與兩個鄰居啟動一個障礙,以確保它們也已進入核心。它還處理所有 Ref
的初始化,並處理到右鄰居「工作」槽的第一個遠端複製 (remote copy)。
主體 (main body) 假設一個值已經從先前的迭代或序言複製到我們的本地工作槽中。一個複雜的因素是我們的目標緩衝區 (destination buffers) 位於 HBM 中,但我們需要在執行算術運算之前將值載入到 VMEM 中。因此,我們同時將工作槽值複製到我們的 VMEM (receive_scratch
),並將該值傳遞到右鄰居的接收槽。一旦該值被複製到我們的 VMEM 中,我們就可以將其累加到我們的結果中(包含在 o_ref
中)。
如果一個裝置比其右鄰居快一個迴圈 (loop) 運行,則可能會發生微妙的競爭條件。在這種情況下,它可能會在接收者從 working_slot
讀取的同時複製到接收者的 working_slot
中。為了避免這種情況,每個裝置在複製到右鄰居的 dst_ref
之前,都會在 REGULAR
號誌上阻塞 (block),直到它發出訊號表示已完成從其 working_slot
的讀取。對於像此範例這樣的小型核心,這種競爭條件很少被觸發,但如果例如使用 pltpu.delay
指令來人為地掛起裝置,則可以顯式觸發它。
請注意,這不是一個最佳或完全通用的核心,因為區塊大小必須完全適合 VMEM,並且我們可以更好地交錯通訊和累加。我們將在後續章節中討論這些最佳化。
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)
def all_reduce_kernel(
x_ref,
o_ref,
hbm_scratch,
copy_sem,
remote_recv_sem,
remote_send_sem,
capacity_sem,
receive_scratch,
):
outer_step = pl.program_id(0)
working_slot = lax.rem(outer_step, 2)
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = lax.rem(my_id + 1, num_devices)
left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)
@pl.when(outer_step == 0)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
# Initialize o_ref, acc_scratch, and hbm_scratch.
o_ref[...] = jnp.zeros_like(o_ref)
receive_scratch[...] = jnp.zeros_like(receive_scratch)
initial_copy = pltpu.make_async_remote_copy(
src_ref=x_ref,
dst_ref=hbm_scratch.at[working_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_copy.start()
initial_copy.wait()
# Signal to our left neighbor that we are ready to receive.
# Without this signal, our left neighbor can be >=1 iteration ahead,
# meaning it could write into our working slot.
pltpu.semaphore_signal(
capacity_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# Copy the partial result our left neighbor sent to us into VMEM for
# computation.
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot],
dst_ref=receive_scratch,
sem=copy_sem,
)
local_copy.start()
# Block until our right neighbor is ready to receive.
pltpu.semaphore_wait(capacity_sem, 1)
# Pass the value to our right neighbor.
remote_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot],
dst_ref=hbm_scratch.at[receiving_slot],
send_sem=remote_send_sem,
recv_sem=remote_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
remote_copy.start()
# Finish local copy and accumulate while remote_copy is happening.
local_copy.wait()
o_ref[...] += receive_scratch[...]
# Block until remote copy finishes.
remote_copy.wait()
out_shape = (
jax.ShapeDtypeStruct((8, 128), jnp.float32),
# We allocate the double-buffer as a Pallas output so that it is
# resident in HBM.
jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# Our input lives in VMEM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
# Our output lives in VMEM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
# Our double-buffer lives in HBM
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices,),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 3
+ [pltpu.SemaphoreType.REGULAR] # capacity_sem
+ [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch
),
)
kernel = pl.pallas_call(
all_reduce_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)
pallas_result = jax.jit(
shard_map.shard_map(
kernel,
mesh=mesh,
in_specs=partition,
out_specs=partition,
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]
def lax_sum(x):
return lax.psum(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
)
)(input_arr)
print('Input = ', input_arr[0, ::128])
print('Pallas result = ', pallas_result[0, ::128])
print('lax.psum result = ', xla_result[0, ::128])
difference = jnp.mean(jnp.abs(pallas_result - xla_result))
print('Difference |Pallas - lax.psum| = ', difference)
Input = [0.9858954 0.11763906 0.9955574 0.775211 ]
Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]
lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]
Difference |Pallas - lax.psum| = 1.4959369e-08
超前執行與競爭條件 (Run-ahead and Race Conditions)#
作為一般的經驗法則 (rule of thumb),為了最大化效能,我們希望允許一個裝置在不犧牲程式正確性的情況下,盡可能地超前於其他裝置執行,而無需同步。雖然我們可以在每次迭代開始時在所有裝置上強制執行障礙,但這會將程式的效能瓶頸 (bottlenecks) 限制在每個迴圈中最慢的裝置上。透過放寬同步並允許適量的超前執行,我們可以更好地適應迭代和裝置之間延遲的差異,因為在一次迭代中速度較慢的裝置可能會在下一次迭代中趕上。
在我們之前編寫的 all-reduce 核心中,我們允許裝置超前執行,但與其鄰居相比少於一次迭代(但是,非鄰近裝置可能相差超過 1 次迭代)。為了了解為什麼號誌同步是必要的,請考慮當一個裝置(例如裝置 2)掛起並落後於其他裝置的情況。RDMA 沒有「握手 (handshake)」— 只有接收者在等待資料到達時被阻塞。因此,每個裝置最多可以超前執行一次迭代,然後才會被阻塞等待下一個 RDMA 到達。如果我們有 N 個裝置,這意味著最後一個裝置最多可以比第一個裝置超前 N 次迭代。
如果不在另一個方向添加同步(強制發送者阻塞),裝置 1 可能會比裝置 2 超前最多 N
次迭代(N = num_devices
),在此過程中發送多次寫入並覆寫值。為了在我們之前編寫的 all_reduce
核心中解決這個問題,我們實作了一個「握手」協定,其中接收者向發送者發回訊號,表示它已準備好接收,然後發送者才開始發出下一個 RDMA。
雙向通訊 (Bi-directional Communication)#
在我們之前的核心中,我們在環狀結構中沿單一方向從左到右進行通訊。然而,由於 ICI 連線是雙向的,因此我們實際上浪費了一半的總頻寬,因為沒有沿相反方向從右到左發送值。在下一個核心中,我們將示範一個雙向通訊的範例,以最大化 ICI 頻寬。
範例:雙向 Reduce-Scatter (lax.psum_scatter
)#
reduce-scatter 運算是 all-reduce 後接 scatter 的組合。或者,all-reduce 是 reduce-scatter 後接 all-gather 的組合。
下圖描述了此運算的語義 (semantics)。我們假設每個裝置都從一組部分總和 (partial sums) 開始(用字母 + 數字表示,例如 A0
)。目標是沿一個軸(數字)進行歸約,同時沿另一個軸(字母)進行分片。
為了實作雙向通訊策略,我們將每個輸入區塊 (input block) 切成兩半,並為每一半指定一個方向。每個區塊的上半部分將從右到左傳遞,下半部分將從左到右傳遞。與我們之前的 all-reduce 和 all-gather 核心的通訊模式的第二個偏差是,我們還將傳遞累加器或部分總和,並將輸入保持在每個裝置的本地。這與之前的範例形成對比,在之前的範例中,我們傳遞輸入,但將累加器保持在裝置的本地。傳遞累加器更適合這個問題,因為與 all-reduce 相比,輸入中的大部分資料都不是最終將本地儲存在裝置上的輸出的一部分。(例如,上圖中的 B0
、C0
和 D0
將不會儲存在最終持有 A
的裝置上)。
下圖說明了這種通訊模式,其中彩色框代表累加器(不是輸入!)。最初,累加器只是輸入中包含的值。在演算法的每次迭代中,我們將從每個方向的鄰居接收一個部分總和。然後,我們計算輸入的正確切片 (slice) 以累加到部分緩衝區 (partial buffer) 中,然後將新的部分總和傳遞給我們的下一個鄰居。經過 N 次迭代後,累加器將已通過每個裝置,這意味著它最終將持有完整總和。
在核心的建構方面,我們在 Pallas 網格中引入了一個額外的 phase
維度,它表示我們目前正在計算哪個累加器(左或右)。我們讓 phase=0
表示向左移動的累加器,phase=1
表示向右移動的累加器。然後,我們將這兩個階段 (phases) 管線化 (pipeline),以便在計算一個階段的結果時,我們在相反方向傳輸我們先前計算的值,以準備下一個階段。例如,當我們處於 phase=0
(左)時,我們首先開始一個 DMA,將我們在先前迭代中計算的結果傳輸到我們的右鄰居(右 DMA)。然後,我們累加到左緩衝區 (left-buffer) 中,並將結果儲存到 HBM。然後,我們等待右 DMA 完成,以便為 phase=1
(右)做好準備。
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We need a block size of (16, 128) to ensure that a half-slice is at least
# of size (8, 128), which is the size of a VREG. This makes tiling easier
# for the compiler.
block_size = (16, 128)
input_arr = jax.random.uniform(
jax.random.key(0),
shape=(block_size[0] * num_devices, block_size[1] * num_devices),
)
input_arr = jax.device_put(input_arr, sharding)
LEFT = 0
RIGHT = 1
def mod(x, n):
return lax.rem(x + n, n)
def signal(left_or_right, semaphore):
my_id = lax.axis_index('x')
if left_or_right == LEFT:
neighbor = mod(my_id - 1, num_devices)
else:
neighbor = mod(my_id + 1, num_devices)
pltpu.semaphore_signal(
semaphore,
inc=1,
device_id=(neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
def reduce_scatter_kernel(
x_ref,
o_ref,
hbm_scratch,
local_copy_sem,
left_recv_sem,
left_send_sem,
right_recv_sem,
right_send_sem,
left_capacity_sem,
right_capacity_sem,
accum_scratch,
):
outer_step = pl.program_id(0)
phase = pl.program_id(1)
is_start = jnp.logical_and(outer_step == 0, phase == 0)
last_iteration = outer_step == pl.num_programs(0) - 1
working_slot = lax.rem(outer_step, 2)
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = mod(my_id + 1, num_devices)
left_neighbor = mod(my_id - 1, num_devices)
left_copy_device = mod(my_id + outer_step + 1, num_devices)
right_copy_device = mod(my_id - outer_step - 1, num_devices)
# Slices can be specified using pl.ds(start, size)
left_copy_slice = pl.ds(0, block_size[0] // 2)
right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)
initial_left_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, left_copy_slice],
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_right_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
left_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
# Note: Right copy is flipped with regards to slots since we are copying
# to the next outer_step iteration.
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# --- Prologue ---
@pl.when(is_start)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
# Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
o_ref[...] = jnp.zeros_like(o_ref[...])
accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
initial_left_copy.start()
initial_left_copy.wait()
initial_right_copy.start()
# We tell our left neighbor that it is allowed to send to the right.
# (and vice versa for right neighbor)
signal(LEFT, right_capacity_sem)
signal(RIGHT, left_capacity_sem)
# --- Body ---
# At the beginning of our kernel body, we start a DMA which copies
# the result we computed in the previous phase to our neighbor.
# This allows us to overlap the communication of sending our previous phase
# with the computation for the current phase.
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
# We block here until our right neighbor tells use we can send to
# the right.
pltpu.semaphore_wait(right_capacity_sem, 1)
right_copy.start()
@pl.when(phase == RIGHT)
def _():
# We block here until our left neighbor tells use we can send to
# the left.
pltpu.semaphore_wait(left_capacity_sem, 1)
left_copy.start()
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=accum_scratch,
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
@pl.when(phase == RIGHT)
def _():
accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
src_ref=accum_scratch,
dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(is_start)
def _():
initial_right_copy.wait()
# At the end of our kernel body, we wait on the DMA of the previous phase
# to make sure the results are ready for the next phase.
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
right_copy.wait()
signal(LEFT, right_capacity_sem)
@pl.when(phase == RIGHT)
def _():
left_copy.wait()
signal(RIGHT, left_capacity_sem)
# --- Epilogue ---
# Store result on last iteration.
@pl.when(last_iteration)
def _():
# Clean up semaphores so that they exit with a value of 0.
@pl.when(phase == LEFT)
def _():
o_ref[left_copy_slice, ...] = accum_scratch[...]
pltpu.semaphore_wait(right_capacity_sem, 1)
@pl.when(phase == RIGHT)
def _():
o_ref[right_copy_slice, ...] = accum_scratch[...]
pltpu.semaphore_wait(left_capacity_sem, 1)
out_shape = (
jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output
# Shape: [working/recv, block[0], block[1]]
jax.ShapeDtypeStruct(
(2, block_size[0], block_size[1]), jnp.float32
), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 5
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
+ [
pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
] # accum_scratch
),
)
def pallas_reduce_scatter(input_arr):
input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
return pl.pallas_call(
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)(input_arr)[0]
pallas_result = jax.jit(
shard_map.shard_map(
pallas_reduce_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)
# Compare our result to XLA.
def lax_reduce_sum_scatter(x):
x = x.reshape(num_devices, block_size[0], block_size[1])
return lax.psum_scatter(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_reduce_sum_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
)
)(input_arr)
print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
'Difference |Pallas - lax.psum_scatter|:',
jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (64, 512) [0.78051674 0.3524047 0.59993696 0.9714314 0.24692321 0.01347649
0.01857424 0.24841607 0.86097646 0.8261659 0.9753758 0.6902338
0.4431417 0.963323 0.3158517 0.535548 ]
Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303
1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733
2.4662397 1.9542246]
lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303
1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733
2.4662397 1.9542246]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07
巢狀遠端和本地 DMA 管線 (Nested Remote and Local DMA Pipelines)#
我們之前編寫的 all-reduce 和 reduce-scatter 核心的一個限制是,我們透過遠端 DMA 複製的區塊必須足夠小,才能容納在我們用於累加的工作 VMEM 中。對於某些核心,使用更大的區塊大小可能有利於更好地利用 TPU。例如,矩陣乘法需要大約 \(O(N^3)\) 的計算操作,但僅需要 \(O(N^2)\) 的記憶體傳輸。因此,我們希望在裝置之間傳輸的每個工作區塊都足夠大,以便操作變成計算密集型 (compute bound),並且我們可以透過管線化來隱藏通訊成本。作為參考,TPU 的 VMEM(對於 v4/v5 世代)通常約為 10-100MB,而 HBM 的範圍為 10-100GB。
為了解決這個問題,我們需要能夠編寫一個「內部核心 (inner kernel)」,在「外部核心 (outer kernel)」內部處理本地 HBM-VMEM 管線化,而「外部核心」處理裝置之間更大的 HBM-HBM 傳輸的管線化。Pallas 提供了一個 API,用於使用 emit_pipeline
函數建構巢狀管線 (nested pipelines)。emit_pipeline
的基本呼叫簽名 (call signature) 遵循標準 pallas_call
的模式,透過指定 grid
和輸入和輸出的 BlockSpec
s。
def emit_pipeline(
kernel: Callable,
grid: tuple[int],
in_specs: PyTree[BlockSpec] = None,
out_specs: PyTree[BlockSpec] = None,
should_accumulate_out: bool = False,
dimension_semantics: tuple[GridDimensionSemantics] = None,
) -> Callable:
... # Returns a custom pipeline given an inner kernel and BlockSpecs.
實際上,可以將 pallas_call
本身視為僅僅是 emit_pipeline
的包裝器 (wrapper)。由於我們的外部核心僅涉及遠端 HBM-HBM 傳輸,因此我們沒有使用 pallas_call
為 HBM-VMEM 傳輸提供的任何內建管線化。以下程式碼骨架示範了使用這種模式的典型程式結構會是什麼樣子。
def outer_kernel(...):
# ... do work to pipeline remote HBM-HBM transfers (outer kernel)
def inner_kernel(...):
# ... do work (inner kernel)
pltpu.emit_pipeline(
inner_kernel,
grid=inner_grid,
in_specs=...,
out_specs=...,
)(inner_kernel_args)
# ... do more work (outer kernel)
pl.pallas_call(
outer_kernel,
grid=outer_grid,
in_specs=...
out_specs=...
scratch=inner_kernel_allocs
)
範例:使用大型 HBM 區塊的 Reduce-Scatter (Reduce-Scatter with large HBM blocks)#
在下一個範例中,我們將修改我們之前的 reduce-scatter 範例,以利用巢狀內部管線。請注意,reduce_scatter
的通訊和計算成本都與輸入大小成線性比例縮放,因此我們不一定期望看到操作隨著更大的區塊大小而變成計算密集型。此範例純粹是為了示範如何使用管線發射器 (pipeline emitter)。
我們將增加外部核心的區塊大小,使其不適合放置在 VMEM 內部,並在 HBM 中分配所有輸入和輸出 (memory_space=TPUMemorySpace.Any
)。與我們之前的核心相比,唯一的主要變化是執行累加的核心主體。我們不再手動從 HBM 複製到 VMEM、累加,然後複製回 HBM,而是使用 emit_pipeline
來為我們處理記憶體傳輸。累加在一個具有更小、VMEM 友善的區塊大小的內部核心中完成。
在我們之前的核心中,我們有以下核心主體,用於將資料從 HBM 複製到 VMEM 累加器、遞增,然後將結果複製回 HBM
local_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=accum_scratch,
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
@pl.when(phase == RIGHT)
def _():
accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
src_ref=accum_scratch,
dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
我們的新核心將其替換為以下 emit_pipeline
呼叫
def inner_kernel(input_ref, accum_ref):
accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(inner_kernel,
in_specs=[inner_block_spec],
out_specs=inner_block_spec,
should_accumulate_out=True,
grid=inner_grid)
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
hbm_scratch.at[working_slot, left_copy_slice],
)
@pl.when(phase == RIGHT)
def _():
accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
hbm_scratch.at[working_slot, right_copy_slice],
)
完整的核心如下所示
partition = P(None, 'x')
mesh = jax.make_mesh((num_devices,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, partition)
# We pick a large outer kernel block size that we do not want to place
# in VMEM. For pedagogical purposes we use (4096, 4096), although in
# principle this can be much larger.
outer_block_size = (4096, 4096)
# We pick a smaller VMEM block size for the inner kernel.
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
jax.random.key(0),
shape=(
outer_block_size[0] * num_devices,
outer_block_size[1] * num_devices,
),
)
input_arr = jax.device_put(input_arr, sharding)
inner_grid = (
outer_block_size[0] // inner_block_size[0] // 2,
outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
index_map=lambda i, j: (i, j),
block_shape=inner_block_size,
memory_space=pltpu.TPUMemorySpace.ANY,
)
def reduce_scatter_kernel(
x_ref,
o_ref,
hbm_scratch,
left_recv_sem,
left_send_sem,
copy_sem,
right_recv_sem,
right_send_sem,
left_capacity_sem,
right_capacity_sem,
):
outer_step = pl.program_id(0)
phase = pl.program_id(1)
is_start = jnp.logical_and(outer_step == 0, phase == 0)
last_iteration = outer_step == pl.num_programs(0) - 1
working_slot = lax.rem(outer_step, 2)
receiving_slot = 1 - working_slot
my_id = lax.axis_index('x')
right_neighbor = mod(my_id + 1, num_devices)
left_neighbor = mod(my_id - 1, num_devices)
left_copy_device = mod(my_id + outer_step + 1, num_devices)
right_copy_device = mod(my_id - outer_step - 1, num_devices)
left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
current_phase_slice = pl.ds(
phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
)
initial_left_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, left_copy_slice],
dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
initial_right_copy = pltpu.make_async_remote_copy(
src_ref=x_ref.at[my_id, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
left_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[working_slot, left_copy_slice],
dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
send_sem=left_send_sem,
recv_sem=left_recv_sem,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
right_copy = pltpu.make_async_remote_copy(
src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
send_sem=right_send_sem,
recv_sem=right_recv_sem,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
# --- Prologue ---
@pl.when(is_start)
def _():
# Barrier with both neighbors at the start, since we will be
# communicating with both.
barrier_sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(left_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_signal(
barrier_sem,
inc=1,
device_id=(right_neighbor,),
device_id_type=pltpu.DeviceIdType.MESH,
)
pltpu.semaphore_wait(barrier_sem, 2)
initial_left_copy.start()
initial_left_copy.wait()
initial_right_copy.start()
# We tell our left neighbor that it is allowed to send to the right.
# (and vice versa for right neighbor)
signal(LEFT, right_capacity_sem)
signal(RIGHT, left_capacity_sem)
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
# We block here until our right neighbor tells use we can send to
# the right.
pltpu.semaphore_wait(right_capacity_sem, 1)
right_copy.start()
@pl.when(phase == RIGHT)
def _():
# We block here until our left neighbor tells use we can send to
# the left.
pltpu.semaphore_wait(left_capacity_sem, 1)
left_copy.start()
# --- Body ---
def inner_kernel(input_ref, accum_ref):
# We do not explicitly use += because we set should_accumulate_out=True.
accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(
inner_kernel,
in_specs=[inner_block_spec],
out_specs=inner_block_spec,
should_accumulate_out=True,
grid=inner_grid,
)
@pl.when(~last_iteration)
def _():
@pl.when(phase == LEFT)
def _():
accum_pipeline(
x_ref.at[left_copy_device, left_copy_slice],
hbm_scratch.at[working_slot, left_copy_slice],
)
@pl.when(phase == RIGHT)
def _():
accum_pipeline(
x_ref.at[right_copy_device, right_copy_slice],
hbm_scratch.at[working_slot, right_copy_slice],
)
# --- Epilogue ---
@pl.when(is_start)
def _():
initial_right_copy.wait()
@pl.when(~is_start)
def _():
@pl.when(phase == LEFT)
def _():
right_copy.wait()
signal(LEFT, right_capacity_sem)
@pl.when(phase == RIGHT)
def _():
left_copy.wait()
signal(RIGHT, left_capacity_sem)
# Store result on last iteration.
@pl.when(last_iteration)
def _():
output_copy = pltpu.make_async_copy(
src_ref=hbm_scratch.at[working_slot, current_phase_slice],
dst_ref=o_ref.at[current_phase_slice],
sem=copy_sem,
)
output_copy.start()
output_copy.wait()
# Clean up semaphores so that they exit with a value of 0.
@pl.when(phase == LEFT)
def _():
pltpu.semaphore_wait(right_capacity_sem, 1)
@pl.when(phase == RIGHT)
def _():
pltpu.semaphore_wait(left_capacity_sem, 1)
out_shape = (
jax.ShapeDtypeStruct(
(outer_block_size[0], outer_block_size[1]), jnp.float32
),
# Shape: [working/recv, block[0], block[1]]
jax.ShapeDtypeStruct(
(2, outer_block_size[0], outer_block_size[1]), jnp.float32
), # hbm_scratch
)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
out_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
[pltpu.SemaphoreType.DMA] * 5
+ [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
),
)
def pallas_reduce_scatter(input_arr):
input_arr = input_arr.reshape(
num_devices, outer_block_size[0], outer_block_size[1]
)
return pl.pallas_call(
reduce_scatter_kernel,
out_shape=out_shape,
grid_spec=grid_spec,
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)(input_arr)[0]
pallas_result = jax.jit(
shard_map.shard_map(
pallas_reduce_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
check_rep=False,
)
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)
# Now we compare our result to XLA.
def lax_reduce_sum_scatter(x):
x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
return lax.psum_scatter(x, 'x')
xla_result = jax.jit(
shard_map.shard_map(
lax_reduce_sum_scatter,
mesh=mesh,
in_specs=P(None, 'x'),
out_specs=P('x', None),
)
)(input_arr)
print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
'Difference |Pallas - lax.psum_scatter|:',
jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (16384, 16384) [0.74162567 0.0242182 0.27751946 ... 0.05213022 0.36088037 0.04494429]
Pallas Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]
lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07
最終注意事項 (Final Notes)#
Megacore#
某些 TPU 在 Megacore 配置中包含多個核心 (cores)。在這種配置中,我們的一般建議是僅從單一核心啟動 DMA,並且僅執行 HBM-HBM 傳輸。為此,將其中一個網格軸 (grid axes) 設定為核心數量(可以透過 jax.devices()[0].num_cores
取得),並將 dimension_semantics 設定為 "parallel"
。然後,您可以使用 core_index = pl.program_id(axis)
來取得沿該軸的核心索引,並使用 @pl.when(core_index==i)
來執行特定於該核心的程式碼。
與 XLA 互動 (Interaction with XLA)#
在本教學課程中,我們介紹了幾個核心範例,它們複製了 JAX 中集合運算的功能,例如 lax.all_gather
、lax.psum
和 lax.psum_scatter
。需要注意的一個重要警告是,Pallas 核心對於 XLA 編譯器來說在某種程度上是不透明的,並且可能導致它錯過一些它通常會執行的最佳化 (optimizations)。例如,XLA 可以非同步地 (asynchronously) 調度集合運算,以便在不編寫自訂核心的情況下交錯通訊和計算。當涉及 Pallas 核心時,這不能保證會發生,因此分析您的程式以查看這是否是一個問題非常重要。另一個範例是,我們在本教學課程中用於產生巢狀管線的 emit_pipeline
函數對於 XLA 編譯器來說是不可見的,因此無法與相鄰的操作融合 (fused)。
後續步驟 (Next Steps)#
讀者可以進行的絕佳後續練習可能包括實作分散式矩陣乘法 (distributed matrix multiplication)、實作 lax.all_to_all
,以及放寬同步以允許額外的超前執行。