jax.experimental.pallas.mosaic_gpu.wgmma_wait#

jax.experimental.pallas.mosaic_gpu.wgmma_wait(n)[source]#

等待直到飛行中的 WGMMA 運算不超過 n 個。

參數:

n (int)