安裝#

使用 JAX 需要安裝兩個套件:jax,它是純 Python 且跨平台的,以及 jaxlib,它包含編譯後的二進位檔案,並且針對不同的作業系統和加速器需要不同的建置版本。

摘要: 對於大多數使用者來說,典型的 JAX 安裝可能看起來像這樣

  • 僅限 CPU (Linux/macOS/Windows)

    pip install -U jax
    
  • GPU (NVIDIA, CUDA 12)

    pip install -U "jax[cuda12]"
    
  • TPU (Google Cloud TPU VM)

    pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    

支援的平台#

下表顯示所有支援的平台和安裝選項。檢查您的設定是否受支援;如果顯示“是”“實驗性”,則點擊相應的連結以了解如何更詳細地安裝 JAX。

Linux, x86_64

Linux, aarch64

Mac, x86_64

Mac, aarch64

Windows, x86_64

Windows WSL2, x86_64

CPU

NVIDIA GPU

不適用

實驗性

Google Cloud TPU

不適用

不適用

不適用

不適用

不適用

AMD GPU

實驗性

實驗性

不適用

Apple GPU

不適用

不適用

實驗性

不適用

不適用

Intel GPU

實驗性

不適用

不適用

不適用

CPU#

pip 安裝:CPU#

目前,JAX 團隊發布適用於以下作業系統和架構的 jaxlib wheels

  • Linux, x86_64

  • Linux, aarch64

  • macOS, Intel

  • macOS, Apple ARM-based

  • Windows, x86_64 (實驗性)

若要安裝僅限 CPU 版本的 JAX,這對於在筆記型電腦上進行本機開發可能很有用,您可以執行

pip install --upgrade pip
pip install --upgrade jax

在 Windows 上,如果您的電腦上尚未安裝 Microsoft Visual Studio 2019 Redistributable,您可能也需要安裝它。

其他作業系統和架構需要從原始碼建置。嘗試在其他作業系統和架構上使用 pip 安裝可能會導致 jaxlib 未與 jax 一起安裝,儘管 jax 可能成功安裝 (但在執行階段會失敗)。

NVIDIA GPU#

JAX 支援 SM 版本 5.2 (Maxwell) 或更新版本的 NVIDIA GPU。請注意,由於 NVIDIA 已停止在其軟體中支援 Kepler GPU,因此 JAX 不再支援 Kepler 系列 GPU。

您必須先安裝 NVIDIA 驅動程式。建議您安裝 NVIDIA 提供的最新驅動程式,但對於 Linux 上的 CUDA 12,驅動程式版本必須 >= 525.60.13。

如果您需要將較新的 CUDA 工具組與較舊的驅動程式一起使用,例如在您無法輕易更新 NVIDIA 驅動程式的叢集上,您或許可以使用 NVIDIA 為此目的提供的 CUDA 向前相容性套件

pip 安裝:NVIDIA GPU (CUDA,透過 pip 安裝,較容易)#

有兩種方法可以安裝支援 NVIDIA GPU 的 JAX

  • 使用從 pip wheels 安裝的 NVIDIA CUDA 和 cuDNN

  • 使用自行安裝的 CUDA/cuDNN

JAX 團隊強烈建議使用 pip wheels 安裝 CUDA 和 cuDNN,因為這樣容易得多!

NVIDIA 僅針對 x86_64 和 aarch64 發布了 CUDA pip 套件;在其他平台上,您必須使用 CUDA 的本機安裝。

pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

如果 JAX 偵測到錯誤版本的 NVIDIA CUDA 程式庫,您需要檢查幾件事

  • 確保未設定 LD_LIBRARY_PATH,因為 LD_LIBRARY_PATH 可能會覆寫 NVIDIA CUDA 程式庫。

  • 確保安裝的 NVIDIA CUDA 程式庫是 JAX 要求的那些。重新執行上面的安裝命令應該可以解決問題。

pip 安裝:NVIDIA GPU (CUDA,本機安裝,較困難)#

如果您偏好使用預先安裝的 NVIDIA CUDA 副本,您必須先安裝 NVIDIA CUDAcuDNN

JAX 僅為 Linux x86_64 和 Linux aarch64 提供預先建置的 CUDA 相容 wheels。其他作業系統和架構的組合是可能的,但需要從原始碼建置 (請參閱 從原始碼建置 以了解更多資訊}。

您應該使用至少與您的 NVIDIA CUDA 工具組的對應驅動程式版本 一樣新的 NVIDIA 驅動程式版本。如果您需要將較新的 CUDA 工具組與較舊的驅動程式一起使用,例如在您無法輕易更新 NVIDIA 驅動程式的叢集上,您或許可以使用 NVIDIA 為此目的提供的 CUDA 向前相容性套件

JAX 目前提供一個 CUDA wheel 變體

使用以下版本建置

與以下版本相容

