jax.scipy.linalg.det#

jax.scipy.linalg.det(a, overwrite_a=False, check_finite=True)[原始碼]#

計算矩陣的行列式

JAX 實作的 scipy.linalg.det()

參數:
  • a (ArrayLike) – 輸入陣列,形狀為 (..., N, N)

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

回傳型別:

Array

回傳值

行列式,形狀為 a.shape[:-2]

參見

jax.numpy.linalg.det():NumPy 風格的行列式 API

範例

小型 2D 陣列的行列式

>>> x = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jax.scipy.linalg.det(x)
Array(-2., dtype=float32)

多個 2D 陣列的批次式行列式

>>> x = jnp.array([[[1., 2.],
...                 [3., 4.]],
...                [[8., 5.],
...                 [7., 9.]]])
>>> jax.scipy.linalg.det(x)
Array([-2., 37.], dtype=float32)