跳转至

从零搭建小型LLM

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:从零实现一个完整的小型大语言模型(Mini-LLM),包括LLaMA架构实现、Tokenizer训练、预训练流程,真正理解LLM的内部工作原理。

📌 定位说明:对标 happy-llm Ch5(动手搭建大模型),我们的实现覆盖完整的LLaMA2架构(RoPE/RMSNorm/SwiGLU/GQA)、BPE Tokenizer训练、以及端到端预训练。代码可运行、形状标注清晰、工程实践更贴近实际。


目录


1. 为什么要从零搭建LLM

Text Only
从零搭建LLM的学习价值:
━━━━━━━━━━━━━━━━━━━━━━━━━━━
1. 彻底理解每一行代码的含义
2. 明白"大模型"到底"大"在哪里
3. 掌握预训练的核心流程
4. 为后续的微调/部署/优化打下基础
5. 面试时能从容讲解底层实现

我们的实现目标:
├── 模型架构: LLaMA-2风格 (~15M参数)
├── Tokenizer: 训练一个BPE分词器
├── 预训练: 在小规模中文语料上CLM训练
└── 生成: 实现自回归文本生成 + 多种采样策略

2. LLaMA架构全貌

2.1 架构概览

Text Only
LLaMA-2 架构(我们的Mini版本)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

输入 Token IDs: [B, L]
┌─────────────────────────┐
│   Token Embedding       │  [B, L] → [B, L, D]
└─────────────────────────┘
┌─────────────────────────┐ ×N layers
│  ┌───────────────────┐  │
│  │    RMSNorm         │  │ Pre-Norm
│  └───────────────────┘  │
│        ↓                │
│  ┌───────────────────┐  │
│  │  GQA Attention    │  │ + RoPE位置编码
│  │  (因果Mask)       │  │
│  └───────────────────┘  │
│        ↓ + Residual      │
│  ┌───────────────────┐  │
│  │    RMSNorm         │  │ Pre-Norm
│  └───────────────────┘  │
│        ↓                │
│  ┌───────────────────┐  │
│  │    SwiGLU FFN      │  │ 8/3 × D hidden
│  └───────────────────┘  │
│        ↓ + Residual      │
└─────────────────────────┘
┌─────────────────────────┐
│   RMSNorm               │ Final Norm
└─────────────────────────┘
┌─────────────────────────┐
│   Linear (→ vocab_size) │ LM Head (weight sharing)
└─────────────────────────┘
Output Logits: [B, L, V]

2.2 与原始Transformer的区别

组件 原始Transformer LLaMA-2 为什么改
归一化 LayerNorm (Post-Norm) RMSNorm (Pre-Norm) 训练更稳定,计算更快
位置编码 正弦位置编码 RoPE旋转位置编码 支持外推,相对位置信息
激活函数 ReLU SwiGLU 效果更好(经验验证)
注意力 MHA GQA KV Cache更省显存
FFN维度 4×D 8/3×D(约2.67×D) SwiGLU有门控,需要调整
Bias 减少参数,效果不降

3. 核心组件实现

3.1 RMSNorm

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization
    比LayerNorm更快:不需要计算均值,只计算均方根

    公式: RMSNorm(x) = x / RMS(x) * γ
    其中: RMS(x) = sqrt(mean(x²) + ε)
    """
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()  # super()调用父类方法
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # 可学习的缩放参数 γ

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args: x [batch_size, seq_len, dim]
        Returns: normalized x, same shape
        """
        # 计算均方根
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        # 归一化并缩放
        return x / rms * self.weight

# 验证
norm = RMSNorm(64)
x = torch.randn(2, 10, 64)
out = norm(x)
print(f"RMSNorm: 输入 {x.shape} → 输出 {out.shape}")
# 验证输出的RMS约为1
rms = torch.sqrt(out.pow(2).mean(dim=-1))
print(f"归一化后RMS均值: {rms.mean().item():.4f}")  # 应接近1.0

3.2 RoPE旋转位置编码

