jax.lax 模組#

jax.lax 是一個基本運算程式庫,為 jax.numpy 等程式庫提供基礎。轉換規則 (例如 JVP 和批次處理規則) 通常定義為 jax.lax primitives 的轉換。

許多 primitives 是圍繞等效 XLA 運算的輕薄包裝函式,如 XLA 運算語意 文件所述。在少數情況下,JAX 會偏離 XLA,通常是為了確保運算集在 JVP 和轉置規則的運算下是封閉的。

在可能的情況下,請優先使用 jax.numpy 等程式庫,而不是直接使用 jax.laxjax.numpy API 遵循 NumPy,因此比 jax.lax API 更穩定且更不易變更。

運算子#

abs(x)

元素級絕對值:\(|x|\)

acos(x)

元素級反餘弦:\(\mathrm{acos}(x)\)

acosh(x)

元素級反雙曲餘弦:\(\mathrm{acosh}(x)\)

add(x, y)

元素級加法:\(x + y\)

after_all(*operands)

合併一或多個 XLA 權杖值。

approx_max_k(operand, k[, ...])

以近似方式傳回 operand 的前 k 個最大值及其索引。

approx_min_k(operand, k[, ...])

以近似方式傳回 operand 的前 k 個最小值及其索引。

argmax(operand, axis, index_dtype)

沿著 axis 計算最大元素的索引。

argmin(operand, axis, index_dtype)

沿著 axis 計算最小元素的索引。

asin(x)

元素級反正弦:\(\mathrm{asin}(x)\)

asinh(x)

元素級反雙曲正弦:\(\mathrm{asinh}(x)\)

atan(x)

元素級反正切:\(\mathrm{atan}(x)\)

atan2(x, y)

兩個變數的元素級反正切:\(\mathrm{atan}({x \over y})\)

atanh(x)

元素級反雙曲正切:\(\mathrm{atanh}(x)\)

batch_matmul(lhs, rhs[, precision])

批次矩陣乘法。

bessel_i0e(x)

指數縮放的 0 階修正貝索函數:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\)

bessel_i1e(x)

指數縮放的 1 階修正貝索函數:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\)

betainc(a, b, x)

元素級正規化不完全貝塔積分。

bitcast_convert_type(operand, new_dtype)

元素級位元轉換。

bitwise_and(x, y)

元素級 AND:\(x \wedge y\)

bitwise_not(x)

元素級 NOT:\(\neg x\)

bitwise_or(x, y)

元素級 OR:\(x \vee y\)

bitwise_xor(x, y)

元素級互斥 OR:\(x \oplus y\)

population_count(x)

元素級 popcount,計算每個元素中設定位元的數量。

broadcast(operand, sizes[, sharding])

廣播陣列,新增前導維度

broadcast_in_dim(operand, shape, ...[, sharding])

包裝 XLA 的 BroadcastInDim 運算子。

broadcast_shapes()

傳回 NumPy 廣播 shapes 所產生的形狀。

broadcast_to_rank(x, rank)

新增 1 的前導維度,使 x 等級為 rank

broadcasted_iota(dtype, shape, dimension[, ...])

iota 的便利包裝函式。

cbrt(x)

元素級立方根:\(\sqrt[3]{x}\)

ceil(x)

元素級天花板:\(\left\lceil x \right\rceil\)

clamp(min, x, max)

元素級夾鉗。

clz(x)

元素級前導零計數。

collapse(operand, start_dimension[, ...])

將陣列的維度摺疊成單一維度。

complex(x, y)

元素級建立複數:\(x + jy\)

concatenate(operands, dimension)

沿著 dimension 串連陣列序列。

conj(x)

元素級複數共軛函數:\(\overline{x}\)

conv(lhs, rhs, window_strides, padding[, ...])

conv_general_dilated 的便利包裝函式。

convert_element_type(operand, new_dtype)

元素級轉換。

conv_dimension_numbers(lhs_shape, rhs_shape, ...)

將卷積 dimension_numbers 轉換為 ConvDimensionNumbers

conv_general_dilated(lhs, rhs, ...[, ...])

一般 n 維卷積運算子,具有選用的擴張。

