Autodiff 食譜#
JAX 具有相當通用的自動微分系統。在本筆記本中,我們將瀏覽一堆很棒的自動微分概念,您可以為自己的工作挑選使用,從基礎知識開始。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
梯度#
從 grad
開始#
您可以使用 grad
對函式進行微分
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad
接受一個函式並傳回一個函式。如果您有一個 Python 函式 f
,其評估數學函式 \(f\),則 grad(f)
是一個 Python 函式,其評估數學函式 \(\nabla f\)。這表示 grad(f)(x)
代表值 \(\nabla f(x)\)。
由於 grad
對函式進行操作,您可以將其應用於自身的輸出,以根據需要進行多次微分
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
讓我們看看如何在線性邏輯迴歸模型中使用 grad
計算梯度。首先,設定
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
將 grad
函式與其 argnums
引數搭配使用,以針對位置引數對函式進行微分。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.43314594 -0.7354604 -1.2598921 ]
W_grad [-0.43314594 -0.7354604 -1.2598921 ]
b_grad -0.6900177
W_grad [-0.43314594 -0.7354604 -1.2598921 ]
b_grad -0.6900177
此 grad
API 與 Spivak 經典著作《流形上的微積分》(1965) 中的出色符號直接對應,也用於 Sussman 和 Wisdom 的古典力學的結構與解釋 (2015) 以及他們的泛函微分幾何 (2013)。這兩本書都是開放存取。尤其是參閱《泛函微分幾何》的「序言」章節,以了解對此符號的辯護。
基本上,當使用 argnums
引數時,如果 f
是用於評估數學函式 \(f\) 的 Python 函式,則 Python 表達式 grad(f, i)
會評估為用於評估 \(\partial_i f\) 的 Python 函式。
對巢狀列表、元組和字典進行微分#
針對標準 Python 容器進行微分即可運作,因此請隨意使用元組、列表和字典 (以及任意巢狀結構)。
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32), 'b': Array(-0.6900177, dtype=float32)}
您可以註冊您自己的容器類型,使其不僅適用於 grad
,也適用於所有 JAX 轉換 (jit
、vmap
等)。
使用 value_and_grad
評估函式及其梯度#
另一個方便的函式是 value_and_grad
,用於有效率地計算函式的值及其梯度值
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 2.9729187
loss value 2.9729187
對照數值差異進行檢查#
關於導數的一大優點是,它們可以使用有限差分法直接檢查
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.6890297
b_grad_autodiff -0.6900177
W_dirderiv_numerical 1.3017654
W_dirderiv_autodiff 1.3006743
JAX 提供了一個簡單的便利函式,其基本上執行相同的操作,但會檢查您想要的任何微分階數
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
使用 grad
的 grad
計算 Hessian 向量積#
我們可以使用高階 grad
做的一件事是建立 Hessian 向量積函式。(稍後我們將撰寫一個更有效率的實作,其混合了前向模式和反向模式,但這個實作將使用純反向模式。)
Hessian 向量積函式在截斷牛頓共軛梯度演算法中可能很有用,以最小化平滑凸函式,或用於研究神經網路訓練目標的曲率 (例如 1、2、3、4)。
對於具有連續二階導數 (因此 Hessian 矩陣是對稱的) 的純量值函式 \(f : \mathbb{R}^n \to \mathbb{R}\),點 \(x \in \mathbb{R}^n\) 的 Hessian 寫為 \(\partial^2 f(x)\)。然後,Hessian 向量積函式能夠評估
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
適用於任何 \(v \in \mathbb{R}^n\)。
訣竅是不實例化完整的 Hessian 矩陣:如果 \(n\) 很大,可能在神經網路的背景下達到數百萬或數十億,那麼這可能無法儲存。
幸運的是,grad
已經提供了一種撰寫有效率的 Hessian 向量積函式的方法。我們只需要使用恆等式
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一個新的純量值函式,其將 \(f\) 在 \(x\) 的梯度與向量 \(v\) 進行點積運算。請注意,我們只對向量值引數的純量值函式進行微分,這正是我們知道 grad
有效率的地方。
在 JAX 程式碼中,我們可以直接撰寫這個
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
此範例顯示您可以自由使用詞法閉包,而 JAX 永遠不會感到困擾或困惑。
一旦我們了解如何計算密集 Hessian 矩陣,我們將在下方幾個儲存格中檢查此實作。我們也將撰寫一個更好的版本,其同時使用前向模式和反向模式。
使用 jacfwd
和 jacrev
計算 Jacobian 矩陣和 Hessian 矩陣#
您可以使用 jacfwd
和 jacrev
函式計算完整的 Jacobian 矩陣
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415 0.1091874 0.07506633]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447992]
[ 0.00574409 -0.0193281 0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415 0.10918739 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
這兩個函式計算相同的值 (精確到機器數值),但在實作上有所不同:jacfwd
使用前向模式自動微分,對於「高」Jacobian 矩陣 (輸出多於輸入) 效率更高,而 jacrev
使用反向模式,對於「寬」Jacobian 矩陣 (輸入多於輸出) 效率更高。對於接近正方形的矩陣,jacfwd
可能比 jacrev
更有優勢。
您也可以將 jacfwd
和 jacrev
與容器類型搭配使用
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05069415 0.10918739 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
Jacobian from b to logits is
[0.09748875 0.16102302 0.24190766 0.00776229]
如需關於前向模式和反向模式的更多詳細資訊,以及如何盡可能有效率地實作 jacfwd
和 jacrev
,請繼續閱讀!
使用這兩個函式的組合,我們可以計算密集 Hessian 矩陣
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932 0.04434624 0.03048803]
[ 0.04434623 0.09551499 0.06566654]
[ 0.03048803 0.06566655 0.04514575]]
[[-0.0743913 0.09129842 -0.01268033]
[ 0.09129842 -0.11204806 0.01556223]
[-0.01268034 0.01556223 -0.00216142]]
[[ 0.01176856 0.00135791 -0.02942139]
[ 0.00135791 0.00015668 -0.00339478]
[-0.0294214 -0.00339478 0.07355348]]
[[-0.00418412 0.014079 -0.00785936]
[ 0.014079 -0.04737393 0.02644569]
[-0.00785936 0.02644569 -0.01476286]]]
這種形狀是有道理的:如果我們從函式 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 開始,那麼在點 \(x \in \mathbb{R}^n\),我們預期會得到以下形狀
\(f(x) \in \mathbb{R}^m\),\(f\) 在 \(x\) 的值,
\(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 的 Jacobian 矩陣,
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 的 Hessian 矩陣,
依此類推。
為了實作 hessian
,我們可以使用 jacfwd(jacrev(f))
或 jacrev(jacfwd(f))
或這兩者的任何其他組合。但前向模式優於反向模式通常是最有效率的。那是因為在內部的 Jacobian 計算中,我們通常對寬 Jacobian 函式 (可能像損失函式 \(f : \mathbb{R}^n \to \mathbb{R}\)) 進行微分,而在外部的 Jacobian 計算中,我們對具有方形 Jacobian 的函式 (由於 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)) 進行微分,而這正是前向模式勝出的地方。
製作方法:兩個基礎自動微分函式#
Jacobian-向量積 (JVP,又稱前向模式自動微分)#
JAX 包含前向模式和反向模式自動微分的有效率且通用的實作。熟悉的 grad
函式建立在反向模式之上,但為了說明這兩種模式的差異,以及每種模式何時可能有用,我們需要一些數學背景知識。
數學中的 JVP#
在數學上,給定一個函式 \(f : \mathbb{R}^n \to \mathbb{R}^m\),在輸入點 \(x \in \mathbb{R}^n\) 評估的 \(f\) 的 Jacobian 矩陣 (表示為 \(\partial f(x)\)) 通常被視為 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的矩陣
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
但我們也可以將 \(\partial f(x)\) 視為線性映射,其將 \(f\) 定義域在點 \(x\) (其只是 \(\mathbb{R}^n\) 的另一個副本) 的切線空間映射到 \(f\) 值域在點 \(f(x)\) ( \(\mathbb{R}^m\) 的副本) 的切線空間
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
此映射稱為 \(f\) 在 \(x\) 的前推映射。Jacobian 矩陣只是此線性映射在標準基底中的矩陣。
如果我們不限定於一個特定的輸入點 \(x\),那麼我們可以將函式 \(\partial f\) 視為首先取得一個輸入點,然後傳回該輸入點的 Jacobian 線性映射
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
特別是,我們可以取消柯里化,以便在給定輸入點 \(x \in \mathbb{R}^n\) 和切線向量 \(v \in \mathbb{R}^n\) 的情況下,我們取回 \(\mathbb{R}^m\) 中的輸出切線向量。我們將從 \((x, v)\) 對到輸出切線向量的映射稱為 Jacobian-向量積,並將其寫為
\(\qquad (x, v) \mapsto \partial f(x) v\)
JAX 程式碼中的 JVP#
回到 Python 程式碼中,JAX 的 jvp
函式對此轉換進行建模。給定一個評估 \(f\) 的 Python 函式,JAX 的 jvp
是一種取得 Python 函式以評估 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的方法。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
以類似 Haskell 的類型簽名而言,我們可以寫成
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中我們使用 T a
來表示 a
的切線空間的類型。換句話說,jvp
接受類型為 a -> b
的函式、類型為 a
的值和類型為 T a
的切線向量值作為引數。它會傳回一對值,其中包含類型為 b
的值和類型為 T b
的輸出切線向量。
jvp
轉換後的函式的評估方式與原始函式非常相似,但它會與類型為 a
的每個原始值配對,並沿著類型為 T a
的切線值推送。對於原始函式本來會套用的每個基本數值運算,jvp
轉換後的函式會針對該基本運算執行「JVP 規則」,該規則既評估原始值上的基本運算,又將基本運算的 JVP 應用於這些原始值。
這種評估策略對計算複雜度產生了一些直接影響:由於我們在進行時評估 JVP,因此我們不需要儲存任何東西以供稍後使用,因此記憶體成本與計算的深度無關。此外,jvp
轉換後的函式的 FLOP 成本約為僅評估函式的成本的 3 倍 (一個工作單位用於評估原始函式,例如 sin(x)
;一個單位用於線性化,例如 cos(x)
;以及一個單位用於將線性化函式應用於向量,例如 cos_x * v
)。換句話說,對於固定的原始點 \(x\),我們可以評估 \(v \mapsto \partial f(x) \cdot v\),其邊際成本與評估 \(f\) 的成本大致相同。
這種記憶體複雜度聽起來相當引人注目!那麼為什麼我們在機器學習中不常看到前向模式呢?
為了回答這個問題,首先想想您如何使用 JVP 來建立完整的 Jacobian 矩陣。如果我們將 JVP 應用於 one-hot 切線向量,它會揭示 Jacobian 矩陣的一列,對應於我們輸入的非零項目。因此,我們可以一次建立一個完整的 Jacobian 矩陣列,而取得每一列的成本與一次函式評估的成本大致相同。這對於具有「高」Jacobian 矩陣的函式來說是有效率的,但對於「寬」Jacobian 矩陣來說是沒效率的。
如果您在機器學習中執行基於梯度的最佳化,您可能想要最小化從 \(\mathbb{R}^n\) 中的參數到 \(\mathbb{R}\) 中的純量損失值的損失函式。這表示此函式的 Jacobian 矩陣是一個非常寬的矩陣:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我們通常將其識別為梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。一次建立一列矩陣,每次呼叫都花費與評估原始函式相似的 FLOP 數量,這確實看起來效率不高!特別是,對於訓練神經網路,其中 \(f\) 是訓練損失函式,而 \(n\) 可能達到數百萬或數十億,這種方法根本無法擴展。
為了更好地處理此類函式,我們只需要使用反向模式。
向量-雅可比積 (VJPs,又稱反向模式自動微分)#
正向模式回傳一個用於評估雅可比-向量積的函數,我們可以接著用它來逐列建構雅可比矩陣;反向模式則是一種回傳用於評估向量-雅可比積(等價於雅可比-轉置-向量積)函數的方法,我們可以接著用它來逐行建構雅可比矩陣。
VJPs 的數學表示#
讓我們再次考慮一個函數 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。從我們對 JVPs 的符號表示開始,VJPs 的符號表示非常簡單
\(\qquad (x, v) \mapsto v \partial f(x)\),
其中 \(v\) 是 \(f\) 在 \(x\) 點的餘切空間的元素(與 \(\mathbb{R}^m\) 的另一個副本同構)。當嚴謹地看待時,我們應該將 \(v\) 視為一個線性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),而當我們寫下 \(v \partial f(x)\) 時,我們指的是函數合成 \(v \circ \partial f(x)\),其中類型是匹配的,因為 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常見情況下,我們可以將 \(v\) 視為 \(\mathbb{R}^m\) 中的向量,並且幾乎可以互換地使用這兩者,就像我們有時可能會在「行向量」和「列向量」之間切換而沒有太多註解一樣。
有了這種識別,我們可以將 VJP 的線性部分替換地視為 JVP 線性部分的轉置(或伴隨共軛)
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
對於給定點 \(x\),我們可以將簽名寫成
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
餘切空間上的對應映射通常被稱為 拉回 (pullback) \(f\) 在 \(x\) 點的值。對我們而言,關鍵在於它從看起來像 \(f\) 輸出的東西,變成了看起來像 \(f\) 輸入的東西,就像我們可能從轉置線性函數中預期的那樣。
JAX 程式碼中的 VJP#
從數學回到 Python,JAX 函數 vjp
可以接收一個用於評估 \(f\) 的 Python 函數,並回傳一個用於評估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函數。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
以類似 Haskell 的類型簽名而言,我們可以寫成
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
其中我們使用 CT a
來表示 a
的餘切空間的類型。換句話說,vjp
接收一個類型為 a -> b
的函數和一個類型為 a
的點作為參數,並回傳一個由類型為 b
的值和類型為 CT b -> CT a
的線性映射組成的pair。
這非常棒,因為它讓我們可以逐行建構雅可比矩陣,並且評估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOP 成本僅約為評估 \(f\) 成本的三倍。特別是,如果我們想要函數 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我們只需呼叫一次即可完成。這就是為什麼即使對於數百萬或數十億參數的神經網路訓練損失函數等目標,grad
對於基於梯度的優化也是有效率的。
但這是有代價的:雖然 FLOPs 是友好的,但記憶體會隨著計算深度而擴展。此外,實作傳統上比正向模式更複雜,儘管 JAX 有一些訣竅(那是未來筆記本的故事!)。
有關反向模式如何運作的更多資訊,請參閱 2017 年深度學習暑期學校的這段教學影片。
使用 VJP 的向量值梯度#
如果您對取得向量值梯度(如 tf.gradients
)感興趣
from jax import vjp
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
使用正向和反向模式的黑塞矩陣-向量積#
在先前的章節中,我們僅使用反向模式實作了黑塞矩陣-向量積函數(假設二階導數連續)
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
這很有效率,但我們可以透過結合使用正向模式和反向模式來做得更好並節省一些記憶體。
在數學上,給定一個要微分的函數 \(f : \mathbb{R}^n \to \mathbb{R}\)、一個函數線性化的點 \(x \in \mathbb{R}^n\) 以及一個向量 \(v \in \mathbb{R}^n\),我們想要的黑塞矩陣-向量積函數是
\((x, v) \mapsto \partial^2 f(x) v\)
考慮輔助函數 \(g : \mathbb{R}^n \to \mathbb{R}^n\),定義為 \(f\) 的導數(或梯度),即 \(g(x) = \partial f(x)\)。我們只需要它的 JVP,因為那將給我們
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
我們可以幾乎直接將其轉換為程式碼
from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
更好的是,由於我們不必直接呼叫 jnp.dot
,因此這個 hvp
函數適用於任何形狀的陣列和任意容器類型(例如儲存為巢狀列表/字典/元組的向量),甚至不依賴於 jax.numpy
。
以下是如何使用它的範例
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
您可能會考慮的另一種寫法是使用反向-在-正向之上
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
雖然這不是那麼好,因為正向模式的 overhead 比反向模式少,並且由於此處的外部微分運算子必須微分比內部運算子更大的計算,因此將正向模式保持在外部效果最佳
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
5.36 ms ± 75.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
13.6 ms ± 9.57 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
19.3 ms ± 13.4 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
55.7 ms ± 2.9 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
組合 VJP、JVP 和 vmap
#
雅可比-矩陣和矩陣-雅可比積#
現在我們有了 jvp
和 vjp
轉換,它們為我們提供了每次推進或拉回單個向量的函數,我們可以利用 JAX 的 vmap
轉換一次推進和拉回整個基底。特別是,我們可以利用它來編寫快速的矩陣-雅可比和雅可比-矩陣積。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
181 ms ± 775 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
5.91 ms ± 131 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1261/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
242 ms ± 401 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
2.92 ms ± 121 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jacfwd
和 jacrev
的實作#
現在我們已經了解了快速的雅可比-矩陣和矩陣-雅可比積,不難猜測如何編寫 jacfwd
和 jacrev
。我們只需使用相同的技術一次推進或拉回整個標準基底(與單位矩陣同構)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd 無法做到這一點。我們在 Autograd 中反向模式 jacobian
的 實作必須使用外部迴圈 map
一次拉回一個向量。一次推進一個向量通過計算的效率遠低於使用 vmap
將它們全部批次處理在一起。
Autograd 無法做到的另一件事是 jit
。有趣的是,無論您在要微分的函數中使用多少 Python 動態性,我們始終可以在計算的線性部分上使用 jit
。例如
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
複數和微分#
JAX 非常擅長處理複數和微分。為了同時支援 全純和非全純微分,從 JVP 和 VJP 的角度思考會有所幫助。
考慮一個複數到複數的函數 \(f: \mathbb{C} \to \mathbb{C}\),並將其與對應的函數 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 識別,
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是說,我們分解了 \(f(z) = u(x, y) + v(x, y) i\),其中 \(z = x + y i\),並將 \(\mathbb{C}\) 與 \(\mathbb{R}^2\) 識別以獲得 \(g\)。
由於 \(g\) 僅涉及實數輸入和輸出,我們已經知道如何為其編寫雅可比-向量積,例如給定一個切向量 \((c, d) \in \mathbb{R}^2\),即
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
為了獲得應用於切向量 \(c + di \in \mathbb{C}\) 的原始函數 \(f\) 的 JVP,我們只需使用相同的定義並將結果識別為另一個複數,
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
這就是我們對 \(\mathbb{C} \to \mathbb{C}\) 函數的 JVP 的定義!請注意,\(f\) 是否為全純函數並不重要:JVP 是明確的。
這是一個檢查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
那 VJP 呢?我們做一些非常相似的事情:對於餘切向量 \(c + di \in \mathbb{C}\),我們將 \(f\) 的 VJP 定義為
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
那些負號是怎麼回事?它們只是為了處理複數共軛,以及我們正在處理共向量的事實。
這是 VJP 規則的檢查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
那像是 grad
、jacfwd
和 jacrev
這樣的便利包裝器呢?
對於 \(\mathbb{R} \to \mathbb{R}\) 函數,回想一下我們將 grad(f)(x)
定義為 vjp(f, x)[1](1.0)
,這之所以有效,是因為將 VJP 應用於 1.0
值會揭示梯度(即雅可比矩陣,或導數)。對於 \(\mathbb{C} \to \mathbb{R}\) 函數,我們可以做同樣的事情:我們仍然可以使用 1.0
作為餘切向量,我們只會得到一個複數結果,總結了完整的雅可比矩陣
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
對於一般的 \(\mathbb{C} \to \mathbb{C}\) 函數,雅可比矩陣具有 4 個實數值的自由度(如上面的 2x2 雅可比矩陣所示),因此我們不能期望在一個複數中表示所有這些自由度。但對於全純函數,我們可以!全純函數恰好是一個 \(\mathbb{C} \to \mathbb{C}\) 函數,其導數可以表示為單個複數。(柯西-黎曼方程式確保上面的 2x2 雅可比矩陣具有複平面中縮放和旋轉矩陣的特殊形式,即單個複數在乘法下的作用。)我們可以通過使用 1.0
的共向量單次呼叫 vjp
來揭示該複數。
由於這僅適用於全純函數,為了使用此技巧,我們需要向 JAX 保證我們的函數是全純的;否則,當 grad
用於複數輸出函數時,JAX 會引發錯誤
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
所有 holomorphic=True
保證所做的只是在輸出為複數值時停用錯誤。當函數不是全純函數時,我們仍然可以寫入 holomorphic=True
,但我們得到的答案不會代表完整的雅可比矩陣。相反,它將是函數的雅可比矩陣,我們只會丟棄輸出的虛部
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
對於 grad
在此處的工作方式,有一些有用的結果
我們可以在全純 \(\mathbb{C} \to \mathbb{C}\) 函數上使用
grad
。我們可以通過朝著
grad(f)(x)
的共軛方向邁進,來使用grad
優化 \(f : \mathbb{C} \to \mathbb{R}\) 函數,例如複數參數x
的實值損失函數。如果我們有一個 \(\mathbb{R} \to \mathbb{R}\) 函數,它碰巧在內部使用了一些複數值運算(其中一些必須是非全純的,例如卷積中使用的 FFT),那麼
grad
仍然有效,並且我們得到的結果與僅使用實數值的實作所給出的結果相同。
在任何情況下,JVP 和 VJP 始終是明確的。如果我們想計算非全純 \(\mathbb{C} \to \mathbb{C}\) 函數的完整雅可比矩陣,我們可以使用 JVP 或 VJP 來完成!
您應該期望複數在 JAX 中的任何地方都能運作。這是通過複數矩陣的 Cholesky 分解進行微分
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
更進階的自動微分#
在本筆記本中,我們逐步完成了 JAX 中自動微分的一些簡單,然後逐漸複雜的應用。我們希望您現在覺得在 JAX 中求導數既容易又強大。
還有一個自動微分技巧和功能的完整世界。我們沒有涵蓋的主題,但希望在「進階自動微分食譜」中涵蓋的主題包括
高斯-牛頓向量積,線性化一次
自訂 VJP 和 JVP
固定點的有效導數
使用隨機黑塞矩陣-向量積估計黑塞矩陣的跡。
僅使用反向模式自動微分的正向模式自動微分。
對自訂資料類型求導數。
檢查點 (用於有效反向模式的二項式檢查點,而非模型快照)。
透過雅可比矩陣預先累積來優化 VJP。