jax.numpy.tri#
- jax.numpy.tri(N, M=None, k=0, dtype=None)[原始碼]#
傳回在對角線及下方為 1,其他地方為 0 的陣列。
JAX 版本的
numpy.tri()
- 參數:
N (int) – int。傳回陣列的列維度。
M (int | None | None) – 選填,int。傳回陣列的欄維度。如果未指定,則
M = N
。k (int) – 選填,int,預設值=0。指定陣列中對角線及下方填充 1 的子對角線。
k=0
指的是主對角線,k<0
指的是主對角線下方的子對角線,而k>0
指的是主對角線上方的子對角線。dtype (DTypeLike | None | None) – 選填,傳回陣列的資料型別。預設型別為 float。
- 傳回值:
形狀為
(N, M)
的陣列,其中包含下三角,子對角線下方的元素由k
指定為 1,其他地方為零。- 傳回型別:
另請參閱
jax.numpy.tril()
:傳回陣列的下三角。jax.numpy.triu()
:傳回陣列的上三角。
範例
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
當
M
不等於N
時>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
當
k>0
時>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
當
k<0
時>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)