jax.remat / jax.checkpoint 變更:您需要知道的事#

目錄#

發生什麼事?#

#11830 開始,我們正在切換 jax.checkpoint() (又名 jax.remat(),這兩個名稱是彼此的別名) 的新實作。對於大多數程式碼,將不會有任何變更。 但在邊緣情況下可能會有observable差異;請參閱升級後可能出現哪些問題?

我該如何停用變更,並暫時回到舊的行為?#

如果您在此變更中遇到問題,jax==0.3.16 版本中,可以透過將 jax_new_checkpoint config 選項設定為 False 來關閉新實作,可透過以下任一方式:

  1. 設定 shell 環境變數 JAX_NEW_CHECKPOINT=0

  2. 執行 jax.config.update('jax_new_checkpoint', False)

  3. 如果您使用 absl 解析旗標,請傳遞 --jax_new_checkpoint=False 選項。

如果您需要還原為舊的實作,請在 GitHub issue 上聯繫我們,以便我們可以讓新的實作為您工作。

jax==0.3.17 開始,jax_new_checkpoint config 選項不再可用。如果您有問題,請在 issue tracker 上聯繫我們,以便我們協助修正!

我們為何要這麼做?#

在撰寫本文時,JAX 有兩個並行的 jax.checkpoint 實作。新的實作已經在幾個月內以選擇加入的方式使用 (例如,由 Pax 和 Flaxformer/T5X)。但它尚未預設啟用。

我們希望將新的實作切換為預設啟用,然後刪除舊的實作。使用新的實作並移除舊的實作,為使用者帶來多項好處。

使用者可自訂的重物化策略#

新實作的主要優點是與 policy 引數對應的新功能。其概念是讓使用者精確控制在自動微分的前向傳遞期間,哪些中間值會被儲存 (相對於重物化)。透過運用這種對記憶體使用量與重新計算權衡的控制,使用者可以獲得顯著的效能提升,尤其是在大型模型和我們的 LLM MLPerf 提交中!

此功能的完整文件仍在撰寫中,但這裡有一個簡單的範例:

from functools import partial
import jax

def apply_layer(W, x):
  return jnp.sin(jnp.dot(W, x))

@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
  for W in params[:-1]:
    x = apply_layer(W, x)
  return jnp.dot(params[-1], x)

透過在此處套用 jax.checkpointpolicy=jax.checkpoint_policies.checkpoint_dots,我們確保在前向傳遞期間只允許儲存矩陣乘法的結果。cos 應用程式的 Jacobian 係數值,以及計算這些值所需的 sin 應用程式的值,不會從前向傳遞中儲存,而是在反向傳遞期間重新計算。(像這樣的策略在 TPU 上可能很有效,其中元素級計算實際上是免費的,但來自矩陣單元的結果值得儲存。)

能夠重物化常數,而不僅限於對引數具有資料依賴性的運算#

舊的 jax.checkpoint 實作實際上無法重物化與裝飾函式的引數沒有資料依賴性的計算。考慮這個玩具範例:

@jax.checkpoint
def f(x):
  a = some_function(jnp.arange(10_000_000))  # `a` does not depend on `x`
  return a * x

舊的 jax.checkpoint 實作被迫儲存 a 的值,這可能需要大量記憶體。新的 jax.checkpoint 實作可以重物化而不是儲存 a 的值。

在某些情況下,Python 額外負擔顯著減少#

新的 jax.checkpoint 在某些情況下產生的 Python 額外負擔顯著減少。簡單的額外負擔基準測試速度提高了 10 倍。這些額外負擔僅在 eager op-by-op 執行中出現,因此在使用 jax.checkpointjax.jit 或類似情況下的常見情況下,速度提升並不相關。但儘管如此,還是很棒!

透過簡化內部機制來啟用新的 JAX 功能#

此變更也解鎖了未來使用者的大量好處,例如自訂批次處理規則 (vmap 類似於 custom_vjp) 和 custom_vjp 的前向可微分升級。它還顯著降低了 JAX 程式碼庫某些部分的複雜性,這對於整體可維護性和錯誤修復很有幫助。

升級後可能出現哪些問題?#

無害的數值變更#

由於新的實作可以重物化更多計算,包括潛在的大型常數的計算,因此某些程式碼可能會看到微小的數值變更。任何數值變更的幅度都應在我們預期從變更編譯器最佳化 (例如,浮點運算的重新排序) 中看到的範圍內。但某些過於嚴格的測試容差可能需要稍微放寬。

concrete=True 選項已移除。#

舊的 jax.checkpoint 實作有一個布林值 concrete 選項,允許在具體的 Python 值上進行追蹤 (而不是延遲所有計算,而僅在抽象值上進行追蹤)。該選項很少使用,並且在使用該選項的情況下,有更簡單的替代方案。因此,我們在新 jax.checkpoint 中移除了該選項。

例如,在 Google 程式碼中,concrete=True 最常見的用途是支援傳遞像 is_training 這樣的引數

@partial(jax.checkpoint, concrete=True)  # OLD jax.checkpoint API
def foo(x, is_training):
  if is_training:
    return g(x)
  else:
    return h(x)

使用新的 jax.checkpoint 實作,我們可以使用 static_argnums 選項來完成相同的操作

@partial(jax.checkpoint, static_argnums=(1,))  # NEW jax.checkpoint API
def foo(x, is_training):
  if is_training:
    ...

如果需要在靜態引數上執行 jax.numpy 運算,並在 Python 追蹤期間而不是延遲計算其數值結果,我們可以使用 static_argnumsjax.ensure_compile_time_eval()。但您似乎不太可能需要這樣做!