jax.numpy 模組#

實作 NumPy API,使用 jax.lax 中的基本運算。

雖然 JAX 嘗試盡可能地遵循 NumPy API,但有時 JAX 無法完全遵循 NumPy。

  • 值得注意的是,由於 JAX 陣列是不可變的,NumPy API 中會就地變更陣列的 API 無法在 JAX 中實作。然而,JAX 通常能夠提供純函數式的替代 API。例如,JAX 提供了替代的純索引更新函數 x.at[i].set(y),而不是就地陣列更新 (x[i] = y)(請參閱 ndarray.at)。

  • 相關地,當可能時,某些 NumPy 函數通常會傳回陣列的檢視 (範例為 transpose()reshape())。JAX 版本的此類函數將會改為傳回副本,儘管當使用 jax.jit() 編譯操作序列時,XLA 通常會將其最佳化掉。

  • NumPy 在將值提升為 float64 型別方面非常積極。JAX 有時在型別提升方面較不積極(請參閱型別提升語意)。

  • 某些 NumPy 常式具有資料相關的輸出形狀(範例包括 unique()nonzero())。由於 XLA 編譯器要求陣列形狀在編譯時已知,因此此類操作與 JIT 不相容。因此,JAX 為這些函數新增了可選的 size 引數,可以靜態指定該引數,以便將它們與 JIT 一起使用。

幾乎所有適用的 NumPy 函數都在 jax.numpy 命名空間中實作;它們在下面列出。

ndarray.at

索引更新功能的輔助屬性。

abs(x, /)

jax.numpy.absolute() 的別名。

absolute(x, /)

逐元素計算絕對值。

acos(x, /)

jax.numpy.arccos() 的別名

acosh(x, /)

jax.numpy.arccosh() 的別名

add

逐元素相加兩個陣列。

all(a[, axis, out, keepdims, where])

測試給定軸上的所有陣列元素是否都評估為 True。

allclose(a, b[, rtol, atol, equal_nan])

檢查兩個陣列在公差範圍內是否逐元素近似相等。

amax(a[, axis, out, keepdims, initial, where])

jax.numpy.max() 的別名。

amin(a[, axis, out, keepdims, initial, where])

jax.numpy.min() 的別名。

angle(z[, deg])

傳回複數值數字或陣列的角度。

any(a[, axis, out, keepdims, where])

測試給定軸上的任何陣列元素是否都評估為 True。

append(arr, values[, axis])

傳回一個新陣列,其中值附加到原始陣列的末尾。

apply_along_axis(func1d, axis, arr, *args, ...)

沿軸將函數應用於 1D 陣列切片。

apply_over_axes(func, a, axes)

在指定的軸上重複應用函數。

arange(start[, stop, step, dtype, device])

建立均勻間隔值的陣列。

arccos(x, /)

計算輸入的三角餘弦的逐元素反函數。

arccosh(x, /)

計算輸入的雙曲餘弦的逐元素反函數。

arcsin(x, /)

計算輸入的三角正弦的逐元素反函數。

arcsinh(x, /)

計算輸入的雙曲正弦的逐元素反函數。

arctan(x, /)

計算輸入的三角正切的逐元素反函數。

arctan2(x1, x2, /)

計算 x1/x2 的反正切,並選擇正確的象限。

arctanh(x, /)

計算輸入的雙曲正切的逐元素反函數。

argmax(a[, axis, out, keepdims])

傳回陣列最大值的索引。

argmin(a[, axis, out, keepdims])

傳回陣列最小值的索引。

argpartition(a, kth[, axis])

傳回部分排序陣列的索引。

argsort(a[, axis, kind, order, stable, ...])

傳回排序陣列的索引。

argwhere(a, *[, size, fill_value])

尋找非零陣列元素的索引

around(a[, decimals, out])

jax.numpy.round() 的別名

array(object[, dtype, copy, order, ndmin, ...])

將物件轉換為 JAX 陣列。

array_equal(a1, a2[, equal_nan])

檢查兩個陣列是否逐元素相等。

array_equiv(a1, a2)

檢查兩個陣列是否逐元素相等。

array_repr(arr[, max_line_width, precision, ...])

傳回陣列的字串表示形式。

array_split(ary, indices_or_sections[, axis])

將陣列分割成子陣列。

array_str(a[, max_line_width, precision, ...])

傳回陣列中資料的字串表示形式。

asarray(a[, dtype, order, copy, device])

將物件轉換為 JAX 陣列。

asin(x, /)

jax.numpy.arcsin() 的別名

asinh(x, /)

jax.numpy.arcsinh() 的別名

astype(x, dtype, /, *[, copy, device])

將陣列轉換為指定的 dtype。

atan(x, /)

jax.numpy.arctan() 的別名

atanh(x, /)

jax.numpy.arctanh() 的別名

atan2(x1, x2, /)

jax.numpy.arctan2() 的別名

atleast_1d(*arys)