conv_general_dilated_local(lhs, rhs, ...[, ...])

一般 n 維非共用卷積運算子,具有選用的擴張。

conv_general_dilated_patches(lhs, ...[, ...])

擷取受 conv_general_dilated 感受野約束的圖塊。

conv_transpose(lhs, rhs, strides, padding[, ...])

用於計算 N 維卷積「轉置」的便利包裝函式。

conv_with_general_padding(lhs, rhs, ...[, ...])

conv_general_dilated 的便利包裝函式。

cos(x)

元素級餘弦:\(\mathrm{cos}(x)\)

cosh(x)

元素級雙曲餘弦:\(\mathrm{cosh}(x)\)

cumlogsumexp(operand[, axis, reverse])

沿著 axis 計算累計 logsumexp。

cummax(operand[, axis, reverse])

沿著 axis 計算累計最大值。

cummin(operand[, axis, reverse])

沿著 axis 計算累計最小值。

cumprod(operand[, axis, reverse])

沿著 axis 計算累計乘積。

cumsum(operand[, axis, reverse])

沿著 axis 計算累計總和。

digamma(x)

元素級 digamma:\(\psi(x)\)

div(x, y)

元素級除法:\(x \over y\)

dot(lhs, rhs[, precision, ...])

向量/向量、矩陣/向量和矩陣/矩陣乘法。

dot_general(lhs, rhs, dimension_numbers[, ...])

一般點積/縮約運算子。

dynamic_index_in_dim(operand, index[, axis, ...])

圍繞 dynamic_slice 的便利包裝函式,用於執行整數索引。

dynamic_slice(operand, start_indices, ...)

包裝 XLA 的 DynamicSlice 運算子。

dynamic_slice_in_dim(operand, start_index, ...)

套用至單一維度的 lax.dynamic_slice() 便利包裝函式。

dynamic_update_index_in_dim(operand, update, ...)

圍繞 dynamic_update_slice() 的便利包裝函式,用於在單一 axis 中更新大小為 1 的切片。

dynamic_update_slice(operand, update, ...)

包裝 XLA 的 DynamicUpdateSlice 運算子。

dynamic_update_slice_in_dim(operand, update, ...)

圍繞 dynamic_update_slice() 的便利包裝函式,用於在單一 axis 中更新切片。

eq(x, y)

元素級等於:\(x = y\)

erf(x)

元素級誤差函數:\(\mathrm{erf}(x)\)

erfc(x)

元素級互補誤差函數:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)

erf_inv(x)

元素級反誤差函數:\(\mathrm{erf}^{-1}(x)\)

exp(x)

元素級指數:\(e^x\)

expand_dims(array, dimensions)

將任意數量的尺寸 1 維度插入陣列。

expm1(x)

元素級 \(e^{x} - 1\)

fft(x, fft_type, fft_lengths)

floor(x)

元素級地板:\(\left\lfloor x \right\rfloor\)

full(shape, fill_value[, dtype, sharding])

傳回以 fill_value 填滿的 shape 陣列。

full_like(x, fill_value[, dtype, shape, ...])

根據範例陣列 x 建立類似 np.full 的完整陣列。

gather(operand, start_indices, ...[, ...])

Gather 運算子。

ge(x, y)

元素級大於等於: \(x \geq y\)

gt(x, y)

元素級大於: \(x > y\)

igamma(a, x)

元素級正則化不完全伽瑪函數。

igammac(a, x)

元素級互補正則化不完全伽瑪函數。

imag(x)

元素級提取虛部: \(\mathrm{Im}(x)\)

index_in_dim(operand, index[, axis, keepdims])

圍繞 lax.slice() 的便捷封裝器,用於執行整數索引。

index_take(src, idxs, axes)

integer_pow(x, y)

元素級冪運算: \(x^y\),其中 \(y\) 是一個固定的整數。

iota(dtype, size)

封裝了 XLA 的 Iota 運算符。

is_finite(x)

元素級 \(\mathrm{isfinite}\)

le(x, y)

元素級小於等於: \(x \leq y\)

lgamma(x)

元素級對數伽瑪函數: \(\mathrm{log}(\Gamma(x))\)

log(x)

