jax.numpy.argpartition#
- jax.numpy.argpartition(a, kth, axis=-1)[原始碼]#
傳回部分排序陣列的索引。
JAX 版本的
numpy.argpartition()
實作。JAX 版本與 NumPy 在 NaN 條目的處理方式上有所不同:設定負位元的 NaN 會排序到陣列的開頭。- 參數:
- 傳回:
沿
axis
在第kth
個值分割a
的索引。kth
之前的條目是小於take(a, kth, axis)
值的索引,而kth
之後的條目是大於take(a, kth, axis)
值的索引- 傳回類型:
注意
JAX 版本要求
kth
引數為靜態整數,而不是一般陣列。這是透過兩次呼叫jax.lax.top_k()
實作的。如果您只存取輸出的頂部或底部 k 個值,則直接呼叫jax.lax.top_k()
可能更有效率。另請參閱
jax.numpy.partition()
:直接部分排序jax.numpy.argsort()
:完整間接排序jax.lax.top_k()
:直接尋找前 k 個條目jax.lax.approx_max_k()
:計算近似前 k 個條目jax.lax.approx_min_k()
:計算近似後 k 個條目
範例
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> idx = jnp.argpartition(x, kth) >>> idx Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
結果是部分排序輸入的索引序列。
kth
之前的所有索引都小於樞紐值,而kth
之後的所有索引都大於樞紐值>>> x_partitioned = x[idx] >>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [6 8 9 7 5]
請注意,在
smallest_values
和largest_values
之間,傳回的順序是任意的且取決於實作。