將輸入轉換為至少具有 1 個維度的陣列。

atleast_2d(*arys)

將輸入轉換為至少具有 2 個維度的陣列。

atleast_3d(*arys)

將輸入轉換為至少具有 3 個維度的陣列。

average(a[, axis, weights, returned, keepdims])

計算加權平均值。

bartlett(M)

傳回大小為 M 的 Bartlett 視窗。

bincount(x[, weights, minlength, length])

計算整數陣列中每個值出現的次數。

bitwise_and

逐元素計算位元 AND 運算。

bitwise_count(x, /)

計算 x 每個元素的絕對值的二進位表示形式中 1 位元的數量。

bitwise_invert(x, /)

jax.numpy.invert() 的別名。

bitwise_left_shift(x, y, /)

jax.numpy.left_shift() 的別名。

bitwise_not(x, /)

jax.numpy.invert() 的別名。

bitwise_or

逐元素計算位元 OR 運算。

bitwise_right_shift(x1, x2, /)

jax.numpy.right_shift() 的別名。

bitwise_xor

逐元素計算位元 XOR 運算。

blackman(M)

傳回大小為 M 的 Blackman 視窗。

block(arrays)

從區塊清單建立陣列。

bool_

bool

broadcast_arrays(*args)

將陣列廣播到通用形狀。

broadcast_shapes(*shapes)

將輸入形狀廣播到通用輸出形狀。

broadcast_to(array, shape)

將陣列廣播到指定的形狀。

c_

沿最後一個軸串連切片、純量和類陣列物件。

can_cast(from_, to[, casting])

如果可以根據轉換規則在資料類型之間進行轉換,則傳回 True。

cbrt(x, /)

計算輸入陣列的逐元素立方根。

cdouble

complex128 的別名

ceil(x, /)

將輸入向上捨入到最接近的整數。

character()

所有字元串純量類型的抽象基底類別。

choose(a, choices[, out, mode])

透過堆疊選擇陣列的切片來建構陣列。

clip([arr, min, max, a, a_min, a_max])

將陣列值裁剪到指定的範圍。

column_stack(tup)

以資料行方式堆疊陣列。

complex_

complex128 的別名

complex128(x)

complex128 型別的 JAX 純量建構函式。

complex64(x)

complex64 型別的 JAX 純量建構函式。

complexfloating()

所有由浮點數組成的複數純量類型的抽象基底類別。

ComplexWarning

將複數 dtype 轉換為實數 dtype 時引發的警告。

compress(condition, a[, axis, size, ...])

使用布林條件沿給定的軸壓縮陣列。

concat(arrays, /, *[, axis])

沿現有軸聯接陣列。

concatenate(arrays[, axis, dtype])

沿現有軸聯接陣列。

conj(x, /)

jax.numpy.conjugate() 的別名

conjugate(x, /)

傳回輸入的逐元素共軛複數。

convolve(a, v[, mode, precision, ...])

兩個一維陣列的卷積。

copy(a[, order])

傳回陣列的副本。

copysign(x1, x2, /)

x2 中每個元素的正負號複製到 x1 中對應的元素。

corrcoef(x[, y, rowvar])

計算皮爾森相關係數。

correlate(a, v[, mode, precision, ...])

兩個一維陣列的相關性。

cos(x, /)

計算輸入中每個元素的三角餘弦值。

cosh(x, /)

計算輸入的逐元素雙曲餘弦值。

count_nonzero(a[, axis, keepdims])

傳回沿著指定軸的非零元素數量。

cov(m[, y, rowvar, bias, ddof, fweights, ...])

估計加權樣本共變異數。

cross(a, b[, axisa, axisb, axisc, axis])

計算兩個陣列的(批次)外積。

csingle

complex64 的別名

cumprod(a[, axis, dtype, out])

沿著軸的元素累積乘積。

cumsum(a[, axis, dtype, out])

沿著軸的元素累積總和。

cumulative_prod(x, /, *[, axis, dtype, ...])

沿著陣列軸的累積乘積。

cumulative_sum(x, /, *[, axis, dtype, ...])

沿著陣列軸的累積總和。

deg2rad(x, /)

將角度從度轉換為弧度。

degrees(x, /)

jax.numpy.rad2deg() 的別名

delete(arr, obj[, axis, assume_unique_indices])

從陣列中刪除條目。

diag(v[, k])

傳回指定的對角線或建構對角陣列。

diag_indices(n[, ndim])

傳回用於存取多維陣列主對角線的索引。

diag_indices_from(arr)

傳回用於存取給定陣列主對角線的索引。

diagflat(v[, k])

傳回一個二維陣列,其中扁平化的輸入陣列佈局在對角線上。

diagonal(a[, offset, axis1, axis2])

傳回陣列的指定對角線。

diff(a[, n, axis, prepend, append])

計算沿著指定軸的陣列元素之間的 n 階差分。

digitize(x, bins[, right, method])

將陣列轉換為 bin 索引。

divide(x1, x2, /)