Python
class RotaryPositionEmbedding:
    """
    RoPE (Rotary Position Embedding)

    核心思想: 用旋转矩阵编码位置信息
    - 对Q和K分别施加旋转,使得内积自然包含相对位置信息
    - q_m · k_n = f(q, m) · f(k, n) → 只依赖相对位置 m-n
    """

    @staticmethod  # @staticmethod无需实例即可调用
    def precompute_freqs(dim: int, max_seq_len: int, theta: float = 10000.0):
        """
        预计算频率参数

        freqs[i] = 1 / theta^(2i/dim), i = 0, 1, ..., dim/2 - 1
        """
        # [dim/2] 频率
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        # [max_seq_len] 位置
        t = torch.arange(max_seq_len).float()
        # [max_seq_len, dim/2] 外积
        freqs = torch.outer(t, freqs)
        # 复数形式: e^(i*θ) = cos(θ) + i*sin(θ)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis  # [max_seq_len, dim/2] 复数张量

    @staticmethod
    def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor,
                          freqs_cis: torch.Tensor):
        """
        对Q和K施加旋转位置编码

        Args:
            xq: [B, L, n_heads, head_dim]
            xk: [B, L, n_kv_heads, head_dim]
            freqs_cis: [L, head_dim/2] 复数频率
        """
        # 将实数张量转为复数: [B, L, H, D] → [B, L, H, D/2] (复数)
        xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
        xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

        # 调整freqs_cis的形状以匹配
        # freqs_cis: [L, D/2] → [1, L, 1, D/2]
        freqs = freqs_cis.unsqueeze(0).unsqueeze(2)  # unsqueeze增加一个维度

        # 旋转: 复数乘法 = 旋转
        xq_out = torch.view_as_real(xq_complex * freqs).flatten(-2)
        xk_out = torch.view_as_real(xk_complex * freqs).flatten(-2)

        return xq_out.type_as(xq), xk_out.type_as(xk)

# 验证RoPE
freqs = RotaryPositionEmbedding.precompute_freqs(dim=64, max_seq_len=512)
print(f"RoPE频率表: {freqs.shape}")  # [512, 32]

q = torch.randn(2, 10, 4, 64)  # [B, L, n_heads, head_dim]
k = torch.randn(2, 10, 4, 64)
q_rot, k_rot = RotaryPositionEmbedding.apply_rotary_emb(q, k, freqs[:10])
print(f"RoPE编码后: Q {q_rot.shape}, K {k_rot.shape}")

3.3 SwiGLU前馈网络

Python
class SwiGLU_FFN(nn.Module):
    """
    SwiGLU前馈网络 (LLaMA使用)

    公式: FFN(x) = (Swish(xW₁) ⊙ xW₃)W₂

    其中:
    - Swish(x) = x * sigmoid(x) (也叫SiLU)
    - ⊙ 表示逐元素乘法(门控机制)
    - hidden_dim = 8/3 * dim(约2.67倍,取最近的multiple_of对齐)
    """
    def __init__(self, dim: int, hidden_dim: int = None, multiple_of: int = 256):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(2 * (4 * dim) / 3)
            # 取最近的multiple_of的倍数
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)   # Gate projection
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)    # Down projection
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)    # Up projection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args: x [B, L, D]
        Returns: [B, L, D]
        """
        # SwiGLU: swish(x @ W1) * (x @ W3) @ W2
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# 验证
ffn = SwiGLU_FFN(dim=128)
x = torch.randn(2, 10, 128)
out = ffn(x)
print(f"SwiGLU FFN: 输入 {x.shape} → 输出 {out.shape}")
print(f"FFN参数量: {sum(p.numel() for p in ffn.parameters()):,}")
print(f"隐藏层维度: {ffn.w1.out_features}")

3.4 GQA注意力

Python
class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention (GQA)

    MHA: n_heads个Q, n_heads个K, n_heads个V
    MQA: n_heads个Q, 1个K, 1个V
    GQA: n_heads个Q, n_kv_heads个K, n_kv_heads个V

    每n_heads/n_kv_heads个Q头共享同一组KV → 减少KV Cache
    """
    def __init__(self, dim: int, n_heads: int, n_kv_heads: int = None):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads or n_heads  # 默认MHA
        self.head_dim = dim // n_heads
        self.n_rep = n_heads // self.n_kv_heads  # 每组KV重复次数

        assert dim % n_heads == 0  # assert断言:条件False时抛出AssertionError
        assert n_heads % self.n_kv_heads == 0

        self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)

    def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
        """将KV头重复以匹配Q头数量"""
        if self.n_rep == 1:
            return x
        B, L, n_kv_heads, head_dim = x.shape
        x = x.unsqueeze(3).expand(B, L, n_kv_heads, self.n_rep, head_dim)
        return x.reshape(B, L, self.n_heads, head_dim)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor,
                mask: torch.Tensor = None):
        """
        Args:
            x: [B, L, D]
            freqs_cis: [L, head_dim/2] RoPE频率
            mask: [L, L] 因果mask
        Returns: [B, L, D]
        """
        B, L, _ = x.shape

        # 线性投影
        q = self.wq(x).view(B, L, self.n_heads, self.head_dim)  # view重塑张量形状
        k = self.wk(x).view(B, L, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, L, self.n_kv_heads, self.head_dim)

        # 应用RoPE
        q, k = RotaryPositionEmbedding.apply_rotary_emb(q, k, freqs_cis)

        # GQA: 重复KV头
        k = self._repeat_kv(k)  # [B, L, n_heads, head_dim]
        v = self._repeat_kv(v)

        # 转置: [B, n_heads, L, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Scaled Dot-Product Attention
        scale = math.sqrt(self.head_dim)
        scores = torch.matmul(q, k.transpose(-2, -1)) / scale  # [B, H, L, L]

        if mask is not None:
            scores = scores + mask  # mask中被遮蔽位置为-inf

        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)  # [B, H, L, head_dim]

        # 合并多头
        output = output.transpose(1, 2).contiguous().view(B, L, -1)
        return self.wo(output)

