jax.lax.custom_linear_solve#

jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)[原始碼]#

執行矩陣無關的線性求解,具有隱式定義的梯度。

此函式允許直接透過解的隱式微分來覆寫或定義線性求解的梯度,而不是透過求解運算進行微分。有時這樣做可以更快或數值上更穩定,或者甚至可能未實作透過求解運算的微分(例如,如果 solve 使用 lax.while_loop)。

必要不變量

x = solve(matvec, b)  # solve the linear equation
assert matvec(x) == b  # not checked
參數:
  • matvec – 要反轉的線性函數。必須是可微分的。

  • b – 方程式的常數右側項。可以是陣列的任何巢狀結構。

  • solve – 更高層級的函式,用於求解線性方程式的解,即對於所有與 b 形式相同的 xsolve(matvec, x) == x。此函式不需要是可微分的。

  • transpose_solve – 用於求解轉置線性方程式的更高層級函式,即 transpose_solve(vecmat, x) == x,其中 vecmat 是線性映射 matvec 的轉置(使用自動微分自動計算)。反向模式自動微分是必需的,除非 symmetric=True,在這種情況下,solve 提供預設值。

  • symmetric – 布林值,指示是否可以安全地假設線性映射對應於對稱矩陣,即 matvec == vecmat

  • has_aux – 布林值,指示 solvetranspose_solve 函式是否傳回輔助資料,例如作為第二個引數的求解器診斷資訊。

傳回:

solve(matvec, b) 的結果,梯度定義假設

x 滿足線性方程式 matvec(x) == b