jax.numpy.true_divide() 的別名。

divmod(x1, x2, /)

逐元素計算 x1 除以 x2 的整數商和餘數

dot(a, b, *[, precision, preferred_element_type])

計算兩個陣列的點積。

double

float64 的別名

dsplit(ary, indices_or_sections)

沿深度方向將陣列分割成子陣列。

dstack(tup[, dtype])

沿深度方向堆疊陣列。

dtype(dtype[, align, copy])

建立資料類型物件。

ediff1d(ary[, to_end, to_begin])

計算扁平化陣列元素的差分。

einsum(subscripts, /, *operands[, out, ...])

愛因斯坦求和

einsum_path(subscripts, /, *operands[, optimize])

評估最佳收縮路徑,而不評估 einsum。

empty(shape[, dtype, device])

建立一個空陣列。

empty_like(prototype[, dtype, shape, device])

建立一個與陣列具有相同形狀和 dtype 的空陣列。

equal(x, y, /)

傳回 x == y 的逐元素真值。

exp(x, /)

計算輸入的逐元素指數。

exp2(x, /)

計算輸入的逐元素 2 為底指數。

expand_dims(a, axis)

將長度為 1 的維度插入陣列

expm1(x, /)

計算輸入中每個元素的 exp(x)-1

extract(condition, arr, *[, size, fill_value])

傳回陣列中滿足條件的元素。

eye(N[, M, k, dtype, device])

建立方形或矩形單位矩陣

fabs(x, /)

計算實數值輸入的逐元素絕對值。

fill_diagonal(a, val[, wrap, inplace])

傳回對角線被覆寫的陣列副本。

finfo(dtype)

浮點數類型的機器限制。

fix(x[, out])

將輸入四捨五入到最接近零的整數。

flatnonzero(a, *[, size, fill_value])

傳回扁平化陣列中非零元素的索引

flexible()

所有沒有預定義長度的純量類型的抽象基底類別。

flip(m[, axis])

沿著給定軸反轉陣列元素的順序。

fliplr(m)

沿著軸 1 反轉陣列元素的順序。

flipud(m)

沿著軸 0 反轉陣列元素的順序。

float_

float64 的別名

float_power(x, y, /)

計算 y 的逐元素底數 x 指數。

float16(x)

一個 float16 類型的 JAX 純量建構子。

float32(x)

一個 float32 類型的 JAX 純量建構子。

float64(x)

一個 float64 類型的 JAX 純量建構子。

floating()

所有浮點純量類型的抽象基底類別。

floor(x, /)

將輸入向下四捨五入到最接近的整數。

floor_divide(x1, x2, /)

逐元素計算 x1 除以 x2 的 floor 除法

fmax(x1, x2)

傳回輸入陣列的逐元素最大值。

fmin(x1, x2)

傳回輸入陣列的逐元素最小值。

fmod(x1, x2, /)

計算逐元素浮點數模數運算。

frexp(x, /)

將浮點數值分割成尾數和 2 的指數。

frombuffer(buffer[, dtype, count, offset])

將 buffer 轉換為一維 JAX 陣列。

fromfile(*args, **kwargs)

jnp.fromfile 的未實作 JAX 包裝器。

fromfunction(function, shape, *[, dtype])

從應用於索引的函數建立陣列。

fromiter(*args, **kwargs)

jnp.fromiter 的未實作 JAX 包裝器。

frompyfunc(func, /, nin, nout, *[, identity])

從任意 JAX 相容的純量函數建立 JAX ufunc。

fromstring(string[, dtype, count])

將文字字串轉換為一維 JAX 陣列。

from_dlpack(x, /, *[, device, copy])

透過 DLPack 建構 JAX 陣列。

full(shape, fill_value[, dtype, device])

建立一個充滿指定值的陣列。

full_like(a, fill_value[, dtype, shape, device])

建立一個與陣列具有相同形狀和 dtype,且充滿指定值的陣列。

gcd(x1, x2)

計算兩個陣列的最大公約數。

generic()

numpy 純量類型的基底類別。

geomspace(start, stop[, num, endpoint, ...])

產生幾何間隔的值。

get_printoptions()

傳回目前的列印選項。

gradient(f, *varargs[, axis, edge_order])

計算取樣函數的數值梯度。

greater(x, y, /)

傳回 x > y 的逐元素真值。

greater_equal(x, y, /)

傳回 x >= y 的逐元素真值。

hamming(M)

傳回大小為 M 的 Hamming 視窗。

hanning(M)

傳回大小為 M 的 Hanning 視窗。

heaviside(x1, x2, /)

計算 heaviside 步階函數。

histogram(a[, bins, range, weights, density])

計算一維直方圖。

histogram_bin_edges(a[, bins, range, weights])

計算直方圖的 bin 邊緣。

histogram2d(x, y[, bins, range, weights, ...])

計算二維直方圖。

histogramdd(sample[, bins, range, weights, ...])

