jax.numpy.linalg.lstsq#
- jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[原始碼]#
傳回線性方程式的最小平方解。
JAX 實作的
numpy.linalg.lstsq()
。- 參數:
a (ArrayLike) – 形狀為
(M, N)
的陣列,代表係數矩陣。b (ArrayLike) – 形狀為
(M,)
或(M, K)
的陣列,代表右手邊。rcond (float | None | None) – 小奇異值的截止比率。小於
rcond * largest_singular_value
的奇異值會被視為零。如果為 None (預設值),將使用最佳值來減少浮點錯誤。numpy_resid (bool) – 如果為 True,則以與 NumPy 的 linalg.lstsq 相同的方式計算並傳回殘差。如果您想要精確地複製 NumPy 的行為,則這是必要的。如果為 False (預設值),則會使用更有效率的方法來計算殘差。
- 回傳:
陣列的元組
(x, resid, rank, s)
,其中x
是形狀為(N,)
或(N, K)
的陣列,包含最小平方解。resid
是形狀為()
或(K,)
的平方殘差總和。rank
是矩陣a
的秩。s
是矩陣a
的奇異值。
- 回傳型別:
範例
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([5, 6]) >>> x, _, _, _ = jnp.linalg.lstsq(a, b) >>> with jnp.printoptions(precision=3): ... print(x) [-4. 4.5]