用於可 JAX 轉換函數的自訂 JVP/VJP 規則#

這是一份設計文件,說明 jax.custom_jvpjax.custom_vjp 設計和實作背後的一些想法。如需以使用者為導向的文件,請參閱教學筆記本

在 JAX 中有兩種定義微分規則的方法

  1. 使用 jax.custom_jvpjax.custom_vjp 為已可 JAX 轉換的 Python 函數定義自訂微分規則;以及

  2. 定義新的 core.Primitive 實例以及其所有轉換規則,例如從其他系統(如求解器、模擬器或一般數值計算系統)呼叫函數。

本文檔僅關於 #1。

目錄#

目標#

我們希望使用者自訂其程式碼的前向和/或反向模式微分行為。此自訂

  1. 應在如何運作以及如何與其他 JAX 轉換組合方面具有清晰且一致的語意;以及

  2. 彈性支援用例和工作流程,如 AutogradPyTorch 中所示,包括涉及 Python 控制流程微分和 NaN 除錯的工作流程。

作為 JAX 開發者,我們希望編寫程式庫函數,例如 logitexpit,這些函數根據其他基本運算定義,但為了微分的目的,具有類似基本運算的行為,因為我們希望為它們定義自訂微分規則,這些規則可能在數值上更穩定或效能更高。特別是,我們不希望必須為 logitexpit 等函數指定 vmapjit 規則。

作為一個延伸目標,我們希望使 JAX 成為尋求為高階函數(如 fixed_pointodeint 等)新增自訂微分規則的進階使用者的絕佳環境;本文檔不會解決該問題,但我們希望確信我們不會排除該問題的良好解決方案。

