jax.experimental.pallas.mosaic_gpu.TilingTransform#

class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[原始碼]#

表示記憶體參考的 tiling 轉換。

在形狀為 (M, N) 的陣列上 tiling (X, Y) 將產生 (M // X, N // Y, X, Y) 的轉換後形狀。例如,以 (64, 32) 的 tiling 進行 tiling 的 (256, 256) 區塊將被 tiling 為 (4, 8, 64, 32)。

參數:

tiling (tuple[int, ...])

__init__(tiling)#
參數:

tiling (tuple[int, ...])

回傳型別:

None

方法

__init__(tiling)

batch(leading_rank)

回傳一個轉換,它接受具有額外 leading_rank 維度的 ref。

to_gpu_transform()

undo(ref)

屬性

tiling