jax.numpy.permute_dims#
- jax.numpy.permute_dims(a, /, axes)[原始碼]#
排列陣列的軸/維度。
array_api.permute_dims()
的 JAX 實作。- 參數:
- 回傳:
排列軸後的
a
副本。- 回傳型別:
範例
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.permute_dims(a, (1, 0)) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)