jax.numpy.tril_indices#
- jax.numpy.tril_indices(n, k=0, m=None)[source]#
傳回大小為
(n, m)
陣列之下三角形的索引。JAX 版本的
numpy.tril_indices()
實作。- 參數:
- 傳回值:
包含下三角形索引的兩個陣列的元組,每個軸各一個。
- 傳回類型:
另請參閱
jax.numpy.triu_indices()
:傳回大小為(n, m)
陣列之上三角形的索引。jax.numpy.triu_indices_from()
:傳回給定陣列之上三角形的索引。jax.numpy.tril_indices_from()
:傳回給定陣列之下三角形的索引。
範例
如果輸入中僅提供
n
,則會傳回大小為(n, n)
陣列之下三角形的索引。>>> jnp.tril_indices(3) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
如果輸入中同時提供
n
和m
,則會傳回(n, m)
陣列之下三角形的索引。>>> jnp.tril_indices(3, m=2) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))
如果
k = 1
,則會傳回主對角線上方的第一個次對角線及其下方的索引。>>> jnp.tril_indices(3, k=1) (Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))
如果
k = -1
,則會傳回主對角線下方的第一個次對角線及其下方的索引。>>> jnp.tril_indices(3, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))