jax.numpy.array#

jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0, *, device=None)[原始碼]#

將物件轉換為 JAX 陣列。

JAX 實作的 numpy.array()

參數:
  • object (Any) – 可轉換為陣列的物件。這包括 JAX 陣列、NumPy 陣列、Python 純量、Python 集合 (如列表和元組)、具有 __array__ 方法的物件,以及支援 Python 緩衝區協定的物件。

  • dtype (DTypeLike | None | None) – 可選地指定輸出陣列的 dtype。如果未指定,將從輸入推斷。

  • copy (bool) – 指定是否強制複製輸入。預設值:True。

  • order (str | None) – 在 JAX 中未實作

  • ndmin (int) – 整數,指定輸出陣列中的最小維度數。

  • device (xc.Device | Sharding | None | None) – 可選的 DeviceSharding,將在其中提交建立的陣列。

傳回:

從輸入建構的 JAX 陣列。

傳回型別:

Array

另請參閱

範例

從 Python 純量建構 JAX 陣列

>>> jnp.array(True)
Array(True, dtype=bool)
>>> jnp.array(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.array(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.array(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

從 Python 集合建構 JAX 陣列

>>> jnp.array([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.array([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.array(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

從 NumPy 陣列建構 JAX 陣列

>>> jnp.array(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

透過 Python 緩衝區介面建構 JAX 陣列,使用 Python 內建的 array 模組。

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.array(pybuffer)
Array([2, 3, 5, 7], dtype=int32)