jax.scipy.linalg.sqrtm#
- jax.scipy.linalg.sqrtm(A, blocksize=1)[原始碼]#
計算矩陣平方根
JAX 版本的
scipy.linalg.sqrtm()
。- 參數:
A (ArrayLike) – 形狀為
(N, N)
的陣列blocksize (int) – JAX 不支援;JAX 始終使用
blocksize=1
。
- 返回:
形狀為
(N, N)
的陣列,包含A
的矩陣平方根- 返回型別:
範例
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> sqrt_a = jax.scipy.linalg.sqrtm(a) >>> with jnp.printoptions(precision=2, suppress=True): ... print(sqrt_a) [[0.92+0.71j 0.54+0.j 0.92-0.71j] [0.54+0.j 1.85+0.j 0.54-0.j ] [0.92-0.71j 0.54-0.j 0.92+0.71j]]
根據定義,矩陣平方根與自身相乘的結果應等於輸入
>>> jnp.allclose(a, sqrt_a @ sqrt_a) Array(True, dtype=bool)
筆記
此函數實作了 [1] 中描述的複數 Schur 方法。它不使用遞迴分塊來加速計算,因為 JAX 中尚無 Sylvester 方程式求解器。
參考文獻