jax.numpy.linalg.trace#

jax.numpy.linalg.trace(x, /, *, offset=0, dtype=None)[原始碼]#

計算矩陣的跡。

JAX 實作的 numpy.linalg.trace()

參數:
  • x (ArrayLike) – 形狀為 (..., M, N) 的陣列,其最內層的兩個維度形成 MxN 矩陣,以計算跡。

  • offset (int) – 從主對角線的正或負偏移量(預設值:0)。

  • dtype (DTypeLike | None | None) – 傳回陣列的資料類型(預設值:None)。如果為 None,則輸出 dtype 將與 x 的 dtype 相符,並在整數類型的情況下提升為預設精度。

傳回值:

具有形狀 x.shape[:-2] 的批次跡陣列

傳回類型:

Array

參見

範例

單一矩陣的跡

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.trace(x)
Array(18, dtype=int32)
>>> jnp.linalg.trace(x, offset=1)
Array(21, dtype=int32)
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
Array(15., dtype=float32)

批次跡

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.trace(x)
Array([15, 51], dtype=int32)