jax.experimental.jet 模組#

Jet 是一個用於高階自動微分的實驗性模組,不依賴重複的一階自動微分。

如何運作?透過截斷泰勒多項式的傳播。考慮一個函數 \(f = g \circ h\)、某個點 \(x\) 和某個偏移量 \(v\)。一階自動微分(例如 jax.jvp())從 \((h(x), \partial h(x)[v])\) 對計算出 \((f(x), \partial f(x)[v])\) 對。

jet() 實作了高階類比:給定元組

\[(h_0, ... h_K) := (h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),\]

其表示 \(h\)\(x\) 處的 \(K\) 階泰勒近似,jet() 傳回 \(f\)\(x\) 處的 \(K\) 階泰勒近似,

\[(f_0, ..., f_K) := (f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).\]

更具體地說,jet() 計算

\[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))\]

因此可以用於 \(f\) 的高階自動微分。詳細資訊請見這些筆記

注意

透過貢獻未完成的 primitive 規則來協助改進 jet()

API#

jax.experimental.jet.jet(fun, primals, series)[原始碼]#

泰勒模式高階自動微分。

參數:
  • fun – 要微分的函數。其引數應為陣列、純量或陣列或純量的標準 Python 容器。它應傳回陣列、純量或陣列或純量的標準 Python 容器。

  • primals – 應該在其中評估 fun 的泰勒近似的原始值。應為引數的元組或列表,且其長度應等於 fun 的位置參數數量。

  • series – 高階泰勒級數係數。 primalsseries 一起構成截斷泰勒多項式。應為元組或元組或列表的列表,且其長度決定截斷泰勒多項式的階數。

傳回:

一個 (primals_out, series_out) 對,其中 primals_outfun(*primals),並且 primals_outseries_out 一起是 \(f(h(\cdot))\) 的截斷泰勒多項式。primals_out 值具有與 primals 相同的 Python 樹狀結構,而 series_out 值具有與 series 相同的 Python 樹狀結構。

例如

>>> import jax
>>> import jax.numpy as np

考慮函數 \(h(z) = z^3\)\(x = 0.5\) 和前幾個泰勒係數 \(h_0=x^3\)\(h_1=3x^2\)\(h_2=6x\)。設 \(f(y) = \sin(y)\)

>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

jet() 根據 Faà di Bruno 公式傳回 \(f(h(z)) = \sin(z^3)\) 的泰勒係數

>>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
>>> print(f0,  f(h0))
0.12467473 0.12467473
>>> print(f1, df(h0) * h1)
0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064622 2.9064634