jax.numpy.stack#

jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[原始碼]#

沿著新軸線加入陣列。

JAX 版本的 numpy.stack()

參數:
  • arrays (np.ndarray | Array | Sequence[ArrayLike]) – 要堆疊的陣列序列;每個陣列都必須具有相同的形狀。如果給定單個陣列,則其處理方式等同於 arrays = unstack(arrays),但實作將避免顯式解堆疊。

  • axis (int) – 指定要沿其堆疊的軸線。

  • out (None | None) – JAX 未使用

  • dtype (DTypeLike | None | None) – 結果陣列的可選 dtype。如果未指定,則將透過類型提升語義中描述的類型提升規則來確定 dtype。

回傳值:

堆疊的結果。

回傳類型:

Array

另請參閱

範例

>>> x = jnp.array([1, 2, 3])
>>> y = jnp.array([4, 5, 6])
>>> jnp.stack([x, y])
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.stack([x, y], axis=1)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

unstack() 執行反向操作

>>> arr = jnp.stack([x, y], axis=1)
>>> x, y = jnp.unstack(arr, axis=1)
>>> x
Array([1, 2, 3], dtype=int32)
>>> y
Array([4, 5, 6], dtype=int32)