jax.numpy.frombuffer#

jax.numpy.frombuffer(buffer, dtype=<class 'float'>, count=-1, offset=0)[source]#

將緩衝區轉換為 1 維 JAX 陣列。

numpy.frombuffer() 的 JAX 實作。

參數:
  • buffer (bytes | Any) – 包含資料的物件。它必須是位元組物件,其長度是 dtype 元素大小的整數倍數,或者它必須是匯出 Python 緩衝區介面的物件。

  • dtype (DTypeLike) – 選填。陣列所需的資料類型。預設值為 float64。這指定用於解析緩衝區的 dtype,但請注意,解析後,如果 jax_enable_x64 標誌設定為 False,則 64 位元值將被轉換為 32 位元 JAX 陣列。

  • count (int) – 選填整數,指定要從緩衝區讀取的項目數。如果為 -1 (預設值),則會讀取緩衝區中的所有項目。

  • offset (int) – 選填整數,指定在緩衝區開頭要跳過的位元組數。預設值為 0。

傳回:

代表從緩衝區解譯資料的 1 維 JAX 陣列。

傳回類型:

Array

另請參閱

範例

使用位元組緩衝區

>>> buf = b"\x00\x01\x02\x03\x04"
>>> jnp.frombuffer(buf, dtype=jnp.uint8)
Array([0, 1, 2, 3, 4], dtype=uint8)
>>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
Array([1, 2, 3, 4], dtype=uint8)

透過 Python 緩衝區介面建構 JAX 陣列,使用 Python 內建的 array 模組。

>>> from array import array
>>> pybuffer = array('i', [0, 1, 2, 3, 4])
>>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
Array([0, 1, 2, 3, 4], dtype=int32)