jax.numpy.linalg.lstsq#

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[原始碼]#

傳回線性方程式的最小平方解。

JAX 實作的 numpy.linalg.lstsq()

參數:
  • a (ArrayLike) – 形狀為 (M, N) 的陣列,代表係數矩陣。

  • b (ArrayLike) – 形狀為 (M,)(M, K) 的陣列,代表右手邊。

  • rcond (float | None | None) – 小奇異值的截止比率。小於 rcond * largest_singular_value 的奇異值會被視為零。如果為 None (預設值),將使用最佳值來減少浮點錯誤。

  • numpy_resid (bool) – 如果為 True,則以與 NumPy 的 linalg.lstsq 相同的方式計算並傳回殘差。如果您想要精確地複製 NumPy 的行為,則這是必要的。如果為 False (預設值),則會使用更有效率的方法來計算殘差。

回傳:

陣列的元組 (x, resid, rank, s),其中

  • x 是形狀為 (N,)(N, K) 的陣列,包含最小平方解。

  • resid 是形狀為 ()(K,) 的平方殘差總和。

  • rank 是矩陣 a 的秩。

  • s 是矩陣 a 的奇異值。

回傳型別:

tuple[Array, Array, Array, Array]

範例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]