計算 N 維直方圖。

hsplit(ary, indices_or_sections)

水平地將陣列分割成子陣列。

hstack(tup[, dtype])

水平堆疊陣列。

hypot(x1, x2, /)

傳回給定直角三角形邊長的逐元素斜邊長度。

i0(x)

計算第一類零階修正貝索函數。

identity(n[, dtype])

建立方形單位矩陣

iinfo(int_type)

imag(val, /)

傳回複數引數的逐元素虛部。

index_exp

一種更方便的方式來建立陣列的索引元組。

indices(dimensions[, dtype, sparse])

產生網格索引陣列。

inexact()

所有數值純量類型的抽象基底類別,這些類型在其範圍內的值具有(可能)不精確的表示形式,例如浮點數。

inner(a, b, *[, precision, ...])

計算兩個陣列的內積。

insert(arr, obj, values[, axis])

在指定索引處將條目插入陣列。

int_

int64 的別名

int16(x)

一個 int16 類型的 JAX 純量建構子。

int32(x)

一個 int32 類型的 JAX 純量建構子。

int64(x)

一個 int64 類型的 JAX 純量建構子。

int8(x)

一個 int8 類型的 JAX 純量建構子。

integer()

所有整數純量類型的抽象基底類別。

interp(x, xp, fp[, left, right, period])

一維線性內插。

intersect1d(ar1, ar2[, assume_unique, ...])

計算兩個一維陣列的集合交集。

invert(x, /)

計算輸入的位元反轉。

isclose(a, b[, rtol, atol, equal_nan])

檢查兩個陣列的元素是否在容差範圍內近似相等。

iscomplex(x)

傳回布林陣列,顯示輸入為複數的位置。

iscomplexobj(x)

檢查輸入是否為複數或包含複數元素的陣列。

isdtype(dtype, kind)

傳回布林值,指示提供的 dtype 是否為指定的種類。

isfinite(x, /)

傳回布林陣列,指示輸入的每個元素是否為有限值。

isin(element, test_elements[, ...])

判斷 element 中的元素是否出現在 test_elements 中。

isinf(x, /)

傳回布林陣列,指示輸入的每個元素是否為無限值。

isnan(x, /)

傳回布林陣列,指示輸入的每個元素是否為 NaN

isneginf(x, /[, out])

傳回布林陣列,指示輸入的每個元素是否為負無限值。

isposinf(x, /[, out])

傳回布林陣列,指示輸入的每個元素是否為正無限值。

isreal(x)

傳回布林陣列,顯示輸入為實數的位置。

isrealobj(x)

檢查輸入是否不是複數或不包含複數元素的陣列。

isscalar(element)

如果輸入是純量,則傳回 True。

issubdtype(arg1, arg2)

如果 arg1 在類型階層中等於或低於 arg2,則傳回 True。

iterable(y)

檢查物件是否可以迭代。

ix_(*args)

從 N 個一維序列傳回多維網格(開放網格)。

kaiser(M, beta)

傳回大小為 M 的 Kaiser 視窗。

kron(a, b)

計算兩個輸入陣列的 Kronecker 乘積。

lcm(x1, x2)

計算兩個陣列的最小公倍數。

ldexp(x1, x2, /)

計算 x1 * 2 ** x2

left_shift(x, y, /)

x 的位元向左移動 y 指定的量,逐元素執行。

less(x, y, /)

傳回 x < y 的逐元素真值。

less_equal(x, y, /)

傳回 x <= y 的逐元素真值。

lexsort(keys[, axis])

以字典順序排序鍵的序列。

linspace(start, stop[, num, endpoint, ...])

傳回區間內均勻間隔的數字。

load(file, *args, **kwargs)

從 npy 檔案載入 JAX 陣列。

log(x, /)

計算輸入的逐元素自然對數。

log10(x, /)

逐元素計算 x 的以 10 為底的對數

log1p(x, /)

逐元素計算 1 加上輸入值的對數,log(x+1)

log2(x, /)

逐元素計算 x 的以 2 為底的對數。

logaddexp

計算 log(exp(x1) + exp(x2)),避免溢位。

logaddexp2

以 2 為底的輸入指數和的對數,避免溢位。

logical_and

逐元素計算邏輯 AND 運算。

logical_not(x, /)

逐元素計算 NOT bool(x)。

logical_or

逐元素計算邏輯 OR 運算。

logical_xor

逐元素計算邏輯 XOR 運算。

logspace(start, stop[, num, endpoint, base, ...])

產生對數間隔值。

mask_indices(n, mask_func[, k, size])

傳回 (n, n) 陣列遮罩的索引。

matmul(a, b, *[, precision, ...])

執行矩陣乘法。

matrix_transpose(x, /)

轉置陣列的最後兩個維度。

matvec(x1, x2, /)

批次矩陣向量乘積。

max(a[, axis, out, keepdims, initial, where])

傳回沿著給定軸的陣列元素最大值。

maximum(x, y, /)

