jax.scipy.linalg.inv#

jax.scipy.linalg.inv(a, overwrite_a=False, check_finite=True)[原始碼]#

返回方陣的反矩陣

scipy.linalg.inv() 的 JAX 實作。

參數:
  • a (ArrayLike) – 形狀為 (..., N, N) 的陣列,指定要反轉的方陣。

  • overwrite_a (bool) – 在 JAX 中未使用

  • check_finite (bool) – 在 JAX 中未使用

返回:

形狀為 (..., N, N) 的陣列,包含輸入的反矩陣。

返回型別:

陣列

注意事項

在大多數情況下,明確計算矩陣的反矩陣是不明智的。例如,若要計算 x = inv(A) @ b,使用直接求解(例如 jax.scipy.linalg.solve())效能更高且數值更精確。

另請參閱

範例

計算 3x3 矩陣的反矩陣

>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> a_inv = jax.scipy.linalg.inv(a)
>>> a_inv  
Array([[ 0.        , -0.25      ,  0.5       ],
       [-0.25      ,  0.5       , -0.25000003],
       [ 0.5       , -0.25      ,  0.        ]], dtype=float32)

檢查與反矩陣相乘是否得到單位矩陣

>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

將反矩陣乘以向量 b,以找到 a @ x = b 的解

>>> b = jnp.array([1., 4., 2.])
>>> a_inv @ b
Array([ 0.  ,  1.25, -0.5 ], dtype=float32)

但請注意,在這種情況下明確計算反矩陣可能會導致效能不佳和精度損失,因為問題規模會擴大。相反地,您應該使用直接求解器,例如 jax.scipy.linalg.solve()

>>> jax.scipy.linalg.solve(a, b)
 Array([ 0.  ,  1.25, -0.5 ], dtype=float32)