用於可 JAX 轉換函數的自訂 JVP/VJP 規則#
這是一份設計文件,說明 jax.custom_jvp
和 jax.custom_vjp
設計和實作背後的一些想法。如需以使用者為導向的文件,請參閱教學筆記本。
在 JAX 中有兩種定義微分規則的方法
使用
jax.custom_jvp
和jax.custom_vjp
為已可 JAX 轉換的 Python 函數定義自訂微分規則;以及定義新的
core.Primitive
實例以及其所有轉換規則,例如從其他系統(如求解器、模擬器或一般數值計算系統)呼叫函數。
本文檔僅關於 #1。
目錄#
目標#
我們希望使用者自訂其程式碼的前向和/或反向模式微分行為。此自訂
應在如何運作以及如何與其他 JAX 轉換組合方面具有清晰且一致的語意;以及
應彈性支援用例和工作流程,如 Autograd 和 PyTorch 中所示,包括涉及 Python 控制流程微分和 NaN 除錯的工作流程。
作為 JAX 開發者,我們希望編寫程式庫函數,例如 logit
和 expit
,這些函數根據其他基本運算定義,但為了微分的目的,具有類似基本運算的行為,因為我們希望為它們定義自訂微分規則,這些規則可能在數值上更穩定或效能更高。特別是,我們不希望必須為 logit
和 expit
等函數指定 vmap
或 jit
規則。
作為一個延伸目標,我們希望使 JAX 成為尋求為高階函數(如 fixed_point
、odeint
等)新增自訂微分規則的進階使用者的絕佳環境;本文檔不會解決該問題,但我們希望確信我們不會排除該問題的良好解決方案。
也就是說,我們的主要目標是
次要目標是 3. 清理和簡化使用者體驗(符號零、kwargs 等) 4. 朝向使用者可以輕鬆新增 fixed_point
、odeint
、root
等的世界邁進。
總體而言,我們希望關閉 #116、#1097、#1249、#1275、#1366、#1723、#1670、#1875、#1938,並取代 custom_transforms 機制(來自 #636、#818 和其他)。
非目標#
以下是我們不打算實現的目標
custom_transforms
機制旨在提供一種轉換通用機制來自訂行為,原則上(雖然實際上從未使用過)允許使用者自訂任何轉換的規則,同時以某種方式繼承其他轉換的「透明」行為。我們反而只會解決微分(JVP 和 VJP,分別)的自訂問題。 微分是唯一實際請求的案例,透過專注於微分,我們可以降低複雜性並提高彈性。若要控制所有規則,您可以直接編寫基本運算。我們不會將數學美學優先於使用者端的彈性和清晰度,以及實作端的簡潔性。特別是,雖然自訂 VJP 簽名
a -> (b, CT b --o CT a)
在數學上令人滿意,但如果由於回傳類型中的閉包而在 Python 機制中難以實作,我們可以更明確地處理殘差。序列化支援,形式為分段輸出的序列化程式表示可以載入並進一步進行 JAX 轉換,而不僅僅是評估,目前不在這些自訂 JVP/VJP 轉換規則的範圍內。序列化不僅對於想要儲存其計算的某些表示形式(並在載入後轉換)的研究人員可能很有用,而且對於未來的考量(例如讓 jaxpr 轉換在 Python 外部實作,或將 jaxpr 作為 MLIR 方言)也可能很有用。透過將其定義為此設計目的的非目標,我們對可以存放 Python 可呼叫物件的位置的限制較少。
主要問題描述#
vmap-移除-自訂-jvp 語意問題#
vmap-移除-自訂-jvp 語意問題是 vmap 無法與具有 custom_transforms
規則的函數微分正確組合
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
return 2. * x
# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
return f(x), lambda g: 3. * x # 3 instead of 2
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # 3.
vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
最後一行 grad-of-vmap 的結果出乎意料!一般而言,應用 vmap
,或實際上任何非微分轉換,都會產生移除自訂微分規則的效果。(當定義自訂 VJP 規則時,應用 jvp
會導致失敗。)
問題存在的原因是轉換就像重寫,而 vmap
轉換有效地重寫函數,使其不再呼叫新引入的基本運算,而該基本運算具有自訂規則(因此 grad
隨後不會產生自訂規則的結果)。更詳細地說,custom_transforms
機制設定了使評估 f(x)
應用函數
{ lambda ; ; a.
let b = f_primitive a
in [b] }
其中 f_primitive
是一個新的基本運算(為每個 custom_transforms
函數引入,實際上是為函數的每次呼叫引入),自訂 VJP 規則與之關聯。當我們評估 grad(f)(x)
時,微分機制會遇到 f_primitive
並使用自訂規則處理它。
但是,由於 f_primitive
對 vmap
是透明的,就 vmap
在 f_primitive
的定義上運作(有效地透過內聯)而言,函數 vmap(f)
實際上是
{ lambda ; ; a.
let b = mul 2. a
in [b] }
換句話說,vmap
根據其底層基本運算及其轉換規則重寫函數,完全移除 f_primitive
。
更一般而言,由於 vmap(f)
的語意是根據對 f 的呼叫來定義的,因此移除自訂導數規則在語意上是不一致的。也就是說,由於我們定義
vmap(f)(xs) == np.stack([f(x) for x in xs])
我們必須有
jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
但是,當 f
定義了自訂導數規則時,不會觀察到此屬性,因為自訂導數規則用於右側版本,但不適用於左側版本。
此問題並非 vmap
特有;它適用於所有轉換,對於這些轉換,轉換函數 f
的語意是根據對函數 f
的呼叫來定義的,而不是將其重寫為另一個函數。mask
轉換也屬於此類。微分轉換和假設的所有一元函數變為餘弦轉換不屬於此類。
(其他自訂規則(如自訂 vmap
規則)之間的交互作用可能會變得更加複雜,這表明 custom_transforms
的問題框架過於廣泛。)
Python 彈性問題#
在 JAX 中,如同 Autograd 和 PyTorch 而非 TF1 中一樣,Python 函數的微分是在函數執行和追蹤時執行的。此行為讓使用者感到高興,原因如下。
首先也是最重要的,它啟用了基於 pdb 的工作流程,例如用於檢查數值或捕捉 NaN。 也就是說,使用者可以採用標準 Python 除錯器和其他 Python 原生工具來除錯其程式碼,甚至能夠檢查執行階段值,以了解範例的數值行為,並捕捉從根本上說是執行階段錯誤(如 NaN)的錯誤。實際上,就在處理與此設計對應的 PR 時,特別是在 odeint
基本運算上,我多次使用執行階段值檢查來除錯問題,這增加了我對這是 Python 中關鍵使用者工作流程的信心。一個特別方便的技巧,我在 JAX 和 Autograd 中多次使用過,就是在自訂 VJP 規則中插入除錯器中斷點,以便在反向傳播中的特定點進入除錯器。
其次,它允許對 Python 原生控制流程進行微分。 我們不確定在最終軟體成品中有多常使用此功能,但當使用者第一次接觸 JAX 或 Autograd 時,他們通常會對這種自由度印象深刻。我們將其包含在我們的 JAX 和 Autograd README、投影片組和演示文稿的頂部是有原因的。放棄此功能將是從 Autograd 向後退一步。我們希望 JAX 具有最佳的自動微分。
但是,custom_transforms
機制不提供這種 Python 支援彈性。也就是說,由於它是根據從使用者函數和自訂微分規則的 Python 程式碼預先形成 jaxpr 來實作的,因此像這樣的程式碼會導致抽象值追蹤錯誤
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
if x > 0:
return x
else:
return 0.
def f_vjp(x):
return ...
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # Error!
解決方案構想#
主要想法是 dougalm@ 已經使用 core.call
解決了這些問題。也就是說,我們可以將為使用者函數指定自訂 JVP 規則的任務構建為新的 Python 級別呼叫基本運算(不會新增到 jaxpr 語言;請參閱下文)。這個新的呼叫基本運算具有與之關聯的使用者 Python 函數,就像 core.call
一樣,但另外還有一個 Python 可呼叫物件,表示 JVP 規則。讓我們將這個新的呼叫基本運算稱為 custom_jvp_call
。
像 vmap
這樣的轉換與 custom_jvp_call
的交互方式與 core.call
相同:它們有效地直接傳遞它,並應用於底層 Python 可呼叫物件。示意性地,為了方便起見,以基本運算的 curried 版本編寫,類似於 vmap
如何透過應用於要呼叫的函數來與 core.call
交互
vmap(call(f)) == call(vmap(f))
對於新的基本運算 custom_jvp_call
,我們只需將 vmap
應用於它所包含的兩個函數
vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
此行為表示我們已解決 vmap-移除-自訂-jvp 語意問題。
jvp
轉換的交互方式正如人們可能預期的那樣:它只呼叫 f_jvp
,
jvp(call(f)) == call(jvp(f))
jvp(custom_jvp_call(f, f_jvp)) == f_jvp
由於 custom_jvp_call
的行為類似於 core.call
(而不是像 xla.xla_call
),因為它沒有提高其輸入的抽象層級(因為它沒有延遲任何內容或分段輸出任何內容),這表示我們已解決 Python 彈性問題:對使用者 Python 函數沒有任何限制(高於 jvp
或 vjp
所需的常見函數式程式設計限制)。
評估和編譯呢?這些是「退出」JAX 系統的兩種方式,因為在這些步驟之後無法應用其他轉換。因此,它們的規則很簡單
eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))
eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
換句話說,如果 JVP 規則尚未將 custom_jvp_call(f, f_jvp)
重寫為 f_jvp
,當我們使用 eval
進行評估或使用 jit
分段輸出到 XLA 時,永遠不會應用微分,因此我們只需忽略 f_jvp
,其行為就像 core.call
。但是,由於接下來討論的細微之處,custom_jvp_call
的部分評估規則必須稍微複雜一些,因為部分評估不僅用於使用 jit
分段輸出到 XLA。
唯一剩下的細微之處與「初始樣式」jaxpr 形成基本運算(如 lax.scan
)及其轉換規則有關。這些表示與用於編譯的分段輸出到 jaxpr 不同種類的「分段輸出到 jaxpr」,因為我們可以在分段輸出的 jaxpr 上執行其他轉換。也就是說,當 lax.scan
形成 jaxpr 時,它不會退出轉換系統,因為當我們將 jvp 或 vmap 應用於 lax.scan
時,我們需要將其應用於 jaxpr 表示的函數。
陳述此細微之處的另一種方式是,初始樣式基本運算(如 lax.scan
)依賴於往返於 jaxpr 和 Python 可呼叫物件之間同時保留語意的能力。這必然也表示保留自訂微分規則語意。
解決方案是使用一點動態作用域:當我們為了初始樣式的原語(primitive)(例如 lax_control_flow.py
中的那些) 暫存(staging out)到 jaxpr 時,我們會在全域追蹤狀態(global trace state)中設定一個位元。當該位元被設定時,我們不會使用最終樣式的 custom_jvp_call
原語,而是使用初始樣式的 custom_jvp_call_jaxpr
原語,並預先將函數 f
和 f_jvp
追蹤(trace)到 jaxpr,以簡化初始樣式的處理。custom_jvp_call_jaxpr
原語在其他方面與最終樣式版本相似。
(註腳:雖然在概念上,我們在綁定 custom_jvp_call_jaxpr
之前為 f
和 f_jvp
都形成了 jaxpr,但我們需要延遲 f_jvp
的 jaxpr 的形成,因為它可能會呼叫自定義 JVP 函數,因此及早處理(eager processing)會導致無限遞迴。我們在一個 thunk 中延遲了 jaxpr 的形成。)
如果我們放棄了Python 靈活性問題,我們就可以只使用 custom_jvp_call_jaxpr
,而不需要單獨的 Python 層級原語 custom_jvp_call
。
API#
對於一個 a -> b
函數的自定義 JVP 是用一個 (a, Ta) -> (b, T b)
函數來指定的
# f :: a -> b
@jax.custom_jvp
def f(x):
return np.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), np.cos(x) * t
f.defjvp(f_jvp)
(有趣的自動微分題外話:為了使規則適用於更高階的微分,必須在 f_jvp
的主體中呼叫 f
;這排除了一些在 f
的內部運作和正切計算之間共享工作的方式。)
對於一個 a -> b
函數的自定義 VJP 是用一個 a -> (b, c)
前向傳遞函數與一個 (c, CT b) -> CT
a 後向傳遞函數配對來指定的
# f :: a -> b
@jax.custom_vjp
def f(x):
return np.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), np.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
return (cos_x * g,)
f.defvjp(f_fwd, f_bwd)
簽名 a -> (b, CT b --o CT a)
在美學上更令人愉悅,但支持它會使實作更複雜,並可能需要妥協可表達性的期望。基本原因是 Python 可呼叫物件是不透明的(除非我們及早將它們追蹤到 jaxpr,這會帶來表達性約束),並且在這種情況下,我們可能會返回一個可呼叫物件,其閉包內部有 vmap
追蹤器,我們需要在前向傳遞期間了解這些追蹤器。
我們可以添加便利的包裝器(wrapper),例如,一次為單個參數定義 JVP 規則(就像我們在內部為原語所做的那樣)。但是因為這個提案已經夠複雜了,所以我決定反對便利層;讓我們現在保持最小化。
API 還有一些其他的花俏功能
輸入和輸出類型
a
、b
和c
可以是 jaxtypes 的任意 pytrees。當可以使用
inspect
模組將具名引數(keyword arguments)解析為位置引數時,支援具名引數。這是對 Python 3 程式化檢查引數簽名的改進能力的實驗。我相信它是合理的,但並不完整,這是一個可以接受的狀態。(另請參閱 #2069。)可以使用
nondiff_argnums
將引數標記為不可微分,並且與jit
的static_argnums
一樣,這些引數不必是 JAX 類型。我們需要為如何將這些引數傳遞給規則設定一個慣例。對於類型簽名為(d, a) -> b
的原始函數,其中d
表示不可微分類型,JVP 規則的簽名是(a, T a, d) -> T b
,而 VJP 規則的反向組件簽名是(d, c, CT b) -> CT a
。也就是說,對於自定義 JVP 規則,不可微分的引數在primals
和tangents
之後按順序傳遞,並且在自定義 VJP 規則的反向函數中,在殘差之前按順序傳遞。
實作筆記#
已更新
jax.experimental.odeint
由於
odeint
是自定義 VJP 規則的一個相當複雜的使用者,除了更新它以使其完全運作之外,我還想修改它,使其成為新的自定義 VJP API 的典型使用者,以此來測試該 API 是否良好。在此過程中,我對
odeint
的實作進行了其他改進移除展開/解開(raveling/unraveling)的樣板程式碼
使用
lax.scan
來移除索引更新邏輯在簡單擺錘基準測試中速度提高了 20% 以上
為每個轉換(transform)上的自定義導數呼叫原語
custom_jvp_call
和custom_vjp_call
添加了自定義綁定方法(bind method)。它類似於core.call_bind
,只是我們不處理環境追蹤(env traces):那些只是錯誤。添加了
custom_lin
原語,它被暫存輸出到線性 jaxpr 中,以便在使用自定義 VJP 規則時進行轉置(transpose)。由於我們的反向模式自動微分被分解為線性化、部分求值和轉置,因此我們的自定義 VJP 規則分兩個獨立的步驟處理:一個在線性化期間,另一個在轉置期間。
線性化步驟,即
custom_vjp_call
的 JVP 規則,將custom_lin
應用於正切值;custom_lin
攜帶了使用者自定義的後向傳遞函數,並且作為一個原語,它只有一個轉置規則。這個機制在 #636 中有更詳細的描述。
為了防止