傳回輸入陣列的逐元素最大值。

mean(a[, axis, dtype, out, keepdims, where])

傳回沿著給定軸的陣列元素平均值。

median(a[, axis, out, overwrite_input, keepdims])

傳回沿著給定軸的陣列元素中位數。

meshgrid(*xi[, copy, sparse, indexing])

從 N 個一維向量建構 N 維網格陣列。

mgrid

傳回密集的多元 "meshgrid"。

min(a[, axis, out, keepdims, initial, where])

傳回沿著給定軸的陣列元素最小值。

minimum(x, y, /)

傳回輸入陣列的逐元素最小值。

mod(x1, x2, /)

jax.numpy.remainder() 的別名

modf(x, /[, out])

傳回輸入陣列的逐元素小數和整數部分。

moveaxis(a, source, destination)

將陣列軸移動到新位置

multiply

逐元素相乘兩個陣列。

nan_to_num(x[, copy, nan, posinf, neginf])

取代陣列中的 NaN 和無限值條目。

nanargmax(a[, axis, out, keepdims])

傳回陣列最大值的索引,忽略 NaN。

nanargmin(a[, axis, out, keepdims])

傳回陣列最小值的索引,忽略 NaN。

nancumprod(a[, axis, dtype, out])

沿著軸的元素累積乘積,忽略 NaN 值。

nancumsum(a[, axis, dtype, out])

沿著軸的元素累積總和,忽略 NaN 值。

nanmax(a[, axis, out, keepdims, initial, where])

傳回沿著給定軸的陣列元素最大值,忽略 NaN。

nanmean(a[, axis, dtype, out, keepdims, where])

傳回沿著給定軸的陣列元素平均值,忽略 NaN。

nanmedian(a[, axis, out, overwrite_input, ...])

傳回沿著給定軸的陣列元素中位數,忽略 NaN。

nanmin(a[, axis, out, keepdims, initial, where])

傳回沿著給定軸的陣列元素最小值,忽略 NaN。

nanpercentile(a, q[, axis, out, ...])

計算沿著指定軸的資料百分位數,忽略 NaN 值。

nanprod(a[, axis, dtype, out, keepdims, ...])

傳回沿著給定軸的陣列元素乘積,忽略 NaN。

nanquantile(a, q[, axis, out, ...])

計算沿著指定軸的資料分位數,忽略 NaN 值。

nanstd(a[, axis, dtype, out, ddof, ...])

計算沿著給定軸的標準差,忽略 NaN。

nansum(a[, axis, dtype, out, keepdims, ...])

傳回沿著給定軸的陣列元素總和,忽略 NaN。

nanvar(a[, axis, dtype, out, ddof, ...])

計算沿著給定軸的陣列元素變異數,忽略 NaN。

ndarray

Array 的別名

ndim(a)

傳回陣列的維度數量。

negative

傳回輸入的逐元素負值。

nextafter(x, y, /)

傳回逐元素位於 x 之後朝向 y 的下一個浮點數值。

nonzero(a, *[, size, fill_value])

傳回陣列中非零元素的索引。

not_equal(x, y, /)

傳回 x != y 的逐元素真值。

number()

所有數值純量類型的抽象基底類別。

object_

任何 Python 物件。

ogrid

傳回開放的多元 "meshgrid"。

ones(shape[, dtype, device])

建立一個充滿 1 的陣列。

ones_like(a[, dtype, shape, device])

建立一個與陣列具有相同形狀和 dtype 的全 1 陣列。

outer(a, b[, out])

計算兩個陣列的外積。

packbits(a[, axis, bitorder])

將位元陣列封裝成 uint8 陣列。

pad(array, pad_width[, mode])

在陣列中加入填充。

partition(a, kth[, axis])

傳回陣列的部分排序副本。

percentile(a, q[, axis, out, ...])

計算沿著指定軸的資料百分位數。

permute_dims(a, /, axes)

置換陣列的軸/維度。

piecewise(x, condlist, funclist, *args, **kw)

評估在整個域中分段定義的函數。

place(arr, mask, vals, *[, inplace])

根據遮罩更新陣列元素。

poly(seq_of_zeros)

傳回給定根序列之多項式的係數。

polyadd(a1, a2)

傳回兩個多項式的總和。

polyder(p[, m])

傳回指定階多項式導數的係數。

polydiv(u, v, *[, trim_leading_zeros])

傳回多項式除法的商和餘數。

polyfit(x, y, deg[, rcond, full, w, cov])

資料的最小平方多項式擬合。

polyint(p[, m, k])

傳回指定階多項式積分的係數。

polymul(a1, a2, *[, trim_leading_zeros])

傳回兩個多項式的乘積。

polysub(a1, a2)

傳回兩個多項式的差。

polyval(p, x, *[, unroll])

在特定值評估多項式。

positive(x, /)

傳回輸入的逐元素正值。

pow(x1, x2, /)

jax.numpy.power() 的別名