元素級自然對數: \(\mathrm{log}(x)\)

log1p(x)

元素級 \(\mathrm{log}(1 + x)\)

logistic(x)

元素級邏輯斯蒂 (sigmoid) 函數: \(\frac{1}{1 + e^{-x}}\)

lt(x, y)

元素級小於: \(x < y\)

max(x, y)

元素級最大值: \(\mathrm{max}(x, y)\)

min(x, y)

元素級最小值: \(\mathrm{min}(x, y)\)

mul(x, y)

元素級乘法: \(x \times y\)

ne(x, y)

元素級不等於: \(x \neq y\)

neg(x)

元素級負數: \(-x\)

nextafter(x1, x2)

返回 x1 之後朝 x2 方向下一個可表示的值。

optimization_barrier(operand, /)

防止編譯器跨越障礙移動操作。

pad(operand, padding_value, padding_config)

對陣列應用低、高和/或內部填充。

platform_dependent(*args[, default])

分階段輸出平台特定的程式碼。

polygamma(m, x)

元素級多伽瑪函數: \(\psi^{(m)}(x)\)

population_count(x)

元素級 popcount,計算每個元素中設定位元的數量。

pow(x, y)

元素級冪運算: \(x^y\)

random_gamma_grad(a, x)

元素級 Gamma(a, 1) 樣本的導數。

real(x)

元素級提取實部: \(\mathrm{Re}(x)\)

reciprocal(x)

元素級倒數: \(1 \over x\)

reduce(operands, init_values, computation, ...)

封裝了 XLA 的 Reduce 運算符。

reduce_precision(operand, exponent_bits, ...)

封裝了 XLA 的 ReducePrecision 運算符。

reduce_window(operand, init_value, ...[, ...])

封裝了 XLA 的 ReduceWindowWithGeneralPadding 運算符。

rem(x, y)

元素級餘數: \(x \bmod y\)

reshape(operand, new_sizes[, dimensions, ...])

封裝了 XLA 的 Reshape 運算符。

rev(operand, dimensions)

封裝了 XLA 的 Rev 運算符。

rng_bit_generator(key, shape[, dtype, algorithm])

無狀態 PRNG 位元產生器。

rng_uniform(a, b, shape)

有狀態 PRNG 產生器。

round(x[, rounding_method])

元素級四捨五入。

rsqrt(x)

元素級倒數平方根: \(1 \over \sqrt{x}\)

scatter(operand, scatter_indices, updates, ...)

分散更新運算符。

scatter_add(operand, scatter_indices, ...[, ...])

分散相加運算符。

scatter_apply(operand, scatter_indices, ...)

分散應用運算符。

scatter_max(operand, scatter_indices, ...[, ...])

分散最大值運算符。

scatter_min(operand, scatter_indices, ...[, ...])

分散最小值運算符。

scatter_mul(operand, scatter_indices, ...[, ...])

分散乘法運算符。

shift_left(x, y)

元素級左移位: \(x \ll y\)

shift_right_arithmetic(x, y)

元素級算術右移位: \(x \gg y\)

shift_right_logical(x, y)

元素級邏輯右移位: \(x \gg y\)

sign(x)

元素級符號。

sin(x)

元素級正弦: \(\mathrm{sin}(x)\)

sinh(x)

元素級雙曲正弦: \(\mathrm{sinh}(x)\)

slice(operand, start_indices, limit_indices)

封裝了 XLA 的 Slice 運算符。

slice_in_dim(operand, start_index, limit_index)

圍繞 lax.slice() 的便捷封裝器,僅應用於一個維度。

排序()

封裝了 XLA 的 Sort 運算符。

sort_key_val(keys, values[, dimension, ...])

沿 dimensionkeys 進行排序,並將相同的排列應用於 values

split(operand, sizes[, axis])

沿 axis 分割陣列。

sqrt(x)

元素級平方根: \(\sqrt{x}\)

square(x)

元素級平方: \(x^2\)

squeeze(array, dimensions)

從陣列中擠壓任意數量的尺寸為 1 的維度。

sub(x, y)

元素級減法: \(x - y\)

tan(x)

元素級正切: \(\mathrm{tan}(x)\)

