jax.numpy.diagflat#
- jax.numpy.diagflat(v, k=0)[原始碼]#
傳回一個 2 維陣列,其平面化的輸入陣列佈置在對角線上。
JAX 實作的
numpy.diagflat()
。對於 v 的某些純量值,這與 np.diagflat 不同。JAX 總是傳回二維陣列,而 NumPy 可能會根據 v 的類型傳回純量。
- 參數:
v (ArrayLike) – 輸入陣列。可以是 N 維,但會被平面化為 1 維。
k (int) – 選用,預設值=0。對角線偏移量。正值將對角線放置在主對角線上方,負值將其放置在主對角線下方。
- 傳回:
一個 2 維陣列,其輸入元素沿著具有指定偏移量 (k) 的對角線放置。剩餘的條目將填充為零。
- 回傳類型:
範例
>>> jnp.diagflat(jnp.array([1, 2, 3])) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32) >>> jnp.diagflat(jnp.array([1, 2, 3]), k=1) Array([[0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 3], [0, 0, 0, 0]], dtype=int32) >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.diagflat(a) Array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)