jax.numpy.rollaxis#
- jax.numpy.rollaxis(a, axis, start=0)[source]#
將指定的軸滾動到給定的位置。
JAX 實作的
numpy.rollaxis()
。此函數的存在是為了與 NumPy 相容,但在大多數情況下,更新的
jax.numpy.moveaxis()
會是更好的選擇,因為它的參數意義更直觀。- 參數:
- 回傳:
滾動軸後的
a
副本。- 回傳型別:
筆記
不同於
numpy.rollaxis()
,jax.numpy.rollaxis()
將回傳輸入陣列的副本,而不是視圖。然而,在 JIT 下,編譯器將在可能的情況下優化掉這些副本,因此這在實務上不會對效能造成影響。另請參閱
jax.numpy.moveaxis()
:語意比rollaxis
更清晰的新 API;在大多數情況下,應優先選擇此 API 而非rollaxis
。jax.numpy.swapaxes()
:交換兩個軸。jax.numpy.transpose()
:軸的一般排列。
範例
>>> a = jnp.ones((2, 3, 4, 5))
將軸 2 滾動到陣列的開頭
>>> jnp.rollaxis(a, 2).shape (4, 2, 3, 5)
將軸 1 滾動到陣列的結尾
>>> jnp.rollaxis(a, 1, a.ndim).shape (2, 4, 5, 3)
這兩個操作與
moveaxis()
等效>>> jnp.moveaxis(a, 2, 0).shape (4, 2, 3, 5) >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3)