使用 jax.checkpoint (jax.remat) 的梯度檢查點#

在本教學中,您將學習如何使用 jax.checkpoint() (也稱為 jax.remat()) 控制 JAX 自動微分的已儲存值,這在機器學習中特別有用。

如果您是自動微分 (autodiff) 的新手,或需要複習記憶,JAX 提供了自動微分進階自動微分教學。

重點摘要 使用 jax.checkpoint() 裝飾器 (別名為 jax.remat()) 與 jax.grad() 搭配使用,以控制前向傳遞中儲存哪些中間值,以及在反向傳遞中重新計算哪些中間值,從而在記憶體和 FLOPs 之間進行權衡。

如果您不使用 jax.checkpoint(),則 jax.grad(f)(x) 前向傳遞會儲存 Jacobian 係數和其他中間值,以在反向傳遞期間使用。這些儲存的值稱為殘差

注意: 請務必查看實務注意事項,以了解 jax.checkpoint() 如何與 jax.jit() 互動。

import jax
import jax.numpy as jnp

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[5] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[7] output of cos from /tmp/ipykernel_1031/1801108376.py:6 (g)

透過將 jax.checkpoint() 應用於子函數 (作為裝飾器或在特定應用位置),您可以強制 JAX 不儲存該子函數的任何殘差。相反地,可能只會儲存 jax.checkpoint() 裝飾函數的輸入,而反向傳遞中使用的任何殘差都會根據需要從這些輸入重新計算。

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)
f32[6] output of sin from /tmp/ipykernel_1031/1801108376.py:6 (g)

在此範例中,儲存了兩個 sin 應用程式的值,因為它們是後續 jax.checkpoint() 裝飾的 g 函數的應用程式中的引數,且可能會儲存 jax.checkpoint() 裝飾函數的輸入。但不會儲存任何 cos 應用程式的值。

若要控制哪些值可儲存,而無需編輯要微分的函數定義,您可以使用重新實體化策略。以下範例僅儲存沒有批次維度的 dot 運算的結果 (因為它們通常受 FLOP 限制,因此值得儲存而不是重新計算)。

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
f32[6] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)
f32[7] output of reduce_precision from /tmp/ipykernel_1031/1801108376.py:5 (g)

您也可以使用策略來參考您使用 jax.ad_checkpoint.checkpoint_name() 命名的中間值。

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument W1
f32[6,5] from the argument W2
f32[7,6] from the argument W3
f32[4] from the argument x
f32[5] output of reduce_precision from /tmp/ipykernel_1031/2296542172.py:4 (f4)

