jax.numpy.trace#
- jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[原始碼]#
計算沿給定軸的輸入對角線總和。
JAX 版本的
numpy.trace()
。- 參數:
- 傳回:
維度為 x.ndim-2 的陣列,包含沿軸 (axis1, axis2) 的對角線元素總和
- 傳回類型:
另請參閱
jax.numpy.diag()
:傳回指定的對角線或建構對角線陣列jax.numpy.diagonal()
:傳回陣列的指定對角線。jax.numpy.diagflat()
:傳回一個 2 維陣列,其中扁平化的輸入陣列佈置在對角線上。
範例
>>> x = jnp.arange(1, 9).reshape(2, 2, 2) >>> x Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=int32) >>> jnp.trace(x) Array([ 8, 10], dtype=int32) >>> jnp.trace(x, offset=1) Array([3, 4], dtype=int32) >>> jnp.trace(x, axis1=1, axis2=2) Array([ 5, 13], dtype=int32) >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32)