jax.numpy.diag#

jax.numpy.diag(v, k=0)[原始碼]#

傳回指定的對角線或建構對角陣列。

JAX 版本的 numpy.diag() 實作。

JAX 版本總是傳回輸入的副本,儘管如果在 JIT 編譯中使用,編譯器可能會避免複製。

參數:
  • v (ArrayLike) – 輸入陣列。可以是 1 維陣列以建立對角矩陣,或是 2 維陣列以提取對角線。

  • k (int) – 選擇性,預設值=0。對角線偏移量。正值將對角線放置在主對角線上方,負值將其放置在主對角線下方。

傳回值:

如果 v 是 2 維陣列,則為包含對角線元素的 1 維陣列。如果 v 是 1 維陣列,則為沿指定對角線放置輸入元素的 2 維陣列。

傳回型別:

Array

範例

從 1 維陣列建立對角矩陣

>>> jnp.diag(jnp.array([1, 2, 3]))
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

指定對角線偏移量

>>> jnp.diag(jnp.array([1, 2, 3]), k=1)
Array([[0, 1, 0, 0],
       [0, 0, 2, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 0]], dtype=int32)

從 2 維陣列提取對角線

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.diag(x)
Array([1, 5, 9], dtype=int32)