jax.numpy.from_dlpack#

jax.numpy.from_dlpack(x, /, *, device=None, copy=None)[原始碼]#

透過 DLPack 建構 JAX 陣列。

numpy.from_dlpack() 的 JAX 實作。

參數:
  • x (Any) – 透過 __dlpack____dlpack_device__ 方法實作 DLPack 協定的物件,或 CPU 或 GPU 上的舊版 DLPack 張量。

  • device (xc.Device | Sharding | None | None) – 可選的 DeviceSharding,表示應放置傳回陣列的單一裝置。如果給定,則結果會提交到裝置。如果未指定,則結果陣列將解壓縮到其原始裝置上。將 device 設定為與 external_array 來源不同的裝置將需要複製,這表示 copy 必須設定為 TrueNone

  • copy (bool | None | None) – 可選的布林值,控制是否執行複製。如果 copy=True 則始終執行複製,即使解壓縮到同一裝置上也是如此。如果 copy=False 則永遠不會執行複製,並且在必要時會引發錯誤。當 copy=None (預設) 時,如果裝置傳輸需要,則可能會執行複製。

傳回值:

輸入緩衝區的 JAX 陣列。

傳回型別:

Array

注意

雖然 JAX 陣列始終是不可變的,但 dlpack 緩衝區無法標記為不可變,並且 JAX 外部的進程可能會就地變更它們。如果從未複製的 dlpack 緩衝區建構 JAX 陣列,並且來源緩衝區稍後就地修改,則在使用相關聯的 JAX 陣列時可能會導致未定義的行為。

範例

透過 DLPack 在 NumPy 和 JAX 之間傳遞資料

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> x_numpy = rng.random(4, dtype='float32')
>>> print(x_numpy)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_numpy, "__dlpack__")  # NumPy supports the DLPack interface
True
>>> import jax.numpy as jnp
>>> x_jax = jnp.from_dlpack(x_numpy)
>>> print(x_jax)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_jax, "__dlpack__")  # JAX supports the DLPack interface
True
>>> x_numpy_round_trip = np.from_dlpack(x_jax)
>>> print(x_numpy_round_trip)
[0.08925092 0.773956   0.6545715  0.43887842]