公開 API:jax 套件#

子套件#

組態#

config

check_tracer_leaks

用於 jax_check_tracer_leaks 組態選項的上下文管理器。

checking_leaks

用於 jax_check_tracer_leaks 組態選項的上下文管理器。

debug_nans

用於 jax_debug_nans 組態選項的上下文管理器。

debug_infs

用於 jax_debug_infs 組態選項的上下文管理器。

default_device

用於 jax_default_device 組態選項的上下文管理器。

default_matmul_precision

用於 jax_default_matmul_precision 組態選項的上下文管理器。

default_prng_impl

用於 jax_default_prng_impl 組態選項的上下文管理器。

enable_checks

用於 jax_enable_checks 組態選項的上下文管理器。

enable_custom_prng

用於 jax_enable_custom_prng 組態選項的上下文管理器 (暫時性)。

enable_custom_vjp_by_custom_transpose

用於 jax_enable_custom_vjp_by_custom_transpose 組態選項的上下文管理器 (暫時性)。

log_compiles

用於 jax_log_compiles 組態選項的上下文管理器。

numpy_rank_promotion

用於 jax_numpy_rank_promotion 組態選項的上下文管理器。

transfer_guard(new_val)

用於控制所有傳輸的傳輸保護等級的上下文管理器。

即時編譯 (jit)#

jit(fun[, in_shardings, out_shardings, ...])

設定 fun 以使用 XLA 進行即時編譯。

disable_jit([disable])

在其動態上下文下禁用 jit() 行為的上下文管理器。

ensure_compile_time_eval()

確保在追蹤/編譯時評估(或錯誤)的上下文管理器。

make_jaxpr([axis_env, return_shape, ...])

建立一個函數,該函數產生給定範例參數的 jaxpr。

eval_shape(fun, *args, **kwargs)

計算 fun 的形狀/dtype,而無需任何 FLOP。

ShapeDtypeStruct(shape, dtype, *[, ...])

用於陣列的形狀、dtype 和其他靜態屬性的容器。

device_put(x[, device, src, donate, may_alias])

x 傳輸到 device

device_get(x)

x 傳輸到主機。

default_backend()

返回預設 XLA 後端的平台名稱。

named_call(fun, *[, name])

在分段輸出 JAX 計算時,向函數添加使用者指定的名稱。

named_scope(name)

一個上下文管理器,向 JAX 名稱堆疊添加使用者指定的名稱。

block_until_ready(x)

嘗試在 pytree 葉子上調用 block_until_ready 方法。

make_mesh(axis_shapes, axis_names, *[, ...])

建立具有指定形狀和軸名稱的有效網格。

自動微分#

grad(fun[, argnums, has_aux, holomorphic, ...])

建立一個評估 fun 梯度的函數。

value_and_grad(fun[, argnums, has_aux, ...])

建立一個同時評估 funfun 梯度的函數。

jacobian(fun[, argnums, has_aux, ...])

jax.jacrev() 的別名。

jacfwd(fun[, argnums, has_aux, holomorphic])

使用前向模式 AD 逐列評估 fun 的 Jacobian。

jacrev(fun[, argnums, has_aux, holomorphic, ...])

使用反向模式 AD 逐行評估 fun 的 Jacobian。

hessian(fun[, argnums, has_aux, holomorphic])

作為密集陣列的 fun 的 Hessian。

jvp(fun, primals, tangents[, has_aux])

計算 fun 的(前向模式) Jacobian-向量乘積。

linearize()

使用 jvp() 和部分評估產生 fun 的線性近似。

linear_transpose(fun, *primals[, reduce_axes])

轉置保證為線性的函數。

vjp() ))

計算 fun 的(反向模式)向量- Jacobian 乘積。

custom_gradient(fun)

用於定義自訂 VJP 規則(又名自訂梯度)的便利函數。

closure_convert(fun, *example_args)

閉包轉換實用程式,用於高階自訂導數。

checkpoint(fun, *[, prevent_cse, policy, ...])

使 fun 在微分時重新計算內部線性化點。

自訂#

custom_jvp#

custom_jvp(fun[, nondiff_argnums])

設定一個 JAX 可轉換函數以進行自訂 JVP 規則定義。

custom_jvp.defjvp(jvp[, symbolic_zeros])

為此實例表示的函數定義自訂 JVP 規則。

custom_jvp.defjvps(*jvps)

用於為每個參數單獨定義 JVP 的便利包裝器。

custom_vjp#

custom_vjp(fun[, nondiff_argnums])

設定一個 JAX 可轉換函數以進行自訂 VJP 規則定義。

custom_vjp.defvjp(fwd, bwd[, ...])

為此實例表示的函數定義自訂 VJP 規則。

custom_batching#

custom_batching.custom_vmap(fun)

自訂 JAX 可轉換函數的 vmap 行為。

custom_batching.custom_vmap.def_vmap(vmap_rule)

為此 custom_vmap 函數定義 vmap 規則。

custom_batching.sequential_vmap(f)

使用迴圈的 custom_vmap 的特殊情況。

jax.Array (jax.Array)#

陣列()

JAX 的陣列基底類別

make_array_from_callback(shape, sharding, ...)

透過從 data_callback 獲取的資料返回 jax.Array

make_array_from_single_device_arrays(shape, ...)

從每個都在單個裝置上的 jax.Array 序列返回 jax.Array

make_array_from_process_local_data(sharding, ...)

使用進程中可用的資料建立分散式張量。

陣列屬性和方法#

Array.addressable_shards

可定址分片的列表。

Array.all([axis, out, keepdims, where])

測試給定軸上的所有陣列元素是否評估為 True。

