jax.experimental.pallas.triton.TritonCompilerParams#

class jax.experimental.pallas.triton.TritonCompilerParams(num_warps=None, num_stages=None, serialized_metadata=None)[原始碼]#

Triton 的編譯器參數。

參數:
  • num_warps (int | None)

  • num_stages (int | None)

  • serialized_metadata (bytes | None)

num_warps#

用於核心的 warp 數量。每個 warp 由 32 個執行緒組成。

類型:

int | None

num_stages#

編譯器應對軟體管線化迴圈使用的階段數。

類型:

int | None

serialized_metadata#

其他編譯器中繼資料。此欄位不穩定,未來可能會移除。

類型:

bytes | None

__init__(num_warps=None, num_stages=None, serialized_metadata=None)#
參數:
  • num_warps (int | None | None)

  • num_stages (int | None | None)

  • serialized_metadata (bytes | None | None)

回傳類型:

None

方法

__init__([num_warps, num_stages, ...])

屬性

PLATFORM

num_stages

num_warps

serialized_metadata