公開 API:jax
套件#
子套件#
jax.numpy
模組jax.scipy
模組jax.lax
模組jax.random
模組jax.sharding
模組jax.debug
模組jax.dlpack
模組jax.distributed
模組jax.dtypes
模組jax.ffi
模組jax.extend.ffi
模組 (已棄用)jax.flatten_util
模組jax.image
模組jax.nn
模組jax.ops
模組jax.profiler
模組jax.stages
模組jax.tree
模組jax.tree_util
模組jax.typing
模組jax.export
模組jax.extend
模組jax.example_libraries
模組jax.experimental
模組
組態#
用於 jax_check_tracer_leaks 組態選項的上下文管理器。 |
|
用於 jax_check_tracer_leaks 組態選項的上下文管理器。 |
|
用於 jax_debug_nans 組態選項的上下文管理器。 |
|
用於 jax_debug_infs 組態選項的上下文管理器。 |
|
用於 jax_default_device 組態選項的上下文管理器。 |
|
用於 jax_default_matmul_precision 組態選項的上下文管理器。 |
|
用於 jax_default_prng_impl 組態選項的上下文管理器。 |
|
用於 jax_enable_checks 組態選項的上下文管理器。 |
|
用於 jax_enable_custom_prng 組態選項的上下文管理器 (暫時性)。 |
|
用於 jax_enable_custom_vjp_by_custom_transpose 組態選項的上下文管理器 (暫時性)。 |
|
用於 jax_log_compiles 組態選項的上下文管理器。 |
|
用於 jax_numpy_rank_promotion 組態選項的上下文管理器。 |
|
|
用於控制所有傳輸的傳輸保護等級的上下文管理器。 |
即時編譯 (jit
)#
|
設定 |
|
在其動態上下文下禁用 |
確保在追蹤/編譯時評估(或錯誤)的上下文管理器。 |
|
|
建立一個函數,該函數產生給定範例參數的 jaxpr。 |
|
計算 |
|
用於陣列的形狀、dtype 和其他靜態屬性的容器。 |
|
將 |
|
將 |
返回預設 XLA 後端的平台名稱。 |
|
|
在分段輸出 JAX 計算時,向函數添加使用者指定的名稱。 |
|
一個上下文管理器,向 JAX 名稱堆疊添加使用者指定的名稱。 |
嘗試在 pytree 葉子上調用 |
|
|
建立具有指定形狀和軸名稱的有效網格。 |
自動微分#
|
建立一個評估 |
|
建立一個同時評估 |
|
|
|
使用前向模式 AD 逐列評估 |
|
使用反向模式 AD 逐行評估 |
|
作為密集陣列的 |
|
計算 |
使用 |
|
|
轉置保證為線性的函數。 |
|
計算 |
|
用於定義自訂 VJP 規則(又名自訂梯度)的便利函數。 |
|
閉包轉換實用程式,用於高階自訂導數。 |
|
使 |
自訂#
custom_jvp
#
|
設定一個 JAX 可轉換函數以進行自訂 JVP 規則定義。 |
|
為此實例表示的函數定義自訂 JVP 規則。 |
|
用於為每個參數單獨定義 JVP 的便利包裝器。 |
custom_vjp
#
|
設定一個 JAX 可轉換函數以進行自訂 VJP 規則定義。 |
|
為此實例表示的函數定義自訂 VJP 規則。 |
custom_batching
#
自訂 JAX 可轉換函數的 vmap 行為。 |
|
|
為此 custom_vmap 函數定義 vmap 規則。 |
使用迴圈的 |
jax.Array (jax.Array
)#
|
JAX 的陣列基底類別 |
|
透過從 |
|
從每個都在單個裝置上的 |
|
使用進程中可用的資料建立分散式張量。 |
陣列屬性和方法#
可定址分片的列表。 |
|
|
測試給定軸上的所有陣列元素是否評估為 True。 |
|
測試給定軸上的任何陣列元素是否評估為 True。 |
|
返回最大值的索引。 |
|
返回最小值的索引。 |
|
返回部分排序陣列的索引。 |
|
返回排序陣列的索引。 |
|
複製陣列並轉換為指定的 dtype。 |
用於索引更新功能的輔助屬性。 |
|
|
建構一個從多個陣列的元素中選擇的陣列。 |
|
返回一個值限制在指定範圍內的陣列。 |
|
沿給定軸返回此陣列的選定切片。 |
陣列是否已提交。 |
|
返回陣列的複數共軛。 |
|
返回陣列的複數共軛。 |
|
返回陣列的副本。 |
|
將 |
|
|
返回陣列的累積乘積。 |
|
返回陣列的累積總和。 |
Array API 相容的裝置屬性。 |
|
|
從陣列返回指定的對角線。 |
|
計算兩個陣列的點積。 |
陣列的資料類型 ( |
|
請改用 |
|
|
將陣列展平為一維形狀。 |
全域分片的列表。 |
|
返回陣列的虛部。 |
|
此陣列是否完全可定址? |
|
此陣列是否完全複製? |
|
|
將陣列的元素複製到標準 Python 純量並返回。 |
一個陣列元素以位元組為單位的長度。 |
|
|
返回給定軸上陣列元素的最大值。 |
|
返回給定軸上陣列元素的平均值。 |
|
返回給定軸上陣列元素的最小值。 |
陣列元素消耗的總位元組數。 |
|
陣列中的維度數。 |
|
|
返回陣列的非零元素的索引。 |
|
返回給定軸上陣列元素的乘積。 |
|
返回給定軸上的峰峰值範圍。 |
|
將陣列展平為一維形狀。 |
返回陣列的實部。 |
|
|
從重複元素建構陣列。 |
|
返回包含相同資料但具有新形狀的陣列。 |
|
將陣列元素四捨五入到給定的十進位數。 |
|
在排序的陣列中執行二元搜尋。 |
陣列的形狀。 |
|
陣列的分片。 |
|
陣列中的元素總數。 |
|
|
傳回陣列的排序副本。 |
|
從陣列中移除一或多個長度為 1 的軸。 |
|
計算沿著給定軸的標準差。 |
|
計算陣列元素在給定軸上的總和。 |
|
交換陣列的兩個軸。 |
|
從陣列中取出元素。 |
|
傳回在指定裝置上的陣列副本 |
|
傳回沿對角線的總和。 |
|
傳回軸已轉置的陣列副本。 |
|
計算沿著給定軸的變異數。 |
|
傳回陣列的位元複製,並將其視為新的 dtype。 |
計算全軸陣列轉置。 |
|
計算(批次)矩陣轉置。 |
向量化 (vmap
)#
|
向量化映射。 |
|
定義具有廣播的向量化函式。 |
平行化 (pmap
)#
|
支援集合運算的平行映射。 |
|
傳回給定後端的所有裝置列表。 |
|
類似於 |
|
傳回此程序的整數程序索引。 |
|
傳回裝置總數。 |
|
傳回此程序可定址的裝置數量。 |
|
傳回與後端關聯的 JAX 程序數量。 |
|
傳回與後端關聯的所有 JAX 程序索引列表。 |
回呼函數#
|
呼叫純 Python 回呼函數。 |
|
呼叫不純 Python 回呼函數。 |
|
呼叫可階段化的 Python 回呼函數。 |
|
印出值,並可在階段化輸出的 JAX 函數中使用。 |
雜項#
可用裝置的描述符。 |
|
|
傳回包含本機環境與 JAX 安裝資訊的字串。 |
|
傳回 platform 後端中的所有活動陣列。 |
清除所有編譯和階段快取。 |