jax.experimental.pallas.mosaic_gpu.TilingTransform#
- class jax.experimental.pallas.mosaic_gpu.TilingTransform(tiling)[原始碼]#
表示記憶體參考的 tiling 轉換。
在形狀為 (M, N) 的陣列上 tiling (X, Y) 將產生 (M // X, N // Y, X, Y) 的轉換後形狀。例如,以 (64, 32) 的 tiling 進行 tiling 的 (256, 256) 區塊將被 tiling 為 (4, 8, 64, 32)。
方法
__init__
(tiling)batch
(leading_rank)回傳一個轉換,它接受具有額外 leading_rank 維度的 ref。
to_gpu_transform
()undo
(ref)屬性
tiling