jax.experimental.pallas.mosaic_gpu 模組#

Pallas 針對 H100 的實驗性 GPU 後端。

這些 API 極度不穩定,可能每週變更。使用風險自負。

類別#

Barrier(num_arrivals[, num_barriers])

GPUBlockSpec([block_shape, index_map, ...])

GPUCompilerParams(*[, approx_math, ...])

Mosaic GPU 編譯器參數。

GPUMemorySpace(value)

列舉。

Layout(value)

列舉。

SwizzleTransform(swizzle)

TilingTransform(tiling)

表示記憶體參考的平鋪轉換。

TransposeTransform(permutation)

轉置平鋪的 memref。

WGMMAAccumulatorRef(shape, dtype, _init)

函式#

barrier_arrive(barrier)

到達給定的屏障。

barrier_wait(barrier)

等待給定的屏障。

commit_smem()

提交所有寫入至 SMEM,使其對載入、TMA 和 WGMMA 可見。

copy_gmem_to_smem(src, dst, barrier)

非同步地將 GMEM 參考複製到 SMEM 參考。

copy_smem_to_gmem(src, dst[, predicate])

非同步地將 SMEM 參考複製到 GMEM 參考。

emit_pipeline(body, *, grid[, in_specs, ...])

建立函式以在 Pallas 核心內發出手動管線。

layout_cast(x, new_layout)

轉換給定陣列的佈局。

set_max_registers(n, *, action)

設定 warp 擁有的最大暫存器數量。

wait_smem_to_gmem(n[, wait_read_only])

等待直到飛行中的 SMEM->GMEM 複製操作不超過 n 個。

wgmma(acc, a, b)

在給定的參考上執行非同步 warp group matmul-accumulate。

wgmma_wait(n)

等待直到飛行中的 WGMMA 操作不超過 n 個。

別名#

ACC

WGMMAAccumulatorRef 的別名

GMEM

jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM 的別名。

SMEM

jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM 的別名。