jax.numpy.triu_indices_from#
- jax.numpy.triu_indices_from(arr, k=0)[原始碼]#
傳回給定陣列的上三角形索引。
JAX 實作的
numpy.triu_indices_from()
。- 參數:
arr (ArrayLike) – 輸入陣列。必須具有
arr.ndim == 2
。k (int) – 選用,int,預設值=0。指定次對角線及其上方要傳回上三角形索引的位置。
k=0
指的是主對角線,k<0
指的是主對角線下方的次對角線,而k>0
指的是主對角線上方的次對角線。
- 傳回:
包含上三角形索引的兩個陣列的元組,每個軸各一個。
- 傳回類型:
另請參閱
jax.numpy.tril_indices_from()
:傳回給定陣列的下三角形索引。jax.numpy.triu_indices()
:傳回大小為(n, m)
的陣列的上三角形索引。jax.numpy.triu()
:傳回陣列的上三角形。
範例
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.triu_indices_from(arr) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
由
jnp.triu_indices_from
索引的元素對應於jnp.triu
輸出的元素。>>> ind = jnp.triu_indices_from(arr) >>> arr[ind] Array([1, 2, 3, 5, 6, 9], dtype=int32) >>> jnp.triu(arr) Array([[1, 2, 3], [0, 5, 6], [0, 0, 9]], dtype=int32)
當
k > 0
時>>> jnp.triu_indices_from(arr, k=1) (Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))
當
k < 0
時>>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))