jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[原始碼]#

沿指定軸滾動陣列的元素。

numpy.roll() 的 JAX 實作。

參數:
  • a (ArrayLike) – 輸入陣列。

  • shift (ArrayLike | Sequence[int]) – 指定軸要移動的位置數。如果是整數,則所有軸都移動相同的量。如果是元組,則個別指定每個軸的移動量。

  • axis (int | Sequence[int] | None | None) – 要滾動的軸或軸。如果為 None,則陣列會被展平、移動,然後重新塑形為其原始形狀。

傳回:

沿指定軸或軸滾動元素的 a 副本。

傳回類型:

Array

參見

範例

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

沿特定軸滾動元素

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)