也就是說,我們的主要目標是

  1. 解決 vmap-移除-自訂-jvp 語意問題 (#1249),以及

  2. 允許在自訂 VJP 中使用 Python,例如除錯 NaN (#1275)。

次要目標是 3. 清理和簡化使用者體驗(符號零、kwargs 等) 4. 朝向使用者可以輕鬆新增 fixed_pointodeintroot 等的世界邁進。

總體而言,我們希望關閉 #116#1097#1249#1275#1366#1723#1670#1875#1938,並取代 custom_transforms 機制(來自 #636#818 和其他)。

非目標#

以下是我們打算實現的目標

  1. custom_transforms 機制旨在提供一種轉換通用機制來自訂行為,原則上(雖然實際上從未使用過)允許使用者自訂任何轉換的規則,同時以某種方式繼承其他轉換的「透明」行為。我們反而只會解決微分(JVP 和 VJP,分別)的自訂問題。 微分是唯一實際請求的案例,透過專注於微分,我們可以降低複雜性並提高彈性。若要控制所有規則,您可以直接編寫基本運算。

  2. 我們不會將數學美學優先於使用者端的彈性和清晰度,以及實作端的簡潔性。特別是,雖然自訂 VJP 簽名 a -> (b, CT b --o CT a) 在數學上令人滿意,但如果由於回傳類型中的閉包而在 Python 機制中難以實作,我們可以更明確地處理殘差。

  3. 序列化支援,形式為分段輸出的序列化程式表示可以載入並進一步進行 JAX 轉換,而不僅僅是評估,目前不在這些自訂 JVP/VJP 轉換規則的範圍內。序列化不僅對於想要儲存其計算的某些表示形式(並在載入後轉換)的研究人員可能很有用,而且對於未來的考量(例如讓 jaxpr 轉換在 Python 外部實作,或將 jaxpr 作為 MLIR 方言)也可能很有用。透過將其定義為此設計目的的非目標,我們對可以存放 Python 可呼叫物件的位置的限制較少。

主要問題描述#

vmap-移除-自訂-jvp 語意問題#

vmap-移除-自訂-jvp 語意問題是 vmap 無法與具有 custom_transforms 規則的函數微分正確組合

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最後一行 grad-of-vmap 的結果出乎意料!一般而言,應用 vmap,或實際上任何非微分轉換,都會產生移除自訂微分規則的效果。(當定義自訂 VJP 規則時,應用 jvp 會導致失敗。)

問題存在的原因是轉換就像重寫,而 vmap 轉換有效地重寫函數,使其不再呼叫新引入的基本運算,而該基本運算具有自訂規則(因此 grad 隨後不會產生自訂規則的結果)。更詳細地說,custom_transforms 機制設定了使評估 f(x) 應用函數

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中 f_primitive 是一個新的基本運算(為每個 custom_transforms 函數引入,實際上是為函數的每次呼叫引入),自訂 VJP 規則與之關聯。當我們評估 grad(f)(x) 時,微分機制會遇到 f_primitive 並使用自訂規則處理它。

但是,由於 f_primitivevmap透明的,就 vmapf_primitive 的定義上運作(有效地透過內聯)而言,函數 vmap(f) 實際上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

換句話說,vmap 根據其底層基本運算及其轉換規則重寫函數,完全移除 f_primitive

更一般而言,由於 vmap(f) 的語意是根據對 f 的呼叫來定義的,因此移除自訂導數規則在語意上是不一致的。也就是說,由於我們定義

vmap(f)(xs) == np.stack([f(x) for x in xs])

我們必須有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

但是,當 f 定義了自訂導數規則時,不會觀察到此屬性,因為自訂導數規則用於右側版本,但不適用於左側版本。

此問題並非 vmap 特有;它適用於所有轉換,對於這些轉換,轉換函數 f 的語意是根據對函數 f 的呼叫來定義的,而不是將其重寫為另一個函數。mask 轉換也屬於此類。微分轉換和假設的所有一元函數變為餘弦轉換不屬於此類。

(其他自訂規則(如自訂 vmap 規則)之間的交互作用可能會變得更加複雜,這表明 custom_transforms 的問題框架過於廣泛。)

Python 彈性問題#

在 JAX 中,如同 AutogradPyTorch 而非 TF1 中一樣,Python 函數的微分是在函數執行和追蹤時執行的。此行為讓使用者感到高興,原因如下。

首先也是最重要的,它啟用了基於 pdb 的工作流程,例如用於檢查數值或捕捉 NaN。 也就是說,使用者可以採用標準 Python 除錯器和其他 Python 原生工具來除錯其程式碼,甚至能夠檢查執行階段值,以了解範例的數值行為,並捕捉從根本上說是執行階段錯誤(如 NaN)的錯誤。實際上,就在處理與此設計對應的 PR 時,特別是在 odeint 基本運算上,我多次使用執行階段值檢查來除錯問題,這增加了我對這是 Python 中關鍵使用者工作流程的信心。一個特別方便的技巧,我在 JAX 和 Autograd 中多次使用過,就是在自訂 VJP 規則中插入除錯器中斷點,以便在反向傳播中的特定點進入除錯器。

其次,它允許對 Python 原生控制流程進行微分。 我們不確定在最終軟體成品中有多常使用此功能,但當使用者第一次接觸 JAX 或 Autograd 時,他們通常會對這種自由度印象深刻。我們將其包含在我們的 JAX 和 Autograd README、投影片組和演示文稿的頂部是有原因的。放棄此功能將是從 Autograd 向後退一步。我們希望 JAX 具有最佳的自動微分。

但是,custom_transforms 機制不提供這種 Python 支援彈性。也就是說,由於它是根據從使用者函數和自訂微分規則的 Python 程式碼預先形成 jaxpr 來實作的,因此像這樣的程式碼會導致抽象值追蹤錯誤

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解決方案構想#

主要想法是 dougalm@ 已經使用 core.call 解決了這些問題。也就是說,我們可以將為使用者函數指定自訂 JVP 規則的任務構建為新的 Python 級別呼叫基本運算(不會新增到 jaxpr 語言;請參閱下文)。這個新的呼叫基本運算具有與之關聯的使用者 Python 函數,就像 core.call 一樣,但另外還有一個 Python 可呼叫物件,表示 JVP 規則。讓我們將這個新的呼叫基本運算稱為 custom_jvp_call

vmap 這樣的轉換與 custom_jvp_call 的交互方式與 core.call 相同:它們有效地直接傳遞它,並應用於底層 Python 可呼叫物件。示意性地,為了方便起見,以基本運算的 curried 版本編寫,類似於 vmap 如何透過應用於要呼叫的函數來與 core.call 交互

vmap(call(f)) == call(vmap(f))

對於新的基本運算 custom_jvp_call,我們只需將 vmap 應用於它所包含的兩個函數

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

此行為表示我們已解決 vmap-移除-自訂-jvp 語意問題

jvp 轉換的交互方式正如人們可能預期的那樣:它只呼叫 f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

由於 custom_jvp_call 的行為類似於 core.call(而不是像 xla.xla_call),因為它沒有提高其輸入的抽象層級(因為它沒有延遲任何內容或分段輸出任何內容),這表示我們已解決 Python 彈性問題:對使用者 Python 函數沒有任何限制(高於 jvpvjp 所需的常見函數式程式設計限制)。

評估和編譯呢?這些是「退出」JAX 系統的兩種方式,因為在這些步驟之後無法應用其他轉換。因此,它們的規則很簡單

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

換句話說,如果 JVP 規則尚未將 custom_jvp_call(f, f_jvp) 重寫為 f_jvp,當我們使用 eval 進行評估或使用 jit 分段輸出到 XLA 時,永遠不會應用微分,因此我們只需忽略 f_jvp,其行為就像 core.call。但是,由於接下來討論的細微之處,custom_jvp_call 的部分評估規則必須稍微複雜一些,因為部分評估不僅用於使用 jit 分段輸出到 XLA。

唯一剩下的細微之處與「初始樣式」jaxpr 形成基本運算(如 lax.scan)及其轉換規則有關。這些表示與用於編譯的分段輸出到 jaxpr 不同種類的「分段輸出到 jaxpr」,因為我們可以在分段輸出的 jaxpr 上執行其他轉換。也就是說,當 lax.scan 形成 jaxpr 時,它不會退出轉換系統,因為當我們將 jvp 或 vmap 應用於 lax.scan 時,我們需要將其應用於 jaxpr 表示的函數。

陳述此細微之處的另一種方式是,初始樣式基本運算(如 lax.scan)依賴於往返於 jaxpr 和 Python 可呼叫物件之間同時保留語意的能力。這必然也表示保留自訂微分規則語意。

解決方案是使用一點動態作用域:當我們為了初始樣式的原語(primitive)(例如 lax_control_flow.py 中的那些) 暫存(staging out)到 jaxpr 時,我們會在全域追蹤狀態(global trace state)中設定一個位元。當該位元被設定時,我們不會使用最終樣式的 custom_jvp_call 原語,而是使用初始樣式的 custom_jvp_call_jaxpr 原語,並預先將函數 ff_jvp 追蹤(trace)到 jaxpr,以簡化初始樣式的處理。custom_jvp_call_jaxpr 原語在其他方面與最終樣式版本相似。

(註腳:雖然在概念上,我們在綁定 custom_jvp_call_jaxpr 之前為 ff_jvp 都形成了 jaxpr,但我們需要延遲 f_jvp 的 jaxpr 的形成,因為它可能會呼叫自定義 JVP 函數,因此及早處理(eager processing)會導致無限遞迴。我們在一個 thunk 中延遲了 jaxpr 的形成。)

如果我們放棄了Python 靈活性問題,我們就可以只使用 custom_jvp_call_jaxpr,而不需要單獨的 Python 層級原語 custom_jvp_call

API#

對於一個 a -> b 函數的自定義 JVP 是用一個 (a, Ta) -> (b, T b) 函數來指定的

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自動微分題外話:為了使規則適用於更高階的微分,必須在 f_jvp 的主體中呼叫 f;這排除了一些在 f 的內部運作和正切計算之間共享工作的方式。)

對於一個 a -> b 函數的自定義 VJP 是用一個 a -> (b, c) 前向傳遞函數與一個 (c, CT b) -> CT a 後向傳遞函數配對來指定的

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

簽名 a -> (b, CT b --o CT a) 在美學上更令人愉悅,但支持它會使實作更複雜,並可能需要妥協可表達性的期望。基本原因是 Python 可呼叫物件是不透明的(除非我們及早將它們追蹤到 jaxpr,這會帶來表達性約束),並且在這種情況下,我們可能會返回一個可呼叫物件,其閉包內部有 vmap 追蹤器,我們需要在前向傳遞期間了解這些追蹤器。

我們可以添加便利的包裝器(wrapper),例如,一次為單個參數定義 JVP 規則(就像我們在內部為原語所做的那樣)。但是因為這個提案已經夠複雜了,所以我決定反對便利層;讓我們現在保持最小化。

API 還有一些其他的花俏功能

  • 輸入和輸出類型 abc 可以是 jaxtypes 的任意 pytrees。

  • 當可以使用 inspect 模組將具名引數(keyword arguments)解析為位置引數時,支援具名引數。這是對 Python 3 程式化檢查引數簽名的改進能力的實驗。我相信它是合理的,但並不完整,這是一個可以接受的狀態。(另請參閱 #2069。)

  • 可以使用 nondiff_argnums 將引數標記為不可微分,並且與 jitstatic_argnums 一樣,這些引數不必是 JAX 類型。我們需要為如何將這些引數傳遞給規則設定一個慣例。對於類型簽名為 (d, a) -> b 的原始函數,其中 d 表示不可微分類型,JVP 規則的簽名是 (a, T a, d) -> T b,而 VJP 規則的反向組件簽名是 (d, c, CT b) -> CT a。也就是說,對於自定義 JVP 規則,不可微分的引數在 primalstangents 之後按順序傳遞,並且在自定義 VJP 規則的反向函數中,在殘差之前按順序傳遞。

實作筆記#

  • 已更新 jax.experimental.odeint

    • 由於 odeint 是自定義 VJP 規則的一個相當複雜的使用者,除了更新它以使其完全運作之外,我還想修改它,使其成為新的自定義 VJP API 的典型使用者,以此來測試該 API 是否良好。

    • 在此過程中,我對 odeint 的實作進行了其他改進

      • 移除展開/解開(raveling/unraveling)的樣板程式碼

      • 使用 lax.scan 來移除索引更新邏輯

      • 在簡單擺錘基準測試中速度提高了 20% 以上

  • 為每個轉換(transform)上的自定義導數呼叫原語 custom_jvp_callcustom_vjp_call 添加了自定義綁定方法(bind method)。它類似於 core.call_bind,只是我們不處理環境追蹤(env traces):那些只是錯誤。

  • 添加了 custom_lin 原語,它被暫存輸出到線性 jaxpr 中,以便在使用自定義 VJP 規則時進行轉置(transpose)。

    • 由於我們的反向模式自動微分被分解為線性化、部分求值和轉置,因此我們的自定義 VJP 規則分兩個獨立的步驟處理:一個在線性化期間,另一個在轉置期間。

    • 線性化步驟,即 custom_vjp_call 的 JVP 規則,將 custom_lin 應用於正切值;custom_lin 攜帶了使用者自定義的後向傳遞函數,並且作為一個原語,它只有一個轉置規則。

    • 這個機制在 #636 中有更詳細的描述。

  • 為了防止