power(x1, x2, /)

計算 x2 的逐元素底數 x1 指數。

printoptions(*args, **kwargs)

用於設定列印選項的上下文管理器。

prod(a[, axis, dtype, out, keepdims, ...])

傳回給定軸上陣列元素的乘積。

promote_types(a, b)

傳回二元運算應將其引數轉換成的類型。

ptp(a[, axis, out, keepdims])

傳回沿著給定軸的峰對峰範圍。

put(a, ind, v[, mode, inplace])

將元素放入陣列中的給定索引位置。

put_along_axis(arr, indices, values, axis[, ...])

透過比對一維索引和資料切片,將值放入目標陣列。

quantile(a, q[, axis, out, overwrite_input, ...])

計算沿著指定軸的資料分位數。

r_

沿著第一個軸串聯切片、純量和類陣列物件。

rad2deg(x, /)

將角度從弧度轉換為度。

radians(x, /)

jax.numpy.deg2rad() 的別名

ravel(a[, order])

將陣列展平成一維形狀。

ravel_multi_index(multi_index, dims[, mode, ...])

將多維索引轉換為扁平索引。

real(val, /)

傳回複數引數的逐元素實部。

reciprocal(x, /)

計算輸入的逐元素倒數。

remainder(x1, x2, /)

傳回除法的逐元素餘數。

repeat(a, repeats[, axis, total_repeat_length])

從重複的元素建構陣列。

reshape(a[, shape, order, newshape, copy])

傳回陣列的重塑副本。

resize(a, new_shape)

傳回具有指定形狀的新陣列。

result_type(*args)

傳回將 JAX 提升規則應用於輸入的結果。

right_shift(x1, x2, /)

x1 的位元向右移動 x2 指定的量。

rint(x, /)

將 x 的元素四捨五入到最接近的整數

roll(a, shift[, axis])

沿著指定的軸滾動陣列的元素。

rollaxis(a, axis[, start])

將指定的軸滾動到給定的位置。

roots(p, *[, strip_zeros])

傳回給定係數 p 的多項式根。

rot90(m[, k, axes])

在軸指定的平面中將陣列逆時針旋轉 90 度。

round(a[, decimals, out])

將輸入均勻四捨五入到給定的十進位位數。

s_

一種更方便的方式來建立陣列的索引元組。

save(file, arr[, allow_pickle, fix_imports])

以 NumPy .npy 格式將陣列儲存到二進位檔案。

savez(file, *args[, allow_pickle])

以未壓縮的 .npz 格式將多個陣列儲存到單一檔案中。

searchsorted(a, v[, side, sorter, method])

在排序陣列中執行二元搜尋。

select(condlist, choicelist[, default])

根據一系列條件選取值。

set_printoptions([precision, threshold, ...])

設定列印選項。

setdiff1d(ar1, ar2[, assume_unique, size, ...])

計算兩個一維陣列的集合差集。

setxor1d(ar1, ar2[, assume_unique, size, ...])

計算兩個陣列中元素的集合互斥或。

shape(a)

傳回陣列的形狀。

sign(x, /)

傳回輸入的逐元素符號指示。

signbit(x, /)

傳回陣列元素的符號位元。

signedinteger()

所有帶符號整數純量類型的抽象基底類別。

sin(x, /)

計算輸入的每個元素之三角正弦值。

sinc(x, /)

計算正規化的 sinc 函數。

single

float32 的別名

sinh(x, /)

計算輸入的逐元素雙曲正弦值。

size(a[, axis])

傳回沿著給定軸的元素數量。

sort(a[, axis, kind, order, stable, descending])

傳回陣列的排序副本。

sort_complex(a)

傳回複數陣列的排序副本。

spacing(x, /)

傳回 x 與下一個相鄰數字之間的間隔。

split(ary, indices_or_sections[, axis])

將陣列分割成子陣列。

sqrt(x, /)

計算輸入陣列的逐元素非負平方根。

square(x, /)

計算輸入陣列的逐元素平方。

squeeze(a[, axis])

從陣列中移除一個或多個長度為 1 的軸

stack(arrays[, axis, out, dtype])

沿著新軸連接陣列。

std(a[, axis, dtype, out, ddof, keepdims, ...])

計算沿著給定軸的標準差。

subtract

逐元素相減兩個陣列。

sum(a[, axis, dtype, out, keepdims, ...])

沿著給定軸向計算陣列元素總和。

swapaxes(a, axis1, axis2)

交換陣列的兩個軸。

take(a, indices[, axis, out, mode, ...])

從陣列中取出元素。

take_along_axis(arr, indices, axis[, mode, ...])

從陣列中取出元素。

tan(x, /)

計算輸入中每個元素的三角正切值。

tanh(x, /)

逐元素計算輸入的雙曲正切值。

tensordot(a, b[, axes, precision, ...])

計算兩個 N 維陣列的張量點積。

tile(A, reps)

通過沿指定維度重複 A 來構建陣列。

