jax.experimental.pallas.pallas_call#
- jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None)[source]#
在某些輸入上調用 Pallas 核心。
請參閱Pallas 快速入門。
- 參數:
kernel (Callable[..., None]) – 核心函數,接收每個輸入和輸出的 Ref。Ref 的形狀由對應的
in_specs
和out_specs
中的block_shape
給定。out_shape (Any) –
jax.ShapeDtypeStruct
的 PyTree,描述輸出的形狀和 dtype。grid_spec (GridSpec | None | None) – 指定
grid
、in_specs
、out_specs
和scratch_shapes
的替代方法。如果給定,則不得同時給定其他參數。grid (TupleGrid) – 迭代空間,以整數元組表示。核心將執行
prod(grid)
次。詳情請參閱 grid,又名迴圈中的核心。in_specs (BlockSpecTree) –
jax.experimental.pallas.BlockSpec
的 PyTree,其結構與位置引數的結構匹配。in_specs
的預設值為所有輸入指定整個陣列,例如pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)
。詳情請參閱 BlockSpec,又名如何將輸入分塊。out_specs (BlockSpecTree) –
jax.experimental.pallas.BlockSpec
的 PyTree,其結構與輸出的結構匹配。out_specs
的預設值為輸出指定整個陣列,例如pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)
。詳情請參閱 BlockSpec,又名如何將輸入分塊。scratch_shapes (ScratchShapeTree) – 核心所需的後端特定暫時物件的 PyTree,例如暫時緩衝區、同步化基本元件等。
input_output_aliases (dict[int, int]) – 一個字典,將某些輸入的索引映射到別名它們的輸出的索引。這些索引位於展平的輸入和輸出中。
debug (bool) – 如果為 True,Pallas 會在處理核心時印出核心的各種中間形式。
interpret (bool) – 將
pallas_call
作為掃描網格的jax.jit
執行,該網格的主體是作為 JAX 函數降低的核心。這不需要 TPU 或 GPU,並且是在 CPU 上執行 Pallas 核心的唯一方法。這對於偵錯很有用。name (str | None | None) – 如果存在,則指定用於偵錯和錯誤訊息中此核心呼叫的名稱。在此名稱中,我們會附加定義核心函數的檔案和行,例如:{name} for kernel function {kernel_name} at {file}:{line}。如果遺失,則我們使用 {kernel_name} at {file}:{line}。
compiler_params (dict[str, Any] | pallas_core.CompilerParams | None | None) – 可選的編譯器參數。如果提供字典,則其格式應為 {platform: {param_name: param_value}},其中 platform 為 ‘mosaic’ 或 ‘triton’。也可以為 TPU 傳入 jax.experimental.pallas.tpu.TPUCompilerParams,以及為 Triton/GPU 傳入 jax.experimental.pallas.gpu.TritonCompilerParams。
backend (_Backend | None | None) – 可選的字串文字,可以是 “mosaic_tpu”、“triton” 或 “mosaic_gpu” 之一,用於決定要使用的後端。None 表示讓 pallas 決定。
cost_estimate (CostEstimate | None | None)
- 傳回:
一個函數,可以針對多個位置陣列引數呼叫以調用 Pallas 核心。
- 傳回類型:
Callable[…, Any]