使用 jax.checkpoint (又名 jax.remat) 控制 autodiff 的儲存值#

import jax
import jax.numpy as jnp

摘要#

使用 jax.checkpoint 裝飾器 (別名為 jax.remat) 與 jax.grad 來控制在前向傳遞中儲存哪些中間值,以及在反向傳遞中重新計算哪些中間值,以權衡記憶體和 FLOPs。

請勿錯過實用注意事項,以了解 jax.checkpoint 如何與 jax.jit 互動的討論。

在不使用 jax.checkpoint 的情況下,jax.grad(f)(x) 的前向傳遞會儲存 Jacobian 係數和其他中間值的值,以供反向傳遞使用。我們將這些儲存的值稱為殘差

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)

透過將 jax.checkpoint 應用於子函數,作為裝飾器或在特定應用位置,我們強制 JAX 不儲存該子函數的任何殘差。相反地,只有 jax.checkpoint 裝飾函數的輸入可能會被儲存,並且在反向傳遞中消耗的任何殘差都會根據需要從這些輸入重新計算

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)

此處儲存了兩個 sin 應用程式的值,因為它們是後續 jax.checkpoint 裝飾的 g 函數的應用程式中的參數,並且 jax.checkpoint 裝飾函數的輸入可能會被儲存。但是沒有儲存 cos 應用程式的值。

為了控制哪些值是可儲存的,而無需編輯要微分的函數的定義,您可以使用重新實體化政策。以下範例僅儲存沒有批次維度的 dot 運算的結果 (因為它們通常受 FLOP 限制,因此值得儲存而不是重新計算)

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)

您也可以使用政策來參考您使用 jax.ad_checkpoint.checkpoint_name 命名的中間值

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)

在玩這些玩具範例時,我們可以更仔細地查看使用此筆記本中定義的 print_fwd_bwd 工具發生了什麼事

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                                      
  forward computation:                                                        backward computation:                                                                   
                                                                                                                                                                      
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[7]. let                                                            
      f:f32[5] = sin e                                                            k:f32[7] = mul j a                                                                  
      g:f32[5] = cos e                                                            l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c                
      h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f        m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b                
      i:f32[6] = sin h                                                            n:f32[6] = mul l d                                                                  
      j:f32[6] = cos h                                                            o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f                
      k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i        p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e                
      l:f32[7] = sin k                                                            q:f32[5] = mul o g                                                                  
      m:f32[7] = cos k                                                            r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i                
    in (l, m, i, c, j, f, b, g, d, a) }                                           s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h                
                                                                                in (s, p, m, r) }                                                                     
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                             
  forward computation:                                                        backward computation:                                                                          
                                                                                                                                                                             
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
      f:f32[5] = sin e                                                              differentiated=True                                                                      
      g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f          jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      h:f32[6] = sin g                                                                  s:f32[4] t:f32[7]. let                                                               
      i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h              u:f32[5] = sin m                                                                     
      j:f32[7] = sin i                                                                  v:f32[5] = cos m                                                                     
    in (j, e, g, i, a, b, c, d) }                                                       w:f32[6] = sin n                                                                     
                                                                                        x:f32[6] = cos n                                                                     
                                                                                        y:f32[7] = cos o                                                                     
                                                                                        z:f32[7] = mul t y                                                                   
                                                                                        ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r                
                                                                                        bb:f32[6] = mul ba x                                                                 
                                                                                        bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q               
                                                                                        bd:f32[5] = mul bc v                                                                 
                                                                                        be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p               
                                                                                        bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s               
                                                                                        bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u               
                                                                                        bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w                
                                                                                      in (bf, bg, bh, be) }                                                                  
                                                                                    policy=<function dot_with_no_batch_dims at 0x7f5e469b1700>                               
                                                                                    prevent_cse=True                                                                         
                                                                                  ] a b c d e f g h                                                                          
                                                                                in (i, j, k, l) }                                                                            

讓我們逐步思考#

您可能想先 (重新) 閱讀自動微分食譜第 1 部分

jax.checkpoint 的基本原理#

jax.linearizejax.vjp 中,某些值的計算方式和時間具有彈性。不同的選擇可以在記憶體使用量和 FLOPs 之間權衡。JAX 透過 jax.checkpoint 提供對這些選擇的控制。

其中一個選擇是在前向傳遞中 (在輸入可用時立即) 或在反向傳遞中 (在需要係數之前) 執行 Jacobian 係數計算。考慮 sin_vjp 的範例

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一個有效的實作是在反向傳遞而不是在前向傳遞中計算 jnp.cos(x) 的值

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

對於這個特定的函數,兩個版本使用的記憶體量相同,儘管我們減少了原始計算 (即前向傳遞) 的 FLOPs,並增加了餘切計算 (即反向傳遞) 的 FLOPs。

在函數組合方面還有另一個選擇。回想一下我們兩個函數組合的 VJP 規則

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一種選擇是

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

