01 - Transformer深入理解(全面版)¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:从零开始深入理解Transformer的每一个组件,掌握其数学原理、实现细节和设计思想。
📌 定位说明:本章侧重大模型视角下的Transformer深入理解(RoPE/RMSNorm/GeLU/编解码器架构对比等)。Transformer架构的基础教学(从零实现完整Transformer)请参考 深度学习/04-Transformer/02-Transformer架构。
目录¶
- Transformer架构总览
- 输入嵌入层深度解析
- 位置编码机制
- 自注意力机制详解
- 多头注意力机制
- 前馈神经网络
- 层归一化与残差连接
- 编解码器架构对比
- Transformer的训练
- Transformer的推理
Transformer架构总览¶
1.1 为什么需要Transformer¶
在Transformer出现之前,序列建模主要依赖RNN/LSTM:
RNN的问题:
├── 顺序计算,无法并行
├── 长距离依赖困难(梯度消失/爆炸)
└── 计算复杂度与序列长度成正比
Transformer的解决方案:
├── 完全基于注意力机制
├── 完全并行计算
├── 任意位置间距离都是O(1)
└── 成为现代NLP的基础架构
1.2 架构全景图¶
┌─────────────────────────────────────────────────────────────────┐
│ Transformer 架构全景 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 输入序列 (Input) │ │
│ │ [我, 喜欢, 深度, 学习] │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 输入嵌入 + 位置编码 │ │
│ │ (Token Embedding + Positional Encoding) │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌──────────────────┴──────────────────┐ │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Encoder │ │ Decoder │ │
│ │ (编码器堆叠) │ │ (解码器堆叠) │ │
│ │ │ │ │ │
│ │ ┌───────────┐ │ │ ┌───────────┐ │ │
│ │ │ Block N │ │ │ │ Block N │ │ │
│ │ │ ┌───────┐ │ │ │ │ ┌───────┐ │ │ │
│ │ │ │Multi- │ │ │ │ │ │Masked │ │ │ │
│ │ │ │Head │ │ │ │ │ │Multi-H│ │ │ │
│ │ │ │Attn │ │ │ │ │ │Attn │ │ │ │
│ │ │ └───────┘ │ │ │ │ └───────┘ │ │ │
│ │ │ ┌───────┐ │ │ │ │ ┌───────┐ │ │ │
│ │ │ │ Feed │ │ │ │ │ │Multi-H│ │ │ │
│ │ │ │Forward│ │ │ │ │ │Cross │ │ │ │
│ │ │ └───────┘ │ │ │ │ │Attn │ │ │ │
│ │ └───────────┘ │ │ │ └───────┘ │ │ │
│ │ ... │ │ │ ┌───────┐ │ │ │
│ │ ┌───────────┐ │ │ │ │ Feed │ │ │ │
│ │ │ Block 1 │ │ │ │ │Forward│ │ │ │
│ │ └───────────┘ │ │ │ └───────┘ │ │ │
│ └─────────────────┘ │ └───────────┘ │ │
│ │ └─────────────────┘ │
│ │ │ │
│ └──────────────────┬───────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 输出层 (Output) │ │
│ │ Linear + Softmax → 概率分布 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 预测输出 │ │
│ │ [I, like, deep, learning] │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
上图展示了Transformer的完整架构,包括编码器(Encoder)和解码器(Decoder)的堆叠结构,以及自注意力机制和前馈神经网络等核心组件。
1.3 核心组件一览¶
| 组件 | 功能 | 关键参数 | 复杂度 |
|---|---|---|---|
| 输入嵌入 | 将token映射为向量 | vocab_size × d_model | - |
| 位置编码 | 注入位置信息 | d_model | - |
| 多头注意力 | 捕捉不同子空间的关系 | h heads, d_k = d_model/h | O(n²·d) |
| 前馈网络 | 非线性变换 | d_model → 4d_model → d_model | O(n·d²) |
| 层归一化 | 稳定训练 | d_model | O(n·d) |
| 残差连接 | 缓解梯度消失 | - | - |
输入嵌入层深度解析¶
2.1 词嵌入的数学本质¶
词嵌入层就是一个查找表(Lookup Table):
E ∈ ℝ^(V × d)
其中:
- V: 词汇表大小(如32000)
- d: 嵌入维度(如512, 768, 1024)
对于输入token id = i:
embedding = E[i, :] ∈ ℝ^d
这相当于一个one-hot向量与嵌入矩阵的乘法:
embedding = one_hot(i) @ E
2.2 嵌入层的实现细节¶
import torch
import torch.nn as nn
import math
class TokenEmbedding(nn.Module):
"""
Token嵌入层完整实现
"""
def __init__(self, vocab_size, d_model, padding_idx=None):
super().__init__() # super()调用父类方法
self.vocab_size = vocab_size
self.d_model = d_model
# 嵌入矩阵
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=d_model,
padding_idx=padding_idx # 用于mask填充位置
)
# 缩放因子(Transformer原论文使用)
self.scale = math.sqrt(d_model)
# 初始化(重要!)
self._init_weights()
def _init_weights(self):
"""
嵌入层初始化策略
使用N(0, 1/d_model)初始化
"""
nn.init.normal_(self.embedding.weight, mean=0, std=1/math.sqrt(self.d_model))
# 如果有padding_idx,将该位置嵌入置零
if self.embedding.padding_idx is not None:
with torch.no_grad(): # 禁用梯度追踪,避免初始化操作被记入计算图
self.embedding.weight[self.embedding.padding_idx].fill_(0)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len] token IDs
Returns:
embeddings: [batch_size, seq_len, d_model]
"""
# 查找嵌入并缩放
# 缩放原因:后续要与位置编码相加,需要平衡量级
return self.embedding(x) * self.scale
# 使用示例
vocab_size = 32000
d_model = 512
batch_size = 2
seq_len = 10
embedding_layer = TokenEmbedding(vocab_size, d_model)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
embeddings = embedding_layer(input_ids)
print(f"输入形状: {input_ids.shape}")
print(f"嵌入形状: {embeddings.shape}")
print(f"嵌入范围: [{embeddings.min():.3f}, {embeddings.max():.3f}]")
2.3 子词分词与嵌入¶
现代大模型使用子词(Subword)分词,如BPE、WordPiece、SentencePiece:
传统分词的问题:
- "playing" 和 "played" 被视为完全不同的词
- 未登录词(OOV)问题
子词分词的优势:
"playing" → ["play", "ing"]
"unhappiness" → ["un", "happiness"] 或 ["un", "happy", "ness"]
常见分词器:
├── GPT系列: BPE (Byte-Pair Encoding)
├── BERT: WordPiece
├── LLaMA/T5: SentencePiece (Unigram)
└── 中文: 字级别或BPE
# 使用Hugging Face Tokenizer示例
from transformers import AutoTokenizer
# 加载GPT-2的分词器
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "Transformer架构 revolutionized NLP"
tokens = tokenizer.tokenize(text)
print(f"分词结果: {tokens}")
# 输出: ['Trans', 'former', '架构', 'Ġre', 'volution', 'ized', 'ĠN', 'LP']
# 转换为ID
input_ids = tokenizer.encode(text)
print(f"Token IDs: {input_ids}")
# 解码
decoded = tokenizer.decode(input_ids)
print(f"解码结果: {decoded}")
位置编码机制¶
3.1 为什么需要位置编码¶
自注意力的排列不变性:
输入: [A, B, C] → 自注意力 → 输出
输入: [B, A, C] → 自注意力 → 输出(只是顺序变了)
问题:模型无法区分"我打你"和"你打我"
解决方案:注入位置信息
3.2 正弦位置编码(Sinusoidal)¶
Transformer原始论文使用的位置编码:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中:
- pos: 位置(0, 1, 2, ..., max_len-1)
- i: 维度索引(0, 1, 2, ..., d_model/2-1)
- d_model: 模型维度
class SinusoidalPositionalEncoding(nn.Module):
"""
正弦位置编码实现
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵 [max_len, d_model]
pe = torch.zeros(max_len, d_model)
# 位置索引 [max_len, 1]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # unsqueeze增加一个维度
# 维度索引的除数项
# 10000^(2i/d_model) = exp(2i * -log(10000) / d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
# 偶数维度用sin
pe[:, 0::2] = torch.sin(position * div_term)
# 奇数维度用cos
pe[:, 1::2] = torch.cos(position * div_term)
# 注册为buffer(不参与训练)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
"""
seq_len = x.size(1)
# 添加位置编码
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
# 可视化位置编码
import matplotlib.pyplot as plt
def visualize_positional_encoding():
d_model = 128
max_len = 100
pe = SinusoidalPositionalEncoding(d_model, max_len)
encoding = pe.pe[0].numpy() # [max_len, d_model]
plt.figure(figsize=(12, 6))
plt.imshow(encoding, aspect='auto', cmap='viridis')
plt.colorbar()
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Sinusoidal Positional Encoding')
plt.show()
# 观察不同位置的编码
plt.figure(figsize=(12, 4))
for pos in [0, 10, 20, 50]:
plt.plot(encoding[pos], label=f'Pos {pos}')
plt.legend()
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.title('Positional Encoding at Different Positions')
plt.show()
# 正弦位置编码的性质
"""
性质1: 唯一性
每个位置都有唯一的编码
性质2: 相对位置关系
PE(pos+k) 可以表示为 PE(pos) 的线性函数
这意味着模型可以学习相对位置
性质3: 有界性
所有值都在[-1, 1]之间,数值稳定
性质4: 外推性
可以处理训练时未见过的更长序列
"""
3.3 可学习位置编码¶
BERT等模型使用可学习的位置编码:
class LearnedPositionalEncoding(nn.Module):
"""
可学习位置编码
"""
def __init__(self, d_model, max_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 可学习的嵌入矩阵
self.pos_embedding = nn.Embedding(max_len, d_model)
# 初始化
nn.init.normal_(self.pos_embedding.weight, std=0.02)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
"""
seq_len = x.size(1)
# 创建位置索引
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
# 添加位置嵌入
x = x + self.pos_embedding(positions)
return self.dropout(x)
3.4 旋转位置编码 (RoPE)¶
现代大模型(LLaMA、GPT-NeoX等)使用RoPE:
3.4.1 RoPE的数学推导(从复数乘法出发)¶
RoPE的核心目标:设计一种位置编码函数 f(x, pos),使得两个位置的内积
只依赖于相对位置差 (m - n),而非绝对位置 m 和 n:
<f(q, m), f(k, n)> = g(q, k, m - n)
第一步:复数表示
将向量的每两个相邻维度视为一个复数:
z = x_{2i} + j·x_{2i+1} (j 是虚数单位)
第二步:复数乘法 = 旋转
在复数平面上,乘以 e^{jθ} 等价于旋转角度 θ:
z · e^{jθ} = (x_{2i} + j·x_{2i+1})(cosθ + j·sinθ)
= (x_{2i}·cosθ - x_{2i+1}·sinθ) + j·(x_{2i}·sinθ + x_{2i+1}·cosθ)
写成矩阵形式:
[x_{2i}' ] [cosθ -sinθ] [x_{2i} ]
[x_{2i+1}'] = [sinθ cosθ] [x_{2i+1}]
第三步:位置相关的旋转角度
对位置 m 的第 i 对维度,旋转角度 θ_i(m) = m · ω_i
其中 ω_i = 1 / 10000^{2i/d}(频率随维度指数递减)
第四步:验证相对位置性质
对位置 m 的 query 和位置 n 的 key(在第 i 对维度上):
f(q, m) = q · e^{j·m·ω_i}
f(k, n) = k · e^{j·n·ω_i}
内积(复数):
Re[f(q, m) · conj(f(k, n))]
= Re[q · e^{j·m·ω_i} · conj(k · e^{j·n·ω_i})]
= Re[q · conj(k) · e^{j·(m-n)·ω_i}]
结果只依赖 (m-n),证毕。✓
第五步:推广到全维度
对 d 维向量,每两维一组,共 d/2 组,每组用不同频率 ω_i 旋转:
RoPE(x, pos) = [R(pos·ω_0)·x_{0:1}, R(pos·ω_1)·x_{2:3}, ..., R(pos·ω_{d/2-1})·x_{d-2:d-1}]
其中 R(θ) 是 2×2 旋转矩阵。
3.4.2 RoPE实现¶
class RotaryPositionalEmbedding(nn.Module):
"""
旋转位置编码 (RoPE)
通过旋转矩阵注入相对位置信息
"""
def __init__(self, dim, max_seq_len=2048, base=10000):
super().__init__()
# 计算旋转角度
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# 预计算位置编码
t = torch.arange(max_seq_len)
freqs = torch.einsum('i,j->ij', t, inv_freq) # [max_seq_len, dim/2]
# 复数形式: cos + i*sin
self.register_buffer('cos_cached', freqs.cos()[None, None, :, :]) # [1, 1, seq_len, dim/2]
self.register_buffer('sin_cached', freqs.sin()[None, None, :, :]) # [1, 1, seq_len, dim/2]
def forward(self, x, seq_len=None):
"""
Args:
x: [batch, heads, seq_len, head_dim]
"""
if seq_len is None:
seq_len = x.shape[2]
cos = self.cos_cached[:, :, :seq_len, :]
sin = self.sin_cached[:, :, :seq_len, :]
return self.apply_rotary_pos_emb(x, cos, sin)
def apply_rotary_pos_emb(self, x, cos, sin):
"""
应用旋转位置编码
Args:
x: [batch, heads, seq_len, head_dim]
cos: [1, 1, seq_len, head_dim/2]
sin: [1, 1, seq_len, head_dim/2]
Returns:
[batch, heads, seq_len, head_dim]
"""
# 将x分成两部分(偶数维和奇数维)
x1, x2 = x[..., ::2], x[..., 1::2]
# 旋转: [x1, x2] @ [[cos, -sin], [sin, cos]]
# 即: x1' = x1 * cos - x2 * sin
# x2' = x1 * sin + x2 * cos
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)
# 将最后两维展平,恢复原始形状
return rotated.flatten(-2)
上图展示了RoPE(旋转位置编码)的工作原理。RoPE通过旋转矩阵将位置信息注入到查询和键向量中,每个维度对在不同的旋转速度下工作,类似于不同速度的时钟指针。
这张图展示了RoPE在2D平面上的投影,不同颜色代表不同位置的向量,它们从原点发出,展示了实部和虚部之间的关系。RoPE的优势在于能够自然地编码相对位置信息,并且具有良好的外推性能。
自注意力机制详解¶
4.1 直觉理解¶
自注意力的核心思想:让序列中每个位置都"看到"其他所有位置,并学习该关注哪些位置。
以"猫坐在垫子上,它很舒服"为例:
处理"它"时,自注意力权重分布(示意):
猫 坐 在 垫子 上 , 它 很 舒服
0.45 0.05 0.02 0.15 0.03 0.01 0.10 0.04 0.15
→ 模型学会了"它"主要指代"猫"(最高权重),次要关联"垫子"和"舒服"
4.2 Q-K-V 的数学推导¶
三个矩阵的角色: - \(Q\)(Query):当前位置发出的"提问" - \(K\)(Key):每个位置提供的"索引标签" - \(V\)(Value):每个位置存储的"实际内容"
类比图书馆:
Q = 你的搜索关键词("深度学习入门")
K = 每本书的标题/关键词索引
V = 每本书的实际内容
检索流程:
1. 计算 Q @ K^T → 搜索词与每本书标题的匹配分数
2. softmax → 将分数归一化为概率
3. 概率 @ V → 加权提取最相关书的内容
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor, # [batch, seq_q, d_k]
key: torch.Tensor, # [batch, seq_k, d_k]
value: torch.Tensor, # [batch, seq_k, d_v]
mask: torch.Tensor = None,
dropout: nn.Dropout = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
缩放点积注意力
关键点:
1. 为什么除以 sqrt(d_k)?(完整推导)
假设 q 和 k 的每个分量独立同分布,均值为0、方差为1:
q = [q_1, q_2, ..., q_{d_k}], q_i ~ (0, 1)
k = [k_1, k_2, ..., k_{d_k}], k_i ~ (0, 1)
点积 q·k = Σ_{i=1}^{d_k} q_i * k_i
每一项 q_i * k_i 的方差:
Var(q_i * k_i) = E[q_i²] * E[k_i²] - (E[q_i] * E[k_i])²
= 1 * 1 - 0 = 1
因为各项独立,所以:
Var(q·k) = Σ Var(q_i * k_i) = d_k
当 d_k=64 时,点积标准差 ≈ 8;d_k=128 时 ≈ 11.3
这些大值会将 softmax 推入饱和区(梯度接近 0)
除以 √d_k 后:
Var(q·k / √d_k) = Var(q·k) / d_k = d_k / d_k = 1
方差归一化为 1,softmax 的输入保持在合理范围,梯度可以正常流动。
2. mask 的两种用途:
- padding mask: 屏蔽填充位置
- causal mask: 屏蔽未来位置(解码器)
"""
d_k = query.size(-1)
# Step 1: QK^T / sqrt(d_k) → [batch, seq_q, seq_k]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: 应用 mask(将需要屏蔽的位置设为 -inf)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax → 权重矩阵
attn_weights = F.softmax(scores, dim=-1)
# Step 4: Dropout(训练时)
if dropout is not None:
attn_weights = dropout(attn_weights)
# Step 5: 加权求和 → [batch, seq_q, d_v]
output = torch.matmul(attn_weights, value)
return output, attn_weights
# ---------- 验证 ----------
batch, seq_len, d_k = 2, 5, 64
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)
# 无 mask
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {out.shape}") # [2, 5, 64]
print(f"权重形状: {weights.shape}") # [2, 5, 5]
print(f"权重行和: {weights.sum(-1)}") # 每行和为 1.0
# 因果 mask(解码器用)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)) # 下三角矩阵
out_masked, weights_masked = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
print(f"\n因果mask权重(第0行只看自己,第4行看所有位置):")
print(weights_masked[0].detach())
4.3 注意力复杂度分析¶
时间复杂度: O(n² · d)
- 其中 n = 序列长度, d = 模型维度
- 主要来自 QK^T 的矩阵乘法
空间复杂度: O(n²)
- 需要存储完整的注意力权重矩阵
序列长度 → 注意力计算量(d=1024)
128 → 16.7M
1,024 → 1.07B
8,192 → 68.7B
128,000 → 16.8T ← 这就是长上下文的挑战
解决方案:
├── FlashAttention: 利用GPU内存层级优化,O(n²)复杂度不变但实际速度快2-4倍
├── GQA (Grouped-Query Attention): 减少KV头数,降低内存
├── MQA (Multi-Query Attention): 所有头共享一组KV
├── Ring Attention: 分布式环形通信处理超长序列
└── 稀疏注意力: 只计算部分位置对的注意力
多头注意力机制¶
5.1 为什么需要多头¶
单头注意力的问题:一个注意力头只能学到一种"关注模式"。而自然语言有多种关系维度(语法、语义、共指、因果等)。
多头的直觉:用不同的"眼睛"看同一句话
Head 1: 可能学会了语法依赖(主语→动词)
Head 2: 可能学会了共指关系(代词→先行词)
Head 3: 可能学会了修饰关系(形容词→名词)
Head 4: 可能学会了长距离依赖(问题→答案)
...
5.2 多头注意力实现¶
class MultiHeadAttention(nn.Module):
"""
多头注意力机制
参数量分析(d_model=1024, h=16, d_k=64):
W_Q: 1024×1024 = 1M
W_K: 1024×1024 = 1M
W_V: 1024×1024 = 1M
W_O: 1024×1024 = 1M
总计: 4M 参数
"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须被 n_heads 整除" # assert断言:条件False时抛出AssertionError
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度
# 四个投影矩阵
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor, # [batch, seq_q, d_model]
key: torch.Tensor, # [batch, seq_k, d_model]
value: torch.Tensor, # [batch, seq_k, d_model]
mask: torch.Tensor = None,
) -> torch.Tensor:
batch_size = query.size(0)
# 1. 线性投影 → [batch, seq, d_model]
Q = self.W_Q(query)
K = self.W_K(key)
V = self.W_V(value)
# 2. 分头: [batch, seq, d_model] → [batch, n_heads, seq, d_k]
Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # view重塑张量形状
K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 3. 缩放点积注意力
d_k = self.d_k
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 4. 加权求和 → [batch, n_heads, seq_q, d_k]
context = torch.matmul(attn_weights, V)
# 5. 合并多头: [batch, seq_q, d_model]
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 6. 输出投影
output = self.W_O(context)
return output
# 验证
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)
# Self-attention: Q=K=V=x
out = mha(x, x, x)
print(f"Self-attention 输出: {out.shape}") # [2, 10, 512]
# Cross-attention: Q来自decoder, K/V来自encoder
enc_out = torch.randn(2, 20, 512)
dec_in = torch.randn(2, 10, 512)
cross_out = mha(dec_in, enc_out, enc_out)
print(f"Cross-attention 输出: {cross_out.shape}") # [2, 10, 512]
5.3 GQA 与 MQA:现代大模型的注意力优化¶
标准 MHA (Multi-Head Attention):
Q: [batch, n_heads, seq, d_k] ← 每个头有独立的Q
K: [batch, n_heads, seq, d_k] ← 每个头有独立的K
V: [batch, n_heads, seq, d_k] ← 每个头有独立的V
KV cache 大小 = 2 × n_heads × seq × d_k
MQA (Multi-Query Attention, GPT-J):
Q: [batch, n_heads, seq, d_k] ← 每个头有独立的Q
K: [batch, 1, seq, d_k] ← 所有头共享一组K
V: [batch, 1, seq, d_k] ← 所有头共享一组V
KV cache 大小 = 2 × 1 × seq × d_k → 减少 n_heads 倍
GQA (Grouped-Query Attention, LLaMA-2/3):
Q: [batch, n_heads, seq, d_k] ← 每个头有独立的Q
K: [batch, n_kv_heads, seq, d_k] ← 每组共享一组K
V: [batch, n_kv_heads, seq, d_k] ← 每组共享一组V
KV cache 大小 = 2 × n_kv_heads × seq × d_k
示例 (LLaMA-3-70B): n_heads=64, n_kv_heads=8
→ 每8个Q头共享1组KV → KV cache减少8倍,性能接近MHA
前馈神经网络¶
6.1 标准 FFN¶
每个Transformer层中的全连接前馈网络(Position-wise FFN):
class FeedForward(nn.Module):
"""标准FFN,4倍扩展"""
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
d_ff = d_ff or 4 * d_model # 默认4倍扩展
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# [batch, seq, d_model] → [batch, seq, d_ff] → [batch, seq, d_model]
return self.w2(self.dropout(F.relu(self.w1(x))))
6.2 现代激活函数:GeLU 与 SwiGLU¶
# GeLU (GPT-2/BERT 使用)
# GELU(x) = x · Φ(x),其中 Φ 是标准正态分布的CDF
# 比 ReLU 更平滑,不会在 x=0 处产生不可导点
class GeLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
d_ff = d_ff or 4 * d_model
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w2(self.dropout(F.gelu(self.w1(x))))
# SwiGLU (LLaMA/Qwen/Mistral 使用)
# SwiGLU(x, W, V, W2) = (Swish(xW) ⊙ xV) W2
# 引入门控机制,效果优于 GeLU,但参数量增加 50%
class SwiGLUFeedForward(nn.Module):
"""LLaMA-style SwiGLU FFN"""
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
# LLaMA 使用 8/3 * d_model 作为 d_ff 以保持参数量不变
d_ff = d_ff or int(8 / 3 * d_model)
# 取最接近的 256 的倍数(GPU对齐优化)
d_ff = ((d_ff + 255) // 256) * 256
self.w1 = nn.Linear(d_model, d_ff, bias=False) # gate projection
self.w3 = nn.Linear(d_model, d_ff, bias=False) # up projection
self.w2 = nn.Linear(d_ff, d_model, bias=False) # down projection
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Swish(xW1) ⊙ xW3 → 门控机制
gate = F.silu(self.w1(x)) # SiLU = Swish(x) = x * sigmoid(x)
up = self.w3(x)
return self.w2(self.dropout(gate * up))
# 对比
d_model = 1024
ffn_relu = FeedForward(d_model)
ffn_gelu = GeLUFeedForward(d_model)
ffn_swiglu = SwiGLUFeedForward(d_model)
for name, module in [("ReLU FFN", ffn_relu), ("GeLU FFN", ffn_gelu), ("SwiGLU FFN", ffn_swiglu)]:
params = sum(p.numel() for p in module.parameters())
print(f"{name}: {params:,} 参数")
# ReLU FFN: 8,393,728 参数
# GeLU FFN: 8,393,728 参数
# SwiGLU FFN: 8,650,752 参数 (d_ff调整后参数量接近)
层归一化与残差连接¶
7.1 为什么需要归一化¶
深层网络中,每层输出的分布会不断漂移(Internal Covariate Shift),导致训练不稳定。归一化将输出拉回标准分布。
7.2 LayerNorm vs RMSNorm¶
class LayerNorm(nn.Module):
"""
标准层归一化 (BERT/GPT-2 使用)
LayerNorm(x) = γ · (x - μ) / √(σ² + ε) + β
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model)) # 可学习缩放
self.beta = nn.Parameter(torch.zeros(d_model)) # 可学习偏移
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class RMSNorm(nn.Module):
"""
RMS层归一化 (LLaMA/Qwen/Mistral 使用)
RMSNorm(x) = γ · x / RMS(x)
其中 RMS(x) = √(1/n · Σx²)
优势:去掉了均值计算和偏移参数 β
→ 计算更快(约10-15%),效果相当
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return self.gamma * x / rms
# 对比
x = torch.randn(2, 10, 512)
ln = LayerNorm(512)
rmsn = RMSNorm(512)
out_ln = ln(x)
out_rmsn = rmsn(x)
print(f"LayerNorm 输出统计: mean={out_ln.mean():.4f}, std={out_ln.std():.4f}")
print(f"RMSNorm 输出统计: mean={out_rmsn.mean():.4f}, std={out_rmsn.std():.4f}")
7.3 Pre-Norm vs Post-Norm¶
Post-Norm (原始Transformer, BERT):
x → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·) → LayerNorm
优点:最终层的表示经过归一化,理论性质更好
缺点:深层训练不稳定,需要 warm-up
Pre-Norm (GPT-2, LLaMA, 现代模型):
x → LayerNorm → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·)
优点:梯度直接通过残差传播,训练更稳定
缺点:理论上收敛性稍差,但实践中更常用
现代大模型几乎都使用 Pre-Norm + RMSNorm 的组合。
class TransformerBlock(nn.Module):
"""Pre-Norm Transformer Block (LLaMA-style)"""
def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm2 = RMSNorm(d_model)
self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)
def forward(self, x, mask=None):
# Pre-Norm + Residual
x_norm = self.norm1(x)
x = x + self.attn(x_norm, x_norm, x_norm, mask)
x = x + self.ffn(self.norm2(x))
return x
编解码器架构对比¶
8.1 三种主流架构¶
这是理解现代大模型最关键的知识点之一:
┌────────────────────────────────────────────────────────────────┐
│ 三种 Transformer 架构 │
├──────────────┬──────────────┬──────────────┬──────────────────┤
│ │ Encoder-Only │ Decoder-Only │ Encoder-Decoder │
├──────────────┼──────────────┼──────────────┼──────────────────┤
│ 代表模型 │ BERT │ GPT/LLaMA │ T5/BART │
│ 注意力类型 │ 双向全注意力 │ 因果注意力 │ 双向+因果+交叉 │
│ 训练目标 │ MLM + NSP │ 下一token预测 │ Seq2Seq │
│ 擅长任务 │ 分类/NER/QA │ 生成/对话 │ 翻译/摘要 │
│ 上下文 │ 看到全部输入 │ 只看到左侧 │ 编码器全部,解码器左侧│
│ 当前趋势 │ 逐渐减少 │ 主流(GPT时代) │ 特定场景使用 │
└──────────────┴──────────────┴──────────────┴──────────────────┘
为什么 Decoder-Only 成为主流?
1. 统一的训练目标:所有任务都可以转化为"文本生成"
2. 规模效应更好:参数量越大,涌现能力越强
3. 少样本能力:通过 prompt 即可适配新任务,无需微调
4. 工程简单:只需一个模型处理所有任务
8.2 因果注意力掩码¶
Decoder-Only 模型的核心:确保位置 \(i\) 只能看到位置 \(0, 1, \ldots, i\)(不泄露未来信息)。
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
创建因果注意力掩码
例如 seq_len=4:
[[1, 0, 0, 0], ← 位置0只看自己
[1, 1, 0, 0], ← 位置1看0和1
[1, 1, 1, 0], ← 位置2看0、1、2
[1, 1, 1, 1]] ← 位置3看所有
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# 训练时:一次前向传播同时预测所有位置
# 因为有因果mask,每个位置只看到左侧,等价于自回归
Transformer的训练¶
9.1 训练目标¶
Decoder-Only (如 GPT) 的训练:
输入: [BOS] The cat sat on the mat
标签: The cat sat on the mat [EOS]
损失函数: Cross-Entropy Loss
L = -1/T Σ_t log P(x_t | x_{<t})
即最大化每个位置正确预测下一个token的对数概率。
整个序列只需一次前向传播 (teacher forcing)。
9.2 学习率调度¶
Transformer 训练中最关键的超参数管理:
class WarmupCosineScheduler:
"""
Warmup + Cosine Decay 学习率调度
几乎所有现代大模型都使用这种方式
阶段1 (warmup): 学习率从0线性增长到max_lr
阶段2 (cosine): 学习率从max_lr余弦衰减到min_lr
"""
def __init__(
self,
optimizer,
warmup_steps: int,
total_steps: int,
max_lr: float = 3e-4,
min_lr: float = 1e-5,
):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.max_lr = max_lr
self.min_lr = min_lr
self.step_count = 0
def step(self):
self.step_count += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def get_lr(self):
if self.step_count < self.warmup_steps:
# 线性 warmup
return self.max_lr * self.step_count / self.warmup_steps
else:
# 余弦衰减
progress = (self.step_count - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
1 + math.cos(math.pi * progress)
)
# 训练循环骨架
"""
model = DecoderOnlyTransformer(...)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=2000, total_steps=100000)
for batch in dataloader:
input_ids = batch['input_ids'] # [B, T]
targets = input_ids[:, 1:] # 右移一位
logits = model(input_ids[:, :-1]) # [B, T-1, vocab]
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 梯度裁剪
optimizer.step()
scheduler.step()
optimizer.zero_grad()
"""
Transformer的推理¶
10.1 自回归生成¶
训练是并行的(一次处理整个序列),
推理是自回归的(一次生成一个token):
Step 0: 输入 [BOS] → 预测 "The"
Step 1: 输入 [BOS, The] → 预测 "cat"
Step 2: 输入 [BOS, The, cat] → 预测 "sat"
...
问题:每一步都要重新计算之前所有位置的注意力 → 巨大浪费
解决:KV Cache
10.2 KV Cache¶
KV Cache 是大模型推理最重要的优化技术:
"""KV Cache 原理示意"""
class CachedMultiHeadAttention(nn.Module):
"""带KV Cache的多头注意力(推理用)"""
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, kv_cache=None):
"""
Args:
x: [batch, 1, d_model] — 推理时只有当前新token
kv_cache: (cached_K, cached_V) — 之前步骤的K/V
Returns:
output, new_kv_cache
"""
B = x.size(0)
# 只为当前 token 计算 Q/K/V
Q = self.W_Q(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
K_new = self.W_K(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
V_new = self.W_V(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
K_cached, V_cached = kv_cache
# 拼接历史 KV → [batch, n_heads, seq_so_far+1, d_k]
K = torch.cat([K_cached, K_new], dim=2)
V = torch.cat([V_cached, V_new], dim=2)
else:
K, V = K_new, V_new
# 只计算当前 Q 与所有 K 的注意力 → O(1×S) 而非 O(S×S)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model)
output = self.W_O(out)
return output, (K, V) # 返回更新后的 cache
"""
无 KV Cache: 生成 T 个token → 总计算量 O(T² · d)
有 KV Cache: 生成 T 个token → 总计算量 O(T · d) + 缓存内存
速度提升: ~T倍(序列越长,加速越明显)
代价: 需要额外内存存储 KV Cache
LLaMA-2-7B, seq_len=4096:
KV cache = 2(K+V) × 32(layers) × 32(heads) × 4096(seq) × 128(d_k) × 2(bytes/fp16)
≈ 2GB
"""
10.3 解码策略¶
def generate(
model, tokenizer, prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
):
"""
完整的文本生成函数,支持多种解码策略
"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
kv_cache = None
generated = []
for _ in range(max_new_tokens):
# 只输入最新 token(有 KV Cache 时)
if kv_cache is not None:
x = input_ids[:, -1:]
else:
x = input_ids
logits, kv_cache = model(x, kv_cache=kv_cache)
logits = logits[:, -1, :] # 只取最后一个位置
# Temperature scaling
logits = logits / temperature
# Top-K: 只保留概率最高的K个token
if top_k > 0:
topk_logits, topk_indices = torch.topk(logits, top_k)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(1, topk_indices, topk_logits)
# Top-P (Nucleus Sampling)
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 移除累积概率超过 top_p 的token
remove_mask = cumulative_probs - sorted_probs > top_p
sorted_probs[remove_mask] = 0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
# 采样
next_token = torch.multinomial(sorted_probs, 1)
next_token = sorted_indices.gather(1, next_token)
if next_token.item() == tokenizer.eos_token_id:
break
generated.append(next_token.item())
input_ids = torch.cat([input_ids, next_token], dim=-1)
return tokenizer.decode(generated)
延伸阅读¶
- 完整实现:如需从零手写完整 Transformer 并训练,请参考 03-手写Transformer完整实现
- 注意力机制扩展:FlashAttention、稀疏注意力等高级主题,请参考 02-注意力机制详解
- 大模型训练进阶:分布式训练、混合精度等,请参考 大模型核心技术
最后更新日期:2025-07-11 适用版本:LLM学习教程 v2025


