jax.numpy.diag_indices#
- jax.numpy.diag_indices(n, ndim=2)[原始碼]#
傳回用於存取多維陣列主對角線的索引。
JAX 版本的
numpy.diag_indices()
。- 參數:
- 傳回值:
陣列的元組,每個長度為 n,包含存取主對角線的索引。
- 傳回類型:
範例
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))