jax.numpy.dstack#

jax.numpy.dstack(tup, dtype=None)[原始碼]#

深度堆疊陣列。

JAX 版本的 numpy.dstack() 實作。

對於三個或更多維度的陣列,這等同於 jax.numpy.concatenate(),其中 axis=2

參數::
  • tup (np.ndarray | Array | Sequence[ArrayLike]) – 要堆疊的陣列序列;除了第三軸外,每個陣列都必須具有相同的形狀。輸入陣列將被提升到至少秩 3。如果給定單個陣列,它將被視為等同於 tup = unstack(tup),但實作將避免顯式解堆疊。

  • dtype (DTypeLike | None | None) – 結果陣列的可選 dtype。如果未指定,dtype 將通過 型別提升語意 中描述的型別提升規則來確定。

返回::

堆疊的結果。

返回型別::

Array

參見

範例

純量值

>>> jnp.dstack([1, 2, 3])
Array([[[1, 2, 3]]], dtype=int32, weak_type=True)

1D 陣列

>>> x = jnp.arange(3)
>>> y = jnp.ones(3)
>>> jnp.dstack([x, y])
Array([[[0., 1.],
        [1., 1.],
        [2., 1.]]], dtype=float32)

2D 陣列

>>> x = x.reshape(1, 3)
>>> y = y.reshape(1, 3)
>>> jnp.dstack([x, y])
Array([[[0., 1.],
        [1., 1.],
        [2., 1.]]], dtype=float32)