純量預取和區塊稀疏計算#

在本教學中,我們將介紹 Pallas 中區塊稀疏計算的基礎知識。稀疏計算是撰寫自訂 Pallas 核心而非僅僅使用 JAX/XLA 的主要原因,因為通常很難在 XLA 中表達執行動態計算量的程式,因為陣列形狀是靜態的。在本教學中,我們將學習如何使用 Pallas 的純量預取功能,以便撰寫可以動態跳過計算和記憶體區塊的區塊稀疏核心。

import functools
import timeit
import numpy as np
import jax
from jax import numpy as jnp
from jax import lax
from jax.experimental import checkify
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print("Running on", jax.devices()[0].device_kind)
Running on TPU v5 lite

使用純量預取的動態區塊索引#

我們將利用 Pallas 的「純量預取」功能來讓我們能夠撰寫稀疏核心。純量預取允許您將少量資料傳遞到 SMEM(「純量記憶體」)中,這些資料會在管線開始之前載入(「預取」)。由於此資料在管線之前載入,因此可用於每個 BlockSpec 的 index_map 中,讓您可以執行資料相關的索引計算。本教學的主要目標是講解利用此功能的常見程式設計模式。

若要使用純量預取,請使用 pltpu.PrefetchScalarGridSpec 來取代標準的 pl.GridSpec

class PrefetchScalarGridSpec:
  def __init__(self,
    num_scalar_prefetch: int,
    grid: tuple[int, ...],
    in_specs: PyTree[BlockSpec],
    out_specs: PyTree[BlockSpec],
    scratch_shapes: tuple[MemorySpace, ...]):
      ...

num_scalar_prefetch 參數表示純量預取值的數量。當此參數設定為非零值時,它會變更核心和索引映射的呼叫簽章,以預期額外的預取值。傳遞到 index_map 和核心的預取 Ref 都配置在 SMEM 中,並且不會分割成區塊,因為它們沒有定義 BlockSpec。此外,index_map 和核心的引數順序始終是固定的,如下所述

  • 現在每個 BlockSpecindex_map 預期預取 Ref 會在網格索引之後出現

def index_map(*grid_indices, *prefetch_refs):
    ...
  • 使用者定義的核心預期預取 Ref 會在輸入 Ref 之前出現。此外,暫存 refs 會在輸出 Ref 之後出現。

def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
    ...
  • 當使用 pallas_call 呼叫新核心時,pallas_call 傳回的函式也預期純量預取引數會在輸入之前出現,例如

kernel = pl.pallas_call(...)
result = kernel(*prefetch_args, *input_args)

範例:使用純量預取的區塊動態切片#

