jax.scipy.linalg.expm#
- jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[原始碼]#
計算矩陣指數
JAX 實現的
scipy.linalg.expm()
。- 參數:
- 傳回:
形狀為
(..., N, N)
的陣列,包含A
的矩陣指數。- 傳回類型:
筆記
這使用了縮放和平方近似方法,其計算複雜度由可選的
max_squarings
參數控制。理論上,所需的平方次數為max(0, ceil(log2(norm(A))) - c)
,其中norm(A)
是 L1 範數,而對於 float64/complex128,c=2.42
,或者對於 float32/complex64,c=1.97
。範例
expm
是矩陣指數,並且具有與更熟悉的純量指數相似的屬性。對於純量a
和b
,\(e^{a + b} = e^a e^b\)。但是,對於矩陣,此屬性僅在A
和B
可交換 (AB = BA
) 時成立。在這種情況下,expm(A+B) = expm(A) @ expm(B)
>>> A = jnp.array([[2, 0], ... [0, 1]]) >>> B = jnp.array([[3, 0], ... [0, 4]]) >>> jnp.allclose(jax.scipy.linalg.expm(A+B), ... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B), ... rtol=0.0001) Array(True, dtype=bool)
如果矩陣
X
是可逆的,則expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)
>>> X = jnp.array([[3, 1], ... [2, 5]]) >>> X_inv = jax.scipy.linalg.inv(X) >>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv), ... X @ jax.scipy.linalg.expm(A) @ X_inv) Array(True, dtype=bool)