jax.numpy.tril_indices#

jax.numpy.tril_indices(n, k=0, m=None)[source]#

傳回大小為 (n, m) 陣列之下三角形的索引。

JAX 版本的 numpy.tril_indices() 實作。

參數:
  • n (int) – int。傳回索引的陣列列數。

  • k (int) – 選項,int,預設值=0。指定要傳回下三角形索引的次對角線及其下方。k=0 指的是主對角線,k<0 指的是主對角線下方的次對角線,而 k>0 指的是主對角線上方的次對角線。

  • m (int | None | None) – 選項,int。傳回索引的陣列行數。如果未指定,則 m = n

傳回值:

包含下三角形索引的兩個陣列的元組,每個軸各一個。

傳回類型:

tuple[Array, Array]

另請參閱

範例

如果輸入中僅提供 n,則會傳回大小為 (n, n) 陣列之下三角形的索引。

>>> jnp.tril_indices(3)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

如果輸入中同時提供 nm,則會傳回 (n, m) 陣列之下三角形的索引。

>>> jnp.tril_indices(3, m=2)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))

如果 k = 1,則會傳回主對角線上方的第一個次對角線及其下方的索引。

>>> jnp.tril_indices(3, k=1)
(Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))

如果 k = -1,則會傳回主對角線下方的第一個次對角線及其下方的索引。

>>> jnp.tril_indices(3, k=-1)
(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))