jax.scipy.linalg.qr#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array, Array] [source]#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = 'full', pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]
計算陣列的 QR 分解
JAX 版本的
scipy.linalg.qr()
實作。矩陣 A 的 QR 分解由下式給出
\[A = QR\]其中 Q 是么正矩陣 (即 \(Q^HQ=I\)),而 R 是上三角矩陣。
- 參數:
a – 形狀為 (…, M, N) 的陣列
mode –
計算模式。支援的值為
"full"
(預設):傳回形狀為(M, M)
的 Q 和形狀為(M, N)
的 R。"r"
:僅傳回 R"economic"
:傳回形狀為(M, K)
的 Q 和形狀為(K, N)
的 R,其中 K = min(M, N)。
pivoting – 允許 QR 分解揭示秩。如果為
True
,則計算列樞軸分解A[:, P] = Q @ R
,其中選擇P
以使R
的對角線不遞增。overwrite_a – 在 JAX 中未使用
lwork – 在 JAX 中未使用
check_finite – 在 JAX 中未使用
- 傳回:
如果
mode
不是"r"
且pivoting
分別為False
或True
,則為元組(Q, R)
或(Q, R, P)
,否則如果 mode 為"r"
且pivoting
分別為False
或True
,則為陣列R
或元組(R, P)
,其中Q
是形狀為(..., M, M)
(如果mode
為"full"
) 或(..., M, K)
(如果mode
為"economic"
) 的正交矩陣,R
是一個形狀為(..., M, N)
的上三角矩陣(如果mode
為"r"
或"full"
)或(..., K, N)
(如果mode
為"economic"
),P
是一個形狀為(..., N)
的索引向量。
其中
K = min(M, N)
。
註記
目前,軸運算僅在 CPU 後端實作。
另請參閱
jax.numpy.linalg.qr()
:NumPy 風格的 QR 分解 APIjax.lax.linalg.qr()
:XLA 風格的 QR 分解 API
範例
計算矩陣的 QR 分解
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jax.scipy.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
檢查
Q
是正交的>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
重建輸入
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)