jax.device_get#

jax.device_get(x)[source]#

x 傳輸到主機。

如果 x 是 pytree,則個別緩衝區會平行複製。

參數::

x (Any) – 代表要傳輸到主機的陣列的陣列、純量、Array 或 (巢狀) 標準 Python 容器。

返回::

代表 x 值的陣列或 (巢狀) Python 容器。

範例

傳遞 Array

>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)

傳遞純量 (沒有效果)

>>> jax.device_get(1)
1

另請參閱

  • device_put

  • device_put_sharded

  • device_put_replicated