jax.experimental.pallas.triton 模組#

Triton 專用的 Pallas API。

類別#

TritonCompilerParams([num_warps, ...])

Triton 的編譯器參數。

函式#

approx_tanh(x)

元素級近似雙曲正切:\(\mathrm{tanh}(x)\)

debug_barrier()

同步網格中的所有核心執行。

elementwise_inline_asm(asm, *, args, ...)

內聯組譯碼,應用元素級操作。