Grids 和 BlockSpecs#

grid,又名迴圈中的核心#

當使用 jax.experimental.pallas.pallas_call() 時,核心函式會在不同的輸入上執行多次,這是透過 pallas_callgrid 引數指定的。概念上

pl.pallas_call(some_kernel, grid=(n,))(...)

對應到

for i in range(n):
  some_kernel(...)

Grids 可以推廣為多維度,對應於巢狀迴圈。例如:

pl.pallas_call(some_kernel, grid=(n, m))(...)

等同於

for i in range(n):
  for j in range(m):
    some_kernel(...)

這可以推廣到任何整數元組(長度為 d 的 grid 將對應於 d 個巢狀迴圈)。核心會執行 prod(grid) 次。預設 grid 值 () 會導致一次核心調用。這些調用中的每一次都稱為「程式」。若要存取核心目前正在執行的程式(即 grid 的哪個元素),我們使用 jax.experimental.pallas.program_id()。例如,對於調用 (1, 2)program_id(axis=0) 會傳回 1,而 program_id(axis=1) 會傳回 2。您也可以使用 jax.experimental.pallas.num_programs() 來取得給定軸的 grid 大小。

請參閱 透過範例瞭解 Grids,以取得使用此 API 的簡單核心。

BlockSpec,又名如何將輸入分塊#

grid 引數結合使用,我們需要為 Pallas 提供關於如何為每次調用切分輸入的資訊。具體來說,我們需要提供 迴圈迭代 要操作的輸入和輸出的哪個區塊 之間的對應關係。這是透過 jax.experimental.pallas.BlockSpec 物件提供的。

在我們深入探討 BlockSpec 的細節之前,您可能想要重新瀏覽 Pallas 快速入門中的 透過範例瞭解 Block specs

BlockSpec 會透過 in_specsout_specs 提供給 pallas_call,每個輸入和輸出分別有一個。

首先,我們討論當 indexing_mode == pl.Blocked()BlockSpec 的語意。

非正式地說,BlockSpecindex_map 將調用索引(與 grid 元組的長度一樣多)作為引數,並傳回區塊索引(整體陣列的每個軸一個區塊索引)。然後,每個區塊索引乘以 block_shape 中的對應軸大小,以取得對應陣列軸上的實際元素索引。

注意

並非所有區塊形狀都受到支援。

  • 在 TPU 上,僅支援秩至少為 1 的區塊。此外,您的區塊形狀的最後兩個維度必須等於整體陣列的各自維度,或分別可被 8 和 128 整除。對於秩為 1 的區塊,區塊維度必須等於陣列維度,或可被 128 * (32 / bitwidth(dtype)) 整除。

  • 在 GPU 上,區塊本身的大小不受限制,但每個操作都必須在大小為 2 的冪次的陣列上運作。

如果區塊形狀無法均勻分割整體形狀,則每個軸上的最後一次迭代仍將收到對 block_shape 區塊的參考,但超出邊界的元素會在輸入時填補,並在輸出時捨棄。填補的值未指定,您應假設它們是垃圾。在 interpret=True 模式中,我們會以 NaN 填補浮點值,讓使用者有機會發現存取超出邊界元素的狀況,但不應依賴此行為。請注意,每個區塊中至少有一個元素必須在邊界內。

更精確地說,形狀為 x_shape 的輸入 x 的每個軸的切片計算方式如下面的函式 slice_for_invocation 所示

>>> import jax
>>> from jax.experimental import pallas as pl
>>> def slices_for_invocation(x_shape: tuple[int, ...],
...                           x_spec: pl.BlockSpec,
...                           grid: tuple[int, ...],
...                           invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
...   assert len(invocation_indices) == len(grid)
...   assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
...   block_indices = x_spec.index_map(*invocation_indices)
...   assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
...   elem_indices = []
...   for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
...     start_idx = block_idx * block_size
...     # At least one element of the block must be within bounds
...     assert start_idx < x_size
...     elem_indices.append(slice(start_idx, start_idx + block_size))
...   return elem_indices

例如

>>> slices_for_invocation(x_shape=(100, 100),
...                       x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
...                       grid = (10, 5),
...                       invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]

>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
...                       x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
...                       grid = (10, 5, 4),
...                       invocation_indices = (2, 4, 0))
[slice(20, 30, None), slice(80, 100, None)]

>>> # An example when the block is partially out-of-bounds in the 2nd axis.
>>> slices_for_invocation(x_shape=(100, 90),
...                       x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
...                       grid = (10, 5),
...                       invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]

下面定義的函式 show_program_ids 使用 Pallas 來顯示調用索引。iota_2D_kernel 將以十進位數字填滿每個輸出區塊,其中第一位數字代表第一個軸上的調用索引,第二位數字代表第二個軸上的調用索引

>>> def show_program_ids(x_shape, block_shape, grid,
...                      index_map=lambda i, j: (i, j),
...                      indexing_mode=pl.Blocked()):
...   def program_ids_kernel(o_ref):  # Fill the output block with 10*program_id(1) + program_id(0)
...     axes = 0
...     for axis in range(len(grid)):
...       axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis)
...     o_ref[...] = jnp.full(o_ref.shape, axes)
...   res = pl.pallas_call(program_ids_kernel,
...                        out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
...                        grid=grid,
...                        in_specs=[],
...                        out_specs=pl.BlockSpec(block_shape, index_map, indexing_mode=indexing_mode),
...                        interpret=True)()
...   print(res)