trace(a[, offset, axis1, axis2, dtype, out])

計算沿給定軸的輸入對角線總和。

trapezoid(y[, x, dx, axis])

使用複合梯形法則沿給定軸積分。

transpose(a[, axes])

返回 N 維陣列的轉置版本。

tri(N[, M, k, dtype])

返回一個在對角線及其下方為 1,其他地方為 0 的陣列。

tril(m[, k])

返回陣列的下三角部分。

tril_indices(n[, k, m])

返回大小為 (n, m) 的陣列的下三角索引。

tril_indices_from(arr[, k])

返回給定陣列的下三角索引。

trim_zeros(filt[, trim])

修剪輸入陣列的前導和/或尾隨零。

triu(m[, k])

返回陣列的上三角部分。

triu_indices(n[, k, m])

返回大小為 (n, m) 的陣列的上三角索引。

triu_indices_from(arr[, k])

返回給定陣列的上三角索引。

true_divide(x1, x2, /)

逐元素計算 x1 除以 x2 的結果。

trunc(x)

將輸入四捨五入到最接近零的整數。

ufunc(func, /, nin, nout, *[, name, nargs, ...])

通用函數,對陣列執行逐元素操作。

uint

uint64 的別名

uint16(x)

uint16 類型的 JAX 純量建構子。

uint32(x)

uint32 類型的 JAX 純量建構子。

uint64(x)

uint64 類型的 JAX 純量建構子。

uint8(x)

uint8 類型的 JAX 純量建構子。

union1d(ar1, ar2, *[, size, fill_value])

計算兩個 1 維陣列的集合聯集。

unique(ar[, return_index, return_inverse, ...])

從陣列中返回唯一值。

unique_all(x, /, *[, size, fill_value])

從 x 返回唯一值,以及索引、反向索引和計數。

unique_counts(x, /, *[, size, fill_value])

從 x 返回唯一值,以及計數。

unique_inverse(x, /, *[, size, fill_value])

從 x 返回唯一值,以及索引、反向索引和計數。

unique_values(x, /, *[, size, fill_value])

從 x 返回唯一值,以及索引、反向索引和計數。

unpackbits(a[, axis, count, bitorder])

解包 uint8 陣列中的位元。

unravel_index(indices, shape)

將扁平索引轉換為多維索引。

unstack(x, /, *[, axis])

沿軸向解堆疊陣列。

unsignedinteger()

所有無符號整數純量類型的抽象基底類別。

unwrap(p[, discont, axis, period])

解開週期性訊號。

vander(x[, N, increasing])

產生 Vandermonde 矩陣。

var(a[, axis, dtype, out, ddof, keepdims, ...])

計算沿給定軸的變異數。

vdot(a, b, *[, precision, ...])

執行兩個 1 維向量的共軛乘法。

vecdot(x1, x2, /, *[, axis, precision, ...])

執行兩個批次向量的共軛乘法。

vecmat(x1, x2, /)

批次共軛向量-矩陣乘積。

vectorize(pyfunc, *[, excluded, signature])

定義一個具有廣播功能的向量化函數。

vsplit(ary, indices_or_sections)

將陣列垂直分割為子陣列。

vstack(tup[, dtype])

垂直堆疊陣列。

where(condition[, x, y, size, fill_value])

根據條件從兩個陣列中選擇元素。

zeros(shape[, dtype, device])

建立一個充滿零的陣列。

zeros_like(a[, dtype, shape, device])

建立一個與陣列形狀和 dtype 相同的零陣列。

jax.numpy.fft#

fft(a[, n, axis, norm])

沿給定軸計算一維離散傅立葉變換。

fft2(a[, s, axes, norm])

沿給定軸計算二維離散傅立葉變換。

fftfreq(n[, d, dtype, device])

返回離散傅立葉變換的採樣頻率。

fftn(a[, s, axes, norm])

沿給定軸計算多維離散傅立葉變換。

fftshift(x[, axes])

將零頻率 fft 分量移動到頻譜中心。

hfft(a[, n, axis, norm])

計算頻譜具有 Hermitian 對稱性的陣列的 1-D FFT。

ifft(a[, n, axis, norm])

計算一維反離散傅立葉變換。

ifft2(a[, s, axes, norm])

計算二維反離散傅立葉變換。

ifftn(a[, s, axes, norm])

計算多維反離散傅立葉變換。

ifftshift(x[, axes])

jax.numpy.fft.fftshift() 的反函數。

ihfft(a[, n, axis, norm])

計算頻譜具有 Hermitian 對稱性的陣列的 1-D 反 FFT。

irfft(a[, n, axis, norm])

計算實值一維反離散傅立葉變換。

irfft2(a[, s, axes, norm])

計算實值二維反離散傅立葉變換。

irfftn(a[, s, axes, norm])

計算實值多維反離散傅立葉變換。

rfft(a[, n, axis, norm])

計算實值陣列的一維離散傅立葉變換。

