jax.scipy.linalg.sqrtm#

jax.scipy.linalg.sqrtm(A, blocksize=1)[原始碼]#

計算矩陣平方根

JAX 版本的 scipy.linalg.sqrtm()

參數:
  • A (ArrayLike) – 形狀為 (N, N) 的陣列

  • blocksize (int) – JAX 不支援;JAX 始終使用 blocksize=1

返回:

形狀為 (N, N) 的陣列,包含 A 的矩陣平方根

返回型別:

Array

範例

>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> sqrt_a = jax.scipy.linalg.sqrtm(a)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(sqrt_a)
[[0.92+0.71j 0.54+0.j   0.92-0.71j]
 [0.54+0.j   1.85+0.j   0.54-0.j  ]
 [0.92-0.71j 0.54-0.j   0.92+0.71j]]

根據定義,矩陣平方根與自身相乘的結果應等於輸入

>>> jnp.allclose(a, sqrt_a @ sqrt_a)
Array(True, dtype=bool)

筆記

此函數實作了 [1] 中描述的複數 Schur 方法。它不使用遞迴分塊來加速計算,因為 JAX 中尚無 Sylvester 方程式求解器。

參考文獻