Array.any([axis, out, keepdims, where])

測試給定軸上的任何陣列元素是否評估為 True。

Array.argmax([axis, out, keepdims])

返回最大值的索引。

Array.argmin([axis, out, keepdims])

返回最小值的索引。

Array.argpartition(kth[, axis])

返回部分排序陣列的索引。

Array.argsort([axis, kind, order, stable, ...])

返回排序陣列的索引。

Array.astype(dtype[, copy, device])

複製陣列並轉換為指定的 dtype。

Array.at

用於索引更新功能的輔助屬性。

Array.choose(choices[, out, mode])

建構一個從多個陣列的元素中選擇的陣列。

Array.clip([min, max])

返回一個值限制在指定範圍內的陣列。

Array.compress(condition[, axis, out, size, ...])

沿給定軸返回此陣列的選定切片。

Array.committed

陣列是否已提交。

Array.conj()

返回陣列的複數共軛。

Array.conjugate()

返回陣列的複數共軛。

Array.copy()

返回陣列的副本。

Array.copy_to_host_async()

Array 非同步複製到主機。

Array.cumprod([axis, dtype, out])

返回陣列的累積乘積。

Array.cumsum([axis, dtype, out])

返回陣列的累積總和。

Array.device

Array API 相容的裝置屬性。

Array.diagonal([offset, axis1, axis2])

從陣列返回指定的對角線。

Array.dot(b, *[, precision, ...])

計算兩個陣列的點積。

Array.dtype

陣列的資料類型 (numpy.dtype)。

Array.flat

請改用 flatten()

Array.flatten([order])

將陣列展平為一維形狀。

Array.global_shards

全域分片的列表。

Array.imag

返回陣列的虛部。

Array.is_fully_addressable

此陣列是否完全可定址?

Array.is_fully_replicated

此陣列是否完全複製?

Array.item(*args)

將陣列的元素複製到標準 Python 純量並返回。

Array.itemsize

一個陣列元素以位元組為單位的長度。

Array.max([axis, out, keepdims, initial, where])

返回給定軸上陣列元素的最大值。

Array.mean([axis, dtype, out, keepdims, where])

返回給定軸上陣列元素的平均值。

Array.min([axis, out, keepdims, initial, where])

返回給定軸上陣列元素的最小值。

Array.nbytes

陣列元素消耗的總位元組數。

Array.ndim

陣列中的維度數。

Array.nonzero(*[, fill_value, size])

返回陣列的非零元素的索引。

Array.prod([axis, dtype, out, keepdims, ...])

返回給定軸上陣列元素的乘積。

Array.ptp([axis, out, keepdims])

返回給定軸上的峰峰值範圍。

Array.ravel([order])

將陣列展平為一維形狀。

Array.real

返回陣列的實部。

Array.repeat(repeats[, axis, ...])

從重複元素建構陣列。

Array.reshape(*args[, order])

返回包含相同資料但具有新形狀的陣列。

Array.round([decimals, out])

將陣列元素四捨五入到給定的十進位數。

Array.searchsorted(v[, side, sorter, method])

在排序的陣列中執行二元搜尋。

Array.shape

陣列的形狀。

Array.sharding

陣列的分片。

Array.size

陣列中的元素總數。

Array.sort([axis, kind, order, stable, ...])

傳回陣列的排序副本。

Array.squeeze([axis])

從陣列中移除一或多個長度為 1 的軸。

Array.std([axis, dtype, out, ddof, ...])

計算沿著給定軸的標準差。

Array.sum([axis, dtype, out, keepdims, ...])

計算陣列元素在給定軸上的總和。

Array.swapaxes(axis1, axis2)

交換陣列的兩個軸。

Array.take(indices[, axis, out, mode, ...])

從陣列中取出元素。

Array.to_device(device, *[, stream])

傳回在指定裝置上的陣列副本

Array.trace([offset, axis1, axis2, dtype, out])

傳回沿對角線的總和。

Array.transpose(*args)

傳回軸已轉置的陣列副本。

Array.var([axis, dtype, out, ddof, ...])

計算沿著給定軸的變異數。

Array.view([dtype, type])

傳回陣列的位元複製,並將其視為新的 dtype。

Array.T

計算全軸陣列轉置。

Array.mT

計算(批次)矩陣轉置。

向量化 (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

向量化映射。

numpy.vectorize(pyfunc, *[, excluded, signature])

定義具有廣播的向量化函式。

平行化 (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

支援集合運算的平行映射。

devices([backend])

傳回給定後端的所有裝置列表。

local_devices([process_index, backend, host_id])

類似於 jax.devices(),但僅傳回給定程序本機的裝置。

process_index([backend])

傳回此程序的整數程序索引。

device_count([backend])

傳回裝置總數。

local_device_count([backend])

傳回此程序可定址的裝置數量。

process_count([backend])

傳回與後端關聯的 JAX 程序數量。

process_indices([backend])

傳回與後端關聯的所有 JAX 程序索引列表。

回呼函數#

pure_callback(callback, result_shape_dtypes, ...)

呼叫純 Python 回呼函數。

experimental.io_callback(callback, ...[, ...])

呼叫不純 Python 回呼函數。

debug.callback(callback, *args[, ordered])

呼叫可階段化的 Python 回呼函數。

debug.print(fmt, *args[, ordered])

印出值,並可在階段化輸出的 JAX 函數中使用。

雜項#

裝置

可用裝置的描述符。

print_environment_info([return_string])

傳回包含本機環境與 JAX 安裝資訊的字串。

live_arrays([platform])

傳回 platform 後端中的所有活動陣列。

clear_caches()

清除所有編譯和階段快取。