jax.numpy.eye#
- jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[原始碼]#
建立方形或矩形單位矩陣
JAX 版本的
numpy.eye()
。- 參數:
- 返回:
形狀為
(N, M)
或(N, N)
的單位陣列(如果未指定M
)。- 返回類型:
參見
jax.numpy.identity()
:用於生成方形單位矩陣的更簡單 API。範例
一個簡單的 3x3 單位矩陣
>>> jnp.eye(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
具有偏移對角線的整數單位矩陣
>>> jnp.eye(3, k=1, dtype=int) Array([[0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=int32) >>> jnp.eye(3, k=-1, dtype=int) Array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=int32)
非方形單位矩陣
>>> jnp.eye(3, 5, k=1) Array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]], dtype=float32)