jax.numpy.sort#

jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[原始碼]#

傳回陣列的排序副本。

JAX 版本的 numpy.sort()

參數:
  • a (ArrayLike) – 要排序的陣列

  • axis (int | None) – 沿著排序的整數軸。預設為 -1,即最後一個軸。如果為 None,則在排序前將 a 展平。

  • stable (bool) – 布林值,指定是否應使用穩定排序。預設值=True。

  • descending (bool) – 布林值,指定是否應以降序排序。預設值=False。

  • kind (None) – 已棄用;請改用 stable=True 或 stable=False 指定排序演算法。

  • order (None) – JAX 不支援

傳回值:

形狀為 a.shape 的已排序陣列(如果 axis 是整數),或形狀為 (a.size,) 的已排序陣列(如果 axis 為 None)。

傳回型別:

Array

範例

簡單的一維排序

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> jnp.sort(x)
Array([1, 1, 2, 3, 4, 5], dtype=int32)

沿陣列的最後一個軸排序

>>> x = jnp.array([[2, 1, 3],
...                [4, 3, 6]])
>>> jnp.sort(x, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

另請參閱