jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array][原始碼]#
jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]

計算 LU 分解

JAX 實作的 scipy.linalg.lu()

矩陣 A 的 LU 分解為

\[A = P L U\]

其中 P 是置換矩陣,L 是下三角矩陣,而 U 是上三角矩陣。

參數:
  • a – 要分解的形狀為 (..., M, N) 的陣列。

  • permute_l – 如果為 True,則置換 L 並傳回 (P @ L, U) (預設值:False)

  • overwrite_a – JAX 未使用

  • check_finite – JAX 未使用

回傳:

  • P 是形狀為 (..., M, M) 的置換矩陣

  • L 是形狀為 (... M, K) 的下三角矩陣

  • U 是形狀為 (..., K, N) 的上三角矩陣

其中 K = min(M, N)

回傳類型:

如果 permute_l 為 True,則為陣列的元組 (P @ L, U),否則為 (P, L, U)

另請參閱

範例

3x3 矩陣的 LU 分解

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)

P 是置換矩陣:即每行和每列都有一個 1

>>> P
Array([[0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)

LU 是下三角矩陣和上三角矩陣

>>> with jnp.printoptions(precision=3):
...   print(L)
...   print(U)
[[ 1.     0.     0.   ]
 [ 0.2    1.     0.   ]
 [ 0.6   -0.333  1.   ]]
[[5.    4.    2.   ]
 [0.    1.2   2.6  ]
 [0.    0.    0.667]]

原始矩陣可以透過將三者相乘來重建

>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)