JAX:高效能陣列運算#

高效能陣列運算

JAX 是一個 Python 函式庫,用於加速器導向的陣列計算和程式轉換,專為高效能數值計算和大規模機器學習而設計。

熟悉的 API

JAX 提供熟悉的 NumPy 風格 API,方便研究人員和工程師採用。

轉換

JAX 包含可組合的函式轉換,用於編譯、批次處理、自動微分和平行化。

隨處執行

相同的程式碼可在多個後端執行,包括 CPU、GPU 和 TPU

安裝
安裝
開始使用
JAX 入門
使用者指南
使用者指南

如果您希望訓練神經網路,請使用 Flax 並從其教學開始。如需在 JAX 上建構的端對端轉換器函式庫,請參閱 MaxText

生態系統#

JAX 本身範圍狹窄,專注於高效的陣列運算和程式轉換。圍繞 JAX 建構的是不斷發展的機器學習和數值計算工具生態系統;以下僅是其中一小部分範例

神經網路

最佳化器和求解器

雜項工具

機率程式設計

物理與模擬

已開發更多基於 JAX 的函式庫;社群運營的 Awesome JAX 頁面維護著最新的列表。