jax.numpy.linalg.det#

jax.numpy.linalg.det(a)[原始碼]#

計算陣列的行列式。

JAX 版本的 numpy.linalg.det()

參數:

a (類陣列) – 形狀為 (..., M, M) 的陣列,用於計算行列式。

回傳:

形狀為 a.shape[:-2] 的行列式陣列。

回傳型別:

Array

另請參閱

jax.scipy.linalg.det():用於行列式的 Scipy 風格 API。

範例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.det(a)
Array(-2., dtype=float32)