custom_vjp
和 nondiff_argnums
更新指南#
mattjj@ 2020 年 10 月 14 日
本文假設您已熟悉 jax.custom_vjp
,如用於 JAX 可轉換 Python 函式的自訂導數規則筆記本中所述。
需要更新什麼#
在 JAX PR #4008 之後,傳遞到 custom_vjp
函式的 nondiff_argnums
中的引數不能是 Tracer
s (或是 Tracer
s 的容器),這基本上表示為了允許任意轉換的程式碼,nondiff_argnums
不應使用於陣列值引數。相反地,nondiff_argnums
應僅用於非陣列值,例如 Python 可呼叫物件或形狀元組或字串。
凡是我們過去使用 nondiff_argnums
處理陣列值的地方,我們都應該直接將它們作為常規引數傳遞。在 bwd
規則中,我們需要為它們產生值,但我們可以只產生 None
值來表示沒有對應的梯度值。
例如,以下是撰寫 clip_gradient
的舊方法,當 hi
和/或 lo
是來自某些 JAX 轉換的 Tracer
s 時,這種方法將無法運作。
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
這是新的、很棒的方法,它支援任意轉換
import jax
@jax.custom_vjp # no nondiff_argnums!
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save lo and hi values as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
如果您使用舊方法而不是新方法,在任何可能出錯的情況下 (也就是當有 Tracer
傳遞到 nondiff_argnums
引數中時),您都會收到明確的錯誤訊息。
這是一個我們實際上需要將 nondiff_argnums
與 custom_vjp
一起使用的案例
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
return f(x)
def skip_app_fwd(f, x):
return skip_app(f, x), None
def skip_app_bwd(f, _, g):
return (g,)
skip_app.defvjp(skip_app_fwd, skip_app_bwd)
說明#
將 Tracer
s 傳遞到 nondiff_argnums
引數中一直都存在錯誤。雖然有些情況可以正確運作,但其他情況會導致複雜且令人困惑的錯誤訊息。
錯誤的本質在於 nondiff_argnums
的實作方式非常像詞法閉包。但是詞法閉包在 Tracer
s 上的運作方式在當時並非旨在與 custom_jvp
/custom_vjp
一起使用。以這種方式實作 nondiff_argnums
是個錯誤!
PR #4008 修正了 custom_jvp
和 custom_vjp
的所有詞法閉包問題。 耶!也就是說,現在 custom_jvp
和 custom_vjp
函式和規則可以隨心所欲地封閉 Tracer
s。對於所有非自動微分轉換,一切都會正常運作。對於自動微分轉換,我們會收到關於為什麼我們不能對 custom_jvp
或 custom_vjp
封閉的值進行微分的明確錯誤訊息
偵測到相對於封閉值的 custom_jvp 函式的微分。這是不支援的,因為自訂 JVP 規則僅指定如何相對於顯式輸入參數來微分 custom_jvp 函式。
嘗試將封閉值作為引數傳遞到 custom_jvp 函式中,並調整自訂 JVP 規則。
以這種方式收緊和強化 custom_jvp
和 custom_vjp
時,我們發現允許 custom_vjp
接受其 nondiff_argnums
中的 Tracer
s 需要大量的簿記工作:我們需要重寫使用者的 fwd
函式以將值作為殘差傳回,並重寫使用者的 bwd
函式以將它們作為常規殘差接受 (而不是像 nondiff_argnums
那樣將它們作為特殊的前導引數接受)。這看起來也許是可管理的,直到您考慮到我們必須如何處理任意 pytrees!此外,這種複雜性是不必要的:如果使用者程式碼將類似陣列的不可微分引數視為常規引數和殘差,那麼一切就已經可以正常運作了。(在 #4039 之前,JAX 可能會抱怨在自動微分中涉及整數值輸入和輸出,但在 #4039 之後,這些都能正常運作!)
與 custom_vjp
不同,讓 custom_jvp
與作為 Tracer
s 的 nondiff_argnums
引數一起運作很容易。因此,這些更新只需要在 custom_vjp
中發生。