在試用這些玩具範例時,您可以使用此筆記本中定義的自訂 print_fwd_bwd 實用程式,更仔細地查看正在發生的情況。

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# Without using `jax.checkpoint`:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                         
  forward computation:                                         backward computation:                                                                     
                                                                                                                                                         
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let    { lambda ; a:f32[4] b:f32[5,4] c:f32[5] d:f32[5] e:f32[6,5] f:f32[6] g:f32[6] h:f32[7,6]  
      e:f32[5] = dot_general[                                      i:f32[7] j:f32[7]. let                                                                
        dimension_numbers=(([1], [0]), ([], []))                   k:f32[7] = mul j i                                                                    
        preferred_element_type=float32                             l:f32[6] = dot_general[                                                               
      ] a d                                                          dimension_numbers=(([0], [0]), ([], []))                                            
      f:f32[5] = sin e                                               preferred_element_type=float32                                                      
      g:f32[5] = cos e                                             ] k h                                                                                 
      h:f32[6] = dot_general[                                      m:f32[7,6] = dot_general[                                                             
        dimension_numbers=(([1], [0]), ([], []))                     dimension_numbers=(([], []), ([], []))                                              
        preferred_element_type=float32                               preferred_element_type=float32                                                      
      ] b f                                                        ] k g                                                                                 
      i:f32[6] = sin h                                             n:f32[6] = mul l f                                                                    
      j:f32[6] = cos h                                             o:f32[5] = dot_general[                                                               
      k:f32[7] = dot_general[                                        dimension_numbers=(([0], [0]), ([], []))                                            
        dimension_numbers=(([1], [0]), ([], []))                     preferred_element_type=float32                                                      
        preferred_element_type=float32                             ] n e                                                                                 
      ] c i                                                        p:f32[6,5] = dot_general[                                                             
      l:f32[7] = sin k                                               dimension_numbers=(([], []), ([], []))                                              
      m:f32[7] = cos k                                               preferred_element_type=float32                                                      
    in (l, d, a, g, f, b, j, i, c, m) }                            ] n d                                                                                 
                                                                   q:f32[5] = mul o c                                                                    
                                                                   r:f32[4] = dot_general[                                                               
                                                                     dimension_numbers=(([0], [0]), ([], []))                                            
                                                                     preferred_element_type=float32                                                      
                                                                   ] q b                                                                                 
                                                                   s:f32[5,4] = dot_general[                                                             
                                                                     dimension_numbers=(([], []), ([], []))                                              
                                                                     preferred_element_type=float32                                                      
                                                                   ] q a                                                                                 
                                                                 in (s, p, m, r) }                                                                       
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                        
  forward computation:                                                   backward computation:                                                                          
                                                                                                                                                                        
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let              { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[                                                i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
        dimension_numbers=(([1], [0]), ([], []))                               differentiated=True                                                                      
        preferred_element_type=float32                                         jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      ] a d                                                                        s:f32[4] t:f32[7]. let                                                               
      f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] e              u:f32[5] = sin m                                                                     
      g:f32[5] = sin f                                                             v:f32[5] = cos m                                                                     
      h:f32[6] = dot_general[                                                      w:f32[6] = sin n                                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   x:f32[6] = cos n                                                                     
        preferred_element_type=float32                                             y:f32[7] = cos o                                                                     
      ] b g                                                                        z:f32[7] = mul t y                                                                   
      i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h              ba:f32[6] = dot_general[                                                             
      j:f32[6] = sin i                                                               dimension_numbers=(([0], [0]), ([], []))                                           
      k:f32[7] = dot_general[                                                        preferred_element_type=float32                                                     
        dimension_numbers=(([1], [0]), ([], []))                                   ] z r                                                                                
        preferred_element_type=float32                                             bb:f32[6] = mul ba x                                                                 
      ] c j                                                                        bc:f32[5] = dot_general[                                                             
      l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k                dimension_numbers=(([0], [0]), ([], []))                                           
      m:f32[7] = sin l                                                               preferred_element_type=float32                                                     
    in (m, f, i, l, a, b, c, d) }                                                  ] bb q                                                                               
                                                                                   bd:f32[5] = mul bc v                                                                 
                                                                                   be:f32[4] = dot_general[                                                             
                                                                                     dimension_numbers=(([0], [0]), ([], []))                                           
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd p                                                                               
                                                                                   bf:f32[5,4] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bd s                                                                               
                                                                                   bg:f32[6,5] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] bb u                                                                               
                                                                                   bh:f32[7,6] = dot_general[                                                           
                                                                                     dimension_numbers=(([], []), ([], []))                                             
                                                                                     preferred_element_type=float32                                                     
                                                                                   ] z w                                                                                
                                                                                 in (bf, bg, bh, be) }                                                                  
                                                                               policy=<function dot_with_no_batch_dims_saveable at 0x7f6ca10ebc70>                      
                                                                               prevent_cse=True                                                                         
                                                                             ] a b c d e f g h                                                                          
                                                                           in (i, j, k, l) }                                                                            

讓我們逐步思考#

注意: 在繼續之前,查看進階自動微分教學可能會有所幫助。

jax.checkpoint 基礎知識#

jax.linearize()jax.vjp() 中,值的計算方式和時間點都具有彈性。不同的選擇可以在記憶體使用量和 FLOPs 之間進行權衡。JAX 透過 jax.checkpoint() 提供對這些選擇的控制。

其中一個選擇是在前向傳遞中 (在輸入可用時立即) 執行 Jacobian 係數計算,還是在反向傳遞中 (在需要係數之前) 執行。考慮 sin_vjp 的範例:

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一個有效的實作會在反向傳遞而不是前向傳遞中計算 jnp.cos(x) 的值:

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

對於這個特定函數,兩個版本使用的記憶體量相同,儘管您減少了原始計算 (前向傳遞) 的 FLOPs,並增加了共切線計算 (反向傳遞) 的 FLOPs。

在函數組合方面還有另一個選擇。回想一下兩個函數組合的 VJP 規則:

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一種選擇是:

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

用文字來說,這種替代實作不會在前向傳遞中計算 g_vjp 或其閉包中的殘差值。相反地,它只會在反向傳遞 f_bwd2 中計算它們。這表示 f_vjp_checkpoint 需要更少的記憶體:如果 gh 各自的殘差需要相似的記憶體量,且都遠大於 x,則 f_vjp_checkpoint(x) 產生的函數所需的記憶體是 f_vjp(x) 的一半!

