jax.checkpoint#
- jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=())[source]#
使
fun
在微分時重新計算內部線性化點。jax.checkpoint()
裝飾器,別名為jax.remat()
,提供了一種在自動微分的背景下,特別是使用反向模式自動微分 (如jax.grad()
和jax.vjp()
) 以及jax.linearize()
時,權衡計算時間和記憶體成本的方法。當以反向模式微分函式時,預設情況下,所有線性化點 (例如,元素級非線性基本運算的輸入) 會在評估前向傳遞時儲存,以便在反向傳遞中重複使用。這種評估策略可能會導致高記憶體成本,甚至在硬體加速器上導致效能不佳,因為記憶體存取比 FLOPs 昂貴得多。
另一種評估策略是重新計算 (即重新實體化) 某些線性化點,而不是儲存它們。這種方法可以減少記憶體使用量,但會增加計算成本。
此函式裝飾器產生
fun
的新版本,該版本遵循重新實體化策略,而不是預設的儲存所有內容策略。也就是說,它傳回fun
的新版本,該版本在微分時不會儲存任何其內部線性化點。相反地,這些線性化點會從函式的已儲存輸入重新計算。請參閱以下範例。
- 參數:
fun (Callable) – 要變更其自動微分評估策略的函式,從預設的儲存所有中間線性化點變更為重新計算它們。其引數和傳回值應為陣列、純量或它們的 (巢狀) 標準 Python 容器 (tuple/list/dict)。
prevent_cse (bool) – 選用,布林關鍵字限定引數,指示是否在從微分產生的 HLO 中阻止常見子表達式消除 (CSE) 優化。這種 CSE 阻止是有代價的,因為它可能會阻礙其他優化,並且可能會在某些後端 (尤其是 GPU) 上產生高額開銷。預設值為 True,因為否則,在
jit()
或pmap()
下,CSE 可能會破壞此裝飾器的目的。但在某些情況下,例如在scan()
內部使用時,這種 CSE 阻止機制是不必要的,在這種情況下,可以將prevent_cse
設定為 False。static_argnums (int | tuple[int, ...]) – 選用,整數或整數序列,關鍵字限定引數,指示要為哪些引數值進行特化,以用於追蹤和快取目的。將引數指定為靜態可以避免追蹤時發生 ConcretizationTypeErrors,但會以增加重新追蹤開銷為代價。請參閱以下範例。
policy (Callable[..., bool] | None | None) – 選用,可呼叫的關鍵字限定引數。它應該是
jax.checkpoint_policies
的屬性之一。可呼叫物件接受第一階基本應用程式的類型級別規格作為輸入,並傳回一個布林值,指示對應的輸出值是否可以儲存為殘差 (或者如果需要,是否必須在 (餘)切線計算中重新計算)。
- 傳回:
一個函式 (可呼叫物件),其輸入/輸出行為與
fun
相同,但在使用例如jax.grad()
、jax.vjp()
或jax.linearize()
進行微分時,會重新計算而不是儲存中間線性化點,因此可能節省記憶體,但會增加額外計算。- 傳回類型:
Callable
這是一個簡單的範例
>>> import jax >>> import jax.numpy as jnp
>>> @jax.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z ... >>> jax.value_and_grad(g)(2.0) (Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
在這裡,無論是否存在
jax.checkpoint()
裝飾器,都會產生相同的值。當裝飾器不存在時,值jnp.cos(2.0)
和jnp.cos(jnp.sin(2.0))
會在前向傳遞中計算,並儲存以供反向傳遞使用,因為反向傳遞需要它們,並且它們僅取決於原始輸入。當使用jax.checkpoint()
時,前向傳遞將僅計算原始輸出,並且僅原始輸入 (2.0
) 將儲存以供反向傳遞使用。屆時,值jnp.sin(2.0)
將與值jnp.cos(2.0)
和jnp.cos(jnp.sin(2.0))
一起重新計算。雖然
jax.checkpoint()
控制從前向傳遞儲存哪些值以在反向傳遞中使用,但評估函式或其 VJP 所需的總記憶體量取決於該函式的許多其他內部細節。這些細節包括使用了哪些數值基本運算、它們是如何組成的、在哪裡使用了 jit 和控制流程基本運算 (如 scan) 以及其他因素。jax.checkpoint()
裝飾器可以遞迴應用,以表達複雜的自動微分重新實體化策略。例如>>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... elif len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... else: ... f1 = recursive_checkpoint(funs[:len(funs)//2]) ... f2 = recursive_checkpoint(funs[len(funs)//2:]) ... return lambda x: f1(jax.checkpoint(f2)(x)) ...
如果
fun
涉及取決於引數值的 Python 控制流程,則可能需要使用static_argnums
參數。例如,考慮一個布林標誌引數from functools import partial @partial(jax.checkpoint, static_argnums=(1,)) def foo(x, is_training): if is_training: ... else: ...
在這裡,使用
static_argnums
允許if
語句的條件取決於is_training
的值。static_argnums
的使用成本是它會在跨呼叫引入重新追蹤開銷:在範例中,每次使用新的is_training
值呼叫foo
時,都會重新追蹤它。在某些情況下,也需要jax.ensure_compile_time_eval
@partial(jax.checkpoint, static_argnums=(1,)) def foo(x, y): with jax.ensure_compile_time_eval(): y_pos = y > 0 if y_pos: ... else: ...
作為使用
static_argnums
(和jax.ensure_compile_time_eval
) 的替代方案,可能更容易在jax.checkpoint()
-裝飾的函式外部計算某些值,然後閉包它們。