jax.numpy.stack#
- jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[原始碼]#
沿著新軸線加入陣列。
JAX 版本的
numpy.stack()
。- 參數:
- 回傳值:
堆疊的結果。
- 回傳類型:
另請參閱
jax.numpy.unstack()
:stack
的反向操作。jax.numpy.concatenate()
:沿著現有軸線串連。jax.numpy.vstack()
:垂直堆疊,即沿軸線 0。jax.numpy.hstack()
:水平堆疊,即沿軸線 1。jax.numpy.dstack()
:深度堆疊,即沿軸線 2。
範例
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack()
執行反向操作>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)