jax.numpy.permute_dims#

jax.numpy.permute_dims(a, /, axes)[原始碼]#

排列陣列的軸/維度。

array_api.permute_dims() 的 JAX 實作。

參數:
  • a (ArrayLike) – 輸入陣列

  • axes (tuple[int, ...]) – 整數 tuple,範圍在 [0, a.ndim) 內,指定軸的排列方式。

回傳:

排列軸後的 a 副本。

回傳型別:

Array

範例

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.permute_dims(a, (1, 0))
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)