jax.numpy.searchsorted#
- jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[原始碼]#
在已排序陣列中執行二元搜尋。
JAX 實現的
numpy.searchsorted()
。這將傳回已排序陣列
a
內的索引,其中可以插入v
中的值以維持其排序順序。- 參數:
- 傳回:
形狀為
v.shape
的插入索引陣列。- 傳回類型:
Note
method
參數控制用於計算插入索引的演算法。'scan'
(預設) 在 CPU 上往往效能更高,尤其當a
非常大時。'scan_unrolled'
在 GPU 上效能更高,但會犧牲額外的編譯時間。'sort'
在 GPU 和 TPU 等加速器後端上通常效能更高,尤其當v
非常大時。'compare_all'
在a
非常小時往往效能最高。
範例
搜尋單個值
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) >>> jnp.searchsorted(a, 2) Array(1, dtype=int32) >>> jnp.searchsorted(a, 2, side='right') Array(3, dtype=int32)
搜尋一批值
>>> vals = jnp.array([0, 3, 8, 1.5, 2]) >>> jnp.searchsorted(a, vals) Array([0, 3, 7, 1, 1], dtype=int32)
或者,可以使用
sorter
參數來查找插入到通過jax.numpy.argsort()
排序的陣列中的索引>>> a = jnp.array([4, 3, 5, 1, 2]) >>> sorter = jnp.argsort(a) >>> jnp.searchsorted(a, vals, sorter=sorter) Array([0, 2, 5, 1, 1], dtype=int32)
結果等效於傳遞已排序的陣列
>>> jnp.searchsorted(jnp.sort(a), vals) Array([0, 2, 5, 1, 1], dtype=int32)