# 验证
gqa = GroupedQueryAttention(dim=128, n_heads=8, n_kv_heads=2)
x = torch.randn(2, 10, 128)
freqs = RotaryPositionEmbedding.precompute_freqs(dim=128//8, max_seq_len=512)
# 因果mask
L = 10
mask = torch.triu(torch.full((L, L), float('-inf')), diagonal=1)
out = gqa(x, freqs[:L], mask)
print(f"GQA: 输入 {x.shape} → 输出 {out.shape}")
print(f"  Q头数: {gqa.n_heads}, KV头数: {gqa.n_kv_heads}, 重复倍数: {gqa.n_rep}")

4. 完整LLaMA模型

4.1 Transformer Block

Python
class TransformerBlock(nn.Module):
    """LLaMA-2 Transformer Block (Pre-Norm + Residual)"""

    def __init__(self, dim: int, n_heads: int, n_kv_heads: int,
                 ffn_dim_multiplier: float = None, multiple_of: int = 256):
        super().__init__()
        self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads)
        self.feed_forward = SwiGLU_FFN(dim, multiple_of=multiple_of)
        self.attention_norm = RMSNorm(dim)
        self.ffn_norm = RMSNorm(dim)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor,
                mask: torch.Tensor = None):
        """
        Pre-Norm + Residual连接
        x → RMSNorm → Attention → +x → RMSNorm → FFN → +
        """
        # Self-Attention with residual
        h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
        # FFN with residual
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

4.2 完整MiniLLaMA模型

Python
from dataclasses import dataclass

@dataclass  # @dataclass自动生成__init__等方法
class MiniLLaMAConfig:
    """Mini-LLaMA模型配置"""
    vocab_size: int = 4096       # 词表大小(小型实验用)
    dim: int = 256               # 隐藏维度
    n_layers: int = 6            # Transformer层数
    n_heads: int = 8             # 注意力头数
    n_kv_heads: int = 2          # KV头数(GQA)
    max_seq_len: int = 512       # 最大序列长度
    multiple_of: int = 64        # FFN维度对齐
    rope_theta: float = 10000.0  # RoPE基础频率
    dropout: float = 0.1         # Dropout率

