jax.numpy.ones#

jax.numpy.ones(shape, dtype=None, *, device=None)[原始碼]#

建立一個充滿 1 的陣列。

JAX 實作的 numpy.ones()

參數:
  • shape (Any) – 整數或整數序列,指定建立的陣列的形狀。

  • dtype (DTypeLike | None | None) – 建立的陣列的可選 dtype;預設為浮點數。

  • device (xc.Device | Sharding | None | None) – (可選) DeviceSharding,建立的陣列將提交到此裝置或分片策略。

返回:

指定形狀和 dtype 的陣列,如果指定裝置,則位於指定裝置上。

返回類型:

Array

範例

>>> jnp.ones(4)
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True,  True,  True],
       [ True,  True,  True]], dtype=bool)