jax.numpy.append#

jax.numpy.append(arr, values, axis=None)[source]#

傳回一個新陣列,其中值已附加到原始陣列的末尾。

JAX 版本的 numpy.append() 實作。

參數:
  • arr (ArrayLike) – 原始陣列。

  • values (ArrayLike) – 要附加到陣列的值。values 必須與 arr 具有相同的維度數,並且除了指定的軸之外,所有維度都必須匹配。

  • axis (int | None) – 沿著哪個軸附加值。如果為 None(預設值),則在附加之前,arrvalues 都將被展平。

傳回值:

一個新的陣列,其中值已附加到 arr

傳回型別:

陣列 (Array)

範例

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.append(a, b)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

沿特定軸附加

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6]])
>>> jnp.append(a, b, axis=0)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

沿尾軸附加

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[7], [8]])
>>> jnp.append(a, b, axis=1)
Array([[1, 2, 3, 7],
       [4, 5, 6, 8]], dtype=int32)