jax.lax.approx_max_k#
- jax.lax.approx_max_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[原始碼]#
以近似的方式傳回
operand
的 maxk
值及其索引。演算法詳細資訊請參閱 https://arxiv.org/abs/2206.14286。
- 參數:
operand (Array) – 搜尋 max-k 的陣列。必須是浮點數類型。
k (int) – 指定 max-k 的數量。
reduction_dimension (int) – 沿著哪個整數維度搜尋。預設值:-1。
recall_target (float) – 近似的召回目標。
reduction_input_size_override (int) – 當設定為正值時,它會覆寫由
operand[reduction_dim]
決定的尺寸,以評估召回率。當給定的operand
只是 SPMD 或分散式管線中整體計算的子集時,此選項非常有用,在這些情況下,真實的輸入尺寸無法由 operand 形狀延遲決定。aggregate_to_topk (bool) – 為 true 時,將近似結果聚合到已排序順序中的 top-k。為 false 時,傳回未排序的近似結果。在這種情況下,近似結果的數量是實作定義的,並且大於或等於指定的
k
。
- 返回值:
兩個陣列的元組。這些陣列是輸入
operand
的 maxk
值以及沿著reduction_dimension
的對應索引。這些陣列的維度與輸入operand
相同,但reduction_dimension
除外:當aggregate_to_topk
為 true 時,reduction dimension 為k
;否則,它大於等於k
,其中大小是實作定義的。- 傳回類型:
我們鼓勵使用者使用 jit 包裝
approx_max_k
。請參閱以下最大化內部乘積搜尋 (MIPS) 的範例>>> import functools >>> import jax >>> import numpy as np >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"]) ... def mips(qy, db, k=10, recall_target=0.95): ... dists = jax.lax.dot(qy, db.transpose()) ... # returns (f32[qy_size, k], i32[qy_size, k]) ... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target) >>> >>> qy = jax.numpy.array(np.random.rand(50, 64)) >>> db = jax.numpy.array(np.random.rand(1024, 64)) >>> dot_products, neighbors = mips(qy, db, k=10)