jax.numpy.rollaxis#

jax.numpy.rollaxis(a, axis, start=0)[source]#

將指定的軸滾動到給定的位置。

JAX 實作的 numpy.rollaxis()

此函數的存在是為了與 NumPy 相容,但在大多數情況下,更新的 jax.numpy.moveaxis() 會是更好的選擇,因為它的參數意義更直觀。

參數:
  • a (ArrayLike) – 輸入陣列。

  • axis (int) – 要向前滾動的軸的索引。

  • start (int) – 軸將滾動到的索引 (預設值 = 0)。在正規化負軸之後,如果 start <= axis,軸會滾動到 start 索引;如果 start > axis,軸會滾動到 start 之前的位置。

回傳:

滾動軸後的 a 副本。

回傳型別:

陣列

筆記

不同於 numpy.rollaxis()jax.numpy.rollaxis() 將回傳輸入陣列的副本,而不是視圖。然而,在 JIT 下,編譯器將在可能的情況下優化掉這些副本,因此這在實務上不會對效能造成影響。

另請參閱

範例

>>> a = jnp.ones((2, 3, 4, 5))

將軸 2 滾動到陣列的開頭

>>> jnp.rollaxis(a, 2).shape
(4, 2, 3, 5)

將軸 1 滾動到陣列的結尾

>>> jnp.rollaxis(a, 1, a.ndim).shape
(2, 4, 5, 3)

這兩個操作與 moveaxis() 等效

>>> jnp.moveaxis(a, 2, 0).shape
(4, 2, 3, 5)
>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)