jax.numpy.astype#

jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[原始碼]#

將陣列轉換為指定的 dtype。

JAX 實作的 numpy.astype()

這是透過 jax.lax.convert_element_type() 實作,在某些情況下,其行為可能與 numpy.astype() 略有不同。特別是,浮點數到整數和整數到浮點數轉換的細節取決於實作。

參數:
  • x (ArrayLike) – 要轉換的輸入陣列

  • dtype (DTypeLike | None) – 輸出 dtype

  • copy (bool) – 如果為 True,則始終回傳副本。如果為 False (預設),則僅在必要時回傳副本。

  • device (xc.Device | Sharding | None | None) – 可選地指定輸出將提交到的裝置。

回傳:

x 具有相同形狀的陣列,包含指定 dtype 的值。

回傳類型:

Array

另請參閱

範例

>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int)  # truncates fractional values
Array([0, 0, 1], dtype=int32)