讓我們從一個基本範例開始,示範如何使用純量預取功能。我們將實作一個區塊對齊的動態切片核心,它只會根據使用者指定的索引,從較大的陣列中擷取一個區塊

  1. 在核心之外,我們計算要擷取的區塊索引為:block_idx = (start[0] // size[0], start[1] // size[1])

  2. 我們將 block_idx 作為純量預取引數傳遞到 pallas_call 中。

  3. 在我們的索引映射中,我們使用區塊索引來選取對應的區塊,方法是傳回 (block_idx[0], block_idx[1])

當然,此核心的限制在於我們的切片大小必須適合核心區塊內(受 VMEM 大小限制),而且我們只能從大小對齊的索引開始。更進階的核心會將核心區塊大小與切片大小解耦,並允許非對齊的起始索引。

def dynamic_slice_kernel(indices, x_ref, o_ref):
  del indices
  o_ref[...] = x_ref[...]

@checkify.checkify
@functools.partial(jax.jit, static_argnums=(2,))
def block_dynamic_slice(x, starts, sizes):
  grid_spec = pltpu.PrefetchScalarGridSpec(
      num_scalar_prefetch=1,
      grid=(1, 1),
      in_specs=[pl.BlockSpec(
          sizes,
          lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
      out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
  )

  kernel = pl.pallas_call(
    dynamic_slice_kernel,
    grid_spec=grid_spec,
    out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
  )
  # Checkify inserts a runtime assert that starts are divisible by block size.
  checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
  checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
  block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
  return kernel(block_idx, x)

shape = (512, 512)
x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
err.throw()
ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
diff = jnp.max(jnp.abs(result - ref))
print("Error |result - lax.dynamic_slice| =", diff)
Error |result - lax.dynamic_slice| = 0

稀疏核心:表示稀疏資料#

在我們深入實作稀疏核心之前,我們先回顧一下稀疏矩陣是如何表示的。雖然有幾種流行的格式用於儲存稀疏矩陣,但我們將遵循座標列表格式 (COO) 的區塊變體,其中我們將矩陣儲存為 (block_index, block_data) 配對的列表。所有未明確儲存在列表中的區塊都假定為零,這表示如果矩陣中有許多零區塊,我們可以節省大量記憶體。

下圖示範了我們如何將 4x4 稠密矩陣(左)轉換為區塊 COO 格式(右),區塊大小為 2x2。請注意,在稀疏格式中,我們可以避免明確儲存由全零元素組成的右上區塊。

block_coo

我們將使用以下輔助函式來取樣區塊稀疏矩陣。它會傳回用於檢查結果的稠密矩陣,以及每個軸的區塊資料和索引列表。

def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
  """Returns a sampled matrix and its block-sparse representation.

  Args:
    key: RNG Key.
    M: Major array dimension.
    N: Minor array dimension.
    blk_M: Block size along M dimension.
    blk_N: Block size along N dimension.
    p: Probability that a block will be non-zero.
    dtype: dtype of the sampled matrix.

  Returns:
    dense_mat: A (M, N) dense sampled array.
    block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
      the non-zero blocks of the matrix.
    indices_i: A (num_blocks,) array of block indices for the first axis.
    indices_j: A (num_blocks,) array of block indices for the second axis.
  """
  mask_key, blocks_key = jax.random.split(key)
  num_blocks = (M // blk_M, N // blk_N)
  # We first sample a block mask, denoting which blocks are nonzero.
  block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
  num_blocks = jnp.sum(block_mask)
  indices = jnp.where(block_mask)
  # For each non-zero block, we sample a block of random values.
  block_data = jax.random.uniform(blocks_key,
                                  shape=(num_blocks, blk_M, blk_N),
                                  dtype=dtype)
  # For checking purposes, create the dense version of the sparse matrix.
  dense_mat = jnp.zeros((M, N), dtype=dtype)
  for blk in range(num_blocks):
    idx_i = indices[0][blk]
    idx_j = indices[1][blk]
    slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
    slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
    dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
  return dense_mat, block_data, indices[0], indices[1]

範例:稀疏 @ 稠密矩陣乘法#

在我們的第一個範例中,我們將稀疏 LHS 矩陣與稠密 RHS 矩陣相乘,以產生稠密輸出。

我們將使用 2 個迴圈來架構我們的核心網格 - 外迴圈遍歷 RHS/輸出的列,內迴圈遍歷 LHS 的稀疏區塊。在每個內迴圈迭代期間,我們從 LHS 載入一個區塊,並使用收縮維度 (K) 的區塊索引在 RHS 中查找對應的區塊。我們將兩個區塊相乘在一起,並累加到正確的輸出區塊中。一個外迴圈迭代將計算整個列的結果,如下圖所示

sparse_matmul

重要的是,我們在將區塊索引傳遞到核心之前,按列分組區塊索引(例如 [0, 0, 1, 2, 3, 3])。首先,在我們的核心中,我們需要知道何時最初將輸出 ref 中的累加器歸零,如果列索引分組,這很容易做到。其次,Pallas 的管線邏輯不允許我們在非連續迭代中重新訪問輸出 Ref 中的區塊,因此我們需要將所有累加都連續迭代到輸出區塊中。這是因為管線發射器會意識到我們在連續迭代中載入相同的輸出區塊,並將該區塊保留在 VMEM 中。當我們變更輸出區塊時,Pallas 最終會將輸出儲存到 HBM 中,並假設我們永遠不再觸碰它。即使核心在邏輯上是正確的,未能連續存取輸出區塊也會導致值不正確。

M = N = K = 16384
blk_M = blk_N = blk_K = 512


def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
               x_ref, y_ref, _, o_ref, # Kernel inputs.
               accum_scratch,
               ):
  """A DSD (Dense = Sparse @ Dense) matmul kernel."""
  del idxs_k_ref
  blk_idx = pl.program_id(0)
  is_start = blk_idx == 0
  changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
  @pl.when(is_start | changed_blocks)
  def _():
    accum_scratch[...] = jnp.zeros_like(accum_scratch)
  accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)

  next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
  is_end = blk_idx == (num_blocks - 1)
  @pl.when(is_end | next_block_change)
  def _():
    o_ref[...] = accum_scratch[...].astype(o_ref.dtype)


def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
  del j, blk_idxs_i, blk_idxs_k
  return (blk_idx, 0, 0)
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
  del blk_idxs_i
  return (blk_idxs_k[blk_idx], j)
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
  del blk_idxs_k
  return (blk_idxs_i[blk_idx], j)

(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
    jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
num_blocks = X_blocks.shape[0]
Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=2,
    # Note that while num_blocks is static here, Pallas does support
    # dynamic grid sizes.
    grid=(num_blocks, N // blk_N),
    in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              # Placeholder for a zeros-array used by input_output_aliases.
              pl.BlockSpec((blk_M, blk_N), o_map),
              ],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  dsd_kernel,
  grid_spec=grid_spec,
  out_shape=out_shape,
  # We use input-output aliases to zero-out o_ref for blocks that we never
  # visit. By passing in an array of zeros we avoid having o_ref start with
  # uninitialized values.
  input_output_aliases={4: 0},  # Map zeros to o_ref.
)
args = (indices_i, indices_k, X_blocks, Y, zeros)
result = kernel(*args)

ref = X_dense @ Y
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 0

我們可以進行快速基準測試,以比較我們的稀疏核心與 JAX 中稠密 matmul 的效能。在 TPU v5e 晶片上,此核心實現了約 ~6 倍的速度提升,而稀疏因子理論上為 10 倍。

這裡有一些主要的效能提示,主要集中在減少 HBM/VMEM 之間的通訊開銷

  • 使用 dtype=jnp.bfloat16 對於效能至關重要,因為它可以將記憶體頻寬減少一半。

  • 使用更大的區塊大小也有幫助,因為矩陣乘法是 \(O(N^3)\) 計算和 \(O(N^2)\) 記憶體操作。隨著 \(N\) 變大,核心會變成運算受限。但是,在實務中對此的反駁是,較小的區塊大小也使資料更稀疏,因此這是一個應仔細選擇的參數。

# Benchmark Sparse Pallas kernel vs reference JAX implementation

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    return time
  return run


n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda x, y: x @ y)
time = benchmark(ref_impl, n_trials)(X_dense, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)
Reference: 46.953 ms (avg over 100 trials)

稠密資料上的稀疏存取模式#

在我們之前的範例中,我們考慮了資料本身是稀疏的情況。這在核心結構中表現為核心網格中的一個維度是動態的,並且迴圈遍歷非零區塊的數量 (num_blocks)。

當底層資料是稠密的,但我們希望對其執行稀疏計算時,會出現第二種有用的程式設計模式。在這種情況下,我們的核心網格將是稠密的,但我們希望根據區塊稀疏遮罩跳過網格中的某些區塊。當在許多機器學習應用程式中使用遮罩時,通常會出現這種程式設計模式,例如自我注意中的因果或局部遮罩。在這些情況下,我們可以完全跳過遮罩歸零的區塊中的計算。jax/experimental/pallas/ops/tpu 或 PyTorch 的 FlexAttention 中可以找到此程式設計模式的範例。

處理稠密資料上的稀疏存取模式的主要效能考量是與管線化的互動。在任何給定的核心迭代中,Pallas 管線發射器都會嘗試透過針對網格的下一次迭代,呼叫每個 BlockSpecindex_map 來預取下一個資料區塊。但是,如果我們的計算是稀疏的,我們可能會跳過網格中下一個區塊的計算,因此我們需要某種方法來告訴管線開始提取我們未跳過的下一個區塊。為了做到這一點,我們需要建構預取映射,其中包含每個核心輸入的下一個非跳過區塊的索引。下圖說明了如何為以類似 COO 格式儲存的區塊稀疏遮罩建構預取映射。

prefetch_map

左圖:稀疏存取模式,其中藍色表示我們需要計算的非零遮罩區塊。右圖:預取映射,其中陣列的每個元素都包含下一個非零區塊資料的索引。

建構預取映射後,我們可以將該映射作為純量預取引數傳遞,並在 BlockSpec 的 index_map 函式中查詢它。

def mask_index_map(prefetch_map, i, j, ...):
  next_nonzero_block = prefetch_map[i, j]
  return (next_nonzero_block, 0, 0)

我們可以為核心的其他輸入建構類似的索引映射。對於稠密輸入,您很可能需要建構預取映射,這些映射指向網格中下一個非零區塊索引。我們的下一個範例將提供使用這些預取映射的範例。

範例:具有區塊稀疏輸出遮罩的稠密 @ 稠密矩陣乘法#

在我們的下一個範例中,我們將介紹稠密矩陣乘法,並使用稀疏輸出遮罩進行融合,並使用預取映射來改善管線效能。我們將使用遮罩來選擇性地跳過計算歸零的輸出區塊,從而節省計算成本。

由於我們將使用稀疏遮罩,因此我們先實作一個函式,將以稠密格式儲存的 N x M 遮罩轉換為區塊稀疏格式。我們還需要計算預取映射,以協助管線發射器知道接下來要提取哪個區塊。總之,我們的 sparsify_mask 函式會計算

  • 形狀為 (num_N_blocks, num_M_blocks)block_mask,指示區塊是否全為零(值 0)或包含非零元素(值 1)。如果 block_mask 的值為 0,我們可以跳過在核心中計算區塊。

  • 形狀為 (num_N_blocks, num_M_blocks)prefetch_mask 陣列,由 mask_data 中下一個非零區塊的索引組成。

  • 形狀為 (num_N_blocks, num_M_blocks)prefetch_i 陣列,由遮罩的下一個非遮罩 i 索引組成。

  • 形狀為 (num_N_blocks, num_M_blocks)prefetch_j 陣列,由遮罩的下一個非遮罩 j 索引組成。

  • 形狀為 (num_blocks, blk_N, blk_M)mask_data 陣列,包含遮罩的非零區塊的資料。

def sparsify_mask(mask: jax.Array,
                  block_shape: tuple[int, int]):
  """Preprocesses a mask into a sparse reprentation.

  Args:
    mask: A boolean array of shape [M, N]
    block_shape: The size of a single block.

  Returns:
    block_mask: A block_shape array of booleans indicating whether a block
      is all-zeros (0) or contains non-zero elements (1).
    prefetch_mask: A block_shape array of integers indicating the index of the
      next non-zero block.
    mask_data: A (num_blocks, block_shape) array containing
      the data for non-zero blocks of the mask.
  """
  M, N = mask.shape
  bm, bn = block_shape

  block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
  mask_types_finder = []
  mask_data = []
  mask_type_idxs = []

  next_mask_type_idx = 0
  prefetch_mask = jnp.zeros_like(block_mask)
  next_i = (M // bm) - 1
  next_j = (N // bn) - 1
  prefetch_i = jnp.zeros_like(block_mask)
  prefetch_j = jnp.zeros_like(block_mask)
  for i in range(M // bm, -1, -1):
    for j in range(N // bn, -1, -1):
      mask_block = mask[i * bm :(i + 1) * bm,
                        j * bn :(j + 1) * bn]
      is_nonzero = jnp.any(mask_block)
      if is_nonzero:
        try:
          type_index = mask_types_finder.index(str(mask_block))
        except ValueError:
          type_index = len(mask_types_finder)
          mask_types_finder.append(str(mask_block))
          mask_data.append(mask_block)
        next_mask_type_idx = type_index
        next_i = i
        next_j = j
      else:
        type_index = -1
      mask_type_idxs.append(type_index)
      block_mask = block_mask.at[i, j].set(is_nonzero)
      prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
      prefetch_i = prefetch_i.at[i, j].set(next_i)
      prefetch_j = prefetch_j.at[i, j].set(next_j)
  return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)

在核心的結構方面,我們使用與我們在先前教學中介紹的標準矩陣乘法核心相同的網格模式,並在 NMK 維度上使用 3 個迴圈。在核心本身內部,我們先檢查 block_mask,以查看目前輸出區塊的遮罩是否全為零。如果遮罩全為零,我們可以跳過計算並移至下一個區塊;否則,我們需要計算矩陣乘法,然後遮罩結果。

M = N = K = 16384
blk_M = blk_N = 512
blk_K = 1024

def sparse_mask_matmul(
    block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
    x_ref, y_ref, mask_ref, o_ref,  # Kernel inputs.
    accum_scratch
    ):
  del prefetch_mask, prefetch_i, prefetch_j
  i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
  should_compute = block_mask_ref[i, j] != 0
  @pl.when(k == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

  # We only compute the output for blocks with non-zero masks.
  # Otherwise we skip the computation entirely.
  @pl.when(should_compute)
  def _():
    result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
    accum_scratch[...] += result
    @pl.when(k == pl.num_programs(2) - 1)
    def _():
      o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)

X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
mask = jnp.ones((M, N), dtype=jnp.int32)
mask = jnp.tril(mask)
block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))

def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_j
  # Zero-out the k index if the mask is zero, to avoid constantly fetching
  # new blocks in the inner loop for blocks we are skipping.
  k_fetch = (block_mask[i, j] != 0) * k
  return (prefetch_i[i, j], k_fetch)

def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_i
  k_fetch = (block_mask[i, j] != 0) * k
  return (k_fetch, prefetch_j[i, j])

def mask_map(i, j, k, block_mask, prefetch_mask, *_):
  del k, block_mask
  return (prefetch_mask[i, j], 0, 0)

def o_map(i, j, k, *_):
  del k
  return (i, j)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=4,
    grid=(M // blk_M, N // blk_N, K // blk_K),
    in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              pl.BlockSpec((1, blk_M, blk_N), mask_map)],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  sparse_mask_matmul,
  grid_spec=grid_spec,
  out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
)
args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
result = kernel(*args)

ref = mask * (X @ Y)
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 1.0252e-05

現在讓我們比較效能與原始稠密實作。在 TPU v5e 上,我們使用稀疏核心實現了約 ~1.8 倍的速度提升,而理論上的最佳情況是 2 倍,因為我們使用了下三角遮罩並且僅訪問了一半可能的輸出。

我們通常預期效能會更接近理論峰值,因為我們的輸入變得更大,因為我們沒有完全達到理論效能的幾個主要原因是

  • 由於沿對角線的區塊混合了 0 和 1,因此我們跳過的計算略少於一半,對於混合區塊,我們需要計算整個區塊。隨著輸入變大,我們的混合區塊開銷相對於整體計算變得更小。

  • 隨著輸入變大,管線氣泡也佔整體執行時間的百分比較少。

n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
time = benchmark(ref_impl, n_trials)(mask, X, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)
Reference: 49.988 ms (avg over 100 trials)