jax.numpy.concatenate#

jax.numpy.concatenate(arrays, axis=0, dtype=None)[原始碼]#

沿現有軸線連接陣列。

numpy.concatenate() 的 JAX 實作。

參數::
  • arrays (np.ndarray | Array | Sequence[ArrayLike]) – 要連接的陣列序列;除了指定的軸線外,每個陣列都必須具有相同的形狀。如果給定單一陣列,則會將其視為等同於 arrays = unstack(arrays),但實作將避免明確的 unstacking。

  • axis (int | None) – 指定要沿其連接的軸線。

  • dtype (DTypeLike | None | None) – 結果陣列的選填資料型別。如果未指定,資料型別將透過型別提升語意中描述的型別提升規則來決定。

傳回::

已連接的結果。

回傳型別::

Array

另請參閱

範例

一維連接

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concatenate([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

二維連接

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concatenate([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)