jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[原始碼]#

建立方形或矩形單位矩陣

JAX 版本的 numpy.eye()

參數:
  • N (DimSize) – 整數,指定陣列的第一個維度。

  • M (DimSize | None | None) – 可選的整數,指定陣列的第二個維度;預設值與 N 相同。

  • k (int | ArrayLike) – 可選的整數,指定對角線的偏移量。正值用於上對角線,負值用於下對角線。預設值為零。

  • dtype (DTypeLike | None | None) – 可選的 dtype;預設為浮點數。

  • device (xc.Device | Sharding | None | None) – 可選的 DeviceSharding,將建立的陣列提交到該裝置。

返回:

形狀為 (N, M)(N, N) 的單位陣列(如果未指定 M)。

返回類型:

Array

參見

jax.numpy.identity():用於生成方形單位矩陣的更簡單 API。

範例

一個簡單的 3x3 單位矩陣

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

具有偏移對角線的整數單位矩陣

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

非方形單位矩陣

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)