class MiniLLaMA(nn.Module):
    """
    Mini-LLaMA: 完整的LLaMA-2风格语言模型

    约15M参数,可在单GPU上训练
    """
    def __init__(self, config: MiniLLaMAConfig):
        super().__init__()
        self.config = config

        # Token Embedding
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)

        # Transformer Blocks
        self.layers = nn.ModuleList([
            TransformerBlock(
                dim=config.dim,
                n_heads=config.n_heads,
                n_kv_heads=config.n_kv_heads,
                multiple_of=config.multiple_of
            )
            for _ in range(config.n_layers)
        ])

        # Final RMSNorm
        self.norm = RMSNorm(config.dim)

        # LM Head (与embedding共享权重)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
        self.tok_embeddings.weight = self.output.weight  # weight tying

        # 预计算RoPE频率
        head_dim = config.dim // config.n_heads
        self.freqs_cis = RotaryPositionEmbedding.precompute_freqs(
            head_dim, config.max_seq_len, config.rope_theta
        )

        # Dropout
        self.dropout = nn.Dropout(config.dropout)

        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Xavier均匀初始化"""
        if isinstance(module, nn.Linear):  # isinstance检查类型
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, tokens: torch.Tensor, targets: torch.Tensor = None):
        """
        前向传播

        Args:
            tokens: [B, L] 输入token IDs
            targets: [B, L] 目标token IDs(训练时提供)
        Returns:
            如果有targets: loss (标量)
            否则: logits [B, L, V]
        """
        B, L = tokens.shape
        device = tokens.device

        # Token Embedding
        h = self.tok_embeddings(tokens)  # [B, L, D]
        h = self.dropout(h)

        # 获取RoPE频率
        freqs_cis = self.freqs_cis[:L].to(device)  # .to(device)将数据移至GPU/CPU

        # 构造因果mask
        mask = torch.triu(
            torch.full((L, L), float('-inf'), device=device),
            diagonal=1
        )

        # 通过所有Transformer层
        for layer in self.layers:
            h = layer(h, freqs_cis, mask)

        # Final Norm
        h = self.norm(h)

        # LM Head
        logits = self.output(h)  # [B, L, V]

        if targets is not None:
            # 计算交叉熵损失
            # logits: [B, L, V] → [B*L, V]
            # targets: [B, L] → [B*L]
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-100  # 忽略padding(PyTorch默认值)
            )
            return loss

        return logits

    @torch.no_grad()  # 禁用梯度计算,节省内存(推理时使用)
    def generate(self, prompt_tokens: torch.Tensor, max_new_tokens: int = 100,
                 temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
        """
        自回归文本生成

        Args:
            prompt_tokens: [1, prompt_len] 提示token序列
            max_new_tokens: 最大生成token数
            temperature: 温度参数(越高越随机)
            top_k: Top-K采样
            top_p: Nucleus采样
        """
        self.eval()
        tokens = prompt_tokens.clone()

        for _ in range(max_new_tokens):
            # 截断到最大长度
            context = tokens[:, -self.config.max_seq_len:]

            # 前向传播获取logits
            logits = self(context)

            # 只取最后一个位置的logits
            next_logits = logits[:, -1, :] / temperature  # [1, V]

            # Top-K过滤
            if top_k > 0:
                top_k_val = min(top_k, next_logits.size(-1))
                indices_to_remove = next_logits < torch.topk(next_logits, top_k_val)[0][:, -1:]
                next_logits[indices_to_remove] = float('-inf')

            # Top-P (Nucleus)过滤
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                # 移除累积概率超过top_p的token
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = False
                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove
                )
                next_logits[indices_to_remove] = float('-inf')

            # 采样
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # [1, 1]

            tokens = torch.cat([tokens, next_token], dim=-1)

        return tokens

# 创建模型
config = MiniLLaMAConfig()
model = MiniLLaMA(config)

# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Mini-LLaMA 模型统计:")
print(f"  总参数量: {total_params:,} ({total_params/1e6:.1f}M)")
print(f"  可训练参数: {trainable_params:,}")
print(f"  配置: dim={config.dim}, layers={config.n_layers}, heads={config.n_heads}")
print(f"  KV头数: {config.n_kv_heads} (GQA, {config.n_heads//config.n_kv_heads}x压缩)")
print(f"  最大序列长度: {config.max_seq_len}")

# 测试前向传播
dummy_input = torch.randint(0, config.vocab_size, (2, 32))
dummy_target = torch.randint(0, config.vocab_size, (2, 32))
loss = model(dummy_input, dummy_target)
print(f"\n前向传播测试: loss = {loss.item():.4f}")
print(f"随机初始化loss理论值: {math.log(config.vocab_size):.4f} (ln({config.vocab_size}))")

5. BPE Tokenizer训练

5.1 BPE算法原理

Text Only
Byte Pair Encoding (BPE) 算法
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

初始: 将文本拆分为单个字符(或字节)
迭代: 每次合并出现频率最高的相邻对
重复: 直到达到目标词表大小

示例 (英文):
初始词表: ['l', 'o', 'w', 'e', 'r', 'n', 's', 't']
语料: "low" × 5, "lower" × 2, "newest" × 6, "lowest" × 3

Step 1: 最频繁对 ('e','s') → 合并为 'es'
Step 2: 最频繁对 ('es','t') → 合并为 'est'
Step 3: 最频繁对 ('l','o') → 合并为 'lo'
Step 4: 最频繁对 ('lo','w') → 合并为 'low'
...

最终: "lowest" → ['low', 'est']
     "newest" → ['n', 'ew', 'est']

5.2 手写BPE Tokenizer

Python
import re
from collections import Counter, defaultdict

class SimpleBPETokenizer:
    """
    手写BPE分词器
    支持中英文混合文本
    """

    def __init__(self, vocab_size: int = 4096):
        self.target_vocab_size = vocab_size
        self.merges = {}      # 合并规则: (a, b) → ab
        self.vocab = {}       # token → id
        self.inverse_vocab = {}  # id → token

        # 特殊token
        self.special_tokens = {
            "<pad>": 0,
            "<unk>": 1,
            "<bos>": 2,
            "<eos>": 3,
        }

    def _get_stats(self, word_freqs):
        """统计相邻pair的频率"""
        pairs = Counter()  # Counter统计元素出现次数
        for word, freq in word_freqs.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i+1])] += freq
        return pairs

    def _merge_pair(self, pair, word_freqs):
        """合并一个pair"""
        new_word_freqs = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)

        for word, freq in word_freqs.items():
            new_word = word.replace(bigram, replacement)
            new_word_freqs[new_word] = freq

        return new_word_freqs

    def train(self, texts: list, verbose: bool = True):
        """
        训练BPE分词器

        Args:
            texts: 训练文本列表
            verbose: 是否打印训练过程
        """
        # Step 1: 分词并统计词频
        word_freqs = Counter()
        for text in texts:
            # 简单分词:按空格和标点拆分
            words = re.findall(r'[\u4e00-\u9fff]|[a-zA-Z]+|[0-9]+|[^\s\w]', text)  # re.findall正则查找所有匹配项
            for word in words:
                # 将每个词拆分为字符(空格分隔)
                char_word = ' '.join(list(word))
                word_freqs[char_word] += 1

        # Step 2: 初始化词表(所有出现的字符)
        self.vocab = dict(self.special_tokens)
        chars = set()
        for word in word_freqs:
            for char in word.split():
                chars.add(char)

        for char in sorted(chars):
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)

        initial_vocab_size = len(self.vocab)
        if verbose:
            print(f"初始词表大小: {initial_vocab_size}")

        # Step 3: 迭代合并
        num_merges = self.target_vocab_size - initial_vocab_size

        for i in range(num_merges):
            # 统计pair频率
            pairs = self._get_stats(word_freqs)
            if not pairs:
                break

            # 找出最频繁的pair
            best_pair = max(pairs, key=pairs.get)
            best_freq = pairs[best_pair]

            if best_freq < 2:  # 频率太低则停止
                break

            # 合并
            word_freqs = self._merge_pair(best_pair, word_freqs)

            # 记录合并规则
            merged_token = ''.join(best_pair)
            self.merges[best_pair] = merged_token

            if merged_token not in self.vocab:
                self.vocab[merged_token] = len(self.vocab)

            if verbose and (i + 1) % 100 == 0:
                print(f"  合并 {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' → '{merged_token}' (freq={best_freq})")

        # 构建反向词表
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}

        if verbose:
            print(f"最终词表大小: {len(self.vocab)}")
            print(f"合并规则数: {len(self.merges)}")

    def encode(self, text: str) -> list:
        """将文本编码为token ID序列"""
        # 分词
        words = re.findall(r'[\u4e00-\u9fff]|[a-zA-Z]+|[0-9]+|[^\s\w]', text)

        ids = []
        for word in words:
            # 拆分为字符
            symbols = list(word)

            # 应用合并规则(贪婪)
            while len(symbols) > 1:
                # 找出可以合并的pair
                best_pair = None
                best_idx = -1

                for i in range(len(symbols) - 1):
                    pair = (symbols[i], symbols[i+1])
                    if pair in self.merges:
                        best_pair = pair
                        best_idx = i
                        break  # 按训练顺序优先合并

                if best_pair is None:
                    break

                # 执行合并
                symbols = (symbols[:best_idx] +
                          [self.merges[best_pair]] +
                          symbols[best_idx+2:])

            # 转换为ID
            for sym in symbols:
                ids.append(self.vocab.get(sym, self.special_tokens["<unk>"]))

        return ids

    def decode(self, ids: list) -> str:
        """将token ID序列解码为文本"""
        tokens = [self.inverse_vocab.get(id, "<unk>") for id in ids]
        return ''.join(tokens)

# === 训练Tokenizer ===
# 中文训练语料(示例)
training_texts = [
    "人工智能正在改变世界",
    "深度学习是机器学习的一个分支",
    "大语言模型具有强大的文本生成能力",
    "自然语言处理是人工智能的重要方向",
    "Transformer架构改变了深度学习的格局",
    "注意力机制是现代神经网络的核心组件",
    "预训练语言模型通过大规模数据学习语言知识",
    "GPT系列模型展示了规模扩展的威力",
    "BERT模型在自然语言理解任务上取得了突破",
    "强化学习可以用于优化语言模型的输出",
] * 50  # 重复增加频率

tokenizer = SimpleBPETokenizer(vocab_size=500)
tokenizer.train(training_texts)

# 测试
test_text = "人工智能改变世界"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print(f"\n编码测试:")
print(f"  原文: {test_text}")
print(f"  编码: {encoded}")
print(f"  解码: {decoded}")
print(f"  tokens: {[tokenizer.inverse_vocab.get(id, '?') for id in encoded]}")

6. 预训练Pipeline

6.1 数据准备

Python
import torch
from torch.utils.data import Dataset, DataLoader
import os

class TextDataset(Dataset):
    """
    用于CLM预训练的文本数据集
    将文本编码后切分为固定长度的序列
    """
    def __init__(self, texts: list, tokenizer, seq_len: int = 128):
        self.seq_len = seq_len

        # 编码所有文本
        all_ids = []
        for text in texts:
            ids = tokenizer.encode(text)
            all_ids.extend(ids)
            all_ids.append(tokenizer.special_tokens["<eos>"])

        # 切分为固定长度序列
        self.data = []
        for i in range(0, len(all_ids) - seq_len - 1, seq_len):
            input_ids = all_ids[i : i + seq_len]
            target_ids = all_ids[i + 1 : i + seq_len + 1]
            self.data.append((
                torch.tensor(input_ids, dtype=torch.long),
                torch.tensor(target_ids, dtype=torch.long)
            ))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

print("数据集准备完成")

6.2 训练循环

Python
import time

class MiniLLaMATrainer:
    """Mini-LLaMA训练器"""

    def __init__(self, model, config, tokenizer, device='cpu'):
        self.model = model.to(device)
        self.config = config
        self.tokenizer = tokenizer
        self.device = device

        # 优化器: AdamW
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=3e-4,
            betas=(0.9, 0.95),
            weight_decay=0.1
        )

        # 学习率调度: Cosine with warmup
        self.warmup_steps = 100
        self.total_steps = 0

    def _get_lr(self, step):
        """Cosine学习率调度 with warmup"""
        if step < self.warmup_steps:
            return 3e-4 * step / self.warmup_steps
        # Cosine decay
        progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
        return 3e-4 * 0.5 * (1 + math.cos(math.pi * progress))

    def train(self, train_dataset, num_epochs=10, batch_size=8,
              log_interval=50, eval_interval=200):
        """
        训练循环
        """
        dataloader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
        )
        self.total_steps = num_epochs * len(dataloader)

        self.model.train()
        global_step = 0
        best_loss = float('inf')

        print(f"开始训练 Mini-LLaMA")
        print(f"  Epochs: {num_epochs}")
        print(f"  Batch size: {batch_size}")
        print(f"  Total steps: {self.total_steps}")
        print(f"  Device: {self.device}")
        print("=" * 60)

        for epoch in range(num_epochs):
            epoch_loss = 0
            epoch_start = time.time()

            for batch_idx, (input_ids, target_ids) in enumerate(dataloader):  # enumerate同时获取索引和元素
                input_ids = input_ids.to(self.device)
                target_ids = target_ids.to(self.device)

                # 更新学习率
                lr = self._get_lr(global_step)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = lr

                # 前向传播
                loss = self.model(input_ids, target_ids)

                # 反向传播
                self.optimizer.zero_grad()
                loss.backward()

                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                self.optimizer.step()

                epoch_loss += loss.item()
                global_step += 1

                # 日志
                if global_step % log_interval == 0:
                    avg_loss = epoch_loss / (batch_idx + 1)
                    ppl = math.exp(min(avg_loss, 20))  # 限制防止溢出
                    elapsed = time.time() - epoch_start
                    print(f"  Step {global_step}: loss={avg_loss:.4f}, ppl={ppl:.2f}, lr={lr:.6f}, time={elapsed:.1f}s")

                # 生成样本
                if global_step % eval_interval == 0:
                    self._generate_sample()

            # Epoch结束统计
            avg_epoch_loss = epoch_loss / len(dataloader)
            epoch_time = time.time() - epoch_start
            print(f"\nEpoch {epoch+1}/{num_epochs}: avg_loss={avg_epoch_loss:.4f}, time={epoch_time:.1f}s")

            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                # 保存最佳模型
                torch.save(self.model.state_dict(), 'mini_llama_best.pt')
                print(f"  ✓ 保存最佳模型 (loss={best_loss:.4f})")
            print()

        print(f"训练完成! 最佳loss: {best_loss:.4f}")

    def _generate_sample(self):
        """训练中生成样本以观察效果"""
        self.model.eval()

        # 用几个起始token生成
        prompt = torch.tensor([[self.tokenizer.special_tokens["<bos>"]]]).to(self.device)

        generated = self.model.generate(
            prompt, max_new_tokens=50, temperature=0.8, top_k=40
        )

        text = self.tokenizer.decode(generated[0].tolist())
        print(f"  [生成样本] {text[:100]}...")

        self.model.train()

# ===  端到端训练示例 ===
print("端到端预训练流程:")
print("1. 准备语料 → 2. 训练Tokenizer → 3. 构建数据集 → 4. 训练模型 → 5. 生成文本")
print()

# 准备训练语料(实际应用中使用大规模语料)
corpus = [
    "人工智能是计算机科学的一个重要分支,它研究如何让计算机模拟人类的智能行为。",
    "深度学习通过多层神经网络自动学习数据的特征表示,是目前最成功的机器学习方法。",
    "自然语言处理让计算机能够理解和生成人类语言,包括文本分类、机器翻译等任务。",
    "大语言模型通过在海量文本上预训练,学习到了丰富的语言知识和世界知识。",
    "Transformer架构使用自注意力机制,能够并行处理序列中的所有位置。",
    "注意力机制允许模型在处理每个位置时关注序列中的其他相关位置。",
    "预训练加微调的范式极大地提高了模型在各种下游任务上的表现。",
    "强化学习通过试错的方式让智能体学习最优策略,在游戏和机器人控制中表现出色。",
] * 100  # 实际需要大量数据

# 1. 训练Tokenizer
tok = SimpleBPETokenizer(vocab_size=500)
tok.train(corpus, verbose=False)

# 2. 更新模型配置
train_config = MiniLLaMAConfig(
    vocab_size=len(tok.vocab),
    dim=128,
    n_layers=4,
    n_heads=4,
    n_kv_heads=2,
    max_seq_len=128,
    multiple_of=32
)

# 3. 创建模型
train_model = MiniLLaMA(train_config)
params = sum(p.numel() for p in train_model.parameters())
print(f"模型参数量: {params:,} ({params/1e6:.2f}M)")

# 4. 创建数据集
train_dataset = TextDataset(corpus, tok, seq_len=64)
print(f"训练样本数: {len(train_dataset)}")

# 5. 训练(实际训练需要更多数据和更长时间)
# trainer = MiniLLaMATrainer(train_model, train_config, tok, device='cpu')
# trainer.train(train_dataset, num_epochs=5, batch_size=8)

print("\n注意: 完整训练需要:")
print("  - 几GB~几十GB的中文语料")
print("  - GPU加速(推荐至少一张RTX 3090)")
print("  - 数小时到数天的训练时间")
print("  - 更大的模型配置(dim=512+, layers=12+)")

7. 模型评估与生成

7.1 困惑度评估

Python
@torch.no_grad()
def evaluate_perplexity(model, dataset, batch_size=8, device='cpu'):
    """
    计算模型在数据集上的困惑度(Perplexity)

    PPL = exp(avg_loss)
    PPL越低,模型越好
    """
    model.eval()
    model = model.to(device)

    dataloader = DataLoader(dataset, batch_size=batch_size)
    total_loss = 0
    total_tokens = 0

    for input_ids, target_ids in dataloader:
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)

        loss = model(input_ids, target_ids)

        num_tokens = (target_ids != -100).sum().item()  # 与 ignore_index=-100 保持一致
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)

    return perplexity, avg_loss

# 如果模型已训练好:
# ppl, loss = evaluate_perplexity(train_model, train_dataset)
# print(f"困惑度: {ppl:.2f}, 平均Loss: {loss:.4f}")

7.2 多种采样策略对比

Python
def generate_text(model, tokenizer, prompt, max_tokens=100,
                  strategy='top_p', **kwargs):
    """
    文本生成:支持多种采样策略
    """
    device = next(model.parameters()).device

    # 编码prompt
    ids = tokenizer.encode(prompt)
    tokens = torch.tensor([ids]).to(device)

    if strategy == 'greedy':
        # 贪婪搜索:每步选概率最高的
        generated = model.generate(tokens, max_tokens, temperature=0.01, top_k=1)

    elif strategy == 'temperature':
        # 温度采样
        temp = kwargs.get('temperature', 1.0)
        generated = model.generate(tokens, max_tokens, temperature=temp, top_k=0, top_p=1.0)

    elif strategy == 'top_k':
        # Top-K采样
        k = kwargs.get('k', 50)
        generated = model.generate(tokens, max_tokens, temperature=0.8, top_k=k)

    elif strategy == 'top_p':
        # Nucleus (Top-P) 采样
        p = kwargs.get('p', 0.9)
        generated = model.generate(tokens, max_tokens, temperature=0.8, top_p=p)

    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    return tokenizer.decode(generated[0].tolist())

print("""
采样策略对比:
┌────────────┬────────────────────┬───────────────┐
│ 策略       │ 特点               │ 适用场景      │
├────────────┼────────────────────┼───────────────┤
│ Greedy     │ 确定性,最保守     │ 翻译/摘要     │
│ Temperature│ 温度高→随机        │ 通用          │
│ Top-K      │ 只从前K个候选采样  │ 创作          │
│ Top-P      │ 动态K, 累积概率≤P  │ 对话/创作     │
└────────────┴────────────────────┴───────────────┘

