jax.numpy.argsort#
- jax.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#
返回排序陣列的索引。
JAX 實作的
numpy.argsort()
。- 參數:
- 返回:
排序陣列的索引陣列。返回的陣列形狀將為
a.shape
(如果axis
是整數)或(a.size,)
(如果axis
為 None)。- 返回類型:
範例
簡單的一維排序
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> indices = jnp.argsort(x) >>> indices Array([0, 5, 4, 1, 3, 2], dtype=int32) >>> x[indices] Array([1, 1, 2, 3, 4, 5], dtype=int32)
沿陣列的最後一個軸排序
>>> x = jnp.array([[2, 1, 3], ... [6, 4, 3]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 0, 2], [2, 1, 0]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
另請參閱
jax.numpy.sort()
:直接返回排序後的值。jax.numpy.lexsort()
:多個陣列的字典式排序。jax.lax.sort()
:包裝 XLA Sort 運算符的底層函數。