jax.Array#

class jax.Array#

JAX 的陣列基底類別

jax.Array 是 JAX 陣列和追蹤器的實例檢查和類型註釋的公共介面。其主要應用在於實例檢查和類型註釋;例如

x = jnp.arange(5)
isinstance(x, jax.Array)  # returns True both inside and outside traced functions.

def f(x: Array) -> Array:  # type annotations are valid for traced and non-traced types.
  return x

jax.Array 不應直接用於建立陣列;相反地,您應該使用 jax.numpy 中提供的陣列建立常式,例如 jax.numpy.array()jax.numpy.zeros()jax.numpy.ones()jax.numpy.full()jax.numpy.arange() 等。

__init__()#

方法

__init__()

addressable_data(index)

傳回特定索引處可定址資料的陣列。

all([axis, out, keepdims, where])

測試給定軸上的所有陣列元素是否評估為 True。

any([axis, out, keepdims, where])

測試給定軸上的任何陣列元素是否評估為 True。

argmax([axis, out, keepdims])

傳回最大值的索引。

argmin([axis, out, keepdims])

傳回最小值的索引。

argpartition(kth[, axis])

傳回部分排序陣列的索引。

argsort([axis, kind, order, stable, descending])

傳回排序陣列的索引。

astype(dtype[, copy, device])

複製陣列並轉換為指定的 dtype。

choose(choices[, out, mode])

建構一個從多個陣列的元素中選擇的陣列。

clip([min, max])

傳回一個值限制在指定範圍內的陣列。

compress(condition[, axis, out, size, ...])

傳回沿給定軸的這個陣列的選定切片。

conj()

傳回陣列的複共軛。

conjugate()

傳回陣列的複共軛。

copy()

傳回陣列的副本。

copy_to_host_async()

非同步地將 Array 複製到主機。

cumprod([axis, dtype, out])

傳回陣列的累積乘積。

cumsum([axis, dtype, out])

傳回陣列的累積總和。

diagonal([offset, axis1, axis2])

從陣列傳回指定的對角線。

dot(b, *[, precision, preferred_element_type])

計算兩個陣列的點積。

flatten([order])

將陣列展平為 1 維形狀。

item(*args)

將陣列的元素複製到標準 Python 純量並傳回。

max([axis, out, keepdims, initial, where])

傳回給定軸上陣列元素的最大值。

mean([axis, dtype, out, keepdims, where])

傳回給定軸上陣列元素的平均值。

min([axis, out, keepdims, initial, where])

傳回給定軸上陣列元素的最小值。

nonzero(*[, fill_value, size])

傳回陣列的非零元素的索引。

prod([axis, dtype, out, keepdims, initial, ...])

傳回給定軸上陣列元素的乘積。

ptp([axis, out, keepdims])

傳回給定軸上的峰對峰範圍。

ravel([order])

將陣列展平為 1 維形狀。

repeat(repeats[, axis, total_repeat_length])

從重複元素建構陣列。

reshape(*args[, order])

傳回包含相同資料但具有新形狀的陣列。

round([decimals, out])

將陣列元素四捨五入到給定的十進位。

searchsorted(v[, side, sorter, method])

在排序的陣列中執行二元搜尋。

sort([axis, kind, order, stable, descending])

傳回陣列的排序副本。

squeeze([axis])

從陣列中移除一或多個長度為 1 的軸。

std([axis, dtype, out, ddof, keepdims, ...])

計算沿給定軸的標準差。

sum([axis, dtype, out, keepdims, initial, ...])

給定軸上陣列元素的總和。

swapaxes(axis1, axis2)

交換陣列的兩個軸。

take(indices[, axis, out, mode, ...])

從陣列中取得元素。

to_device(device, *[, stream])

傳回指定裝置上陣列的副本

trace([offset, axis1, axis2, dtype, out])

傳回沿對角線的總和。

transpose(*args)

傳回軸已轉置的陣列副本。

var([axis, dtype, out, ddof, keepdims, ...])

計算沿給定軸的變異數。

view([dtype, type])

傳回陣列的位元複製,視為新的 dtype。

屬性

T

計算全軸陣列轉置。

addressable_shards

可定址分片的清單。

at

用於索引更新功能的輔助屬性。

committed

陣列是否已提交。

device

與 Array API 相容的裝置屬性。

dtype

陣列的資料類型 (numpy.dtype)。

flat

請改用 flatten()

global_shards

全域分片的清單。

imag

傳回陣列的虛部。

is_fully_addressable

此陣列是否完全可定址?

is_fully_replicated

此陣列是否完全複製?

itemsize

一個陣列元素以位元組為單位的長度。

mT

計算(批次)矩陣轉置。

nbytes

陣列元素消耗的總位元組數。

ndim

陣列中的維度數。

real

傳回陣列的實部。

shape

陣列的形狀。

sharding

陣列的分片。

size

陣列中的元素總數。