jax.scipy.linalg.funm#

jax.scipy.linalg.funm(A, func, disp=True)[source]#

評估矩陣值函數

JAX 實作的 scipy.linalg.funm()

參數:
  • A (ArrayLike) – 形狀為 (N, N) 的陣列,用於計算函數。

  • func (Callable[[Array], Array]) – 可呼叫物件,接受純量引數並傳回純量結果。表示要在 A 的特徵值上評估的函數。

  • disp (bool) – 如果為 true(預設值),則不傳回錯誤資訊。與 SciPy 版本不同,JAX 不會嘗試在執行時顯示資訊。

  • compute_expm – (N, N) array_like 或 None,可選。如果提供,則為 A 的矩陣指數。當 func 為指數函數時,這用於提高效率。如果未提供,則在內部計算。預設為 None。

傳回:

形狀與 A 相同的陣列,包含在 A 的特徵值上評估 func 的結果。

傳回型別:

Array | tuple[Array, Array]

注意事項

JAX 實作傳回的 dtype 可能與 SciPy 的不同;特別是,在陣列值的所有虛部都接近於零的情況下,SciPy 函數可能會傳回實數值陣列,而 JAX 實作將傳回複數值陣列。

範例

應用任意矩陣函數

>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> def func(x):
...   return jnp.sin(x) + 2 * jnp.cos(x)
>>> jax.scipy.linalg.funm(A, func)  
Array([[ 1.2452652 +0.j, -0.3701772 +0.j],
       [-0.55526584+0.j,  0.6899995 +0.j]], dtype=complex64)

比較計算矩陣指數的兩種方法

>>> expA_1 = jax.scipy.linalg.funm(A, jnp.exp)
>>> expA_2 = jax.scipy.linalg.expm(A)
>>> jnp.allclose(expA_1, expA_2, rtol=1E-4)
Array(True, dtype=bool)