jax.numpy.linalg.trace#
- jax.numpy.linalg.trace(x, /, *, offset=0, dtype=None)[原始碼]#
計算矩陣的跡。
JAX 實作的
numpy.linalg.trace()
。- 參數:
x (ArrayLike) – 形狀為
(..., M, N)
的陣列,其最內層的兩個維度形成 MxN 矩陣,以計算跡。offset (int) – 從主對角線的正或負偏移量(預設值:0)。
dtype (DTypeLike | None | None) – 傳回陣列的資料類型(預設值:
None
)。如果為None
,則輸出 dtype 將與x
的 dtype 相符,並在整數類型的情況下提升為預設精度。
- 傳回值:
具有形狀
x.shape[:-2]
的批次跡陣列- 傳回類型:
參見
jax.numpy.trace()
:jax.numpy
命名空間中類似的 API。
範例
單一矩陣的跡
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.trace(x) Array(18, dtype=int32) >>> jnp.linalg.trace(x, offset=1) Array(21, dtype=int32) >>> jnp.linalg.trace(x, offset=-1, dtype="float32") Array(15., dtype=float32)
批次跡
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.trace(x) Array([15, 51], dtype=int32)