您付出的代價是多餘的工作:在 f_bwd2 中,您必須重新評估 g(x) 作為 jax.vjp(g, x) 的一部分,只是為了捨棄其值 (在 _, g_vjp = jax.vjp(g, x) 行的底線變數中)。

您可以在自動微分中獲得此 VJP 行為,而無需直接撰寫 VJP 函數,而是透過在原始函數 f 的替代定義中使用 jax.checkpoint()

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

換句話說,您可以將 jax.checkpoint() 應用於 g ( f 的第一階段),而不是應用於 f 本身。這樣,當您評估 jax.grad(f_checkpoint)(x) 時,您會得到類似以下的計算:

  1. 執行 g 的前向傳遞,捨棄殘差值。

  2. 執行 h 的前向傳遞,儲存殘差。

  3. 執行 h 的反向傳遞,使用步驟 2 的殘差。

  4. 重新執行 g 的前向傳遞,儲存殘差。

  5. 執行 g 的反向傳遞,使用步驟 4 的殘差。

也就是說,透過評估 jax.grad(f_checkpoint)(x),我們會得到與以下計算相同的結果:

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

一般而言,jax.checkpoint(foo) 是一個新函數,其輸入-輸出行為與 foo 相同,但在自動微分下 (特別是在 jax.linearize()jax.vjp() (及其包裝器,例如 jax.grad()),但非 jax.jvp() 下) 行為不同。當進行微分時,只會在前向傳遞中儲存 jax.checkpoint() 微分函數的輸入。在反向傳遞中,會重新計算殘差 (foo 的中間值及其反向傳遞所需的 Jacobian 係數值)。

請注意,如果 f = lambda x: h(g(x)) 是您要微分的函數 (換句話說,如果您想要應用 jax.grad(f)),則透過將 jax.checkpoint() 應用於 f 本身,您不會獲得任何記憶體節省。這是因為評估 jax.grad(jax.checkpoint(f))(x) 會導致如下的計算:

  1. 執行前向傳遞,捨棄所有殘差。

  2. 立即重新執行前向傳遞,儲存殘差。

  3. 執行反向傳遞,使用步驟 2 的殘差。

在程式碼中,您會得到類似以下的內容:

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar

透過將 jax.checkpoint() 應用於 h ( f 的第二階段),您也不會獲得任何記憶體節省。這是因為評估 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 會導致如下的計算:

  1. 執行 g 的前向傳遞,儲存殘差。

  2. 執行 h 的前向傳遞,捨棄殘差。

  3. 立即重新執行 h 的前向傳遞,儲存殘差。

  4. 執行 h 的反向傳遞,使用步驟 3 的殘差。

  5. 執行 g 的反向傳遞,使用步驟 1 的殘差。

在程式碼中,您會得到類似以下的內容:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar

更廣泛來說,如果您有一個函數鏈組合,例如 f = lambda x: f3(f2(f1(x))),並且對評估 jax.grad(f) 感興趣,您可以說您:

  • 不應將 jax.checkpoint() 應用於整個函數 f,因為這樣不會節省任何記憶體 (並且會執行浪費的重新計算)。

  • 不應將 jax.checkpoint() 應用於最後一個子函數 f3,因為這樣不會節省任何記憶體 (並且會執行浪費的重新計算)。

  • 可以將 jax.checkpoint() 應用於 f1f2 或其組合 lambda x: f2(f1(x)),因為這些都可能節省記憶體,並且會表達不同的記憶體/重新計算權衡。

可儲存內容的自訂策略#

如目前所示,使用 jax.checkpoint() 會在一個極端到另一個極端之間切換:

  • 在沒有 jax.checkpoint() 的情況下,JAX 的自動微分傾向於在前向傳遞中計算所有可能的內容,並將其儲存以供反向傳遞使用。

  • 使用 jax.checkpoint() 裝飾器,您可以在前向傳遞中計算盡可能少的內容,並在反向傳遞中根據需要重新計算值。

若要在這兩個極端之間操作 (儲存某些內容,而不儲存其他內容),您可以仔細地將 jax.checkpoint() 裝飾器放置在子函數上。但這需要編輯要微分的函數 (例如模型程式碼),這可能很不方便。也很難實驗變體。

