jax.scipy.linalg.lu#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array] [原始碼]#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
計算 LU 分解
JAX 實作的
scipy.linalg.lu()
。矩陣 A 的 LU 分解為
\[A = P L U\]其中 P 是置換矩陣,L 是下三角矩陣,而 U 是上三角矩陣。
- 參數:
a – 要分解的形狀為
(..., M, N)
的陣列。permute_l – 如果為 True,則置換
L
並傳回(P @ L, U)
(預設值:False)overwrite_a – JAX 未使用
check_finite – JAX 未使用
- 回傳:
P
是形狀為(..., M, M)
的置換矩陣L
是形狀為(... M, K)
的下三角矩陣U
是形狀為(..., K, N)
的上三角矩陣
其中
K = min(M, N)
- 回傳類型:
如果
permute_l
為 True,則為陣列的元組(P @ L, U)
,否則為(P, L, U)
另請參閱
jax.numpy.linalg.lu()
:NumPy 風格的 LU 分解 API。jax.lax.linalg.lu()
:XLA 風格的 LU 分解 API。jax.scipy.linalg.lu_solve()
:基於 LU 的線性求解器。
範例
3x3 矩陣的 LU 分解
>>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> P, L, U = jax.scipy.linalg.lu(a)
P
是置換矩陣:即每行和每列都有一個1
>>> P Array([[0., 1., 0.], [1., 0., 0.], [0., 0., 1.]], dtype=float32)
L
和U
是下三角矩陣和上三角矩陣>>> with jnp.printoptions(precision=3): ... print(L) ... print(U) [[ 1. 0. 0. ] [ 0.2 1. 0. ] [ 0.6 -0.333 1. ]] [[5. 4. 2. ] [0. 1.2 2.6 ] [0. 0. 0.667]]
原始矩陣可以透過將三者相乘來重建
>>> a_reconstructed = P @ L @ U >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool)