跳转至

01 - Transformer深入理解(全面版)

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:从零开始深入理解Transformer的每一个组件,掌握其数学原理、实现细节和设计思想。

📌 定位说明:本章侧重大模型视角下的Transformer深入理解(RoPE/RMSNorm/GeLU/编解码器架构对比等)。Transformer架构的基础教学(从零实现完整Transformer)请参考 深度学习/04-Transformer/02-Transformer架构


目录

  1. Transformer架构总览
  2. 输入嵌入层深度解析
  3. 位置编码机制
  4. 自注意力机制详解
  5. 多头注意力机制
  6. 前馈神经网络
  7. 层归一化与残差连接
  8. 编解码器架构对比
  9. Transformer的训练
  10. Transformer的推理

Transformer架构总览

1.1 为什么需要Transformer

在Transformer出现之前,序列建模主要依赖RNN/LSTM:

Text Only
RNN的问题:
├── 顺序计算,无法并行
├── 长距离依赖困难(梯度消失/爆炸)
└── 计算复杂度与序列长度成正比

Transformer的解决方案:
├── 完全基于注意力机制
├── 完全并行计算
├── 任意位置间距离都是O(1)
└── 成为现代NLP的基础架构

1.2 架构全景图

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                     Transformer 架构全景                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    输入序列 (Input)                      │    │
│  │              [我, 喜欢, 深度, 学习]                       │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              输入嵌入 + 位置编码                          │    │
│  │         (Token Embedding + Positional Encoding)          │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│           ┌──────────────────┴──────────────────┐               │
│           ▼                                      ▼               │
│  ┌─────────────────┐                  ┌─────────────────┐       │
│  │    Encoder      │                  │    Decoder      │       │
│  │  (编码器堆叠)    │                  │  (解码器堆叠)    │       │
│  │                 │                  │                 │       │
│  │  ┌───────────┐  │                  │  ┌───────────┐  │       │
│  │  │ Block N   │  │                  │  │ Block N   │  │       │
│  │  │ ┌───────┐ │  │                  │  │ ┌───────┐ │  │       │
│  │  │ │Multi- │ │  │                  │  │ │Masked │ │  │       │
│  │  │ │Head   │ │  │                  │  │ │Multi-H│ │  │       │
│  │  │ │Attn   │ │  │                  │  │ │Attn   │ │  │       │
│  │  │ └───────┘ │  │                  │  │ └───────┘ │  │       │
│  │  │ ┌───────┐ │  │                  │  │ ┌───────┐ │  │       │
│  │  │ │ Feed  │ │  │                  │  │ │Multi-H│ │  │       │
│  │  │ │Forward│ │  │                  │  │ │Cross  │ │  │       │
│  │  │ └───────┘ │  │                  │  │ │Attn   │ │  │       │
│  │  └───────────┘  │                  │  │ └───────┘ │  │       │
│  │       ...       │                  │  │ ┌───────┐ │  │       │
│  │  ┌───────────┐  │                  │  │ │ Feed  │ │  │       │
│  │  │ Block 1   │  │                  │  │ │Forward│ │  │       │
│  │  └───────────┘  │                  │  │ └───────┘ │  │       │
│  └─────────────────┘                  │  └───────────┘  │       │
│           │                           └─────────────────┘       │
│           │                                  │                   │
│           └──────────────────┬───────────────┘                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    输出层 (Output)                       │    │
│  │              Linear + Softmax → 概率分布                 │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    预测输出                              │    │
│  │              [I, like, deep, learning]                   │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Transformer架构图

上图展示了Transformer的完整架构,包括编码器(Encoder)和解码器(Decoder)的堆叠结构,以及自注意力机制和前馈神经网络等核心组件。

1.3 核心组件一览

组件 功能 关键参数 复杂度
输入嵌入 将token映射为向量 vocab_size × d_model -
位置编码 注入位置信息 d_model -
多头注意力 捕捉不同子空间的关系 h heads, d_k = d_model/h O(n²·d)
前馈网络 非线性变换 d_model → 4d_model → d_model O(n·d²)
层归一化 稳定训练 d_model O(n·d)
残差连接 缓解梯度消失 - -

输入嵌入层深度解析

2.1 词嵌入的数学本质

Text Only
词嵌入层就是一个查找表(Lookup Table):

E ∈ ℝ^(V × d)

其中:
- V: 词汇表大小(如32000)
- d: 嵌入维度(如512, 768, 1024)

对于输入token id = i:
embedding = E[i, :] ∈ ℝ^d

这相当于一个one-hot向量与嵌入矩阵的乘法:
embedding = one_hot(i) @ E

2.2 嵌入层的实现细节

Python
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    """
    Token嵌入层完整实现
    """
    def __init__(self, vocab_size, d_model, padding_idx=None):
        super().__init__()  # super()调用父类方法
        self.vocab_size = vocab_size
        self.d_model = d_model

        # 嵌入矩阵
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model,
            padding_idx=padding_idx  # 用于mask填充位置
        )

        # 缩放因子(Transformer原论文使用)
        self.scale = math.sqrt(d_model)

        # 初始化(重要!)
        self._init_weights()

    def _init_weights(self):
        """
        嵌入层初始化策略
        使用N(0, 1/d_model)初始化
        """
        nn.init.normal_(self.embedding.weight, mean=0, std=1/math.sqrt(self.d_model))

        # 如果有padding_idx,将该位置嵌入置零
        if self.embedding.padding_idx is not None:
            with torch.no_grad():  # 禁用梯度追踪,避免初始化操作被记入计算图
                self.embedding.weight[self.embedding.padding_idx].fill_(0)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len] token IDs
        Returns:
            embeddings: [batch_size, seq_len, d_model]
        """
        # 查找嵌入并缩放
        # 缩放原因:后续要与位置编码相加,需要平衡量级
        return self.embedding(x) * self.scale

# 使用示例
vocab_size = 32000
d_model = 512
batch_size = 2
seq_len = 10

embedding_layer = TokenEmbedding(vocab_size, d_model)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
embeddings = embedding_layer(input_ids)

print(f"输入形状: {input_ids.shape}")
print(f"嵌入形状: {embeddings.shape}")
print(f"嵌入范围: [{embeddings.min():.3f}, {embeddings.max():.3f}]")

2.3 子词分词与嵌入

现代大模型使用子词(Subword)分词,如BPE、WordPiece、SentencePiece:

Text Only
传统分词的问题:
- "playing" 和 "played" 被视为完全不同的词
- 未登录词(OOV)问题

子词分词的优势:
"playing" → ["play", "ing"]
"unhappiness" → ["un", "happiness"] 或 ["un", "happy", "ness"]

常见分词器:
├── GPT系列: BPE (Byte-Pair Encoding)
├── BERT: WordPiece
├── LLaMA/T5: SentencePiece (Unigram)
└── 中文: 字级别或BPE
Python
# 使用Hugging Face Tokenizer示例
from transformers import AutoTokenizer

# 加载GPT-2的分词器
tokenizer = AutoTokenizer.from_pretrained("gpt2")

text = "Transformer架构 revolutionized NLP"
tokens = tokenizer.tokenize(text)
print(f"分词结果: {tokens}")
# 输出: ['Trans', 'former', '架构', 'Ġre', 'volution', 'ized', 'ĠN', 'LP']

# 转换为ID
input_ids = tokenizer.encode(text)
print(f"Token IDs: {input_ids}")

# 解码
decoded = tokenizer.decode(input_ids)
print(f"解码结果: {decoded}")

位置编码机制

3.1 为什么需要位置编码

Text Only
自注意力的排列不变性:

输入: [A, B, C] → 自注意力 → 输出
输入: [B, A, C] → 自注意力 → 输出(只是顺序变了)

问题:模型无法区分"我打你"和"你打我"

解决方案:注入位置信息

3.2 正弦位置编码(Sinusoidal)

Transformer原始论文使用的位置编码:

Text Only
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:
- pos: 位置(0, 1, 2, ..., max_len-1)
- i: 维度索引(0, 1, 2, ..., d_model/2-1)
- d_model: 模型维度
Python
class SinusoidalPositionalEncoding(nn.Module):
    """
    正弦位置编码实现
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)

        # 位置索引 [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # unsqueeze增加一个维度

        # 维度索引的除数项
        # 10000^(2i/d_model) = exp(2i * -log(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model)
        )

        # 偶数维度用sin
        pe[:, 0::2] = torch.sin(position * div_term)

        # 奇数维度用cos
        pe[:, 1::2] = torch.cos(position * div_term)

        # 注册为buffer(不参与训练)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)

        # 添加位置编码
        x = x + self.pe[:, :seq_len, :]

        return self.dropout(x)

# 可视化位置编码
import matplotlib.pyplot as plt

def visualize_positional_encoding():
    d_model = 128
    max_len = 100

    pe = SinusoidalPositionalEncoding(d_model, max_len)
    encoding = pe.pe[0].numpy()  # [max_len, d_model]

    plt.figure(figsize=(12, 6))
    plt.imshow(encoding, aspect='auto', cmap='viridis')
    plt.colorbar()
    plt.xlabel('Dimension')
    plt.ylabel('Position')
    plt.title('Sinusoidal Positional Encoding')
    plt.show()

    # 观察不同位置的编码
    plt.figure(figsize=(12, 4))
    for pos in [0, 10, 20, 50]:
        plt.plot(encoding[pos], label=f'Pos {pos}')
    plt.legend()
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    plt.title('Positional Encoding at Different Positions')
    plt.show()

# 正弦位置编码的性质
"""
性质1: 唯一性
每个位置都有唯一的编码

性质2: 相对位置关系
PE(pos+k) 可以表示为 PE(pos) 的线性函数
这意味着模型可以学习相对位置

性质3: 有界性
所有值都在[-1, 1]之间,数值稳定

性质4: 外推性
可以处理训练时未见过的更长序列
"""

3.3 可学习位置编码

BERT等模型使用可学习的位置编码:

Python
class LearnedPositionalEncoding(nn.Module):
    """
    可学习位置编码
    """
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 可学习的嵌入矩阵
        self.pos_embedding = nn.Embedding(max_len, d_model)

        # 初始化
        nn.init.normal_(self.pos_embedding.weight, std=0.02)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)

        # 创建位置索引
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)

        # 添加位置嵌入
        x = x + self.pos_embedding(positions)

        return self.dropout(x)

3.4 旋转位置编码 (RoPE)

现代大模型(LLaMA、GPT-NeoX等)使用RoPE:

3.4.1 RoPE的数学推导(从复数乘法出发)

Text Only
RoPE的核心目标:设计一种位置编码函数 f(x, pos),使得两个位置的内积
只依赖于相对位置差 (m - n),而非绝对位置 m 和 n:

    <f(q, m), f(k, n)> = g(q, k, m - n)

第一步:复数表示
  将向量的每两个相邻维度视为一个复数:
    z = x_{2i} + j·x_{2i+1}  (j 是虚数单位)

第二步:复数乘法 = 旋转
  在复数平面上,乘以 e^{jθ} 等价于旋转角度 θ:
    z · e^{jθ} = (x_{2i} + j·x_{2i+1})(cosθ + j·sinθ)
              = (x_{2i}·cosθ - x_{2i+1}·sinθ) + j·(x_{2i}·sinθ + x_{2i+1}·cosθ)

  写成矩阵形式:
    [x_{2i}' ]   [cosθ  -sinθ] [x_{2i}  ]
    [x_{2i+1}'] = [sinθ   cosθ] [x_{2i+1}]

第三步:位置相关的旋转角度
  对位置 m 的第 i 对维度,旋转角度 θ_i(m) = m · ω_i
  其中 ω_i = 1 / 10000^{2i/d}(频率随维度指数递减)

第四步:验证相对位置性质
  对位置 m 的 query 和位置 n 的 key(在第 i 对维度上):
    f(q, m) = q · e^{j·m·ω_i}
    f(k, n) = k · e^{j·n·ω_i}

  内积(复数):
    Re[f(q, m) · conj(f(k, n))]
    = Re[q · e^{j·m·ω_i} · conj(k · e^{j·n·ω_i})]
    = Re[q · conj(k) · e^{j·(m-n)·ω_i}]

  结果只依赖 (m-n),证毕。✓

第五步:推广到全维度
  对 d 维向量,每两维一组,共 d/2 组,每组用不同频率 ω_i 旋转:

    RoPE(x, pos) = [R(pos·ω_0)·x_{0:1}, R(pos·ω_1)·x_{2:3}, ..., R(pos·ω_{d/2-1})·x_{d-2:d-1}]

  其中 R(θ) 是 2×2 旋转矩阵。

3.4.2 RoPE实现

Python
class RotaryPositionalEmbedding(nn.Module):
    """
    旋转位置编码 (RoPE)
    通过旋转矩阵注入相对位置信息
    """
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()

        # 计算旋转角度
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算位置编码
        t = torch.arange(max_seq_len)
        freqs = torch.einsum('i,j->ij', t, inv_freq)  # [max_seq_len, dim/2]

        # 复数形式: cos + i*sin
        self.register_buffer('cos_cached', freqs.cos()[None, None, :, :])  # [1, 1, seq_len, dim/2]
        self.register_buffer('sin_cached', freqs.sin()[None, None, :, :])  # [1, 1, seq_len, dim/2]

    def forward(self, x, seq_len=None):
        """
        Args:
            x: [batch, heads, seq_len, head_dim]
        """
        if seq_len is None:
            seq_len = x.shape[2]

        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]

        return self.apply_rotary_pos_emb(x, cos, sin)

    def apply_rotary_pos_emb(self, x, cos, sin):
        """
        应用旋转位置编码

        Args:
            x: [batch, heads, seq_len, head_dim]
            cos: [1, 1, seq_len, head_dim/2]
            sin: [1, 1, seq_len, head_dim/2]

        Returns:
            [batch, heads, seq_len, head_dim]
        """
        # 将x分成两部分(偶数维和奇数维)
        x1, x2 = x[..., ::2], x[..., 1::2]

        # 旋转: [x1, x2] @ [[cos, -sin], [sin, cos]]
        # 即: x1' = x1 * cos - x2 * sin
        #     x2' = x1 * sin + x2 * cos
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1)

        # 将最后两维展平,恢复原始形状
        return rotated.flatten(-2)

RoPE旋转位置编码可视化

上图展示了RoPE(旋转位置编码)的工作原理。RoPE通过旋转矩阵将位置信息注入到查询和键向量中,每个维度对在不同的旋转速度下工作,类似于不同速度的时钟指针。

RoPE 2D投影可视化

这张图展示了RoPE在2D平面上的投影,不同颜色代表不同位置的向量,它们从原点发出,展示了实部和虚部之间的关系。RoPE的优势在于能够自然地编码相对位置信息,并且具有良好的外推性能。


自注意力机制详解

4.1 直觉理解

自注意力的核心思想:让序列中每个位置都"看到"其他所有位置,并学习该关注哪些位置。

Text Only
以"猫坐在垫子上,它很舒服"为例:

处理"它"时,自注意力权重分布(示意):
猫  坐  在  垫子 上  ,  它  很 舒服
0.45 0.05 0.02 0.15 0.03 0.01 0.10 0.04 0.15

→ 模型学会了"它"主要指代"猫"(最高权重),次要关联"垫子"和"舒服"

4.2 Q-K-V 的数学推导

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

三个矩阵的角色: - \(Q\)(Query):当前位置发出的"提问" - \(K\)(Key):每个位置提供的"索引标签" - \(V\)(Value):每个位置存储的"实际内容"

Text Only
类比图书馆:
Q = 你的搜索关键词("深度学习入门")
K = 每本书的标题/关键词索引
V = 每本书的实际内容