tanh(x)

元素級雙曲正切: \(\mathrm{tanh}(x)\)

top_k(operand, k)

返回 operand 最後一個軸的前 k 個值及其索引。

transpose(operand, permutation)

封裝了 XLA 的 Transpose 運算符。

zeros_like_array(x)

zeta(x, q)

元素級 Hurwitz zeta 函數: \(\zeta(x, q)\)

控制流運算符#

associative_scan(fn, elems[, reverse, axis])

使用結合二元運算以平行方式執行掃描。

cond(pred, true_fun, false_fun, *operands[, ...])

有條件地應用 true_funfalse_fun

fori_loop(lower, upper, body_fun, init_val, *)

lower 迴圈到 upper,通過簡化為 jax.lax.while_loop()

map(f, xs, *[, batch_size])

將函數映射到前導陣列軸上。

scan(f, init[, xs, length, reverse, unroll, ...])

掃描前導陣列軸上的函數,同時攜帶狀態。

select(pred, on_true, on_false)

根據布林謂詞在兩個分支之間進行選擇。

select_n(which, *cases)

從多個案例中選擇陣列值。

switch(index, branches, *operands[, operand])

應用由 index 給定的 branches 中的恰好一個。

while_loop(cond_fun, body_fun, init_val)

cond_fun 為 True 時,在迴圈中重複調用 body_fun

自訂梯度運算符#

stop_gradient(x)

停止梯度計算。

custom_linear_solve(matvec, b, solve[, ...])

使用隱式定義的梯度執行無矩陣線性求解。

custom_root(f, initial_guess, solve, ...[, ...])

可微分地求解函數的根。

並行運算符#

all_gather(x, axis_name, *[, ...])

跨所有副本收集 x 的值。

all_to_all(x, axis_name, split_axis, ...[, ...])

實體化映射的軸並映射不同的軸。

psum(x, axis_name, *[, axis_index_groups])

在 pmapped 軸 axis_name 上計算 x 的 all-reduce 總和。

psum_scatter(x, axis_name, *[, ...])

類似於 psum(x, axis_name),但每個設備僅保留結果的一部分。

pmax(x, axis_name, *[, axis_index_groups])

在 pmapped 軸 axis_name 上計算 x 的 all-reduce 最大值。

pmin(x, axis_name, *[, axis_index_groups])

在 pmapped 軸 axis_name 上計算 x 的 all-reduce 最小值。

pmean(x, axis_name, *[, axis_index_groups])

在 pmapped 軸 axis_name 上計算 x 的 all-reduce 平均值。

ppermute(x, axis_name, perm)

根據排列 perm 執行集體排列。

pshuffle(x, axis_name, perm)

jax.lax.ppermute 的便捷封裝器,具有備用排列編碼

pswapaxes(x, axis_name, axis, *[, ...])

將 pmapped 軸 axis_name 與未映射的軸 axis 交換。

axis_index(axis_name)

返回沿映射軸 axis_name 的索引。

線性代數運算符 (jax.lax.linalg)#

cholesky(x, *[, symmetrize_input])

Cholesky 分解。

eig(x, *[, compute_left_eigenvectors, ...])

一般矩陣的特徵分解。

eigh(x, *[, lower, symmetrize_input, ...])

厄米矩陣的特徵分解。

hessenberg(a)

將方陣簡化為上 Hessenberg 形式。

lu(x)

帶部分主元的 LU 分解。

householder_product(a, taus)

基本 Householder 反射器的乘積。

qdwh(x, *[, is_hermitian, max_iterations, ...])

基於 QR 的動態加權 Halley 迭代,用於極分解。

qr()

QR 分解。

schur(x, *[, compute_schur_vectors, ...])

svd()

奇異值分解。

triangular_solve(a, b, *[, left_side, ...])

三角求解。

tridiagonal(a, *[, lower])

將對稱/厄米矩陣簡化為三對角形式。

tridiagonal_solve(dl, d, du, b)

計算三對角線性系統的解。

參數類別#

class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[source]#

描述卷積的批次、空間和特徵維度。

