jax.numpy.linalg.solve#

jax.numpy.linalg.solve(a, b)[原始碼]#

求解線性方程式系統

JAX 實作的 numpy.linalg.solve()

這會求解 (批次) 線性方程式系統 a @ x = b,針對給定的 ab 求解 x

參數:
  • a (ArrayLike) – 形狀為 (..., N, N) 的陣列。

  • b (ArrayLike) – 形狀為 (N,) (適用於 1 維右側) 或 (..., N, M) (適用於批次 2 維右側) 的陣列。

傳回:

包含線性求解結果的陣列。如果 b 的形狀為 (N,),則結果的形狀為 (..., N),否則形狀為 (..., N, M)

傳回類型:

Array

另請參閱

範例

簡單的 3x3 線性系統

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jnp.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

確認結果求解了系統

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)