jax.checkpoint#

jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=())[source]#

使 fun 在微分時重新計算內部線性化點。

jax.checkpoint() 裝飾器,別名為 jax.remat(),提供了一種在自動微分的背景下,特別是使用反向模式自動微分 (如 jax.grad()jax.vjp()) 以及 jax.linearize() 時,權衡計算時間和記憶體成本的方法。

當以反向模式微分函式時,預設情況下,所有線性化點 (例如,元素級非線性基本運算的輸入) 會在評估前向傳遞時儲存,以便在反向傳遞中重複使用。這種評估策略可能會導致高記憶體成本,甚至在硬體加速器上導致效能不佳,因為記憶體存取比 FLOPs 昂貴得多。

另一種評估策略是重新計算 (即重新實體化) 某些線性化點,而不是儲存它們。這種方法可以減少記憶體使用量,但會增加計算成本。

此函式裝飾器產生 fun 的新版本,該版本遵循重新實體化策略,而不是預設的儲存所有內容策略。也就是說,它傳回 fun 的新版本,該版本在微分時不會儲存任何其內部線性化點。相反地,這些線性化點會從函式的已儲存輸入重新計算。

請參閱以下範例。

參數:
  • fun (Callable) – 要變更其自動微分評估策略的函式,從預設的儲存所有中間線性化點變更為重新計算它們。其引數和傳回值應為陣列、純量或它們的 (巢狀) 標準 Python 容器 (tuple/list/dict)。

  • prevent_cse (bool) – 選用,布林關鍵字限定引數,指示是否在從微分產生的 HLO 中阻止常見子表達式消除 (CSE) 優化。這種 CSE 阻止是有代價的,因為它可能會阻礙其他優化,並且可能會在某些後端 (尤其是 GPU) 上產生高額開銷。預設值為 True,因為否則,在 jit()pmap() 下,CSE 可能會破壞此裝飾器的目的。但在某些情況下,例如在 scan() 內部使用時,這種 CSE 阻止機制是不必要的,在這種情況下,可以將 prevent_cse 設定為 False。

  • static_argnums (int | tuple[int, ...]) – 選用,整數或整數序列,關鍵字限定引數,指示要為哪些引數值進行特化,以用於追蹤和快取目的。將引數指定為靜態可以避免追蹤時發生 ConcretizationTypeErrors,但會以增加重新追蹤開銷為代價。請參閱以下範例。

  • policy (Callable[..., bool] | None | None) – 選用,可呼叫的關鍵字限定引數。它應該是 jax.checkpoint_policies 的屬性之一。可呼叫物件接受第一階基本應用程式的類型級別規格作為輸入,並傳回一個布林值,指示對應的輸出值是否可以儲存為殘差 (或者如果需要,是否必須在 (餘)切線計算中重新計算)。

傳回:

一個函式 (可呼叫物件),其輸入/輸出行為與 fun 相同,但在使用例如 jax.grad()jax.vjp()jax.linearize() 進行微分時,會重新計算而不是儲存中間線性化點,因此可能節省記憶體,但會增加額外計算。

傳回類型:

Callable

這是一個簡單的範例

>>> import jax
>>> import jax.numpy as jnp
>>> @jax.checkpoint
... def g(x):
...   y = jnp.sin(x)
...   z = jnp.sin(y)
...   return z
...
>>> jax.value_and_grad(g)(2.0)
(Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))

在這裡,無論是否存在 jax.checkpoint() 裝飾器,都會產生相同的值。當裝飾器不存在時,值 jnp.cos(2.0)jnp.cos(jnp.sin(2.0)) 會在前向傳遞中計算,並儲存以供反向傳遞使用,因為反向傳遞需要它們,並且它們僅取決於原始輸入。當使用 jax.checkpoint() 時,前向傳遞將僅計算原始輸出,並且僅原始輸入 (2.0) 將儲存以供反向傳遞使用。屆時,值 jnp.sin(2.0) 將與值 jnp.cos(2.0)jnp.cos(jnp.sin(2.0)) 一起重新計算。

雖然 jax.checkpoint() 控制從前向傳遞儲存哪些值以在反向傳遞中使用,但評估函式或其 VJP 所需的總記憶體量取決於該函式的許多其他內部細節。這些細節包括使用了哪些數值基本運算、它們是如何組成的、在哪裡使用了 jit 和控制流程基本運算 (如 scan) 以及其他因素。

jax.checkpoint() 裝飾器可以遞迴應用,以表達複雜的自動微分重新實體化策略。例如

>>> def recursive_checkpoint(funs):
...   if len(funs) == 1:
...     return funs[0]
...   elif len(funs) == 2:
...     f1, f2 = funs
...     return lambda x: f1(f2(x))
...   else:
...     f1 = recursive_checkpoint(funs[:len(funs)//2])
...     f2 = recursive_checkpoint(funs[len(funs)//2:])
...     return lambda x: f1(jax.checkpoint(f2)(x))
...

如果 fun 涉及取決於引數值的 Python 控制流程,則可能需要使用 static_argnums 參數。例如,考慮一個布林標誌引數

from functools import partial

@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, is_training):
  if is_training:
    ...
  else:
    ...

在這裡,使用 static_argnums 允許 if 語句的條件取決於 is_training 的值。static_argnums 的使用成本是它會在跨呼叫引入重新追蹤開銷:在範例中,每次使用新的 is_training 值呼叫 foo 時,都會重新追蹤它。在某些情況下,也需要 jax.ensure_compile_time_eval

@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, y):
  with jax.ensure_compile_time_eval():
    y_pos = y > 0
  if y_pos:
    ...
  else:
    ...

作為使用 static_argnums (和 jax.ensure_compile_time_eval) 的替代方案,可能更容易在 jax.checkpoint()-裝飾的函式外部計算某些值,然後閉包它們。