jax.scipy.linalg.expm#

jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[原始碼]#

計算矩陣指數

JAX 實現的 scipy.linalg.expm()

參數:
  • A (ArrayLike) – 形狀為 (..., N, N) 的陣列

  • upper_triangular (bool) – 如果為 True,則假設 A 是上三角矩陣。預設值=False。

  • max_squarings (int) – 縮放和平方近似方法中的平方次數 (預設值:16)。

傳回:

形狀為 (..., N, N) 的陣列,包含 A 的矩陣指數。

傳回類型:

Array

筆記

這使用了縮放和平方近似方法,其計算複雜度由可選的 max_squarings 參數控制。理論上,所需的平方次數為 max(0, ceil(log2(norm(A))) - c),其中 norm(A) 是 L1 範數,而對於 float64/complex128,c=2.42,或者對於 float32/complex64,c=1.97

範例

expm 是矩陣指數,並且具有與更熟悉的純量指數相似的屬性。對於純量 ab\(e^{a + b} = e^a e^b\)。但是,對於矩陣,此屬性僅在 AB 可交換 (AB = BA) 時成立。在這種情況下,expm(A+B) = expm(A) @ expm(B)

>>> A = jnp.array([[2, 0],
...                [0, 1]])
>>> B = jnp.array([[3, 0],
...                [0, 4]])
>>> jnp.allclose(jax.scipy.linalg.expm(A+B),
...              jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B),
...              rtol=0.0001)
Array(True, dtype=bool)

如果矩陣 X 是可逆的,則 expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)

>>> X = jnp.array([[3, 1],
...                [2, 5]])
>>> X_inv = jax.scipy.linalg.inv(X)
>>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv),
...              X @ jax.scipy.linalg.expm(A) @ X_inv)
Array(True, dtype=bool)