換句話說,這種替代實作不會在前向傳遞中計算 g_vjp 或其閉包中的殘差值。相反地,它僅在反向傳遞 f_bwd2 中計算它們。這表示 f_vjp_checkpoint 需要更少的記憶體:如果 gh 各自的殘差需要相似的記憶體量,並且都遠大於 x,那麼 f_vjp_checkpoint(x) 產生的函數需要的記憶體是 f_vjp(x) 的一半!

我們付出的代價是冗餘工作:在 f_bwd2 中,我們必須重新評估 g(x) 作為 jax.vjp(g, x) 的一部分,只是為了丟棄它的值 (在 _, g_vjp = jax.vjp(g, x) 行上的底線變數中)。

我們可以透過在原始函數 f 的替代定義中使用 jax.checkpoint,在自動微分中獲得此 VJP 行為,而無需直接撰寫 VJP 函數

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

換句話說,我們將 jax.checkpoint 應用於 g,即 f 的第一階段,而不是 f 本身。這樣,當我們評估 jax.grad(f_checkpoint)(x) 時,我們會得到類似以下的計算

  1. 執行 g 的前向傳遞,丟棄殘差值;

  2. 執行 h 的前向傳遞,儲存殘差;

  3. 執行 h 的反向傳遞,消耗步驟 2 的殘差;

  4. 重新執行 g 的前向傳遞,儲存殘差;

  5. 執行 g 的反向傳遞,消耗步驟 4 的殘差。

也就是說,透過評估 jax.grad(f_checkpoint)(x),我們會得到與以下相同的計算

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

一般來說,jax.checkpoint(foo) 是一個新函數,它具有與 foo 相同的輸入輸出行為,但在自動微分下 (尤其是在 jax.linearizejax.vjp (及其包裝器,如 jax.grad) 下) 行為不同,但在 jax.jvp 下則不然。當微分時,只有 jax.checkpoint 微分函數的輸入會儲存在前向傳遞中;在反向傳遞中,殘差 (即 foo 的中間值及其反向傳遞所需的 Jacobian 係數值) 會被重新計算。

請注意,如果 f = lambda x: h(g(x)) 是我們要微分的函數,也就是說,如果我們要應用 jax.grad(f),則將 jax.checkpoint 應用於 f 本身不會節省任何記憶體。這是因為評估 jax.grad(jax.checkpoint(f))(x) 會導致類似以下的計算

  1. 執行前向傳遞,丟棄所有殘差;

  2. 立即重新執行前向傳遞,儲存殘差;

  3. 執行反向傳遞,消耗步驟 2 的殘差。

也就是說,在程式碼中,我們會得到類似以下的內容

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

jax.checkpoint 應用於 f 的第二階段 h 也無法節省任何記憶體。這是因為評估 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 會導致類似以下的計算

  1. 執行 g 的前向傳遞,儲存殘差;

  2. 執行 h 的前向傳遞,丟棄殘差;

  3. 立即重新執行 h 的前向傳遞,儲存殘差;

  4. 執行 h 的反向傳遞,消耗步驟 3 的殘差;

  5. 執行 g 的反向傳遞,消耗步驟 1 的殘差。

也就是說,在程式碼中,我們會得到類似以下的內容

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

更廣泛地說,如果我們有一個函數的鏈式組合,例如 f = lambda x: f3(f2(f1(x))),並且我們有興趣評估 jax.grad(f),我們可以說

  • 我們不應該將 jax.checkpoint 應用於整個函數 f,因為這不會節省任何記憶體 (並且會執行浪費的重新計算);

  • 我們不應該將 jax.checkpoint 應用於最後一個子函數 f3,因為這不會節省任何記憶體 (並且會執行浪費的重新計算);

  • 我們可以將 jax.checkpoint 應用於 f1f2 或它們的組合 lambda x: f2(f1(x)),因為它們中的任何一個都可能節省記憶體,並且會表達不同的記憶體/重新計算權衡。

可儲存內容的自訂政策#

如目前所示,使用 jax.checkpoint 會從一個極端切換到另一個極端

  • 在沒有 jax.checkpoint 的情況下,JAX 的自動微分傾向於在前向傳遞中計算所有可能的值,並將其儲存以供反向傳遞使用;

  • 使用 jax.checkpoint 裝飾器,我們在前向傳遞中計算盡可能少的值,並在反向傳遞中根據需要重新計算值。

為了在這兩個極端之間運作,儲存某些內容而不儲存其他內容,我們可以仔細地將 jax.checkpoint 裝飾器放在子函數上。但這需要編輯要微分的函數,例如模型程式碼,這可能很不方便。也很難實驗變體。

因此,另一種方法是使用 jax.checkpointpolicy 參數。政策是一個可呼叫物件 (即函數),它將一階基本運算應用程式的型別層級規格作為輸入,並傳回布林值,指示是否允許將對應的輸出值儲存為殘差 (或者是否必須在 (餘) 切線計算中根據需要重新計算)。為了撰寫穩健的程式碼,政策應從 jax.checkpoint_policies 的屬性中選取,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable,因為撰寫自訂政策可呼叫物件的 API 被認為是內部 API。

