詞彙表#
- 陣列#
JAX 中
numpy.ndarray
的類比。請參閱jax.Array
。- CPU#
中央處理單元 (Central Processing Unit) 的縮寫,CPU 是大多數電腦中可用的標準計算架構。JAX 可以在 CPU 上執行計算,但通常可以在 GPU 和 TPU 上獲得更好的效能。
- 裝置#
- 前向模式自動微分#
請參閱 JVP
- 函數式程式設計#
一種程式設計典範,其中程式透過應用和組合純函數來定義。JAX 旨在與函數式程式一起使用。
- GPU#
圖形處理單元 (Graphical Processing Unit) 的縮寫,GPU 最初專門用於與螢幕影像渲染相關的操作,但現在用途更加廣泛。JAX 能夠以 GPU 為目標,以加速陣列運算(另請參閱 CPU 和 TPU)。
- jaxpr#
JAX 運算式 (JAX expression) 的縮寫,jaxpr 是 JAX 產生的計算的中間表示形式,並轉發到 XLA 進行編譯和執行。請參閱 JAX 內部機制:jaxpr 語言,以獲得更多討論和範例。
- JIT#
即時 (Just In Time) 編譯的縮寫,JAX 中的 JIT 通常指將陣列運算編譯為 XLA,最常使用
jax.jit()
完成。- JVP#
雅可比向量積 (Jacobian Vector Product) 的縮寫,有時也稱為前向模式自動微分。如需更多詳細資訊,請參閱 Jacobian-Vector products (JVPs, aka forward-mode autodiff)。在 JAX 中,JVP 是一種轉換,透過
jax.jvp()
實作。另請參閱 VJP。- 基本運算#
基本運算是 JAX 程式中使用的基本計算單位。
jax.lax
中的大多數函數都代表個別的基本運算。當在 jaxpr 中表示計算時,jaxpr 中的每個運算都是一個基本運算。- 純函數#
- pytree#
pytree 是一種抽象化,可讓 JAX 以統一的方式處理元組、列表、字典和陣列值的其他更通用容器。請參閱使用 pytree 以獲得更詳細的討論。
- 反向模式自動微分#
請參閱 VJP。
- SPMD#
單一程式多資料 (Single Program Multi Data) 的縮寫,它指的是一種平行計算技術,其中相同的計算(例如,神經網路的前向傳遞)在不同的輸入資料(例如,批次中的不同輸入)上平行運行於不同的裝置(例如,多個 TPU)。
jax.pmap()
是一種實作 SPMD 平行化的 JAX 轉換。- 靜態#
- TPU#
張量處理單元 (Tensor Processing Unit) 的縮寫,TPU 是專門為深度學習應用中使用的 N 維張量的快速運算而設計的晶片。JAX 能夠以 TPU 為目標,以加速陣列運算(另請參閱 CPU 和 GPU)。
- 追蹤器#
一個用作 JAX 陣列的替代物件,以確定 Python 函數執行的運算順序。在內部,JAX 透過
jax.core.Tracer
類別實作此功能。- 轉換#
一種高階函數:也就是說,一個以函數作為輸入並輸出轉換後函數的函數。JAX 中的範例包括
jax.jit()
、jax.vmap()
和jax.grad()
。- VJP#
向量雅可比積 (Vector Jacobian Product) 的縮寫,有時也稱為反向模式自動微分。如需更多詳細資訊,請參閱 Vector-Jacobian products (VJPs, aka reverse-mode autodiff)。在 JAX 中,VJP 是一種轉換,透過
jax.vjp()
實作。另請參閱 JVP。- XLA#
加速線性代數 (Accelerated Linear Algebra) 的縮寫,XLA 是一種用於線性代數運算的領域特定編譯器,是 JIT 編譯的 JAX 程式碼的主要後端。請參閱 https://tensorflow.dev.org.tw/xla/。
- 弱型別#
一種 JAX 資料型別,其型別提升語意與 Python 純量相同;請參閱 JAX 中的弱型別值。