jax.lax.custom_root#

jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[原始碼]#

可微分地求解函數的根。

這是一個底層常式,主要用於 JAX 內部。custom_root() 的梯度是根據隱函數定理,相對於所提供函數 f 中封閉的變數所定義:https://en.wikipedia.org/wiki/Implicit_function_theorem

參數:
  • f – 要尋找根的函數。應接受單一引數,並傳回與其輸入結構相同的陣列樹狀結構。

  • initial_guess – f 零點的初始猜測值。

  • solve

    用於求解 f 根的函數。應接受兩個位置引數,f 和 initial_guess,並傳回與 initial_guess 結構相同的解,使得 func(solution) = 0。換句話說,假設以下為真(但未檢查)

    solution = solve(f, initial_guess)
    error = f(solution)
    assert all(error == 0)
    

  • tangent_solve

    用於求解切線系統的函數。應接受兩個位置引數,一個線性函數 g(函數 f 在其根處線性化)和一個與 initial_guess 結構相同的陣列樹狀結構 y,並傳回一個解 x,使得 g(x)=y

    • 對於純量 y,使用 lambda g, y: y / g(1.0)

    • 對於向量 y,如果 y 的維度不太大,則可以使用 Jacobian 的線性求解:lambda g, y: np.linalg.solve(jacobian(g)(y), y)

  • has_aux – 布林值,指示 solve 函數是否傳回輔助資料,例如求解器診斷資訊作為第二個引數。

傳回值:

假設 f(solve(f, initial_guess)) == 0,透過隱式微分定義梯度來呼叫 solve(f, initial_guess) 的結果。