JAX:高效能陣列運算#

高效能陣列運算

JAX 是一個 Python 函式庫,用於加速器導向的陣列計算和程式轉換,專為高效能數值計算和大規模機器學習而設計。

熟悉的 API

JAX 提供熟悉的 NumPy 風格 API,方便研究人員和工程師採用。

轉換

JAX 包括用於編譯、批次處理、自動微分和平行化的可組合函數轉換。

隨處執行

相同的程式碼可在多個後端執行,包括 CPU、GPU 和 TPU

安裝
安裝
開始使用
JAX 入門
使用者指南
使用者指南

如果您希望訓練神經網路,請使用 Flax 並從其教學課程開始。對於基於 JAX 建構的端對端轉換器函式庫,請參閱 MaxText

生態系統#

JAX 本身範圍狹窄,專注於高效陣列運算和程式轉換。圍繞 JAX 建構的是一個不斷發展的機器學習和數值計算工具生態系統;以下僅是現有工具的一小部分範例

神經網路

最佳化器與求解器

雜項工具

機率程式設計

物理與模擬

已開發出更多基於 JAX 的函式庫;社群運作的 Awesome JAX 頁面維護著最新的列表。