jax.scipy.linalg.expm_frechet#

jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array][source]#
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]

計算矩陣指數的 Frechet 導數。

scipy.linalg.expm_frechet() 的 JAX 實作

參數:
  • A – 形狀為 (..., N, N) 的陣列

  • E – 形狀為 (..., N, N) 的陣列;指定導數的方向。

  • compute_expm – 如果為 True (預設值),則計算並傳回 expm(A)

  • method – JAX 忽略

傳回:

如果 compute_expm 為 True,則傳回元組 (expm_A, expm_frechet_AE),否則傳回陣列 expm_frechet_AE。兩個傳回的陣列都具有形狀 (..., N, N)

範例

我們可以使用此 API 來計算 A 的矩陣指數,以及其在 E 方向上的導數

>>> key1, key2 = jax.random.split(jax.random.key(3372))
>>> A = jax.random.normal(key1, (3, 3))
>>> E = jax.random.normal(key2, (3, 3))
>>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)

這可以使用 JAX 的自動微分方法等效地計算;在這裡,我們將使用 jax.jvp() 計算 expm()E 方向上的導數,並找到相同的結果

>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,))
>>> jnp.allclose(expmA, expmA2)
Array(True, dtype=bool)
>>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2)
Array(True, dtype=bool)