jax.numpy.diag#
- jax.numpy.diag(v, k=0)[原始碼]#
傳回指定的對角線或建構對角陣列。
JAX 版本的
numpy.diag()
實作。JAX 版本總是傳回輸入的副本,儘管如果在 JIT 編譯中使用,編譯器可能會避免複製。
- 參數:
v (ArrayLike) – 輸入陣列。可以是 1 維陣列以建立對角矩陣,或是 2 維陣列以提取對角線。
k (int) – 選擇性,預設值=0。對角線偏移量。正值將對角線放置在主對角線上方,負值將其放置在主對角線下方。
- 傳回值:
如果 v 是 2 維陣列,則為包含對角線元素的 1 維陣列。如果 v 是 1 維陣列,則為沿指定對角線放置輸入元素的 2 維陣列。
- 傳回型別:
範例
從 1 維陣列建立對角矩陣
>>> jnp.diag(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32)
指定對角線偏移量
>>> jnp.diag(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32)
從 2 維陣列提取對角線
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.diag(x) Array([1, 5, 9], dtype=int32)