jax.scipy.linalg.lu_solve#

jax.scipy.linalg.lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[原始碼]#

使用 LU 分解法解線性系統

JAX 實作的 scipy.linalg.lu_solve()。使用 jax.scipy.linalg.lu_factor() 的輸出。

參數:
  • lu_and_piv (tuple[Array, ArrayLike]) – (lu, piv)lu_factor() 的輸出。lu 是形狀為 (..., M, N) 的陣列,在其下三角包含 L,上三角包含 Upiv 是形狀為 (..., K) 的陣列,其中 K = min(M, N),用於編碼樞軸。

  • b (ArrayLike) – 線性系統的右側。必須具有形狀 (..., M)

  • trans (int) –

    要解的系統類型。選項為

    • 0: \(A x = b\)

    • 1: \(A^Tx = b\)

    • 2: \(A^Hx = b\)

  • overwrite_b (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

回傳:

形狀為 (..., N) 的陣列,表示線性系統的解。

回傳類型:

Array

範例

透過 LU 分解法解小型線性系統

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])

透過 lu_factor() 計算 LU 分解,並使用它透過 lu_solve() 解線性方程式。

>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)

檢查結果是否一致

>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)