02 - 注意力机制详解(全面版)¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:深入理解注意力机制的各种变体、数学原理、优化技术和可视化分析方法。
📌 定位说明:本章侧重大模型中的注意力优化与前沿变体(FlashAttention/GQA/MQA/稀疏注意力等)。注意力机制的基础教学(从Seq2Seq出发的动机、加性/点积/多头注意力推导)请参考 深度学习/04-Transformer/01-注意力机制详解。
目录¶
注意力机制的统一框架¶
1.1 注意力机制的通用形式¶
所有注意力机制都可以统一表示为以下形式:
Text Only
Attention(Q, K, V) = f(Q, K) · V
其中:
- Q (Query): 查询向量,表示"我在寻找什么"
- K (Key): 键向量,表示"我有什么"
- V (Value): 值向量,表示"实际内容是什么"
- f(Q, K): 相似度函数,计算查询与键的匹配程度
1.2 相似度函数对比¶
| 机制 | 相似度函数 | 时间复杂度 | 空间复杂度 | 特点 |
|---|---|---|---|---|
| 加性注意力 | score = v^T tanh(W_Q·Q + W_K·K) | O(n²d) | O(n²) | 灵活,适合不同维度 |
| 点积注意力 | score = Q·K^T | O(n²d) | O(n²) | 简单快速,需维度匹配 |
| 缩放点积 | score = (Q·K^T) / √d_k | O(n²d) | O(n²) | 最常用,数值稳定 |
| 双线性注意力 | score = Q·W·K^T | O(n²d²) | O(n²) | 可学习相似度 |
1.3 注意力作为软寻址¶
Text Only
注意力机制可以看作是一种"软寻址":
硬寻址(如数据库查询):
- 精确匹配某个key
- 返回对应的value
软寻址(注意力):
- 计算query与所有key的相似度
- 返回所有value的加权平均
- 权重由相似度决定
类比:
- 硬寻址:在图书馆精确找到某本书
- 软寻址:根据主题相关性,从多本书中提取信息
上图展示了自注意力机制的核心流程。输入token通过权重矩阵W转换为查询(Query)、键(Key)和值(Value),然后通过计算查询和键的相似度得到注意力权重,最后用这些权重对值进行加权聚合,得到最终的输出。
自注意力深度剖析¶
2.1 自注意力的信息流动¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttentionAnalyzer:
"""
自注意力分析工具
"""
@staticmethod # @staticmethod无需实例即可调用
def analyze_information_flow(attention_weights, tokens):
"""
分析自注意力中的信息流动
Args:
attention_weights: [seq_len, seq_len]
tokens: 词列表
"""
seq_len = len(tokens)
print("=" * 60)
print("自注意力信息流动分析")
print("=" * 60)
for i, token in enumerate(tokens): # enumerate同时获取索引和元素
# 获取当前token的注意力分布
attn_dist = attention_weights[i]
# 找到最关注的token
top_k = 3
top_indices = torch.topk(attn_dist, top_k).indices
print(f"\n位置 {i} ('{token}') 最关注:")
for idx in top_indices:
print(f" - 位置 {idx} ('{tokens[idx]}'): {attn_dist[idx]:.3f}")
# 计算熵(注意力分布的集中度)
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum()
print(f" 注意力熵: {entropy:.3f} (越低越集中)")
@staticmethod
def compute_attention_entropy(attention_weights):
"""
计算注意力分布的熵
熵高 = 注意力分散(关注多个位置)
熵低 = 注意力集中(关注特定位置)
"""
entropy = -(attention_weights * torch.log(attention_weights + 1e-10)).sum(dim=-1)
return entropy.mean()
@staticmethod
def compute_attention_sparsity(attention_weights, threshold=0.1):
"""
计算注意力稀疏度
返回注意力权重大于threshold的比例
"""
sparsity = (attention_weights > threshold).float().mean()
return sparsity
# 示例分析
"""
输入: "The cat sat on the mat"
自注意力后的信息流动:
位置0 'The':
- 可能关注: 'cat' (确定名词)
- 熵: 中等(需要看上下文)
位置1 'cat':
- 可能关注: 'The' (主语), 'sat' (谓语)
- 熵: 较低(关系明确)
位置2 'sat':
- 可能关注: 'cat' (主语), 'on' (介词)
- 熵: 中等
位置3 'on':
- 可能关注: 'sat' (动词), 'mat' (宾语)
- 熵: 较低
位置4 'the':
- 可能关注: 'mat' (确定名词)
- 熵: 中等
位置5 'mat':
- 可能关注: 'on' (介词), 'the' (冠词)
- 熵: 较低
"""
2.2 注意力的梯度流动¶
Python
class AttentionGradientAnalyzer:
"""
分析注意力的梯度流动
"""
@staticmethod
def analyze_gradient_flow(model, input_ids, target_ids):
"""
分析注意力层的梯度
"""
model.train()
# 前向传播
outputs = model(input_ids, labels=target_ids)
loss = outputs.loss
# 反向传播
loss.backward()
# 分析各层的梯度
grad_stats = {}
for name, param in model.named_parameters():
if 'attention' in name and param.grad is not None:
grad_stats[name] = {
'mean': param.grad.mean().item(),
'std': param.grad.std().item(),
'max': param.grad.max().item(),
'min': param.grad.min().item()
}
return grad_stats
注意力变体详解¶
3.1 加性注意力(Additive Attention)¶
Python
class AdditiveAttention(nn.Module):
"""
加性注意力(Bahdanau Attention)
使用一个前馈网络计算相似度
score = v^T * tanh(W_q * Q + W_k * K)
优点:可以处理不同维度的Q和K
"""
def __init__(self, query_dim, key_dim, hidden_dim):
super().__init__() # super()调用父类方法
self.W_q = nn.Linear(query_dim, hidden_dim, bias=False)
self.W_k = nn.Linear(key_dim, hidden_dim, bias=False)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, query, key, value, mask=None):
"""
Args:
query: [batch, query_len, query_dim]
key: [batch, key_len, key_dim]
value: [batch, key_len, value_dim]
"""
# 扩展维度以便广播
# query: [batch, query_len, 1, hidden]
# key: [batch, 1, key_len, hidden]
query_proj = self.W_q(query).unsqueeze(2) # unsqueeze增加一个维度
key_proj = self.W_k(key).unsqueeze(1)
# 计算分数
scores = self.v(torch.tanh(query_proj + key_proj)).squeeze(-1)
# scores: [batch, query_len, key_len]
# 应用mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.bmm(attn_weights, value)
return output, attn_weights
3.2 双线性注意力¶
Python
class BilinearAttention(nn.Module):
"""
双线性注意力
score = Q * W * K^T
可以学习Q和K之间的复杂交互
"""
def __init__(self, query_dim, key_dim):
super().__init__()
self.W = nn.Parameter(torch.randn(query_dim, key_dim))
nn.init.xavier_uniform_(self.W)
def forward(self, query, key, value, mask=None):
"""
Args:
query: [batch, query_len, query_dim]
key: [batch, key_len, key_dim]
"""
# scores[b, i, j] = sum_d query[b, i, d] * W[d, e] * key[b, j, e]
scores = torch.matmul(torch.matmul(query, self.W), key.transpose(-2, -1))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.bmm(attn_weights, value)
return output, attn_weights
3.3 局部注意力(Local Attention)¶
Python
class LocalAttention(nn.Module):
"""
局部注意力
只关注窗口内的位置,降低计算复杂度
"""
def __init__(self, window_size=128):
super().__init__()
self.window_size = window_size
def forward(self, query, key, value, mask=None):
"""
每个位置只关注前后window_size/2个位置
"""
batch_size, seq_len, dim = query.shape
half_window = self.window_size // 2
outputs = []
for i in range(seq_len):
# 确定窗口范围
start = max(0, i - half_window)
end = min(seq_len, i + half_window + 1)
# 提取局部key和value
local_key = key[:, start:end, :]
local_value = value[:, start:end, :]
local_query = query[:, i:i+1, :]
# 计算局部注意力
scores = torch.matmul(local_query, local_key.transpose(-2, -1))
scores = scores / math.sqrt(dim)
attn_weights = F.softmax(scores, dim=-1)
local_output = torch.matmul(attn_weights, local_value)
outputs.append(local_output)
return torch.cat(outputs, dim=1)
稀疏注意力机制¶
4.1 稀疏注意力的动机¶
Text Only
标准自注意力的复杂度是O(n²),当序列长度n很大时:
- n=1K: 1M 计算量
- n=4K: 16M 计算量
- n=32K: 1B 计算量
- n=128K: 16B 计算量
稀疏注意力的目标:
- 保持长距离依赖能力
- 将复杂度降至O(n log n)或O(n)
- 只计算"重要"的注意力对
4.2 滑动窗口注意力¶
Python
class SlidingWindowAttention(nn.Module):
"""
滑动窗口注意力(Longformer)
每个token只关注窗口内的邻居
"""
def __init__(self, d_model, num_heads, window_size=512):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.window_size = window_size
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 生成QKV
qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, -1)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 转置为 [batch, heads, seq_len, head_dim]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# 创建滑动窗口mask
window_mask = torch.zeros(seq_len, seq_len, device=x.device)
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
window_mask[i, start:end] = 1
# 计算注意力(只计算窗口内)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
scores = scores.masked_fill(window_mask.unsqueeze(0).unsqueeze(0) == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
return self.out_proj(output)
4.3 全局-局部混合注意力¶
Python
class GlobalLocalAttention(nn.Module):
"""
全局-局部混合注意力(Longformer)
全局token: 可以关注所有位置,被所有位置关注
局部token: 只关注邻居
"""
def __init__(self, d_model, num_heads, window_size=512, num_global_tokens=16):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.window_size = window_size
self.num_global_tokens = num_global_tokens
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 假设前num_global_tokens个是全局token
global_indices = list(range(self.num_global_tokens))
# 生成QKV
qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, -1)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# 创建注意力mask
attn_mask = torch.zeros(seq_len, seq_len, device=x.device)
for i in range(seq_len):
if i in global_indices:
# 全局token可以关注所有位置
attn_mask[i, :] = 1
else:
# 局部token只关注邻居和全局token
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
attn_mask[i, start:end] = 1
attn_mask[i, global_indices] = 1
# 计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
scores = scores.masked_fill(attn_mask.unsqueeze(0).unsqueeze(0) == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
return self.out_proj(output)
4.4 稀疏Transformer模式¶
Python
class SparseTransformerAttention(nn.Module):
"""
稀疏Transformer(Sparse Transformer)
使用稀疏模式:
- 奇数层:行注意力(每个token关注同行)
- 偶数层:列注意力(每个token关注同列)
需要将1D序列reshape为2D
"""
def __init__(self, d_model, num_heads, block_size=32):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.block_size = block_size
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, layer_idx=0):
"""
Args:
x: [batch, seq_len, d_model]
layer_idx: 当前层索引,决定使用行还是列注意力
"""
batch_size, seq_len, _ = x.shape
# 计算2D网格大小(向上取整,确保 grid_size² ≥ seq_len)
grid_size = int(math.ceil(math.sqrt(seq_len)))
padding = grid_size * grid_size - seq_len
if padding > 0:
# 填充到完全平方
x = F.pad(x, (0, 0, 0, padding))
# Reshape为2D [batch, height, width, dim]
x_2d = x.reshape(batch_size, grid_size, grid_size, -1)
# 生成QKV
qkv = self.qkv(x_2d).reshape(batch_size, grid_size, grid_size, 3, self.num_heads, -1)
q, k, v = qkv[:, :, :, 0], qkv[:, :, :, 1], qkv[:, :, :, 2]
if layer_idx % 2 == 0:
# 行注意力
q = q.permute(0, 4, 1, 2, 3) # [batch, heads, height, width, head_dim]
k = k.permute(0, 4, 1, 2, 3)
v = v.permute(0, 4, 1, 2, 3)
# 每行内部做注意力
scores = torch.matmul(q, k.transpose(-2, -1))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.permute(0, 2, 3, 1, 4).reshape(batch_size, grid_size, grid_size, self.d_model)
else:
# 列注意力
q = q.permute(0, 4, 2, 1, 3) # [batch, heads, width, height, head_dim]
k = k.permute(0, 4, 2, 1, 3)
v = v.permute(0, 4, 2, 1, 3)
# 每列内部做注意力
scores = torch.matmul(q, k.transpose(-2, -1))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
output = output.permute(0, 3, 2, 1, 4).reshape(batch_size, grid_size, grid_size, self.d_model)
# Reshape回1D
output = output.reshape(batch_size, -1, self.d_model)[:, :seq_len, :]
return self.out_proj(output)
线性注意力与高效注意力¶
5.1 线性注意力原理¶
Text Only
标准注意力: O(n²) 复杂度
Softmax(QK^T)V
线性注意力: O(n) 复杂度
使用核技巧: φ(Q)φ(K)^T V = φ(Q)(φ(K)^T V)
关键洞察:先计算 φ(K)^T V,复杂度与序列长度无关!
Python
class LinearAttention(nn.Module):
"""
线性注意力(Katharopoulos et al.)
使用特征映射φ将softmax注意力转化为线性复杂度
"""
def __init__(self, dim, num_heads, feature_dim=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.feature_dim = feature_dim or self.head_dim
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def elu_feature_map(self, x):
"""特征映射:ELU + 1"""
return F.elu(x) + 1
def forward(self, x):
batch, seq_len, dim = x.shape
# 生成QKV
qkv = self.qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 转置 [batch, heads, seq_len, head_dim]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# 特征映射
q = self.elu_feature_map(q)
k = self.elu_feature_map(k)
# 线性注意力计算
# KV = Σ_t k_t^T v_t
KV = torch.einsum('bhsk,bhsv->bhkv', k, v)
# Z = Σ_t k_t
Z = k.sum(dim=2)
# 输出 = Q @ KV / (Q @ Z)
numerator = torch.einsum('bhnk,bhkv->bhnv', q, KV)
denominator = torch.einsum('bhnk,bhk->bhn', q, Z).unsqueeze(-1) + 1e-6
out = numerator / denominator
out = out.transpose(1, 2).reshape(batch, seq_len, dim)
return self.proj(out)
5.2 FlashAttention¶
Python
class FlashAttention(nn.Module):
"""
FlashAttention原理说明
FlashAttention不是新的注意力机制,而是IO感知的注意力实现
核心思想:
1. 将输入分块(tiling),避免加载完整的N×N注意力矩阵到GPU HBM
2. 在SRAM(高速缓存)中计算注意力
3. 使用online softmax算法,避免存储中间结果
优势:
- 减少HBM访问次数(从O(N²)到O(N))
- 内存高效(不需要存储N×N注意力矩阵)
- 计算更快(虽然FLOPs相同,但内存访问更少)
实际使用:
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)
"""
def __init__(self, dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x, causal=True):
"""
这里只是接口说明,实际应使用flash_attn库
"""
batch, seq_len, dim = x.shape
# 生成QKV
qkv = self.qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 转置为 [batch, heads, seq_len, head_dim]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# 实际使用时:
# from flash_attn import flash_attn_func
# out = flash_attn_func(q, k, v, causal=causal)
# 这里使用标准注意力作为fallback
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if causal:
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask.to(x.device), float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(batch, seq_len, dim)
return self.proj(out)
现代大模型的注意力优化¶
6.1 MQA (Multi-Query Attention)¶
Python
class MultiQueryAttention(nn.Module):
"""
多查询注意力(Multi-Query Attention)
来自PaLM,所有头共享同一组K和V
优势:
- 大幅减少KV Cache内存占用
- 解码速度更快
劣势:
- 略微降低模型质量
"""
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Q投影到所有头
self.W_q = nn.Linear(d_model, d_model)
# K,V只投影到一个头(共享)
self.W_k = nn.Linear(d_model, self.head_dim)
self.W_v = nn.Linear(d_model, self.head_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch, seq_len, _ = x.shape
# Q: [batch, seq_len, d_model]
q = self.W_q(x)
q = q.reshape(batch, seq_len, self.num_heads, self.head_dim)
q = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
# K,V: [batch, seq_len, head_dim]
k = self.W_k(x).unsqueeze(1) # [batch, 1, seq_len, head_dim]
v = self.W_v(x).unsqueeze(1) # [batch, 1, seq_len, head_dim]
# 广播K,V到所有头
k = k.expand(-1, self.num_heads, -1, -1)
v = v.expand(-1, self.num_heads, -1, -1)
# 计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
return self.W_o(out)
6.2 GQA (Grouped-Query Attention)¶
Python
class GroupedQueryAttention(nn.Module):
"""
分组查询注意力(Grouped-Query Attention)
来自LLaMA-2,MQA和MHA的折中
多个查询头共享一组K,V
例如:32个头,8个K,V组,每组4个头共享
"""
def __init__(self, d_model, num_heads, num_kv_groups=8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_groups = num_kv_groups
self.head_dim = d_model // num_heads
assert num_heads % num_kv_groups == 0 # assert断言:条件False时抛出AssertionError
self.heads_per_group = num_heads // num_kv_groups
# Q投影到所有头
self.W_q = nn.Linear(d_model, d_model)
# K,V投影到组数
kv_dim = self.head_dim * num_kv_groups
self.W_k = nn.Linear(d_model, kv_dim)
self.W_v = nn.Linear(d_model, kv_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch, seq_len, _ = x.shape
# Q: [batch, seq_len, num_heads, head_dim]
q = self.W_q(x).reshape(batch, seq_len, self.num_heads, self.head_dim)
q = q.permute(0, 2, 1, 3) # [batch, num_heads, seq_len, head_dim]
# K,V: [batch, seq_len, num_kv_groups, head_dim]
k = self.W_k(x).reshape(batch, seq_len, self.num_kv_groups, self.head_dim)
v = self.W_v(x).reshape(batch, seq_len, self.num_kv_groups, self.head_dim)
# 转置
k = k.permute(0, 2, 1, 3) # [batch, num_kv_groups, seq_len, head_dim]
v = v.permute(0, 2, 1, 3)
# 重复K,V以匹配Q的头数
k = k.repeat_interleave(self.heads_per_group, dim=1)
v = v.repeat_interleave(self.heads_per_group, dim=1)
# 计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
return self.W_o(out)
注意力可视化与分析¶
7.1 注意力权重可视化¶
Python
import matplotlib.pyplot as plt
import seaborn as sns
class AttentionVisualizer:
"""
注意力可视化工具
"""
@staticmethod
def plot_attention_heatmap(attention_weights, tokens, title="Attention Heatmap"):
"""
绘制注意力热力图
Args:
attention_weights: [seq_len, seq_len]
tokens: token列表
"""
plt.figure(figsize=(10, 8))
sns.heatmap(
attention_weights.cpu().numpy(),
xticklabels=tokens,
yticklabels=tokens,
cmap='viridis',
cbar_kws={'label': 'Attention Weight'}
)
plt.title(title)
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
@staticmethod
def plot_multi_head_attention(attention_weights, tokens, num_heads=8):
"""
绘制多头注意力
Args:
attention_weights: [num_heads, seq_len, seq_len]
"""
上图展示了多头注意力机制的架构。多头注意力将输入分割到多个子空间,每个头独立计算注意力,最后将所有头的输出拼接并通过线性投影层。这种设计允许模型同时关注不同类型的关系和模式,大大增强了模型的表达能力。
Python
fig, axes = plt.subplots(2, num_heads // 2, figsize=(20, 8))
axes = axes.flatten()
for i in range(num_heads):
ax = axes[i]
sns.heatmap(
attention_weights[i].cpu().numpy(),
ax=ax,
cmap='viridis',
cbar=False,
xticklabels=False,
yticklabels=False
)
ax.set_title(f'Head {i+1}')
plt.suptitle('Multi-Head Attention Patterns')
plt.tight_layout()
plt.show()
@staticmethod
def plot_attention_rollout(attention_weights_list, tokens):
"""
注意力展开(Attention Rollout)
累积多层注意力,显示从输入到输出的信息流动
"""
# 初始化单位矩阵
rollout = torch.eye(len(tokens))
for attn in attention_weights_list:
# 注意力归一化
attn = attn + torch.eye(len(tokens)) # 添加残差连接
attn = attn / attn.sum(dim=-1, keepdim=True)
# 累积
rollout = torch.matmul(attn, rollout)
plt.figure(figsize=(10, 8))
sns.heatmap(
rollout.cpu().numpy(),
xticklabels=tokens,
yticklabels=tokens,
cmap='viridis'
)
plt.title('Attention Rollout')
plt.show()
7.2 注意力模式分析¶
Python
class AttentionPatternAnalyzer:
"""
注意力模式分析
"""
@staticmethod
def identify_attention_patterns(attention_weights):
"""
识别注意力模式
返回每种模式的占比
"""
seq_len = attention_weights.shape[0]
patterns = {
'diagonal': 0, # 对角线(关注邻近)
'vertical': 0, # 垂直(关注特定位置)
'block': 0, # 块(关注连续区域)
'uniform': 0, # 均匀分布
'sparse': 0 # 稀疏
}
for i in range(seq_len):
attn_dist = attention_weights[i]
# 对角线模式
if i > 0:
diagonal_score = attn_dist[i-1] + (attn_dist[i+1] if i < seq_len-1 else 0)
else:
diagonal_score = attn_dist[i+1] if i < seq_len-1 else 0
# 垂直模式(某个位置特别高)
max_val = attn_dist.max()
vertical_score = max_val
# 均匀度
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum()
uniform_score = 1.0 / (entropy + 1)
# 稀疏度
sparse_score = (attn_dist > 0.1).float().sum() / seq_len
# 分类
scores = {
'diagonal': diagonal_score,
'vertical': vertical_score,
'uniform': uniform_score,
'sparse': 1 - sparse_score
}
best_pattern = max(scores, key=scores.get)
patterns[best_pattern] += 1
# 归一化
for key in patterns:
patterns[key] /= seq_len
return patterns
@staticmethod
def find_syntactic_patterns(attention_weights, tokens, pos_tags):
"""
发现句法模式
例如:代词-名词、动词-宾语关系
"""
patterns = []
for i, (token, tag) in enumerate(zip(tokens, pos_tags)):
attn_dist = attention_weights[i]
# 找到最关注的token
top_idx = attn_dist.argmax()
# 分析关系
if tag in ['PRP', 'PRP$']: # 代词
if pos_tags[top_idx] in ['NN', 'NNS']: # 名词
patterns.append({
'type': 'pronoun-noun',
'from': (i, token),
'to': (top_idx, tokens[top_idx]),
'weight': attn_dist[top_idx].item()
})
elif tag in ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']: # 动词
if pos_tags[top_idx] in ['NN', 'NNS', 'PRP']: # 名词或代词
patterns.append({
'type': 'verb-object',
'from': (i, token),
'to': (top_idx, tokens[top_idx]),
'weight': attn_dist[top_idx].item()
})
return patterns
注意力机制的扩展与应用¶
8.1 跨模态注意力¶
Python
class CrossModalAttention(nn.Module):
"""
跨模态注意力(用于多模态模型)
例如:视觉-语言模型中,文本查询关注图像特征
"""
def __init__(self, text_dim, image_dim, hidden_dim):
super().__init__()
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.image_proj_k = nn.Linear(image_dim, hidden_dim)
self.image_proj_v = nn.Linear(image_dim, hidden_dim)
def forward(self, text_features, image_features):
"""
Args:
text_features: [batch, text_len, text_dim]
image_features: [batch, num_patches, image_dim]
"""
# Q来自文本
q = self.text_proj(text_features)
# K,V来自图像
k = self.image_proj_k(image_features)
v = self.image_proj_v(image_features)
# 计算跨模态注意力
scores = torch.matmul(q, k.transpose(-2, -1))
attn = F.softmax(scores, dim=-1)
# 加权图像特征
output = torch.matmul(attn, v)
return output, attn
8.2 图注意力(Graph Attention)¶
Python
class GraphAttentionLayer(nn.Module):
"""
图注意力层(GAT)
用于图结构数据的注意力
"""
def __init__(self, in_features, out_features, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.out_features = out_features
self.W = nn.Linear(in_features, out_features * num_heads)
self.a = nn.Parameter(torch.randn(num_heads, 2 * out_features))
def forward(self, x, adj_matrix):
"""
Args:
x: [num_nodes, in_features]
adj_matrix: [num_nodes, num_nodes] 邻接矩阵
"""
num_nodes = x.size(0)
# 线性变换
h = self.W(x) # [num_nodes, out_features * num_heads]
h = h.view(num_nodes, self.num_heads, -1) # [num_nodes, heads, out_features] # view重塑张量形状
# 计算注意力系数
# 对于每条边(i,j),计算e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
attn_scores = []
for head in range(self.num_heads):
h_head = h[:, head, :] # [num_nodes, out_features]
# 计算所有节点对的注意力
# [num_nodes, 1, out_features] + [1, num_nodes, out_features]
concat = torch.cat([
h_head.unsqueeze(1).expand(-1, num_nodes, -1),
h_head.unsqueeze(0).expand(num_nodes, -1, -1)
], dim=-1) # [num_nodes, num_nodes, 2*out_features]
e = torch.matmul(concat, self.a[head]) # [num_nodes, num_nodes]
e = F.leaky_relu(e)
# Mask掉不存在的边
e = e.masked_fill(adj_matrix == 0, float('-inf'))
attn_scores.append(e)
attn_scores = torch.stack(attn_scores, dim=0) # [heads, num_nodes, num_nodes]
attn = F.softmax(attn_scores, dim=-1)
# 聚合邻居特征
out = torch.matmul(attn, h.transpose(0, 1)) # [heads, num_nodes, out_features]
out = out.transpose(0, 1).reshape(num_nodes, -1) # [num_nodes, heads*out_features]
return out
总结¶
注意力机制选择指南¶
Text Only
根据场景选择注意力机制:
短序列(<1K):
├── 标准自注意力 ✅
└── FlashAttention ✅(更快)
中等序列(1K-8K):
├── 滑动窗口注意力 ✅
├── 稀疏注意力 ✅
└── FlashAttention ✅
长序列(8K-32K):
├── 全局-局部混合 ✅
├── 稀疏Transformer ✅
└── 线性注意力 ✅
超长序列(>32K):
├── 线性注意力 ✅
├── Mamba/State Space ✅
└── 分块注意力 ✅
内存受限:
├── MQA ✅(减少KV Cache)
├── GQA ✅(平衡质量和内存)
└── 线性注意力 ✅(O(n)复杂度)
关键要点¶
- 注意力是软寻址:计算相似度,加权聚合
- 复杂度是瓶颈:O(n²)限制了长序列应用
- 稀疏化是方向:只计算重要的注意力对
- 内存访问是关键:FlashAttention优化内存访问模式
- MQA/GQA是工程优化:牺牲少量质量换取大幅速度提升
下一步:继续学习实践-手写Transformer,动手实现各种注意力机制!
最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026

