jax.remat
/ jax.checkpoint
變更:您需要知道的事#
目錄#
發生什麼事?#
從 #11830 開始,我們正在切換 jax.checkpoint()
(又名 jax.remat()
,這兩個名稱是彼此的別名) 的新實作。對於大多數程式碼,將不會有任何變更。 但在邊緣情況下可能會有observable差異;請參閱升級後可能出現哪些問題?
我該如何停用變更,並暫時回到舊的行為?#
如果您在此變更中遇到問題,在 jax==0.3.16
版本中,可以透過將 jax_new_checkpoint
config 選項設定為 False 來關閉新實作,可透過以下任一方式:
設定 shell 環境變數
JAX_NEW_CHECKPOINT=0
;執行
jax.config.update('jax_new_checkpoint', False)
;如果您使用
absl
解析旗標,請傳遞--jax_new_checkpoint=False
選項。
如果您需要還原為舊的實作,請在 GitHub issue 上聯繫我們,以便我們可以讓新的實作為您工作。
從 jax==0.3.17
開始,jax_new_checkpoint
config 選項不再可用。如果您有問題,請在 issue tracker 上聯繫我們,以便我們協助修正!
我們為何要這麼做?#
在撰寫本文時,JAX 有兩個並行的 jax.checkpoint
實作。新的實作已經在幾個月內以選擇加入的方式使用 (例如,由 Pax 和 Flaxformer/T5X)。但它尚未預設啟用。
我們希望將新的實作切換為預設啟用,然後刪除舊的實作。使用新的實作並移除舊的實作,為使用者帶來多項好處。
使用者可自訂的重物化策略#
新實作的主要優點是與 policy
引數對應的新功能。其概念是讓使用者精確控制在自動微分的前向傳遞期間,哪些中間值會被儲存 (相對於重物化)。透過運用這種對記憶體使用量與重新計算權衡的控制,使用者可以獲得顯著的效能提升,尤其是在大型模型和我們的 LLM MLPerf 提交中!
此功能的完整文件仍在撰寫中,但這裡有一個簡單的範例:
from functools import partial
import jax
def apply_layer(W, x):
return jnp.sin(jnp.dot(W, x))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
for W in params[:-1]:
x = apply_layer(W, x)
return jnp.dot(params[-1], x)
透過在此處套用 jax.checkpoint
和 policy=jax.checkpoint_policies.checkpoint_dots
,我們確保在前向傳遞期間只允許儲存矩陣乘法的結果。cos
應用程式的 Jacobian 係數值,以及計算這些值所需的 sin
應用程式的值,不會從前向傳遞中儲存,而是在反向傳遞期間重新計算。(像這樣的策略在 TPU 上可能很有效,其中元素級計算實際上是免費的,但來自矩陣單元的結果值得儲存。)
能夠重物化常數,而不僅限於對引數具有資料依賴性的運算#
舊的 jax.checkpoint
實作實際上無法重物化與裝飾函式的引數沒有資料依賴性的計算。考慮這個玩具範例:
@jax.checkpoint
def f(x):
a = some_function(jnp.arange(10_000_000)) # `a` does not depend on `x`
return a * x
舊的 jax.checkpoint
實作被迫儲存 a
的值,這可能需要大量記憶體。新的 jax.checkpoint
實作可以重物化而不是儲存 a
的值。
在某些情況下,Python 額外負擔顯著減少#
新的 jax.checkpoint
在某些情況下產生的 Python 額外負擔顯著減少。簡單的額外負擔基準測試速度提高了 10 倍。這些額外負擔僅在 eager op-by-op 執行中出現,因此在使用 jax.checkpoint
在 jax.jit
或類似情況下的常見情況下,速度提升並不相關。但儘管如此,還是很棒!
透過簡化內部機制來啟用新的 JAX 功能#
此變更也解鎖了未來使用者的大量好處,例如自訂批次處理規則 (vmap
類似於 custom_vjp
) 和 custom_vjp
的前向可微分升級。它還顯著降低了 JAX 程式碼庫某些部分的複雜性,這對於整體可維護性和錯誤修復很有幫助。
升級後可能出現哪些問題?#
無害的數值變更#
由於新的實作可以重物化更多計算,包括潛在的大型常數的計算,因此某些程式碼可能會看到微小的數值變更。任何數值變更的幅度都應在我們預期從變更編譯器最佳化 (例如,浮點運算的重新排序) 中看到的範圍內。但某些過於嚴格的測試容差可能需要稍微放寬。
concrete=True
選項已移除。#
舊的 jax.checkpoint
實作有一個布林值 concrete
選項,允許在具體的 Python 值上進行追蹤 (而不是延遲所有計算,而僅在抽象值上進行追蹤)。該選項很少使用,並且在使用該選項的情況下,有更簡單的替代方案。因此,我們在新 jax.checkpoint
中移除了該選項。
例如,在 Google 程式碼中,concrete=True
最常見的用途是支援傳遞像 is_training
這樣的引數
@partial(jax.checkpoint, concrete=True) # OLD jax.checkpoint API
def foo(x, is_training):
if is_training:
return g(x)
else:
return h(x)
使用新的 jax.checkpoint
實作,我們可以使用 static_argnums
選項來完成相同的操作
@partial(jax.checkpoint, static_argnums=(1,)) # NEW jax.checkpoint API
def foo(x, is_training):
if is_training:
...
如果需要在靜態引數上執行 jax.numpy
運算,並在 Python 追蹤期間而不是延遲計算其數值結果,我們可以使用 static_argnums
和 jax.ensure_compile_time_eval()
。但您似乎不太可能需要這樣做!