推荐设置:
- 创意写作: temperature=1.0, top_p=0.95
- 对话: temperature=0.7, top_p=0.9
- 代码/事实: temperature=0.2, top_k=10
""")

8. 工程优化技巧

8.1 KV Cache

Python
class MiniLLaMAWithKVCache(nn.Module):
    """
    带KV Cache的推理优化

    问题: 自回归生成时,每生成一个token都要重新计算所有位置的KV
    优化: 缓存已计算的KV,新token只需计算增量

    无Cache: 生成100个token → 计算 1+2+...+100 = 5050 次KV
    有Cache: 生成100个token → 计算 100 次KV
    """

    def __init__(self, config):
        super().__init__()
        # ... 复用MiniLLaMA的初始化 ...
        self.config = config

    # KV Cache的核心思想(伪代码)
    def generate_with_cache(self, prompt_tokens, max_new_tokens=100):
        """
        KV Cache加速生成
        """
        # 初始化cache: 每层存储K和V
        kv_cache = [
            {"k": None, "v": None}
            for _ in range(self.config.n_layers)
        ]

        # Prefill阶段:处理整个prompt
        # 此时计算所有位置的KV并缓存
        logits = self.forward_with_cache(prompt_tokens, kv_cache, is_prefill=True)

        tokens = prompt_tokens
        for step in range(max_new_tokens):
            # Decode阶段:每次只处理新增的1个token
            next_token = logits[:, -1:, :].argmax(dim=-1)
            tokens = torch.cat([tokens, next_token], dim=-1)

            # 只传入新token,利用cache中已有的KV
            logits = self.forward_with_cache(next_token, kv_cache, is_prefill=False)

        return tokens

print("""
KV Cache性能提升:
┌───────────────┬──────────────────┬──────────────────┐
│ 生成长度      │ 无Cache计算量    │ 有Cache计算量    │
├───────────────┼──────────────────┼──────────────────┤
│ 100 tokens    │ 5050 × KV_cost   │ 100 × KV_cost    │
│ 1000 tokens   │ 500500 × KV_cost │ 1000 × KV_cost   │
│ 加速比        │ ~50x (100)       │ ~500x (1000)     │
└───────────────┴──────────────────┴──────────────────┘

