jax.numpy.diagflat#

jax.numpy.diagflat(v, k=0)[原始碼]#

傳回一個 2 維陣列,其平面化的輸入陣列佈置在對角線上。

JAX 實作的 numpy.diagflat()

對於 v 的某些純量值,這與 np.diagflat 不同。JAX 總是傳回二維陣列,而 NumPy 可能會根據 v 的類型傳回純量。

參數:
  • v (ArrayLike) – 輸入陣列。可以是 N 維,但會被平面化為 1 維。

  • k (int) – 選用,預設值=0。對角線偏移量。正值將對角線放置在主對角線上方,負值將其放置在主對角線下方。

傳回:

一個 2 維陣列,其輸入元素沿著具有指定偏移量 (k) 的對角線放置。剩餘的條目將填充為零。

回傳類型:

Array

範例

>>> jnp.diagflat(jnp.array([1, 2, 3]))
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)
>>> jnp.diagflat(jnp.array([1, 2, 3]), k=1)
Array([[0, 1, 0, 0],
       [0, 0, 2, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 0]], dtype=int32)
>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.diagflat(a)
Array([[1, 0, 0, 0],
       [0, 2, 0, 0],
       [0, 0, 3, 0],
       [0, 0, 0, 4]], dtype=int32)