jax.numpy.copy#

jax.numpy.copy(a, order=None)[原始碼]#

傳回陣列的副本。

JAX 實作的 numpy.copy()

參數:
  • a (ArrayLike) – 要複製的類陣列物件

  • order (str | None | None) – 在 JAX 中未實作

傳回:

輸入陣列 a 的副本。

傳回類型:

Array

另請參閱

範例

由於 JAX 陣列是不可變的,因此在大多數情況下,顯式陣列副本不是必要的。一種例外情況是當使用具有捐贈引數的函式時 (請參閱 jax.jit()donate_argnums 引數)。

>>> f = jax.jit(lambda x: 2 * x, donate_argnums=0)
>>> x = jnp.arange(4)
>>> y = f(x)
>>> print(y)
[0 2 4 6]

因為我們將 x 標記為已捐贈,所以原始陣列不再可用

>>> print(x)  
Traceback (most recent call last):
RuntimeError: Array has been deleted with shape=int32[4].

在這種情況下,顯式副本可讓您保留對原始緩衝區的存取權

>>> x = jnp.arange(4)
>>> y = f(x.copy())
>>> print(y)
[0 2 4 6]
>>> print(x)
[0 1 2 3]