jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[原始碼]#

在已排序陣列中執行二元搜尋。

JAX 實現的 numpy.searchsorted()

這將傳回已排序陣列 a 內的索引,其中可以插入 v 中的值以維持其排序順序。

參數:
  • a (ArrayLike) – 一維陣列,除非指定 sorter,否則假定為已排序。

  • v (ArrayLike) – 查詢值的 N 維陣列

  • side (str) – 'left' (預設) 或 'right';指定在平局的情況下,插入索引將位於左側還是右側。

  • sorter (ArrayLike | None) – 可選的索引陣列,用於指定 a 的排序順序。如果指定,則演算法假定 a[sorter] 處於已排序狀態。

  • method (str) – 'scan' (預設)、'scan_unrolled''sort''compare_all' 之一。請參閱下面的Note

傳回:

形狀為 v.shape 的插入索引陣列。

傳回類型:

Array

Note

method 參數控制用於計算插入索引的演算法。

  • 'scan' (預設) 在 CPU 上往往效能更高,尤其當 a 非常大時。

  • 'scan_unrolled' 在 GPU 上效能更高,但會犧牲額外的編譯時間。

  • 'sort' 在 GPU 和 TPU 等加速器後端上通常效能更高,尤其當 v 非常大時。

  • 'compare_all'a 非常小時往往效能最高。

範例

搜尋單個值

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

搜尋一批值

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

或者,可以使用 sorter 參數來查找插入到通過 jax.numpy.argsort() 排序的陣列中的索引

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

結果等效於傳遞已排序的陣列

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)