jax.numpy.concat#
- jax.numpy.concat(arrays, /, *, axis=0)[原始碼]#
沿著現有軸連接陣列。
array_api.concat()
的 JAX 實作。- 參數:
arrays (Sequence[ArrayLike]) – 要連接的陣列序列;除了指定的軸之外,每個陣列都必須具有相同的形狀。如果給定單個陣列,它將被視為等同於 arrays = unstack(arrays),但實作將避免顯式解堆疊。
axis (int | None) – 指定要沿其連接的軸。
- 回傳:
連接的結果。
- 回傳型別:
參見
jax.lax.concatenate()
:XLA 連接 API。jax.numpy.concatenate()
:此函數的 NumPy 版本。jax.numpy.stack()
:沿新軸連接陣列。
範例
一維連接
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concat([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)
二維連接
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concat([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32)