jax.scipy.linalg.solve_triangular#

jax.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[原始碼]#

解三角線性方程式系統

scipy.linalg.solve_triangular() 的 JAX 實作。

這為給定的三角矩陣 a 和向量或矩陣 b 解 (批次) 線性方程式系統 a @ x = b 中的 x

參數:
  • a (ArrayLike) – 形狀為 (..., N, N) 的陣列。 只有陣列的一部分會被存取,取決於 lowerunit_diagonal 參數。

  • b (ArrayLike) – 形狀為 (..., N)(..., N, M) 的陣列

  • lower (bool) – 如果為 True,則僅使用輸入的下三角部分,如果為 False (預設),則僅使用上三角部分。

  • unit_diagonal (bool) – 如果為 True,則忽略 a 的對角線元素並假設它們為 1 (預設:False)。

  • trans (int | str) –

    指定可以假設 a 的哪些屬性。選項為

    • 0'N':解 \(Ax=b\)

    • 1'T':解 \(A^Tx=b\)

    • 2'C':解 \(A^Hx=b\)

  • overwrite_b (bool) – JAX 未使用

  • debug (Any | None) – JAX 未使用

  • check_finite (bool) – JAX 未使用

傳回:

b 形狀相同的陣列,其中包含線性系統的解。

傳回類型:

Array

參見

jax.scipy.linalg.solve():解一般線性系統。

範例

一個簡單的 3x3 三角線性系統

>>> A = jnp.array([[1., 2., 3.],
...                [0., 3., 2.],
...                [0., 0., 5.]])
>>> b = jnp.array([10., 8., 5.])
>>> x = jax.scipy.linalg.solve_triangular(A, b)
>>> x
Array([3., 2., 1.], dtype=float32)

確認結果能解系統

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)

計算轉置問題

>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T')
>>> x
Array([10. , -4. , -3.4], dtype=float32)

確認結果能解系統

>>> jnp.allclose(A.T @ x, b)
Array(True, dtype=bool)