jax.numpy.argpartition#

jax.numpy.argpartition(a, kth, axis=-1)[原始碼]#

傳回部分排序陣列的索引。

JAX 版本的 numpy.argpartition() 實作。JAX 版本與 NumPy 在 NaN 條目的處理方式上有所不同:設定負位元的 NaN 會排序到陣列的開頭。

參數:
  • a (ArrayLike) – 要分割的陣列。

  • kth (int) – 靜態整數索引,陣列將圍繞此索引進行分割。

  • axis (int) – 靜態整數軸,沿此軸分割陣列;預設值為 -1。

傳回:

沿 axis 在第 kth 個值分割 a 的索引。 kth 之前的條目是小於 take(a, kth, axis) 值的索引,而 kth 之後的條目是大於 take(a, kth, axis) 值的索引

傳回類型:

Array

注意

JAX 版本要求 kth 引數為靜態整數,而不是一般陣列。這是透過兩次呼叫 jax.lax.top_k() 實作的。如果您只存取輸出的頂部或底部 k 個值,則直接呼叫 jax.lax.top_k() 可能更有效率。

另請參閱

範例

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> idx = jnp.argpartition(x, kth)
>>> idx
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)

結果是部分排序輸入的索引序列。 kth 之前的所有索引都小於樞紐值,而 kth 之後的所有索引都大於樞紐值

>>> x_partitioned = x[idx]
>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [6 8 9 7 5]

請注意,在 smallest_valueslargest_values 之間,傳回的順序是任意的且取決於實作。