例如

>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2),
...                  index_map=lambda i, j: (i, j))
[[ 0  0  0  1  1  1]
 [ 0  0  0  1  1  1]
 [10 10 10 11 11 11]
 [10 10 10 11 11 11]
 [20 20 20 21 21 21]
 [20 20 20 21 21 21]
 [30 30 30 31 31 31]
 [30 30 30 31 31 31]]

>>> # An example with out-of-bounds accesses
>>> show_program_ids(x_shape=(7, 5), block_shape=(2, 3), grid=(4, 2),
...                  index_map=lambda i, j: (i, j))
[[ 0  0  0  1  1]
 [ 0  0  0  1  1]
 [10 10 10 11 11]
 [10 10 10 11 11]
 [20 20 20 21 21]
 [20 20 20 21 21]
 [30 30 30 31 31]]

>>> # It is allowed for the shape to be smaller than block_shape
>>> show_program_ids(x_shape=(1, 2), block_shape=(2, 3), grid=(1, 1),
...                  index_map=lambda i, j: (i, j))
[[0 0]]

當多個調用寫入輸出陣列的相同元素時,結果取決於平台。

在下面的範例中,我們有一個 3D grid,其中最後一個 grid 維度未在區塊選擇中使用 (index_map=lambda i, j, k: (i, j))。因此,我們對相同的輸出區塊迭代 10 次。下面顯示的輸出是在 CPU 上使用 interpret=True 模式產生的,目前該模式循序執行調用。在 TPU 上,程式以平行和循序的組合方式執行,此函式產生顯示的輸出。請參閱 值得注意的屬性和限制

>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
...                  index_map=lambda i, j, k: (i, j))
[[  9   9   9  19  19  19]
 [  9   9   9  19  19  19]
 [109 109 109 119 119 119]
 [109 109 109 119 119 119]
 [209 209 209 219 219 219]
 [209 209 209 219 219 219]
 [309 309 309 319 319 319]
 [309 309 309 319 319 319]]

block_shape 中以維度值出現的 None 值行為如同值 1,但對應的區塊軸會被擠壓。在下面的範例中,觀察到當區塊形狀指定為 (None, 2) 時,o_ref 的形狀為 (2,)(前導維度被擠壓)。

>>> def kernel(o_ref):
...   assert o_ref.shape == (2,)
...   o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
>>> pl.pallas_call(kernel,
...                jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
...                out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
...                grid=(3, 2), interpret=True)()
Array([[ 0,  0, 10, 10],
       [ 1,  1, 11, 11],
       [ 2,  2, 12, 12]], dtype=int32)

當我們建構 BlockSpec 時,我們可以對 block_shape 參數使用值 None,在這種情況下,整體陣列的形狀會被用作 block_shape。如果我們對 index_map 參數使用值 None,則會使用傳回零元組的預設索引對應函式:index_map=lambda *invocation_indices: (0,) * len(block_shape)

>>> show_program_ids(x_shape=(4, 4), block_shape=None, grid=(2, 3),
...                  index_map=None)
[[12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]]

>>> show_program_ids(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
...                  index_map=None)
[[12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]
 [12 12 12 12]]

「未分塊」的索引模式#

上面記錄的行為適用於 indexing_mode=pl.Blocked()。當使用 pl.Unblocked 索引模式時,索引對應函式傳回的值會直接用作陣列索引,而無需先按區塊大小縮放它們。當使用未分塊模式時,您可以為每個維度指定低-高填補元組作為陣列的虛擬填補:行為就好像整體陣列在輸入時被填補一樣。未分塊模式中的填補值不提供任何保證,這與區塊形狀無法分割整體陣列形狀時的分塊索引模式的填補值類似。

未分塊模式目前僅在 TPU 上受支援。

>>> # unblocked without padding
>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2),
...                  index_map=lambda i, j: (2*i, 3*j),
...                  indexing_mode=pl.Unblocked())
    [[ 0  0  0  1  1  1]
     [ 0  0  0  1  1  1]
     [10 10 10 11 11 11]
     [10 10 10 11 11 11]
     [20 20 20 21 21 21]
     [20 20 20 21 21 21]
     [30 30 30 31 31 31]
     [30 30 30 31 31 31]]

>>> # unblocked, first pad the array with 1 row and 2 columns.
>>> show_program_ids(x_shape=(7, 7), block_shape=(2, 3), grid=(4, 3),
...                  index_map=lambda i, j: (2*i, 3*j),
...                  indexing_mode=pl.Unblocked(((1, 0), (2, 0))))
    [[ 0  1  1  1  2  2  2]
     [10 11 11 11 12 12 12]
     [10 11 11 11 12 12 12]
     [20 21 21 21 22 22 22]
     [20 21 21 21 22 22 22]
     [30 31 31 31 32 32 32]
     [30 31 31 31 32 32 32]]