jax.lax.approx_min_k#
- jax.lax.approx_min_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[原始碼]#
以近似方式傳回
operand
的最小k
值及其索引。演算法詳細資訊請參閱 https://arxiv.org/abs/2206.14286。
- 參數:
operand (Array) – 搜尋最小 k 值的陣列。必須為浮點數類型。
k (int) – 指定最小 k 值的數量。
reduction_dimension (int) – 沿哪個整數維度搜尋。預設值:-1。
recall_target (float) – 近似的召回目標。
reduction_input_size_override (int) – 當設定為正值時,它會覆寫由
operand[reduction_dim]
決定的尺寸,以評估召回率。當給定的運算元只是 SPMD 或分散式管線中整體計算的子集時,此選項非常有用,在這些情況下,真實的輸入尺寸無法由operand
形狀延遲決定。aggregate_to_topk (bool) – 若為 true,則將近似結果彙總到排序順序中的前 k 個。若為 false,則傳回未排序的近似結果。在這種情況下,近似結果的數量是實作定義的,且大於或等於指定的
k
。
- 傳回:
包含兩個陣列的元組。這些陣列是最小的
k
值,以及沿輸入operand
的reduction_dimension
的對應索引。陣列的維度與輸入operand
相同,但reduction_dimension
除外:當aggregate_to_topk
為 true 時,縮減維度為k
;否則,它大於等於k
,其中尺寸是實作定義的。- 傳回類型:
我們鼓勵使用者使用 jit 包裝
approx_min_k
。請參閱以下範例,了解平方 l2 距離上的最近鄰搜尋>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95): ... dists = half_db_norms - jax.lax.dot(qy, db.transpose()) ... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2 >>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
在上面的範例中,我們計算
db^2/2 - dot(qy, db^T)
而不是qy^2 - 2 dot(qy, db^T) + db^2
是為了效能考量。前者使用的算術運算較少,並產生相同的鄰居集合。