jax.experimental.pallas 模組#

Pallas 模組,用於自訂核心的 JAX 擴充功能。

請參閱 Pallas 文件:https://jax.dev.org.tw/en/latest/pallas.html

後端#

類別#

BlockSpec([block_shape, index_map, ...])

指定應如何為核心的每次調用切分陣列。

GridSpec([grid, in_specs, out_specs, ...])

編碼 jax.experimental.pallas.pallas_call() 的網格參數。

Slice(start, size[, stride])

具有起始索引和大小的切片。

MemoryRef(shape, dtype, memory_space)

類似 jax.ShapeDtypeStruct,但具有記憶體空間。

函式#

pallas_call(kernel, out_shape, *[, ...])

在某些輸入上調用 Pallas 核心。

program_id(axis)

傳回沿著網格給定軸的核心執行位置。

num_programs(axis)

傳回沿著給定軸的網格大小。

load(x_ref_or_view, idx, *[, mask, other, ...])

傳回從給定索引載入的陣列。

store(x_ref_or_view, idx, val, *[, mask, ...])

在給定索引處儲存值。

swap(x_ref_or_view, idx, val, *[, mask, ...])

交換給定索引處的值,並傳回舊值。

atomic_and(x_ref_or_view, idx, val, *[, mask])

原子地計算 `x_ref_or_view[idx] &= val`。

atomic_add(x_ref_or_view, idx, val, *[, mask])

原子地計算 `x_ref_or_view[idx] += val`。

atomic_cas(ref, cmp, val)

對 ref 中的值執行原子比較和交換,替換為給定值。

atomic_max(x_ref_or_view, idx, val, *[, mask])

原子地計算 `x_ref_or_view[idx] = max(x_ref_or_view[idx], val)`。

atomic_min(x_ref_or_view, idx, val, *[, mask])

原子地計算 `x_ref_or_view[idx] = min(x_ref_or_view[idx], val)`。

atomic_or(x_ref_or_view, idx, val, *[, mask])

原子地計算 `x_ref_or_view[idx] |= val`。

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

以原子方式將給定值與給定索引處的值交換。

atomic_xor(x_ref_or_view, idx, val, *[, mask])

原子地計算 `x_ref_or_view[idx] ^= val`。

broadcast_to(a, shape)

debug_print(fmt, *args)

從 Pallas 核心內部列印值。

dot(a, b[, trans_a, trans_b, allow_tf32, ...])

max_contiguous(x, values)

multiple_of(x, values)

run_scoped(f, *types, **kw_types)

使用已分配的引用呼叫函式,並傳回結果。

when(condition)