jax.numpy.linalg.pinv#
- jax.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=Deprecated)[原始碼]#
計算矩陣的 (Moore-Penrose) 偽反矩陣。
JAX 實作的
numpy.linalg.pinv()
。- 參數:
a (ArrayLike) – 形狀為
(..., M, N)
的陣列,包含要計算偽反矩陣的矩陣。rtol (ArrayLike | None | None) – 浮點數或形狀為
a.shape[:-2]
的類陣列 (array_like)。指定小奇異值的截止值。形狀為(...,)
。小奇異值的截止值;小於rtol * largest_singular_value
的奇異值會被視為零。預設值根據 dtype 的浮點精度決定。hermitian (bool) – 若為 True,則輸入會被假定為 Hermitian 矩陣,並使用更有效率的演算法 (預設值:False)
rcond (ArrayLike | DeprecatedArg | None) – 已棄用的
rtol
參數別名。若使用,將會導致DeprecationWarning
。
- 傳回:
形狀為
(..., N, M)
的陣列,包含a
的偽反矩陣。- 傳回型別:
另請參閱
jax.numpy.linalg.inv()
:方陣的乘法反矩陣。
注意事項
jax.numpy.linalg.pinv()
與numpy.linalg.pinv()
在 rcond` 的預設值上有所不同:在 NumPy 中,預設值為 1e-15。在 JAX 中,預設值為10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps
。範例
>>> a = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> a_pinv = jnp.linalg.pinv(a) >>> a_pinv Array([[-1.333332 , -0.33333257, 0.6666657 ], [ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32)
只要輸出不是秩虧 (rank-deficient),偽反矩陣的作用就像乘法反矩陣
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool)