使用 jax.checkpoint
(jax.remat
) 的梯度檢查點#
在本教學中,您將學習如何使用 jax.checkpoint()
(也稱為 jax.remat()
) 控制 JAX 自動微分的已儲存值,這在機器學習中特別有用。
如果您是自動微分 (autodiff) 的新手,或需要複習記憶,JAX 提供了自動微分和進階自動微分教學。
重點摘要 使用 jax.checkpoint()
裝飾器 (別名為 jax.remat()
) 與 jax.grad()
搭配使用,以控制前向傳遞中儲存哪些中間值,以及在反向傳遞中重新計算哪些中間值,從而在記憶體和 FLOPs 之間進行權衡。
如果您不使用 jax.checkpoint()
,則 jax.grad(f)(x)
前向傳遞會儲存 Jacobian 係數和其他中間值,以在反向傳遞期間使用。這些儲存的值稱為殘差。
注意: 請務必查看實務注意事項,以了解 jax.checkpoint()
如何與 jax.jit()
互動。
import jax
import jax.numpy as jnp
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
透過將 jax.checkpoint()
應用於子函數 (作為裝飾器或在特定應用位置),您可以強制 JAX 不儲存該子函數的任何殘差。相反地,可能只會儲存 jax.checkpoint()
裝飾函數的輸入,而反向傳遞中使用的任何殘差都會根據需要從這些輸入重新計算。
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
在此範例中,儲存了兩個 sin
應用程式的值,因為它們是後續 jax.checkpoint()
裝飾的 g
函數的應用程式中的引數,且可能會儲存 jax.checkpoint()
裝飾函數的輸入。但不會儲存任何 cos
應用程式的值。
若要控制哪些值可儲存,而無需編輯要微分的函數定義,您可以使用重新實體化策略。以下範例僅儲存沒有批次維度的 dot
運算的結果 (因為它們通常受 FLOP 限制,因此值得儲存而不是重新計算)。
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
您也可以使用策略來參考您使用 jax.ad_checkpoint.checkpoint_name()
命名的中間值。
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1031/2296542172.py:4 (f4)
在試用這些玩具範例時,您可以使用此筆記本中定義的自訂 print_fwd_bwd
實用程式,更仔細地查看正在發生的情況。
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6] e:f32[5] = dot_general[ i:f32[7] j:f32[7]. let dimension_numbers=(([1], [0]), ([], [])) k:f32[7] = mul j i preferred_element_type=float32 l:f32[6] = dot_general[ ] a d dimension_numbers=(([0], [0]), ([], [])) f:f32[5] = sin e preferred_element_type=float32 g:f32[5] = cos e ] k h h:f32[6] = dot_general[ m:f32[7,6] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] b f ] k g i:f32[6] = sin h n:f32[6] = mul l f j:f32[6] = cos h o:f32[5] = dot_general[ k:f32[7] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] n e ] c i p:f32[6,5] = dot_general[ l:f32[7] = sin k dimension_numbers=(([], []), ([], [])) m:f32[7] = cos k preferred_element_type=float32 in (l, d, a, g, f, b, j, i, c, m) } ] n d q:f32[5] = mul o c r:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] q b s:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] q a in (s, p, m, r) }
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[ i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ dimension_numbers=(([1], [0]), ([], [])) differentiated=True preferred_element_type=float32 jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] ] a d s:f32[4] t:f32[7]. let f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e u:f32[5] = sin m g:f32[5] = sin f v:f32[5] = cos m h:f32[6] = dot_general[ w:f32[6] = sin n dimension_numbers=(([1], [0]), ([], [])) x:f32[6] = cos n preferred_element_type=float32 y:f32[7] = cos o ] b g z:f32[7] = mul t y i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h ba:f32[6] = dot_general[ j:f32[6] = sin i dimension_numbers=(([0], [0]), ([], [])) k:f32[7] = dot_general[ preferred_element_type=float32 dimension_numbers=(([1], [0]), ([], [])) ] z r preferred_element_type=float32 bb:f32[6] = mul ba x ] c j bc:f32[5] = dot_general[ l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k dimension_numbers=(([0], [0]), ([], [])) m:f32[7] = sin l preferred_element_type=float32 in (m, f, i, l, a, b, c, d) } ] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] bd p bf:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bd s bg:f32[6,5] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bb u bh:f32[7,6] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] z w in (bf, bg, bh, be) } policy=<function dot_with_no_batch_dims_saveable at 0x7f6ca10ebc70> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }
讓我們逐步思考#
注意: 在繼續之前,查看進階自動微分教學可能會有所幫助。
jax.checkpoint
基礎知識#
在 jax.linearize()
和 jax.vjp()
中,值的計算方式和時間點都具有彈性。不同的選擇可以在記憶體使用量和 FLOPs 之間進行權衡。JAX 透過 jax.checkpoint()
提供對這些選擇的控制。
其中一個選擇是在前向傳遞中 (在輸入可用時立即) 執行 Jacobian 係數計算,還是在反向傳遞中 (在需要係數之前) 執行。考慮 sin_vjp
的範例:
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
另一個有效的實作會在反向傳遞而不是前向傳遞中計算 jnp.cos(x)
的值:
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
對於這個特定函數,兩個版本使用的記憶體量相同,儘管您減少了原始計算 (前向傳遞) 的 FLOPs,並增加了共切線計算 (反向傳遞) 的 FLOPs。
在函數組合方面還有另一個選擇。回想一下兩個函數組合的 VJP 規則:
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
另一種選擇是:
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
用文字來說,這種替代實作不會在前向傳遞中計算 g_vjp
或其閉包中的殘差值。相反地,它只會在反向傳遞 f_bwd2
中計算它們。這表示 f_vjp_checkpoint
需要更少的記憶體:如果 g
和 h
各自的殘差需要相似的記憶體量,且都遠大於 x
,則 f_vjp_checkpoint(x)
產生的函數所需的記憶體是 f_vjp(x)
的一半!
您付出的代價是多餘的工作:在 f_bwd2
中,您必須重新評估 g(x)
作為 jax.vjp(g, x)
的一部分,只是為了捨棄其值 (在 _, g_vjp = jax.vjp(g, x)
行的底線變數中)。
您可以在自動微分中獲得此 VJP 行為,而無需直接撰寫 VJP 函數,而是透過在原始函數 f
的替代定義中使用 jax.checkpoint()
:
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
換句話說,您可以將 jax.checkpoint()
應用於 g
( f
的第一階段),而不是應用於 f
本身。這樣,當您評估 jax.grad(f_checkpoint)(x)
時,您會得到類似以下的計算:
執行
g
的前向傳遞,捨棄殘差值。執行
h
的前向傳遞,儲存殘差。執行
h
的反向傳遞,使用步驟 2 的殘差。重新執行
g
的前向傳遞,儲存殘差。執行
g
的反向傳遞,使用步驟 4 的殘差。
也就是說,透過評估 jax.grad(f_checkpoint)(x)
,我們會得到與以下計算相同的結果:
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
一般而言,jax.checkpoint(foo)
是一個新函數,其輸入-輸出行為與 foo
相同,但在自動微分下 (特別是在 jax.linearize()
和 jax.vjp()
(及其包裝器,例如 jax.grad()
),但非 jax.jvp()
下) 行為不同。當進行微分時,只會在前向傳遞中儲存 jax.checkpoint()
微分函數的輸入。在反向傳遞中,會重新計算殘差 (foo
的中間值及其反向傳遞所需的 Jacobian 係數值)。
請注意,如果 f = lambda x: h(g(x))
是您要微分的函數 (換句話說,如果您想要應用 jax.grad(f)
),則透過將 jax.checkpoint()
應用於 f
本身,您不會獲得任何記憶體節省。這是因為評估 jax.grad(jax.checkpoint(f))(x)
會導致如下的計算:
執行前向傳遞,捨棄所有殘差。
立即重新執行前向傳遞,儲存殘差。
執行反向傳遞,使用步驟 2 的殘差。
在程式碼中,您會得到類似以下的內容:
def f_grad_bad(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
透過將 jax.checkpoint()
應用於 h
( f
的第二階段),您也不會獲得任何記憶體節省。這是因為評估 jax.grad(lambda x: jax.checkpoint(h)(g(x)))
會導致如下的計算:
執行
g
的前向傳遞,儲存殘差。執行
h
的前向傳遞,捨棄殘差。立即重新執行
h
的前向傳遞,儲存殘差。執行
h
的反向傳遞,使用步驟 3 的殘差。執行
g
的反向傳遞,使用步驟 1 的殘差。
在程式碼中,您會得到類似以下的內容:
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
更廣泛來說,如果您有一個函數鏈組合,例如 f = lambda x: f3(f2(f1(x)))
,並且對評估 jax.grad(f)
感興趣,您可以說您:
不應將
jax.checkpoint()
應用於整個函數f
,因為這樣不會節省任何記憶體 (並且會執行浪費的重新計算)。不應將
jax.checkpoint()
應用於最後一個子函數f3
,因為這樣不會節省任何記憶體 (並且會執行浪費的重新計算)。可以將
jax.checkpoint()
應用於f1
、f2
或其組合lambda x: f2(f1(x))
,因為這些都可能節省記憶體,並且會表達不同的記憶體/重新計算權衡。
可儲存內容的自訂策略#
如目前所示,使用 jax.checkpoint()
會在一個極端到另一個極端之間切換:
在沒有
jax.checkpoint()
的情況下,JAX 的自動微分傾向於在前向傳遞中計算所有可能的內容,並將其儲存以供反向傳遞使用。使用
jax.checkpoint()
裝飾器,您可以在前向傳遞中計算盡可能少的內容,並在反向傳遞中根據需要重新計算值。
若要在這兩個極端之間操作 (儲存某些內容,而不儲存其他內容),您可以仔細地將 jax.checkpoint()
裝飾器放置在子函數上。但這需要編輯要微分的函數 (例如模型程式碼),這可能很不方便。也很難實驗變體。
因此,替代方案是使用 jax.checkpoint()
的 policy
引數。策略是一個可呼叫物件 (即函數),它將一階 primitive 應用程式的型別層級規格作為輸入,並傳回一個布林值,指示是否允許將對應的輸出值儲存為殘差 (或者是否必須在 (共)切線計算中根據需要重新計算)。為了撰寫穩健的程式碼,應從 jax.checkpoint_policies()
上的屬性中選取策略,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable()
,因為撰寫自訂策略可呼叫物件的 API 被視為內部 API。
例如,考慮要微分的此函數:
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_1031/4230705069.py:2 (loss)
您可能只想儲存沒有批次維度的矩陣乘法的結果 (因為它們可能是受 FLOP 限制而不是受記憶體限制),而不是在前向傳遞中儲存這麼多值。您可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable()
來執行此操作:
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:8 (predict)
另請注意,透過提供策略,您不需要編輯定義 loss
、predict
或 layer
的程式碼。如果您想要在呼叫程式碼 (例如訓練腳本) 中實驗策略,而無需變更程式庫程式碼 (例如神經網路程式庫),這尤其方便。
某些策略可以參考使用 jax.ad_checkpoint.checkpoint_name()
命名的值。
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
就其本身而言,jax.ad_checkpoint import.checkpoint_name()
只是一個恆等函數。但由於某些策略函數知道要尋找它們,因此您可以使用名稱來控制 jax.ad_checkpoint import.checkpoint_name()
輸出的某些值是否被視為可儲存。
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1031/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1031/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_1031/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
另一個參考名稱的策略是 jax.checkpoint_policies.save_only_these_names
。
卸載的自訂策略#
當進行檢查點以節省加速器記憶體時,您可以考慮卸載到 CPU 記憶體而不是重新計算。jax.checkpoint_policies.offload_dot_with_no_batch_dims
可以將沒有批次維度的矩陣乘法的結果卸載到 CPU。
from jax.ad_checkpoint import checkpoint
def checkpoint_offload_dot_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
@functools.partial(checkpoint, policy=policy)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.sum(x)
return x
JAX 的檢查點策略之一允許將指定的檢查點名稱卸載到 CPU。此策略透過 jax.checkpoint_policies.save_and_offload_only_these_names
實作,此函式有四個參數:names_which_can_be_saved
、names_which_can_be_offloaded
、卸載來源和目的地。列在 names_which_can_be_saved
中的名稱會保留在裝置上,列在 names_which_can_be_offloaded
中的名稱會移至 CPU 記憶體,而其他未命名的名稱或運算則會重新計算。例如,如果我們有檢查點名稱 y
、z
和 w
,y
可以保存在裝置上,z
可以卸載到 CPU 記憶體,而 w
則可以重新計算。
from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu
def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(checkpoint, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
此程式碼定義了一個函式 f
,其應用了具有自訂策略的檢查點。此策略決定了在執行期間哪些計算可以被保存或卸載。f
內部有一個巢狀函式 g
,用於執行核心計算。jax.lax.scan
函式用於對輸入資料重複應用 g
。
策略列表#
策略如下:
everything_saveable
(預設策略,如同完全未使用jax.checkpoint
一樣)nothing_saveable
(即重新實體化所有內容,如同完全未使用自訂策略一樣)dots_saveable
或其別名checkpoint_dots
dots_with_no_batch_dims_saveable
或其別名checkpoint_dots_with_no_batch_dims
save_anything_but_these_names
(保存除了具有任何給定名稱的checkpoint_name
輸出之外的所有值)save_any_names_but_these
(僅保存具名值,即任何checkpoint_name
的輸出,除了具有給定名稱的輸出之外)save_only_these_names
(僅保存具名值,且僅限於給定的名稱)offload_dot_with_no_batch_dims
與dots_with_no_batch_dims_saveable
相同,但卸載到 CPU 記憶體而不是重新計算。save_and_offload_only_these_names
與save_only_these_names
相同,但卸載到 CPU 記憶體而不是重新計算。save_from_both_policies(policy_1, policy_2)
(類似於邏輯or
,因此如果根據policy_1
*或*policy_2
可保存,則殘差即可保存)
策略僅指示哪些是可保存的;只有當反向傳遞實際需要某個值時,才會保存該值。
進階:遞迴 jax.checkpoint
#
透過正確地應用 jax.checkpoint()
,可以在記憶體使用量和(重新)計算之間進行許多權衡。一個令人驚訝的例子是*遞迴*檢查點,您將 jax.checkpoint()
應用於一個函式,該函式本身以某種方式呼叫 jax.checkpoint()
裝飾的函式,使得來自 \(D\) 個函式鏈式組合的記憶體使用量以 \(\mathcal{O}(\log_2 D)\) 而非 \(\mathcal{O}(D)\) 的規模增長。
作為一個玩具範例,考慮多個 jax.numpy.sin()
函式的鏈式組合
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
一般來說,儲存的殘差數量與鏈的長度成線性比例
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
但是您可以遞迴地應用 jax.checkpoint()
以改善規模增長
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
這裡的成本,與往常一樣,是重新計算:特別是,您最終會執行 \(\mathcal{O}(\log_2 D)\) 倍的 FLOPs
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i h c:f32[] = cos a k:f32[] = mul j g d:f32[] = sin b l:f32[] = mul k f e:f32[] = cos b m:f32[] = mul l e f:f32[] = sin d n:f32[] = mul m d g:f32[] = cos d o:f32[] = mul n c h:f32[] = sin f p:f32[] = mul o b i:f32[] = cos f q:f32[] = mul p a j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, c, e, g, i, k, m, o, q) }
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d c differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] a f in (l, g, a, k, m) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0 g in (o,) }
實務注意事項#
當微分函式被暫存輸出到 XLA 以進行編譯時 — 例如透過將 jax.jit()
應用於包含 jax.grad()
呼叫的函式 — XLA 將自動最佳化計算,包括何時計算或重新實體化值的決策。因此,**對於 jax.jit()
下的微分函式,通常不需要 jax.checkpoint()
**。XLA 會為您最佳化一切。
一個例外是當使用暫存輸出的控制流程時,例如 jax.lax.scan()
。跨多個控制流程原語(例如,跨越前向傳遞 scan
和相應的反向傳遞 scan
)的自動編譯器最佳化通常不夠徹底。因此,通常最好在傳遞給 jax.lax.scan()
的主体函式上使用 jax.checkpoint()
。
例如,大型 Transformer 模型 中的一個常見模式是將架構表示為跨層的 jax.lax.scan()
,以減少編譯時間。也就是說,以簡單的全連接網路作為類比,而不是像這樣編寫
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # Weights-bias pair for a layer.
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
而是使用 jax.lax.scan()
迭代層應用
params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])),
(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
這種跨層掃描版本減少了編譯時間,但由於阻礙了一些編譯器最佳化,可能會導致梯度計算效率低下。為了緩解這個問題,您可以在掃描函式上使用 jax.checkpoint()
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
透過以這種方式使用 jax.checkpoint()
,您可以手動控制 JAX 的自動微分在正向和反向傳遞之間保存哪些值,因此不依賴 XLA 最佳化為您選擇。