jax.numpy.astype#
- jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[原始碼]#
將陣列轉換為指定的 dtype。
JAX 實作的
numpy.astype()
。這是透過
jax.lax.convert_element_type()
實作,在某些情況下,其行為可能與numpy.astype()
略有不同。特別是,浮點數到整數和整數到浮點數轉換的細節取決於實作。- 參數:
- 回傳:
與
x
具有相同形狀的陣列,包含指定 dtype 的值。- 回傳類型:
另請參閱
jax.lax.convert_element_type()
:用於 XLA 風格 dtype 轉換的底層函式。
範例
>>> x = jnp.array([0, 1, 2, 3]) >>> x Array([0, 1, 2, 3], dtype=int32) >>> x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0]) >>> y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)