jax.experimental.pallas.pallas_call#

jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None)[source]#

在某些輸入上調用 Pallas 核心。

請參閱Pallas 快速入門

參數:
  • kernel (Callable[..., None]) – 核心函數,接收每個輸入和輸出的 Ref。Ref 的形狀由對應的 in_specsout_specs 中的 block_shape 給定。

  • out_shape (Any) – jax.ShapeDtypeStruct 的 PyTree,描述輸出的形狀和 dtype。

  • grid_spec (GridSpec | None | None) – 指定 gridin_specsout_specsscratch_shapes 的替代方法。如果給定,則不得同時給定其他參數。

  • grid (TupleGrid) – 迭代空間,以整數元組表示。核心將執行 prod(grid) 次。詳情請參閱 grid,又名迴圈中的核心

  • in_specs (BlockSpecTree) – jax.experimental.pallas.BlockSpec 的 PyTree,其結構與位置引數的結構匹配。in_specs 的預設值為所有輸入指定整個陣列,例如 pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)。詳情請參閱 BlockSpec,又名如何將輸入分塊

  • out_specs (BlockSpecTree) – jax.experimental.pallas.BlockSpec 的 PyTree,其結構與輸出的結構匹配。out_specs 的預設值為輸出指定整個陣列,例如 pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)。詳情請參閱 BlockSpec,又名如何將輸入分塊

  • scratch_shapes (ScratchShapeTree) – 核心所需的後端特定暫時物件的 PyTree,例如暫時緩衝區、同步化基本元件等。

  • input_output_aliases (dict[int, int]) – 一個字典,將某些輸入的索引映射到別名它們的輸出的索引。這些索引位於展平的輸入和輸出中。

  • debug (bool) – 如果為 True,Pallas 會在處理核心時印出核心的各種中間形式。

  • interpret (bool) – 將 pallas_call 作為掃描網格的 jax.jit 執行,該網格的主體是作為 JAX 函數降低的核心。這不需要 TPU 或 GPU,並且是在 CPU 上執行 Pallas 核心的唯一方法。這對於偵錯很有用。

  • name (str | None | None) – 如果存在,則指定用於偵錯和錯誤訊息中此核心呼叫的名稱。在此名稱中,我們會附加定義核心函數的檔案和行,例如:{name} for kernel function {kernel_name} at {file}:{line}。如果遺失,則我們使用 {kernel_name} at {file}:{line}

  • compiler_params (dict[str, Any] | pallas_core.CompilerParams | None | None) – 可選的編譯器參數。如果提供字典,則其格式應為 {platform: {param_name: param_value}},其中 platform 為 ‘mosaic’ 或 ‘triton’。也可以為 TPU 傳入 jax.experimental.pallas.tpu.TPUCompilerParams,以及為 Triton/GPU 傳入 jax.experimental.pallas.gpu.TritonCompilerParams

  • backend (_Backend | None | None) – 可選的字串文字,可以是 “mosaic_tpu”、“triton” 或 “mosaic_gpu” 之一,用於決定要使用的後端。None 表示讓 pallas 決定。

  • cost_estimate (CostEstimate | None | None)

傳回:

一個函數,可以針對多個位置陣列引數呼叫以調用 Pallas 核心。

傳回類型:

Callable[…, Any]