jax.scipy.linalg.expm_frechet#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array] [source]#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]
計算矩陣指數的 Frechet 導數。
scipy.linalg.expm_frechet()
的 JAX 實作- 參數:
A – 形狀為
(..., N, N)
的陣列E – 形狀為
(..., N, N)
的陣列;指定導數的方向。compute_expm – 如果為 True (預設值),則計算並傳回
expm(A)
。method – JAX 忽略
- 傳回:
如果
compute_expm
為 True,則傳回元組(expm_A, expm_frechet_AE)
,否則傳回陣列expm_frechet_AE
。兩個傳回的陣列都具有形狀(..., N, N)
。
範例
我們可以使用此 API 來計算
A
的矩陣指數,以及其在E
方向上的導數>>> key1, key2 = jax.random.split(jax.random.key(3372)) >>> A = jax.random.normal(key1, (3, 3)) >>> E = jax.random.normal(key2, (3, 3)) >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)
這可以使用 JAX 的自動微分方法等效地計算;在這裡,我們將使用
jax.jvp()
計算expm()
在E
方向上的導數,並找到相同的結果>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) >>> jnp.allclose(expmA, expmA2) Array(True, dtype=bool) >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) Array(True, dtype=bool)