检索流程:
1. 计算 Q @ K^T → 搜索词与每本书标题的匹配分数
2. softmax → 将分数归一化为概率
3. 概率 @ V → 加权提取最相关书的内容
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    query: torch.Tensor,   # [batch, seq_q, d_k]
    key: torch.Tensor,     # [batch, seq_k, d_k]
    value: torch.Tensor,   # [batch, seq_k, d_v]
    mask: torch.Tensor = None,
    dropout: nn.Dropout = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    缩放点积注意力

    关键点:
    1. 为什么除以 sqrt(d_k)?(完整推导)

       假设 q 和 k 的每个分量独立同分布,均值为0、方差为1:
         q = [q_1, q_2, ..., q_{d_k}],  q_i ~ (0, 1)
         k = [k_1, k_2, ..., k_{d_k}],  k_i ~ (0, 1)

       点积 q·k = Σ_{i=1}^{d_k} q_i * k_i

       每一项 q_i * k_i 的方差:
         Var(q_i * k_i) = E[q_i²] * E[k_i²] - (E[q_i] * E[k_i])²
                        = 1 * 1 - 0 = 1

       因为各项独立,所以:
         Var(q·k) = Σ Var(q_i * k_i) = d_k

       当 d_k=64 时,点积标准差 ≈ 8;d_k=128 时 ≈ 11.3
       这些大值会将 softmax 推入饱和区(梯度接近 0)

       除以 √d_k 后:
         Var(q·k / √d_k) = Var(q·k) / d_k = d_k / d_k = 1

       方差归一化为 1,softmax 的输入保持在合理范围,梯度可以正常流动。

    2. mask 的两种用途:
       - padding mask: 屏蔽填充位置
       - causal mask: 屏蔽未来位置(解码器)
    """
    d_k = query.size(-1)

    # Step 1: QK^T / sqrt(d_k) → [batch, seq_q, seq_k]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: 应用 mask(将需要屏蔽的位置设为 -inf)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax → 权重矩阵
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Dropout(训练时)
    if dropout is not None:
        attn_weights = dropout(attn_weights)

    # Step 5: 加权求和 → [batch, seq_q, d_v]
    output = torch.matmul(attn_weights, value)

    return output, attn_weights

# ---------- 验证 ----------
batch, seq_len, d_k = 2, 5, 64
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)

# 无 mask
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {out.shape}")           # [2, 5, 64]
print(f"权重形状: {weights.shape}")        # [2, 5, 5]
print(f"权重行和: {weights.sum(-1)}")      # 每行和为 1.0

# 因果 mask(解码器用)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))  # 下三角矩阵
out_masked, weights_masked = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
print(f"\n因果mask权重(第0行只看自己,第4行看所有位置):")
print(weights_masked[0].detach())

4.3 注意力复杂度分析

Text Only
时间复杂度: O(n² · d)
  - 其中 n = 序列长度, d = 模型维度
  - 主要来自 QK^T 的矩阵乘法

空间复杂度: O(n²)
  - 需要存储完整的注意力权重矩阵

序列长度 →  注意力计算量(d=1024)
128         → 16.7M
1,024       → 1.07B
8,192       → 68.7B
128,000     → 16.8T  ← 这就是长上下文的挑战

解决方案:
├── FlashAttention: 利用GPU内存层级优化,O(n²)复杂度不变但实际速度快2-4倍
├── GQA (Grouped-Query Attention): 减少KV头数,降低内存
├── MQA (Multi-Query Attention): 所有头共享一组KV
├── Ring Attention: 分布式环形通信处理超长序列
└── 稀疏注意力: 只计算部分位置对的注意力

多头注意力机制

5.1 为什么需要多头

单头注意力的问题:一个注意力头只能学到一种"关注模式"。而自然语言有多种关系维度(语法、语义、共指、因果等)。

Text Only
多头的直觉:用不同的"眼睛"看同一句话

Head 1: 可能学会了语法依赖(主语→动词)
Head 2: 可能学会了共指关系(代词→先行词)
Head 3: 可能学会了修饰关系(形容词→名词)
Head 4: 可能学会了长距离依赖(问题→答案)
...

5.2 多头注意力实现

\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \]
\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]
Python
class MultiHeadAttention(nn.Module):
    """
    多头注意力机制

    参数量分析(d_model=1024, h=16, d_k=64):
    W_Q: 1024×1024 = 1M
    W_K: 1024×1024 = 1M
    W_V: 1024×1024 = 1M
    W_O: 1024×1024 = 1M
    总计: 4M 参数
    """

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须被 n_heads 整除"  # assert断言:条件False时抛出AssertionError

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度

        # 四个投影矩阵
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        query: torch.Tensor,   # [batch, seq_q, d_model]
        key: torch.Tensor,     # [batch, seq_k, d_model]
        value: torch.Tensor,   # [batch, seq_k, d_model]
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)

        # 1. 线性投影 → [batch, seq, d_model]
        Q = self.W_Q(query)
        K = self.W_K(key)
        V = self.W_V(value)

        # 2. 分头: [batch, seq, d_model] → [batch, n_heads, seq, d_k]
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # view重塑张量形状
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # 3. 缩放点积注意力
        d_k = self.d_k
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 4. 加权求和 → [batch, n_heads, seq_q, d_k]
        context = torch.matmul(attn_weights, V)

        # 5. 合并多头: [batch, seq_q, d_model]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 6. 输出投影
        output = self.W_O(context)

        return output

# 验证
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)

# Self-attention: Q=K=V=x
out = mha(x, x, x)
print(f"Self-attention 输出: {out.shape}")  # [2, 10, 512]

# Cross-attention: Q来自decoder, K/V来自encoder
enc_out = torch.randn(2, 20, 512)
dec_in = torch.randn(2, 10, 512)
cross_out = mha(dec_in, enc_out, enc_out)
print(f"Cross-attention 输出: {cross_out.shape}")  # [2, 10, 512]

5.3 GQA 与 MQA:现代大模型的注意力优化

Text Only
标准 MHA (Multi-Head Attention):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, n_heads, seq, d_k]    ← 每个头有独立的K
  V: [batch, n_heads, seq, d_k]    ← 每个头有独立的V
  KV cache 大小 = 2 × n_heads × seq × d_k

MQA (Multi-Query Attention, GPT-J):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, 1, seq, d_k]          ← 所有头共享一组K
  V: [batch, 1, seq, d_k]          ← 所有头共享一组V
  KV cache 大小 = 2 × 1 × seq × d_k  → 减少 n_heads 倍

GQA (Grouped-Query Attention, LLaMA-2/3):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, n_kv_heads, seq, d_k] ← 每组共享一组K
  V: [batch, n_kv_heads, seq, d_k] ← 每组共享一组V
  KV cache 大小 = 2 × n_kv_heads × seq × d_k

示例 (LLaMA-3-70B): n_heads=64, n_kv_heads=8
  → 每8个Q头共享1组KV → KV cache减少8倍,性能接近MHA

前馈神经网络

6.1 标准 FFN

每个Transformer层中的全连接前馈网络(Position-wise FFN):

\[ \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 \]
Python
class FeedForward(nn.Module):
    """标准FFN,4倍扩展"""
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model  # 默认4倍扩展

        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # [batch, seq, d_model] → [batch, seq, d_ff] → [batch, seq, d_model]
        return self.w2(self.dropout(F.relu(self.w1(x))))

6.2 现代激活函数:GeLU 与 SwiGLU

Python
# GeLU (GPT-2/BERT 使用)
# GELU(x) = x · Φ(x),其中 Φ 是标准正态分布的CDF
# 比 ReLU 更平滑,不会在 x=0 处产生不可导点

class GeLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w2(self.dropout(F.gelu(self.w1(x))))

# SwiGLU (LLaMA/Qwen/Mistral 使用)
# SwiGLU(x, W, V, W2) = (Swish(xW) ⊙ xV) W2
# 引入门控机制,效果优于 GeLU,但参数量增加 50%

class SwiGLUFeedForward(nn.Module):
    """LLaMA-style SwiGLU FFN"""
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        # LLaMA 使用 8/3 * d_model 作为 d_ff 以保持参数量不变
        d_ff = d_ff or int(8 / 3 * d_model)
        # 取最接近的 256 的倍数(GPU对齐优化)
        d_ff = ((d_ff + 255) // 256) * 256

        self.w1 = nn.Linear(d_model, d_ff, bias=False)  # gate projection
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # up projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)  # down projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Swish(xW1) ⊙ xW3 → 门控机制
        gate = F.silu(self.w1(x))  # SiLU = Swish(x) = x * sigmoid(x)
        up = self.w3(x)
        return self.w2(self.dropout(gate * up))

# 对比
d_model = 1024
ffn_relu = FeedForward(d_model)
ffn_gelu = GeLUFeedForward(d_model)
ffn_swiglu = SwiGLUFeedForward(d_model)

for name, module in [("ReLU FFN", ffn_relu), ("GeLU FFN", ffn_gelu), ("SwiGLU FFN", ffn_swiglu)]:
    params = sum(p.numel() for p in module.parameters())
    print(f"{name}: {params:,} 参数")
# ReLU FFN: 8,393,728 参数
# GeLU FFN: 8,393,728 参数
# SwiGLU FFN: 8,650,752 参数 (d_ff调整后参数量接近)

层归一化与残差连接

7.1 为什么需要归一化

深层网络中,每层输出的分布会不断漂移(Internal Covariate Shift),导致训练不稳定。归一化将输出拉回标准分布。

7.2 LayerNorm vs RMSNorm

Python
class LayerNorm(nn.Module):
    """
    标准层归一化 (BERT/GPT-2 使用)
    LayerNorm(x) = γ · (x - μ) / √(σ² + ε) + β
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))   # 可学习缩放
        self.beta = nn.Parameter(torch.zeros(d_model))    # 可学习偏移
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class RMSNorm(nn.Module):
    """
    RMS层归一化 (LLaMA/Qwen/Mistral 使用)
    RMSNorm(x) = γ · x / RMS(x)
    其中 RMS(x) = √(1/n · Σx²)

    优势:去掉了均值计算和偏移参数 β
    → 计算更快(约10-15%),效果相当
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return self.gamma * x / rms

# 对比
x = torch.randn(2, 10, 512)
ln = LayerNorm(512)
rmsn = RMSNorm(512)

out_ln = ln(x)
out_rmsn = rmsn(x)
print(f"LayerNorm 输出统计: mean={out_ln.mean():.4f}, std={out_ln.std():.4f}")
print(f"RMSNorm 输出统计:   mean={out_rmsn.mean():.4f}, std={out_rmsn.std():.4f}")

7.3 Pre-Norm vs Post-Norm

Text Only
Post-Norm (原始Transformer, BERT):
  x → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·) → LayerNorm

  优点:最终层的表示经过归一化,理论性质更好
  缺点:深层训练不稳定,需要 warm-up

Pre-Norm (GPT-2, LLaMA, 现代模型):
  x → LayerNorm → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·)

  优点:梯度直接通过残差传播,训练更稳定
  缺点:理论上收敛性稍差,但实践中更常用

现代大模型几乎都使用 Pre-Norm + RMSNorm 的组合。
Python
class TransformerBlock(nn.Module):
    """Pre-Norm Transformer Block (LLaMA-style)"""

    def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = RMSNorm(d_model)
        self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)

    def forward(self, x, mask=None):
        # Pre-Norm + Residual
        x_norm = self.norm1(x)
        x = x + self.attn(x_norm, x_norm, x_norm, mask)
        x = x + self.ffn(self.norm2(x))
        return x

编解码器架构对比

8.1 三种主流架构

这是理解现代大模型最关键的知识点之一:

Text Only
┌────────────────────────────────────────────────────────────────┐
│                    三种 Transformer 架构                        │
├──────────────┬──────────────┬──────────────┬──────────────────┤
│              │ Encoder-Only │ Decoder-Only │ Encoder-Decoder  │
├──────────────┼──────────────┼──────────────┼──────────────────┤
│ 代表模型     │ BERT         │ GPT/LLaMA    │ T5/BART          │
│ 注意力类型   │ 双向全注意力  │ 因果注意力    │ 双向+因果+交叉   │
│ 训练目标     │ MLM + NSP    │ 下一token预测 │ Seq2Seq          │
│ 擅长任务     │ 分类/NER/QA  │ 生成/对话     │ 翻译/摘要        │
│ 上下文       │ 看到全部输入  │ 只看到左侧    │ 编码器全部,解码器左侧│
│ 当前趋势     │ 逐渐减少     │ 主流(GPT时代) │ 特定场景使用     │
└──────────────┴──────────────┴──────────────┴──────────────────┘

为什么 Decoder-Only 成为主流?
1. 统一的训练目标:所有任务都可以转化为"文本生成"
2. 规模效应更好:参数量越大,涌现能力越强
3. 少样本能力:通过 prompt 即可适配新任务,无需微调
4. 工程简单:只需一个模型处理所有任务

8.2 因果注意力掩码

Decoder-Only 模型的核心:确保位置 \(i\) 只能看到位置 \(0, 1, \ldots, i\)(不泄露未来信息)。

Python
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    创建因果注意力掩码

    例如 seq_len=4:
    [[1, 0, 0, 0],   ← 位置0只看自己
     [1, 1, 0, 0],   ← 位置1看0和1
     [1, 1, 1, 0],   ← 位置2看0、1、2
     [1, 1, 1, 1]]   ← 位置3看所有
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# 训练时:一次前向传播同时预测所有位置
# 因为有因果mask,每个位置只看到左侧,等价于自回归

Transformer的训练

9.1 训练目标

Text Only
Decoder-Only (如 GPT) 的训练:

输入:  [BOS] The  cat  sat  on  the  mat
标签:  The   cat  sat  on   the mat  [EOS]

损失函数: Cross-Entropy Loss
L = -1/T Σ_t log P(x_t | x_{<t})

即最大化每个位置正确预测下一个token的对数概率。
整个序列只需一次前向传播 (teacher forcing)。

9.2 学习率调度

Transformer 训练中最关键的超参数管理:

Python
class WarmupCosineScheduler:
    """
    Warmup + Cosine Decay 学习率调度
    几乎所有现代大模型都使用这种方式

    阶段1 (warmup): 学习率从0线性增长到max_lr
    阶段2 (cosine): 学习率从max_lr余弦衰减到min_lr
    """
    def __init__(
        self,
        optimizer,
        warmup_steps: int,
        total_steps: int,
        max_lr: float = 3e-4,
        min_lr: float = 1e-5,
    ):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.step_count = 0

    def step(self):
        self.step_count += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        if self.step_count < self.warmup_steps:
            # 线性 warmup
            return self.max_lr * self.step_count / self.warmup_steps
        else:
            # 余弦衰减
            progress = (self.step_count - self.warmup_steps) / (
                self.total_steps - self.warmup_steps
            )
            return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
                1 + math.cos(math.pi * progress)
            )

# 训练循环骨架
"""
model = DecoderOnlyTransformer(...)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=2000, total_steps=100000)

