jax.numpy.column_stack#
- jax.numpy.column_stack(tup)[原始碼]#
以行堆疊陣列。
numpy.column_stack()
的 JAX 實作。對於兩個或多個維度的陣列,這等同於
jax.numpy.concatenate()
,其中axis=1
。- 參數:
- 傳回:
堆疊的結果。
- 傳回型別:
參見
jax.numpy.stack()
:沿任意軸堆疊jax.numpy.concatenate()
:沿現有軸串聯。jax.numpy.vstack()
:垂直堆疊,即沿軸 0。jax.numpy.hstack()
:水平堆疊,即沿軸 1。jax.numpy.hstack()
:深度堆疊,即沿軸 2。
範例
純量值
>>> jnp.column_stack([1, 2, 3]) Array([[1, 2, 3]], dtype=int32, weak_type=True)
1D 陣列
>>> x = jnp.arange(3) >>> y = jnp.ones(3) >>> jnp.column_stack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)
2D 陣列
>>> x = x.reshape(3, 1) >>> y = y.reshape(3, 1) >>> jnp.column_stack([x, y]) Array([[0., 1.], [1., 1.], [2., 1.]], dtype=float32)