jax.numpy.triu_indices#
- jax.numpy.triu_indices(n, k=0, m=None)[原始碼]#
傳回大小為
(n, m)
的陣列的上三角形索引。numpy.triu_indices()
的 JAX 實作。- 參數:
- 傳回值:
包含上三角形索引的兩個陣列的元組,每個軸一個。
- 傳回類型:
參見
jax.numpy.tril_indices()
:傳回大小為(n, m)
的陣列的下三角形索引。jax.numpy.triu_indices_from()
:傳回給定陣列的上三角形索引。jax.numpy.tril_indices_from()
:傳回給定陣列的下三角形索引。
範例
如果輸入中僅提供
n
,則傳回大小為(n, n)
陣列的上三角形索引。>>> jnp.triu_indices(3) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
如果輸入中同時提供
n
和m
,則傳回(n, m)
陣列的上三角形索引。>>> jnp.triu_indices(3, m=2) (Array([0, 0, 1], dtype=int32), Array([0, 1, 1], dtype=int32))
如果
k = 1
,則傳回主對角線上方的第一個子對角線及其上方的索引。>>> jnp.triu_indices(3, k=1) (Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))
如果
k = -1
,則傳回主對角線下方的第一個子對角線及其上方的索引。>>> jnp.triu_indices(3, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))