jax.numpy.vstack#
- jax.numpy.vstack(tup, dtype=None)[原始碼]#
垂直堆疊陣列。
numpy.vstack()
的 JAX 實作。對於二維或更多維度的陣列,這等同於使用
axis=0
的jax.numpy.concatenate()
。- 參數:
- 返回:
堆疊結果。
- 返回型別:
參見
jax.numpy.stack()
:沿任意軸堆疊jax.numpy.concatenate()
:沿現有軸串連。jax.numpy.hstack()
:水平堆疊,即沿軸 1。jax.numpy.dstack()
:深度方向堆疊,即沿軸 2。
範例
純量值
>>> jnp.vstack([1, 2, 3]) Array([[1], [2], [3]], dtype=int32, weak_type=True)
一維陣列
>>> x = jnp.arange(4) >>> y = jnp.ones(4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)
二維陣列
>>> x = x.reshape(1, 4) >>> y = y.reshape(1, 4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)