jax.lax.custom_root#
- jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[原始碼]#
可微分地求解函數的根。
這是一個底層常式,主要用於 JAX 內部。
custom_root()
的梯度是根據隱函數定理,相對於所提供函數f
中封閉的變數所定義:https://en.wikipedia.org/wiki/Implicit_function_theorem- 參數:
f – 要尋找根的函數。應接受單一引數,並傳回與其輸入結構相同的陣列樹狀結構。
initial_guess – f 零點的初始猜測值。
solve –
用於求解 f 根的函數。應接受兩個位置引數,f 和 initial_guess,並傳回與 initial_guess 結構相同的解,使得 func(solution) = 0。換句話說,假設以下為真(但未檢查)
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
tangent_solve –
用於求解切線系統的函數。應接受兩個位置引數,一個線性函數
g
(函數f
在其根處線性化)和一個與 initial_guess 結構相同的陣列樹狀結構y
,並傳回一個解x
,使得g(x)=y
對於純量
y
,使用lambda g, y: y / g(1.0)
。對於向量
y
,如果y
的維度不太大,則可以使用 Jacobian 的線性求解:lambda g, y: np.linalg.solve(jacobian(g)(y), y)
。
has_aux – 布林值,指示
solve
函數是否傳回輔助資料,例如求解器診斷資訊作為第二個引數。
- 傳回值:
假設
f(solve(f, initial_guess)) == 0
,透過隱式微分定義梯度來呼叫 solve(f, initial_guess) 的結果。