安裝#
使用 JAX 需要安裝兩個套件:jax
,它是純 Python 且跨平台的,以及 jaxlib
,它包含編譯後的二進位檔案,並且針對不同的作業系統和加速器需要不同的建置版本。
摘要: 對於大多數使用者來說,典型的 JAX 安裝可能看起來像這樣
僅限 CPU (Linux/macOS/Windows)
pip install -U jax
GPU (NVIDIA, CUDA 12)
pip install -U "jax[cuda12]"
TPU (Google Cloud TPU VM)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
支援的平台#
下表顯示所有支援的平台和安裝選項。檢查您的設定是否受支援;如果顯示“是”或“實驗性”,則點擊相應的連結以了解如何更詳細地安裝 JAX。
Linux, x86_64 |
Linux, aarch64 |
Mac, x86_64 |
Mac, aarch64 |
Windows, x86_64 |
Windows WSL2, x86_64 |
|
---|---|---|---|---|---|---|
CPU |
||||||
NVIDIA GPU |
否 |
不適用 |
否 |
|||
Google Cloud TPU |
不適用 |
不適用 |
不適用 |
不適用 |
不適用 |
|
AMD GPU |
否 |
不適用 |
否 |
否 |
||
Apple GPU |
不適用 |
否 |
不適用 |
不適用 |
不適用 |
|
Intel GPU |
不適用 |
不適用 |
不適用 |
否 |
否 |
CPU#
pip 安裝:CPU#
目前,JAX 團隊發布適用於以下作業系統和架構的 jaxlib
wheels
Linux, x86_64
Linux, aarch64
macOS, Intel
macOS, Apple ARM-based
Windows, x86_64 (實驗性)
若要安裝僅限 CPU 版本的 JAX,這對於在筆記型電腦上進行本機開發可能很有用,您可以執行
pip install --upgrade pip
pip install --upgrade jax
在 Windows 上,如果您的電腦上尚未安裝 Microsoft Visual Studio 2019 Redistributable,您可能也需要安裝它。
其他作業系統和架構需要從原始碼建置。嘗試在其他作業系統和架構上使用 pip 安裝可能會導致 jaxlib
未與 jax
一起安裝,儘管 jax
可能成功安裝 (但在執行階段會失敗)。
NVIDIA GPU#
JAX 支援 SM 版本 5.2 (Maxwell) 或更新版本的 NVIDIA GPU。請注意,由於 NVIDIA 已停止在其軟體中支援 Kepler GPU,因此 JAX 不再支援 Kepler 系列 GPU。
您必須先安裝 NVIDIA 驅動程式。建議您安裝 NVIDIA 提供的最新驅動程式,但對於 Linux 上的 CUDA 12,驅動程式版本必須 >= 525.60.13。
如果您需要將較新的 CUDA 工具組與較舊的驅動程式一起使用,例如在您無法輕易更新 NVIDIA 驅動程式的叢集上,您或許可以使用 NVIDIA 為此目的提供的 CUDA 向前相容性套件。
pip 安裝:NVIDIA GPU (CUDA,透過 pip 安裝,較容易)#
有兩種方法可以安裝支援 NVIDIA GPU 的 JAX
使用從 pip wheels 安裝的 NVIDIA CUDA 和 cuDNN
使用自行安裝的 CUDA/cuDNN
JAX 團隊強烈建議使用 pip wheels 安裝 CUDA 和 cuDNN,因為這樣容易得多!
NVIDIA 僅針對 x86_64 和 aarch64 發布了 CUDA pip 套件;在其他平台上,您必須使用 CUDA 的本機安裝。
pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
如果 JAX 偵測到錯誤版本的 NVIDIA CUDA 程式庫,您需要檢查幾件事
確保未設定
LD_LIBRARY_PATH
,因為LD_LIBRARY_PATH
可能會覆寫 NVIDIA CUDA 程式庫。確保安裝的 NVIDIA CUDA 程式庫是 JAX 要求的那些。重新執行上面的安裝命令應該可以解決問題。
pip 安裝:NVIDIA GPU (CUDA,本機安裝,較困難)#
如果您偏好使用預先安裝的 NVIDIA CUDA 副本,您必須先安裝 NVIDIA CUDA 和 cuDNN。
JAX 僅為 Linux x86_64 和 Linux aarch64 提供預先建置的 CUDA 相容 wheels。其他作業系統和架構的組合是可能的,但需要從原始碼建置 (請參閱 從原始碼建置 以了解更多資訊}。
您應該使用至少與您的 NVIDIA CUDA 工具組的對應驅動程式版本 一樣新的 NVIDIA 驅動程式版本。如果您需要將較新的 CUDA 工具組與較舊的驅動程式一起使用,例如在您無法輕易更新 NVIDIA 驅動程式的叢集上,您或許可以使用 NVIDIA 為此目的提供的 CUDA 向前相容性套件。
JAX 目前提供一個 CUDA wheel 變體
使用以下版本建置 |
與以下版本相容 |
---|---|
CUDA 12.3 |
CUDA >=12.1 |
CUDNN 9.1 |
CUDNN >=9.1, <10.0 |
NCCL 2.19 |
NCCL >=2.18 |
JAX 會檢查您的程式庫版本,如果版本不夠新,則會報告錯誤。設定 JAX_SKIP_CUDA_CONSTRAINTS_CHECK
環境變數將停用檢查,但使用較舊版本的 CUDA 可能會導致錯誤或不正確的結果。
NCCL 是一個可選的相依性,僅在您執行多 GPU 計算時才需要。
若要安裝,請執行
pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]"
這些 pip
安裝不適用於 Windows,並且可能會靜默失敗;請參閱上方的表格。
您可以使用以下命令找到您的 CUDA 版本
nvcc --version
JAX 使用 LD_LIBRARY_PATH
來尋找 CUDA 程式庫,並使用 PATH
來尋找二進位檔案 (ptxas
、nvlink
)。請確保這些路徑指向正確的 CUDA 安裝。
JAX 需要 libdevice10.bc,它通常來自 cuda-nvvm 套件。請確保它存在於您的 CUDA 安裝中。
如果您在使用預先建置的 wheels 時遇到任何錯誤或問題,請在 GitHub issue tracker 上告知 JAX 團隊。
NVIDIA GPU Docker 容器#
NVIDIA 提供 JAX Toolbox 容器,這些容器是最新的容器,其中包含 jax 的每夜版發行版本和一些模型/框架。
Google Cloud TPU#
pip 安裝:Google Cloud TPU#
JAX 為 Google Cloud TPU 提供預先建置的 wheels。若要安裝 JAX 以及適當版本的 jaxlib
和 libtpu
,您可以在您的 cloud TPU VM 中執行以下命令
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
對於 Colab (https://colab.research.google.com/) 的使用者,請確保您使用的是 TPU v2,而不是較舊、已棄用的 TPU 執行階段。
Mac GPU#
pip 安裝#
Apple 提供了一個實驗性的 Metal 外掛程式。如需詳細資訊,請參閱 Apple 的 JAX on Metal 文件。
注意: Metal 外掛程式有一些注意事項
Metal 外掛程式是新的且實驗性的,並且有許多 已知問題。請在 JAX issue tracker 上報告任何問題。
Metal 外掛程式目前需要非常特定的
jax
和jaxlib
版本。隨著外掛程式 API 的成熟,此限制將會逐漸放寬。
AMD GPU (Linux)#
JAX 具有實驗性的 ROCm 支援。有兩種方法可以安裝 JAX
使用 AMD 的 Docker 容器;或
從原始碼建置。請參閱 為 AMD GPU 建置 ROCm jaxlib 的其他注意事項 章節。
Intel GPU#
Intel 提供了一個實驗性的 OneAPI 外掛程式:intel-extension-for-openxla,適用於 Intel GPU 硬體。如需更多詳細資訊和安裝說明,請參閱以下兩種方法之一
Pip 安裝:在 Intel GPU 上加速 JAX。
請回報與以下項目相關的任何問題
JAX: JAX issue tracker。
Intel 的 OpenXLA 外掛程式:Intel-extension-for-openxla issue tracker。
Conda (社群支援)#
Conda 安裝#
有一個社群支援的 jax
Conda 建置版本。若要使用 conda
安裝它,只需執行
conda install jax -c conda-forge
如果您在具有 NVIDIA GPU 的機器上執行此命令,這應該會安裝 CUDA-enabled 版本的 jaxlib
。
為了確保您安裝的 jax 版本確實是 CUDA-enabled,請執行
conda install "jaxlib=*=*cuda*" jax -c conda-forge
如果您想要覆寫 JAX 使用的 CUDA 版本,或在沒有 GPU 的機器上安裝 CUDA 建置版本,請依照 conda-forge
網站的 Tips & tricks 章節中的說明進行操作。
JAX 每夜版安裝#
每夜版發行版本反映了建置時 main JAX 儲存庫的狀態,並且可能未通過完整的測試套件。
與安裝 JAX 發行版本的說明不同,在這裡我們在命令列上明確命名 JAX 的所有套件,因此如果有較新版本可用,pip
將會升級它們。
僅限 CPU
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
Google Cloud TPU
pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
NVIDIA GPU (CUDA 12)
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
NVIDIA GPU (CUDA 12) 舊版
針對單體式 CUDA jaxlibs 的歷史每夜版發行版本,請使用以下命令。您很可能不需要這個;不會再建置更多單體式 CUDA jaxlibs,並且現有的版本將在 2024 年 9 月到期。請使用上面的「CUDA 12」選項。
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
從原始碼建置 JAX#
請參閱 從原始碼建置。
安裝舊版 jaxlib
wheels#
由於 Python 套件索引上的儲存限制,JAX 團隊會定期從 http://pypi.org/project/jax 上的發行版本中移除舊版 jaxlib
wheels。這些仍然可以透過此處的 URL 直接安裝。例如
# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
對於特定的舊版 GPU wheels,請務必使用 jax_cuda_releases.html
URL;例如
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html