jax.numpy.moveaxis#

jax.numpy.moveaxis(a, source, destination)[原始碼]#

將陣列軸移動到新位置

numpy.moveaxis() 的 JAX 實作,以 jax.lax.transpose() 實作。

參數:
  • a (ArrayLike) – 輸入陣列

  • source (int | Sequence[int]) – 要移動的軸的索引或索引。

  • destination (int | Sequence[int]) – 軸目的地的索引或索引

傳回:

source 移動到 destination 軸的 a 副本。

傳回類型:

Array

筆記

numpy.moveaxis() 不同,jax.numpy.moveaxis() 將傳回輸入陣列的副本而不是視圖。然而,在 JIT 下,編譯器將在可能的情況下最佳化掉此類副本,因此這在實務上不會對效能產生影響。

參見

範例

>>> a = jnp.ones((2, 3, 4, 5))

將軸 1 移動到陣列末尾

>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)

將最後一個軸移動到位置 1

>>> jnp.moveaxis(a, -1, 1).shape
(2, 5, 3, 4)

移動多個軸

>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape
(4, 5, 3, 2)

這也可以透過 transpose() 完成

>>> a.transpose(2, 3, 1, 0).shape
(4, 5, 3, 2)