jax.scipy.linalg.lu_factor#
- jax.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True)[source]#
用於 LU 的線性求解的因式分解
scipy.linalg.lu_factor()
的 JAX 實作。此函數返回適用於
jax.scipy.linalg.lu_solve()
的結果。對於直接 LU 分解,請優先使用jax.scipy.linalg.lu()
。- 參數:
- 返回值:
一個元組
(lu, piv)
lu
是一個形狀為(..., M, N)
的陣列,其下三角包含L
,上三角包含U
。piv
是一個形狀為(..., K)
的陣列,其中K = min(M, N)
,用於編碼樞紐。
- 返回類型:
範例
通過 LU 分解求解小型線性系統
>>> a = jnp.array([[2., 1.], ... [1., 2.]])
通過
lu_factor()
計算 LU 分解,並使用它通過lu_solve()
求解線性方程式。>>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32)
檢查結果是否一致
>>> jnp.allclose(a @ y, b) Array(True, dtype=bool)