01-注意力机制详解¶
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 线性代数、神经网络基础、RNN/LSTM基础 学习目标: 深入理解注意力机制的各种形式,掌握自注意力和多头注意力的完整推导与实现
📌 定位说明:本章从深度学习基础视角讲解注意力机制(从Seq2Seq出发的动机、加性/点积/多头注意力推导)。面向大模型的注意力优化(FlashAttention/GQA/MQA/稀疏注意力等)请参考 LLM学习/01-基础巩固/02-注意力机制详解。
目录¶
- 1. 注意力机制动机
- 2. 注意力机制的一般框架
- 3. 加性注意力
- 4. 点积注意力与缩放点积注意力
- 5. 自注意力(Self-Attention)
- 6. 多头注意力(Multi-Head Attention)
- 7. 交叉注意力(Cross-Attention)
- 8. 注意力可视化
- 9. 注意力的变体与优化
- 10. 练习与自我检查
1. 注意力机制动机¶
1.1 信息瓶颈问题¶
在经典 Seq2Seq 模型中,编码器将整个输入序列压缩为一个固定维度的向量 \(c\)。当输入序列很长时,这个向量无法承载所有信息 — 这就是信息瓶颈。
注意力机制让解码器在每一步都能"回看"编码器的所有输出,根据当前需要动态选择关注哪些输入位置。
1.2 人类注意力的类比¶
当你阅读一个英文句子来翻译时,翻译每个词时会"注意"源句子中不同的部分。注意力机制就是让模型学会这种"选择性关注"。
2. 注意力机制的一般框架¶
2.1 Query-Key-Value 框架¶
注意力机制的核心是三元组 (Query, Key, Value):
- Query (Q): 当前要关注什么("我在找什么?")
- Key (K): 各个位置的索引("我有什么可以提供的?")
- Value (V): 各个位置的实际内容("我的实际信息是什么?")
步骤: 1. 计算 Q 与每个 K 之间的相似度分数 2. 通过 Softmax 将分数转换为注意力权重(和为 1) 3. 用权重对 V 做加权求和,得到上下文向量
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def general_attention(query, keys, values, score_fn, mask=None):
"""
通用注意力框架
query: (batch, 1, d_q) 或 (batch, seq_q, d_q)
keys: (batch, seq_k, d_k)
values: (batch, seq_k, d_v)
"""
# 1. 计算注意力分数
scores = score_fn(query, keys) # (batch, seq_q, seq_k)
# 2. 掩码(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 3. Softmax 归一化
weights = F.softmax(scores, dim=-1) # (batch, seq_q, seq_k)
# 4. 加权求和
context = torch.bmm(weights, values) # (batch, seq_q, d_v)
return context, weights
3. 加性注意力¶
3.1 Bahdanau 注意力¶
Bahdanau et al.(2015)提出的加性注意力:
其中 \(W_q \in \mathbb{R}^{d_a \times d_q}\), \(W_k \in \mathbb{R}^{d_a \times d_k}\), \(v \in \mathbb{R}^{d_a}\) 是可学习参数,\(d_a\) 是注意力维度。
class AdditiveAttention(nn.Module): # 继承nn.Module定义神经网络层
"""加性(Bahdanau)注意力"""
def __init__(self, query_dim, key_dim, attn_dim): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.W_q = nn.Linear(query_dim, attn_dim, bias=False)
self.W_k = nn.Linear(key_dim, attn_dim, bias=False)
self.v = nn.Linear(attn_dim, 1, bias=False)
def forward(self, query, keys, values, mask=None):
"""
query: (batch, 1, query_dim) — 单个查询位置
keys: (batch, seq_len, key_dim)
values: (batch, seq_len, value_dim)
"""
# (batch, 1, attn_dim) + (batch, seq_len, attn_dim) → 广播相加
scores = self.v(torch.tanh(self.W_q(query) + self.W_k(keys)))
scores = scores.squeeze(-1) # (batch, seq_len) # squeeze去除大小为1的维度
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
context = torch.bmm(weights.unsqueeze(1), values) # (batch, 1, value_dim)
return context, weights
# 测试
attn = AdditiveAttention(query_dim=256, key_dim=512, attn_dim=128)
q = torch.randn(4, 1, 256)
k = torch.randn(4, 20, 512)
v = torch.randn(4, 20, 512)
context, weights = attn(q, k, v)
print(f"Context: {context.shape}, Weights: {weights.shape}")
4. 点积注意力与缩放点积注意力¶
4.1 点积注意力¶
简单但要求 \(d_q = d_k\)。
4.2 缩放点积注意力(Scaled Dot-Product Attention)¶
Vaswani et al.(2017)在 Transformer 中使用的注意力形式:
为什么要除以 \(\sqrt{d_k}\)?
当 \(d_k\) 较大时,\(q^Tk = \sum_{i=1}^{d_k} q_i k_i\) 的方差约为 \(d_k\)(假设 \(q_i, k_i\) 独立同分布,均值 0,方差 1),这使得某些 softmax 输出极端趋近 0 或 1,梯度几乎为零。除以 \(\sqrt{d_k}\) 将方差恢复到 1。
4.3 完整推导与实现¶
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力 — Transformer 的核心"""
def __init__(self, dropout=0.0):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
"""
Q: (batch, num_heads, seq_q, d_k)
K: (batch, num_heads, seq_k, d_k)
V: (batch, num_heads, seq_k, d_v)
mask: (batch, 1, 1, seq_k) 或 (batch, 1, seq_q, seq_k)
"""
d_k = Q.size(-1)
# Step 1: 计算注意力分数
# QK^T: (batch, num_heads, seq_q, seq_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 3: Softmax 得到注意力权重
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Step 4: 加权求和
context = torch.matmul(attn_weights, V)
return context, attn_weights
# 测试
sdpa = ScaledDotProductAttention()
Q = torch.randn(2, 8, 10, 64) # batch=2, 8 heads, seq=10, d_k=64
K = torch.randn(2, 8, 20, 64) # seq_k=20
V = torch.randn(2, 8, 20, 64)
context, weights = sdpa(Q, K, V)
print(f"Context: {context.shape}") # (2, 8, 10, 64)
print(f"Weights: {weights.shape}") # (2, 8, 10, 20)
5. 自注意力(Self-Attention)¶
5.1 原理¶
自注意力中 Q、K、V 都来自同一个序列。每个位置都能关注序列中的所有位置(包括自身),捕获序列内部的依赖关系。
对于输入序列 \(X \in \mathbb{R}^{n \times d}\):
5.2 自注意力 vs RNN¶
| 特性 | 自注意力 | RNN |
|---|---|---|
| 并行计算 | ✅ 完全并行 | ❌ 必须顺序 |
| 长距离依赖 | O(1) 连接路径 | O(n) 连接路径 |
| 计算复杂度 | \(O(n^2 d)\) | \(O(n d^2)\) |
| 适合长序列 | \(n < d\) 时高效 | \(n > d\) 时高效 |
5.3 实现¶
class SelfAttention(nn.Module):
"""单头自注意力"""
def __init__(self, d_model, d_k=None, d_v=None, dropout=0.1):
super().__init__()
d_k = d_k or d_model
d_v = d_v or d_model
self.W_Q = nn.Linear(d_model, d_k)
self.W_K = nn.Linear(d_model, d_k)
self.W_V = nn.Linear(d_model, d_v)
self.scale = math.sqrt(d_k)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
x: (batch, seq_len, d_model)
"""
Q = self.W_Q(x) # (batch, seq_len, d_k)
K = self.W_K(x) # (batch, seq_len, d_k)
V = self.W_V(x) # (batch, seq_len, d_v)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = self.dropout(F.softmax(scores, dim=-1))
context = torch.matmul(weights, V)
return context, weights
# 测试
self_attn = SelfAttention(d_model=512)
x = torch.randn(4, 20, 512)
context, weights = self_attn(x)
print(f"自注意力: context={context.shape}, weights={weights.shape}")
# weights[0, i, j] 表示位置 i 对位置 j 的注意力权重
6. 多头注意力(Multi-Head Attention)¶
6.1 动机¶
单头注意力只能学习一种注意力模式。多头注意力将注意力分成多个"头",每个头可以关注不同的表示子空间和不同的位置关系。
6.2 数学公式¶
其中: $\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)$
参数:\(W_i^Q \in \mathbb{R}^{d_{model} \times d_k}\), \(W_i^K \in \mathbb{R}^{d_{model} \times d_k}\), \(W_i^V \in \mathbb{R}^{d_{model} \times d_v}\), \(W^O \in \mathbb{R}^{hd_v \times d_{model}}\)
通常 \(d_k = d_v = d_{model} / h\)。
6.3 完整实现¶
class MultiHeadAttention(nn.Module):
"""多头注意力 — 从零实现"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影(所有头合并在一起计算更高效)
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def split_heads(self, x):
"""将 (batch, seq, d_model) 重塑为 (batch, num_heads, seq, d_k)"""
batch_size, seq_len, _ = x.shape
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2)
def merge_heads(self, x):
"""将 (batch, num_heads, seq, d_k) 重塑为 (batch, seq, d_model)"""
batch_size, _, seq_len, _ = x.shape
x = x.transpose(1, 2).contiguous() # 链式调用,连续执行多个方法
return x.view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
"""
query: (batch, seq_q, d_model)
key: (batch, seq_k, d_model)
value: (batch, seq_k, d_model)
mask: (batch, 1, 1, seq_k) 或 (batch, 1, seq_q, seq_k)
"""
# 1. 线性投影
Q = self.split_heads(self.W_Q(query)) # (batch, heads, seq_q, d_k)
K = self.split_heads(self.W_K(key)) # (batch, heads, seq_k, d_k)
V = self.split_heads(self.W_V(value)) # (batch, heads, seq_k, d_k)
# 2. 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = self.dropout(F.softmax(scores, dim=-1))
context = torch.matmul(attn_weights, V)
# 3. 合并多头
context = self.merge_heads(context) # (batch, seq_q, d_model)
# 4. 输出投影
output = self.W_O(context)
return output, attn_weights
# 测试
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(4, 20, 512)
output, weights = mha(x, x, x) # 自注意力:Q=K=V=x
print(f"多头注意力: output={output.shape}, weights={weights.shape}")
# output: (4, 20, 512), weights: (4, 8, 20, 20)
7. 交叉注意力(Cross-Attention)¶
交叉注意力中 Q 来自一个序列(通常是解码器),K 和 V 来自另一个序列(通常是编码器)。这是 Transformer 解码器中连接编码器和解码器的关键机制。
# 交叉注意力的使用(在 Transformer 解码器中)
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 自注意力(带因果掩码)
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
# 交叉注意力 — Q 来自解码器,K/V 来自编码器
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm2 = nn.LayerNorm(d_model)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# 1. 自注意力
attn_out, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_out))
# 2. 交叉注意力 — Q=x(解码器), K=V=enc_output(编码器)
cross_out, cross_weights = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(cross_out))
# 3. 前馈网络
ffn_out = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_out))
return x, cross_weights
8. 注意力可视化¶
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def visualize_attention(attention_weights, src_tokens, tgt_tokens=None, head_idx=0):
"""
可视化注意力权重
attention_weights: (num_heads, tgt_len, src_len)
"""
if tgt_tokens is None:
tgt_tokens = src_tokens
# 选择某个头
weights = attention_weights[head_idx].detach().cpu().numpy()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(weights, xticklabels=src_tokens, yticklabels=tgt_tokens,
cmap='Blues', annot=True, fmt='.2f', ax=ax)
ax.set_title(f'Attention Weights (Head {head_idx})')
ax.set_xlabel('Source')
ax.set_ylabel('Target')
plt.tight_layout()
plt.show()
def visualize_all_heads(attention_weights, src_tokens, tgt_tokens=None, ncols=4):
"""可视化所有头的注意力"""
num_heads = attention_weights.shape[0]
nrows = (num_heads + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
if tgt_tokens is None:
tgt_tokens = src_tokens
for idx in range(num_heads):
ax = axes[idx // ncols, idx % ncols] if nrows > 1 else axes[idx % ncols]
weights = attention_weights[idx].detach().cpu().numpy()
sns.heatmap(weights, ax=ax, cmap='Blues', xticklabels=src_tokens,
yticklabels=tgt_tokens if idx % ncols == 0 else False)
ax.set_title(f'Head {idx}')
plt.tight_layout()
plt.show()
# 使用示例
# tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat']
# weights = torch.randn(8, 6, 6).softmax(dim=-1) # 8 heads
# visualize_attention(weights, tokens, head_idx=0)
# visualize_all_heads(weights, tokens)
9. 注意力的变体与优化¶
9.0 缩放点积注意力的梯度推导¶
设 \(S = \frac{QK^T}{\sqrt{d_k}}\),\(A = \text{softmax}(S)\)(逐行),\(O = AV\)。已知上游梯度 \(\frac{\partial \mathcal{L}}{\partial O}\)(记为 \(dO\)),反向传播求 \(dQ, dK, dV\)。
Step 1: 对 V 的梯度
Step 2: 对 A 的梯度
Step 3: 通过 softmax 反向传播(逐行,设第 \(i\) 行的 softmax 输出为 \(a\),上游梯度为 \(da\)):
即 \(ds_{ij} = a_{ij}(da_{ij} - \sum_k a_{ik} \cdot da_{ik})\)。这是 softmax 反向传播的标准公式。
Step 4: 对 Q 和 K 的梯度
实践要点:标准实现中 softmax 反向传播的中间矩阵 \(dS\) 大小为 \(O(n^2)\),是注意力的显存瓶颈。Flash Attention 通过分块计算和在线 softmax 重计算避免了显式存储 \(S\) 和 \(dS\),将显存从 \(O(n^2)\) 降到 \(O(n)\)。
9.1 常见变体¶
| 变体 | 复杂度 | 思想 |
|---|---|---|
| 标准注意力 | \(O(n^2)\) | 全注意力 |
| 稀疏注意力 | \(O(n\sqrt{n})\) | 只关注固定模式的位置 |
| 线性注意力 | \(O(n)\) | 用核函数近似 softmax |
| Flash Attention | \(O(n^2)\), 低显存 | IO 感知的精确注意力 |
| 滑动窗口 | \(O(nw)\) | 只关注局部窗口 |
| 分组查询注意力(GQA) | \(O(n^2)\), 低显存 | K/V 共享组 |
9.2 Flash Attention¶
# PyTorch 2.0+ 内置了 Flash Attention
# 使用 scaled_dot_product_attention(自动选择最优实现)
from torch.nn.functional import scaled_dot_product_attention
Q = torch.randn(2, 8, 1024, 64, device='cuda')
K = torch.randn(2, 8, 1024, 64, device='cuda')
V = torch.randn(2, 8, 1024, 64, device='cuda')
# 自动使用 Flash Attention(如果满足条件)
output = scaled_dot_product_attention(Q, K, V, is_causal=True)
print(f"Flash Attention 输出: {output.shape}")
10. 练习与自我检查¶
练习题¶
- 实现对比:分别实现加性注意力和缩放点积注意力,在同一任务上对比效率。
- 自注意力:从零实现单头自注意力,加入掩码支持,在简单序列分类上测试。
- 多头注意力:从零实现多头注意力,验证不同头学到了不同的注意力模式。
- 注意力可视化:训练一个带注意力的翻译模型,可视化并解读注意力热力图。
- 缩放分析:实验验证不缩放 (\(1/\sqrt{d_k}\)) 时梯度会趋向零的现象。
自我检查清单¶
- 理解注意力机制解决了什么问题
- 能区分 Q、K、V 的各自角色
- 理解为什么需要 \(1/\sqrt{d_k}\) 缩放
- 能从零实现缩放点积注意力
- 理解多头注意力的动机和具体操作
- 能区分自注意力和交叉注意力
- 了解 Flash Attention 等效率优化方法
下一篇: 02-Transformer架构 — 完整的Transformer架构详解




