jax.numpy.diag_indices#

jax.numpy.diag_indices(n, ndim=2)[原始碼]#

傳回用於存取多維陣列主對角線的索引。

JAX 版本的 numpy.diag_indices()

參數:
  • n (int) – int。方形陣列每個維度的大小。

  • ndim (int) – 選填,int,預設值=2。陣列的維度數量。

傳回值:

陣列的元組,每個長度為 n,包含存取主對角線的索引。

傳回類型:

tuple[Array, …]

範例

>>> jnp.diag_indices(3)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> jnp.diag_indices(4, ndim=3)
(Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32))