矩陣乘法#

在本指南中,我們將使用 Pallas 撰寫矩陣乘法常式。我們也將探討如何在 TPU 上思考矩陣乘法效能,以及如何範本化矩陣乘法核心以融合運算。

#@title Imports
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np

背景#

矩陣乘法是現代深度學習和語言建模核心的基本線性代數運算。我們希望使用像 TPU 和 GPU 這樣的專用加速器,盡可能地加速矩陣乘法,因為它們都具有用於快速矩陣乘法的專用單元。

為了有效地利用 TPU 進行矩陣乘法,我們需要涵蓋幾個背景概念:區塊矩陣乘法、平鋪和管線化。

區塊矩陣乘法#

假設我們想要實作 matmul(x, y),它通常將 (m, k) 陣列與 (k, n) 陣列相乘,但有一個變化。我們只允許使用基本運算 matmul_small,它將小矩陣相乘(例如 m、k、n <= 256)。我們該如何做到?

矩陣乘法的一個良好特性是,輸出的每個區塊都可以表示為輸入的行區塊和列區塊的幾個較小矩陣乘法的總和。形式上,如果我們有輸入陣列 \(x \in \mathbb{R}^{m \times k}\)\(y \in \mathbb{R}^{k \times n}\) 以及輸出 \(z \in \mathbb{R}^{m \times n}\),我們沿著大小為 \(b_m, b_k, b_n\) 的維度將它們分解成區塊。

例如,\(x\) 將被分解為

\[\begin{split} \begin{bmatrix} x_{0, 0} & \cdots & x_{0, i_k} \\ x_{1, 0} & \cdots & x_{1, i_k} \\ \vdots & \ddots & \vdots \\ x_{i_m, 0} & \cdots & x_{i_m, i_k} \\ \end{bmatrix} \end{split}\]

其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我們可以類似地分解 \(y\)\(z\)。)

對於特定的輸出區塊 \(z_{ij}\),我們可以將其計算為

\[ z_{ij} = \sum_k x_{ik} y_{kj} \]

因此,每個輸出區塊 \(z_{ij}\) 是幾個較小的區塊矩陣乘法 \(x_{ik} y_{kj}\) 的總和。以下是我們在 NumPy 中實作此演算法的方式

def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

我們的 block_matmul 函數現在應該可以處理大於 256 的輸入(儘管我們假設輸入維度可以均勻地除以 256)。

m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)

block_matmul 透過觀察到每個大小為 (bm, bn) 的輸出區塊可以透過累積幾個 (bm, bk) x (bk, bn) 大小的矩陣乘法來計算,從而將矩陣乘法分解為許多較小的矩陣乘法。

TPU 和 GPU 就像這樣進行矩陣乘法!它們原生支援類似於 matmul_small 的小矩陣乘法,因此為了在執行更大的矩陣乘法時利用此硬體,我們將應用 block_matmul 分解。

平鋪和管線化#

先前的指南中,我們介紹了在 Pallas 中平鋪計算和管線化的工作原理。為了確保我們的運算單元始終在工作,並且永遠不會因記憶體傳輸而停滯,我們將核心的下一次迭代的記憶體傳輸與目前的一次迭代重疊。

在 Pallas 中,我們透過 BlockSpecgrid 來指定。請注意,我們在區塊矩陣乘法演算法中已經有一個巢狀 for 迴圈。我們可以透過 grid 在 Pallas 中指定它。區塊矩陣乘法中的切片也可以透過 BlockSpec 來指定。

您的第一個矩陣乘法核心#

綜合以上所述,以下是一個區塊矩陣乘法核心的實作,它將記憶體傳輸與計算管線化。我們建立了一個 3 維網格,對應於 NumPy 程式碼中的 3 層巢狀迴圈。請注意,雖然 MXU 只能夠將小區塊相乘,但 Pallas 會自動採用更大的區塊,並自動將它們平鋪到 MXU 上。

網格的最後一個維度對應於矩陣乘法的收縮維度,並且是縮減維度,因此我們需要確保初始化累加器。

