jax.nn.dot_product_attention#

jax.nn.dot_product_attention(query, key, value, bias=None, mask=None, *, scale=None, is_causal=False, query_seq_lengths=None, key_value_seq_lengths=None, local_window_size=None, implementation=None)[原始碼]#

縮放點積注意力函數。

計算 Query、Key 和 Value 張量上的注意力函數

\[\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]

如果我們將 logits 定義為 \(QK^T\) 的輸出,而將 probs 定義為 \(softmax\) 的輸出。

在本函數中,我們使用以下大寫字母來表示陣列的形狀

B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head
K = number of key/value heads
G = number of groups, which equals to N // K
參數:
  • query (ArrayLike) – 查詢陣列;形狀 (BTNH|TNH)

  • key (ArrayLike) – 鍵值陣列:形狀 (BSKH|SKH)。當 K 等於 N 時,執行多頭注意力 (MHA https://arxiv.org/abs/1706.03762)。否則,如果 NK 的倍數,則執行分組查詢注意力 (GQA https://arxiv.org/abs/2305.13245);如果 K == 1 (GQA 的特殊情況),則執行多查詢注意力 (MQA https://arxiv.org/abs/1911.02150)。

  • value (ArrayLike) – 值陣列,應具有與 key 陣列相同的形狀。

  • bias (ArrayLike | None | None) – 選項,要新增到 logits 的偏差陣列;形狀必須為 4D 且可廣播到 (BNTS|NTS)

  • mask (ArrayLike | None | None) – 選項,用於過濾 logits 的遮罩陣列。它是一個布林遮罩,其中 True 表示元素應參與注意力機制。對於加法遮罩,使用者應將其傳遞給 bias。形狀必須為 4D 且可廣播到 (BNTS|NTS)

  • scale (float | None | None) – logits 的縮放比例。如果為 None,則縮放比例將設定為 1 除以查詢頭部維度(即 H)的平方根。

  • is_causal (bool) – 如果為 true,將應用因果注意力機制。請注意,某些實作(如 xla)將產生遮罩張量並將其應用於 logits,以遮罩掉注意力矩陣的非因果部分,但其他實作(如 cudnn)將避免計算非因果區域,從而提供加速。

  • query_seq_lengths (ArrayLike | None | None) – int32 查詢的序列長度陣列;形狀 (B)

  • key_value_seq_lengths (ArrayLike | None | None) – int32 鍵值和值的序列長度陣列;形狀 (B)

  • local_window_size (int | tuple[int, int] | None | None) – 使自我注意力機制關注每個 token 的局部視窗的視窗大小。如果設定,這將指定每個 token 的 (left_window_size, right_window_size)。例如,如果 local_window_size == (3, 2) 且序列為 [0, 1, 2, 3, 4, 5, c, 7, 8, 9],則 token c 可以關注 [3, 4, 5, c, 7, 8]。如果給定單個整數,它將被解釋為對稱視窗 (window_size, window_size)。

  • implementation (Literal['xla', 'cudnn'] | None | None) – 控制要使用哪個實作後端的字串。支援的字串為 xlacudnn (cuDNN 快閃注意力)。預設值為 None,這將自動選擇最佳可用後端。請注意,cudnn 僅支援形狀/dtypes 的子集,如果不支持,則會擲出例外。

返回:

query 形狀相同的注意力輸出陣列。

返回類型:

Array