jax.numpy.tril#

jax.numpy.tril(m, k=0)[原始碼]#

傳回陣列的下三角矩陣。

JAX 實作的 numpy.tril()

參數:
  • m (ArrayLike) – 輸入陣列。必須具有 m.ndim >= 2

  • k (int) – k:選用,整數,預設值=0。指定次對角線,高於該次對角線的陣列元素會設為零。k=0 指的是主對角線,k<0 指的是主對角線下方的次對角線,而 k>0 指的是主對角線上方的次對角線。

傳回值:

一個與輸入形狀相同的陣列,包含給定陣列的下三角矩陣,其中高於 k 指定的次對角線之上的元素會設為零。

傳回類型:

Array

參見

範例

>>> x = jnp.array([[1, 2, 3, 4],
...                [5, 6, 7, 8],
...                [9, 10, 11, 12]])
>>> jnp.tril(x)
Array([[ 1,  0,  0,  0],
       [ 5,  6,  0,  0],
       [ 9, 10, 11,  0]], dtype=int32)
>>> jnp.tril(x, k=1)
Array([[ 1,  2,  0,  0],
       [ 5,  6,  7,  0],
       [ 9, 10, 11, 12]], dtype=int32)
>>> jnp.tril(x, k=-1)
Array([[ 0,  0,  0,  0],
       [ 5,  0,  0,  0],
       [ 9, 10,  0,  0]], dtype=int32)

m.ndim > 2 時,jnp.tril 會在尾部軸上批次運算。

>>> x1 = jnp.array([[[1, 2],
...                  [3, 4]],
...                 [[5, 6],
...                  [7, 8]]])
>>> jnp.tril(x1)
Array([[[1, 0],
        [3, 4]],

       [[5, 0],
        [7, 8]]], dtype=int32)