參數:
  • lhs_spec (Sequence[int]) – 包含 (批次維度、特徵維度、空間維度…) 的非負整數維度數字元組。

  • rhs_spec (Sequence[int]) – 包含 (輸出特徵維度、輸入特徵維度、空間維度…) 的非負整數維度數字元組。

  • out_spec (Sequence[int]) – 包含 (批次維度、特徵維度、空間維度…) 的非負整數維度數字元組。

jax.lax.ConvGeneralDilatedDimensionNumbers#

別名為 tuple[str, str, str] | ConvDimensionNumbers | None

class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#

指定用於計算點積的演算法。

當用於指定 dot()dot_general() 和其他點積函數的 precision 輸入時,此資料結構用於控制用於計算點積的演算法的屬性。此 API 控制用於計算的精度,並允許使用者存取硬體特定的加速。

對這些演算法的支援取決於平台,並且使用不受支援的演算法將在編譯計算時引發 Python 例外。至少在某些平台上已知支援的演算法列在 DotAlgorithmPreset 列舉中,這些是試用此 API 的良好起點。

“點積演算法”由以下參數指定

  • lhs_precision_typerhs_precision_type,操作的 LHS 和 RHS 四捨五入到的資料類型。

  • accumulation_type 用於累積的資料類型。

  • lhs_component_countrhs_component_countnum_primitive_operations 適用於將 LHS 和/或 RHS 分解為多個組件並對這些值執行多個操作的演算法,通常是為了模擬更高的精度。對於沒有分解的演算法,這些值應設定為 1

  • allow_imprecise_accumulation 指定是否允許在某些步驟中使用較低精度的累積(例如 CUBLASLT_MATMUL_DESC_FAST_ACCUM)。

StableHLO 規範 對於點積操作不要求精度類型與輸入或輸出的儲存類型相同,但某些平台可能要求這些類型匹配。此外,dot_general() 的返回類型始終由輸入演算法的 accumulation_type 參數定義(如果指定)。

範例

使用 32 位元浮點累加器累積兩個 16 位元浮點數

>>> algorithm = DotAlgorithm(
...     lhs_precision_type=np.float16,
...     rhs_precision_type=np.float16,
...     accumulation_type=np.float32,
... )
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> dot(lhs, rhs, precision=algorithm)  
array([ 1.,  4.,  9., 16.], dtype=float16)

或者,等效地,使用預設值

>>> algorithm = DotAlgorithmPreset.F16_F16_F32
>>> dot(lhs, rhs, precision=algorithm)  
array([ 1.,  4.,  9., 16.], dtype=float16)

預設值也可以通過名稱指定

>>> dot(lhs, rhs, precision="F16_F16_F32")  
array([ 1.,  4.,  9., 16.], dtype=float16)

可以使用 preferred_element_type 參數返回輸出,而無需向下轉換累積類型

>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32)  
array([ 1.,  4.,  9., 16.], dtype=float32)
參數:
  • lhs_precision_type (DTypeLike)

  • rhs_precision_type (DTypeLike)

  • accumulation_type (DTypeLike)

  • lhs_component_count (int)

  • rhs_component_count (int)

  • num_primitive_operations (int)

  • allow_imprecise_accumulation (bool)

class jax.lax.DotAlgorithmPreset(value)[source]#

用於計算點積的已知演算法列舉。

Enum 提供了一組已命名的 DotAlgorithm 物件,已知這些物件在至少一個平台上受到支援。 有關這些演算法行為的更多詳細資訊,請參閱 DotAlgorithm 文件。

當呼叫 dot()dot_general() 或大多數其他 JAX 點積函數時,可以從此列表中選擇演算法,方法是傳遞此 Enum 的成員或其名稱字串作為 precision 參數。

例如,使用者可以直接使用此 Enum 指定預設值

>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
>>> dot(lhs, rhs, precision=algorithm)  
array([ 1.,  4.,  9., 16.], dtype=float16)

或者,等效地,它們可以通過名稱指定

>>> dot(lhs, rhs, precision="F16_F16_F32")  
array([ 1.,  4.,  9., 16.], dtype=float16)

預設值的名稱通常為 LHS_RHS_ACCUM,其中 LHSRHS 分別是 lhsrhs 輸入的元素類型,而 ACCUM 是累加器的元素類型。 某些預設值具有額外的後綴,每個後綴的含義如下文所述。 支援的預設值為

