jax.numpy.partition#

jax.numpy.partition(a, kth, axis=-1)[原始碼]#

傳回陣列的部分排序副本。

JAX 版本的 numpy.partition() 實作。JAX 版本與 NumPy 在 NaN 條目的處理上有所不同:設定了負位元的 NaN 會排序到陣列的開頭。

參數:
  • a (ArrayLike) – 要分割的陣列。

  • kth (int) – 靜態整數索引,陣列將圍繞此索引分割。

  • axis (int) – 靜態整數軸,陣列將沿此軸分割;預設值為 -1。

傳回:

沿 axis 在第 kth 個值分割的 a 副本。 kth 之前的條目是小於 take(a, kth, axis) 的值,kth 之後的條目是值大於 take(a, kth, axis) 的索引

傳回型別:

陣列

注意

JAX 版本要求 kth 引數為靜態整數,而不是一般陣列。這是透過兩次呼叫 jax.lax.top_k() 實作的。如果您只存取輸出的頂部或底部 k 個值,則直接呼叫 jax.lax.top_k() 可能更有效率。

另請參閱

範例

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> x_partitioned = jnp.partition(x, kth)
>>> x_partitioned
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)

結果是輸入的部分排序副本。 kth 之前的所有值都小於樞紐值,而 kth 之後的所有值都大於樞紐值

>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [9 8 7 6 5]

請注意,在 smallest_valueslargest_values 中,傳回的順序是任意的,並且取決於實作。