jax.numpy.linalg.diagonal#
- jax.numpy.linalg.diagonal(x, /, *, offset=0)[原始碼]#
提取矩陣或矩陣堆疊的對角線。
JAX 實作的
numpy.linalg.diagonal()
。- 參數:
x (ArrayLike) – 形狀為
(..., M, N)
的陣列,將從中提取對角線。offset (int) – 相對於主對角線的正或負偏移量。
- 傳回值:
形狀為
(..., K)
的陣列,其中K
是指定對角線的長度。- 傳回型別:
參見
jax.numpy.diagonal()
:用於提取對角線的更通用功能。jax.numpy.diag()
:從值建立對角矩陣。
範例
單一矩陣的對角線
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.diagonal(x) Array([ 1, 6, 11], dtype=int32) >>> jnp.linalg.diagonal(x, offset=1) Array([ 2, 7, 12], dtype=int32) >>> jnp.linalg.diagonal(x, offset=-1) Array([ 5, 10], dtype=int32)
批次對角線
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.diagonal(x) Array([[ 0, 5, 10], [12, 17, 22]], dtype=int32)