jax.numpy.asarray#
- jax.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None)[原始碼]#
將物件轉換為 JAX 陣列。
numpy.asarray()
的 JAX 實作。- 參數:
a (Any) – 可轉換為陣列的物件。這包括 JAX 陣列、NumPy 陣列、Python 純量、Python 集合(如列表和元組)、具有
__array__
方法的物件,以及支援 Python buffer 協定的物件。dtype (DTypeLike | None | None) – 可選地指定輸出陣列的 dtype。如果未指定,將從輸入推斷。
order (str | None | None) – 在 JAX 中未實作
copy (bool | None | None) – 可選的布林值,指定複製模式。如果為 True,則始終返回副本。如果為 False,則在必要時複製會出錯。預設值為 None,僅在必要時複製。
device (xc.Device | Sharding | None | None) – 可選的
Device
或Sharding
,將在其中提交建立的陣列。
- 返回:
從輸入建構的 JAX 陣列。
- 返回類型:
參見
jax.numpy.array()
:類似於 asarray,但預設為 copy=True。jax.numpy.from_dlpack()
:從實作 dlpack 介面的物件建構 JAX 陣列。jax.numpy.frombuffer()
:從實作 buffer 介面的物件建構 JAX 陣列。
範例
從 Python 純量建構 JAX 陣列
>>> jnp.asarray(True) Array(True, dtype=bool) >>> jnp.asarray(42) Array(42, dtype=int32, weak_type=True) >>> jnp.asarray(3.5) Array(3.5, dtype=float32, weak_type=True) >>> jnp.asarray(1 + 1j) Array(1.+1.j, dtype=complex64, weak_type=True)
從 Python 集合建構 JAX 陣列
>>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array Array([1, 2, 3], dtype=int32) >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.asarray(range(5)) Array([0, 1, 2, 3, 4], dtype=int32)
從 NumPy 陣列建構 JAX 陣列
>>> jnp.asarray(np.linspace(0, 2, 5)) Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
透過 Python buffer 介面建構 JAX 陣列,使用 Python 內建的
array
模組。>>> from array import array >>> pybuffer = array('i', [2, 3, 5, 7]) >>> jnp.asarray(pybuffer) Array([2, 3, 5, 7], dtype=int32)