从零搭建小型 LLM¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:从零实现一个完整的小型大语言模型( Mini-LLM ),包括 LLaMA 架构实现、 Tokenizer 训练、预训练流程,真正理解 LLM 的内部工作原理。
📌 定位说明:对标 happy-LLM Ch5 (动手搭建大模型),我们的实现覆盖完整的 LLaMA2 架构( RoPE/RMSNorm/SwiGLU/GQA )、 BPE Tokenizer 训练、以及端到端预训练。代码可运行、形状标注清晰、工程实践更贴近实际。
目录¶
- 1. 为什么要从零搭建 LLM
- 2. LLaMA 架构全貌
- 3. 核心组件实现
- 4. 完整 LLaMA 模型
- 5. BPE Tokenizer 训练
- 6. 预训练 Pipeline
- 7. 模型评估与生成
- 8. 工程优化技巧
1. 为什么要从零搭建 LLM¶
Text Only
从零搭建LLM的学习价值:
━━━━━━━━━━━━━━━━━━━━━━━━━━━
1. 彻底理解每一行代码的含义
2. 明白"大模型"到底"大"在哪里
3. 掌握预训练的核心流程
4. 为后续的微调/部署/优化打下基础
5. 面试时能从容讲解底层实现
我们的实现目标:
├── 模型架构: LLaMA-2风格 (~15M参数)
├── Tokenizer: 训练一个BPE分词器
├── 预训练: 在小规模中文语料上CLM训练
└── 生成: 实现自回归文本生成 + 多种采样策略
2. LLaMA 架构全貌¶
2.1 架构概览¶
Text Only
LLaMA-2 架构(我们的Mini版本)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入 Token IDs: [B, L]
↓
┌─────────────────────────┐
│ Token Embedding │ [B, L] → [B, L, D]
└─────────────────────────┘
↓
┌─────────────────────────┐ ×N layers
│ ┌───────────────────┐ │
│ │ RMSNorm │ │ Pre-Norm
│ └───────────────────┘ │
│ ↓ │
│ ┌───────────────────┐ │
│ │ GQA Attention │ │ + RoPE位置编码
│ │ (因果Mask) │ │
│ └───────────────────┘ │
│ ↓ + Residual │
│ ┌───────────────────┐ │
│ │ RMSNorm │ │ Pre-Norm
│ └───────────────────┘ │
│ ↓ │
│ ┌───────────────────┐ │
│ │ SwiGLU FFN │ │ 8/3 × D hidden
│ └───────────────────┘ │
│ ↓ + Residual │
└─────────────────────────┘
↓
┌─────────────────────────┐
│ RMSNorm │ Final Norm
└─────────────────────────┘
↓
┌─────────────────────────┐
│ Linear (→ vocab_size) │ LM Head (weight sharing)
└─────────────────────────┘
↓
Output Logits: [B, L, V]
2.2 与原始 Transformer 的区别¶
| 组件 | 原始 Transformer | LLaMA-2 | 为什么改 |
|---|---|---|---|
| 归一化 | LayerNorm (Post-Norm) | RMSNorm (Pre-Norm) | 训练更稳定,计算更快 |
| 位置编码 | 正弦位置编码 | RoPE 旋转位置编码 | 支持外推,相对位置信息 |
| 激活函数 | ReLU | SwiGLU | 效果更好(经验验证) |
| 注意力 | MHA | GQA | KV Cache 更省显存 |
| FFN 维度 | 4×D | 8/3×D (约 2.67×D ) | SwiGLU 有门控,需要调整 |
| Bias | 有 | 无 | 减少参数,效果不降 |
2.3 模型参数量逐层分析¶
理解参数量分布是设计模型架构的关键。以下以 LLaMA-2 系列为例,逐层拆解参数来源:
Python
def analyze_llama_params(dim, n_layers, n_heads, n_kv_heads, vocab_size, multiple_of=256):
"""
逐层分析 LLaMA 模型的参数量分布
LLaMA 每层 Transformer 包含:
- Attention: wq, wk, wv, wo (4个线性层)
- FFN (SwiGLU): w1, w2, w3 (3个线性层)
- Norm: 2个 RMSNorm
"""
head_dim = dim // n_heads
# SwiGLU hidden_dim = 2/3 * 4 * dim,对齐到 multiple_of
hidden_dim = int(2 * 4 * dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# === 每层参数 ===
# Attention: wq(dim→dim), wk(dim→n_kv_heads*head_dim), wv(dim→n_kv_heads*head_dim), wo(dim→dim)
attn_params = dim * n_heads * head_dim + 2 * dim * n_kv_heads * head_dim + n_heads * head_dim * dim
# FFN (SwiGLU): w1(dim→hidden), w2(hidden→dim), w3(dim→hidden)
ffn_params = dim * hidden_dim + hidden_dim * dim + dim * hidden_dim
# Norm: 2个 RMSNorm,每个 dim 个参数
norm_params = 2 * dim
per_layer = attn_params + ffn_params + norm_params
# === 全局参数 ===
embedding_params = vocab_size * dim # Token Embedding
# LLaMA 使用 weight tying: LM Head 与 Embedding 共享权重
final_norm_params = dim # Final RMSNorm
total = n_layers * per_layer + embedding_params + final_norm_params
print(f"{'='*60}")
print(f"LLaMA 参数分析 (dim={dim}, layers={n_layers}, heads={n_heads}, kv_heads={n_kv_heads})")
print(f"{'='*60}")
print(f" 每层 Attention: {attn_params:>12,} ({attn_params/1e6:.2f}M)")
print(f" 每层 FFN: {ffn_params:>12,} ({ffn_params/1e6:.2f}M)")
print(f" 每层 Norm: {norm_params:>12,}")
print(f" 每层合计: {per_layer:>12,} ({per_layer/1e6:.2f}M)")
print(f" ─────────────────────────────")
print(f" N层合计: {n_layers * per_layer:>12,} ({n_layers*per_layer/1e6:.2f}M)")
print(f" Embedding: {embedding_params:>12,} ({embedding_params/1e6:.2f}M)")
print(f" Final Norm: {final_norm_params:>12,}")
print(f" ═════════════════════════════")
print(f" 总参数量: {total:>12,} ({total/1e6:.2f}M)")
print(f" Attention 占比: {n_layers * attn_params / total * 100:.1f}%")
print(f" FFN 占比: {n_layers * ffn_params / total * 100:.1f}%")
print(f" Embedding 占比: {embedding_params / total * 100:.1f}%")
return total
# 真实 LLaMA-2 系列配置对比
print(">>> LLaMA-2 7B")
analyze_llama_params(dim=4096, n_layers=32, n_heads=32, n_kv_heads=32, vocab_size=32000)
print()
print(">>> LLaMA-2 13B")
analyze_llama_params(dim=5120, n_layers=40, n_heads=40, n_kv_heads=40, vocab_size=32000)
print()
print(">>> LLaMA-2 70B (GQA)")
analyze_llama_params(dim=8192, n_layers=80, n_heads=64, n_kv_heads=8, vocab_size=32000)
print()
print(">>> 我们的 Mini-LLaMA (~15M)")
analyze_llama_params(dim=512, n_layers=8, n_heads=8, n_kv_heads=4, vocab_size=5000)
Text Only
关键发现:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1. FFN 占参数量的主体(约 60-67%),而非 Attention
2. Embedding 在小模型中占比大,大模型中占比小
3. GQA 将 Attention 的 KV 投影参数减少到 1/n_kv_heads
- LLaMA-2 70B: n_kv_heads=8 vs n_heads=64 → KV 参数减少 8 倍
4. Weight Tying (共享 Embedding 和 LM Head) 节省 vocab_size × dim 参数
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
3. 核心组件实现¶
3.1 RMSNorm¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
比LayerNorm更快:不需要计算均值,只计算均方根
公式: RMSNorm(x) = x / RMS(x) * γ
其中: RMS(x) = sqrt(mean(x²) + ε)
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() # super()调用父类方法
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放参数 γ
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args: x [batch_size, seq_len, dim]
Returns: normalized x, same shape
"""
# 计算均方根
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
# 归一化并缩放
return x / rms * self.weight
# 验证
norm = RMSNorm(64)
x = torch.randn(2, 10, 64)
out = norm(x)
print(f"RMSNorm: 输入 {x.shape} → 输出 {out.shape}")
# 验证输出的RMS约为1
rms = torch.sqrt(out.pow(2).mean(dim=-1))
print(f"归一化后RMS均值: {rms.mean().item():.4f}") # 应接近1.0
3.2 RoPE 旋转位置编码¶
Python
class RotaryPositionEmbedding:
"""
RoPE (Rotary Position Embedding)
核心思想: 用旋转矩阵编码位置信息
- 对Q和K分别施加旋转,使得内积自然包含相对位置信息
- q_m · k_n = f(q, m) · f(k, n) → 只依赖相对位置 m-n
"""
@staticmethod # @staticmethod无需实例即可调用
def precompute_freqs(dim: int, max_seq_len: int, theta: float = 10000.0):
"""
预计算频率参数
freqs[i] = 1 / theta^(2i/dim), i = 0, 1, ..., dim/2 - 1
"""
# [dim/2] 频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
# [max_seq_len] 位置
t = torch.arange(max_seq_len).float()
# [max_seq_len, dim/2] 外积
freqs = torch.outer(t, freqs)
# 复数形式: e^(i*θ) = cos(θ) + i*sin(θ)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis # [max_seq_len, dim/2] 复数张量
@staticmethod
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor,
freqs_cis: torch.Tensor):
"""
对Q和K施加旋转位置编码
Args:
xq: [B, L, n_heads, head_dim]
xk: [B, L, n_kv_heads, head_dim]
freqs_cis: [L, head_dim/2] 复数频率
"""
# 将实数张量转为复数: [B, L, H, D] → [B, L, H, D/2] (复数)
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 调整freqs_cis的形状以匹配
# freqs_cis: [L, D/2] → [1, L, 1, D/2]
freqs = freqs_cis.unsqueeze(0).unsqueeze(2) # unsqueeze增加一个维度
# 旋转: 复数乘法 = 旋转
xq_out = torch.view_as_real(xq_complex * freqs).flatten(-2)
xk_out = torch.view_as_real(xk_complex * freqs).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
# 验证RoPE
freqs = RotaryPositionEmbedding.precompute_freqs(dim=64, max_seq_len=512)
print(f"RoPE频率表: {freqs.shape}") # [512, 32]
q = torch.randn(2, 10, 4, 64) # [B, L, n_heads, head_dim]
k = torch.randn(2, 10, 4, 64)
q_rot, k_rot = RotaryPositionEmbedding.apply_rotary_emb(q, k, freqs[:10])
print(f"RoPE编码后: Q {q_rot.shape}, K {k_rot.shape}")
3.3 SwiGLU 前馈网络¶
Python
class SwiGLU_FFN(nn.Module):
"""
SwiGLU前馈网络 (LLaMA使用)
公式: FFN(x) = (Swish(xW₁) ⊙ xW₃)W₂
其中:
- Swish(x) = x * sigmoid(x) (也叫SiLU)
- ⊙ 表示逐元素乘法(门控机制)
- hidden_dim = 8/3 * dim(约2.67倍,取最近的multiple_of对齐)
"""
def __init__(self, dim: int, hidden_dim: int = None, multiple_of: int = 256):
super().__init__()
if hidden_dim is None:
hidden_dim = int(2 * (4 * dim) / 3)
# 取最近的multiple_of的倍数
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # Gate projection
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # Down projection
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # Up projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args: x [B, L, D]
Returns: [B, L, D]
"""
# SwiGLU: swish(x @ W1) * (x @ W3) @ W2
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# 验证
ffn = SwiGLU_FFN(dim=128)
x = torch.randn(2, 10, 128)
out = ffn(x)
print(f"SwiGLU FFN: 输入 {x.shape} → 输出 {out.shape}")
print(f"FFN参数量: {sum(p.numel() for p in ffn.parameters()):,}")
print(f"隐藏层维度: {ffn.w1.out_features}")
3.4 GQA 注意力¶
Python
class GroupedQueryAttention(nn.Module):
"""
Grouped-Query Attention (GQA)
MHA: n_heads个Q, n_heads个K, n_heads个V
MQA: n_heads个Q, 1个K, 1个V
GQA: n_heads个Q, n_kv_heads个K, n_kv_heads个V
每n_heads/n_kv_heads个Q头共享同一组KV → 减少KV Cache
"""
def __init__(self, dim: int, n_heads: int, n_kv_heads: int = None):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads or n_heads # 默认MHA
self.head_dim = dim // n_heads
self.n_rep = n_heads // self.n_kv_heads # 每组KV重复次数
assert dim % n_heads == 0 # assert断言:条件False时抛出AssertionError
assert n_heads % self.n_kv_heads == 0
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""将KV头重复以匹配Q头数量"""
if self.n_rep == 1:
return x
B, L, n_kv_heads, head_dim = x.shape
x = x.unsqueeze(3).expand(B, L, n_kv_heads, self.n_rep, head_dim)
return x.reshape(B, L, self.n_heads, head_dim)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor,
mask: torch.Tensor = None):
"""
Args:
x: [B, L, D]
freqs_cis: [L, head_dim/2] RoPE频率
mask: [L, L] 因果mask
Returns: [B, L, D]
"""
B, L, _ = x.shape
# 线性投影
q = self.wq(x).view(B, L, self.n_heads, self.head_dim) # view重塑张量形状
k = self.wk(x).view(B, L, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, L, self.n_kv_heads, self.head_dim)
# 应用RoPE
q, k = RotaryPositionEmbedding.apply_rotary_emb(q, k, freqs_cis)
# GQA: 重复KV头
k = self._repeat_kv(k) # [B, L, n_heads, head_dim]
v = self._repeat_kv(v)
# 转置: [B, n_heads, L, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Scaled Dot-Product Attention
scale = math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) / scale # [B, H, L, L]
if mask is not None:
scores = scores + mask # mask中被遮蔽位置为-inf
attn = F.softmax(scores, dim=-1)
output = torch.matmul(attn, v) # [B, H, L, head_dim]
# 合并多头
output = output.transpose(1, 2).contiguous().view(B, L, -1)
return self.wo(output)
# 验证
gqa = GroupedQueryAttention(dim=128, n_heads=8, n_kv_heads=2)
x = torch.randn(2, 10, 128)
freqs = RotaryPositionEmbedding.precompute_freqs(dim=128//8, max_seq_len=512)
# 因果mask
L = 10
mask = torch.triu(torch.full((L, L), float('-inf')), diagonal=1)
out = gqa(x, freqs[:L], mask)
print(f"GQA: 输入 {x.shape} → 输出 {out.shape}")
print(f" Q头数: {gqa.n_heads}, KV头数: {gqa.n_kv_heads}, 重复倍数: {gqa.n_rep}")
4. 完整 LLaMA 模型¶
4.1 Transformer Block¶
Python
class TransformerBlock(nn.Module):
"""LLaMA-2 Transformer Block (Pre-Norm + Residual)"""
def __init__(self, dim: int, n_heads: int, n_kv_heads: int,
ffn_dim_multiplier: float = None, multiple_of: int = 256):
super().__init__()
self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads)
self.feed_forward = SwiGLU_FFN(dim, multiple_of=multiple_of)
self.attention_norm = RMSNorm(dim)
self.ffn_norm = RMSNorm(dim)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor,
mask: torch.Tensor = None):
"""
Pre-Norm + Residual连接
x → RMSNorm → Attention → +x → RMSNorm → FFN → +
"""
# Self-Attention with residual
h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
# FFN with residual
out = h + self.feed_forward(self.ffn_norm(h))
return out
4.2 完整 MiniLLaMA 模型¶
Python
from dataclasses import dataclass
@dataclass # @dataclass自动生成__init__等方法
class MiniLLaMAConfig:
"""Mini-LLaMA模型配置"""
vocab_size: int = 4096 # 词表大小(小型实验用)
dim: int = 256 # 隐藏维度
n_layers: int = 6 # Transformer层数
n_heads: int = 8 # 注意力头数
n_kv_heads: int = 2 # KV头数(GQA)
max_seq_len: int = 512 # 最大序列长度
multiple_of: int = 64 # FFN维度对齐
rope_theta: float = 10000.0 # RoPE基础频率
dropout: float = 0.1 # Dropout率
class MiniLLaMA(nn.Module):
"""
Mini-LLaMA: 完整的LLaMA-2风格语言模型
约15M参数,可在单GPU上训练
"""
def __init__(self, config: MiniLLaMAConfig):
super().__init__()
self.config = config
# Token Embedding
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
# Transformer Blocks
self.layers = nn.ModuleList([
TransformerBlock(
dim=config.dim,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
multiple_of=config.multiple_of
)
for _ in range(config.n_layers)
])
# Final RMSNorm
self.norm = RMSNorm(config.dim)
# LM Head (与embedding共享权重)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight # weight tying
# 预计算RoPE频率
head_dim = config.dim // config.n_heads
self.freqs_cis = RotaryPositionEmbedding.precompute_freqs(
head_dim, config.max_seq_len, config.rope_theta
)
# Dropout
self.dropout = nn.Dropout(config.dropout)
# 初始化权重
self.apply(self._init_weights)
def _init_weights(self, module):
"""Xavier均匀初始化"""
if isinstance(module, nn.Linear): # isinstance检查类型
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: torch.Tensor = None):
"""
前向传播
Args:
tokens: [B, L] 输入token IDs
targets: [B, L] 目标token IDs(训练时提供)
Returns:
如果有targets: loss (标量)
否则: logits [B, L, V]
"""
B, L = tokens.shape
device = tokens.device
# Token Embedding
h = self.tok_embeddings(tokens) # [B, L, D]
h = self.dropout(h)
# 获取RoPE频率
freqs_cis = self.freqs_cis[:L].to(device) # .to(device)将数据移至GPU/CPU
# 构造因果mask
mask = torch.triu(
torch.full((L, L), float('-inf'), device=device),
diagonal=1
)
# 通过所有Transformer层
for layer in self.layers:
h = layer(h, freqs_cis, mask)
# Final Norm
h = self.norm(h)
# LM Head
logits = self.output(h) # [B, L, V]
if targets is not None:
# 计算交叉熵损失
# logits: [B, L, V] → [B*L, V]
# targets: [B, L] → [B*L]
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100 # 忽略padding(PyTorch默认值)
)
return loss
return logits
@torch.no_grad() # 禁用梯度计算,节省内存(推理时使用)
def generate(self, prompt_tokens: torch.Tensor, max_new_tokens: int = 100,
temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
"""
自回归文本生成
Args:
prompt_tokens: [1, prompt_len] 提示token序列
max_new_tokens: 最大生成token数
temperature: 温度参数(越高越随机)
top_k: Top-K采样
top_p: Nucleus采样
"""
self.eval()
tokens = prompt_tokens.clone()
for _ in range(max_new_tokens):
# 截断到最大长度
context = tokens[:, -self.config.max_seq_len:]
# 前向传播获取logits
logits = self(context)
# 只取最后一个位置的logits
next_logits = logits[:, -1, :] / temperature # [1, V]
# Top-K过滤
if top_k > 0:
top_k_val = min(top_k, next_logits.size(-1))
indices_to_remove = next_logits < torch.topk(next_logits, top_k_val)[0][:, -1:]
next_logits[indices_to_remove] = float('-inf')
# Top-P (Nucleus)过滤
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除累积概率超过top_p的token
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
next_logits[indices_to_remove] = float('-inf')
# 采样
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
tokens = torch.cat([tokens, next_token], dim=-1)
return tokens
# 创建模型
config = MiniLLaMAConfig()
model = MiniLLaMA(config)
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Mini-LLaMA 模型统计:")
print(f" 总参数量: {total_params:,} ({total_params/1e6:.1f}M)")
print(f" 可训练参数: {trainable_params:,}")
print(f" 配置: dim={config.dim}, layers={config.n_layers}, heads={config.n_heads}")
print(f" KV头数: {config.n_kv_heads} (GQA, {config.n_heads//config.n_kv_heads}x压缩)")
print(f" 最大序列长度: {config.max_seq_len}")
# 测试前向传播
dummy_input = torch.randint(0, config.vocab_size, (2, 32))
dummy_target = torch.randint(0, config.vocab_size, (2, 32))
loss = model(dummy_input, dummy_target)
print(f"\n前向传播测试: loss = {loss.item():.4f}")
print(f"随机初始化loss理论值: {math.log(config.vocab_size):.4f} (ln({config.vocab_size}))")
5. BPE Tokenizer 训练¶
5.1 BPE 算法原理¶
Text Only
Byte Pair Encoding (BPE) 算法
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
初始: 将文本拆分为单个字符(或字节)
迭代: 每次合并出现频率最高的相邻对
重复: 直到达到目标词表大小
示例 (英文):
初始词表: ['l', 'o', 'w', 'e', 'r', 'n', 's', 't']
语料: "low" × 5, "lower" × 2, "newest" × 6, "lowest" × 3
Step 1: 最频繁对 ('e','s') → 合并为 'es'
Step 2: 最频繁对 ('es','t') → 合并为 'est'
Step 3: 最频繁对 ('l','o') → 合并为 'lo'
Step 4: 最频繁对 ('lo','w') → 合并为 'low'
...
最终: "lowest" → ['low', 'est']
"newest" → ['n', 'ew', 'est']
5.2 手写 BPE Tokenizer¶
Python
import re
from collections import Counter, defaultdict
class SimpleBPETokenizer:
"""
手写BPE分词器
支持中英文混合文本
"""
def __init__(self, vocab_size: int = 4096):
self.target_vocab_size = vocab_size
self.merges = {} # 合并规则: (a, b) → ab
self.vocab = {} # token → id
self.inverse_vocab = {} # id → token
# 特殊token
self.special_tokens = {
"<pad>": 0,
"<unk>": 1,
"<bos>": 2,
"<eos>": 3,
}
def _get_stats(self, word_freqs):
"""统计相邻pair的频率"""
pairs = Counter() # Counter统计元素出现次数
for word, freq in word_freqs.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i+1])] += freq
return pairs
def _merge_pair(self, pair, word_freqs):
"""合并一个pair"""
new_word_freqs = {}
bigram = ' '.join(pair)
replacement = ''.join(pair)
for word, freq in word_freqs.items():
new_word = word.replace(bigram, replacement)
new_word_freqs[new_word] = freq
return new_word_freqs
def train(self, texts: list, verbose: bool = True):
"""
训练BPE分词器
Args:
texts: 训练文本列表
verbose: 是否打印训练过程
"""
# Step 1: 分词并统计词频
word_freqs = Counter()
for text in texts:
# 简单分词:按空格和标点拆分
words = re.findall(r'[\u4e00-\u9fff]|[a-zA-Z]+|[0-9]+|[^\s\w]', text) # re.findall正则查找所有匹配项
for word in words:
# 将每个词拆分为字符(空格分隔)
char_word = ' '.join(list(word))
word_freqs[char_word] += 1
# Step 2: 初始化词表(所有出现的字符)
self.vocab = dict(self.special_tokens)
chars = set()
for word in word_freqs:
for char in word.split():
chars.add(char)
for char in sorted(chars):
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
initial_vocab_size = len(self.vocab)
if verbose:
print(f"初始词表大小: {initial_vocab_size}")
# Step 3: 迭代合并
num_merges = self.target_vocab_size - initial_vocab_size
for i in range(num_merges):
# 统计pair频率
pairs = self._get_stats(word_freqs)
if not pairs:
break
# 找出最频繁的pair
best_pair = max(pairs, key=pairs.get)
best_freq = pairs[best_pair]
if best_freq < 2: # 频率太低则停止
break
# 合并
word_freqs = self._merge_pair(best_pair, word_freqs)
# 记录合并规则
merged_token = ''.join(best_pair)
self.merges[best_pair] = merged_token
if merged_token not in self.vocab:
self.vocab[merged_token] = len(self.vocab)
if verbose and (i + 1) % 100 == 0:
print(f" 合并 {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' → '{merged_token}' (freq={best_freq})")
# 构建反向词表
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
if verbose:
print(f"最终词表大小: {len(self.vocab)}")
print(f"合并规则数: {len(self.merges)}")
def encode(self, text: str) -> list:
"""将文本编码为token ID序列"""
# 分词
words = re.findall(r'[\u4e00-\u9fff]|[a-zA-Z]+|[0-9]+|[^\s\w]', text)
ids = []
for word in words:
# 拆分为字符
symbols = list(word)
# 应用合并规则(贪婪)
while len(symbols) > 1:
# 找出可以合并的pair
best_pair = None
best_idx = -1
for i in range(len(symbols) - 1):
pair = (symbols[i], symbols[i+1])
if pair in self.merges:
best_pair = pair
best_idx = i
break # 按训练顺序优先合并
if best_pair is None:
break
# 执行合并
symbols = (symbols[:best_idx] +
[self.merges[best_pair]] +
symbols[best_idx+2:])
# 转换为ID
for sym in symbols:
ids.append(self.vocab.get(sym, self.special_tokens["<unk>"]))
return ids
def decode(self, ids: list) -> str:
"""将token ID序列解码为文本"""
tokens = [self.inverse_vocab.get(id, "<unk>") for id in ids]
return ''.join(tokens)
# === 训练Tokenizer ===
# 中文训练语料(示例)
training_texts = [
"人工智能正在改变世界",
"深度学习是机器学习的一个分支",
"大语言模型具有强大的文本生成能力",
"自然语言处理是人工智能的重要方向",
"Transformer架构改变了深度学习的格局",
"注意力机制是现代神经网络的核心组件",
"预训练语言模型通过大规模数据学习语言知识",
"GPT系列模型展示了规模扩展的威力",
"BERT模型在自然语言理解任务上取得了突破",
"强化学习可以用于优化语言模型的输出",
] * 50 # 重复增加频率
tokenizer = SimpleBPETokenizer(vocab_size=500)
tokenizer.train(training_texts)
# 测试
test_text = "人工智能改变世界"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print(f"\n编码测试:")
print(f" 原文: {test_text}")
print(f" 编码: {encoded}")
print(f" 解码: {decoded}")
print(f" tokens: {[tokenizer.inverse_vocab.get(id, '?') for id in encoded]}")
6. 预训练 Pipeline¶
6.1 数据准备¶
Python
import torch
from torch.utils.data import Dataset, DataLoader
import os
class TextDataset(Dataset):
"""
用于CLM预训练的文本数据集
将文本编码后切分为固定长度的序列
"""
def __init__(self, texts: list, tokenizer, seq_len: int = 128):
self.seq_len = seq_len
# 编码所有文本
all_ids = []
for text in texts:
ids = tokenizer.encode(text)
all_ids.extend(ids)
all_ids.append(tokenizer.special_tokens["<eos>"])
# 切分为固定长度序列
self.data = []
for i in range(0, len(all_ids) - seq_len - 1, seq_len):
input_ids = all_ids[i : i + seq_len]
target_ids = all_ids[i + 1 : i + seq_len + 1]
self.data.append((
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(target_ids, dtype=torch.long)
))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
print("数据集准备完成")
6.2 训练循环¶
Python
import time
class MiniLLaMATrainer:
"""Mini-LLaMA训练器"""
def __init__(self, model, config, tokenizer, device='cpu'):
self.model = model.to(device)
self.config = config
self.tokenizer = tokenizer
self.device = device
# 优化器: AdamW
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.95),
weight_decay=0.1
)
# 学习率调度: Cosine with warmup
self.warmup_steps = 100
self.total_steps = 0
def _get_lr(self, step):
"""Cosine学习率调度 with warmup"""
if step < self.warmup_steps:
return 3e-4 * step / self.warmup_steps
# Cosine decay
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
return 3e-4 * 0.5 * (1 + math.cos(math.pi * progress))
def train(self, train_dataset, num_epochs=10, batch_size=8,
log_interval=50, eval_interval=200):
"""
训练循环
"""
dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
self.total_steps = num_epochs * len(dataloader)
self.model.train()
global_step = 0
best_loss = float('inf')
print(f"开始训练 Mini-LLaMA")
print(f" Epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print(f" Total steps: {self.total_steps}")
print(f" Device: {self.device}")
print("=" * 60)
for epoch in range(num_epochs):
epoch_loss = 0
epoch_start = time.time()
for batch_idx, (input_ids, target_ids) in enumerate(dataloader): # enumerate同时获取索引和元素
input_ids = input_ids.to(self.device)
target_ids = target_ids.to(self.device)
# 更新学习率
lr = self._get_lr(global_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# 前向传播
loss = self.model(input_ids, target_ids)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
epoch_loss += loss.item()
global_step += 1
# 日志
if global_step % log_interval == 0:
avg_loss = epoch_loss / (batch_idx + 1)
ppl = math.exp(min(avg_loss, 20)) # 限制防止溢出
elapsed = time.time() - epoch_start
print(f" Step {global_step}: loss={avg_loss:.4f}, ppl={ppl:.2f}, lr={lr:.6f}, time={elapsed:.1f}s")
# 生成样本
if global_step % eval_interval == 0:
self._generate_sample()
# Epoch结束统计
avg_epoch_loss = epoch_loss / len(dataloader)
epoch_time = time.time() - epoch_start
print(f"\nEpoch {epoch+1}/{num_epochs}: avg_loss={avg_epoch_loss:.4f}, time={epoch_time:.1f}s")
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
# 保存最佳模型
torch.save(self.model.state_dict(), 'mini_llama_best.pt')
print(f" ✓ 保存最佳模型 (loss={best_loss:.4f})")
print()
print(f"训练完成! 最佳loss: {best_loss:.4f}")
def _generate_sample(self):
"""训练中生成样本以观察效果"""
self.model.eval()
# 用几个起始token生成
prompt = torch.tensor([[self.tokenizer.special_tokens["<bos>"]]]).to(self.device)
generated = self.model.generate(
prompt, max_new_tokens=50, temperature=0.8, top_k=40
)
text = self.tokenizer.decode(generated[0].tolist())
print(f" [生成样本] {text[:100]}...")
self.model.train()
# === 端到端训练示例 ===
print("端到端预训练流程:")
print("1. 准备语料 → 2. 训练Tokenizer → 3. 构建数据集 → 4. 训练模型 → 5. 生成文本")
print()
# 准备训练语料(实际应用中使用大规模语料)
corpus = [
"人工智能是计算机科学的一个重要分支,它研究如何让计算机模拟人类的智能行为。",
"深度学习通过多层神经网络自动学习数据的特征表示,是目前最成功的机器学习方法。",
"自然语言处理让计算机能够理解和生成人类语言,包括文本分类、机器翻译等任务。",
"大语言模型通过在海量文本上预训练,学习到了丰富的语言知识和世界知识。",
"Transformer架构使用自注意力机制,能够并行处理序列中的所有位置。",
"注意力机制允许模型在处理每个位置时关注序列中的其他相关位置。",
"预训练加微调的范式极大地提高了模型在各种下游任务上的表现。",
"强化学习通过试错的方式让智能体学习最优策略,在游戏和机器人控制中表现出色。",
] * 100 # 实际需要大量数据
# 1. 训练Tokenizer
tok = SimpleBPETokenizer(vocab_size=500)
tok.train(corpus, verbose=False)
# 2. 更新模型配置
train_config = MiniLLaMAConfig(
vocab_size=len(tok.vocab),
dim=128,
n_layers=4,
n_heads=4,
n_kv_heads=2,
max_seq_len=128,
multiple_of=32
)
# 3. 创建模型
train_model = MiniLLaMA(train_config)
params = sum(p.numel() for p in train_model.parameters())
print(f"模型参数量: {params:,} ({params/1e6:.2f}M)")
# 4. 创建数据集
train_dataset = TextDataset(corpus, tok, seq_len=64)
print(f"训练样本数: {len(train_dataset)}")
# 5. 训练(实际训练需要更多数据和更长时间)
# trainer = MiniLLaMATrainer(train_model, train_config, tok, device='cpu')
# trainer.train(train_dataset, num_epochs=5, batch_size=8)
print("\n注意: 完整训练需要:")
print(" - 几GB~几十GB的中文语料")
print(" - GPU加速(推荐至少一张RTX 3090)")
print(" - 数小时到数天的训练时间")
print(" - 更大的模型配置(dim=512+, layers=12+)")
6.3 梯度累积与断点续训¶
在实际训练中,GPU 显存往往不足以容纳大 batch,此时梯度累积是标配技巧;而断点续训则是长时间训练的保险机制。
Python
import json
import os
class RobustTrainer:
"""
支持梯度累积 + 断点续训的训练器
梯度累积原理:
┌─────────────────────────────────────────────────────┐
│ 等效 batch_size = micro_batch × accumulation_steps │
│ │
│ 例: 想要 batch=64, 但 GPU 只能跑 batch=8 │
│ → accumulation_steps = 64/8 = 8 │
│ → 每8个micro-batch才更新一次参数 │
└─────────────────────────────────────────────────────┘
"""
def __init__(self, model, config, tokenizer, device='cpu',
accumulation_steps=4, checkpoint_dir='checkpoints'):
self.model = model.to(device)
self.config = config
self.tokenizer = tokenizer
self.device = device
self.accumulation_steps = accumulation_steps
self.checkpoint_dir = checkpoint_dir
os.makedirs(checkpoint_dir, exist_ok=True)
self.optimizer = torch.optim.AdamW(
model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1
)
self.global_step = 0
self.best_loss = float('inf')
def train_epoch(self, dataloader, epoch):
self.model.train()
epoch_loss = 0
self.optimizer.zero_grad() # 每个epoch开头清零梯度
for step, (input_ids, target_ids) in enumerate(dataloader):
input_ids = input_ids.to(self.device)
target_ids = target_ids.to(self.device)
# 前向传播(损失已除以 accumulation_steps 以保持等效)
loss = self.model(input_ids, target_ids) / self.accumulation_steps
loss.backward() # 梯度累积,不立即清零
epoch_loss += loss.item() * self.accumulation_steps
# 每 accumulation_steps 步执行一次参数更新
if (step + 1) % self.accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
return epoch_loss / len(dataloader)
def save_checkpoint(self, epoch, loss):
"""保存完整训练状态,支持精确恢复"""
checkpoint = {
'epoch': epoch,
'global_step': self.global_step,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': loss,
'best_loss': self.best_loss,
'config': vars(self.config),
}
path = os.path.join(self.checkpoint_dir, f'ckpt_epoch_{epoch}.pt')
torch.save(checkpoint, path)
# 同时保存最佳模型
if loss < self.best_loss:
self.best_loss = loss
best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
torch.save(checkpoint, best_path)
print(f" ✓ 新最佳模型 (loss={loss:.4f})")
def load_checkpoint(self, path):
"""从检查点恢复训练状态"""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.global_step = checkpoint['global_step']
self.best_loss = checkpoint['best_loss']
print(f" ✓ 从 epoch {checkpoint['epoch']} 恢复, "
f"step={checkpoint['global_step']}, loss={checkpoint['loss']:.4f}")
return checkpoint['epoch']
# === 使用示例 ===
print("梯度累积 vs 直接大 batch:")
print("┌────────────────────┬──────────────┬──────────────┬────────────┐")
print("│ 方案 │ micro_batch │ accum_steps │ 等效batch │")
print("├────────────────────┼──────────────┼──────────────┼────────────┤")
print("│ 直接训练 │ 64 │ 1 │ 64 │")
print("│ 梯度累积(显存不足) │ 8 │ 8 │ 64 │")
print("│ 梯度累积(极端情况) │ 1 │ 64 │ 64 │")
print("└────────────────────┴──────────────┴──────────────┴────────────┘")
print()
print("断点续训使用方式:")
print(" trainer = RobustTrainer(model, config, tok, accumulation_steps=8)")
print(" # 如需恢复:")
print(" # start_epoch = trainer.load_checkpoint('checkpoints/best_model.pt')")
print(" # for epoch in range(start_epoch + 1, total_epochs):")
7. 模型评估与生成¶
7.1 困惑度评估¶
Python
@torch.no_grad()
def evaluate_perplexity(model, dataset, batch_size=8, device='cpu'):
"""
计算模型在数据集上的困惑度(Perplexity)
PPL = exp(avg_loss)
PPL越低,模型越好
"""
model.eval()
model = model.to(device)
dataloader = DataLoader(dataset, batch_size=batch_size)
total_loss = 0
total_tokens = 0
for input_ids, target_ids in dataloader:
input_ids = input_ids.to(device)
target_ids = target_ids.to(device)
loss = model(input_ids, target_ids)
num_tokens = (target_ids != -100).sum().item() # 与 ignore_index=-100 保持一致
total_loss += loss.item() * num_tokens
total_tokens += num_tokens
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity, avg_loss
# 如果模型已训练好:
# ppl, loss = evaluate_perplexity(train_model, train_dataset)
# print(f"困惑度: {ppl:.2f}, 平均Loss: {loss:.4f}")
7.2 多种采样策略对比¶
Python
def generate_text(model, tokenizer, prompt, max_tokens=100,
strategy='top_p', **kwargs):
"""
文本生成:支持多种采样策略
"""
device = next(model.parameters()).device
# 编码prompt
ids = tokenizer.encode(prompt)
tokens = torch.tensor([ids]).to(device)
if strategy == 'greedy':
# 贪婪搜索:每步选概率最高的
generated = model.generate(tokens, max_tokens, temperature=0.01, top_k=1)
elif strategy == 'temperature':
# 温度采样
temp = kwargs.get('temperature', 1.0)
generated = model.generate(tokens, max_tokens, temperature=temp, top_k=0, top_p=1.0)
elif strategy == 'top_k':
# Top-K采样
k = kwargs.get('k', 50)
generated = model.generate(tokens, max_tokens, temperature=0.8, top_k=k)
elif strategy == 'top_p':
# Nucleus (Top-P) 采样
p = kwargs.get('p', 0.9)
generated = model.generate(tokens, max_tokens, temperature=0.8, top_p=p)
else:
raise ValueError(f"Unknown strategy: {strategy}")
return tokenizer.decode(generated[0].tolist())
print("""
采样策略对比:
┌────────────┬────────────────────┬───────────────┐
│ 策略 │ 特点 │ 适用场景 │
├────────────┼────────────────────┼───────────────┤
│ Greedy │ 确定性,最保守 │ 翻译/摘要 │
│ Temperature│ 温度高→随机 │ 通用 │
│ Top-K │ 只从前K个候选采样 │ 创作 │
│ Top-P │ 动态K, 累积概率≤P │ 对话/创作 │
└────────────┴────────────────────┴───────────────┘
推荐设置:
- 创意写作: temperature=1.0, top_p=0.95
- 对话: temperature=0.7, top_p=0.9
- 代码/事实: temperature=0.2, top_k=10
""")
8. 工程优化技巧¶
8.1 KV Cache¶
Python
class MiniLLaMAWithKVCache(nn.Module):
"""
带KV Cache的推理优化
问题: 自回归生成时,每生成一个token都要重新计算所有位置的KV
优化: 缓存已计算的KV,新token只需计算增量
无Cache: 生成100个token → 计算 1+2+...+100 = 5050 次KV
有Cache: 生成100个token → 计算 100 次KV
"""
def __init__(self, config):
super().__init__()
# ... 复用MiniLLaMA的初始化 ...
self.config = config
# KV Cache的核心思想(伪代码)
def generate_with_cache(self, prompt_tokens, max_new_tokens=100):
"""
KV Cache加速生成
"""
# 初始化cache: 每层存储K和V
kv_cache = [
{"k": None, "v": None}
for _ in range(self.config.n_layers)
]
# Prefill阶段:处理整个prompt
# 此时计算所有位置的KV并缓存
logits = self.forward_with_cache(prompt_tokens, kv_cache, is_prefill=True)
tokens = prompt_tokens
for step in range(max_new_tokens):
# Decode阶段:每次只处理新增的1个token
next_token = logits[:, -1:, :].argmax(dim=-1)
tokens = torch.cat([tokens, next_token], dim=-1)
# 只传入新token,利用cache中已有的KV
logits = self.forward_with_cache(next_token, kv_cache, is_prefill=False)
return tokens
print("""
KV Cache性能提升:
┌───────────────┬──────────────────┬──────────────────┐
│ 生成长度 │ 无Cache计算量 │ 有Cache计算量 │
├───────────────┼──────────────────┼──────────────────┤
│ 100 tokens │ 5050 × KV_cost │ 100 × KV_cost │
│ 1000 tokens │ 500500 × KV_cost │ 1000 × KV_cost │
│ 加速比 │ ~50x (100) │ ~500x (1000) │
└───────────────┴──────────────────┴──────────────────┘
显存占用:
KV Cache大小 = 2 × n_layers × n_kv_heads × head_dim × seq_len × batch_size
对于LLaMA-2-7B (seq_len=4096, batch=1):
= 2 × 32 × 32 × 128 × 4096 × 2 bytes ≈ 2GB
""")
8.2 混合精度训练¶
Python
# 使用PyTorch AMP (Automatic Mixed Precision)
from torch.amp import autocast, GradScaler
def train_step_mixed_precision(model, optimizer, input_ids, target_ids, scaler):
"""
混合精度训练步骤
FP32运算 → 改为 FP16 → 速度2x, 显存减半
但某些操作(如loss计算、归一化)需要保持FP32
"""
optimizer.zero_grad()
# 自动混合精度: 前向传播中适当使用FP16
with autocast(device_type='cuda'):
loss = model(input_ids, target_ids)
# 缩放loss以防止FP16下梯度下溢
scaler.scale(loss).backward()
# 梯度裁剪(需要先unscale)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# 优化器步骤
scaler.step(optimizer)
scaler.update()
return loss.item()
print("""
混合精度训练好处:
1. 训练速度提升 ~2x(利用Tensor Cores)
2. 显存占用减少 ~50%
3. 模型精度基本不受影响
关键技术:
- GradScaler: 防止FP16梯度下溢
- autocast: 自动选择合适的精度
- 保持关键操作在FP32: LayerNorm, Loss, Softmax
""")
8.3 Flash Attention¶
Flash Attention 是现代 LLM 训练和推理的标配加速技术。它通过数学上的分块计算(tiling)避免了实例化完整的 \(N \times N\) 注意力矩阵,从而同时节省显存和加速计算。
Python
# === Flash Attention 使用方式 ===
# 方式1: PyTorch 2.0+ 内置 SDPA(推荐)
# PyTorch 的 scaled_dot_product_attention 会自动选择最优后端
import torch.nn.functional as F
def attention_with_sdpa(q, k, v, dropout_p=0.0, is_causal=True):
"""
使用 PyTorch SDPA 的注意力计算
SDPA 会自动选择后端:
1. FlashAttention-2 (GPU, 头维度 ≤ 256)
2. Memory-Efficient Attention (xformers 风格)
3. Math 实现 (fallback)
Args:
q: [batch, n_heads, seq_len, head_dim]
k: [batch, n_heads, seq_len, head_dim]
v: [batch, n_heads, seq_len, head_dim]
"""
# 一行代码替代手动实现!
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None, # is_causal=True 时自动生成因果mask
dropout_p=dropout_p,
is_causal=is_causal
)
return output
# 方式2: 在我们的 MiniLLaMA 中集成
# 注意:以下辅助函数需要在使用 FlashAttentionLayer 之前定义
def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
"""
应用旋转位置编码(cos/sin 分离形式)
xq, xk: [batch, seq_len, n_heads, head_dim]
freqs_cos, freqs_sin: [seq_len, head_dim]
"""
# 将 cos/sin reshape 为 [1, seq_len, 1, head_dim] 以便广播
freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2) # [1, L, 1, d]
freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2)
# 将向量分成两半,应用旋转
xq_r, xq_i = xq.float().chunk(2, dim=-1) # 实部和虚部
xk_r, xk_i = xk.float().chunk(2, dim=-1)
# 旋转: [cos·x_r - sin·x_i, sin·x_r + cos·x_i]
xq_out = torch.cat([xq_r * freqs_cos - xq_i * freqs_sin,
xq_r * freqs_sin + xq_i * freqs_cos], dim=-1)
xk_out = torch.cat([xk_r * freqs_cos - xk_i * freqs_sin,
xk_r * freqs_sin + xk_i * freqs_cos], dim=-1)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x, n_rep):
"""将 KV 头重复以匹配 Q 头数量(GQA 支持)"""
if n_rep == 1:
return x
bsz, seqlen, n_kv_heads, head_dim = x.shape
# [B, L, n_kv, d] → [B, L, n_kv, n_rep, d] → [B, L, n_kv*n_rep, d]
return (
x[:, :, :, None, :]
.expand(bsz, seqlen, n_kv_heads, n_rep, head_dim)
.reshape(bsz, seqlen, n_kv_heads * n_rep, head_dim)
)
class FlashAttentionLayer(nn.Module):
"""支持 Flash Attention 的注意力层"""
def __init__(self, dim, n_heads, n_kv_heads):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.n_rep = n_heads // n_kv_heads
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
# 检测 Flash Attention 是否可用
self.flash = hasattr(F, 'scaled_dot_product_attention')
def forward(self, x, freqs_cos, freqs_sin):
bsz, seqlen, _ = x.shape
xq = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim)
xk = self.wk(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = self.wv(x).view(bsz, seqlen, self.n_kv_heads, self.head_dim)
# 应用 RoPE(使用上面定义的辅助函数)
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# GQA: 扩展 KV
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# 转置为 [batch, heads, seq, dim]
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
if self.flash:
# Flash Attention: 无需显式 O(N²) 显存!
output = F.scaled_dot_product_attention(
xq, xk, xv, attn_mask=None,
dropout_p=0.0, is_causal=True
)
else:
# Fallback: 手动实现
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
mask = torch.triu(torch.full((seqlen, seqlen), float('-inf')), diagonal=1)
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, xv)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
print("""
Flash Attention 核心优势:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
标准注意力 Flash Attention
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
显存复杂度 O(N²) O(N) ← 关键!
计算量 O(N²d) O(N²d) ← 相同
是否实例化 S 矩阵 是 否 ← 分块计算
支持序列长度提升 ~2K ~32K+ ← 线性显存
加速比 (vs 标准) 1× 2-4× ← IO优化
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
原理简述:
Flash Attention 将 Q, K, V 分成小块 (tiles),在 GPU SRAM 中
完成 "分块注意力",避免将 N×N 的注意力矩阵写入 HBM。
这是一种 "IO-aware" 的精确注意力(非近似),数学结果完全一致。
使用建议:
- PyTorch ≥ 2.0: 直接用 F.scaled_dot_product_attention
- 训练时: 配合 torch.compile() 进一步加速
- 推理时: 与 KV Cache 配合使用
- 注意: head_dim ≤ 256 时才能使用 Flash Attention
""")
练习题¶
练习 1 :组件验证¶
- 修改 GQA 的
n_kv_heads为 1 ( MQA )和n_heads( MHA ),对比参数量差异 - 将 RMSNorm 替换为 nn.LayerNorm ,观察训练稳定性差异
💡 参考答案
**1. GQA → MQA / MHA 参数量对比** 以 `n_heads=32, head_dim=128, n_layers=32` 为例: | 配置 | n_kv_heads | KV 参数/层 | 总 KV 参数 | |------|-----------|-----------|-----------| | MHA | 32 | $2 \times 32 \times 128 \times d_{model}$ | 最多 | | GQA (8组) | 8 | $2 \times 8 \times 128 \times d_{model}$ | 减少 75% | | MQA | 1 | $2 \times 1 \times 128 \times d_{model}$ | 减少 ~97% | 关键代码修改:Python
**2. RMSNorm vs LayerNorm 训练稳定性** # MQA: n_kv_heads = 1
model_mqa = MiniLLaMA(
dim=4096, n_layers=32, n_heads=32, n_kv_heads=1, # MQA
vocab_size=32000, max_seq_len=2048, intermediate_dim=11008
)
# MHA: n_kv_heads = n_heads
model_mha = MiniLLaMA(
dim=4096, n_layers=32, n_heads=32, n_kv_heads=32, # MHA
vocab_size=32000, max_seq_len=2048, intermediate_dim=11008
)
# 参数量对比
def count_params(model):
return sum(p.numel() for p in model.parameters())
print(f"MQA: {count_params(model_mqa)/1e6:.1f}M")
print(f"GQA: {count_params(model_gqa)/1e6:.1f}M")
print(f"MHA: {count_params(model_mha)/1e6:.1f}M")
Python
# 替换为 LayerNorm
class TransformerBlockWithLN(nn.Module):
def __init__(self, ...):
# 替换 RMSNorm
self.attention_norm = nn.LayerNorm(dim) # 原来是 RMSNorm(dim)
self.ffn_norm = nn.LayerNorm(dim) # 原来是 RMSNorm(dim)
# 观察要点:
# - LayerNorm 需要计算均值和方差,RMSNorm 只计算均方根
# - 在 bf16 精度下,RMSNorm 通常更稳定(减少了减均值的数值误差)
# - 实验中观察训练 loss 曲线:RMSNorm 通常更平滑
# - 速度上 RMSNorm 略快(约 10-30%),因为计算量更少
练习 2 : Tokenizer 扩展¶
- 为 Tokenizer 添加
<pad>,<unk>等特殊 token 的正确处理 - 实现 WordPiece 分词器并与 BPE 对比
💡 参考答案
**1. 添加特殊 token**Python
**2. WordPiece 分词器核心实现** class BPETokenizerWithSpecial:
def __init__(self, vocab_path):
self.special_tokens = {
"<pad>": len(self.vocab), # padding
"<unk>": len(self.vocab) + 1, # unknown
"<bos>": len(self.vocab) + 2, # beginning of sequence
"<eos>": len(self.vocab) + 3, # end of sequence
}
self.vocab.update(self.special_tokens)
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
def encode(self, text, add_special=True):
tokens = self._bpe_encode(text)
ids = [self.vocab.get(t, self.special_tokens["<unk>"]) for t in tokens]
if add_special:
ids = [self.special_tokens["<bos>"]] + ids + [self.special_tokens["<eos>"]]
return ids
def pad_batch(self, batch_ids, pad_id=None):
"""将不同长度的序列 padding 到同一长度"""
if pad_id is None:
pad_id = self.special_tokens["<pad>"]
max_len = max(len(ids) for ids in batch_ids)
padded = [ids + [pad_id] * (max_len - len(ids)) for ids in batch_ids]
mask = [[1] * len(ids) + [0] * (max_len - len(ids)) for ids in batch_ids]
return padded, mask
Python
class WordPieceTokenizer:
"""WordPiece 分词:与 BPE 的区别在于合并策略"""
def __init__(self, vocab, max_word_len=200):
self.vocab = vocab
self.max_word_len = max_word_len
def tokenize_word(self, word):
"""贪心最长匹配:从左到右匹配最长子词"""
tokens = []
start = 0
while start < len(word):
end = len(word)
found = False
while end > start:
substr = word[start:end]
if start > 0:
substr = "##" + substr # 非词首加 ## 前缀
if substr in self.vocab:
tokens.append(substr)
found = True
break
end -= 1
if not found:
tokens.append("[UNK]")
start += 1
else:
start = end
return tokens
# BPE vs WordPiece 对比:
# - BPE: 合并频率最高的相邻 token 对(自底向上)
# - WordPiece: 选择使似然增长最大的 token 对(语言模型概率)
# - 实际效果:WordPiece 在处理 OOV 词时更保守,更倾向于保留常见子词
练习 3 :预训练实验¶
- 在 WikiText-2 数据集上预训练 Mini-LLaMA
- 对比不同超参数( dim, n_layers, lr )对困惑度的影响
- 实现学习率 Warmup + Cosine Decay 调度器
💡 参考答案
**1. WikiText-2 预训练示例**Python
**2. 超参数对比实验** from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, text, tokenizer, seq_len=256):
tokens = tokenizer.encode(text)
self.data = tokens
self.seq_len = seq_len
def __len__(self):
return max(0, len(self.data) - self.seq_len - 1)
def __getitem__(self, idx):
x = torch.tensor(self.data[idx:idx+self.seq_len], dtype=torch.long)
y = torch.tensor(self.data[idx+1:idx+self.seq_len+1], dtype=torch.long)
return x, y
# 加载 WikiText-2
with open("wikitext-2/train.txt", "r") as f:
train_text = f.read()
with open("wikitext-2/valid.txt", "r") as f:
val_text = f.read()
train_dataset = TextDataset(train_text, tokenizer, seq_len=256)
val_dataset = TextDataset(val_text, tokenizer, seq_len=256)
Python
**3. Warmup + Cosine Decay 调度器** # 实验配置
configs = [
{"dim": 256, "n_layers": 4, "n_heads": 4, "lr": 3e-4}, # 小模型
{"dim": 512, "n_layers": 6, "n_heads": 8, "lr": 1e-4}, # 中模型
{"dim": 512, "n_layers": 6, "n_heads": 8, "lr": 3e-4}, # 中模型+高lr
{"dim": 768, "n_layers": 8, "n_heads": 12, "lr": 1e-4}, # 大模型
]
# 预期结果(WikiText-2, ~10K steps):
# | dim | n_layers | lr | Val PPL |
# |-----|----------|-------|---------|
# | 256 | 4 | 3e-4 | ~35-45 |
# | 512 | 6 | 1e-4 | ~20-28 |
# | 512 | 6 | 3e-4 | ~22-30 | (可能不稳定)
# | 768 | 8 | 1e-4 | ~15-22 |
Python
import math
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr=1e-5):
"""Warmup + Cosine Decay 学习率调度"""
if step < warmup_steps:
# 线性 warmup
return max_lr * step / warmup_steps
# Cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
# 使用示例
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
for step in range(total_steps):
lr = get_cosine_lr(step, warmup_steps=500, total_steps=10000, max_lr=3e-4)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# ... training step ...
练习 4 :生成优化¶
- 实现 Beam Search 解码
- 实现 Repetition Penalty (减少重复生成)
- 实现 KV Cache 并测量推理速度提升
💡 参考答案
**1. Beam Search 解码**Python
**2. Repetition Penalty** ```python> def generate_with_repetition_penalty(model, tokenizer, prompt, max_len=100, penalty=1.2, temperature=0.8): """带重复惩罚的生成""" input_ids = tokenizer.encode(prompt) generated = input_ids.copy() for _ in range(max_len): with torch.no_grad(): logits = model(torch.tensor([generated]))[0, -1] # 对已生成的 token 施加惩罚 for token_id in set(generated): if logits[token_id] > 0: logits[token_id] /= penalty # 正logits除以penalty else: logits[token_id] *= penalty # 负logits乘以penalty # Temperature sampling logits = logits / temperature probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1).item() if next_token == tokenizer.eos_id: break generated.append(next_token) return generated def beam_search(model, tokenizer, prompt, beam_width=5, max_len=100, length_penalty=1.0):
"""Beam Search 解码"""
input_ids = tokenizer.encode(prompt)
# beams: [(累计log概率, token_ids)]
beams = [(0.0, input_ids)]
for _ in range(max_len):
candidates = []
for log_prob, seq in beams:
if seq[-1] == tokenizer.eos_id:
candidates.append((log_prob, seq))
continue
with torch.no_grad():
logits = model(torch.tensor([seq]))[0, -1] # 最后一个token的logits
log_probs = torch.log_softmax(logits, dim=-1)
topk_log_probs, topk_ids = log_probs.topk(beam_width)
for i in range(beam_width):
new_log_prob = log_prob + topk_log_probs[i].item()
new_seq = seq + [topk_ids[i].item()]
# 长度惩罚:避免beam偏向短序列
normalized = new_log_prob / (len(new_seq) ** length_penalty)
candidates.append((normalized, new_seq))
# 保留top-k
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_width]
# 所有beam都结束
if all(seq[-1] == tokenizer.eos_id for _, seq in beams):
break
return max(beams, key=lambda x: x[0])[1]
Text Only
**3. KV Cache 推理加速测量**
```python
import time
def measure_kv_cache_speedup(model, tokenizer, prompt, max_new_tokens=100):
"""测量 KV Cache 带来的推理加速"""
input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([input_ids])
# 1. 无 KV Cache(每步重新计算所有token)
model.use_kv_cache = False
start = time.time()
generated = input_ids.copy()
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(torch.tensor([generated]))[0, -1]
next_token = logits.argmax().item()
generated.append(next_token)
time_no_cache = time.time() - start
# 2. 有 KV Cache(只计算新token)
model.use_kv_cache = True
model.clear_kv_cache()
start = time.time()
with torch.no_grad():
# Prefill: 处理所有输入token
logits = model(input_tensor)
past_kv = model.get_kv_cache()
generated = input_ids.copy()
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
for _ in range(max_new_tokens - 1):
with torch.no_grad():
# Decode: 只处理1个新token + 缓存的KV
logits = model(torch.tensor([[next_token]]), past_kv=past_kv)
past_kv = model.get_kv_cache()
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
time_with_cache = time.time() - start
speedup = time_no_cache / time_with_cache
print(f"无 KV Cache: {time_no_cache:.2f}s")
print(f"有 KV Cache: {time_with_cache:.2f}s")
print(f"加速比: {speedup:.1f}x")
return speedup
# 预期结果(prompt=50 tokens, max_new_tokens=100):
# 无 KV Cache: ~5-10s
# 有 KV Cache: ~0.5-2s
# 加速比: ~3-10x(序列越长加速越明显)
📝 本章小结¶
| 知识点 | 掌握程度检查 |
|---|---|
| LLaMA 架构特点 | 能否说出与原始 Transformer 的 5 个主要区别? |
| RMSNorm | 能否手写实现并解释为什么比 LayerNorm 好? |
| RoPE | 能否解释旋转位置编码的核心思想? |
| SwiGLU | 能否解释门控机制的作用? |
| GQA | 能否解释 MHA→GQA→MQA 的演进逻辑? |
| BPE Tokenizer | 能否手写 BPE 训练过程? |
| CLM 预训练 | 能否解释 Next Token Prediction 的训练流程? |
| KV Cache | 能否解释 KV Cache 加速推理的原理? |
🔗 后续学习路径¶
- Transformer 深入理解 → 01-Transformer 深入理解
- 手写 Transformer → 03-手写 Transformer 完整实现
- 高效微调 → 01-高效微调技术
- LoRA 从零实现 → 06-LoRA 从零实现
- 大模型预训练理论 → 03-大模型预训练
📚 参考资料¶
- Touvron et al. "LLaMA: Open and Efficient Foundation Language Models" (2023)
- Touvron et al. "Llama 2: Open Foundation and Fine-Tuned Chat Models" (2023)
- Su et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
- Zhang & Sennrich "Root Mean Square Layer Normalization" (2019)
- Shazeer "GLU Variants Improve Transformer" (2020) — SwiGLU
- Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models" (2023)
- Sennrich et al. "Neural Machine Translation of Rare Words with Subword Units" (2016) — BPE
- Kaplan et al. "Scaling Laws for Neural Language Models" (2020)
- Korthikanti et al. "Reducing Activation Recomputation in Large Transformer Models" (2022)
- Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)
最后更新日期: 2026-04-20 适用版本: LLM 学习教程 v2026