def matmul_kernel(x_ref, y_ref, z_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    z_ref[...] = jnp.zeros_like(z_ref)

  z_ref[...] += x_ref[...] @ y_ref[...]

def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      matmul_kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
      out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      grid=(m // bm, n // bn, k // bk),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))

矩陣乘法效能#

讓我們思考一下如何分析矩陣乘法效能。當我們考慮矩陣乘法效能時,我們通常關心兩件事:浮點運算的總數 (FLOP) 和記憶體頻寬使用量。從關於 TPU 和管線化的指南中,我們看到為了使用 TPU(以及通用 ML 加速器)上的高效運算單元,我們需要將輸入從 HBM 複製到 VMEM,更靠近運算單元。複製到 HBM 和從 HBM 複製需要時間,而有效率的核心希望將大部分時間花在實際計算上,而不是等待這些傳輸。記憶體頻寬衡量此資料傳輸的速率。

快速注意事項:在本指南中,我們將討論浮點運算,但想要區分 FLOP 與 FLOP/s。當我們說「FLOP」時,我們指的是「浮點運算」,如運算次數。當我們說「FLOP/s」時,我們指的是「每秒浮點運算」,如執行浮點運算的速率

一個 (m, k) x (k, n) 矩陣乘法中的 FLOP 數(約略)為 2 * m * k * n。(技術上它是 n * m * (2k - 1),但對於足夠大的 k 來說,我們的近似值已足夠。)

矩陣乘法的最小記憶體頻寬使用量(假設 float32)是輸入的總大小(複製到 VMEM)加上輸出的總大小(複製到 HBM)。因此,最小頻寬使用量為 (m * k + k * n + m * n) * 4 bytes/float32。如果我們多次重新讀取輸入,則記憶體使用量可能會更大,這通常是這種情況。

一個觀察結果是,矩陣乘法的 FLOP 數在其輸入中是三次方的,而最小頻寬使用量在其輸入中是二次方的。直覺上,這表示 FLOP 的成長速度快於頻寬使用量,這表示我們的矩陣乘法越大,相對於複製,我們擁有的計算就越多。

def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912

現在我們可以計算矩陣乘法的 FLOP 總數和(最小)記憶體頻寬使用量,讓我們看看真正的 TPU 可以處理什麼。

此筆記本在 TPU v5e 晶片上執行,因此我們將使用 v5e 數字(如果您正在執行此筆記本,則您的數字可能會有所不同)。TPU v5e 具有 197 TFLOP/s 的 bf16/f32 計算和 819 GB/s 的記憶體頻寬。透過查看這些數字的比率(稱為算術強度),我們可以獲得在我們變得受 IO 限制之前,此「FLOP/記憶體頻寬使用量」比率可以降低多少的界限(在 TPU v5e 上約為 240 FLOP/byte)。

v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

粗略地說,這些數字告訴我們,矩陣乘法的 FLOP 應該花費 2 * m * k * n / (197 TFLOP/s) 秒,而複製到/從 VMEM 應該花費 (m*k + k*n + m*n) * 4 bytes / 819GB/s 秒。

def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

這個基本計算粗略地告訴我們,我們將能夠多有效率地使用 MXU。如果我們的矩陣乘法運算強度低於晶片的能力,那麼我們的計算將受記憶體限制,也就是說,我們的運算單元將在等待值傳輸時閒置。如果矩陣乘法強度高於晶片的能力,那麼我們將受計算限制

由於矩陣乘法 FLOP 在其輸入大小中是三次方的,而記憶體頻寬使用量是二次方的,因此我們預期隨著我們變得越來越大,我們將受到計算限制,但是這個交叉點非常重要!假設我們正在進行 (1024, 1024) x (1024, 1024) float32 矩陣乘法。

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte

我們的矩陣乘法 FLOP 強度低於晶片的能力。這不好!我們很可能受到這種矩陣乘法的記憶體限制。但是,如果我們的輸入和輸出更大呢?在某些時候,當我們的矩陣乘法變得足夠大時,我們將從記憶體限制 переходить 到計算限制。例如,如果我們有一個矩陣乘法,其中 m = k = n,當 2m**3 / 12m**2 > 240 或當 m = k = n > 1440 時,我們將 переходить(在 TPU v5e 上)。

bfloat16 矩陣乘法#

為了使矩陣乘法更容易在 TPU 上受到計算限制,我們也可以為輸入和輸出使用更小的 dtype。我們之前的範例使用了 float32 輸入和輸出,但 TPU v5e 也支援 bfloat16 資料類型(一種 16 位元浮點格式,也稱為 bf16)用於矩陣乘法。在 TPU v5e 上,我們將具有相同的 FLOP/s,但會將記憶體頻寬使用量減半。這使得對於較小的矩陣來說,更容易受到計算限制。讓我們看看 1024 x 1024 x 1024 bf16 矩陣乘法的強度是多少

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte

我們現在有一個受計算限制的矩陣乘法!

讓我們將 bf16 支援添加到我們的矩陣乘法核心。

原生 MXU bf16 矩陣乘法常式採用兩個輸入 bf16 矩陣,並在 f32 中累積它。我們將透過將 preferred_element_type=jnp.float32 傳遞到 jnp.matmul 中來觸發此常式。我們還需要一個 f32 中的累加器 Ref。然後,我們將在將輸出寫回 HBM 之前將其向下轉換回 bf16。這樣,我們不會損失任何精度,不會執行任何額外的轉換,並且仍然保留 bf16 記憶體頻寬節省。

請注意,目前配置暫存空間的唯一方法是透過 pltpu.PrefetchScalarGridSpec。現在不用擔心它到底做什麼 – 您現在只需要知道它允許您在 VMEM 中配置暫存空間。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))