rfft2(a[, s, axes, norm])

計算實值陣列的二維離散傅立葉變換。

rfftfreq(n[, d, dtype, device])

返回離散傅立葉變換的採樣頻率。

rfftn(a[, s, axes, norm])

計算實值陣列的多維離散傅立葉變換。

jax.numpy.linalg#

cholesky(a, *[, upper])

計算矩陣的 Cholesky 分解。

cond(x[, p])

計算矩陣的條件數。

cross(x1, x2, /, *[, axis])

計算兩個 3D 向量的叉積。

det(a)

計算陣列的行列式。

diagonal(x, /, *[, offset])

提取矩陣或矩陣堆疊的對角線。

eig(a)

計算方陣的特徵值和特徵向量。

eigh(a[, UPLO, symmetrize_input])

計算 Hermitian 矩陣的特徵值和特徵向量。

eigvals(a)

計算一般矩陣的特徵值。

eigvalsh(a[, UPLO])

計算 Hermitian 矩陣的特徵值。

inv(a)

返回方陣的反矩陣。

lstsq(a, b[, rcond, numpy_resid])

返回線性方程式的最小平方解。

matmul(x1, x2, /, *[, precision, ...])

執行矩陣乘法。

matrix_norm(x, /, *[, keepdims, ord])

計算矩陣或矩陣堆疊的範數。

matrix_power(a, n)

將方陣提升為整數冪次。

matrix_rank(M[, rtol, tol])

計算矩陣的秩。

matrix_transpose(x, /)

轉置矩陣或矩陣堆疊。

multi_dot(arrays, *[, precision])

有效率地計算陣列序列之間的矩陣乘積。

norm(x[, ord, axis, keepdims])

計算矩陣或向量的範數。

outer(x1, x2, /)

計算兩個 1 維陣列的外積。

pinv(a[, rtol, hermitian, rcond])

計算矩陣的 (Moore-Penrose) 偽反矩陣。

qr(a[, mode])

計算陣列的 QR 分解。

slogdet(a, *[, method])

計算陣列行列式的符號和 (自然) 對數。

solve(a, b)

求解線性方程式系統。

svd(a[, full_matrices, compute_uv, ...])

計算奇異值分解。

svdvals(x, /)

計算矩陣的奇異值。

tensordot(x1, x2, /, *[, axes, precision, ...])

計算兩個 N 維陣列的張量點積。

tensorinv(a[, ind])

計算陣列的張量反矩陣。

tensorsolve(a, b[, axes])

求解張量方程式 a x = b 中的 x。

trace(x, /, *[, offset, dtype])

計算矩陣的跡。

vector_norm(x, /, *[, axis, keepdims, ord])

計算向量或批次向量的向量範數。

vecdot(x1, x2, /, *[, axis, precision, ...])

計算兩個陣列的(批次)向量共軛點積。

JAX 陣列#

JAX Array (及其別名 jax.numpy.ndarray) 是 JAX 中的核心陣列物件:您可以將其視為 JAX 中與 numpy.ndarray 等效的物件。與 numpy.ndarray 類似,大多數使用者不需要手動實例化 Array 物件,而是通過 jax.numpy 函數(如 array()arange()linspace() 和上面列出的其他函數)來建立它們。

複製與序列化#

JAX Array 物件旨在與適用的 Python 標準函式庫工具無縫協作。

使用內建的 copy 模組時,當 copy.copy()copy.deepcopy() 遇到 Array 時,它等效於調用 copy() 方法,這將在與原始陣列相同的裝置上建立緩衝區的副本。這在追蹤/JIT 編譯程式碼中可以正確運作,儘管複製操作可能會在此上下文中被編譯器省略。

當內建的 pickle 模組遇到 Array 時,它將通過緊湊的位元表示形式進行序列化,方式與 pickled numpy.ndarray 物件類似。當 unpickled 時,結果將是在預設裝置上的新 Array 物件。這是因為一般而言,pickling 和 unpickling 可能在不同的執行階段環境中進行,並且沒有通用的方法將一個執行階段的裝置 ID 映射到另一個執行階段的裝置 ID。如果在追蹤/JIT 編譯程式碼中使用 pickle,將會導致 ConcretizationTypeError

Python 陣列 API 標準#

注意

在 JAX v0.4.32 之前的版本,您必須 import jax.experimental.array_api 才能為 JAX 陣列啟用陣列 API。在 JAX v0.4.32 之後,不再需要導入此模組,並且會引發棄用警告。在 JAX v0.5.0 之後,此導入將引發錯誤。

從 JAX v0.4.32 開始,jax.Arrayjax.numpyPython 陣列 API 標準 相容。您可以通過 jax.Array.__array_namespace__() 訪問陣列 API 命名空間

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX 在一些地方偏離了標準,主要是因為 JAX 陣列是不可變的,不支援原地更新。其中一些不相容性正在通過 array-api-compat 模組來解決。

更多資訊,請參考 Python Array API Standard 文件。