jax.numpy.linalg.pinv#

jax.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=Deprecated)[原始碼]#

計算矩陣的 (Moore-Penrose) 偽反矩陣。

JAX 實作的 numpy.linalg.pinv()

參數:
  • a (ArrayLike) – 形狀為 (..., M, N) 的陣列,包含要計算偽反矩陣的矩陣。

  • rtol (ArrayLike | None | None) – 浮點數或形狀為 a.shape[:-2] 的類陣列 (array_like)。指定小奇異值的截止值。形狀為 (...,)。小奇異值的截止值;小於 rtol * largest_singular_value 的奇異值會被視為零。預設值根據 dtype 的浮點精度決定。

  • hermitian (bool) – 若為 True,則輸入會被假定為 Hermitian 矩陣,並使用更有效率的演算法 (預設值:False)

  • rcond (ArrayLike | DeprecatedArg | None) – 已棄用的 rtol 參數別名。若使用,將會導致 DeprecationWarning

傳回:

形狀為 (..., N, M) 的陣列,包含 a 的偽反矩陣。

傳回型別:

Array

另請參閱

注意事項

jax.numpy.linalg.pinv()numpy.linalg.pinv()rcond` 的預設值上有所不同:在 NumPy 中,預設值為 1e-15。在 JAX 中,預設值為 10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps

範例

>>> a = jnp.array([[1, 2],
...                [3, 4],
...                [5, 6]])
>>> a_pinv = jnp.linalg.pinv(a)
>>> a_pinv  
Array([[-1.333332  , -0.33333257,  0.6666657 ],
       [ 1.0833322 ,  0.33333272, -0.41666582]], dtype=float32)

只要輸出不是秩虧 (rank-deficient),偽反矩陣的作用就像乘法反矩陣

>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
Array(True, dtype=bool)