jax.numpy.tril#
- jax.numpy.tril(m, k=0)[原始碼]#
傳回陣列的下三角矩陣。
JAX 實作的
numpy.tril()
- 參數:
m (ArrayLike) – 輸入陣列。必須具有
m.ndim >= 2
。k (int) – k:選用,整數,預設值=0。指定次對角線,高於該次對角線的陣列元素會設為零。
k=0
指的是主對角線,k<0
指的是主對角線下方的次對角線,而k>0
指的是主對角線上方的次對角線。
- 傳回值:
一個與輸入形狀相同的陣列,包含給定陣列的下三角矩陣,其中高於
k
指定的次對角線之上的元素會設為零。- 傳回類型:
參見
jax.numpy.triu()
:傳回陣列的上三角矩陣。jax.numpy.tri()
:傳回一個在對角線及其下方為 1,其他地方為 0 的陣列。
範例
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.tril(x) Array([[ 1, 0, 0, 0], [ 5, 6, 0, 0], [ 9, 10, 11, 0]], dtype=int32) >>> jnp.tril(x, k=1) Array([[ 1, 2, 0, 0], [ 5, 6, 7, 0], [ 9, 10, 11, 12]], dtype=int32) >>> jnp.tril(x, k=-1) Array([[ 0, 0, 0, 0], [ 5, 0, 0, 0], [ 9, 10, 0, 0]], dtype=int32)
當
m.ndim > 2
時,jnp.tril
會在尾部軸上批次運算。>>> x1 = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.tril(x1) Array([[[1, 0], [3, 4]], [[5, 0], [7, 8]]], dtype=int32)