jax.numpy.diagonal#

jax.numpy.diagonal(a, offset=0, axis1=0, axis2=1)[原始碼]#

傳回陣列的指定對角線。

JAX 版本的 numpy.diagonal() 實作。

JAX 版本總是傳回輸入的副本,儘管如果在 JIT 編譯中使用,編譯器可能會避免複製。

參數:
  • a (ArrayLike) – 輸入陣列。必須至少是 2 維。

  • offset (int) – 選用,預設值=0。從主對角線的對角線偏移量。必須是靜態整數值。可以是正數或負數。

  • axis1 (int) – 選用,預設值=0。要沿其取得對角線的第一個軸。

  • axis2 (int) –

    選用,預設值=1。要沿其取得對角線的第二個軸。

    傳回

    對於 2D 輸入,傳回 1D 陣列;一般而言,對於 N 維輸入,傳回 N-1 維陣列。

傳回類型:

Array

範例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.diagonal(x)
Array([1, 5, 9], dtype=int32)
>>> jnp.diagonal(x, offset=1)
Array([2, 6], dtype=int32)
>>> jnp.diagonal(x, offset=-1)
Array([4, 8], dtype=int32)