管線化核心的效能#

我們上面關於 FLOP 與記憶體使用量的分析適用於粗略的規模,也就是說,當我們查看總矩陣乘法的大小時。但是,請記住,在實務上,我們正在管線化區塊矩陣乘法的執行,這表示我們有一個迴圈,我們在其中使用較小的區塊進行矩陣乘法。

這表示我們實際上關心核心的每個個別實例的 FLOP 與記憶體頻寬使用量,而不是全域 FLOP 與記憶體頻寬使用量。因此,區塊大小 bmbkbn 對於效能極其重要。即使我們擁有世界上最大的矩陣,如果我們選擇非常小的 bmbkbn,我們也會受到記憶體限制,因為每次我們調用核心時,我們將有太少的 FLOP 來隱藏在背景中發生的記憶體傳輸。

因此,直覺應該是:為了受到計算限制,請使區塊盡可能大!有兩個主要約束

  1. VMEM 使用量:我們的區塊越大,我們使用的 VMEM 就越多。使用足夠大的區塊,我們將耗盡 VMEM。

  2. 管線氣泡:我們的區塊相對於矩陣大小越大,我們在管線中的迴圈迭代次數就越少。這將使管線開始和結束時的氣泡大小相對於總管線更大,並且此額外負荷可能很大。

在 Pallas 中獲得良好的矩陣乘法效能歸結為選擇良好的區塊大小,以平衡此最佳化問題。在實務上,我們通常掃描大量候選區塊大小,分析核心,然後選擇最佳的一個。

現在,讓我們進行一些非常簡單的計時實驗。我們將使用 timeit 來測量運行每個核心所需的時間量。請注意,這是核心實際執行時間的上限,因為我們正在使用 timeit 測量 Python 調度和其他額外負荷。我們將以這種方式計算我們獲得的 FLOP/s 量,並計算我們獲得的利用率百分比(與晶片提供的利用率相比),並且我們將使用一些合理的區塊大小來驗證我們的直覺。

import timeit

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    # print(f"Time: {time}")
    return time
  return run

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func):
  x = jnp.ones((m, k), dtype=dtype)
  y = jnp.ones((k, n), dtype=dtype)
  time = benchmark(mm_func)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00029766598949208854
Matmul FLOP/s:  7214407167121.377
FLOP/s utilization: 3.6621%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011771515250438824
Matmul FLOP/s:  11675553278230.387
FLOP/s utilization: 5.9267%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09183577066054567
Matmul FLOP/s:  11972585626140.668
FLOP/s utilization: 6.0775%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012708659982308746
Matmul FLOP/s:  16897797651282.135
FLOP/s utilization: 8.5776%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.00088908776990138
Matmul FLOP/s:  154584235803001.88
FLOP/s utilization: 78.4692%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006099433819763363
Matmul FLOP/s:  180264539343531.62
FLOP/s utilization: 91.5048%

更大的區塊大小有很大幫助!在較大的矩陣乘法中,我們獲得了非常好的利用率 (80-90%),但最小的矩陣乘法似乎很難獲得良好的效能。

讓我們將其與 XLA 的矩陣乘法進行比較。我們不期望 Pallas 比 XLA 做得更好,因為 XLA 非常擅長產生矩陣乘法,但希望我們接近。透過更仔細的區塊大小調整(留作未來工作),我們也可以達到 XLA 效能。

print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00011943008983507753
Matmul FLOP/s:  17981093801113.996
FLOP/s utilization: 9.1275%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008272899803705514
Matmul FLOP/s:  166131533963991.34
FLOP/s utilization: 84.3307%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006047147869830951
Matmul FLOP/s:  181823175395037.44
FLOP/s utilization: 92.2960%

Pallas 透過一些非常基本的調整,非常接近 XLA 的效能數字!透過嘗試更多區塊大小,我們應該期望完全彌合差距。

範本化矩陣乘法#

現在我們有了一個基本的矩陣乘法核心,我們現在可以嘗試將運算融合到其中。

融合右手側轉置#

要做的常見第一件事是融合轉置。我們說的融合是什麼意思?假設我們想要計算 x @ y.T 而不是 x @ y。天真地,我們可以先計算 y.T,然後將其傳遞到我們有效率的矩陣乘法核心中。但是,y.T 運算本身並非免費的 – 它涉及複製 O(n^2) 資料。理想情況下,我們可以在一個核心中執行矩陣乘法計算轉置,也就是說,將其與矩陣乘法「融合」。

