jax.scipy.linalg.lu_factor#

jax.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True)[source]#

用於 LU 的線性求解的因式分解

scipy.linalg.lu_factor() 的 JAX 實作。

此函數返回適用於 jax.scipy.linalg.lu_solve() 的結果。對於直接 LU 分解,請優先使用 jax.scipy.linalg.lu()

參數:
  • a (ArrayLike) – 形狀為 (..., M, N) 的輸入陣列。

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回值:

一個元組 (lu, piv)

  • lu 是一個形狀為 (..., M, N) 的陣列,其下三角包含 L,上三角包含 U

  • piv 是一個形狀為 (..., K) 的陣列,其中 K = min(M, N),用於編碼樞紐。

返回類型:

tuple[Array, Array]

範例

通過 LU 分解求解小型線性系統

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])

通過 lu_factor() 計算 LU 分解,並使用它通過 lu_solve() 求解線性方程式。

>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)

檢查結果是否一致

>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)