jax.closure_convert#
- jax.closure_convert(fun, *example_args)[原始碼]#
閉包轉換工具,用於高階自訂導數。
若要使用
jax.custom_vjp(f)
等定義自訂導數,目標函式f
必須將所有與微分相關的值作為形式引數。如果f
是高階函式,也就是說它接受 Python 函式g
作為引數,則儲存在g
的閉包中的值將無法被自訂導數規則看到,並且涉及這些值的 AD 嘗試將會失敗。一種解決方法是轉換閉包,方法是提取這些值,並在自訂導數邊界上將它們作為明確的形式引數傳遞。此工具執行該轉換。更精確地說,它會閉包轉換函式fun
,該函式專門用於example_args
中給定的引數類型。當我們在此處提及
fun
的「閉包中的值」時,我們並非指直接在定義fun
時由 Python 捕獲的值(例如,fun.__closure__
中的 Python 物件,如果該屬性存在)。相反地,我們指的是在example_args
上執行fun
期間遇到的值,這些值決定了其輸出。這可能包括例如以遞移方式在 Python 閉包中捕獲的陣列,即由fun
呼叫的函式的 Python 閉包、它們呼叫的函式的閉包等等。函式
fun
必須是純函式。使用範例
def minimize(objective_fn, x0): converted_fn, aux_args = closure_convert(objective_fn, x0) return _minimize(converted_fn, x0, *aux_args) @partial(custom_vjp, nondiff_argnums=(0,)) def _minimize(objective_fn, x0, *args): z = objective_fn(x0, *args) # ... find minimizer x_opt ... return x_opt def fwd(objective_fn, x0, *args): y = _minimize(objective_fn, x0, *args) return y, (y, args) def rev(objective_fn, res, g): y, args = res y_bar = g # ... custom reverse-mode AD ... return x0_bar, *args_bars _minimize.defvjp(fwd, rev)