jax.numpy.asarray#

jax.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None)[原始碼]#

將物件轉換為 JAX 陣列。

numpy.asarray() 的 JAX 實作。

參數:
  • a (Any) – 可轉換為陣列的物件。這包括 JAX 陣列、NumPy 陣列、Python 純量、Python 集合(如列表和元組)、具有 __array__ 方法的物件,以及支援 Python buffer 協定的物件。

  • dtype (DTypeLike | None | None) – 可選地指定輸出陣列的 dtype。如果未指定,將從輸入推斷。

  • order (str | None | None) – 在 JAX 中未實作

  • copy (bool | None | None) – 可選的布林值,指定複製模式。如果為 True,則始終返回副本。如果為 False,則在必要時複製會出錯。預設值為 None,僅在必要時複製。

  • device (xc.Device | Sharding | None | None) – 可選的 DeviceSharding,將在其中提交建立的陣列。

返回:

從輸入建構的 JAX 陣列。

返回類型:

Array

參見

範例

從 Python 純量建構 JAX 陣列

>>> jnp.asarray(True)
Array(True, dtype=bool)
>>> jnp.asarray(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.asarray(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.asarray(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

從 Python 集合建構 JAX 陣列

>>> jnp.asarray([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.asarray([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.asarray(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

從 NumPy 陣列建構 JAX 陣列

>>> jnp.asarray(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

透過 Python buffer 介面建構 JAX 陣列,使用 Python 內建的 array 模組。

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.asarray(pybuffer)
Array([2, 3, 5, 7], dtype=int32)