jax.numpy.full#

jax.numpy.full(shape, fill_value, dtype=None, *, device=None)[原始碼]#

建立一個充滿指定值的陣列。

JAX 實作的 numpy.full()

參數:
  • shape (Any) – 指定建立陣列形狀的整數或整數序列。

  • fill_value (ArrayLike) – 用於填充建立陣列的純量或陣列。

  • dtype (DTypeLike | None | None) – 建立陣列的可選 dtype;預設為填充值的 dtype。

  • device (xc.Device | Sharding | None | None) – (可選) 將建立的陣列提交到的 DeviceSharding

返回:

具有指定形狀和 dtype 的陣列,如果指定,則在指定的裝置上。

返回類型:

Array

範例

>>> jnp.full(4, 2, dtype=float)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full((2, 3), 0, dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)

fill_value 也可以是廣播到指定形狀的陣列

>>> jnp.full((2, 3), fill_value=jnp.arange(3))
Array([[0, 1, 2],
       [0, 1, 2]], dtype=int32)