詞彙表#

陣列#

JAX 中 numpy.ndarray 的類比。請參閱 jax.Array

CPU#

中央處理單元 (Central Processing Unit) 的縮寫,CPU 是大多數電腦中可用的標準計算架構。JAX 可以在 CPU 上執行計算,但通常可以在 GPUTPU 上獲得更好的效能。

裝置#

用於指稱 JAX 用於執行計算的 CPUGPUTPU 的通用名稱。

前向模式自動微分#

請參閱 JVP

函數式程式設計#

一種程式設計典範,其中程式透過應用和組合純函數來定義。JAX 旨在與函數式程式一起使用。

GPU#

圖形處理單元 (Graphical Processing Unit) 的縮寫,GPU 最初專門用於與螢幕影像渲染相關的操作,但現在用途更加廣泛。JAX 能夠以 GPU 為目標,以加速陣列運算(另請參閱 CPUTPU)。

jaxpr#

JAX 運算式 (JAX expression) 的縮寫,jaxpr 是 JAX 產生的計算的中間表示形式,並轉發到 XLA 進行編譯和執行。請參閱 JAX 內部機制:jaxpr 語言,以獲得更多討論和範例。

JIT#

即時 (Just In Time) 編譯的縮寫,JAX 中的 JIT 通常指將陣列運算編譯為 XLA,最常使用 jax.jit() 完成。

JVP#

雅可比向量積 (Jacobian Vector Product) 的縮寫,有時也稱為前向模式自動微分。如需更多詳細資訊,請參閱 Jacobian-Vector products (JVPs, aka forward-mode autodiff)。在 JAX 中,JVP 是一種轉換,透過 jax.jvp() 實作。另請參閱 VJP

基本運算#

基本運算是 JAX 程式中使用的基本計算單位。jax.lax 中的大多數函數都代表個別的基本運算。當在 jaxpr 中表示計算時,jaxpr 中的每個運算都是一個基本運算。

純函數#

純函數是一種輸出僅基於其輸入,且沒有副作用的函數。JAX 的轉換模型旨在與純函數一起使用。另請參閱函數式程式設計

pytree#

pytree 是一種抽象化,可讓 JAX 以統一的方式處理元組、列表、字典和陣列值的其他更通用容器。請參閱使用 pytree 以獲得更詳細的討論。

反向模式自動微分#

請參閱 VJP

SPMD#

單一程式多資料 (Single Program Multi Data) 的縮寫,它指的是一種平行計算技術,其中相同的計算(例如,神經網路的前向傳遞)在不同的輸入資料(例如,批次中的不同輸入)上平行運行於不同的裝置(例如,多個 TPU)。jax.pmap() 是一種實作 SPMD 平行化的 JAX 轉換

靜態#

JIT 編譯中,一個未追蹤的值(請參閱 Tracer)。有時也指對靜態值進行的編譯時間計算。

TPU#

張量處理單元 (Tensor Processing Unit) 的縮寫,TPU 是專門為深度學習應用中使用的 N 維張量的快速運算而設計的晶片。JAX 能夠以 TPU 為目標,以加速陣列運算(另請參閱 CPUGPU)。

追蹤器#

一個用作 JAX 陣列的替代物件,以確定 Python 函數執行的運算順序。在內部,JAX 透過 jax.core.Tracer 類別實作此功能。

轉換#

一種高階函數:也就是說,一個以函數作為輸入並輸出轉換後函數的函數。JAX 中的範例包括 jax.jit()jax.vmap()jax.grad()

VJP#

向量雅可比積 (Vector Jacobian Product) 的縮寫,有時也稱為反向模式自動微分。如需更多詳細資訊,請參閱 Vector-Jacobian products (VJPs, aka reverse-mode autodiff)。在 JAX 中,VJP 是一種轉換,透過 jax.vjp() 實作。另請參閱 JVP

XLA#

加速線性代數 (Accelerated Linear Algebra) 的縮寫,XLA 是一種用於線性代數運算的領域特定編譯器,是 JIT 編譯的 JAX 程式碼的主要後端。請參閱 https://tensorflow.dev.org.tw/xla/

弱型別#

一種 JAX 資料型別,其型別提升語意與 Python 純量相同;請參閱 JAX 中的弱型別值