jax.lax.approx_max_k#

jax.lax.approx_max_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[原始碼]#

以近似的方式傳回 operand 的 max k 值及其索引。

演算法詳細資訊請參閱 https://arxiv.org/abs/2206.14286

參數:
  • operand (Array) – 搜尋 max-k 的陣列。必須是浮點數類型。

  • k (int) – 指定 max-k 的數量。

  • reduction_dimension (int) – 沿著哪個整數維度搜尋。預設值:-1。

  • recall_target (float) – 近似的召回目標。

  • reduction_input_size_override (int) – 當設定為正值時,它會覆寫由 operand[reduction_dim] 決定的尺寸,以評估召回率。當給定的 operand 只是 SPMD 或分散式管線中整體計算的子集時,此選項非常有用,在這些情況下,真實的輸入尺寸無法由 operand 形狀延遲決定。

  • aggregate_to_topk (bool) – 為 true 時,將近似結果聚合到已排序順序中的 top-k。為 false 時,傳回未排序的近似結果。在這種情況下,近似結果的數量是實作定義的,並且大於或等於指定的 k

返回值:

兩個陣列的元組。這些陣列是輸入 operand 的 max k 值以及沿著 reduction_dimension 的對應索引。這些陣列的維度與輸入 operand 相同,但 reduction_dimension 除外:當 aggregate_to_topk 為 true 時,reduction dimension 為 k;否則,它大於等於 k,其中大小是實作定義的。

傳回類型:

tuple[Array, Array]

我們鼓勵使用者使用 jit 包裝 approx_max_k。請參閱以下最大化內部乘積搜尋 (MIPS) 的範例

>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def mips(qy, db, k=10, recall_target=0.95):
...   dists = jax.lax.dot(qy, db.transpose())
...   # returns (f32[qy_size, k], i32[qy_size, k])
...   return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> dot_products, neighbors = mips(qy, db, k=10)