jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a, b)[原始碼]#
求解線性方程式系統
JAX 實作的
numpy.linalg.solve()
。這會求解 (批次) 線性方程式系統
a @ x = b
,針對給定的a
和b
求解x
。- 參數:
a (ArrayLike) – 形狀為
(..., N, N)
的陣列。b (ArrayLike) – 形狀為
(N,)
(適用於 1 維右側) 或(..., N, M)
(適用於批次 2 維右側) 的陣列。
- 傳回:
包含線性求解結果的陣列。如果
b
的形狀為(N,)
,則結果的形狀為(..., N)
,否則形狀為(..., N, M)
。- 傳回類型:
另請參閱
jax.scipy.linalg.solve()
:用於求解線性系統的 SciPy 風格 API。jax.lax.custom_linear_solve()
:免矩陣線性求解器。
範例
簡單的 3x3 線性系統
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
確認結果求解了系統
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)