jax.numpy.sort#
- jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[原始碼]#
傳回陣列的排序副本。
JAX 版本的
numpy.sort()
。- 參數:
- 傳回值:
形狀為
a.shape
的已排序陣列(如果axis
是整數),或形狀為(a.size,)
的已排序陣列(如果axis
為 None)。- 傳回型別:
範例
簡單的一維排序
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)
沿陣列的最後一個軸排序
>>> x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) >>> jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
另請參閱
jax.numpy.argsort()
:傳回排序值的索引。jax.numpy.lexsort()
:多個陣列的詞彙排序。jax.lax.sort()
:包裝 XLA Sort 運算子的較低層級函式。