for batch in dataloader:
    input_ids = batch['input_ids']                      # [B, T]
    targets = input_ids[:, 1:]                          # 右移一位
    logits = model(input_ids[:, :-1])                   # [B, T-1, vocab]
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 梯度裁剪
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
"""

Transformer的推理

10.1 自回归生成

Text Only
训练是并行的(一次处理整个序列),
推理是自回归的(一次生成一个token):

Step 0: 输入 [BOS] → 预测 "The"
Step 1: 输入 [BOS, The] → 预测 "cat"
Step 2: 输入 [BOS, The, cat] → 预测 "sat"
...

问题:每一步都要重新计算之前所有位置的注意力 → 巨大浪费
解决:KV Cache

10.2 KV Cache

KV Cache 是大模型推理最重要的优化技术:

Python
"""KV Cache 原理示意"""

class CachedMultiHeadAttention(nn.Module):
    """带KV Cache的多头注意力(推理用)"""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        """
        Args:
            x: [batch, 1, d_model] — 推理时只有当前新token
            kv_cache: (cached_K, cached_V) — 之前步骤的K/V
        Returns:
            output, new_kv_cache
        """
        B = x.size(0)

        # 只为当前 token 计算 Q/K/V
        Q = self.W_Q(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
        K_new = self.W_K(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
        V_new = self.W_V(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            K_cached, V_cached = kv_cache
            # 拼接历史 KV → [batch, n_heads, seq_so_far+1, d_k]
            K = torch.cat([K_cached, K_new], dim=2)
            V = torch.cat([V_cached, V_new], dim=2)
        else:
            K, V = K_new, V_new

        # 只计算当前 Q 与所有 K 的注意力 → O(1×S) 而非 O(S×S)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model)
        output = self.W_O(out)

        return output, (K, V)  # 返回更新后的 cache

"""
无 KV Cache: 生成 T 个token → 总计算量 O(T² · d)
有 KV Cache: 生成 T 个token → 总计算量 O(T · d) + 缓存内存
  速度提升: ~T倍(序列越长,加速越明显)

