jax.numpy.load#

jax.numpy.load(file, *args, **kwargs)[原始碼]#

從 npy 檔案載入 JAX 陣列。

numpy.load() 的 JAX 包裝器。

此函數是 numpy.load() 的簡單包裝器,但在使用 numpy.save()jax.numpy.save() 建立的 .npy 檔案的情況下,輸出將以 jax.Array 傳回,並且會還原 bfloat16 資料類型。對於 .npz 檔案,結果將以一般的 NumPy 陣列傳回。

此函數需要具體的陣列輸入,並且與 jax.jit()jax.vmap() 等轉換不相容。

參數:
  • file (IO[bytes] | str | os.PathLike[Any]) – 字串、位元組或包含陣列資料的路徑類物件。

  • args (Any) – 如需其他引數,請參閱 numpy.load()

  • kwargs (Any) – 如需其他引數,請參閱 numpy.load()

傳回:

儲存在檔案中的陣列。

傳回類型:

Array

參見

範例

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)