jax.numpy.roll#
- jax.numpy.roll(a, shift, axis=None)[原始碼]#
沿指定軸滾動陣列的元素。
numpy.roll()
的 JAX 實作。- 參數:
- 傳回:
沿指定軸或軸滾動元素的
a
副本。- 傳回類型:
參見
jax.numpy.rollaxis()
:將指定軸滾動到給定位置。
範例
>>> a = jnp.array([0, 1, 2, 3, 4, 5]) >>> jnp.roll(a, 2) Array([4, 5, 0, 1, 2, 3], dtype=int32)
沿特定軸滾動元素
>>> a = jnp.array([[ 0, 1, 2, 3], ... [ 4, 5, 6, 7], ... [ 8, 9, 10, 11]]) >>> jnp.roll(a, 1, axis=0) Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32) >>> jnp.roll(a, [2, 3], axis=[0, 1]) Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)