jax.numpy.array#
- jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0, *, device=None)[原始碼]#
將物件轉換為 JAX 陣列。
JAX 實作的
numpy.array()
。- 參數:
object (Any) – 可轉換為陣列的物件。這包括 JAX 陣列、NumPy 陣列、Python 純量、Python 集合 (如列表和元組)、具有
__array__
方法的物件,以及支援 Python 緩衝區協定的物件。dtype (DTypeLike | None | None) – 可選地指定輸出陣列的 dtype。如果未指定,將從輸入推斷。
copy (bool) – 指定是否強制複製輸入。預設值:True。
order (str | None) – 在 JAX 中未實作
ndmin (int) – 整數,指定輸出陣列中的最小維度數。
device (xc.Device | Sharding | None | None) – 可選的
Device
或Sharding
,將在其中提交建立的陣列。
- 傳回:
從輸入建構的 JAX 陣列。
- 傳回型別:
另請參閱
jax.numpy.asarray()
:類似 array,但預設僅在必要時複製。jax.numpy.from_dlpack()
:從實作 dlpack 介面的物件建構 JAX 陣列。jax.numpy.frombuffer()
:從實作緩衝區介面的物件建構 JAX 陣列。
範例
從 Python 純量建構 JAX 陣列
>>> jnp.array(True) Array(True, dtype=bool) >>> jnp.array(42) Array(42, dtype=int32, weak_type=True) >>> jnp.array(3.5) Array(3.5, dtype=float32, weak_type=True) >>> jnp.array(1 + 1j) Array(1.+1.j, dtype=complex64, weak_type=True)
從 Python 集合建構 JAX 陣列
>>> jnp.array([1, 2, 3]) # list of ints -> 1D array Array([1, 2, 3], dtype=int32) >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.array(range(5)) Array([0, 1, 2, 3, 4], dtype=int32)
從 NumPy 陣列建構 JAX 陣列
>>> jnp.array(np.linspace(0, 2, 5)) Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
透過 Python 緩衝區介面建構 JAX 陣列,使用 Python 內建的
array
模組。>>> from array import array >>> pybuffer = array('i', [2, 3, 5, 7]) >>> jnp.array(pybuffer) Array([2, 3, 5, 7], dtype=int32)