jax.numpy.tri#

jax.numpy.tri(N, M=None, k=0, dtype=None)[原始碼]#

傳回在對角線及下方為 1,其他地方為 0 的陣列。

JAX 版本的 numpy.tri()

參數:
  • N (int) – int。傳回陣列的列維度。

  • M (int | None | None) – 選填,int。傳回陣列的欄維度。如果未指定,則 M = N

  • k (int) – 選填,int,預設值=0。指定陣列中對角線及下方填充 1 的子對角線。k=0 指的是主對角線,k<0 指的是主對角線下方的子對角線,而 k>0 指的是主對角線上方的子對角線。

  • dtype (DTypeLike | None | None) – 選填,傳回陣列的資料型別。預設型別為 float。

傳回值:

形狀為 (N, M) 的陣列,其中包含下三角,子對角線下方的元素由 k 指定為 1,其他地方為零。

傳回型別:

Array

另請參閱

範例

>>> jnp.tri(3)
Array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]], dtype=float32)

M 不等於 N

>>> jnp.tri(3, 4)
Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.]], dtype=float32)

k>0

>>> jnp.tri(3, k=1)
Array([[1., 1., 0.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

k<0

>>> jnp.tri(3, 4, k=-1)
Array([[0., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 1., 0., 0.]], dtype=float32)