加速器通常支援融合 RHS 轉置的原生矩陣乘法常式。例如 TPU v5e,MXU 允許我們對小陣列執行 x @ y.T。我們可以使用 jax.lax.dot_general 調用此常式,這將比單獨執行轉置然後執行矩陣乘法更有效率。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  # dot_general expects a data structure (contraction_dims, batch_dims),
  # where contraction_dims are the set of dimensions for LHS and RHS that will
  # be contracted (reduced) in the matmul; batch_dims, on the other hand, are
  # looped over. The remaining dimensions will be the input and output dimension
  # of the matmul.
  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            y_block_spec,
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)

我們在 matmul 函數內執行轉置 (y = y.swapaxes(0, 1))。這是因為在 JIT 編譯的 JAX 計算中,維度順序純粹是邏輯的,而不是物理的,因此重新排列維度並不意味著物理佈局差異。但是,當我們將陣列傳遞到 pallas_call 中時,我們確實強制執行了主維度到次維度的順序約束。透過在 matmul 函數內轉置 y,我們要求 y 採用轉置佈局 (n, k) 而不是通常的 (k, n)。但是,使用者仍然會以(邏輯)(k, n) 維度傳遞陣列。

注意:為了基準測試轉置,我們實際上希望當我們將 y 傳遞到核心時,y 處於物理轉置佈局中,因此我們不會測量重新佈局時間。在封裝函數中,我們將(邏輯上)將其轉置回 (k, n),然後再將其傳遞到 matmul 中,因為 matmul 期望邏輯 (k, n) 維度順序。

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = mm_func
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.0003029372810851783
Matmul FLOP/s:  7088872126624.065
FLOP/s utilization: 3.5984%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.012017967159627005
Matmul FLOP/s:  11436123235026.848
FLOP/s utilization: 5.8051%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09500920018996112
Matmul FLOP/s:  11572685861765.383
FLOP/s utilization: 5.8745%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012131539988331496
Matmul FLOP/s:  17701657415839.363
FLOP/s utilization: 8.9856%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008790623804088682
Matmul FLOP/s:  156347213275211.03
FLOP/s utilization: 79.3641%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006107717020204291
Matmul FLOP/s:  180020067095253.78
FLOP/s utilization: 91.3807%

看看我們如何在額外轉置的情況下獲得相同的利用率!

融合啟動函數#

融合啟動函數也很常見。這確保我們不會在有效率、受計算限制的矩陣乘法核心之後,接著執行速度慢、受記憶體限制的啟動核心。

def matmul_kernel(
    x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...],
      y_ref[...],
      dims,
      preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
    activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(
          matmul_kernel,
          nsteps=k // bk,
          transpose_rhs=transpose_rhs,
          activation=activation,
      ),
      grid_spec=pltpu.PrefetchScalarGridSpec(
          num_scalar_prefetch=0,
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              y_block_spec,
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
          scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
          grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False,
                   activation = lambda x: x):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True, activation=activation)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = functools.partial(mm_func, activation=activation)
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()


activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00030103540048003196
Matmul FLOP/s:  7133658182976.541
FLOP/s utilization: 3.6211%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011807117109419778
Matmul FLOP/s:  11640348122095.826
FLOP/s utilization: 5.9088%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09181861146935262
Matmul FLOP/s:  11974823079773.941
FLOP/s utilization: 6.0786%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012622540001757442
Matmul FLOP/s:  17013086492108.6
FLOP/s utilization: 8.6361%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.000896632740041241
Matmul FLOP/s:  153283442968721.44
FLOP/s utilization: 77.8089%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006130605939542875
Matmul FLOP/s:  179347953304919.88
FLOP/s utilization: 91.0396%

額外融合的啟動函數幾乎完全不影響我們的利用率!

結論#

在本指南中,我們介紹了如何使用 Pallas 在 TPU 上撰寫有效率的矩陣乘法。我們討論了區塊矩陣乘法和管線化、如何分析 TPU 矩陣乘法的效能,以及如何撰寫有效率的 bf16 矩陣乘法。最後,我們以範本化矩陣乘法來支援融合轉置和融合啟動函數。

留給讀者的練習

  • 新增對輸入融合的支援。有時我們希望將運算融合到矩陣乘法的輸入中。嘗試更進一步範本化矩陣乘法以支援此功能。

  • 新增對 int8 矩陣乘法的支援。TPU v5 支援原生 int8 矩陣乘法,其 FLOP 是 bf16 的兩倍。嘗試新增對它的支援,看看可能的利用率是多少。

  • 新增對 matmul 函數的反向傳遞支援。您可以使用 jax.custom_vjp 來完成此操作。