例如,考慮這個要微分的函數

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)

與其在前向傳遞中儲存這麼多值,不如我們只想儲存沒有批次維度的矩陣乘法的結果 (因為它們可能受 FLOP 限制而不是記憶體限制)。我們可以透過使用政策 jax.checkpoint_policies.dots_with_no_batch_dims_saveable 來做到這一點

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)

另請注意,透過提供政策,我們無需編輯定義 losspredictlayer 的程式碼。如果我們想在呼叫程式碼 (例如訓練腳本) 中實驗政策,而無需變更程式庫程式碼 (例如神經網路程式庫),這尤其方便。

某些政策可以參考使用 jax.ad_checkpoint.checkpoint_name 命名的值

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

就其本身而言,checkpoint_name 只是身分函數。但是由於某些政策函數知道要尋找它們,因此我們可以使用名稱來控制 checkpoint_name 輸出的某些值是否被認為是可儲存的

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'

另一個參考名稱的政策是 jax.checkpoint_policies.save_only_these_names

一些政策包括

  • everything_saveable (預設策略,如同根本未使用 jax.checkpoint)

  • nothing_saveable (即重新實體化所有內容,如同根本未使用自訂政策)

  • dots_saveable 或其別名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其別名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names (儲存任何值,但具有任何給定名稱的 checkpoint_name 輸出除外)

  • save_any_names_but_these (僅儲存已命名的值,即 checkpoint_name 的任何輸出,但具有給定名稱的值除外)

  • save_only_these_names (僅儲存已命名的值,並且僅在給定的名稱中儲存)

政策僅指示什麼是可儲存的;只有當反向傳遞實際需要值時,才會儲存值。

進階:遞迴 jax.checkpoint#

透過以正確的方式應用 jax.checkpoint,可以在記憶體使用量和 (重新) 計算之間表達許多權衡。一個令人驚訝的範例是遞迴檢查點,我們將 jax.checkpoint 應用於一個函數,該函數本身以記憶體使用量從 \(D\) 個函數的鏈式組合中以 \(\mathcal{O}(\log_2 D)\) 而不是 \(\mathcal{O}(D)\) 的比例縮放的方式呼叫 jax.checkpoint 裝飾的函數。

作為一個玩具範例,考慮多個 jnp.sin 函數的鏈式組合

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

一般來說,儲存的殘差數量與鏈的長度成線性比例

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

但是我們可以遞迴地應用 jax.checkpoint 以改善縮放比例

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)

如同往常一樣,這裡的代價是重新計算:特別是,我們最終執行的 FLOPs 是 \(\mathcal{O}(\log_2 D)\)

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i a                                                                    
      c:f32[] = cos a                       k:f32[] = mul j b                                                                    
      d:f32[] = sin b                       l:f32[] = mul k c                                                                    
      e:f32[] = cos b                       m:f32[] = mul l d                                                                    
      f:f32[] = sin d                       n:f32[] = mul m e                                                                    
      g:f32[] = cos d                       o:f32[] = mul n f                                                                    
      h:f32[] = sin f                       p:f32[] = mul o g                                                                    
      i:f32[] = cos f                       q:f32[] = mul p h                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, q, o, m, k, i, g, e, c) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d a                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] c f                                           
    in (l, m, k, g, a) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

實用注意事項#

當微分函數被 staged out 到 XLA 以進行編譯時,例如透過將 jax.jit 應用於包含 jax.grad 呼叫的函數,XLA 將自動最佳化計算,包括何時計算或重新實體化值的決策。因此,jax.jit 下的微分函數通常不需要 jax.checkpoint。XLA 會為您最佳化。

一個例外是當使用 staged-out 控制流程時,例如 jax.lax.scan。跨多個控制流程基本運算 (例如,跨前向傳遞 scan 和對應的反向傳遞 scan) 的自動編譯器最佳化通常不夠徹底。因此,在傳遞給 jax.lax.scan 的主體函數上使用 jax.checkpoint 通常是一個好主意。

例如,大型 Transformer 模型中的一個常見模式是將架構表示為跨圖層的 jax.lax.scan,以減少編譯時間。也就是說,以簡單的全連接網路作為類比,而不是撰寫類似以下的內容

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # weights, bias pair for a layer
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

我們會改為使用 jax.lax.scan 迭代圖層應用

StackedWeights = jnp.ndarray  # all weight matrices stacked together
StackedBiases = jnp.ndarray   # all bias vectors stacked together

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

這個 scan-over-layers 版本減少了編譯時間,但透過阻礙某些編譯器最佳化,它可能會導致梯度計算效率低下。為了減輕這個問題,我們會在掃描函數上使用 jax.checkpoint

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

透過以這種方式使用 jax.checkpoint,我們正在手動控制 JAX 的自動微分在前向和反向傳遞之間儲存哪些值,因此不依賴 XLA 最佳化來為我們選擇。