jax.nn.dot_product_attention#
- jax.nn.dot_product_attention(query, key, value, bias=None, mask=None, *, scale=None, is_causal=False, query_seq_lengths=None, key_value_seq_lengths=None, local_window_size=None, implementation=None)[原始碼]#
縮放點積注意力函數。
計算 Query、Key 和 Value 張量上的注意力函數
\[\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]如果我們將
logits
定義為 \(QK^T\) 的輸出,而將probs
定義為 \(softmax\) 的輸出。在本函數中,我們使用以下大寫字母來表示陣列的形狀
B = batch size S = length of the key/value (source) T = length of the query (target) N = number of attention heads H = dimensions of each attention head K = number of key/value heads G = number of groups, which equals to N // K
- 參數:
query (ArrayLike) – 查詢陣列;形狀
(BTNH|TNH)
key (ArrayLike) – 鍵值陣列:形狀
(BSKH|SKH)
。當 K 等於 N 時,執行多頭注意力 (MHA https://arxiv.org/abs/1706.03762)。否則,如果 N 是 K 的倍數,則執行分組查詢注意力 (GQA https://arxiv.org/abs/2305.13245);如果 K == 1 (GQA 的特殊情況),則執行多查詢注意力 (MQA https://arxiv.org/abs/1911.02150)。value (ArrayLike) – 值陣列,應具有與 key 陣列相同的形狀。
bias (ArrayLike | None | None) – 選項,要新增到 logits 的偏差陣列;形狀必須為 4D 且可廣播到
(BNTS|NTS)
。mask (ArrayLike | None | None) – 選項,用於過濾 logits 的遮罩陣列。它是一個布林遮罩,其中 True 表示元素應參與注意力機制。對於加法遮罩,使用者應將其傳遞給 bias。形狀必須為 4D 且可廣播到
(BNTS|NTS)
。scale (float | None | None) – logits 的縮放比例。如果為 None,則縮放比例將設定為 1 除以查詢頭部維度(即 H)的平方根。
is_causal (bool) – 如果為 true,將應用因果注意力機制。請注意,某些實作(如 xla)將產生遮罩張量並將其應用於 logits,以遮罩掉注意力矩陣的非因果部分,但其他實作(如 cudnn)將避免計算非因果區域,從而提供加速。
query_seq_lengths (ArrayLike | None | None) – int32 查詢的序列長度陣列;形狀
(B)
key_value_seq_lengths (ArrayLike | None | None) – int32 鍵值和值的序列長度陣列;形狀
(B)
local_window_size (int | tuple[int, int] | None | None) – 使自我注意力機制關注每個 token 的局部視窗的視窗大小。如果設定,這將指定每個 token 的 (left_window_size, right_window_size)。例如,如果 local_window_size == (3, 2) 且序列為 [0, 1, 2, 3, 4, 5, c, 7, 8, 9],則 token c 可以關注 [3, 4, 5, c, 7, 8]。如果給定單個整數,它將被解釋為對稱視窗 (window_size, window_size)。
implementation (Literal['xla', 'cudnn'] | None | None) – 控制要使用哪個實作後端的字串。支援的字串為 xla、cudnn (cuDNN 快閃注意力)。預設值為 None,這將自動選擇最佳可用後端。請注意,cudnn 僅支援形狀/dtypes 的子集,如果不支持,則會擲出例外。
- 返回:
與
query
形狀相同的注意力輸出陣列。- 返回類型: