jax.device_get#
- jax.device_get(x)[source]#
將
x
傳輸到主機。如果
x
是 pytree,則個別緩衝區會平行複製。- 參數::
x (Any) – 代表要傳輸到主機的陣列的陣列、純量、Array 或 (巢狀) 標準 Python 容器。
- 返回::
代表
x
值的陣列或 (巢狀) Python 容器。
範例
傳遞 Array
>>> import jax >>> x = jax.numpy.array([1., 2., 3.]) >>> jax.device_get(x) array([1., 2., 3.], dtype=float32)
傳遞純量 (沒有效果)
>>> jax.device_get(1) 1
另請參閱
device_put
device_put_sharded
device_put_replicated