jax.scipy.sparse.linalg.cg#

jax.scipy.sparse.linalg.cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[原始碼]#

使用共軛梯度迭代法求解 Ax = b

JAX 的 cg 數值應與 SciPy 的 cg 完全匹配 (在數值精度範圍內),但請注意介面略有不同:您需要將線性運算子 A 作為函式提供,而不是稀疏矩陣或 LinearOperator

cg 的導數是透過隱式微分和另一個 cg 求解來實現的,而不是透過微分穿過求解器。只有當兩個求解都收斂時,它們才是準確的。

參數:
  • A (ndarray, function, 或 matmul-compatible object) – 2D 陣列或函式,用於計算線性映射 (矩陣向量乘積) Ax,當像 A(x)A @ x 這樣呼叫時。A 必須表示 Hermitian、正定矩陣,並且必須傳回與其引數具有相同結構和形狀的陣列。

  • b (arraytree of arrays) – 線性系統的右側,表示單個向量。可以儲存為陣列或 Python 容器,其中包含任何形狀的陣列。

  • x0 (arraytree of arrays) – 解的起始猜測。必須具有與 b 相同的結構。

  • tol (float, optional) – 收斂容差,norm(residual) <= max(tol*norm(b), atol)。我們不實作 SciPy 的「傳統」行為,因此,除非您明確將 atol 傳遞給 SciPy 的 cg,否則 JAX 的容差將與 SciPy 不同。

  • atol (float, optional) – 收斂容差,norm(residual) <= max(tol*norm(b), atol)。我們不實作 SciPy 的「傳統」行為,因此,除非您明確將 atol 傳遞給 SciPy 的 cg,否則 JAX 的容差將與 SciPy 不同。

  • maxiter (integer) – 最大迭代次數。即使未達到指定的容差,迭代也會在 maxiter 步驟後停止。

  • M (ndarray, function, 或 matmul-compatible object) – A 的預處理器。預處理器應近似於 A 的反矩陣。有效的預處理可以顯著提高收斂速度,這表示達到給定的誤差容差所需的迭代次數更少。

傳回:

  • x (array or tree of arrays) – 收斂的解。具有與 b 相同的結構。

  • info (None) – 收斂資訊的預留位置。未來,JAX 將報告未達到收斂時的迭代次數,類似於 SciPy。