快速入門#
JAX 是一個用於陣列導向數值計算的函式庫(à la NumPy),具有自動微分和 JIT 編譯功能,可實現高效能機器學習研究.
本文檔快速概述了 JAX 的基本功能,讓您可以快速開始使用 JAX
JAX 為在 CPU、GPU 或 TPU 上運行的計算提供統一的類 NumPy 介面,無論是本機或分散式設定。
JAX 透過 Open XLA(一個開源機器學習編譯器生態系統)具有內建的即時 (JIT) 編譯功能。
JAX 函式透過其自動微分轉換支援有效率的梯度評估。
JAX 函式可以自動向量化,以有效率地將它們映射到代表輸入批次的陣列上。
安裝#
JAX 可以直接從 Python Package Index 安裝在 Linux、Windows 和 macOS 上的 CPU 上
pip install jax
或者,對於 NVIDIA GPU
pip install -U "jax[cuda12]"
如需更詳細的平台特定安裝資訊,請查看安裝。
JAX 作為 NumPy#
大多數 JAX 用法是透過熟悉的 jax.numpy
API,通常以 jnp
別名匯入
import jax.numpy as jnp
透過此匯入,您可以立即以類似於典型 NumPy 程式的方式使用 JAX,包括使用 NumPy 風格的陣列建立函式、Python 函式和運算子,以及陣列屬性和方法
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x))
[0. 1.05 2.1 3.1499999 4.2 ]
一旦您開始深入研究,您會發現 JAX 陣列和 NumPy 陣列之間存在一些差異;這些差異在🔪 JAX - 尖銳之處 🔪中進行了探討。
使用 jax.jit()
的即時編譯#
JAX 在 GPU 或 TPU 上透明地運行(如果沒有,則會回退到 CPU)。但是,在上面的範例中,JAX 正在一次將核心分派到晶片一個操作。如果我們有一系列操作,我們可以使用 jax.jit()
函式,使用 XLA 將這一系列操作一起編譯。
我們可以使用 IPython 的 %timeit
快速基準測試我們的 selu
函式,使用 block_until_ready()
來考量 JAX 的動態分派(請參閱非同步調度)
from jax import random
key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
3.18 ms ± 16.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(請注意,我們使用了 jax.random
來產生一些隨機數;有關如何在 JAX 中產生隨機數的詳細資訊,請查看虛擬隨機數)。
我們可以透過 jax.jit()
轉換來加速此函式的執行,這將在第一次呼叫 selu
時進行 jit 編譯,並在之後快取。
from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x) # compiles on first call
%timeit selu_jit(x).block_until_ready()
851 μs ± 2.05 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
上面的計時代表在 CPU 上的執行,但相同的程式碼可以在 GPU 或 TPU 上運行,通常可以獲得更大的加速。
有關 JAX 中 JIT 編譯的更多資訊,請查看即時編譯。
使用 jax.grad()
取得導數#
除了透過 JIT 編譯轉換函式外,JAX 還提供其他轉換。jax.grad()
就是這樣一種轉換,它執行自動微分 (autodiff)
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
讓我們使用有限差分法驗證我們的結果是否正確。
def first_finite_differences(f, x, eps=1E-3):
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761 0.10502338]
grad()
和 jit()
轉換可以組合並且可以任意混合。在上面的範例中,我們 jitted sum_logistic
然後取其導數。我們可以更進一步
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256
除了純量值函式之外,jax.jacobian()
轉換可用於計算向量值函式的完整 Jacobian 矩陣
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1. 0. 0. ]
[0. 2.7182817 0. ]
[0. 0. 7.389056 ]]
對於更進階的自動微分操作,您可以使用 jax.vjp()
進行反向模式向量-Jacobian 乘積,以及 jax.jvp()
和 jax.linearize()
進行正向模式 Jacobian-向量乘積。這兩者可以與彼此以及其他 JAX 轉換任意組合。例如,jax.jvp()
和 jax.vjp()
用於定義正向模式 jax.jacfwd()
和反向模式 jax.jacrev()
,用於分別以正向和反向模式計算 Jacobian 矩陣。以下是一種組合它們以建立有效計算完整 Hessian 矩陣的函式的方法
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0. -0. -0. ]
[-0. -0.09085776 -0. ]
[-0. -0. -0.07996249]]
這種組合在實務中產生有效率的程式碼;這或多或少是 JAX 內建 jax.hessian()
函式的實作方式。
有關 JAX 中自動微分的更多資訊,請查看自動微分。
使用 jax.vmap()
的自動向量化#
另一個有用的轉換是 vmap()
,即向量化映射。它具有沿陣列軸映射函式的熟悉語意,但它不是顯式迴圈處理函式呼叫,而是將函式轉換為原生向量化版本,以獲得更好的效能。當與 jit()
組合時,它可以與手動重寫函式以在額外的批次維度上運作一樣高效能。
我們將使用一個簡單的範例,並使用 vmap()
將矩陣-向量乘積提升為矩陣-矩陣乘積。雖然在這種特定情況下手動執行很容易,但相同的技術可以應用於更複雜的函式。
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
def apply_matrix(x):
return jnp.dot(mat, x)
apply_matrix
函式將向量映射到向量,但我們可能希望在矩陣中逐列應用它。我們可以透過在 Python 中迴圈處理批次維度來做到這一點,但這通常會導致效能不佳。
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
1.77 ms ± 2.05 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
熟悉 jnp.dot
函式的程式設計師可能會意識到可以重寫 apply_matrix
以避免顯式迴圈,方法是使用 jnp.dot
的內建批次處理語意
import numpy as np
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, mat.T)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
47.8 μs ± 1.73 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
但是,隨著函式變得越來越複雜,這種手動批次處理變得更加困難且容易出錯。vmap()
轉換旨在自動將函式轉換為批次感知版本
from jax import vmap
@jit
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
49.1 μs ± 1.3 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
正如您所預期的,vmap()
可以與 jit()
、grad()
和任何其他 JAX 轉換任意組合。
有關 JAX 中自動向量化的更多資訊,請查看自動向量化。
這只是 JAX 功能的一小部分。我們很高興看到您如何使用它!