jax.numpy.linalg.diagonal#

jax.numpy.linalg.diagonal(x, /, *, offset=0)[原始碼]#

提取矩陣或矩陣堆疊的對角線。

JAX 實作的 numpy.linalg.diagonal()

參數:
  • x (ArrayLike) – 形狀為 (..., M, N) 的陣列,將從中提取對角線。

  • offset (int) – 相對於主對角線的正或負偏移量。

傳回值:

形狀為 (..., K) 的陣列,其中 K 是指定對角線的長度。

傳回型別:

Array

參見

範例

單一矩陣的對角線

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.diagonal(x)
Array([ 1,  6, 11], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=1)
Array([ 2,  7, 12], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=-1)
Array([ 5, 10], dtype=int32)

批次對角線

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.diagonal(x)
Array([[ 0,  5, 10],
       [12, 17, 22]], dtype=int32)