jax.numpy.unstack#
- jax.numpy.unstack(x, /, *, axis=0)[source]#
沿著軸向展開陣列。
array_api.unstack()
的 JAX 實作。- 參數:
x (ArrayLike) – 要展開的陣列。必須具有
x.ndim >= 1
。axis (int) – 沿著此整數軸向展開。必須滿足
-x.ndim <= axis < x.ndim
。
- 回傳:
展開陣列的元組。
- 回傳類型:
另請參閱
jax.numpy.stack()
:unstack
的反向操作jax.numpy.split()
:沿著軸向將陣列分割成批次。
範例
>>> arr = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> arrs = jnp.unstack(arr) >>> print(*arrs) [1 2 3] [4 5 6]
stack()
提供此操作的反向功能>>> jnp.stack(arrs) Array([[1, 2, 3], [4, 5, 6]], dtype=int32)