jax.scipy.sparse.linalg.cg#
- jax.scipy.sparse.linalg.cg(A, b, x0=None, *, tol=1e-05, atol=0.0, maxiter=None, M=None)[原始碼]#
使用共軛梯度迭代法求解
Ax = b
。JAX 的
cg
數值應與 SciPy 的cg
完全匹配 (在數值精度範圍內),但請注意介面略有不同:您需要將線性運算子A
作為函式提供,而不是稀疏矩陣或LinearOperator
。cg
的導數是透過隱式微分和另一個cg
求解來實現的,而不是透過微分穿過求解器。只有當兩個求解都收斂時,它們才是準確的。- 參數:
A (ndarray, function, 或 matmul-compatible object) – 2D 陣列或函式,用於計算線性映射 (矩陣向量乘積)
Ax
,當像A(x)
或A @ x
這樣呼叫時。A
必須表示 Hermitian、正定矩陣,並且必須傳回與其引數具有相同結構和形狀的陣列。b (array 或 tree of arrays) – 線性系統的右側,表示單個向量。可以儲存為陣列或 Python 容器,其中包含任何形狀的陣列。
x0 (array 或 tree of arrays) – 解的起始猜測。必須具有與
b
相同的結構。tol (float, optional) – 收斂容差,
norm(residual) <= max(tol*norm(b), atol)
。我們不實作 SciPy 的「傳統」行為,因此,除非您明確將atol
傳遞給 SciPy 的cg
,否則 JAX 的容差將與 SciPy 不同。atol (float, optional) – 收斂容差,
norm(residual) <= max(tol*norm(b), atol)
。我們不實作 SciPy 的「傳統」行為,因此,除非您明確將atol
傳遞給 SciPy 的cg
,否則 JAX 的容差將與 SciPy 不同。maxiter (integer) – 最大迭代次數。即使未達到指定的容差,迭代也會在 maxiter 步驟後停止。
M (ndarray, function, 或 matmul-compatible object) – A 的預處理器。預處理器應近似於 A 的反矩陣。有效的預處理可以顯著提高收斂速度,這表示達到給定的誤差容差所需的迭代次數更少。
- 傳回:
x (array or tree of arrays) – 收斂的解。具有與
b
相同的結構。info (None) – 收斂資訊的預留位置。未來,JAX 將報告未達到收斂時的迭代次數,類似於 SciPy。