jax.closure_convert#

jax.closure_convert(fun, *example_args)[原始碼]#

閉包轉換工具,用於高階自訂導數。

若要使用 jax.custom_vjp(f) 等定義自訂導數,目標函式 f 必須將所有與微分相關的值作為形式引數。如果 f 是高階函式,也就是說它接受 Python 函式 g 作為引數,則儲存在 g 的閉包中的值將無法被自訂導數規則看到,並且涉及這些值的 AD 嘗試將會失敗。一種解決方法是轉換閉包,方法是提取這些值,並在自訂導數邊界上將它們作為明確的形式引數傳遞。此工具執行該轉換。更精確地說,它會閉包轉換函式 fun,該函式專門用於 example_args 中給定的引數類型。

當我們在此處提及 fun 的「閉包中的值」時,我們並非指直接在定義 fun 時由 Python 捕獲的值(例如,fun.__closure__ 中的 Python 物件,如果該屬性存在)。相反地,我們指的是在 example_args 上執行 fun 期間遇到的值,這些值決定了其輸出。這可能包括例如以遞移方式在 Python 閉包中捕獲的陣列,即由 fun 呼叫的函式的 Python 閉包、它們呼叫的函式的閉包等等。

函式 fun 必須是純函式。

使用範例

def minimize(objective_fn, x0):
  converted_fn, aux_args = closure_convert(objective_fn, x0)
  return _minimize(converted_fn, x0, *aux_args)

@partial(custom_vjp, nondiff_argnums=(0,))
def _minimize(objective_fn, x0, *args):
  z = objective_fn(x0, *args)
  # ... find minimizer x_opt ...
  return x_opt

def fwd(objective_fn, x0, *args):
  y = _minimize(objective_fn, x0, *args)
  return y, (y, args)

def rev(objective_fn, res, g):
  y, args = res
  y_bar = g
  # ... custom reverse-mode AD ...
  return x0_bar, *args_bars

_minimize.defvjp(fwd, rev)
參數:
  • fun (Callable) – 要轉換的 Python 可呼叫物件。必須是純函式。

  • example_args – 陣列、純量或(巢狀)標準 Python 容器(元組、列表、字典、具名元組,即 pytree)的物件,用於決定 fun 的形式引數類型。此類型專用形式的 fun 是將要進行閉包轉換的函式。

回傳:

一個包含 (i) Python 可呼叫物件的配對,接受與 fun 相同的引數,後跟對應於從其閉包中提升的值的引數,以及 (ii) 從閉包中提升的值的列表。

回傳類型:

tuple[Callable, list[Any]]