jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[原始碼]#

建立方形單位矩陣

JAX 版本的 numpy.identity() 實作。

參數:
  • n (DimSize) – 整數,指定每個陣列維度的大小。

  • dtype (DTypeLike | None | None) – 選擇性 dtype;預設為浮點數。

傳回值:

形狀為 (n, n) 的單位陣列。

傳回型別:

Array

另請參閱

jax.numpy.eye():非方形和/或偏移單位矩陣。

範例

一個簡單的 3x3 單位矩陣

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

一個 2x2 整數單位矩陣

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)