jax.numpy.swapaxes#
- jax.numpy.swapaxes(a, axis1, axis2)[source]#
交換陣列的兩個軸。
JAX 實現的
numpy.swapaxes()
,以jax.lax.transpose()
實作。筆記
與
numpy.swapaxes()
不同,jax.numpy.swapaxes()
將返回輸入陣列的副本,而不是視圖。然而,在 JIT 下,編譯器將在可能的情況下優化掉這些副本,因此這在實踐中不會對效能產生影響。參見
jax.numpy.moveaxis()
:移動陣列的單個軸。jax.numpy.rollaxis()
:moveaxis
的較舊 API。jax.lax.transpose()
:更通用的軸置換。jax.Array.swapaxes()
:通過陣列方法實現相同的功能。
範例
>>> a = jnp.ones((2, 3, 4, 5)) >>> jnp.swapaxes(a, 1, 3).shape (2, 5, 4, 3)
通過
swapaxes
陣列方法實現等效輸出>>> a.swapaxes(1, 3).shape (2, 5, 4, 3)
通過
transpose()
實現等效輸出>>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3)