01 - Transformer 深入理解¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:从零开始深入理解 Transformer 的每一个组件,掌握其数学原理、实现细节和设计思想。
📌 定位说明:本章侧重大模型视角下的 Transformer 深入理解( RoPE/RMSNorm/GeLU/编解码器架构对比等)。 Transformer 架构的基础教学(从零实现完整 Transformer )请参考 深度学习/04-Transformer/02-Transformer 架构。
目录¶
- Transformer 架构总览
- 输入嵌入层深度解析
- 位置编码机制
- 自注意力机制详解
- 多头注意力机制
- 前馈神经网络
- 层归一化与残差连接
- 编解码器架构对比
- Transformer 的训练
- Transformer 的推理
Transformer 架构总览¶
1.1 为什么需要 Transformer¶
在 Transformer 出现之前,NLP 领域的序列建模主要依赖 RNN(循环神经网络)和 LSTM(长短期记忆网络)。这些模型虽然在当时取得了不错的效果,但存在两个根本性的缺陷:
缺陷一:无法并行计算。 RNN 的计算是严格串行的——必须先处理完第 \(t\) 步,才能处理第 \(t+1\) 步。这意味着即使你有成百上千块 GPU,也无法加速单个序列的处理。对于长文本(如一整篇文章),这种串行瓶颈极其严重。
缺陷二:长距离依赖困难。 理论上 LSTM 通过门控机制可以记忆远距离信息,但在实践中,当序列长度超过 100-200 时,梯度在反向传播过程中会不断衰减或爆炸(即梯度消失/爆炸问题),导致模型难以学习到远距离 token 之间的依赖关系。例如在"猫,那个毛茸茸的、总是追着球跑的小家伙,正坐在窗台上"中,模型需要将"猫"和" sitting "关联起来,中间隔了十几个词。
2017 年,Vaswani 等人在论文《Attention Is All You Need》中提出了 Transformer,用纯注意力机制彻底替代了 RNN/LSTM:
RNN 的问题与 Transformer 的解决方案:
┌──────────────────────┬──────────────────────────────────────┐
│ RNN/LSTM 的缺陷 │ Transformer 的解决方案 │
├──────────────────────┼──────────────────────────────────────┤
│ 顺序计算,无法并行 │ 完全基于矩阵运算,天然支持并行 │
│ 长距离依赖困难 │ 任意两个位置间的注意力计算都是 O(1) │
│ │ 距离不再是障碍 │
│ 梯度消失/爆炸 │ 残差连接 + 层归一化确保梯度稳定 │
│ 固定长度瓶颈 │ 灵活处理变长序列 │
└──────────────────────┴──────────────────────────────────────┘
Transformer 不仅解决了上述问题,更重要的是它奠定了一个可扩展的架构基础——通过简单地增加层数、参数量和训练数据,就能持续提升模型性能。这一特性直接催生了 GPT、BERT、LLaMA 等划时代的模型,使 Transformer 成为现代 AI 最重要的基础架构之一。
1.2 架构全景图¶
┌─────────────────────────────────────────────────────────────────┐
│ Transformer 架构全景 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 输入序列 (Input) │ │
│ │ [我, 喜欢, 深度, 学习] │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 输入嵌入 + 位置编码 │ │
│ │ (Token Embedding + Positional Encoding) │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌──────────────────┴──────────────────┐ │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Encoder │ │ Decoder │ │
│ │ (编码器堆叠) │ │ (解码器堆叠) │ │
│ │ │ │ │ │
│ │ ┌───────────┐ │ │ ┌───────────┐ │ │
│ │ │ Block N │ │ │ │ Block N │ │ │
│ │ │ ┌───────┐ │ │ │ │ ┌───────┐ │ │ │
│ │ │ │Multi- │ │ │ │ │ │Masked │ │ │ │
│ │ │ │Head │ │ │ │ │ │Multi-H│ │ │ │
│ │ │ │Attn │ │ │ │ │ │Attn │ │ │ │
│ │ │ └───────┘ │ │ │ │ └───────┘ │ │ │
│ │ │ ┌───────┐ │ │ │ │ ┌───────┐ │ │ │
│ │ │ │ Feed │ │ │ │ │ │Multi-H│ │ │ │
│ │ │ │Forward│ │ │ │ │ │Cross │ │ │ │
│ │ │ └───────┘ │ │ │ │ │Attn │ │ │ │
│ │ └───────────┘ │ │ │ └───────┘ │ │ │
│ │ ... │ │ │ ┌───────┐ │ │ │
│ │ ┌───────────┐ │ │ │ │ Feed │ │ │ │
│ │ │ Block 1 │ │ │ │ │Forward│ │ │ │
│ │ └───────────┘ │ │ │ └───────┘ │ │ │
│ └─────────────────┘ │ └───────────┘ │ │
│ │ └─────────────────┘ │
│ │ │ │
│ └──────────────────┬───────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 输出层 (Output) │ │
│ │ Linear + Softmax → 概率分布 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 预测输出 │ │
│ │ [I, like, deep, learning] │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
图源:TensorFlow 教程 - Neural machine translation with a Transformer and Keras,图片文件
transformer.png,许可 CC BY 4.0。
上图展示了 Transformer 的完整架构,包括编码器( Encoder )和解码器( Decoder )的堆叠结构,以及自注意力机制和前馈神经网络等核心组件。
1.3 核心组件一览¶
| 组件 | 功能 | 关键参数 | 复杂度 |
|---|---|---|---|
| 输入嵌入 | 将 token 映射为向量 | vocab_size × d_model | - |
| 位置编码 | 注入位置信息 | d_model | - |
| 多头注意力 | 捕捉不同子空间的关系 | h heads, d_k = d_model/h | O(n²·d) |
| 前馈网络 | 非线性变换 | d_model → 4d_model → d_model | O(n·d²) |
| 层归一化 | 稳定训练 | d_model | O(n·d) |
| 残差连接 | 缓解梯度消失 | - | - |
输入嵌入层深度解析¶
2.1 词嵌入的数学本质¶
词嵌入层就是一个查找表(Lookup Table):
E ∈ ℝ^(V × d)
其中:
- V: 词汇表大小(如32000)
- d: 嵌入维度(如512, 768, 1024)
对于输入token id = i:
embedding = E[i, :] ∈ ℝ^d
这相当于一个one-hot向量与嵌入矩阵的乘法:
embedding = one_hot(i) @ E
2.2 嵌入层的实现细节¶
import torch
import torch.nn as nn
import math
class TokenEmbedding(nn.Module):
"""
Token嵌入层完整实现
"""
def __init__(self, vocab_size, d_model, padding_idx=None):
super().__init__() # super()调用父类方法
self.vocab_size = vocab_size
self.d_model = d_model
# 嵌入矩阵
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=d_model,
padding_idx=padding_idx # 用于mask填充位置
)
# 缩放因子(Transformer原论文使用)
self.scale = math.sqrt(d_model)
# 初始化(重要!)
self._init_weights()
def _init_weights(self):
"""
嵌入层初始化策略
使用N(0, 1/d_model)初始化
"""
nn.init.normal_(self.embedding.weight, mean=0, std=1/math.sqrt(self.d_model))
# 如果有padding_idx,将该位置嵌入置零
if self.embedding.padding_idx is not None:
with torch.no_grad(): # 禁用梯度追踪,避免初始化操作被记入计算图
self.embedding.weight[self.embedding.padding_idx].fill_(0)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len] token IDs
Returns:
embeddings: [batch_size, seq_len, d_model]
"""
# 查找嵌入并缩放
# 缩放原因:后续要与位置编码相加,需要平衡量级
return self.embedding(x) * self.scale
# 使用示例
vocab_size = 32000
d_model = 512
batch_size = 2
seq_len = 10
embedding_layer = TokenEmbedding(vocab_size, d_model)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
embeddings = embedding_layer(input_ids)
print(f"输入形状: {input_ids.shape}")
print(f"嵌入形状: {embeddings.shape}")
print(f"嵌入范围: [{embeddings.min():.3f}, {embeddings.max():.3f}]")
2.3 子词分词与嵌入¶
现代大模型使用子词( Subword )分词,如 BPE 、 WordPiece 、 SentencePiece :
传统分词的问题:
- "playing" 和 "played" 被视为完全不同的词
- 未登录词(OOV)问题
子词分词的优势:
"playing" → ["play", "ing"]
"unhappiness" → ["un", "happiness"] 或 ["un", "happy", "ness"]
常见分词器:
├── GPT系列: BPE (Byte-Pair Encoding)
├── BERT: WordPiece
├── LLaMA/T5: SentencePiece (Unigram)
└── 中文: 字级别或BPE
# 使用Hugging Face Tokenizer示例
from transformers import AutoTokenizer
# 加载GPT-2的分词器
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "Transformer架构 revolutionized NLP"
tokens = tokenizer.tokenize(text)
print(f"分词结果: {tokens}")
# 输出: ['Trans', 'former', '架构', 'Ġre', 'volution', 'ized', 'ĠN', 'LP']
# 转换为ID
input_ids = tokenizer.encode(text)
print(f"Token IDs: {input_ids}")
# 解码
decoded = tokenizer.decode(input_ids)
print(f"解码结果: {decoded}")
位置编码机制¶
3.1 为什么需要位置编码¶
自注意力的排列不变性:
输入: [A, B, C] → 自注意力 → 输出
输入: [B, A, C] → 自注意力 → 输出(只是顺序变了)
问题:模型无法区分"我打你"和"你打我"
解决方案:注入位置信息
3.2 正弦位置编码( Sinusoidal )¶
Transformer 原始论文使用的位置编码:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中:
- pos: 位置(0, 1, 2, ..., max_len-1)
- i: 维度索引(0, 1, 2, ..., d_model/2-1)
- d_model: 模型维度
class SinusoidalPositionalEncoding(nn.Module):
"""
正弦位置编码实现
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵 [max_len, d_model]
pe = torch.zeros(max_len, d_model)
# 位置索引 [max_len, 1]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # unsqueeze增加一个维度
# 维度索引的除数项
# 10000^(2i/d_model) = exp(2i * -log(10000) / d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
# 偶数维度用sin
pe[:, 0::2] = torch.sin(position * div_term)
# 奇数维度用cos
pe[:, 1::2] = torch.cos(position * div_term)
# 注册为buffer(不参与训练)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
"""
seq_len = x.size(1)
# 添加位置编码
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
# 可视化位置编码
import matplotlib.pyplot as plt
def visualize_positional_encoding():
d_model = 128
max_len = 100
pe = SinusoidalPositionalEncoding(d_model, max_len)
encoding = pe.pe[0].numpy() # [max_len, d_model]
plt.figure(figsize=(12, 6))
plt.imshow(encoding, aspect='auto', cmap='viridis')
plt.colorbar()
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Sinusoidal Positional Encoding')
plt.show()
# 观察不同位置的编码
plt.figure(figsize=(12, 4))
for pos in [0, 10, 20, 50]:
plt.plot(encoding[pos], label=f'Pos {pos}')
plt.legend()
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.title('Positional Encoding at Different Positions')
plt.show()
正弦位置编码的四个关键性质¶
性质1:唯一性 — 每个位置都有唯一的编码向量。由于不同维度使用不同频率的正弦/余弦函数,任意两个位置的编码都不会完全相同。
性质2:相对位置关系 — 这是正弦编码最重要的性质。可以证明 \(PE(pos+k)\) 可以表示为 \(PE(pos)\) 的线性函数:
这意味着模型不需要显式编码绝对位置,只需通过学习一组线性变换就能捕捉相对位置关系。
性质3:有界性 — 所有值都在 \([-1, 1]\) 之间,数值稳定,不会导致梯度爆炸。
性质4:外推性 — 理论上可以处理训练时未见过的更长序列(因为公式对任意 \(pos\) 都有定义),但实践中外推能力有限——这就是 RoPE 等改进方案被提出的原因之一。
3.3 可学习位置编码¶
BERT 等模型使用可学习的位置编码:
class LearnedPositionalEncoding(nn.Module):
"""
可学习位置编码
"""
def __init__(self, d_model, max_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 可学习的嵌入矩阵
self.pos_embedding = nn.Embedding(max_len, d_model)
# 初始化
nn.init.normal_(self.pos_embedding.weight, std=0.02)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, d_model]
"""
seq_len = x.size(1)
# 创建位置索引
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
# 添加位置嵌入
x = x + self.pos_embedding(positions)
return self.dropout(x)
3.4 旋转位置编码 (RoPE)¶
现代大模型( LLaMA 、 GPT-NeoX 等)使用 RoPE :
3.4.1 RoPE 的数学推导(从复数乘法出发)¶
RoPE的核心目标:设计一种位置编码函数 f(x, pos),使得两个位置的内积
只依赖于相对位置差 (m - n),而非绝对位置 m 和 n:
<f(q, m), f(k, n)> = g(q, k, m - n)
第一步:复数表示
将向量的每两个相邻维度视为一个复数:
z = x_{2i} + j·x_{2i+1} (j 是虚数单位)
第二步:复数乘法 = 旋转
在复数平面上,乘以 e^{jθ} 等价于旋转角度 θ:
z · e^{jθ} = (x_{2i} + j·x_{2i+1})(cosθ + j·sinθ)
= (x_{2i}·cosθ - x_{2i+1}·sinθ) + j·(x_{2i}·sinθ + x_{2i+1}·cosθ)
写成矩阵形式:
[x_{2i}' ] [cosθ -sinθ] [x_{2i} ]
[x_{2i+1}'] = [sinθ cosθ] [x_{2i+1}]
第三步:位置相关的旋转角度
对位置 m 的第 i 对维度,旋转角度 θ_i(m) = m · ω_i
其中 ω_i = 1 / 10000^{2i/d}(频率随维度指数递减)
第四步:验证相对位置性质
对位置 m 的 query 和位置 n 的 key(在第 i 对维度上):
f(q, m) = q · e^{j·m·ω_i}
f(k, n) = k · e^{j·n·ω_i}
内积(复数):
Re[f(q, m) · conj(f(k, n))]
= Re[q · e^{j·m·ω_i} · conj(k · e^{j·n·ω_i})]
= Re[q · conj(k) · e^{j·(m-n)·ω_i}]
结果只依赖 (m-n),证毕。✓
第五步:推广到全维度
对 d 维向量,每两维一组,共 d/2 组,每组用不同频率 ω_i 旋转:
RoPE(x, pos) = [R(pos·ω_0)·x_{0:1}, R(pos·ω_1)·x_{2:3}, ..., R(pos·ω_{d/2-1})·x_{d-2:d-1}]
其中 R(θ) 是 2×2 旋转矩阵。
3.4.2 RoPE 实现¶
class RotaryPositionalEmbedding(nn.Module):
"""
旋转位置编码 (RoPE)
通过旋转矩阵注入相对位置信息
"""
def __init__(self, dim, max_seq_len=2048, base=10000):
super().__init__()
# 计算旋转角度
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# 预计算位置编码
t = torch.arange(max_seq_len)
freqs = torch.einsum('i,j->ij', t, inv_freq) # [max_seq_len, dim/2]
# 复数形式: cos + i*sin
self.register_buffer('cos_cached', freqs.cos()[None, None, :, :]) # [1, 1, seq_len, dim/2]
self.register_buffer('sin_cached', freqs.sin()[None, None, :, :]) # [1, 1, seq_len, dim/2]
def forward(self, x, seq_len=None):
"""
Args:
x: [batch, heads, seq_len, head_dim]
"""
if seq_len is None:
seq_len = x.shape[2]
cos = self.cos_cached[:, :, :seq_len, :]
sin = self.sin_cached[:, :, :seq_len, :]
return self.apply_rotary_pos_emb(x, cos, sin)
def apply_rotary_pos_emb(self, x, cos, sin):
"""
应用旋转位置编码
Args:
x: [batch, heads, seq_len, head_dim]
cos: [1, 1, seq_len, head_dim/2]
sin: [1, 1, seq_len, head_dim/2]
Returns:
[batch, heads, seq_len, head_dim]
"""
# 将x分成两部分(偶数维和奇数维)
x1, x2 = x[..., ::2], x[..., 1::2]
# 旋转: [x1, x2] @ [[cos, -sin], [sin, cos]]
# 即: x1' = x1 * cos - x2 * sin
# x2' = x1 * sin + x2 * cos
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)
# 将最后两维展平,恢复原始形状
return rotated.flatten(-2)
RoPE (旋转位置编码)的核心是:把相邻两个维度视作一个二维平面上的坐标对,再通过位置相关的旋转矩阵把位置信息注入到查询和键向量中。与其依赖质量一般的示意图,不如直接结合上面的代码和后面的公式理解“二维旋转 + 频率分层”这两个关键点。
自注意力机制详解¶
4.1 直觉理解¶
自注意力的核心思想:让序列中每个位置都"看到"其他所有位置,并学习该关注哪些位置。
以"猫坐在垫子上,它很舒服"为例:
处理"它"时,自注意力权重分布(示意):
猫 坐 在 垫子 上 , 它 很 舒服
0.45 0.05 0.02 0.15 0.03 0.01 0.10 0.04 0.15
→ 模型学会了"它"主要指代"猫"(最高权重),次要关联"垫子"和"舒服"
4.2 Q-K-V 的数学推导¶
三个矩阵的角色: - \(Q\)( Query ):当前位置发出的"提问" - \(K\)( Key ):每个位置提供的"索引标签" - \(V\)( Value ):每个位置存储的"实际内容"
类比图书馆:
Q = 你的搜索关键词("深度学习入门")
K = 每本书的标题/关键词索引
V = 每本书的实际内容
检索流程:
1. 计算 Q @ K^T → 搜索词与每本书标题的匹配分数
2. softmax → 将分数归一化为概率
3. 概率 @ V → 加权提取最相关书的内容
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor, # [batch, seq_q, d_k]
key: torch.Tensor, # [batch, seq_k, d_k]
value: torch.Tensor, # [batch, seq_k, d_v]
mask: torch.Tensor = None,
dropout: nn.Dropout = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
缩放点积注意力
关键点:
1. 为什么除以 sqrt(d_k)?(完整推导)
假设 q 和 k 的每个分量独立同分布,均值为0、方差为1:
q = [q_1, q_2, ..., q_{d_k}], q_i ~ (0, 1)
k = [k_1, k_2, ..., k_{d_k}], k_i ~ (0, 1)
点积 q·k = Σ_{i=1}^{d_k} q_i * k_i
每一项 q_i * k_i 的方差:
Var(q_i * k_i) = E[q_i²] * E[k_i²] - (E[q_i] * E[k_i])²
= 1 * 1 - 0 = 1
因为各项独立,所以:
Var(q·k) = Σ Var(q_i * k_i) = d_k
当 d_k=64 时,点积标准差 ≈ 8;d_k=128 时 ≈ 11.3
这些大值会将 softmax 推入饱和区(梯度接近 0)
除以 √d_k 后:
Var(q·k / √d_k) = Var(q·k) / d_k = d_k / d_k = 1
方差归一化为 1,softmax 的输入保持在合理范围,梯度可以正常流动。
2. mask 的两种用途:
- padding mask: 屏蔽填充位置
- causal mask: 屏蔽未来位置(解码器)
"""
d_k = query.size(-1)
# Step 1: QK^T / sqrt(d_k) → [batch, seq_q, seq_k]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: 应用 mask(将需要屏蔽的位置设为 -inf)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax → 权重矩阵
attn_weights = F.softmax(scores, dim=-1)
# Step 4: Dropout(训练时)
if dropout is not None:
attn_weights = dropout(attn_weights)
# Step 5: 加权求和 → [batch, seq_q, d_v]
output = torch.matmul(attn_weights, value)
return output, attn_weights
# ---------- 验证 ----------
batch, seq_len, d_k = 2, 5, 64
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)
# 无 mask
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {out.shape}") # [2, 5, 64]
print(f"权重形状: {weights.shape}") # [2, 5, 5]
print(f"权重行和: {weights.sum(-1)}") # 每行和为 1.0
# 因果 mask(解码器用)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)) # 下三角矩阵
out_masked, weights_masked = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
print(f"\n因果mask权重(第0行只看自己,第4行看所有位置):")
print(weights_masked[0].detach())
4.3 注意力复杂度分析¶
时间复杂度: O(n² · d)
- 其中 n = 序列长度, d = 模型维度
- 主要来自 QK^T 的矩阵乘法
空间复杂度: O(n²)
- 需要存储完整的注意力权重矩阵
序列长度 → 注意力计算量(d=1024)
128 → 16.7M
1,024 → 1.07B
8,192 → 68.7B
128,000 → 16.8T ← 这就是长上下文的挑战
解决方案:
├── FlashAttention: 利用GPU内存层级优化,O(n²)复杂度不变但实际速度快2-4倍
├── GQA (Grouped-Query Attention): 减少KV头数,降低内存
├── MQA (Multi-Query Attention): 所有头共享一组KV
├── Ring Attention: 分布式环形通信处理超长序列
└── 稀疏注意力: 只计算部分位置对的注意力
4.3.1 GPU内存层次结构与访问瓶颈¶
理解FlashAttention的关键在于GPU的内存层次结构:
GPU内存带宽对比(典型A100 GPU):
| 内存类型 | 带宽 | 延迟 |
|----------|------|------|
| HBM (High Bandwidth Memory) | ~1.6 TB/s | ~500ns |
| L2 Cache | ~4.6 TB/s | ~150ns |
| L1 Cache / SRAM | ~10-20 TB/s | ~10-30ns |
关键洞察:
- L1 Cache带宽是HBM的10倍以上
- 延迟差异达50倍
- 减少HBM访问是提升性能的关键
4.3.2 标准Attention的IO复杂度¶
标准Self-Attention的内存访问模式:
标准Attention计算流程:
1. 读取 Q, K, V 矩阵: O(n·d)
2. 计算 S = QK^T: O(n²·d)
3. 计算 P = softmax(S): O(n²)
4. 计算 O = PV: O(n²·d)
5. 存储完整注意力矩阵 S 和 P: O(n²)
总内存访问量: Θ(n² + n·d)
问题分析: - 当 n = 序列长度(如8192),d = 模型维度(如128)时 - n² 远大于 n·d(如8192² = 67M >> 8192×128 = 1M) - 注意力矩阵存储成为真正的瓶颈
4.3.3 FlashAttention的分块策略¶
FlashAttention核心思想:分块计算,避免存储n²矩阵
分块计算流程:
┌─────────────────────────────────────────────────────────────┐
│ FlashAttention 分块策略 │
├─────────────────────────────────────────────────────────────┤
│ │
│ HBM (全局显存) SRAM (片上缓存) │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Q[0:Br] │ load │ Q_block │ │
│ │ K[0:Bc] │ ──────→ │ K_block │ │
│ │ V[0:Bc] │ │ V_block │ │
│ │ O[0:Br] │ write │ O_block │ ← accumulate │
│ └──────────────┘ ←───── └──────────────┘ │
│ │
│ 块大小限制: SRAM大小(如A100 L1=192KB) │
│ Br × d + Bc × d + Br × Bc ≤ 192KB │
└─────────────────────────────────────────────────────────────┘
伪代码:
for block_q in blocks(Q, block_size=Br):
load Q_block to SRAM
for block_kv in blocks(K, V, block_size=Bc):
load K_block, V_block to SRAM
compute partial_attention(Q_block, K_block, V_block)
accumulate result in output_block
write output_block to HBM
关键优势:
- 每个K、V块只从HBM读取一次
- 不存储完整的S和P矩阵
- 输出直接累积在SRAM中
4.3.4 FlashAttention的IO复杂度分析¶
FlashAttention的IO复杂度推导:
设:
- M = SRAM大小(如192KB)
- B = 单块可容纳元素数,B ≈ M / (2d)(考虑Q、K、V、O同时存在)
标准Attention:
- HBM访问量: O(n² + n·d)
- 中间结果存储: O(n²)
FlashAttention:
- 将Q分成T = n/B个块
- 每个Q块处理时,需遍历所有K、V块
- 每个K、V块访问一次 → O(n²·d / B)
- Q、O访问 → O(n·d)
总HBM访问量: O(n·d + n²·d / M)
加速比估算:
- 理想情况下,FlashAttention加速比约为 B/(2d) 倍
- 实际上由于计算重叠和内存访问优化,可达3-4倍加速
4.3.5 实际速度提升数据¶
FlashAttention vs 标准Attention(A100 GPU,FP16):
| 序列长度 | 标准Attention | FlashAttention | 加速比 |
|----------|---------------|----------------|--------|
| 512 | ~15 ms | ~4 ms | 3.75x |
| 2,048 | ~180 ms | ~25 ms | 7.2x |
| 4,096 | ~680 ms | ~80 ms | 8.5x |
| 8,192 | ~2,800 ms | ~250 ms | 11.2x |
显存占用对比:
| 序列长度 | 标准Attention显存 | FlashAttention显存 | 节省比例 |
|----------|-------------------|-------------------|----------|
| 2,048 | 16 MB | 2 MB | 8x |
| 4,096 | 64 MB | 4 MB | 16x |
| 8,192 | 256 MB | 8 MB | 32x |
| 32,768 | 4 GB | 32 MB | 128x |
> 💡 **关键洞察**:序列越长,FlashAttention的优势越明显。这正是长上下文场景(如32K+上下文)必须使用FlashAttention的原因。
4.3.6 FlashAttention演进¶
FlashAttention版本演进:
FlashAttention (2022):
- 核心创新:IO-aware分块计算
- 利用SRAM加速,减少HBM访问
- 加速比:2-4倍
FlashAttention-2 (2023):
- 更好的并行化策略
- 改进工作划分,减少线程束等待
- 加速比:4-8倍
FlashAttention-3 (2024, Hopper架构):
- 利用Tensor Core异步操作
- FP8低精度支持
- H100利用率从35%提升至75%
- 加速比:8-16倍
主流大模型均采用FlashAttention:
GPT-3、Falcon、LLaMA2/3、Mistral、Qwen等
多头注意力机制¶
5.1 为什么需要多头¶
单头注意力的问题:一个注意力头只能学到一种"关注模式"。而自然语言有多种关系维度(语法、语义、共指、因果等)。
多头的直觉:用不同的"眼睛"看同一句话
Head 1: 可能学会了语法依赖(主语→动词)
Head 2: 可能学会了共指关系(代词→先行词)
Head 3: 可能学会了修饰关系(形容词→名词)
Head 4: 可能学会了长距离依赖(问题→答案)
...
5.2 多头注意力实现¶
class MultiHeadAttention(nn.Module):
"""
多头注意力机制
参数量分析(d_model=1024, h=16, d_k=64):
W_Q: 1024×1024 = 1M
W_K: 1024×1024 = 1M
W_V: 1024×1024 = 1M
W_O: 1024×1024 = 1M
总计: 4M 参数
"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须被 n_heads 整除" # assert断言:条件False时抛出AssertionError
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度
# 四个投影矩阵
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor, # [batch, seq_q, d_model]
key: torch.Tensor, # [batch, seq_k, d_model]
value: torch.Tensor, # [batch, seq_k, d_model]
mask: torch.Tensor = None,
) -> torch.Tensor:
batch_size = query.size(0)
# 1. 线性投影 → [batch, seq, d_model]
Q = self.W_Q(query)
K = self.W_K(key)
V = self.W_V(value)
# 2. 分头: [batch, seq, d_model] → [batch, n_heads, seq, d_k]
Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # view重塑张量形状
K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 3. 缩放点积注意力
d_k = self.d_k
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 4. 加权求和 → [batch, n_heads, seq_q, d_k]
context = torch.matmul(attn_weights, V)
# 5. 合并多头: [batch, seq_q, d_model]
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 6. 输出投影
output = self.W_O(context)
return output
# 验证
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)
# Self-attention: Q=K=V=x
out = mha(x, x, x)
print(f"Self-attention 输出: {out.shape}") # [2, 10, 512]
# Cross-attention: Q来自decoder, K/V来自encoder
enc_out = torch.randn(2, 20, 512)
dec_in = torch.randn(2, 10, 512)
cross_out = mha(dec_in, enc_out, enc_out)
print(f"Cross-attention 输出: {cross_out.shape}") # [2, 10, 512]
5.3 GQA 与 MQA :现代大模型的注意力优化¶
标准 MHA (Multi-Head Attention):
Q: [batch, n_heads, seq, d_k] ← 每个头有独立的Q
K: [batch, n_heads, seq, d_k] ← 每个头有独立的K
V: [batch, n_heads, seq, d_k] ← 每个头有独立的V
KV cache 大小 = 2 × n_heads × seq × d_k
MQA (Multi-Query Attention, GPT-J):
Q: [batch, n_heads, seq, d_k] ← 每个头有独立的Q
K: [batch, 1, seq, d_k] ← 所有头共享一组K
V: [batch, 1, seq, d_k] ← 所有头共享一组V
KV cache 大小 = 2 × 1 × seq × d_k → 减少 n_heads 倍
GQA (Grouped-Query Attention, LLaMA-2/3):
Q: [batch, n_heads, seq, d_k] ← 每个头有独立的Q
K: [batch, n_kv_heads, seq, d_k] ← 每组共享一组K
V: [batch, n_kv_heads, seq, d_k] ← 每组共享一组V
KV cache 大小 = 2 × n_kv_heads × seq × d_k
示例 (LLaMA-3-70B): n_heads=64, n_kv_heads=8
→ 每8个Q头共享1组KV → KV cache减少8倍,性能接近MHA
前馈神经网络¶
6.1 标准 FFN¶
每个 Transformer 层中的全连接前馈网络( Position-wise FFN ):
class FeedForward(nn.Module):
"""标准FFN,4倍扩展"""
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
d_ff = d_ff or 4 * d_model # 默认4倍扩展
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# [batch, seq, d_model] → [batch, seq, d_ff] → [batch, seq, d_model]
return self.w2(self.dropout(F.relu(self.w1(x))))
6.2 现代激活函数: GeLU 与 SwiGLU¶
# GeLU (GPT-2/BERT 使用)
# GELU(x) = x · Φ(x),其中 Φ 是标准正态分布的CDF
# 比 ReLU 更平滑,不会在 x=0 处产生不可导点
class GeLUFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
d_ff = d_ff or 4 * d_model
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w2(self.dropout(F.gelu(self.w1(x))))
# SwiGLU (LLaMA/Qwen/Mistral 使用)
# SwiGLU(x, W, V, W2) = (Swish(xW) ⊙ xV) W2
# 引入门控机制,效果优于 GeLU,但参数量增加 50%
class SwiGLUFeedForward(nn.Module):
"""LLaMA-style SwiGLU FFN"""
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
# LLaMA 使用 8/3 * d_model 作为 d_ff 以保持参数量不变
d_ff = d_ff or int(8 / 3 * d_model)
# 取最接近的 256 的倍数(GPU对齐优化)
d_ff = ((d_ff + 255) // 256) * 256
self.w1 = nn.Linear(d_model, d_ff, bias=False) # gate projection
self.w3 = nn.Linear(d_model, d_ff, bias=False) # up projection
self.w2 = nn.Linear(d_ff, d_model, bias=False) # down projection
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Swish(xW1) ⊙ xW3 → 门控机制
gate = F.silu(self.w1(x)) # SiLU = Swish(x) = x * sigmoid(x)
up = self.w3(x)
return self.w2(self.dropout(gate * up))
# 对比
d_model = 1024
ffn_relu = FeedForward(d_model)
ffn_gelu = GeLUFeedForward(d_model)
ffn_swiglu = SwiGLUFeedForward(d_model)
for name, module in [("ReLU FFN", ffn_relu), ("GeLU FFN", ffn_gelu), ("SwiGLU FFN", ffn_swiglu)]:
params = sum(p.numel() for p in module.parameters())
print(f"{name}: {params:,} 参数")
# ReLU FFN: 8,393,728 参数
# GeLU FFN: 8,393,728 参数
# SwiGLU FFN: 8,650,752 参数 (d_ff调整后参数量接近)
层归一化与残差连接¶
7.1 为什么需要归一化¶
深层网络中,每层输出的分布会不断漂移( Internal Covariate Shift ),导致训练不稳定。归一化将输出拉回标准分布。
7.2 LayerNorm vs RMSNorm¶
class LayerNorm(nn.Module):
"""
标准层归一化 (BERT/GPT-2 使用)
LayerNorm(x) = γ · (x - μ) / √(σ² + ε) + β
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model)) # 可学习缩放
self.beta = nn.Parameter(torch.zeros(d_model)) # 可学习偏移
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class RMSNorm(nn.Module):
"""
RMS层归一化 (LLaMA/Qwen/Mistral 使用)
RMSNorm(x) = γ · x / RMS(x)
其中 RMS(x) = √(1/n · Σx²)
优势:去掉了均值计算和偏移参数 β
→ 计算更快(约10-15%),效果相当
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return self.gamma * x / rms
# 对比
x = torch.randn(2, 10, 512)
ln = LayerNorm(512)
rmsn = RMSNorm(512)
out_ln = ln(x)
out_rmsn = rmsn(x)
print(f"LayerNorm 输出统计: mean={out_ln.mean():.4f}, std={out_ln.std():.4f}")
print(f"RMSNorm 输出统计: mean={out_rmsn.mean():.4f}, std={out_rmsn.std():.4f}")
7.3 Pre-LN vs Post-LN:结构对比与数学分析¶
7.3.1 两种结构的对比¶
┌─────────────────────────────────────────────────────────────────────────────┐
│ Post-LN (原始 Transformer) │
│ │
│ x → Attention → + residual → LayerNorm → FFN → + residual → LayerNorm │
│ │
│ 特点:LayerNorm 在残差连接之后(最后) │
│ 缺点:深层训练极不稳定,需要 warm-up 学习率预热 │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│ Pre-LN (现代大模型标配) │
│ │
│ x → LayerNorm → Attention → + residual │
│ x → LayerNorm → FFN → + residual │
│ │
│ 特点:LayerNorm 在残差连接之前(每个子层内部) │
│ 优点:梯度直接传播,训练稳定,无需 warm-up │
└─────────────────────────────────────────────────────────────────────────────┘
7.3.2 为什么 Pre-LN 更稳定:数学分析¶
梯度分析:
设第 l 层输入为 x_l,输出为 x_{l+1}:
Post-LN 梯度流动:
∂L/∂x_l = ∂L/∂x_{l+1} · ∂x_{l+1}/∂x_l
其中 ∂x_{l+1}/∂x_l 包含 LayerNorm 的雅可比行列式
→ 当网络加深时,LayerNorm 的梯度依赖于其输入的统计量
→ 深层梯度容易消失或爆炸,导致训练不稳定
Pre-LN 梯度流动:
每一层的梯度都有独立的"直接路径"回到输入:
∂L/∂x_l = ∂L/∂x_{l+1} · (I + J_l)
其中 I 是恒等矩阵,J_l 是子层Jacobian
→ 残差连接确保梯度可以无损传递
→ 即使网络很深,梯度也能稳定回传
训练动态对比:
| 特性 | Post-LN | Pre-LN |
|---|---|---|
| 梯度幅度 | 输出层附近梯度极大 | 各层梯度幅度均匀 |
| 训练稳定性 | 需 warm-up,避免爆炸 | 无需 warm-up,稳定训练 |
| 深层网络 | 超过6层难以训练 | 可训练数百层 |
| 典型深度 | ≤ 12 层 | 12 → 100+ 层 |
论文《On Layer Normalization in the Transformer Architecture》证明: - Post-LN 在训练初期,输出层附近参数梯度期望值很大 - Pre-LN 将 LN 置于残差分支内部,梯度直接传递
7.3.3 不同大模型的 LN 位置¶
| 模型 | LN 类型 | 位置 | 备注 |
|---|---|---|---|
| BERT | Post-LN | 残差后 | 原始 Transformer,12 层 |
| GPT-2 | Pre-LN | 残差前 | OpenAI,揭示 LN 位置重要性 |
| T5 | Pre-LN | 残差前 | Google,Encoder-Decoder |
| LLaMA | Pre-LN + RMSNorm | 残差前 | Meta AI,现代 LLM 标配 |
| GPT-3 | Pre-LN | 残差前 | 175B 参数,深层训练关键 |
| PaLM | Pre-LN + RMSNorm | 残差前 | Google,540B 参数 |
| DeepSeek-V3 | Pre-LN + RMSNorm | 残差前 | MoE + Pre-LN 支持超深网络 |
💡 规律:所有超过 30 层的大模型(GPT-3/PaLM/LLaMA/DeepSeek)都使用 Pre-LN。
7.3.4 Pre-LN 的潜在问题与解决方案¶
虽然 Pre-LN 更稳定,但研究发现浅层梯度可能大于深层(与 Post-LN 正好相反):
梯度幅度分布:
Post-LN:深层 > 浅层
Pre-LN: 浅层 > 深层
解决方案:DeepNorm (Microsoft)
在 Pre-LN 基础上引入缩放因子:
x_{l+1} = α · x_l + Sublayer( LN(x_l) )
α 的设置:
- 普通 Transformer:α = (2N)^(1/4),N 为层数
- 极深网络(如 1000 层):α 动态调整
7.3.5 残差连接的几何意义¶
残差连接不仅仅是"缓解梯度消失"的技术,从几何视角看,它有着更深层的意义。
7.3.5.1 直觉解释:梯度高速公路¶
残差连接的本质:提供一条梯度高速公路
考虑等效电路:
输入 x ──┬──→ [主路径: 多层网络] ──→ 输出
│
└──→ [跳远连接: identity] ──→ 输出 +
如果主路径学不到东西,梯度可以直接绕过去
这使得深层网络的训练更加稳定
这个"短路"设计让网络学习变得更容易: - 最坏情况:子层什么都不学习,输出为 0 → 输出 = 输入(恒等映射) - 最好情况:子层学习到有用的变换 → 输出 = 输入 + 有用变换
7.3.5.2 梯度路径分析:多条并行通道¶
没有残差:梯度必须穿过每一层网络
∂L/∂x = ∂L/∂Layer₁ → ∂L/∂Layer₂ → ... → ∂L/∂Layerₙ
有残差连接:梯度有多条路径
∂L/∂x = ∂L/∂直接路径 + ∂L/∂Layer₁路径 + ∂L/∂Layer₂路径 + ...
深层梯度消失问题被有效缓解
数学上,对于残差块:
其中 I 是单位矩阵,保证梯度下界至少为 1,有效防止梯度消失。
7.3.5.3 信息流视角:流形上的增量更新¶
残差连接确保了深层网络的信息流动性:
浅层特征 ——直接传递——→ 深层特征
↗ ↘
通过残差 通过多层变换
↘ ↗
——混合后的特征——
这使得网络可以学习恒等映射:f(x) ≈ x
从流形学习角度看: - 深层特征位于输入空间的高维流形上 - 残差连接允许在流形上"原地踏步"(恒等映射) - 浅层特征可以直接传递到深层,保持语义信息不被破坏
7.3.5.4 主成分分析视角¶
从PCA角度看残差连接的作用:
无残差网络:
- 每层都在对特征做线性/非线性变换
- 深层特征可能被"扭曲",偏离原始流形
- 重要主成分可能在多层变换中丢失
有残差网络:
- 恒等分支保持原始主成分方向
- 变换分支学习"残差"方向
- 浅层学到的语义可以无损传递到深层
这解释了为什么极深网络(如 100+ 层)可以正常训练:残差连接保持了特征空间的"骨架"结构。
7.3.6 现代大模型的共识¶
Pre-LN + RMSNorm 已成为现代大模型的标准配置:
LLaMA 架构:
x → RMSNorm → Attention → + residual
x → RMSNorm → SwiGLU FFN → + residual
LLaMA 2/3 代码示例:
```python
class TransformerBlock(nn.Module):
"""Pre-Norm Transformer Block (LLaMA-style)"""
def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.norm2 = RMSNorm(d_model)
self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)
def forward(self, x, mask=None):
# Pre-Norm + Residual
x_norm = self.norm1(x)
x = x + self.attn(x_norm, x_norm, x_norm, mask)
x = x + self.ffn(self.norm2(x))
return x
编解码器架构对比¶
8.1 三种主流架构¶
这是理解现代大模型最关键的知识点之一:
┌────────────────────────────────────────────────────────────────┐
│ 三种 Transformer 架构 │
├──────────────┬──────────────┬──────────────┬──────────────────┤
│ │ Encoder-Only │ Decoder-Only │ Encoder-Decoder │
├──────────────┼──────────────┼──────────────┼──────────────────┤
│ 代表模型 │ BERT │ GPT/LLaMA │ T5/BART │
│ 注意力类型 │ 双向全注意力 │ 因果注意力 │ 双向+因果+交叉 │
│ 训练目标 │ MLM + NSP │ 下一token预测 │ Seq2Seq │
│ 擅长任务 │ 分类/NER/QA │ 生成/对话 │ 翻译/摘要 │
│ 上下文 │ 看到全部输入 │ 只看到左侧 │ 编码器全部,解码器左侧│
│ 当前趋势 │ 逐渐减少 │ 主流(GPT时代) │ 特定场景使用 │
└──────────────┴──────────────┴──────────────┴──────────────────┘
为什么 Decoder-Only 成为主流?
1. 统一的训练目标:所有任务都可以转化为"文本生成"
2. 规模效应更好:参数量越大,涌现能力越强
3. 少样本能力:通过 prompt 即可适配新任务,无需微调
4. 工程简单:只需一个模型处理所有任务
8.2 因果注意力掩码¶
Decoder-Only 模型的核心:确保位置 \(i\) 只能看到位置 \(0, 1, \ldots, i\)(不泄露未来信息)。
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
创建因果注意力掩码
例如 seq_len=4:
[[1, 0, 0, 0], ← 位置0只看自己
[1, 1, 0, 0], ← 位置1看0和1
[1, 1, 1, 0], ← 位置2看0、1、2
[1, 1, 1, 1]] ← 位置3看所有
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# 训练时:一次前向传播同时预测所有位置
# 因为有因果mask,每个位置只看到左侧,等价于自回归
Transformer 的训练¶
9.1 训练目标¶
Decoder-Only (如 GPT) 的训练:
输入: [BOS] The cat sat on the mat
标签: The cat sat on the mat [EOS]
损失函数: Cross-Entropy Loss
L = -1/T Σ_t log P(x_t | x_{<t})
即最大化每个位置正确预测下一个token的对数概率。
整个序列只需一次前向传播 (teacher forcing)。
9.2 学习率调度¶
Transformer 训练中最关键的超参数管理:
class WarmupCosineScheduler:
"""
Warmup + Cosine Decay 学习率调度
几乎所有现代大模型都使用这种方式
阶段1 (warmup): 学习率从0线性增长到max_lr
阶段2 (cosine): 学习率从max_lr余弦衰减到min_lr
"""
def __init__(
self,
optimizer,
warmup_steps: int,
total_steps: int,
max_lr: float = 3e-4,
min_lr: float = 1e-5,
):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.max_lr = max_lr
self.min_lr = min_lr
self.step_count = 0
def step(self):
self.step_count += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def get_lr(self):
if self.step_count < self.warmup_steps:
# 线性 warmup
return self.max_lr * self.step_count / self.warmup_steps
else:
# 余弦衰减
progress = (self.step_count - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
1 + math.cos(math.pi * progress)
)
# ---------- 完整训练循环示例 ----------
# 以下代码展示了 Decoder-Only Transformer 的标准训练流程
def train_decoder_only(model, dataloader, vocab_size, num_steps=10000, device='cpu'):
"""
Decoder-Only Transformer 训练循环
核心流程:
1. 输入序列 [BOS, w1, w2, ..., wT] 送入模型
2. 模型在每个位置预测下一个 token
3. 标签是输入右移一位: [w1, w2, ..., wT, EOS]
4. 用交叉熵损失计算误差并反向传播
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=2000, total_steps=num_steps)
model.train()
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
input_ids = batch['input_ids'].to(device) # [B, T]
targets = input_ids[:, 1:] # 右移一位作为标签
logits = model(input_ids[:, :-1]) # [B, T-1, vocab]
# 交叉熵损失:每个位置预测下一个 token
loss = F.cross_entropy(
logits.reshape(-1, vocab_size), # [B*(T-1), vocab]
targets.reshape(-1) # [B*(T-1)]
)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 梯度裁剪:防止梯度爆炸(Transformer 训练的关键技巧)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
if (step + 1) % 1000 == 0:
print(f"Step {step+1}/{num_steps} | Loss: {loss.item():.4f} | LR: {scheduler.get_lr():.6f}")
return model
# 注意:完整的可运行训练代码请参考 03-手写Transformer完整实现.md
Transformer 的推理¶
10.1 自回归生成¶
训练是并行的(一次处理整个序列),
推理是自回归的(一次生成一个token):
Step 0: 输入 [BOS] → 预测 "The"
Step 1: 输入 [BOS, The] → 预测 "cat"
Step 2: 输入 [BOS, The, cat] → 预测 "sat"
...
问题:每一步都要重新计算之前所有位置的注意力 → 巨大浪费
解决:KV Cache
10.2 KV Cache¶
KV Cache 是大模型推理最重要的优化技术:
"""KV Cache 原理示意"""
class CachedMultiHeadAttention(nn.Module):
"""带KV Cache的多头注意力(推理用)"""
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, kv_cache=None):
"""
Args:
x: [batch, 1, d_model] — 推理时只有当前新token
kv_cache: (cached_K, cached_V) — 之前步骤的K/V
Returns:
output, new_kv_cache
"""
B = x.size(0)
# 只为当前 token 计算 Q/K/V
Q = self.W_Q(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
K_new = self.W_K(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
V_new = self.W_V(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
K_cached, V_cached = kv_cache
# 拼接历史 KV → [batch, n_heads, seq_so_far+1, d_k]
K = torch.cat([K_cached, K_new], dim=2)
V = torch.cat([V_cached, V_new], dim=2)
else:
K, V = K_new, V_new
# 只计算当前 Q 与所有 K 的注意力 → O(1×S) 而非 O(S×S)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model)
output = self.W_O(out)
return output, (K, V) # 返回更新后的 cache
"""
无 KV Cache: 生成 T 个token → 总计算量 O(T² · d)
有 KV Cache: 生成 T 个token → 总计算量 O(T · d) + 缓存内存
速度提升: ~T倍(序列越长,加速越明显)
代价: 需要额外内存存储 KV Cache
LLaMA-2-7B, seq_len=4096:
KV cache = 2(K+V) × 32(layers) × 32(heads) × 4096(seq) × 128(d_k) × 2(bytes/fp16)
≈ 2GB
"""
10.3 解码策略¶
def generate(
model, tokenizer, prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
):
"""
完整的文本生成函数,支持多种解码策略
"""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
kv_cache = None
generated = []
for _ in range(max_new_tokens):
# 只输入最新 token(有 KV Cache 时)
if kv_cache is not None:
x = input_ids[:, -1:]
else:
x = input_ids
logits, kv_cache = model(x, kv_cache=kv_cache)
logits = logits[:, -1, :] # 只取最后一个位置
# Temperature scaling
logits = logits / temperature
# Top-K: 只保留概率最高的K个token
if top_k > 0:
topk_logits, topk_indices = torch.topk(logits, top_k)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(1, topk_indices, topk_logits)
# Top-P (Nucleus Sampling)
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 移除累积概率超过 top_p 的token
remove_mask = cumulative_probs - sorted_probs > top_p
sorted_probs[remove_mask] = 0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
# 采样
next_token = torch.multinomial(sorted_probs, 1)
next_token = sorted_indices.gather(1, next_token)
if next_token.item() == tokenizer.eos_token_id:
break
generated.append(next_token.item())
input_ids = torch.cat([input_ids, next_token], dim=-1)
return tokenizer.decode(generated)
📝 本章小结¶
核心知识点回顾¶
| 组件 | 核心公式/概念 | 现代大模型中的演进 |
|---|---|---|
| 输入嵌入 | \(E \in \mathbb{R}^{V \times d}\),缩放 \(\sqrt{d_{model}}\) | 子词分词(BPE/SentencePiece) |
| 位置编码 | Sinusoidal / Learnable / RoPE | RoPE 成为 LLM 标配 |
| 自注意力 | \(\text{softmax}(QK^T / \sqrt{d_k})V\) | FlashAttention ⅔ 加速 |
| 多头注意力 | \(h\) 个头并行,Concat + \(W^O\) | GQA/MQA 减少 KV Cache |
| 前馈网络 | \(\text{ReLU}(xW_1)W_2\) | SwiGLU(门控 + SiLU 激活) |
| 层归一化 | LayerNorm / RMSNorm | Pre-LN + RMSNorm 成为标配 |
| 残差连接 | \(x + \text{Sublayer}(x)\) | 确保梯度稳定传播 |
| 解码策略 | Greedy / Top-K / Top-P / Temperature | 自回归 + KV Cache |
关键设计决策总结¶
现代 LLM(如 LLaMA-3)的架构选择:
├── 位置编码: RoPE(旋转位置编码,支持相对位置)
├── 注意力: GQA(分组查询注意力,平衡性能和效率)
├── FFN: SwiGLU(门控线性单元,效果优于 ReLU/GeLU)
├── 归一化: RMSNorm(比 LayerNorm 快 10-15%)
├── LN 位置: Pre-LN(训练更稳定,无需 warm-up)
├── 加速: FlashAttention-2/3(减少 HBM 访问)
└── 架构: Decoder-Only(统一生成范式)
思考题¶
-
RoPE 的外推性:RoPE 理论上可以处理任意长度的序列,但实践中长序列性能会下降。为什么?有哪些方法可以改善 RoPE 的长度外推?(提示:NTK-aware Scaling、YaRN、Position Interpolation)
-
FlashAttention 的精度:FlashAttention 使用 online softmax 近似,与标准注意力在数值上可能有微小差异。在实际训练中,这种差异是否会影响模型质量?为什么?
-
KV Cache 压缩:除了 GQA/MQA,还有哪些方法可以减少 KV Cache 的内存占用?(提示:MQA with Quantization、Token Eviction、Cross-Layer Sharing)
-
Pre-LN 的局限:虽然 Pre-LN 训练更稳定,但有研究发现它可能导致"表征塌缩"(Representation Collapse),即深层的输出趋于相似。你如何看待这个问题?
延伸阅读¶
- 完整实现:如需从零手写完整 Transformer 并训练,请参考 03-手写 Transformer 完整实现
- 注意力机制扩展: FlashAttention 、稀疏注意力等高级主题,请参考 02-注意力机制详解
- 大模型训练进阶:分布式训练、混合精度等,请参考 大模型核心技术
最后更新日期: 2026-03-26 适用版本: LLM 学习教程 v2026
