custom_vjpnondiff_argnums 更新指南#

mattjj@ 2020 年 10 月 14 日

本文假設您已熟悉 jax.custom_vjp,如用於 JAX 可轉換 Python 函式的自訂導數規則筆記本中所述。

需要更新什麼#

在 JAX PR #4008 之後,傳遞到 custom_vjp 函式的 nondiff_argnums 中的引數不能是 Tracers (或是 Tracers 的容器),這基本上表示為了允許任意轉換的程式碼,nondiff_argnums 不應使用於陣列值引數。相反地,nondiff_argnums 應僅用於非陣列值,例如 Python 可呼叫物件或形狀元組或字串。

凡是我們過去使用 nondiff_argnums 處理陣列值的地方,我們都應該直接將它們作為常規引數傳遞。在 bwd 規則中,我們需要為它們產生值,但我們可以只產生 None 值來表示沒有對應的梯度值。

例如,以下是撰寫 clip_gradient方法,當 hi 和/或 lo 是來自某些 JAX 轉換的 Tracers 時,這種方法將無法運作。

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, None  # no residual values to save

def clip_gradient_bwd(lo, hi, _, g):
  return (jnp.clip(g, lo, hi),)

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

這是的、很棒的方法,它支援任意轉換

import jax

@jax.custom_vjp  # no nondiff_argnums!
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save lo and hi values as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # return None for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

如果您使用舊方法而不是新方法,在任何可能出錯的情況下 (也就是當有 Tracer 傳遞到 nondiff_argnums 引數中時),您都會收到明確的錯誤訊息。

這是一個我們實際上需要將 nondiff_argnumscustom_vjp 一起使用的案例

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
  return f(x)

def skip_app_fwd(f, x):
  return skip_app(f, x), None

def skip_app_bwd(f, _, g):
  return (g,)

skip_app.defvjp(skip_app_fwd, skip_app_bwd)

說明#

Tracers 傳遞到 nondiff_argnums 引數中一直都存在錯誤。雖然有些情況可以正確運作,但其他情況會導致複雜且令人困惑的錯誤訊息。

錯誤的本質在於 nondiff_argnums 的實作方式非常像詞法閉包。但是詞法閉包在 Tracers 上的運作方式在當時並非旨在與 custom_jvp/custom_vjp 一起使用。以這種方式實作 nondiff_argnums 是個錯誤!

PR #4008 修正了 custom_jvpcustom_vjp 的所有詞法閉包問題。 耶!也就是說,現在 custom_jvpcustom_vjp 函式和規則可以隨心所欲地封閉 Tracers。對於所有非自動微分轉換,一切都會正常運作。對於自動微分轉換,我們會收到關於為什麼我們不能對 custom_jvpcustom_vjp 封閉的值進行微分的明確錯誤訊息

偵測到相對於封閉值的 custom_jvp 函式的微分。這是不支援的,因為自訂 JVP 規則僅指定如何相對於顯式輸入參數來微分 custom_jvp 函式。

嘗試將封閉值作為引數傳遞到 custom_jvp 函式中,並調整自訂 JVP 規則。

以這種方式收緊和強化 custom_jvpcustom_vjp 時,我們發現允許 custom_vjp 接受其 nondiff_argnums 中的 Tracers 需要大量的簿記工作:我們需要重寫使用者的 fwd 函式以將值作為殘差傳回,並重寫使用者的 bwd 函式以將它們作為常規殘差接受 (而不是像 nondiff_argnums 那樣將它們作為特殊的前導引數接受)。這看起來也許是可管理的,直到您考慮到我們必須如何處理任意 pytrees!此外,這種複雜性是不必要的:如果使用者程式碼將類似陣列的不可微分引數視為常規引數和殘差,那麼一切就已經可以正常運作了。(在 #4039 之前,JAX 可能會抱怨在自動微分中涉及整數值輸入和輸出,但在 #4039 之後,這些都能正常運作!)

custom_vjp 不同,讓 custom_jvp 與作為 Tracers 的 nondiff_argnums 引數一起運作很容易。因此,這些更新只需要在 custom_vjp 中發生。