jax.lax.approx_min_k#

jax.lax.approx_min_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[原始碼]#

以近似方式傳回 operand 的最小 k 值及其索引。

演算法詳細資訊請參閱 https://arxiv.org/abs/2206.14286

參數:
  • operand (Array) – 搜尋最小 k 值的陣列。必須為浮點數類型。

  • k (int) – 指定最小 k 值的數量。

  • reduction_dimension (int) – 沿哪個整數維度搜尋。預設值:-1。

  • recall_target (float) – 近似的召回目標。

  • reduction_input_size_override (int) – 當設定為正值時,它會覆寫由 operand[reduction_dim] 決定的尺寸,以評估召回率。當給定的運算元只是 SPMD 或分散式管線中整體計算的子集時,此選項非常有用,在這些情況下,真實的輸入尺寸無法由 operand 形狀延遲決定。

  • aggregate_to_topk (bool) – 若為 true,則將近似結果彙總到排序順序中的前 k 個。若為 false,則傳回未排序的近似結果。在這種情況下,近似結果的數量是實作定義的,且大於或等於指定的 k

傳回:

包含兩個陣列的元組。這些陣列是最小的 k 值,以及沿輸入 operandreduction_dimension 的對應索引。陣列的維度與輸入 operand 相同,但 reduction_dimension 除外:當 aggregate_to_topk 為 true 時,縮減維度為 k;否則,它大於等於 k,其中尺寸是實作定義的。

傳回類型:

tuple[Array, Array]

我們鼓勵使用者使用 jit 包裝 approx_min_k。請參閱以下範例,了解平方 l2 距離上的最近鄰搜尋

>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
...   dists = half_db_norms - jax.lax.dot(qy, db.transpose())
...   return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)

在上面的範例中,我們計算 db^2/2 - dot(qy, db^T) 而不是 qy^2 - 2 dot(qy, db^T) + db^2 是為了效能考量。前者使用的算術運算較少,並產生相同的鄰居集合。