jax.numpy.tril_indices_from#
- jax.numpy.tril_indices_from(arr, k=0)[原始碼]#
傳回給定陣列下三角的索引。
JAX 實作的
numpy.tril_indices_from()
。- 參數:
arr (ArrayLike) – 輸入陣列。必須具有
arr.ndim == 2
。k (int) – 選用,整數,預設值=0。指定次對角線以及下方的次對角線,傳回上三角形的索引。
k=0
代表主對角線,k<0
代表主對角線下方的次對角線,而k>0
代表主對角線上方的次對角線。
- 傳回值:
包含下三角形索引的兩個陣列的元組,每個軸各一個。
- 傳回類型:
另請參閱
jax.numpy.triu_indices_from()
:傳回給定陣列上三角形的索引。jax.numpy.tril_indices()
:傳回大小為(n, m)
的陣列下三角形的索引。jax.numpy.tril()
:傳回陣列的下三角形
範例
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.tril_indices_from(arr) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
由
jnp.tril_indices_from
索引的元素對應於jnp.tril
輸出的元素。>>> ind = jnp.tril_indices_from(arr) >>> arr[ind] Array([1, 4, 5, 7, 8, 9], dtype=int32) >>> jnp.tril(arr) Array([[1, 0, 0], [4, 5, 0], [7, 8, 9]], dtype=int32)
當
k > 0
時>>> jnp.tril_indices_from(arr, k=1) (Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))
當
k < 0
時>>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))