DEFAULT = 1#

將根據輸入和輸出類型選擇演算法。

ANY_F8_ANY_F8_F32 = 2#

接受任何 float8 輸入類型,並累加到 float32 中。

ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#

類似於 ANY_F8_ANY_F8_F32,但使用更快的累加,代價是較低的準確性。

ANY_F8_ANY_F8_ANY = 4#

類似於 ANY_F8_ANY_F8_F32,但累加類型由 preferred_element_type 控制。

ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#

類似於 ANY_F8_ANY_F8_F32_FAST_ACCUM,但累加類型由 preferred_element_type 控制。

F16_F16_F16 = 6#
F16_F16_F32 = 7#
BF16_BF16_BF16 = 8#
BF16_BF16_F32 = 9#
BF16_BF16_F32_X3 = 10#

_X3 後綴表示該演算法使用 3 個運算來模擬更高的精度。

BF16_BF16_F32_X6 = 11#

類似於 BF16_BF16_F32_X3,但使用 6 個運算而不是 3 個。

TF32_TF32_F32 = 12#
TF32_TF32_F32_X3 = 13#

_X3 後綴表示該演算法使用 3 個運算來模擬更高的精度。

F32_F32_F32 = 14#
F64_F64_F64 = 15#
property supported_lhs_types: tuple[DTypeLike, ...] | None[source]#
property supported_rhs_types: tuple[DTypeLike, ...] | None[source]#
property accumulation_type: DTypeLike | None[source]#
supported_output_types(lhs_dtype, rhs_dtype)[source]#
參數:
  • lhs_dtype (DTypeLike)

  • rhs_dtype (DTypeLike)

返回類型:

tuple[DTypeLike, …] | None

class jax.lax.FftType(value)[source]#

描述要執行的 FFT 運算類型。

FFT = 0#

正向複數到複數 FFT。

IFFT = 1#

反向複數到複數 FFT。

IRFFT = 3#

反向實數到複數 FFT。

RFFT = 2#

正向實數到複數 FFT。

class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[source]#

描述 XLA 的 Gather 運算子的維度編號引數。 有關維度編號含義的更多詳細資訊,請參閱 XLA 文件。

參數:
  • offset_dims (tuple[int, ...]) – gather 輸出中偏移到從 operand 切割的陣列中的維度集合。 必須是升序整數元組,每個整數代表輸出的維度編號。

  • collapsed_slice_dims (tuple[int, ...]) – operand 中維度 i 的集合,這些維度的 slice_sizes[i] == 1,並且在 gather 的輸出中不應具有對應的維度。 必須是升序整數元組。

  • start_index_map (tuple[int, ...]) – 對於 start_indices 中的每個維度,給出要切割的 operand 中的對應維度。 必須是整數元組,大小等於 start_indices.shape[-1]

  • operand_batching_dims (tuple[int, ...]) – operand 中批次維度 i 的集合,這些維度的 slice_sizes[i] == 1,並且在 start_indices (在 start_indices_batching_dims 中的相同索引處)和 gather 的輸出中都應具有對應的維度。 必須是升序整數元組。

  • start_indices_batching_dims (tuple[int, ...]) – start_indices 中批次維度 i 的集合,這些維度在 operand (在 operand_batching_dims 中的相同索引處)和 gather 的輸出中都應具有對應的維度。 必須是整數元組(順序根據與 operand_batching_dims 的對應關係而固定)。

與 XLA 的 GatherDimensionNumbers 結構不同,index_vector_dim 是隱含的;始終存在索引向量維度,並且它必須始終是最後一個維度。 若要收集純量索引,請新增大小為 1 的尾部維度。

class jax.lax.GatherScatterMode(value)[source]#

描述如何在 gather 或 scatter 中處理超出邊界的索引。

可能的值為

CLIP

索引將被限制到最接近的範圍內值,即,使得要 gather 的整個視窗都在範圍內。

FILL_OR_DROP

如果 gathered 視窗的任何部分超出邊界,則返回的整個視窗(即使是那些原本在範圍內的元素)都將以常數填充。 如果 scattered 視窗的任何部分超出邊界,則整個視窗將被丟棄。

