jax.numpy.from_dlpack#
- jax.numpy.from_dlpack(x, /, *, device=None, copy=None)[原始碼]#
透過 DLPack 建構 JAX 陣列。
numpy.from_dlpack()
的 JAX 實作。- 參數:
x (Any) – 透過
__dlpack__
和__dlpack_device__
方法實作 DLPack 協定的物件,或 CPU 或 GPU 上的舊版 DLPack 張量。device (xc.Device | Sharding | None | None) – 可選的
Device
或Sharding
,表示應放置傳回陣列的單一裝置。如果給定,則結果會提交到裝置。如果未指定,則結果陣列將解壓縮到其原始裝置上。將device
設定為與external_array
來源不同的裝置將需要複製,這表示copy
必須設定為True
或None
。copy (bool | None | None) – 可選的布林值,控制是否執行複製。如果
copy=True
則始終執行複製,即使解壓縮到同一裝置上也是如此。如果copy=False
則永遠不會執行複製,並且在必要時會引發錯誤。當copy=None
(預設) 時,如果裝置傳輸需要,則可能會執行複製。
- 傳回值:
輸入緩衝區的 JAX 陣列。
- 傳回型別:
注意
雖然 JAX 陣列始終是不可變的,但 dlpack 緩衝區無法標記為不可變,並且 JAX 外部的進程可能會就地變更它們。如果從未複製的 dlpack 緩衝區建構 JAX 陣列,並且來源緩衝區稍後就地修改,則在使用相關聯的 JAX 陣列時可能會導致未定義的行為。
範例
透過 DLPack 在 NumPy 和 JAX 之間傳遞資料
>>> import numpy as np >>> rng = np.random.default_rng(42) >>> x_numpy = rng.random(4, dtype='float32') >>> print(x_numpy) [0.08925092 0.773956 0.6545715 0.43887842] >>> hasattr(x_numpy, "__dlpack__") # NumPy supports the DLPack interface True
>>> import jax.numpy as jnp >>> x_jax = jnp.from_dlpack(x_numpy) >>> print(x_jax) [0.08925092 0.773956 0.6545715 0.43887842] >>> hasattr(x_jax, "__dlpack__") # JAX supports the DLPack interface True
>>> x_numpy_round_trip = np.from_dlpack(x_jax) >>> print(x_numpy_round_trip) [0.08925092 0.773956 0.6545715 0.43887842]