jax.scipy.linalg.svd#

jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') tuple[Array, Array, Array][原始碼]#
jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array | tuple[Array, Array, Array]

計算奇異值分解。

JAX 實作的 scipy.linalg.svd()

矩陣 A 的 SVD 由下式給出

\[A = U\Sigma V^H\]
  • \(U\) 包含左奇異向量,並滿足 \(U^HU=I\)

  • \(V\) 包含右奇異向量,並滿足 \(V^HV=I\)

  • \(\Sigma\) 是奇異值的對角矩陣。

參數:
  • a – 輸入陣列,形狀為 (..., N, M)

  • full_matrices – 如果為 True (預設),則計算完整矩陣;即 uvh 的形狀分別為 (..., N, N)(..., M, M)。如果為 False,則形狀為 (..., N, K)(..., K, M),其中 K = min(N, M)

  • compute_uv – 如果為 True (預設),則傳回完整 SVD (u, s, vh)。如果為 False,則僅傳回奇異值 s

  • overwrite_a – JAX 未使用

  • check_finite – JAX 未使用

  • lapack_driver – JAX 未使用

傳回:

如果 compute_uv 為 True,則傳回陣列元組 (u, s, vh),否則傳回陣列 s

  • u:左奇異向量,如果 full_matrices 為 True,則形狀為 (..., N, N),否則為 (..., N, K)

  • s:奇異值,形狀為 (..., K)

  • vh:共軛轉置右奇異向量,如果 full_matrices 為 True,則形狀為 (..., M, M),否則為 (..., K, M)

其中 K = min(N, M)

另請參閱

範例

考慮小型實值陣列的 SVD

>>> x = jnp.array([[1., 2., 3.],
...                [6., 5., 4.]])
>>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False)
>>> s  
Array([9.361919 , 1.8315067], dtype=float32)

奇異向量位於 uv = vt.T 的列中。這些向量是單位正交的,可以透過將矩陣乘積與單位矩陣進行比較來證明

>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

給定 SVD,可以透過矩陣乘法重建 x

>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)