jax.numpy.vstack#

jax.numpy.vstack(tup, dtype=None)[原始碼]#

垂直堆疊陣列。

numpy.vstack() 的 JAX 實作。

對於二維或更多維度的陣列,這等同於使用 axis=0jax.numpy.concatenate()

參數:
  • tup (np.ndarray | Array | Sequence[ArrayLike]) – 要堆疊的陣列序列;每個陣列除了第一個軸之外,其餘軸的形狀必須相同。如果給定單一陣列,則會將其視為等同於 tup = unstack(tup),但實作將避免明確的 unstacking。

  • dtype (DTypeLike | None | None) – 結果陣列的可選 dtype。如果未指定,dtype 將透過型別提升語意中描述的型別提升規則來決定。

返回:

堆疊結果。

返回型別:

Array

參見

範例

純量值

>>> jnp.vstack([1, 2, 3])
Array([[1],
       [2],
       [3]], dtype=int32, weak_type=True)

一維陣列

>>> x = jnp.arange(4)
>>> y = jnp.ones(4)
>>> jnp.vstack([x, y])
Array([[0., 1., 2., 3.],
       [1., 1., 1., 1.]], dtype=float32)

二維陣列

>>> x = x.reshape(1, 4)
>>> y = y.reshape(1, 4)
>>> jnp.vstack([x, y])
Array([[0., 1., 2., 3.],
       [1., 1., 1., 1.]], dtype=float32)