CUDA 12.3

CUDA >=12.1

CUDNN 9.1

CUDNN >=9.1, <10.0

NCCL 2.19

NCCL >=2.18

JAX 會檢查您的程式庫版本,如果版本不夠新,則會報告錯誤。設定 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 環境變數將停用檢查,但使用較舊版本的 CUDA 可能會導致錯誤或不正確的結果。

NCCL 是一個可選的相依性,僅在您執行多 GPU 計算時才需要。

若要安裝,請執行

pip install --upgrade pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]"

這些 pip 安裝不適用於 Windows,並且可能會靜默失敗;請參閱上方的表格。

您可以使用以下命令找到您的 CUDA 版本

nvcc --version

JAX 使用 LD_LIBRARY_PATH 來尋找 CUDA 程式庫,並使用 PATH 來尋找二進位檔案 (ptxasnvlink)。請確保這些路徑指向正確的 CUDA 安裝。

JAX 需要 libdevice10.bc,它通常來自 cuda-nvvm 套件。請確保它存在於您的 CUDA 安裝中。

如果您在使用預先建置的 wheels 時遇到任何錯誤或問題,請在 GitHub issue tracker 上告知 JAX 團隊。

NVIDIA GPU Docker 容器#

NVIDIA 提供 JAX Toolbox 容器,這些容器是最新的容器,其中包含 jax 的每夜版發行版本和一些模型/框架。

Google Cloud TPU#

pip 安裝:Google Cloud TPU#

JAX 為 Google Cloud TPU 提供預先建置的 wheels。若要安裝 JAX 以及適當版本的 jaxliblibtpu,您可以在您的 cloud TPU VM 中執行以下命令

pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

對於 Colab (https://colab.research.google.com/) 的使用者,請確保您使用的是 TPU v2,而不是較舊、已棄用的 TPU 執行階段。

Mac GPU#

pip 安裝#

Apple 提供了一個實驗性的 Metal 外掛程式。如需詳細資訊,請參閱 Apple 的 JAX on Metal 文件

注意: Metal 外掛程式有一些注意事項

  • Metal 外掛程式是新的且實驗性的,並且有許多 已知問題。請在 JAX issue tracker 上報告任何問題。

  • Metal 外掛程式目前需要非常特定的 jaxjaxlib 版本。隨著外掛程式 API 的成熟,此限制將會逐漸放寬。

AMD GPU (Linux)#

JAX 具有實驗性的 ROCm 支援。有兩種方法可以安裝 JAX

Intel GPU#

Intel 提供了一個實驗性的 OneAPI 外掛程式:intel-extension-for-openxla,適用於 Intel GPU 硬體。如需更多詳細資訊和安裝說明,請參閱以下兩種方法之一

  1. Pip 安裝:在 Intel GPU 上加速 JAX

  2. 使用 Intel 的 XLA Docker 容器

請回報與以下項目相關的任何問題

Conda (社群支援)#

Conda 安裝#

有一個社群支援的 jax Conda 建置版本。若要使用 conda 安裝它,只需執行

conda install jax -c conda-forge

如果您在具有 NVIDIA GPU 的機器上執行此命令,這應該會安裝 CUDA-enabled 版本的 jaxlib

為了確保您安裝的 jax 版本確實是 CUDA-enabled,請執行

conda install "jaxlib=*=*cuda*" jax -c conda-forge

如果您想要覆寫 JAX 使用的 CUDA 版本,或在沒有 GPU 的機器上安裝 CUDA 建置版本,請依照 conda-forge 網站的 Tips & tricks 章節中的說明進行操作。

前往 conda-forge jaxlibjax 儲存庫以取得更多詳細資訊。

JAX 每夜版安裝#

每夜版發行版本反映了建置時 main JAX 儲存庫的狀態,並且可能未通過完整的測試套件。

與安裝 JAX 發行版本的說明不同,在這裡我們在命令列上明確命名 JAX 的所有套件,因此如果有較新版本可用,pip 將會升級它們。

  • 僅限 CPU

pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • Google Cloud TPU

pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU (CUDA 12)

pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
  • NVIDIA GPU (CUDA 12) 舊版

針對單體式 CUDA jaxlibs 的歷史每夜版發行版本,請使用以下命令。您很可能不需要這個;不會再建置更多單體式 CUDA jaxlibs,並且現有的版本將在 2024 年 9 月到期。請使用上面的「CUDA 12」選項。

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html

從原始碼建置 JAX#

請參閱 從原始碼建置

安裝舊版 jaxlib wheels#

由於 Python 套件索引上的儲存限制,JAX 團隊會定期從 http://pypi.org/project/jax 上的發行版本中移除舊版 jaxlib wheels。這些仍然可以透過此處的 URL 直接安裝。例如

# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html

對於特定的舊版 GPU wheels,請務必使用 jax_cuda_releases.html URL;例如

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html