jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[原始碼]#
從 npy 檔案載入 JAX 陣列。
的 JAX 包裝器。numpy.load()
此函數是
的簡單包裝器,但在使用numpy.load()
或numpy.save()
建立的jax.numpy.save()
.npy
檔案的情況下,輸出將以
傳回,並且會還原jax.Array
bfloat16
資料類型。對於.npz
檔案,結果將以一般的 NumPy 陣列傳回。此函數需要具體的陣列輸入,並且與
或jax.jit()
等轉換不相容。jax.vmap()
- 參數:
file (IO[bytes] | str | os.PathLike[Any]) – 字串、位元組或包含陣列資料的路徑類物件。
args (Any) – 如需其他引數,請參閱
numpy.load()
kwargs (Any) – 如需其他引數,請參閱
numpy.load()
- 傳回:
儲存在檔案中的陣列。
- 傳回類型:
參見
:將陣列儲存到檔案。jax.numpy.save()
範例
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)