jax.scipy.linalg.qr#

jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array, Array][source]#
jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array, Array]
jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array]
jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array]
jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array]
jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = 'full', pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]

計算陣列的 QR 分解

JAX 版本的 scipy.linalg.qr() 實作。

矩陣 A 的 QR 分解由下式給出

\[A = QR\]

其中 Q 是么正矩陣 (即 \(Q^HQ=I\)),而 R 是上三角矩陣。

參數:
  • a – 形狀為 (…, M, N) 的陣列

  • mode

    計算模式。支援的值為

    • "full" (預設):傳回形狀為 (M, M)Q 和形狀為 (M, N)R

    • "r":僅傳回 R

    • "economic":傳回形狀為 (M, K)Q 和形狀為 (K, N)R,其中 K = min(M, N)。

  • pivoting – 允許 QR 分解揭示秩。如果為 True,則計算列樞軸分解 A[:, P] = Q @ R,其中選擇 P 以使 R 的對角線不遞增。

  • overwrite_a – 在 JAX 中未使用

  • lwork – 在 JAX 中未使用

  • check_finite – 在 JAX 中未使用

傳回:

如果 mode 不是 "r"pivoting 分別為 FalseTrue,則為元組 (Q, R)(Q, R, P),否則如果 mode 為 "r"pivoting 分別為 FalseTrue,則為陣列 R 或元組 (R, P),其中

  • Q 是形狀為 (..., M, M) (如果 mode"full") 或 (..., M, K) (如果 mode"economic") 的正交矩陣,

  • R 是一個形狀為 (..., M, N) 的上三角矩陣(如果 mode"r""full")或 (..., K, N)(如果 mode"economic"),

  • P 是一個形狀為 (..., N) 的索引向量。

其中 K = min(M, N)

註記

  • 目前,軸運算僅在 CPU 後端實作。

另請參閱

範例

計算矩陣的 QR 分解

>>> a = jnp.array([[1., 2., 3., 4.],
...                [5., 4., 2., 1.],
...                [6., 3., 1., 5.]])
>>> Q, R = jax.scipy.linalg.qr(a)
>>> Q  
Array([[-0.12700021, -0.7581426 , -0.6396022 ],
       [-0.63500065, -0.43322435,  0.63960224],
       [-0.7620008 ,  0.48737738, -0.42640156]], dtype=float32)
>>> R  
Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ],
       [ 0.       , -1.7870499, -2.6534991, -1.028908 ],
       [ 0.       ,  0.       , -1.0660033, -4.050814 ]], dtype=float32)

檢查 Q 是正交的

>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

重建輸入

>>> jnp.allclose(Q @ R, a)
Array(True, dtype=bool)