jax.numpy.flip#

jax.numpy.flip(m, axis=None)[原始碼]#

沿著給定軸反轉陣列元素的順序。

NumPy numpy.flip() 的 JAX 實作。

參數:
  • m (ArrayLike) – 陣列。

  • axis (int | Sequence[int] | None | None) – 整數或整數序列。指定應沿哪些軸反轉陣列元素。預設值為 None,表示沿所有軸反轉。

回傳值:

沿著 axis 軸元素順序反轉的陣列。

回傳類型:

陣列

另請參閱

範例

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> jnp.flip(x1)
Array([[4, 3],
       [2, 1]], dtype=int32)

如果 axis 指定為整數,則 jax.numpy.flip 僅沿該特定軸反轉陣列。

>>> jnp.flip(x1, axis=1)
Array([[2, 1],
       [4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x2
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.flip(x2)
Array([[[8, 7],
        [6, 5]],

       [[4, 3],
        [2, 1]]], dtype=int32)

axis 指定為整數序列時,則 jax.numpy.flip 沿指定軸反轉陣列。

>>> jnp.flip(x2, axis=[1, 2])
Array([[[4, 3],
        [2, 1]],

       [[8, 7],
        [6, 5]]], dtype=int32)