jax.scipy.linalg.eigh#
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: Literal[False] = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) tuple[Array, Array] [原始碼]#
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, *, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array | tuple[Array, Array]
- jax.scipy.linalg.eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array | tuple[Array, Array]
計算 Hermitian 矩陣的特徵值和特徵向量
JAX 實作的
scipy.linalg.eigh()
。- 參數:
a – 形狀為
(..., N, N)
的 Hermitian 輸入陣列b – 形狀為
(..., N, N)
的可選 Hermitian 輸入。如果指定,則計算廣義特徵值問題。lower – 如果為 True (預設),則僅存取輸入矩陣的下半部分。否則僅存取上半部分。
eigvals_only – 如果為 True,則僅計算特徵值。如果為 False (預設),則同時計算特徵值和特徵向量。
type –
如果指定
b
,則type
給出要計算的廣義特徵值問題的類型。將(λ, v)
表示為特徵值、特徵向量對type = 1
求解a @ v = λ * b @ v
(預設)type = 2
求解a @ b @ v = λ * v
type = 3
求解b @ a @ v = λ * v
eigvals – 指定要計算的特徵值的
(low, high)
元組。overwrite_a – JAX 未使用。
overwrite_b – JAX 未使用。
turbo – JAX 未使用。
check_finite – JAX 未使用。
- 傳回:
如果
eigvals_only
為 False,則傳回陣列元組(eigvals, eigvecs)
,否則傳回陣列eigvals
。eigvals
: 形狀為(..., N)
的陣列,包含特徵值。eigvecs
: 形狀為(..., N, N)
的陣列,包含特徵向量。
另請參閱
jax.numpy.linalg.eigh()
:NumPy 樣式的 eigh API。jax.lax.linalg.eigh()
:XLA 樣式的 eigh API。jax.numpy.linalg.eig()
:非 Hermitian 特徵值問題。jax.scipy.linalg.eigh_tridiagonal()
:三對角線特徵值問題。
範例
計算一個簡單 2x2 矩陣的標準特徵值分解
>>> a = jnp.array([[2., 1.], ... [1., 2.]]) >>> eigvals, eigvecs = jax.scipy.linalg.eigh(a) >>> eigvals Array([1., 3.], dtype=float32) >>> eigvecs Array([[-0.70710677, 0.70710677], [ 0.70710677, 0.70710677]], dtype=float32)
特徵向量為單位正交
>>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5) Array(True, dtype=bool)
解滿足特徵值問題
>>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals)) Array(True, dtype=bool)