显存占用:
  KV Cache大小 = 2 × n_layers × n_kv_heads × head_dim × seq_len × batch_size
  对于LLaMA-2-7B (seq_len=4096, batch=1):
    = 2 × 32 × 32 × 128 × 4096 × 2 bytes ≈ 2GB
""")

8.2 混合精度训练

Python
# 使用PyTorch AMP (Automatic Mixed Precision)
from torch.amp import autocast, GradScaler

def train_step_mixed_precision(model, optimizer, input_ids, target_ids, scaler):
    """
    混合精度训练步骤

    FP32运算 → 改为 FP16 → 速度2x, 显存减半
    但某些操作(如loss计算、归一化)需要保持FP32
    """
    optimizer.zero_grad()

    # 自动混合精度: 前向传播中适当使用FP16
    with autocast(device_type='cuda'):
        loss = model(input_ids, target_ids)

    # 缩放loss以防止FP16下梯度下溢
    scaler.scale(loss).backward()

    # 梯度裁剪(需要先unscale)
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # 优化器步骤
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

print("""
混合精度训练好处:
1. 训练速度提升 ~2x(利用Tensor Cores)
2. 显存占用减少 ~50%
3. 模型精度基本不受影响

关键技术:
- GradScaler: 防止FP16梯度下溢
- autocast: 自动选择合适的精度
- 保持关键操作在FP32: LayerNorm, Loss, Softmax
""")

练习题

练习1:组件验证

  1. 修改GQA的n_kv_heads为1(MQA)和n_heads(MHA),对比参数量差异
  2. 将RMSNorm替换为nn.LayerNorm,观察训练稳定性差异

练习2:Tokenizer扩展

  1. 为Tokenizer添加<pad>, <unk>等特殊token的正确处理
  2. 实现WordPiece分词器并与BPE对比

练习3:预训练实验

  1. 在WikiText-2数据集上预训练Mini-LLaMA
  2. 对比不同超参数(dim, n_layers, lr)对困惑度的影响
  3. 实现学习率Warmup + Cosine Decay调度器

练习4:生成优化

  1. 实现Beam Search解码
  2. 实现Repetition Penalty(减少重复生成)
  3. 实现KV Cache并测量推理速度提升

📝 本章小结

知识点 掌握程度检查
LLaMA架构特点 能否说出与原始Transformer的5个主要区别?
RMSNorm 能否手写实现并解释为什么比LayerNorm好?
RoPE 能否解释旋转位置编码的核心思想?
SwiGLU 能否解释门控机制的作用?
GQA 能否解释MHA→GQA→MQA的演进逻辑?
BPE Tokenizer 能否手写BPE训练过程?
CLM预训练 能否解释Next Token Prediction的训练流程?
KV Cache 能否解释KV Cache加速推理的原理?

🔗 后续学习路径

📚 参考资料

  1. Touvron et al. "LLaMA: Open and Efficient Foundation Language Models" (2023)
  2. Touvron et al. "Llama 2: Open Foundation and Fine-Tuned Chat Models" (2023)
  3. Su et al. "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
  4. Zhang & Sennrich "Root Mean Square Layer Normalization" (2019)
  5. Shazeer "GLU Variants Improve Transformer" (2020) — SwiGLU
  6. Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models" (2023)
  7. Sennrich et al. "Neural Machine Translation of Rare Words with Subword Units" (2016) — BPE
  8. Kaplan et al. "Scaling Laws for Neural Language Models" (2020)
  9. Korthikanti et al. "Reducing Activation Recomputation in Large Transformer Models" (2022)
  10. Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)