从零搭建小型LLM¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:从零实现一个完整的小型大语言模型(Mini-LLM),包括LLaMA架构实现、Tokenizer训练、预训练流程,真正理解LLM的内部工作原理。
📌 定位说明:对标 happy-llm Ch5(动手搭建大模型),我们的实现覆盖完整的LLaMA2架构(RoPE/RMSNorm/SwiGLU/GQA)、BPE Tokenizer训练、以及端到端预训练。代码可运行、形状标注清晰、工程实践更贴近实际。
目录¶
- 1. 为什么要从零搭建LLM
- 2. LLaMA架构全貌
- 3. 核心组件实现
- 4. 完整LLaMA模型
- 5. BPE Tokenizer训练
- 6. 预训练Pipeline
- 7. 模型评估与生成
- 8. 工程优化技巧
1. 为什么要从零搭建LLM¶
Text Only
从零搭建LLM的学习价值:
━━━━━━━━━━━━━━━━━━━━━━━━━━━
1. 彻底理解每一行代码的含义
2. 明白"大模型"到底"大"在哪里
3. 掌握预训练的核心流程
4. 为后续的微调/部署/优化打下基础
5. 面试时能从容讲解底层实现
我们的实现目标:
├── 模型架构: LLaMA-2风格 (~15M参数)
├── Tokenizer: 训练一个BPE分词器
├── 预训练: 在小规模中文语料上CLM训练
└── 生成: 实现自回归文本生成 + 多种采样策略
2. LLaMA架构全貌¶
2.1 架构概览¶
Text Only
LLaMA-2 架构(我们的Mini版本)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入 Token IDs: [B, L]
↓
┌─────────────────────────┐
│ Token Embedding │ [B, L] → [B, L, D]
└─────────────────────────┘
↓
┌─────────────────────────┐ ×N layers
│ ┌───────────────────┐ │
│ │ RMSNorm │ │ Pre-Norm
│ └───────────────────┘ │
│ ↓ │
│ ┌───────────────────┐ │
│ │ GQA Attention │ │ + RoPE位置编码
│ │ (因果Mask) │ │
│ └───────────────────┘ │
│ ↓ + Residual │
│ ┌───────────────────┐ │
│ │ RMSNorm │ │ Pre-Norm
│ └───────────────────┘ │
│ ↓ │
│ ┌───────────────────┐ │
│ │ SwiGLU FFN │ │ 8/3 × D hidden
│ └───────────────────┘ │
│ ↓ + Residual │
└─────────────────────────┘
↓
┌─────────────────────────┐
│ RMSNorm │ Final Norm
└─────────────────────────┘
↓
┌─────────────────────────┐
│ Linear (→ vocab_size) │ LM Head (weight sharing)
└─────────────────────────┘
↓
Output Logits: [B, L, V]
2.2 与原始Transformer的区别¶
| 组件 | 原始Transformer | LLaMA-2 | 为什么改 |
|---|---|---|---|
| 归一化 | LayerNorm (Post-Norm) | RMSNorm (Pre-Norm) | 训练更稳定,计算更快 |
| 位置编码 | 正弦位置编码 | RoPE旋转位置编码 | 支持外推,相对位置信息 |
| 激活函数 | ReLU | SwiGLU | 效果更好(经验验证) |
| 注意力 | MHA | GQA | KV Cache更省显存 |
| FFN维度 | 4×D | 8/3×D(约2.67×D) | SwiGLU有门控,需要调整 |
| Bias | 有 | 无 | 减少参数,效果不降 |
3. 核心组件实现¶
3.1 RMSNorm¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
比LayerNorm更快:不需要计算均值,只计算均方根
公式: RMSNorm(x) = x / RMS(x) * γ
其中: RMS(x) = sqrt(mean(x²) + ε)
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() # super()调用父类方法
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放参数 γ
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args: x [batch_size, seq_len, dim]
Returns: normalized x, same shape
"""
# 计算均方根
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
# 归一化并缩放
return x / rms * self.weight
# 验证
norm = RMSNorm(64)
x = torch.randn(2, 10, 64)
out = norm(x)
print(f"RMSNorm: 输入 {x.shape} → 输出 {out.shape}")
# 验证输出的RMS约为1
rms = torch.sqrt(out.pow(2).mean(dim=-1))
print(f"归一化后RMS均值: {rms.mean().item():.4f}") # 应接近1.0
3.2 RoPE旋转位置编码¶
Python
class RotaryPositionEmbedding:
"""
RoPE (Rotary Position Embedding)
核心思想: 用旋转矩阵编码位置信息
- 对Q和K分别施加旋转,使得内积自然包含相对位置信息
- q_m · k_n = f(q, m) · f(k, n) → 只依赖相对位置 m-n
"""
@staticmethod # @staticmethod无需实例即可调用
def precompute_freqs(dim: int, max_seq_len: int, theta: float = 10000.0):
"""
预计算频率参数
freqs[i] = 1 / theta^(2i/dim), i = 0, 1, ..., dim/2 - 1
"""
# [dim/2] 频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
# [max_seq_len] 位置
t = torch.arange(max_seq_len).float()
# [max_seq_len, dim/2] 外积
freqs = torch.outer(t, freqs)
# 复数形式: e^(i*θ) = cos(θ) + i*sin(θ)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis # [max_seq_len, dim/2] 复数张量
@staticmethod
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor,
freqs_cis: torch.Tensor):
"""
对Q和K施加旋转位置编码
Args:
xq: [B, L, n_heads, head_dim]
xk: [B, L, n_kv_heads, head_dim]
freqs_cis: [L, head_dim/2] 复数频率
"""
# 将实数张量转为复数: [B, L, H, D] → [B, L, H, D/2] (复数)
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 调整freqs_cis的形状以匹配
# freqs_cis: [L, D/2] → [1, L, 1, D/2]
freqs = freqs_cis.unsqueeze(0).unsqueeze(2) # unsqueeze增加一个维度
# 旋转: 复数乘法 = 旋转
xq_out = torch.view_as_real(xq_complex * freqs).flatten(-2)
xk_out = torch.view_as_real(xk_complex * freqs).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
# 验证RoPE
freqs = RotaryPositionEmbedding.precompute_freqs(dim=64, max_seq_len=512)
print(f"RoPE频率表: {freqs.shape}") # [512, 32]
q = torch.randn(2, 10, 4, 64) # [B, L, n_heads, head_dim]
k = torch.randn(2, 10, 4, 64)
q_rot, k_rot = RotaryPositionEmbedding.apply_rotary_emb(q, k, freqs[:10])
print(f"RoPE编码后: Q {q_rot.shape}, K {k_rot.shape}")
3.3 SwiGLU前馈网络¶
Python
class SwiGLU_FFN(nn.Module):
"""
SwiGLU前馈网络 (LLaMA使用)
公式: FFN(x) = (Swish(xW₁) ⊙ xW₃)W₂
其中:
- Swish(x) = x * sigmoid(x) (也叫SiLU)
- ⊙ 表示逐元素乘法(门控机制)
- hidden_dim = 8/3 * dim(约2.67倍,取最近的multiple_of对齐)
"""
def __init__(self, dim: int, hidden_dim: int = None, multiple_of: int = 256):
super().__init__()
if hidden_dim is None:
hidden_dim = int(2 * (4 * dim) / 3)
# 取最近的multiple_of的倍数
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # Gate projection
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # Down projection
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # Up projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args: x [B, L, D]
Returns: [B, L, D]
"""
# SwiGLU: swish(x @ W1) * (x @ W3) @ W2
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# 验证
ffn = SwiGLU_FFN(dim=128)
x = torch.randn(2, 10, 128)
out = ffn(x)
print(f"SwiGLU FFN: 输入 {x.shape} → 输出 {out.shape}")
print(f"FFN参数量: {sum(p.numel() for p in ffn.parameters()):,}")
print(f"隐藏层维度: {ffn.w1.out_features}")
3.4 GQA注意力¶
Python
class GroupedQueryAttention(nn.Module):
"""
Grouped-Query Attention (GQA)
MHA: n_heads个Q, n_heads个K, n_heads个V
MQA: n_heads个Q, 1个K, 1个V
GQA: n_heads个Q, n_kv_heads个K, n_kv_heads个V
每n_heads/n_kv_heads个Q头共享同一组KV → 减少KV Cache
"""
def __init__(self, dim: int, n_heads: int, n_kv_heads: int = None):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads or n_heads # 默认MHA
self.head_dim = dim // n_heads
self.n_rep = n_heads // self.n_kv_heads # 每组KV重复次数
assert dim % n_heads == 0 # assert断言:条件False时抛出AssertionError
assert n_heads % self.n_kv_heads == 0
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""将KV头重复以匹配Q头数量"""
if self.n_rep == 1:
return x
B, L, n_kv_heads, head_dim = x.shape
x = x.unsqueeze(3).expand(B, L, n_kv_heads, self.n_rep, head_dim)
return x.reshape(B, L, self.n_heads, head_dim)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor,
mask: torch.Tensor = None):
"""
Args:
x: [B, L, D]
freqs_cis: [L, head_dim/2] RoPE频率
mask: [L, L] 因果mask
Returns: [B, L, D]
"""
B, L, _ = x.shape
# 线性投影
q = self.wq(x).view(B, L, self.n_heads, self.head_dim) # view重塑张量形状
k = self.wk(x).view(B, L, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, L, self.n_kv_heads, self.head_dim)
# 应用RoPE
q, k = RotaryPositionEmbedding.apply_rotary_emb(q, k, freqs_cis)
# GQA: 重复KV头
k = self._repeat_kv(k) # [B, L, n_heads, head_dim]
v = self._repeat_kv(v)
# 转置: [B, n_heads, L, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Scaled Dot-Product Attention
scale = math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) / scale # [B, H, L, L]
if mask is not None:
scores = scores + mask # mask中被遮蔽位置为-inf
attn = F.softmax(scores, dim=-1)
output = torch.matmul(attn, v) # [B, H, L, head_dim]
# 合并多头
output = output.transpose(1, 2).contiguous().view(B, L, -1)
return self.wo(output)
# 验证
gqa = GroupedQueryAttention(dim=128, n_heads=8, n_kv_heads=2)
x = torch.randn(2, 10, 128)
freqs = RotaryPositionEmbedding.precompute_freqs(dim=128//8, max_seq_len=512)
# 因果mask
L = 10
mask = torch.triu(torch.full((L, L), float('-inf')), diagonal=1)
out = gqa(x, freqs[:L], mask)
print(f"GQA: 输入 {x.shape} → 输出 {out.shape}")
print(f" Q头数: {gqa.n_heads}, KV头数: {gqa.n_kv_heads}, 重复倍数: {gqa.n_rep}")
4. 完整LLaMA模型¶
4.1 Transformer Block¶
Python
class TransformerBlock(nn.Module):
"""LLaMA-2 Transformer Block (Pre-Norm + Residual)"""
def __init__(self, dim: int, n_heads: int, n_kv_heads: int,
ffn_dim_multiplier: float = None, multiple_of: int = 256):
super().__init__()
self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads)
self.feed_forward = SwiGLU_FFN(dim, multiple_of=multiple_of)
self.attention_norm = RMSNorm(dim)
self.ffn_norm = RMSNorm(dim)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor,
mask: torch.Tensor = None):
"""
Pre-Norm + Residual连接
x → RMSNorm → Attention → +x → RMSNorm → FFN → +
"""
# Self-Attention with residual
h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
# FFN with residual
out = h + self.feed_forward(self.ffn_norm(h))
return out
4.2 完整MiniLLaMA模型¶
Python
from dataclasses import dataclass
@dataclass # @dataclass自动生成__init__等方法
class MiniLLaMAConfig:
"""Mini-LLaMA模型配置"""
vocab_size: int = 4096 # 词表大小(小型实验用)
dim: int = 256 # 隐藏维度
n_layers: int = 6 # Transformer层数
n_heads: int = 8 # 注意力头数
n_kv_heads: int = 2 # KV头数(GQA)
max_seq_len: int = 512 # 最大序列长度
multiple_of: int = 64 # FFN维度对齐
rope_theta: float = 10000.0 # RoPE基础频率
dropout: float = 0.1 # Dropout率
class MiniLLaMA(nn.Module):
"""
Mini-LLaMA: 完整的LLaMA-2风格语言模型
约15M参数,可在单GPU上训练
"""
def __init__(self, config: MiniLLaMAConfig):
super().__init__()
self.config = config
# Token Embedding
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
# Transformer Blocks
self.layers = nn.ModuleList([
TransformerBlock(
dim=config.dim,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
multiple_of=config.multiple_of
)
for _ in range(config.n_layers)
])
# Final RMSNorm
self.norm = RMSNorm(config.dim)
# LM Head (与embedding共享权重)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.tok_embeddings.weight = self.output.weight # weight tying
# 预计算RoPE频率
head_dim = config.dim // config.n_heads
self.freqs_cis = RotaryPositionEmbedding.precompute_freqs(
head_dim, config.max_seq_len, config.rope_theta
)
# Dropout
self.dropout = nn.Dropout(config.dropout)
# 初始化权重
self.apply(self._init_weights)
def _init_weights(self, module):
"""Xavier均匀初始化"""
if isinstance(module, nn.Linear): # isinstance检查类型
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: torch.Tensor = None):
"""
前向传播
Args:
tokens: [B, L] 输入token IDs
targets: [B, L] 目标token IDs(训练时提供)
Returns:
如果有targets: loss (标量)
否则: logits [B, L, V]
"""
B, L = tokens.shape
device = tokens.device
# Token Embedding
h = self.tok_embeddings(tokens) # [B, L, D]
h = self.dropout(h)
# 获取RoPE频率
freqs_cis = self.freqs_cis[:L].to(device) # .to(device)将数据移至GPU/CPU
# 构造因果mask
mask = torch.triu(
torch.full((L, L), float('-inf'), device=device),
diagonal=1
)
# 通过所有Transformer层
for layer in self.layers:
h = layer(h, freqs_cis, mask)
# Final Norm
h = self.norm(h)
# LM Head
logits = self.output(h) # [B, L, V]
if targets is not None:
# 计算交叉熵损失
# logits: [B, L, V] → [B*L, V]
# targets: [B, L] → [B*L]
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100 # 忽略padding(PyTorch默认值)
)
return loss
return logits
@torch.no_grad() # 禁用梯度计算,节省内存(推理时使用)
def generate(self, prompt_tokens: torch.Tensor, max_new_tokens: int = 100,
temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
"""
自回归文本生成
Args:
prompt_tokens: [1, prompt_len] 提示token序列
max_new_tokens: 最大生成token数
temperature: 温度参数(越高越随机)
top_k: Top-K采样
top_p: Nucleus采样
"""
self.eval()
tokens = prompt_tokens.clone()
for _ in range(max_new_tokens):
# 截断到最大长度
context = tokens[:, -self.config.max_seq_len:]
# 前向传播获取logits
logits = self(context)
# 只取最后一个位置的logits
next_logits = logits[:, -1, :] / temperature # [1, V]
# Top-K过滤
if top_k > 0:
top_k_val = min(top_k, next_logits.size(-1))
indices_to_remove = next_logits < torch.topk(next_logits, top_k_val)[0][:, -1:]
next_logits[indices_to_remove] = float('-inf')
# Top-P (Nucleus)过滤
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除累积概率超过top_p的token
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
next_logits[indices_to_remove] = float('-inf')
# 采样
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
tokens = torch.cat([tokens, next_token], dim=-1)
return tokens
# 创建模型
config = MiniLLaMAConfig()
model = MiniLLaMA(config)
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Mini-LLaMA 模型统计:")
print(f" 总参数量: {total_params:,} ({total_params/1e6:.1f}M)")
print(f" 可训练参数: {trainable_params:,}")
print(f" 配置: dim={config.dim}, layers={config.n_layers}, heads={config.n_heads}")
print(f" KV头数: {config.n_kv_heads} (GQA, {config.n_heads//config.n_kv_heads}x压缩)")
print(f" 最大序列长度: {config.max_seq_len}")
# 测试前向传播
dummy_input = torch.randint(0, config.vocab_size, (2, 32))
dummy_target = torch.randint(0, config.vocab_size, (2, 32))
loss = model(dummy_input, dummy_target)
print(f"\n前向传播测试: loss = {loss.item():.4f}")
print(f"随机初始化loss理论值: {math.log(config.vocab_size):.4f} (ln({config.vocab_size}))")
5. BPE Tokenizer训练¶
5.1 BPE算法原理¶
Text Only
Byte Pair Encoding (BPE) 算法
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
初始: 将文本拆分为单个字符(或字节)
迭代: 每次合并出现频率最高的相邻对
重复: 直到达到目标词表大小
示例 (英文):
初始词表: ['l', 'o', 'w', 'e', 'r', 'n', 's', 't']
语料: "low" × 5, "lower" × 2, "newest" × 6, "lowest" × 3
Step 1: 最频繁对 ('e','s') → 合并为 'es'
Step 2: 最频繁对 ('es','t') → 合并为 'est'
Step 3: 最频繁对 ('l','o') → 合并为 'lo'
Step 4: 最频繁对 ('lo','w') → 合并为 'low'
...
最终: "lowest" → ['low', 'est']
"newest" → ['n', 'ew', 'est']
5.2 手写BPE Tokenizer¶
Python
import re
from collections import Counter, defaultdict
class SimpleBPETokenizer:
"""
手写BPE分词器
支持中英文混合文本
"""
def __init__(self, vocab_size: int = 4096):
self.target_vocab_size = vocab_size
self.merges = {} # 合并规则: (a, b) → ab
self.vocab = {} # token → id
self.inverse_vocab = {} # id → token
# 特殊token
self.special_tokens = {
"<pad>": 0,
"<unk>": 1,
"<bos>": 2,
"<eos>": 3,
}
def _get_stats(self, word_freqs):
"""统计相邻pair的频率"""
pairs = Counter() # Counter统计元素出现次数
for word, freq in word_freqs.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i+1])] += freq
return pairs
def _merge_pair(self, pair, word_freqs):
"""合并一个pair"""
new_word_freqs = {}
bigram = ' '.join(pair)
replacement = ''.join(pair)
for word, freq in word_freqs.items():
new_word = word.replace(bigram, replacement)
new_word_freqs[new_word] = freq
return new_word_freqs
def train(self, texts: list, verbose: bool = True):
"""
训练BPE分词器
Args:
texts: 训练文本列表
verbose: 是否打印训练过程
"""
# Step 1: 分词并统计词频
word_freqs = Counter()
for text in texts:
# 简单分词:按空格和标点拆分
words = re.findall(r'[\u4e00-\u9fff]|[a-zA-Z]+|[0-9]+|[^\s\w]', text) # re.findall正则查找所有匹配项
for word in words:
# 将每个词拆分为字符(空格分隔)
char_word = ' '.join(list(word))
word_freqs[char_word] += 1
# Step 2: 初始化词表(所有出现的字符)
self.vocab = dict(self.special_tokens)
chars = set()
for word in word_freqs:
for char in word.split():
chars.add(char)
for char in sorted(chars):
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
initial_vocab_size = len(self.vocab)
if verbose:
print(f"初始词表大小: {initial_vocab_size}")
# Step 3: 迭代合并
num_merges = self.target_vocab_size - initial_vocab_size
for i in range(num_merges):
# 统计pair频率
pairs = self._get_stats(word_freqs)
if not pairs:
break
# 找出最频繁的pair
best_pair = max(pairs, key=pairs.get)
best_freq = pairs[best_pair]
if best_freq < 2: # 频率太低则停止
break
# 合并
word_freqs = self._merge_pair(best_pair, word_freqs)
# 记录合并规则
merged_token = ''.join(best_pair)
self.merges[best_pair] = merged_token
if merged_token not in self.vocab:
self.vocab[merged_token] = len(self.vocab)
if verbose and (i + 1) % 100 == 0:
print(f" 合并 {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' → '{merged_token}' (freq={best_freq})")
# 构建反向词表
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
if verbose:
print(f"最终词表大小: {len(self.vocab)}")
print(f"合并规则数: {len(self.merges)}")
def encode(self, text: str) -> list:
"""将文本编码为token ID序列"""
# 分词
words = re.findall(r'[\u4e00-\u9fff]|[a-zA-Z]+|[0-9]+|[^\s\w]', text)
ids = []
for word in words:
# 拆分为字符
symbols = list(word)
# 应用合并规则(贪婪)
while len(symbols) > 1:
# 找出可以合并的pair
best_pair = None
best_idx = -1
for i in range(len(symbols) - 1):
pair = (symbols[i], symbols[i+1])
if pair in self.merges:
best_pair = pair
best_idx = i
break # 按训练顺序优先合并
if best_pair is None:
break
# 执行合并
symbols = (symbols[:best_idx] +
[self.merges[best_pair]] +
symbols[best_idx+2:])
# 转换为ID
for sym in symbols:
ids.append(self.vocab.get(sym, self.special_tokens["<unk>"]))
return ids
def decode(self, ids: list) -> str:
"""将token ID序列解码为文本"""
tokens = [self.inverse_vocab.get(id, "<unk>") for id in ids]
return ''.join(tokens)
# === 训练Tokenizer ===
# 中文训练语料(示例)
training_texts = [
"人工智能正在改变世界",
"深度学习是机器学习的一个分支",
"大语言模型具有强大的文本生成能力",
"自然语言处理是人工智能的重要方向",
"Transformer架构改变了深度学习的格局",
"注意力机制是现代神经网络的核心组件",
"预训练语言模型通过大规模数据学习语言知识",
"GPT系列模型展示了规模扩展的威力",
"BERT模型在自然语言理解任务上取得了突破",
"强化学习可以用于优化语言模型的输出",
] * 50 # 重复增加频率
tokenizer = SimpleBPETokenizer(vocab_size=500)
tokenizer.train(training_texts)
# 测试
test_text = "人工智能改变世界"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print(f"\n编码测试:")
print(f" 原文: {test_text}")
print(f" 编码: {encoded}")
print(f" 解码: {decoded}")
print(f" tokens: {[tokenizer.inverse_vocab.get(id, '?') for id in encoded]}")
6. 预训练Pipeline¶
6.1 数据准备¶
Python
import torch
from torch.utils.data import Dataset, DataLoader
import os
class TextDataset(Dataset):
"""
用于CLM预训练的文本数据集
将文本编码后切分为固定长度的序列
"""
def __init__(self, texts: list, tokenizer, seq_len: int = 128):
self.seq_len = seq_len
# 编码所有文本
all_ids = []
for text in texts:
ids = tokenizer.encode(text)
all_ids.extend(ids)
all_ids.append(tokenizer.special_tokens["<eos>"])
# 切分为固定长度序列
self.data = []
for i in range(0, len(all_ids) - seq_len - 1, seq_len):
input_ids = all_ids[i : i + seq_len]
target_ids = all_ids[i + 1 : i + seq_len + 1]
self.data.append((
torch.tensor(input_ids, dtype=torch.long),
torch.tensor(target_ids, dtype=torch.long)
))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
print("数据集准备完成")
6.2 训练循环¶
Python
import time
class MiniLLaMATrainer:
"""Mini-LLaMA训练器"""
def __init__(self, model, config, tokenizer, device='cpu'):
self.model = model.to(device)
self.config = config
self.tokenizer = tokenizer
self.device = device
# 优化器: AdamW
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.95),
weight_decay=0.1
)
# 学习率调度: Cosine with warmup
self.warmup_steps = 100
self.total_steps = 0
def _get_lr(self, step):
"""Cosine学习率调度 with warmup"""
if step < self.warmup_steps:
return 3e-4 * step / self.warmup_steps
# Cosine decay
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
return 3e-4 * 0.5 * (1 + math.cos(math.pi * progress))
def train(self, train_dataset, num_epochs=10, batch_size=8,
log_interval=50, eval_interval=200):
"""
训练循环
"""
dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
self.total_steps = num_epochs * len(dataloader)
self.model.train()
global_step = 0
best_loss = float('inf')
print(f"开始训练 Mini-LLaMA")
print(f" Epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print(f" Total steps: {self.total_steps}")
print(f" Device: {self.device}")
print("=" * 60)
for epoch in range(num_epochs):
epoch_loss = 0
epoch_start = time.time()
for batch_idx, (input_ids, target_ids) in enumerate(dataloader): # enumerate同时获取索引和元素
input_ids = input_ids.to(self.device)
target_ids = target_ids.to(self.device)
# 更新学习率
lr = self._get_lr(global_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# 前向传播
loss = self.model(input_ids, target_ids)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
epoch_loss += loss.item()
global_step += 1
# 日志
if global_step % log_interval == 0:
avg_loss = epoch_loss / (batch_idx + 1)
ppl = math.exp(min(avg_loss, 20)) # 限制防止溢出
elapsed = time.time() - epoch_start
print(f" Step {global_step}: loss={avg_loss:.4f}, ppl={ppl:.2f}, lr={lr:.6f}, time={elapsed:.1f}s")
# 生成样本
if global_step % eval_interval == 0:
self._generate_sample()
# Epoch结束统计
avg_epoch_loss = epoch_loss / len(dataloader)
epoch_time = time.time() - epoch_start
print(f"\nEpoch {epoch+1}/{num_epochs}: avg_loss={avg_epoch_loss:.4f}, time={epoch_time:.1f}s")
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
# 保存最佳模型
torch.save(self.model.state_dict(), 'mini_llama_best.pt')
print(f" ✓ 保存最佳模型 (loss={best_loss:.4f})")
print()
print(f"训练完成! 最佳loss: {best_loss:.4f}")
def _generate_sample(self):
"""训练中生成样本以观察效果"""
self.model.eval()
# 用几个起始token生成
prompt = torch.tensor([[self.tokenizer.special_tokens["<bos>"]]]).to(self.device)
generated = self.model.generate(
prompt, max_new_tokens=50, temperature=0.8, top_k=40
)
text = self.tokenizer.decode(generated[0].tolist())
print(f" [生成样本] {text[:100]}...")
self.model.train()
# === 端到端训练示例 ===
print("端到端预训练流程:")
print("1. 准备语料 → 2. 训练Tokenizer → 3. 构建数据集 → 4. 训练模型 → 5. 生成文本")
print()
# 准备训练语料(实际应用中使用大规模语料)
corpus = [
"人工智能是计算机科学的一个重要分支,它研究如何让计算机模拟人类的智能行为。",
"深度学习通过多层神经网络自动学习数据的特征表示,是目前最成功的机器学习方法。",
"自然语言处理让计算机能够理解和生成人类语言,包括文本分类、机器翻译等任务。",
"大语言模型通过在海量文本上预训练,学习到了丰富的语言知识和世界知识。",
"Transformer架构使用自注意力机制,能够并行处理序列中的所有位置。",
"注意力机制允许模型在处理每个位置时关注序列中的其他相关位置。",
"预训练加微调的范式极大地提高了模型在各种下游任务上的表现。",
"强化学习通过试错的方式让智能体学习最优策略,在游戏和机器人控制中表现出色。",
] * 100 # 实际需要大量数据
# 1. 训练Tokenizer
tok = SimpleBPETokenizer(vocab_size=500)
tok.train(corpus, verbose=False)
# 2. 更新模型配置
train_config = MiniLLaMAConfig(
vocab_size=len(tok.vocab),
dim=128,
n_layers=4,
n_heads=4,
n_kv_heads=2,
max_seq_len=128,
multiple_of=32
)
# 3. 创建模型
train_model = MiniLLaMA(train_config)
params = sum(p.numel() for p in train_model.parameters())
print(f"模型参数量: {params:,} ({params/1e6:.2f}M)")
# 4. 创建数据集
train_dataset = TextDataset(corpus, tok, seq_len=64)
print(f"训练样本数: {len(train_dataset)}")
# 5. 训练(实际训练需要更多数据和更长时间)
# trainer = MiniLLaMATrainer(train_model, train_config, tok, device='cpu')
# trainer.train(train_dataset, num_epochs=5, batch_size=8)
print("\n注意: 完整训练需要:")
print(" - 几GB~几十GB的中文语料")
print(" - GPU加速(推荐至少一张RTX 3090)")
print(" - 数小时到数天的训练时间")
print(" - 更大的模型配置(dim=512+, layers=12+)")
7. 模型评估与生成¶
7.1 困惑度评估¶
Python
@torch.no_grad()
def evaluate_perplexity(model, dataset, batch_size=8, device='cpu'):
"""
计算模型在数据集上的困惑度(Perplexity)
PPL = exp(avg_loss)
PPL越低,模型越好
"""
model.eval()
model = model.to(device)
dataloader = DataLoader(dataset, batch_size=batch_size)
total_loss = 0
total_tokens = 0
for input_ids, target_ids in dataloader:
input_ids = input_ids.to(device)
target_ids = target_ids.to(device)
loss = model(input_ids, target_ids)
num_tokens = (target_ids != -100).sum().item() # 与 ignore_index=-100 保持一致
total_loss += loss.item() * num_tokens
total_tokens += num_tokens
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)
return perplexity, avg_loss
# 如果模型已训练好:
# ppl, loss = evaluate_perplexity(train_model, train_dataset)
# print(f"困惑度: {ppl:.2f}, 平均Loss: {loss:.4f}")
7.2 多种采样策略对比¶
Python
def generate_text(model, tokenizer, prompt, max_tokens=100,
strategy='top_p', **kwargs):
"""
文本生成:支持多种采样策略
"""
device = next(model.parameters()).device
# 编码prompt
ids = tokenizer.encode(prompt)
tokens = torch.tensor([ids]).to(device)
if strategy == 'greedy':
# 贪婪搜索:每步选概率最高的
generated = model.generate(tokens, max_tokens, temperature=0.01, top_k=1)
elif strategy == 'temperature':
# 温度采样
temp = kwargs.get('temperature', 1.0)
generated = model.generate(tokens, max_tokens, temperature=temp, top_k=0, top_p=1.0)
elif strategy == 'top_k':
# Top-K采样
k = kwargs.get('k', 50)
generated = model.generate(tokens, max_tokens, temperature=0.8, top_k=k)
elif strategy == 'top_p':
# Nucleus (Top-P) 采样
p = kwargs.get('p', 0.9)
generated = model.generate(tokens, max_tokens, temperature=0.8, top_p=p)
else:
raise ValueError(f"Unknown strategy: {strategy}")
return tokenizer.decode(generated[0].tolist())
print("""
采样策略对比:
┌────────────┬────────────────────┬───────────────┐
│ 策略 │ 特点 │ 适用场景 │
├────────────┼────────────────────┼───────────────┤
│ Greedy │ 确定性,最保守 │ 翻译/摘要 │
│ Temperature│ 温度高→随机 │ 通用 │
│ Top-K │ 只从前K个候选采样 │ 创作 │
│ Top-P │ 动态K, 累积概率≤P │ 对话/创作 │
└────────────┴────────────────────┴───────────────┘
推荐设置:
- 创意写作: temperature=1.0, top_p=0.95
- 对话: temperature=0.7, top_p=0.9
- 代码/事实: temperature=0.2, top_k=10
""")
8. 工程优化技巧¶
8.1 KV Cache¶
Python
class MiniLLaMAWithKVCache(nn.Module):
"""
带KV Cache的推理优化
问题: 自回归生成时,每生成一个token都要重新计算所有位置的KV
优化: 缓存已计算的KV,新token只需计算增量
无Cache: 生成100个token → 计算 1+2+...+100 = 5050 次KV
有Cache: 生成100个token → 计算 100 次KV
"""
def __init__(self, config):
super().__init__()
# ... 复用MiniLLaMA的初始化 ...
self.config = config
# KV Cache的核心思想(伪代码)
def generate_with_cache(self, prompt_tokens, max_new_tokens=100):
"""
KV Cache加速生成
"""
# 初始化cache: 每层存储K和V
kv_cache = [
{"k": None, "v": None}
for _ in range(self.config.n_layers)
]
# Prefill阶段:处理整个prompt
# 此时计算所有位置的KV并缓存
logits = self.forward_with_cache(prompt_tokens, kv_cache, is_prefill=True)
tokens = prompt_tokens
for step in range(max_new_tokens):
# Decode阶段:每次只处理新增的1个token
next_token = logits[:, -1:, :].argmax(dim=-1)
tokens = torch.cat([tokens, next_token], dim=-1)
# 只传入新token,利用cache中已有的KV
logits = self.forward_with_cache(next_token, kv_cache, is_prefill=False)
return tokens
print("""
KV Cache性能提升:
┌───────────────┬──────────────────┬──────────────────┐
│ 生成长度 │ 无Cache计算量 │ 有Cache计算量 │
├───────────────┼──────────────────┼──────────────────┤
│ 100 tokens │ 5050 × KV_cost │ 100 × KV_cost │
│ 1000 tokens │ 500500 × KV_cost │ 1000 × KV_cost │
│ 加速比 │ ~50x (100) │ ~500x (1000) │
└───────────────┴──────────────────┴──────────────────┘
显存占用:
KV Cache大小 = 2 × n_layers × n_kv_heads × head_dim × seq_len × batch_size
对于LLaMA-2-7B (seq_len=4096, batch=1):
= 2 × 32 × 32 × 128 × 4096 × 2 bytes ≈ 2GB
""")
8.2 混合精度训练¶
Python
# 使用PyTorch AMP (Automatic Mixed Precision)
from torch.amp import autocast, GradScaler
def train_step_mixed_precision(model, optimizer, input_ids, target_ids, scaler):
"""
混合精度训练步骤
FP32运算 → 改为 FP16 → 速度2x, 显存减半
但某些操作(如loss计算、归一化)需要保持FP32
"""
optimizer.zero_grad()
# 自动混合精度: 前向传播中适当使用FP16
with autocast(device_type='cuda'):
loss = model(input_ids, target_ids)
# 缩放loss以防止FP16下梯度下溢
scaler.scale(loss).backward()
# 梯度裁剪(需要先unscale)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# 优化器步骤
scaler.step(optimizer)
scaler.update()
return loss.item()
print("""
混合精度训练好处:
1. 训练速度提升 ~2x(利用Tensor Cores)
2. 显存占用减少 ~50%
3. 模型精度基本不受影响
关键技术:
- GradScaler: 防止FP16梯度下溢
- autocast: 自动选择合适的精度
- 保持关键操作在FP32: LayerNorm, Loss, Softmax
""")
练习题¶
练习1:组件验证¶
- 修改GQA的
n_kv_heads为1(MQA)和n_heads(MHA),对比参数量差异 - 将RMSNorm替换为nn.LayerNorm,观察训练稳定性差异
练习2:Tokenizer扩展¶
- 为Tokenizer添加
<pad>,<unk>等特殊token的正确处理 - 实现WordPiece分词器并与BPE对比
练习3:预训练实验¶
- 在WikiText-2数据集上预训练Mini-LLaMA
- 对比不同超参数(dim, n_layers, lr)对困惑度的影响
- 实现学习率Warmup + Cosine Decay调度器
练习4:生成优化¶
- 实现Beam Search解码
- 实现Repetition Penalty(减少重复生成)
- 实现KV Cache并测量推理速度提升
📝 本章小结¶
| 知识点 | 掌握程度检查 |
|---|---|
| LLaMA架构特点 | 能否说出与原始Transformer的5个主要区别? |
| RMSNorm | 能否手写实现并解释为什么比LayerNorm好? |
| RoPE | 能否解释旋转位置编码的核心思想? |
| SwiGLU | 能否解释门控机制的作用? |
| GQA | 能否解释MHA→GQA→MQA的演进逻辑? |
| BPE Tokenizer | 能否手写BPE训练过程? |
| CLM预训练 | 能否解释Next Token Prediction的训练流程? |
| KV Cache | 能否解释KV Cache加速推理的原理? |
🔗 后续学习路径¶
- Transformer深入理解 → 01-Transformer深入理解
- 手写Transformer → 03-手写Transformer完整实现
- 高效微调 → 01-高效微调技术
- LoRA从零实现 → 06-LoRA从零实现
- 大模型预训练理论 → 03-大模型预训练
📚 参考资料¶
- Touvron et al. "LLaMA: Open and Efficient Foundation Language Models" (2023)
- Touvron et al. "Llama 2: Open Foundation and Fine-Tuned Chat Models" (2023)
- Su et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
- Zhang & Sennrich "Root Mean Square Layer Normalization" (2019)
- Shazeer "GLU Variants Improve Transformer" (2020) — SwiGLU
- Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models" (2023)
- Sennrich et al. "Neural Machine Translation of Rare Words with Subword Units" (2016) — BPE
- Kaplan et al. "Scaling Laws for Neural Language Models" (2020)
- Korthikanti et al. "Reducing Activation Recomputation in Large Transformer Models" (2022)
- Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)