jax.numpy.concat#

jax.numpy.concat(arrays, /, *, axis=0)[原始碼]#

沿著現有軸連接陣列。

array_api.concat() 的 JAX 實作。

參數:
  • arrays (Sequence[ArrayLike]) – 要連接的陣列序列;除了指定的軸之外,每個陣列都必須具有相同的形狀。如果給定單個陣列,它將被視為等同於 arrays = unstack(arrays),但實作將避免顯式解堆疊。

  • axis (int | None) – 指定要沿其連接的軸。

回傳:

連接的結果。

回傳型別:

Array

參見

範例

一維連接

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concat([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

二維連接

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concat([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)