jax.numpy.identity#
- jax.numpy.identity(n, dtype=None)[原始碼]#
建立方形單位矩陣
JAX 版本的
numpy.identity()
實作。- 參數:
n (DimSize) – 整數,指定每個陣列維度的大小。
dtype (DTypeLike | None | None) – 選擇性 dtype;預設為浮點數。
- 傳回值:
形狀為
(n, n)
的單位陣列。- 傳回型別:
另請參閱
jax.numpy.eye()
:非方形和/或偏移單位矩陣。範例
一個簡單的 3x3 單位矩陣
>>> jnp.identity(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
一個 2x2 整數單位矩陣
>>> jnp.identity(2, dtype=int) Array([[1, 0], [0, 1]], dtype=int32)