因此,替代方案是使用 jax.checkpoint()policy 引數。策略是一個可呼叫物件 (即函數),它將一階 primitive 應用程式的型別層級規格作為輸入,並傳回一個布林值,指示是否允許將對應的輸出值儲存為殘差 (或者是否必須在 (共)切線計算中根據需要重新計算)。為了撰寫穩健的程式碼,應從 jax.checkpoint_policies() 上的屬性中選取策略,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable(),因為撰寫自訂策略可呼叫物件的 API 被視為內部 API。

例如,考慮要微分的此函數:

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of sin from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of sin from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of mul from /tmp/ipykernel_1031/4230705069.py:2 (loss)

您可能只想儲存沒有批次維度的矩陣乘法的結果 (因為它們可能是受 FLOP 限制而不是受記憶體限制),而不是在前向傳遞中儲存這麼多值。您可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable() 來執行此操作:

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] output of reduce_precision from /tmp/ipykernel_1031/4230705069.py:8 (predict)

另請注意,透過提供策略,您不需要編輯定義 losspredictlayer 的程式碼。如果您想要在呼叫程式碼 (例如訓練腳本) 中實驗策略,而無需變更程式庫程式碼 (例如神經網路程式庫),這尤其方便。

某些策略可以參考使用 jax.ad_checkpoint.checkpoint_name() 命名的值。

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

就其本身而言,jax.ad_checkpoint import.checkpoint_name() 只是一個恆等函數。但由於某些策略函數知道要尋找它們,因此您可以使用名稱來控制 jax.ad_checkpoint import.checkpoint_name() 輸出的某些值是否被視為可儲存。

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] named 'layer0_output' from /tmp/ipykernel_1031/178264713.py:7 (predict)
f32[4] output of cos from /tmp/ipykernel_1031/4230705069.py:12 (layer)
f32[4] named 'layer1_output' from /tmp/ipykernel_1031/178264713.py:7 (predict)
f32[4] output of mul from /tmp/ipykernel_1031/4230705069.py:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument params[0]
f32[4,4] from the argument params[1]
f32[4,4] from the argument params[2]
f32[4] from the argument x
f32[4] from the argument y

另一個參考名稱的策略是 jax.checkpoint_policies.save_only_these_names

卸載的自訂策略#

當進行檢查點以節省加速器記憶體時,您可以考慮卸載到 CPU 記憶體而不是重新計算。jax.checkpoint_policies.offload_dot_with_no_batch_dims 可以將沒有批次維度的矩陣乘法的結果卸載到 CPU。

from jax.ad_checkpoint import checkpoint

def checkpoint_offload_dot_with_no_batch_dims(self):
  policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
      "device", "pinned_host")

  @functools.partial(checkpoint, policy=policy)
  def f(x):
    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
    x = jnp.sin(x)
    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
    x = jnp.sin(x)
    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
    x = jnp.sin(x)
    x = jnp.sum(x)
    return x

JAX 的檢查點策略之一允許將指定的檢查點名稱卸載到 CPU。此策略透過 jax.checkpoint_policies.save_and_offload_only_these_names 實作,此函式有四個參數:names_which_can_be_savednames_which_can_be_offloaded、卸載來源和目的地。列在 names_which_can_be_saved 中的名稱會保留在裝置上,列在 names_which_can_be_offloaded 中的名稱會移至 CPU 記憶體,而其他未命名的名稱或運算則會重新計算。例如,如果我們有檢查點名稱 yzwy 可以保存在裝置上,z 可以卸載到 CPU 記憶體,而 w 則可以重新計算。

from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu

def checkpoint_names_saved_offloaded_recomputed(self):
  mesh = jtu.create_mesh((2,), ("x",))
  shape = (256, 128)
  np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
  s = NamedSharding(mesh, P("x"))
  inp = jax.device_put(np_inp, s)

  policy = jax.checkpoint_policies.save_and_offload_only_these_names(
      names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
      offload_src='device', offload_dst='pinned_host')

  @functools.partial(checkpoint, policy=policy)
  def f(x):
    def g(ys, _):
      y, _ = ys
      y = checkpoint_name(jnp.sin(y), "y")
      z = checkpoint_name(jnp.sin(y), "z")
      z = z.T
      w = checkpoint_name(jnp.sin(z), "w")
      return (w.T, jnp.sum(w)), None
    _, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
    return scan_out

此程式碼定義了一個函式 f,其應用了具有自訂策略的檢查點。此策略決定了在執行期間哪些計算可以被保存或卸載。f 內部有一個巢狀函式 g,用於執行核心計算。jax.lax.scan 函式用於對輸入資料重複應用 g

策略列表#

