jax.numpy.trace#

jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[原始碼]#

計算沿給定軸的輸入對角線總和。

JAX 版本的 numpy.trace()

參數:
  • a (ArrayLike) – 輸入陣列。必須具有 a.ndim >= 2

  • offset (int | ArrayLike) – 選用,int,預設值=0。從主對角線的對角線偏移量。可以是正數或負數。

  • axis1 (int) – 選用,預設值=0。要沿其取對角線總和的第一個軸。必須是靜態整數值。

  • axis2 (int) – 選用,預設值=1。要沿其取對角線總和的第二個軸。必須是靜態整數值。

  • dtype (DTypeLike | None) – 選用。輸出陣列的 dtype。應在 JIT 編譯中作為靜態引數提供。

  • out (None) – JAX 未使用。

傳回:

維度為 x.ndim-2 的陣列,包含沿軸 (axis1, axis2) 的對角線元素總和

傳回類型:

Array

另請參閱

範例

>>> x = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.trace(x)
Array([ 8, 10], dtype=int32)
>>> jnp.trace(x, offset=1)
Array([3, 4], dtype=int32)
>>> jnp.trace(x, axis1=1, axis2=2)
Array([ 5, 13], dtype=int32)
>>> jnp.trace(x, offset=1, axis1=1, axis2=2)
Array([2, 6], dtype=int32)