jax.numpy.triu_indices#

jax.numpy.triu_indices(n, k=0, m=None)[原始碼]#

傳回大小為 (n, m) 的陣列的上三角形索引。

numpy.triu_indices() 的 JAX 實作。

參數:
  • n (int) – int。傳回索引的陣列列數。

  • k (int) – 選用,int,預設值=0。指定要傳回上三角形索引的子對角線和其上方。k=0 指的是主對角線,k<0 指的是主對角線下方的子對角線,而 k>0 指的是主對角線上方的子對角線。

  • m (int | None | None) – 選用,int。傳回索引的陣列行數。如果未指定,則 m = n

傳回值:

包含上三角形索引的兩個陣列的元組,每個軸一個。

傳回類型:

tuple[Array, Array]

參見

範例

如果輸入中僅提供 n,則傳回大小為 (n, n) 陣列的上三角形索引。

>>> jnp.triu_indices(3)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))

如果輸入中同時提供 nm,則傳回 (n, m) 陣列的上三角形索引。

>>> jnp.triu_indices(3, m=2)
(Array([0, 0, 1], dtype=int32), Array([0, 1, 1], dtype=int32))

如果 k = 1,則傳回主對角線上方的第一個子對角線及其上方的索引。

>>> jnp.triu_indices(3, k=1)
(Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))

如果 k = -1,則傳回主對角線下方的第一個子對角線及其上方的索引。

>>> jnp.triu_indices(3, k=-1)
(Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))