策略如下:

  • everything_saveable(預設策略,如同完全未使用 jax.checkpoint 一樣)

  • nothing_saveable(即重新實體化所有內容,如同完全未使用自訂策略一樣)

  • dots_saveable 或其別名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其別名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names(保存除了具有任何給定名稱的 checkpoint_name 輸出之外的所有值)

  • save_any_names_but_these(僅保存具名值,即任何 checkpoint_name 的輸出,除了具有給定名稱的輸出之外)

  • save_only_these_names(僅保存具名值,且僅限於給定的名稱)

  • offload_dot_with_no_batch_dimsdots_with_no_batch_dims_saveable 相同,但卸載到 CPU 記憶體而不是重新計算。

  • save_and_offload_only_these_namessave_only_these_names 相同,但卸載到 CPU 記憶體而不是重新計算。

  • save_from_both_policies(policy_1, policy_2)(類似於邏輯 or,因此如果根據 policy_1 *或* policy_2 可保存,則殘差即可保存)

策略僅指示哪些是可保存的;只有當反向傳遞實際需要某個值時,才會保存該值。

進階:遞迴 jax.checkpoint#

透過正確地應用 jax.checkpoint(),可以在記憶體使用量和(重新)計算之間進行許多權衡。一個令人驚訝的例子是*遞迴*檢查點,您將 jax.checkpoint() 應用於一個函式,該函式本身以某種方式呼叫 jax.checkpoint() 裝飾的函式,使得來自 \(D\) 個函式鏈式組合的記憶體使用量以 \(\mathcal{O}(\log_2 D)\) 而非 \(\mathcal{O}(D)\) 的規模增長。

作為一個玩具範例,考慮多個 jax.numpy.sin() 函式的鏈式組合

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)

一般來說,儲存的殘差數量與鏈的長度成線性比例

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)
f32[] output of cos from /tmp/ipykernel_1031/410288286.py:4 (f)

但是您可以遞迴地應用 jax.checkpoint() 以改善規模增長

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument x
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of sin from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)
f32[] output of cos from /tmp/ipykernel_1031/1943107544.py:6 (<lambda>)

這裡的成本,與往常一樣,是重新計算:特別是,您最終會執行 \(\mathcal{O}(\log_2 D)\) 倍的 FLOPs

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i h                                                                    
      c:f32[] = cos a                       k:f32[] = mul j g                                                                    
      d:f32[] = sin b                       l:f32[] = mul k f                                                                    
      e:f32[] = cos b                       m:f32[] = mul l e                                                                    
      f:f32[] = sin d                       n:f32[] = mul m d                                                                    
      g:f32[] = cos d                       o:f32[] = mul n c                                                                    
      h:f32[] = sin f                       p:f32[] = mul o b                                                                    
      i:f32[] = cos f                       q:f32[] = mul p a                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, c, e, g, i, k, m, o, q) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d c                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] a f                                           
    in (l, g, a, k, m) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

實務注意事項#

當微分函式被暫存輸出到 XLA 以進行編譯時 — 例如透過將 jax.jit() 應用於包含 jax.grad() 呼叫的函式 — XLA 將自動最佳化計算,包括何時計算或重新實體化值的決策。因此,**對於 jax.jit() 下的微分函式,通常不需要 jax.checkpoint()**。XLA 會為您最佳化一切。

一個例外是當使用暫存輸出的控制流程時,例如 jax.lax.scan()。跨多個控制流程原語(例如,跨越前向傳遞 scan 和相應的反向傳遞 scan)的自動編譯器最佳化通常不夠徹底。因此,通常最好在傳遞給 jax.lax.scan() 的主体函式上使用 jax.checkpoint()

例如,大型 Transformer 模型 中的一個常見模式是將架構表示為跨層的 jax.lax.scan(),以減少編譯時間。也就是說,以簡單的全連接網路作為類比,而不是像這樣編寫

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # Weights-bias pair for a layer.
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

而是使用 jax.lax.scan() 迭代層應用

params = [(jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5])), 
          (jnp.array([[0.5, 0.5], [1., 1.]]), jnp.array([0.5, 0.5]))]

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

這種跨層掃描版本減少了編譯時間,但由於阻礙了一些編譯器最佳化,可能會導致梯度計算效率低下。為了緩解這個問題,您可以在掃描函式上使用 jax.checkpoint()

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

透過以這種方式使用 jax.checkpoint(),您可以手動控制 JAX 的自動微分在正向和反向傳遞之間保存哪些值,因此不依賴 XLA 最佳化為您選擇。