jax.numpy.moveaxis#
- jax.numpy.moveaxis(a, source, destination)[原始碼]#
將陣列軸移動到新位置
numpy.moveaxis()
的 JAX 實作,以jax.lax.transpose()
實作。- 參數:
- 傳回:
從
source
移動到destination
軸的a
副本。- 傳回類型:
筆記
與
numpy.moveaxis()
不同,jax.numpy.moveaxis()
將傳回輸入陣列的副本而不是視圖。然而,在 JIT 下,編譯器將在可能的情況下最佳化掉此類副本,因此這在實務上不會對效能產生影響。參見
jax.numpy.swapaxes()
:交換兩個軸。jax.numpy.rollaxis()
:用於移動軸的舊版 API。jax.numpy.transpose()
:一般軸置換。
範例
>>> a = jnp.ones((2, 3, 4, 5))
將軸
1
移動到陣列末尾>>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)
將最後一個軸移動到位置 1
>>> jnp.moveaxis(a, -1, 1).shape (2, 5, 3, 4)
移動多個軸
>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape (4, 5, 3, 2)
這也可以透過
transpose()
完成>>> a.transpose(2, 3, 1, 0).shape (4, 5, 3, 2)