管線化#
在本指南中,我們將介紹 TPU 中記憶體空間的運作方式,以及如何在 Pallas 中撰寫可將記憶體 I/O 與計算重疊的管線。
#@title Imports
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
TPU 及其記憶體空間#
TPU 及其 TensorCore 由記憶體空間 (陣列可以駐留於其中)、暫存器 (暫時儲存純量和陣列值) 和運算單元 (使用暫存器中的值進行運算) 組成。以下是 TPU 的示意圖,其中 x
和 y
是存在於高頻寬記憶體 (HBM) 中的陣列
讓我們更詳細地討論此圖表的組件
記憶體空間:TPU 具有高頻寬記憶體 (HBM),這通常是我們認為的「裝置記憶體」。還有向量記憶體 (VMEM),一種旨在儲存向量和陣列值的快取,以及純量記憶體 (SMEM),一種旨在儲存純量值的快取。
暫存器:TensorCore 具有兩種主要類型的暫存器:向量暫存器 (VREG) 儲存陣列值,而純量暫存器 (SREG) 儲存純量值。值可以從其各自的快取 (VREG 的 VMEM 和 SREG 的 SMEM) 載入到記憶體中。
運算單元:TensorCore 具有純量單元、向量單元 (VPU) 和矩陣單元 (MXU),可以進行數值運算。運算單元對 SREG 和 VREG 中存在的值進行運算,並將值輸出到這些暫存器中。
為了對存在於 HBM 中的值 x
和 y
執行向量化計算,我們需要
將值
x
和y
複製到 VMEM 中。從 VMEM 將值載入到 VREG 中。
使用 VPU 或 MXU 執行計算,將輸出儲存在 VREG 中。
將輸出 VREG 中的值儲存到 VMEM 中。
將 VMEM 中的輸出值複製回 HBM。
讓我們實作一個 Pallas 函數來完成這件事!
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
# It will then copy `x` and `y` from HBM into VMEM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from VMEM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我們撰寫了兩個函數:add_matrices_kernel
和 add_matrices
。
add_matrices_kernel
使用存在於 VMEM 中的 Ref
進行操作。從 VMEM Ref
載入會產生存在於 VREG 中的值。VREG 中的值行為類似於 jax.Array
,因為我們可以在它們上使用 jnp
和 jax.lax
運算,以產生存在於 VREG 中的新值。當我們產生想要傳回的值時,我們會將它們儲存在輸出 VMEM Ref
中。
add_matrices
函數作用於 jax.Array
並傳回 jax.Array
。在其中,我們將 x
和 y
傳遞到 pallas_call
中。pallas_call
負責將 x
和 y
複製到 VMEM 中,並分配核心運作的 VMEM 緩衝區 (包括分配 z_vmem_ref
,輸出 VMEM 緩衝區)。在核心函數完成執行後,pallas_call
也會將 z_vmem_ref
中的值複製到 HBM,從而產生輸出 jax.Array
。
使用 VMEM/SMEM 的限制#
Pallas 公開了對較低層級記憶體空間 (如 VMEM 和 SMEM) 的存取,但撰寫利用它們的核心增加了一些考量。
記憶體容量。VMEM 和 SMEM 都很小!v4 TPU 上的 VMEM 僅為 16MiB,而 SMEM 的範圍在數十到數百 KiB。如果我們的陣列太大,我們甚至無法將它們全部放入 VMEM 中。作為參考,
f32[2048, 2048]
陣列為 16MiB,因此我們上面的核心無法擴展到超出中等大小的陣列。記憶體頻寬。與大多數運算指令相比,複製到/從 HBM 和 VMEM 耗時較長。
add_matrices
函數可能花費更多時間在 HBM 和 VMEM 之間複製,而不是實際執行加法本身。
考慮到這兩個限制,我們將不得不重新思考從 TPU 獲得效能的策略。
入門:管線化#
管線化我們的計算提供了一種一次性處理記憶體容量和頻寬限制的方法。我們所說的管線化是什麼意思?
目標是:平行複製到/從 HBM 和 VMEM 同時利用我們的運算單元。天真地說,這很困難,因為在我們上面的程式中,我們在開始對它們進行任何計算之前複製所有的 x
和 y
,從而在複製和計算之間建立依賴關係。
但是,如果我們可以將我們的計算分成幾個子計算 (例如,當我們將兩個矩陣相加時,我們可以將其表示為原始矩陣的「區塊」相加),我們現在可以將其中一個子計算的複製與另一個子計算的計算重疊。讓我們逐步完成一個簡單的範例
假設我們將陣列 x
和 y
分割成 x1、x2
和 y1、y2
(例如,沿著前導軸分割,每個輸入產生兩個 (256, 512)
陣列。我們現在可以執行以下管線化計算。
將
x1
和y1
複製到 VMEM 中。開始將
x2
和y2
複製到 VMEM 中從 VMEM 將
x1、y1
載入到 VREG 中。使用運算單元執行
z1 = x1 + y1
。將
z1
儲存到 VMEM 中。開始將
z1
從 VMEM 複製回 HBM。等待直到
x2、y2
已複製到 VMEM 中。從 VMEM 將
x2、y2
載入到 VREG 中。使用運算單元執行
z2 = x2 + y2
。將
z2
儲存到 VMEM 中。等待直到
z1
已複製到 HBM 中。開始將
z2
從 VMEM 複製回 HBM。等待直到
z2
已複製到 HBM 中。
任何時候我們在這裡進行計算,我們都在非同步複製某些東西。這表示一些複製時間沒有被浪費。
用於確定管線化計算效率的兩個最重要數字是 a) 我們需要執行的浮點運算 (FLOP) 數量,以及 b) 我們需要複製多少位元組才能執行該計算。這兩者的比率 (FLOP/記憶體使用量) 稱為運算的算術強度,並決定我們的管線是受計算限制還是受記憶體限制。
Pallas 中的管線化#
我們如何在 Pallas 中實作像上面這樣的管線?這似乎是一個複雜的非同步資料運算序列和執行核心,手動實作會很麻煩。別擔心!Pallas 提供了一個 API 來表達管線,而無需過多的樣板程式碼,即透過 grid
和 BlockSpec
。
請參閱在上面的管線化範例中,我們如何多次執行相同的邏輯:步驟 3-5 和 8-10 都執行相同的運算,只是在不同的輸入上。jax.experimental.pallas.pallas_call()
提供了一種透過使用 grid
引數多次執行核心的方法。請參閱 grid,又名迴圈中的核心。
我們也使用 jax.experimental.pallas.BlockSpec
來指定如何建構每個核心調用的輸入。請參閱 BlockSpec,又名如何將輸入分塊。
在上面的管線化範例中,我們有 (512, 512)
形狀的陣列,並沿著前導維度將它們分割成兩個 (256, 512)
形狀的陣列。在此管線中,我們的 BlockSpec.block_shape
將為 (256, 512)
。在第一次迭代中,我們想要選擇 x1
,而在第二次迭代中,我們想要使用 x2
。這可以使用以下 index_map
來表示
def x_index_map(i):
return (i, 0)
然後我們將建構 BlockSpec
block_spec = pl.BlockSpec((256, 512), x_index_map)
y
和 z
的 BlockSpec
將與 x
的相同。
整合在一起#
我們透過 grid
、in_specs
和 out_specs
將這些引數提供給 pallas_call
(in_specs
對應於位置引數的元組,而 out_specs
對應於輸出)。
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,)
)(x, y)
add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我們只在原始函數中新增了一點程式碼來新增自動管線化,但 BlockSpec
和 grid
完成了很多繁重的工作!
它是如何運作的?嗯,BlockSpec
提供了足夠的資訊來開始從 HBM 預取輸入區塊到 VMEM 中。例如,如果我們正在啟動 grid
的迭代 i
,我們可以將 i + 1
傳遞到 index_map
函數中,以取得下一次迭代所需的區塊。然後我們可以為這些區塊啟動非同步複製。同樣對於輸出,我們可以等待前一次迭代的輸出被複製,然後再啟動目前迭代輸出的複製。
參數化管線#
在我們的核心中參數化區塊形狀是很常見的。區塊大小可能是調整 Pallas 核心效能時最重要的參數!它們讓我們可以控制管線 (例如,選擇較小的區塊會為我們的管線化迴圈新增更多迭代,其中每次迭代的工作量較少)。
此外,我們也可以沿著第二個維度分割輸入和輸出 (我們現在只沿著第一個維度分割)。讓我們撰寫一個更通用的核心,可以處理這兩個功能。
def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
處理歸約#
您將如何使用 pallas_call
實作類似 jnp.sum
的功能?具體來說,我們想要跨歸約維度進行管線化。
以將 (8, 512, 512)
形狀的陣列歸約為 (512, 512)
形狀的陣列為例。
x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
為了使用 pallas_call
執行此操作,我們可以使用大小為 (8,)
的 grid,並在每次迭代 i
中將 x[i]
載入到 VMEM 中。然後我們可以將 x[i]
新增到輸出 VMEM 緩衝區。讓我們先天真地實作這個。
# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
)(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
...,
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.]], dtype=float32)
請注意我們如何設定 BlockSpec
:我們將整個 (512, 512)
維度載入到 VMEM 中 (那裡沒有管線化),但在 index_map
中為每次迭代選擇 x
的第 i
個維度。我們在區塊形狀中為該維度使用 None
,這表示我們正在從 x
中選擇一個單例維度,我們想要在核心中擠壓掉它。因此,x_ref
在 VMEM 中也是 (512, 512)
形狀。
out_spec
使用 lambda i: (0, 0)
作為其 index_map
,表示 o_ref
在管線過程中保持不變。這表示我們可以透過讀取和寫入來更新其每次迭代的值。或者可以嗎?實際上,有一個問題:o_ref
最初是垃圾,這表示我們將累積到垃圾中。這將導致整個函數輸出不正確的值!
因此,每當我們在核心中進行歸約時,我們都需要確保初始化儲存歸約值的 Ref
。我們可以透過在迭代 0 時有條件地將值寫入 out_ref
來完成此操作。我們可以透過輔助函數 pl.when
(圍繞 jax.lax.cond
的便利包裝器) 和 pl.program_id
(查詢我們在 grid 軸中的哪個迭代) 來完成此操作。
def sum_kernel(x_ref, o_ref):
@pl.when(pl.program_id(axis=0) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
現在這個 sum
函數輸出正確的值!
關於 Pallas 中歸約的最後一件事要注意的是,它們必須在我們 grid 的最內層 (最右邊) 維度中完成 (在上面的範例中,我們的 grid 是 1 維的,因此我們正在其最內層維度上進行歸約)。這是因為 Pallas 使用 BlockSpec
、grid
和核心函數產生的管線不會從 HBM 讀回輸出。一旦您將輸出值寫回 HBM,您就無法再次訪問它。因此,您無法跨具有任何重新訪問的 grid 維度進行歸約,因此所有歸約都需要在最右邊的維度中發生。
Megacore 配置中的 TPU#
某些 TPU 晶片具有兩個 TensorCore,但在 JAX 使用者看來,它們是一個裝置。這稱為「megacore」。單獨的 TensorCore 具有自己單獨的 VMEM、VREG、SMEM、SREG 和運算單元,但共享 HBM。
從概念上講,Megacore 中的 TPU 行為很像非常簡單的 GPU,即它們只有兩個執行緒。我們如何修改我們的核心以同時利用兩個 TensorCore?
基本概念是,如果我們的計算中有令人尷尬地並行的維度,我們可以跨 TensorCore 分割這些維度。我們可以透過向 pallas_call
提供一個名為 dimension_semantics
的註解來指示哪些維度可以平行化。
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
應該是一個與 grid
長度相同的元組,其中每個條目都是 "parallel"
或 "arbitrary"
。"parallel"
向 Pallas 指示,對應於該維度的 for 迴圈的迭代可以獨立執行,而不會影響程式的正確性。"arbitrary"
向 Pallas 指示,對於此 grid 維度不能做任何假設,因此無法平行化。
透過指定 dimension_semantics
,我們現在在每個 TensorCore 上同時執行核心。Pallas 將自動處理 grid 的分割。
請注意,Megacore 目前僅在 TPU
v4
和 TPUv5p
上可用。提供dimension_semantics
註解在其他平台上是空操作,但不指定它將導致僅使用一個 TensorCore (即使有多個可用)。
結論#
在本指南中,我們介紹了如何使用 pallas_call
、grid
和 BlockSpec
來表達 TPU 管線。我們涵蓋了如何透過多維網格表達巢狀迴圈,以及如何在歸約開始時初始化累加器來處理歸約。我們也學習了如何透過在核心中加入註解來處理 Megacore。
留給讀者的練習
嘗試實作一個
sum
核心,將其他維度也進行管線化為
add
核心和sum
核心加入 megacore 支援。