jax.numpy.linalg.svd#
- jax.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False, subset_by_index=None)[來源]#
計算奇異值分解。
JAX 實作的
numpy.linalg.svd()
,以jax.lax.linalg.svd()
實作。矩陣 A 的 SVD 由下式給出
\[A = U\Sigma V^H\]\(U\) 包含左奇異向量,並滿足 \(U^HU=I\)
\(V\) 包含右奇異向量,並滿足 \(V^HV=I\)
\(\Sigma\) 是奇異值的對角矩陣。
- 參數:
a (ArrayLike) – 輸入陣列,形狀為
(..., N, M)
full_matrices (bool) – 如果為 True (預設),則計算完整矩陣;即
u
和vh
的形狀分別為(..., N, N)
和(..., M, M)
。如果為 False,則形狀為(..., N, K)
和(..., K, M)
,其中K = min(N, M)
。compute_uv (bool) – 如果為 True (預設),則傳回完整 SVD
(u, s, vh)
。如果為 False,則僅傳回奇異值s
。hermitian (bool) – 如果為 True,則假設矩陣為 Hermitian 矩陣,這可以更有效率地實作 (預設值=False)
subset_by_index (tuple[int, int] | None) – (僅限 TPU) 可選的 2 元組 [start, end],指示要計算的奇異值索引範圍。例如,如果
[n-2, n]
,則svd
計算兩個最大奇異值及其奇異向量。僅與full_matrices=False
相容。
- 傳回值:
如果
compute_uv
為 True,則傳回陣列元組(u, s, vh)
,否則傳回陣列s
。u
: 左奇異向量,如果full_matrices
為 True,則形狀為(..., N, N)
,否則為(..., N, K)
。s
: 奇異值,形狀為(..., K)
vh
: 共軛轉置右奇異向量,如果full_matrices
為 True,則形狀為(..., M, M)
,否則為(..., K, M)
。
其中
K = min(N, M)
。- 傳回類型:
Array | SVDResult
另請參閱
jax.scipy.linalg.svd()
: SciPy 風格的 SVD APIjax.lax.linalg.svd()
: XLA 風格的 SVD API
範例
考慮一個小型實值陣列的 SVD
>>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jnp.linalg.svd(x, full_matrices=False) >>> s Array([9.361919 , 1.8315067], dtype=float32)
奇異向量位於
u
和v = vt.T
的列中。這些向量是單位正交的,可以透過比較矩陣乘積與單位矩陣來證明>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
給定 SVD,可以透過矩陣乘法重建
x
>>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool)