jax.numpy.dsplit#

jax.numpy.dsplit(ary, indices_or_sections)[原始碼]#

深度方向分割陣列為子陣列。

numpy.dsplit() 的 JAX 實作。

詳細資訊請參閱 jax.numpy.split() 的文件。dsplit 等同於 axis=2split

範例

>>> x = jnp.arange(12).reshape(3, 1, 4)
>>> print(x)
[[[ 0  1  2  3]]

 [[ 4  5  6  7]]

 [[ 8  9 10 11]]]
>>> x1, x2 = jnp.dsplit(x, 2)
>>> print(x1)
[[[0 1]]

 [[4 5]]

 [[8 9]]]
>>> print(x2)
[[[ 2  3]]

 [[ 6  7]]

 [[10 11]]]

另請參閱

參數:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

返回類型:

list[Array]