代价: 需要额外内存存储 KV Cache
  LLaMA-2-7B, seq_len=4096:
  KV cache = 2(K+V) × 32(layers) × 32(heads) × 4096(seq) × 128(d_k) × 2(bytes/fp16)
           ≈ 2GB
"""

10.3 解码策略

Python
def generate(
    model, tokenizer, prompt: str,
    max_new_tokens: int = 100,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9,
):
    """
    完整的文本生成函数,支持多种解码策略
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    kv_cache = None
    generated = []

    for _ in range(max_new_tokens):
        # 只输入最新 token(有 KV Cache 时)
        if kv_cache is not None:
            x = input_ids[:, -1:]
        else:
            x = input_ids

        logits, kv_cache = model(x, kv_cache=kv_cache)
        logits = logits[:, -1, :]  # 只取最后一个位置

        # Temperature scaling
        logits = logits / temperature

        # Top-K: 只保留概率最高的K个token
        if top_k > 0:
            topk_logits, topk_indices = torch.topk(logits, top_k)
            logits = torch.full_like(logits, float('-inf'))
            logits.scatter_(1, topk_indices, topk_logits)

        # Top-P (Nucleus Sampling)
        probs = F.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # 移除累积概率超过 top_p 的token
        remove_mask = cumulative_probs - sorted_probs > top_p
        sorted_probs[remove_mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

        # 采样
        next_token = torch.multinomial(sorted_probs, 1)
        next_token = sorted_indices.gather(1, next_token)

        if next_token.item() == tokenizer.eos_token_id:
            break

        generated.append(next_token.item())
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(generated)

延伸阅读


最后更新日期:2025-07-11 适用版本:LLM学习教程 v2025