jax.numpy.linalg.svd#

jax.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False, subset_by_index=None)[來源]#

計算奇異值分解。

JAX 實作的 numpy.linalg.svd(),以 jax.lax.linalg.svd() 實作。

矩陣 A 的 SVD 由下式給出

\[A = U\Sigma V^H\]
  • \(U\) 包含左奇異向量,並滿足 \(U^HU=I\)

  • \(V\) 包含右奇異向量,並滿足 \(V^HV=I\)

  • \(\Sigma\) 是奇異值的對角矩陣。

參數:
  • a (ArrayLike) – 輸入陣列,形狀為 (..., N, M)

  • full_matrices (bool) – 如果為 True (預設),則計算完整矩陣;即 uvh 的形狀分別為 (..., N, N)(..., M, M)。如果為 False,則形狀為 (..., N, K)(..., K, M),其中 K = min(N, M)

  • compute_uv (bool) – 如果為 True (預設),則傳回完整 SVD (u, s, vh)。如果為 False,則僅傳回奇異值 s

  • hermitian (bool) – 如果為 True,則假設矩陣為 Hermitian 矩陣,這可以更有效率地實作 (預設值=False)

  • subset_by_index (tuple[int, int] | None) – (僅限 TPU) 可選的 2 元組 [start, end],指示要計算的奇異值索引範圍。例如,如果 [n-2, n],則 svd 計算兩個最大奇異值及其奇異向量。僅與 full_matrices=False 相容。

傳回值:

如果 compute_uv 為 True,則傳回陣列元組 (u, s, vh),否則傳回陣列 s

  • u: 左奇異向量,如果 full_matrices 為 True,則形狀為 (..., N, N),否則為 (..., N, K)

  • s: 奇異值,形狀為 (..., K)

  • vh: 共軛轉置右奇異向量,如果 full_matrices 為 True,則形狀為 (..., M, M),否則為 (..., K, M)

其中 K = min(N, M)

傳回類型:

Array | SVDResult

另請參閱

範例

考慮一個小型實值陣列的 SVD

>>> x = jnp.array([[1., 2., 3.],
...                [6., 5., 4.]])
>>> u, s, vt = jnp.linalg.svd(x, full_matrices=False)
>>> s  
Array([9.361919 , 1.8315067], dtype=float32)

奇異向量位於 uv = vt.T 的列中。這些向量是單位正交的,可以透過比較矩陣乘積與單位矩陣來證明

>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

給定 SVD,可以透過矩陣乘法重建 x

>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)