PROMISE_IN_BOUNDS

使用者承諾索引在邊界內。 將不會執行額外檢查。 實際上,使用目前的 XLA 實作,這表示超出邊界的 gather 將被限制,但超出邊界的 scatter 將被丟棄。 如果索引超出邊界,則梯度將不正確。

class jax.lax.Precision(value)[source]#

用於 lax 矩陣乘法相關函數的精度列舉。

JAX 函數的裝置相關 precision 引數通常控制加速器後端(即 TPU 和 GPU)上陣列計算的速度和準確性之間的權衡。 對 CPU 後端沒有影響。 這僅對 float32 計算有影響,並且不影響輸入/輸出資料類型。 成員為

DEFAULT

最快模式,但準確度最低。 在 TPU 上:以 bfloat16 執行 float32 計算。 在 GPU 上:如果可用(例如在 A100 和 H100 GPU 上),則使用 tensorfloat32,否則使用標準 float32(例如在 V100 GPU 上)。 別名:'default''fastest'

HIGH

較慢但更準確。 在 TPU 上:以 3 個 bfloat16 通道執行 float32 計算。 在 GPU 上:在可用的情況下使用 tensorfloat32,否則使用 float32。 別名:'high'

HIGHEST

最慢但最準確。 在 TPU 上:以 6 個 bfloat16 執行 float32 計算。 別名:'highest'。 在 GPU 上:使用 float32。

jax.lax.PrecisionLike#

alias of None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset

class jax.lax.RandomAlgorithm(value)[source]#

描述用於 rng_bit_generator 的 PRNG 演算法。

RNG_DEFAULT = 0#

平台的預設演算法。

RNG_THREE_FRY = 1#

Threefry-2x32 PRNG 演算法。

RNG_PHILOX = 2#

Philox-4x32 PRNG 演算法。

class jax.lax.RoundingMethod(value)[source]#

用於處理 jax.lax.round() 中間值(例如 0.5)的捨入策略。

AWAY_FROM_ZERO = 0#

將中間值朝遠離零的方向捨入(例如,0.5 -> 1, -0.5 -> -1)。

TO_NEAREST_EVEN = 1#

將中間值捨入到最接近的偶數整數。 這也稱為「銀行家捨入法」(例如,0.5 -> 0, 1.5 -> 2)。

class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#

描述 XLA 的 Scatter 運算子的維度編號引數。 有關維度編號含義的更多詳細資訊,請參閱 XLA 文件。

參數:
  • update_window_dims (Sequence[int]) – updates 中作為視窗維度的維度集合。 必須是升序整數元組,每個整數代表維度編號。

  • inserted_window_dims (Sequence[int]) – 必須插入到 updates 形狀中的大小為 1 的視窗維度集合。 必須是升序整數元組,每個整數代表輸出的維度編號。 這些是 gather 情況下 collapsed_slice_dims 的鏡像影像。

  • scatter_dims_to_operand_dims (Sequence[int]) – 對於 scatter_indices 中的每個維度,給出 operand 中的對應維度。 必須是整數序列,大小等於 scatter_indices.shape[-1]

  • operand_batching_dims (Sequence[int]) – operand 中批次維度 i 的集合,這些維度在 scatter_indices (在 scatter_indices_batching_dims 中的相同索引處)和 updates 中都應具有對應的維度。 必須是升序整數元組。 這些是 gather 情況下 operand_batching_dims 的鏡像影像。

  • scatter_indices_batching_dims (Sequence[int]) – scatter_indices 中批次維度 i 的集合,這些維度在 operand (在 operand_batching_dims 中的相同索引處)和 gather 的輸出中都應具有對應的維度。 必須是整數元組(順序根據與 input_batching_dims 的對應關係而固定)。 這些是 gather 情況下 start_indices_batching_dims 的鏡像影像。

與 XLA 的 ScatterDimensionNumbers 結構不同,index_vector_dim 是隱含的;始終存在索引向量維度,並且它必須始終是最後一個維度。 若要 scatter 純量索引,請新增大小為 1 的尾部維度。