跳转至

01 - Transformer 深入理解

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:从零开始深入理解 Transformer 的每一个组件,掌握其数学原理、实现细节和设计思想。

📌 定位说明:本章侧重大模型视角下的 Transformer 深入理解( RoPE/RMSNorm/GeLU/编解码器架构对比等)。 Transformer 架构的基础教学(从零实现完整 Transformer )请参考 深度学习/04-Transformer/02-Transformer 架构


目录

  1. Transformer 架构总览
  2. 输入嵌入层深度解析
  3. 位置编码机制
  4. 自注意力机制详解
  5. 多头注意力机制
  6. 前馈神经网络
  7. 层归一化与残差连接
  8. 编解码器架构对比
  9. Transformer 的训练
  10. Transformer 的推理

Transformer 架构总览

1.1 为什么需要 Transformer

在 Transformer 出现之前,NLP 领域的序列建模主要依赖 RNN(循环神经网络)和 LSTM(长短期记忆网络)。这些模型虽然在当时取得了不错的效果,但存在两个根本性的缺陷:

缺陷一:无法并行计算。 RNN 的计算是严格串行的——必须先处理完第 \(t\) 步,才能处理第 \(t+1\) 步。这意味着即使你有成百上千块 GPU,也无法加速单个序列的处理。对于长文本(如一整篇文章),这种串行瓶颈极其严重。

缺陷二:长距离依赖困难。 理论上 LSTM 通过门控机制可以记忆远距离信息,但在实践中,当序列长度超过 100-200 时,梯度在反向传播过程中会不断衰减或爆炸(即梯度消失/爆炸问题),导致模型难以学习到远距离 token 之间的依赖关系。例如在",那个毛茸茸的、总是追着球跑的小家伙,正坐在窗台上"中,模型需要将"猫"和" sitting "关联起来,中间隔了十几个词。

2017 年,Vaswani 等人在论文《Attention Is All You Need》中提出了 Transformer,用纯注意力机制彻底替代了 RNN/LSTM:

Text Only
RNN 的问题与 Transformer 的解决方案:
┌──────────────────────┬──────────────────────────────────────┐
│ RNN/LSTM 的缺陷       │ Transformer 的解决方案                │
├──────────────────────┼──────────────────────────────────────┤
│ 顺序计算,无法并行     │ 完全基于矩阵运算,天然支持并行         │
│ 长距离依赖困难        │ 任意两个位置间的注意力计算都是 O(1)    │
│                       │ 距离不再是障碍                        │
│ 梯度消失/爆炸         │ 残差连接 + 层归一化确保梯度稳定        │
│ 固定长度瓶颈          │ 灵活处理变长序列                      │
└──────────────────────┴──────────────────────────────────────┘

Transformer 不仅解决了上述问题,更重要的是它奠定了一个可扩展的架构基础——通过简单地增加层数、参数量和训练数据,就能持续提升模型性能。这一特性直接催生了 GPT、BERT、LLaMA 等划时代的模型,使 Transformer 成为现代 AI 最重要的基础架构之一。

1.2 架构全景图

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                     Transformer 架构全景                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    输入序列 (Input)                      │    │
│  │              [我, 喜欢, 深度, 学习]                       │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │              输入嵌入 + 位置编码                          │    │
│  │         (Token Embedding + Positional Encoding)          │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│           ┌──────────────────┴──────────────────┐               │
│           ▼                                      ▼               │
│  ┌─────────────────┐                  ┌─────────────────┐       │
│  │    Encoder      │                  │    Decoder      │       │
│  │  (编码器堆叠)    │                  │  (解码器堆叠)    │       │
│  │                 │                  │                 │       │
│  │  ┌───────────┐  │                  │  ┌───────────┐  │       │
│  │  │ Block N   │  │                  │  │ Block N   │  │       │
│  │  │ ┌───────┐ │  │                  │  │ ┌───────┐ │  │       │
│  │  │ │Multi- │ │  │                  │  │ │Masked │ │  │       │
│  │  │ │Head   │ │  │                  │  │ │Multi-H│ │  │       │
│  │  │ │Attn   │ │  │                  │  │ │Attn   │ │  │       │
│  │  │ └───────┘ │  │                  │  │ └───────┘ │  │       │
│  │  │ ┌───────┐ │  │                  │  │ ┌───────┐ │  │       │
│  │  │ │ Feed  │ │  │                  │  │ │Multi-H│ │  │       │
│  │  │ │Forward│ │  │                  │  │ │Cross  │ │  │       │
│  │  │ └───────┘ │  │                  │  │ │Attn   │ │  │       │
│  │  └───────────┘  │                  │  │ └───────┘ │  │       │
│  │       ...       │                  │  │ ┌───────┐ │  │       │
│  │  ┌───────────┐  │                  │  │ │ Feed  │ │  │       │
│  │  │ Block 1   │  │                  │  │ │Forward│ │  │       │
│  │  └───────────┘  │                  │  │ └───────┘ │  │       │
│  └─────────────────┘                  │  └───────────┘  │       │
│           │                           └─────────────────┘       │
│           │                                  │                   │
│           └──────────────────┬───────────────┘                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    输出层 (Output)                       │    │
│  │              Linear + Softmax → 概率分布                 │    │
│  └─────────────────────────────────────────────────────────┘    │
│                              │                                   │
│                              ▼                                   │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │                    预测输出                              │    │
│  │              [I, like, deep, learning]                   │    │
│  └─────────────────────────────────────────────────────────┘    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Transformer 架构图

图源:TensorFlow 教程 - Neural machine translation with a Transformer and Keras,图片文件 transformer.png,许可 CC BY 4.0。

上图展示了 Transformer 的完整架构,包括编码器( Encoder )和解码器( Decoder )的堆叠结构,以及自注意力机制和前馈神经网络等核心组件。

1.3 核心组件一览

组件 功能 关键参数 复杂度
输入嵌入 将 token 映射为向量 vocab_size × d_model -
位置编码 注入位置信息 d_model -
多头注意力 捕捉不同子空间的关系 h heads, d_k = d_model/h O(n²·d)
前馈网络 非线性变换 d_model → 4d_model → d_model O(n·d²)
层归一化 稳定训练 d_model O(n·d)
残差连接 缓解梯度消失 - -

输入嵌入层深度解析

2.1 词嵌入的数学本质

Text Only
词嵌入层就是一个查找表(Lookup Table):

E ∈ ℝ^(V × d)

其中:
- V: 词汇表大小(如32000)
- d: 嵌入维度(如512, 768, 1024)

对于输入token id = i:
embedding = E[i, :] ∈ ℝ^d

这相当于一个one-hot向量与嵌入矩阵的乘法:
embedding = one_hot(i) @ E

2.2 嵌入层的实现细节

Python
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    """
    Token嵌入层完整实现
    """
    def __init__(self, vocab_size, d_model, padding_idx=None):
        super().__init__()  # super()调用父类方法
        self.vocab_size = vocab_size
        self.d_model = d_model

        # 嵌入矩阵
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model,
            padding_idx=padding_idx  # 用于mask填充位置
        )

        # 缩放因子(Transformer原论文使用)
        self.scale = math.sqrt(d_model)

        # 初始化(重要!)
        self._init_weights()

    def _init_weights(self):
        """
        嵌入层初始化策略
        使用N(0, 1/d_model)初始化
        """
        nn.init.normal_(self.embedding.weight, mean=0, std=1/math.sqrt(self.d_model))

        # 如果有padding_idx,将该位置嵌入置零
        if self.embedding.padding_idx is not None:
            with torch.no_grad():  # 禁用梯度追踪,避免初始化操作被记入计算图
                self.embedding.weight[self.embedding.padding_idx].fill_(0)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len] token IDs
        Returns:
            embeddings: [batch_size, seq_len, d_model]
        """
        # 查找嵌入并缩放
        # 缩放原因:后续要与位置编码相加,需要平衡量级
        return self.embedding(x) * self.scale

# 使用示例
vocab_size = 32000
d_model = 512
batch_size = 2
seq_len = 10

embedding_layer = TokenEmbedding(vocab_size, d_model)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
embeddings = embedding_layer(input_ids)

print(f"输入形状: {input_ids.shape}")
print(f"嵌入形状: {embeddings.shape}")
print(f"嵌入范围: [{embeddings.min():.3f}, {embeddings.max():.3f}]")

2.3 子词分词与嵌入

现代大模型使用子词( Subword )分词,如 BPE 、 WordPiece 、 SentencePiece :

Text Only
传统分词的问题:
- "playing" 和 "played" 被视为完全不同的词
- 未登录词(OOV)问题

子词分词的优势:
"playing" → ["play", "ing"]
"unhappiness" → ["un", "happiness"] 或 ["un", "happy", "ness"]

常见分词器:
├── GPT系列: BPE (Byte-Pair Encoding)
├── BERT: WordPiece
├── LLaMA/T5: SentencePiece (Unigram)
└── 中文: 字级别或BPE
Python
# 使用Hugging Face Tokenizer示例
from transformers import AutoTokenizer

# 加载GPT-2的分词器
tokenizer = AutoTokenizer.from_pretrained("gpt2")

text = "Transformer架构 revolutionized NLP"
tokens = tokenizer.tokenize(text)
print(f"分词结果: {tokens}")
# 输出: ['Trans', 'former', '架构', 'Ġre', 'volution', 'ized', 'ĠN', 'LP']

# 转换为ID
input_ids = tokenizer.encode(text)
print(f"Token IDs: {input_ids}")

# 解码
decoded = tokenizer.decode(input_ids)
print(f"解码结果: {decoded}")

位置编码机制

3.1 为什么需要位置编码

Text Only
自注意力的排列不变性:

输入: [A, B, C] → 自注意力 → 输出
输入: [B, A, C] → 自注意力 → 输出(只是顺序变了)

问题:模型无法区分"我打你"和"你打我"

解决方案:注入位置信息

3.2 正弦位置编码( Sinusoidal )

Transformer 原始论文使用的位置编码:

Text Only
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:
- pos: 位置(0, 1, 2, ..., max_len-1)
- i: 维度索引(0, 1, 2, ..., d_model/2-1)
- d_model: 模型维度
Python
class SinusoidalPositionalEncoding(nn.Module):
    """
    正弦位置编码实现
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)

        # 位置索引 [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # unsqueeze增加一个维度

        # 维度索引的除数项
        # 10000^(2i/d_model) = exp(2i * -log(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model)
        )

        # 偶数维度用sin
        pe[:, 0::2] = torch.sin(position * div_term)

        # 奇数维度用cos
        pe[:, 1::2] = torch.cos(position * div_term)

        # 注册为buffer(不参与训练)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)

        # 添加位置编码
        x = x + self.pe[:, :seq_len, :]

        return self.dropout(x)

# 可视化位置编码
import matplotlib.pyplot as plt

def visualize_positional_encoding():
    d_model = 128
    max_len = 100

    pe = SinusoidalPositionalEncoding(d_model, max_len)
    encoding = pe.pe[0].numpy()  # [max_len, d_model]

    plt.figure(figsize=(12, 6))
    plt.imshow(encoding, aspect='auto', cmap='viridis')
    plt.colorbar()
    plt.xlabel('Dimension')
    plt.ylabel('Position')
    plt.title('Sinusoidal Positional Encoding')
    plt.show()

    # 观察不同位置的编码
    plt.figure(figsize=(12, 4))
    for pos in [0, 10, 20, 50]:
        plt.plot(encoding[pos], label=f'Pos {pos}')
    plt.legend()
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    plt.title('Positional Encoding at Different Positions')
    plt.show()

正弦位置编码的四个关键性质

性质1:唯一性 — 每个位置都有唯一的编码向量。由于不同维度使用不同频率的正弦/余弦函数,任意两个位置的编码都不会完全相同。

性质2:相对位置关系 — 这是正弦编码最重要的性质。可以证明 \(PE(pos+k)\) 可以表示为 \(PE(pos)\) 的线性函数:

\[ \begin{bmatrix} \sin((pos+k)\omega_i) \\ \cos((pos+k)\omega_i) \end{bmatrix} = \begin{bmatrix} \cos(k\omega_i) & -\sin(k\omega_i) \\ \sin(k\omega_i) & \cos(k\omega_i) \end{bmatrix} \begin{bmatrix} \sin(pos \cdot \omega_i) \\ \cos(pos \cdot \omega_i) \end{bmatrix} \]

这意味着模型不需要显式编码绝对位置,只需通过学习一组线性变换就能捕捉相对位置关系。

性质3:有界性 — 所有值都在 \([-1, 1]\) 之间,数值稳定,不会导致梯度爆炸。

性质4:外推性 — 理论上可以处理训练时未见过的更长序列(因为公式对任意 \(pos\) 都有定义),但实践中外推能力有限——这就是 RoPE 等改进方案被提出的原因之一。

3.3 可学习位置编码

BERT 等模型使用可学习的位置编码:

Python
class LearnedPositionalEncoding(nn.Module):
    """
    可学习位置编码
    """
    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 可学习的嵌入矩阵
        self.pos_embedding = nn.Embedding(max_len, d_model)

        # 初始化
        nn.init.normal_(self.pos_embedding.weight, std=0.02)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        seq_len = x.size(1)

        # 创建位置索引
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)

        # 添加位置嵌入
        x = x + self.pos_embedding(positions)

        return self.dropout(x)

3.4 旋转位置编码 (RoPE)

现代大模型( LLaMA 、 GPT-NeoX 等)使用 RoPE :

3.4.1 RoPE 的数学推导(从复数乘法出发)

Text Only
RoPE的核心目标:设计一种位置编码函数 f(x, pos),使得两个位置的内积
只依赖于相对位置差 (m - n),而非绝对位置 m 和 n:

    <f(q, m), f(k, n)> = g(q, k, m - n)

第一步:复数表示
  将向量的每两个相邻维度视为一个复数:
    z = x_{2i} + j·x_{2i+1}  (j 是虚数单位)

第二步:复数乘法 = 旋转
  在复数平面上,乘以 e^{jθ} 等价于旋转角度 θ:
    z · e^{jθ} = (x_{2i} + j·x_{2i+1})(cosθ + j·sinθ)
              = (x_{2i}·cosθ - x_{2i+1}·sinθ) + j·(x_{2i}·sinθ + x_{2i+1}·cosθ)

  写成矩阵形式:
    [x_{2i}' ]   [cosθ  -sinθ] [x_{2i}  ]
    [x_{2i+1}'] = [sinθ   cosθ] [x_{2i+1}]

第三步:位置相关的旋转角度
  对位置 m 的第 i 对维度,旋转角度 θ_i(m) = m · ω_i
  其中 ω_i = 1 / 10000^{2i/d}(频率随维度指数递减)

第四步:验证相对位置性质
  对位置 m 的 query 和位置 n 的 key(在第 i 对维度上):
    f(q, m) = q · e^{j·m·ω_i}
    f(k, n) = k · e^{j·n·ω_i}

  内积(复数):
    Re[f(q, m) · conj(f(k, n))]
    = Re[q · e^{j·m·ω_i} · conj(k · e^{j·n·ω_i})]
    = Re[q · conj(k) · e^{j·(m-n)·ω_i}]

  结果只依赖 (m-n),证毕。✓

第五步:推广到全维度
  对 d 维向量,每两维一组,共 d/2 组,每组用不同频率 ω_i 旋转:

    RoPE(x, pos) = [R(pos·ω_0)·x_{0:1}, R(pos·ω_1)·x_{2:3}, ..., R(pos·ω_{d/2-1})·x_{d-2:d-1}]

  其中 R(θ) 是 2×2 旋转矩阵。

3.4.2 RoPE 实现

Python
class RotaryPositionalEmbedding(nn.Module):
    """
    旋转位置编码 (RoPE)
    通过旋转矩阵注入相对位置信息
    """
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()

        # 计算旋转角度
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算位置编码
        t = torch.arange(max_seq_len)
        freqs = torch.einsum('i,j->ij', t, inv_freq)  # [max_seq_len, dim/2]

        # 复数形式: cos + i*sin
        self.register_buffer('cos_cached', freqs.cos()[None, None, :, :])  # [1, 1, seq_len, dim/2]
        self.register_buffer('sin_cached', freqs.sin()[None, None, :, :])  # [1, 1, seq_len, dim/2]

    def forward(self, x, seq_len=None):
        """
        Args:
            x: [batch, heads, seq_len, head_dim]
        """
        if seq_len is None:
            seq_len = x.shape[2]

        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]

        return self.apply_rotary_pos_emb(x, cos, sin)

    def apply_rotary_pos_emb(self, x, cos, sin):
        """
        应用旋转位置编码

        Args:
            x: [batch, heads, seq_len, head_dim]
            cos: [1, 1, seq_len, head_dim/2]
            sin: [1, 1, seq_len, head_dim/2]

        Returns:
            [batch, heads, seq_len, head_dim]
        """
        # 将x分成两部分(偶数维和奇数维)
        x1, x2 = x[..., ::2], x[..., 1::2]

        # 旋转: [x1, x2] @ [[cos, -sin], [sin, cos]]
        # 即: x1' = x1 * cos - x2 * sin
        #     x2' = x1 * sin + x2 * cos
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1)

        # 将最后两维展平,恢复原始形状
        return rotated.flatten(-2)

RoPE (旋转位置编码)的核心是:把相邻两个维度视作一个二维平面上的坐标对,再通过位置相关的旋转矩阵把位置信息注入到查询和键向量中。与其依赖质量一般的示意图,不如直接结合上面的代码和后面的公式理解“二维旋转 + 频率分层”这两个关键点。


自注意力机制详解

4.1 直觉理解

自注意力的核心思想:让序列中每个位置都"看到"其他所有位置,并学习该关注哪些位置。

Text Only
以"猫坐在垫子上,它很舒服"为例:

处理"它"时,自注意力权重分布(示意):
猫  坐  在  垫子 上  ,  它  很 舒服
0.45 0.05 0.02 0.15 0.03 0.01 0.10 0.04 0.15

→ 模型学会了"它"主要指代"猫"(最高权重),次要关联"垫子"和"舒服"

4.2 Q-K-V 的数学推导

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

三个矩阵的角色: - \(Q\)( Query ):当前位置发出的"提问" - \(K\)( Key ):每个位置提供的"索引标签" - \(V\)( Value ):每个位置存储的"实际内容"

Text Only
类比图书馆:
Q = 你的搜索关键词("深度学习入门")
K = 每本书的标题/关键词索引
V = 每本书的实际内容

检索流程:
1. 计算 Q @ K^T → 搜索词与每本书标题的匹配分数
2. softmax → 将分数归一化为概率
3. 概率 @ V → 加权提取最相关书的内容
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    query: torch.Tensor,   # [batch, seq_q, d_k]
    key: torch.Tensor,     # [batch, seq_k, d_k]
    value: torch.Tensor,   # [batch, seq_k, d_v]
    mask: torch.Tensor = None,
    dropout: nn.Dropout = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    缩放点积注意力

    关键点:
    1. 为什么除以 sqrt(d_k)?(完整推导)

       假设 q 和 k 的每个分量独立同分布,均值为0、方差为1:
         q = [q_1, q_2, ..., q_{d_k}],  q_i ~ (0, 1)
         k = [k_1, k_2, ..., k_{d_k}],  k_i ~ (0, 1)

       点积 q·k = Σ_{i=1}^{d_k} q_i * k_i

       每一项 q_i * k_i 的方差:
         Var(q_i * k_i) = E[q_i²] * E[k_i²] - (E[q_i] * E[k_i])²
                        = 1 * 1 - 0 = 1

       因为各项独立,所以:
         Var(q·k) = Σ Var(q_i * k_i) = d_k

       当 d_k=64 时,点积标准差 ≈ 8;d_k=128 时 ≈ 11.3
       这些大值会将 softmax 推入饱和区(梯度接近 0)

       除以 √d_k 后:
         Var(q·k / √d_k) = Var(q·k) / d_k = d_k / d_k = 1

       方差归一化为 1,softmax 的输入保持在合理范围,梯度可以正常流动。

    2. mask 的两种用途:
       - padding mask: 屏蔽填充位置
       - causal mask: 屏蔽未来位置(解码器)
    """
    d_k = query.size(-1)

    # Step 1: QK^T / sqrt(d_k) → [batch, seq_q, seq_k]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: 应用 mask(将需要屏蔽的位置设为 -inf)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax → 权重矩阵
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Dropout(训练时)
    if dropout is not None:
        attn_weights = dropout(attn_weights)

    # Step 5: 加权求和 → [batch, seq_q, d_v]
    output = torch.matmul(attn_weights, value)

    return output, attn_weights

# ---------- 验证 ----------
batch, seq_len, d_k = 2, 5, 64
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)

# 无 mask
out, weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {out.shape}")           # [2, 5, 64]
print(f"权重形状: {weights.shape}")        # [2, 5, 5]
print(f"权重行和: {weights.sum(-1)}")      # 每行和为 1.0

# 因果 mask(解码器用)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))  # 下三角矩阵
out_masked, weights_masked = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
print(f"\n因果mask权重(第0行只看自己,第4行看所有位置):")
print(weights_masked[0].detach())

4.3 注意力复杂度分析

Text Only
时间复杂度: O(n² · d)
  - 其中 n = 序列长度, d = 模型维度
  - 主要来自 QK^T 的矩阵乘法

空间复杂度: O(n²)
  - 需要存储完整的注意力权重矩阵

序列长度 →  注意力计算量(d=1024)
128         → 16.7M
1,024       → 1.07B
8,192       → 68.7B
128,000     → 16.8T  ← 这就是长上下文的挑战

解决方案:
├── FlashAttention: 利用GPU内存层级优化,O(n²)复杂度不变但实际速度快2-4倍
├── GQA (Grouped-Query Attention): 减少KV头数,降低内存
├── MQA (Multi-Query Attention): 所有头共享一组KV
├── Ring Attention: 分布式环形通信处理超长序列
└── 稀疏注意力: 只计算部分位置对的注意力

4.3.1 GPU内存层次结构与访问瓶颈

理解FlashAttention的关键在于GPU的内存层次结构:

Text Only
GPU内存带宽对比(典型A100 GPU):

| 内存类型 | 带宽 | 延迟 |
|----------|------|------|
| HBM (High Bandwidth Memory) | ~1.6 TB/s | ~500ns |
| L2 Cache | ~4.6 TB/s | ~150ns |
| L1 Cache / SRAM | ~10-20 TB/s | ~10-30ns |

关键洞察:
- L1 Cache带宽是HBM的10倍以上
- 延迟差异达50倍
- 减少HBM访问是提升性能的关键

4.3.2 标准Attention的IO复杂度

标准Self-Attention的内存访问模式:

Text Only
标准Attention计算流程:
1. 读取 Q, K, V 矩阵: O(n·d)
2. 计算 S = QK^T: O(n²·d)
3. 计算 P = softmax(S): O(n²)
4. 计算 O = PV: O(n²·d)
5. 存储完整注意力矩阵 S 和 P: O(n²)

总内存访问量: Θ(n² + n·d)

问题分析: - 当 n = 序列长度(如8192),d = 模型维度(如128)时 - n² 远大于 n·d(如8192² = 67M >> 8192×128 = 1M) - 注意力矩阵存储成为真正的瓶颈

4.3.3 FlashAttention的分块策略

FlashAttention核心思想:分块计算,避免存储n²矩阵

Text Only
分块计算流程:

┌─────────────────────────────────────────────────────────────┐
│                     FlashAttention 分块策略                    │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  HBM (全局显存)              SRAM (片上缓存)                  │
│  ┌──────────────┐           ┌──────────────┐               │
│  │  Q[0:Br]     │  load    │  Q_block     │               │
│  │  K[0:Bc]     │ ──────→ │  K_block     │               │
│  │  V[0:Bc]     │          │  V_block     │               │
│  │  O[0:Br]     │  write   │  O_block     │ ← accumulate │
│  └──────────────┘  ←─────  └──────────────┘               │
│                                                              │
│  块大小限制: SRAM大小(如A100 L1=192KB)                     │
│  Br × d + Bc × d + Br × Bc ≤ 192KB                         │
└─────────────────────────────────────────────────────────────┘

伪代码:
for block_q in blocks(Q, block_size=Br):
    load Q_block to SRAM
    for block_kv in blocks(K, V, block_size=Bc):
        load K_block, V_block to SRAM
        compute partial_attention(Q_block, K_block, V_block)
        accumulate result in output_block

    write output_block to HBM

关键优势:
- 每个K、V块只从HBM读取一次
- 不存储完整的S和P矩阵
- 输出直接累积在SRAM中

4.3.4 FlashAttention的IO复杂度分析

Text Only
FlashAttention的IO复杂度推导:

设:
- M = SRAM大小(如192KB)
- B = 单块可容纳元素数,B ≈ M / (2d)(考虑Q、K、V、O同时存在)

标准Attention:
- HBM访问量: O(n² + n·d)
- 中间结果存储: O(n²)

FlashAttention:
- 将Q分成T = n/B个块
- 每个Q块处理时,需遍历所有K、V块
- 每个K、V块访问一次 → O(n²·d / B)
- Q、O访问 → O(n·d)

总HBM访问量: O(n·d + n²·d / M)

加速比估算:
- 理想情况下,FlashAttention加速比约为 B/(2d) 倍
- 实际上由于计算重叠和内存访问优化,可达3-4倍加速

4.3.5 实际速度提升数据

Text Only
FlashAttention vs 标准Attention(A100 GPU,FP16):

| 序列长度 | 标准Attention | FlashAttention | 加速比 |
|----------|---------------|----------------|--------|
| 512      | ~15 ms        | ~4 ms          | 3.75x  |
| 2,048    | ~180 ms       | ~25 ms         | 7.2x   |
| 4,096    | ~680 ms       | ~80 ms         | 8.5x   |
| 8,192    | ~2,800 ms     | ~250 ms        | 11.2x  |

显存占用对比:

| 序列长度 | 标准Attention显存 | FlashAttention显存 | 节省比例 |
|----------|-------------------|-------------------|----------|
| 2,048    | 16 MB             | 2 MB              | 8x       |
| 4,096    | 64 MB             | 4 MB              | 16x      |
| 8,192    | 256 MB            | 8 MB              | 32x      |
| 32,768   | 4 GB              | 32 MB             | 128x     |

> 💡 **关键洞察**:序列越长,FlashAttention的优势越明显。这正是长上下文场景(如32K+上下文)必须使用FlashAttention的原因。

4.3.6 FlashAttention演进

Text Only
FlashAttention版本演进:

FlashAttention (2022):
- 核心创新:IO-aware分块计算
- 利用SRAM加速,减少HBM访问
- 加速比:2-4倍

FlashAttention-2 (2023):
- 更好的并行化策略
- 改进工作划分,减少线程束等待
- 加速比:4-8倍

FlashAttention-3 (2024, Hopper架构):
- 利用Tensor Core异步操作
- FP8低精度支持
- H100利用率从35%提升至75%
- 加速比:8-16倍

主流大模型均采用FlashAttention:
GPT-3、Falcon、LLaMA2/3、Mistral、Qwen等

多头注意力机制

5.1 为什么需要多头

单头注意力的问题:一个注意力头只能学到一种"关注模式"。而自然语言有多种关系维度(语法、语义、共指、因果等)。

Text Only
多头的直觉:用不同的"眼睛"看同一句话

Head 1: 可能学会了语法依赖(主语→动词)
Head 2: 可能学会了共指关系(代词→先行词)
Head 3: 可能学会了修饰关系(形容词→名词)
Head 4: 可能学会了长距离依赖(问题→答案)
...

5.2 多头注意力实现

\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \]
\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]
Python
class MultiHeadAttention(nn.Module):
    """
    多头注意力机制

    参数量分析(d_model=1024, h=16, d_k=64):
    W_Q: 1024×1024 = 1M
    W_K: 1024×1024 = 1M
    W_V: 1024×1024 = 1M
    W_O: 1024×1024 = 1M
    总计: 4M 参数
    """

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须被 n_heads 整除"  # assert断言:条件False时抛出AssertionError

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度

        # 四个投影矩阵
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        query: torch.Tensor,   # [batch, seq_q, d_model]
        key: torch.Tensor,     # [batch, seq_k, d_model]
        value: torch.Tensor,   # [batch, seq_k, d_model]
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)

        # 1. 线性投影 → [batch, seq, d_model]
        Q = self.W_Q(query)
        K = self.W_K(key)
        V = self.W_V(value)

        # 2. 分头: [batch, seq, d_model] → [batch, n_heads, seq, d_k]
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # view重塑张量形状
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # 3. 缩放点积注意力
        d_k = self.d_k
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 4. 加权求和 → [batch, n_heads, seq_q, d_k]
        context = torch.matmul(attn_weights, V)

        # 5. 合并多头: [batch, seq_q, d_model]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 6. 输出投影
        output = self.W_O(context)

        return output

# 验证
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)

# Self-attention: Q=K=V=x
out = mha(x, x, x)
print(f"Self-attention 输出: {out.shape}")  # [2, 10, 512]

# Cross-attention: Q来自decoder, K/V来自encoder
enc_out = torch.randn(2, 20, 512)
dec_in = torch.randn(2, 10, 512)
cross_out = mha(dec_in, enc_out, enc_out)
print(f"Cross-attention 输出: {cross_out.shape}")  # [2, 10, 512]

5.3 GQA 与 MQA :现代大模型的注意力优化

Text Only
标准 MHA (Multi-Head Attention):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, n_heads, seq, d_k]    ← 每个头有独立的K
  V: [batch, n_heads, seq, d_k]    ← 每个头有独立的V
  KV cache 大小 = 2 × n_heads × seq × d_k

MQA (Multi-Query Attention, GPT-J):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, 1, seq, d_k]          ← 所有头共享一组K
  V: [batch, 1, seq, d_k]          ← 所有头共享一组V
  KV cache 大小 = 2 × 1 × seq × d_k  → 减少 n_heads 倍

GQA (Grouped-Query Attention, LLaMA-2/3):
  Q: [batch, n_heads, seq, d_k]    ← 每个头有独立的Q
  K: [batch, n_kv_heads, seq, d_k] ← 每组共享一组K
  V: [batch, n_kv_heads, seq, d_k] ← 每组共享一组V
  KV cache 大小 = 2 × n_kv_heads × seq × d_k

示例 (LLaMA-3-70B): n_heads=64, n_kv_heads=8
  → 每8个Q头共享1组KV → KV cache减少8倍,性能接近MHA

前馈神经网络

6.1 标准 FFN

每个 Transformer 层中的全连接前馈网络( Position-wise FFN ):

\[ \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 \]
Python
class FeedForward(nn.Module):
    """标准FFN,4倍扩展"""
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model  # 默认4倍扩展

        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # [batch, seq, d_model] → [batch, seq, d_ff] → [batch, seq, d_model]
        return self.w2(self.dropout(F.relu(self.w1(x))))

6.2 现代激活函数: GeLU 与 SwiGLU

Python
# GeLU (GPT-2/BERT 使用)
# GELU(x) = x · Φ(x),其中 Φ 是标准正态分布的CDF
# 比 ReLU 更平滑,不会在 x=0 处产生不可导点

class GeLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w2(self.dropout(F.gelu(self.w1(x))))

# SwiGLU (LLaMA/Qwen/Mistral 使用)
# SwiGLU(x, W, V, W2) = (Swish(xW) ⊙ xV) W2
# 引入门控机制,效果优于 GeLU,但参数量增加 50%

class SwiGLUFeedForward(nn.Module):
    """LLaMA-style SwiGLU FFN"""
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        # LLaMA 使用 8/3 * d_model 作为 d_ff 以保持参数量不变
        d_ff = d_ff or int(8 / 3 * d_model)
        # 取最接近的 256 的倍数(GPU对齐优化)
        d_ff = ((d_ff + 255) // 256) * 256

        self.w1 = nn.Linear(d_model, d_ff, bias=False)  # gate projection
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # up projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)  # down projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Swish(xW1) ⊙ xW3 → 门控机制
        gate = F.silu(self.w1(x))  # SiLU = Swish(x) = x * sigmoid(x)
        up = self.w3(x)
        return self.w2(self.dropout(gate * up))

# 对比
d_model = 1024
ffn_relu = FeedForward(d_model)
ffn_gelu = GeLUFeedForward(d_model)
ffn_swiglu = SwiGLUFeedForward(d_model)

for name, module in [("ReLU FFN", ffn_relu), ("GeLU FFN", ffn_gelu), ("SwiGLU FFN", ffn_swiglu)]:
    params = sum(p.numel() for p in module.parameters())
    print(f"{name}: {params:,} 参数")
# ReLU FFN: 8,393,728 参数
# GeLU FFN: 8,393,728 参数
# SwiGLU FFN: 8,650,752 参数 (d_ff调整后参数量接近)

层归一化与残差连接

7.1 为什么需要归一化

深层网络中,每层输出的分布会不断漂移( Internal Covariate Shift ),导致训练不稳定。归一化将输出拉回标准分布。

7.2 LayerNorm vs RMSNorm

Python
class LayerNorm(nn.Module):
    """
    标准层归一化 (BERT/GPT-2 使用)
    LayerNorm(x) = γ · (x - μ) / √(σ² + ε) + β
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))   # 可学习缩放
        self.beta = nn.Parameter(torch.zeros(d_model))    # 可学习偏移
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class RMSNorm(nn.Module):
    """
    RMS层归一化 (LLaMA/Qwen/Mistral 使用)
    RMSNorm(x) = γ · x / RMS(x)
    其中 RMS(x) = √(1/n · Σx²)

    优势:去掉了均值计算和偏移参数 β
    → 计算更快(约10-15%),效果相当
    """
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return self.gamma * x / rms

# 对比
x = torch.randn(2, 10, 512)
ln = LayerNorm(512)
rmsn = RMSNorm(512)

out_ln = ln(x)
out_rmsn = rmsn(x)
print(f"LayerNorm 输出统计: mean={out_ln.mean():.4f}, std={out_ln.std():.4f}")
print(f"RMSNorm 输出统计:   mean={out_rmsn.mean():.4f}, std={out_rmsn.std():.4f}")

7.3 Pre-LN vs Post-LN:结构对比与数学分析

7.3.1 两种结构的对比

Text Only
┌─────────────────────────────────────────────────────────────────────────────┐
│                         Post-LN (原始 Transformer)                           │
│                                                                             │
│   x → Attention → + residual → LayerNorm → FFN → + residual → LayerNorm    │
│                                                                             │
│   特点:LayerNorm 在残差连接之后(最后)                                     │
│   缺点:深层训练极不稳定,需要 warm-up 学习率预热                            │
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
│                         Pre-LN (现代大模型标配)                              │
│                                                                             │
│   x → LayerNorm → Attention → + residual                                    │
│   x → LayerNorm → FFN      → + residual                                    │
│                                                                             │
│   特点:LayerNorm 在残差连接之前(每个子层内部)                            │
│   优点:梯度直接传播,训练稳定,无需 warm-up                                │
└─────────────────────────────────────────────────────────────────────────────┘

7.3.2 为什么 Pre-LN 更稳定:数学分析

梯度分析:

Text Only
设第 l 层输入为 x_l,输出为 x_{l+1}:

Post-LN 梯度流动:
  ∂L/∂x_l = ∂L/∂x_{l+1} · ∂x_{l+1}/∂x_l

  其中 ∂x_{l+1}/∂x_l 包含 LayerNorm 的雅可比行列式
  → 当网络加深时,LayerNorm 的梯度依赖于其输入的统计量
  → 深层梯度容易消失或爆炸,导致训练不稳定

Pre-LN 梯度流动:
  每一层的梯度都有独立的"直接路径"回到输入:
  ∂L/∂x_l = ∂L/∂x_{l+1} · (I + J_l)

  其中 I 是恒等矩阵,J_l 是子层Jacobian
  → 残差连接确保梯度可以无损传递
  → 即使网络很深,梯度也能稳定回传

训练动态对比:

特性 Post-LN Pre-LN
梯度幅度 输出层附近梯度极大 各层梯度幅度均匀
训练稳定性 需 warm-up,避免爆炸 无需 warm-up,稳定训练
深层网络 超过6层难以训练 可训练数百层
典型深度 ≤ 12 层 12 → 100+ 层

论文《On Layer Normalization in the Transformer Architecture》证明: - Post-LN 在训练初期,输出层附近参数梯度期望值很大 - Pre-LN 将 LN 置于残差分支内部,梯度直接传递

7.3.3 不同大模型的 LN 位置

模型 LN 类型 位置 备注
BERT Post-LN 残差后 原始 Transformer,12 层
GPT-2 Pre-LN 残差前 OpenAI,揭示 LN 位置重要性
T5 Pre-LN 残差前 Google,Encoder-Decoder
LLaMA Pre-LN + RMSNorm 残差前 Meta AI,现代 LLM 标配
GPT-3 Pre-LN 残差前 175B 参数,深层训练关键
PaLM Pre-LN + RMSNorm 残差前 Google,540B 参数
DeepSeek-V3 Pre-LN + RMSNorm 残差前 MoE + Pre-LN 支持超深网络

💡 规律:所有超过 30 层的大模型(GPT-3/PaLM/LLaMA/DeepSeek)都使用 Pre-LN。

7.3.4 Pre-LN 的潜在问题与解决方案

虽然 Pre-LN 更稳定,但研究发现浅层梯度可能大于深层(与 Post-LN 正好相反):

Text Only
梯度幅度分布:
  Post-LN:深层 > 浅层
  Pre-LN:  浅层 > 深层

解决方案:DeepNorm (Microsoft)
  在 Pre-LN 基础上引入缩放因子:
  x_{l+1} = α · x_l + Sublayer( LN(x_l) )

  α 的设置:
  - 普通 Transformer:α = (2N)^(1/4),N 为层数
  - 极深网络(如 1000 层):α 动态调整

7.3.5 残差连接的几何意义

残差连接不仅仅是"缓解梯度消失"的技术,从几何视角看,它有着更深层的意义。

7.3.5.1 直觉解释:梯度高速公路
Text Only
残差连接的本质:提供一条梯度高速公路

考虑等效电路:
输入 x ──┬──→ [主路径: 多层网络] ──→ 输出
        └──→ [跳远连接: identity] ──→ 输出 +

如果主路径学不到东西,梯度可以直接绕过去
这使得深层网络的训练更加稳定

这个"短路"设计让网络学习变得更容易: - 最坏情况:子层什么都不学习,输出为 0 → 输出 = 输入(恒等映射) - 最好情况:子层学习到有用的变换 → 输出 = 输入 + 有用变换

7.3.5.2 梯度路径分析:多条并行通道
Text Only
没有残差:梯度必须穿过每一层网络
∂L/∂x = ∂L/∂Layer₁ → ∂L/∂Layer₂ → ... → ∂L/∂Layerₙ

有残差连接:梯度有多条路径
∂L/∂x = ∂L/∂直接路径 + ∂L/∂Layer₁路径 + ∂L/∂Layer₂路径 + ...

深层梯度消失问题被有效缓解

数学上,对于残差块:

Text Only
x_{l+1} = x_l + F(x_l)

∂L/∂x_l = ∂L/∂x_{l+1} · (I + ∂F/∂x_l)

其中 I 是单位矩阵,保证梯度下界至少为 1,有效防止梯度消失。

7.3.5.3 信息流视角:流形上的增量更新
Text Only
残差连接确保了深层网络的信息流动性:

浅层特征 ——直接传递——→ 深层特征
         ↗               ↘
    通过残差        通过多层变换
         ↘               ↗
          ——混合后的特征——

这使得网络可以学习恒等映射:f(x) ≈ x

从流形学习角度看: - 深层特征位于输入空间的高维流形上 - 残差连接允许在流形上"原地踏步"(恒等映射) - 浅层特征可以直接传递到深层,保持语义信息不被破坏

7.3.5.4 主成分分析视角
Text Only
从PCA角度看残差连接的作用:

无残差网络:
- 每层都在对特征做线性/非线性变换
- 深层特征可能被"扭曲",偏离原始流形
- 重要主成分可能在多层变换中丢失

有残差网络:
- 恒等分支保持原始主成分方向
- 变换分支学习"残差"方向
- 浅层学到的语义可以无损传递到深层

这解释了为什么极深网络(如 100+ 层)可以正常训练:残差连接保持了特征空间的"骨架"结构。

7.3.6 现代大模型的共识

Text Only
Pre-LN + RMSNorm 已成为现代大模型的标准配置:

LLaMA 架构:
  x → RMSNorm → Attention → + residual
  x → RMSNorm → SwiGLU FFN → + residual

LLaMA 2/3 代码示例:
```python
class TransformerBlock(nn.Module):
    """Pre-Norm Transformer Block (LLaMA-style)"""

    def __init__(self, d_model: int, n_heads: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = RMSNorm(d_model)
        self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)

    def forward(self, x, mask=None):
        # Pre-Norm + Residual
        x_norm = self.norm1(x)
        x = x + self.attn(x_norm, x_norm, x_norm, mask)
        x = x + self.ffn(self.norm2(x))
        return x

编解码器架构对比

8.1 三种主流架构

这是理解现代大模型最关键的知识点之一:

Text Only
┌────────────────────────────────────────────────────────────────┐
│                    三种 Transformer 架构                        │
├──────────────┬──────────────┬──────────────┬──────────────────┤
│              │ Encoder-Only │ Decoder-Only │ Encoder-Decoder  │
├──────────────┼──────────────┼──────────────┼──────────────────┤
│ 代表模型     │ BERT         │ GPT/LLaMA    │ T5/BART          │
│ 注意力类型   │ 双向全注意力  │ 因果注意力    │ 双向+因果+交叉   │
│ 训练目标     │ MLM + NSP    │ 下一token预测 │ Seq2Seq          │
│ 擅长任务     │ 分类/NER/QA  │ 生成/对话     │ 翻译/摘要        │
│ 上下文       │ 看到全部输入  │ 只看到左侧    │ 编码器全部,解码器左侧│
│ 当前趋势     │ 逐渐减少     │ 主流(GPT时代) │ 特定场景使用     │
└──────────────┴──────────────┴──────────────┴──────────────────┘

为什么 Decoder-Only 成为主流?
1. 统一的训练目标:所有任务都可以转化为"文本生成"
2. 规模效应更好:参数量越大,涌现能力越强
3. 少样本能力:通过 prompt 即可适配新任务,无需微调
4. 工程简单:只需一个模型处理所有任务

8.2 因果注意力掩码

Decoder-Only 模型的核心:确保位置 \(i\) 只能看到位置 \(0, 1, \ldots, i\)(不泄露未来信息)。

Python
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    创建因果注意力掩码

    例如 seq_len=4:
    [[1, 0, 0, 0],   ← 位置0只看自己
     [1, 1, 0, 0],   ← 位置1看0和1
     [1, 1, 1, 0],   ← 位置2看0、1、2
     [1, 1, 1, 1]]   ← 位置3看所有
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# 训练时:一次前向传播同时预测所有位置
# 因为有因果mask,每个位置只看到左侧,等价于自回归

Transformer 的训练

9.1 训练目标

Text Only
Decoder-Only (如 GPT) 的训练:

输入:  [BOS] The  cat  sat  on  the  mat
标签:  The   cat  sat  on   the mat  [EOS]

损失函数: Cross-Entropy Loss
L = -1/T Σ_t log P(x_t | x_{<t})

即最大化每个位置正确预测下一个token的对数概率。
整个序列只需一次前向传播 (teacher forcing)。

9.2 学习率调度

Transformer 训练中最关键的超参数管理:

Python
class WarmupCosineScheduler:
    """
    Warmup + Cosine Decay 学习率调度
    几乎所有现代大模型都使用这种方式

    阶段1 (warmup): 学习率从0线性增长到max_lr
    阶段2 (cosine): 学习率从max_lr余弦衰减到min_lr
    """
    def __init__(
        self,
        optimizer,
        warmup_steps: int,
        total_steps: int,
        max_lr: float = 3e-4,
        min_lr: float = 1e-5,
    ):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.step_count = 0

    def step(self):
        self.step_count += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        if self.step_count < self.warmup_steps:
            # 线性 warmup
            return self.max_lr * self.step_count / self.warmup_steps
        else:
            # 余弦衰减
            progress = (self.step_count - self.warmup_steps) / (
                self.total_steps - self.warmup_steps
            )
            return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
                1 + math.cos(math.pi * progress)
            )

# ---------- 完整训练循环示例 ----------
# 以下代码展示了 Decoder-Only Transformer 的标准训练流程

def train_decoder_only(model, dataloader, vocab_size, num_steps=10000, device='cpu'):
    """
    Decoder-Only Transformer 训练循环

    核心流程:
    1. 输入序列 [BOS, w1, w2, ..., wT] 送入模型
    2. 模型在每个位置预测下一个 token
    3. 标签是输入右移一位: [w1, w2, ..., wT, EOS]
    4. 用交叉熵损失计算误差并反向传播
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
    scheduler = WarmupCosineScheduler(optimizer, warmup_steps=2000, total_steps=num_steps)

    model.train()
    for step, batch in enumerate(dataloader):
        if step >= num_steps:
            break

        input_ids = batch['input_ids'].to(device)       # [B, T]
        targets = input_ids[:, 1:]                       # 右移一位作为标签
        logits = model(input_ids[:, :-1])                # [B, T-1, vocab]

        # 交叉熵损失:每个位置预测下一个 token
        loss = F.cross_entropy(
            logits.reshape(-1, vocab_size),               # [B*(T-1), vocab]
            targets.reshape(-1)                           # [B*(T-1)]
        )

        # 反向传播
        optimizer.zero_grad()
        loss.backward()

        # 梯度裁剪:防止梯度爆炸(Transformer 训练的关键技巧)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        if (step + 1) % 1000 == 0:
            print(f"Step {step+1}/{num_steps} | Loss: {loss.item():.4f} | LR: {scheduler.get_lr():.6f}")

    return model

# 注意:完整的可运行训练代码请参考 03-手写Transformer完整实现.md

Transformer 的推理

10.1 自回归生成

Text Only
训练是并行的(一次处理整个序列),
推理是自回归的(一次生成一个token):

Step 0: 输入 [BOS] → 预测 "The"
Step 1: 输入 [BOS, The] → 预测 "cat"
Step 2: 输入 [BOS, The, cat] → 预测 "sat"
...

问题:每一步都要重新计算之前所有位置的注意力 → 巨大浪费
解决:KV Cache

10.2 KV Cache

KV Cache 是大模型推理最重要的优化技术:

Python
"""KV Cache 原理示意"""

class CachedMultiHeadAttention(nn.Module):
    """带KV Cache的多头注意力(推理用)"""

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        """
        Args:
            x: [batch, 1, d_model] — 推理时只有当前新token
            kv_cache: (cached_K, cached_V) — 之前步骤的K/V
        Returns:
            output, new_kv_cache
        """
        B = x.size(0)

        # 只为当前 token 计算 Q/K/V
        Q = self.W_Q(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
        K_new = self.W_K(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)
        V_new = self.W_V(x).view(B, 1, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            K_cached, V_cached = kv_cache
            # 拼接历史 KV → [batch, n_heads, seq_so_far+1, d_k]
            K = torch.cat([K_cached, K_new], dim=2)
            V = torch.cat([V_cached, V_new], dim=2)
        else:
            K, V = K_new, V_new

        # 只计算当前 Q 与所有 K 的注意力 → O(1×S) 而非 O(S×S)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model)
        output = self.W_O(out)

        return output, (K, V)  # 返回更新后的 cache

"""
无 KV Cache: 生成 T 个token → 总计算量 O(T² · d)
有 KV Cache: 生成 T 个token → 总计算量 O(T · d) + 缓存内存
  速度提升: ~T倍(序列越长,加速越明显)

代价: 需要额外内存存储 KV Cache
  LLaMA-2-7B, seq_len=4096:
  KV cache = 2(K+V) × 32(layers) × 32(heads) × 4096(seq) × 128(d_k) × 2(bytes/fp16)
           ≈ 2GB
"""

10.3 解码策略

Python
def generate(
    model, tokenizer, prompt: str,
    max_new_tokens: int = 100,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9,
):
    """
    完整的文本生成函数,支持多种解码策略
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    kv_cache = None
    generated = []

    for _ in range(max_new_tokens):
        # 只输入最新 token(有 KV Cache 时)
        if kv_cache is not None:
            x = input_ids[:, -1:]
        else:
            x = input_ids

        logits, kv_cache = model(x, kv_cache=kv_cache)
        logits = logits[:, -1, :]  # 只取最后一个位置

        # Temperature scaling
        logits = logits / temperature

        # Top-K: 只保留概率最高的K个token
        if top_k > 0:
            topk_logits, topk_indices = torch.topk(logits, top_k)
            logits = torch.full_like(logits, float('-inf'))
            logits.scatter_(1, topk_indices, topk_logits)

        # Top-P (Nucleus Sampling)
        probs = F.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # 移除累积概率超过 top_p 的token
        remove_mask = cumulative_probs - sorted_probs > top_p
        sorted_probs[remove_mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

        # 采样
        next_token = torch.multinomial(sorted_probs, 1)
        next_token = sorted_indices.gather(1, next_token)

        if next_token.item() == tokenizer.eos_token_id:
            break

        generated.append(next_token.item())
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(generated)

📝 本章小结

核心知识点回顾

组件 核心公式/概念 现代大模型中的演进
输入嵌入 \(E \in \mathbb{R}^{V \times d}\),缩放 \(\sqrt{d_{model}}\) 子词分词(BPE/SentencePiece)
位置编码 Sinusoidal / Learnable / RoPE RoPE 成为 LLM 标配
自注意力 \(\text{softmax}(QK^T / \sqrt{d_k})V\) FlashAttention ⅔ 加速
多头注意力 \(h\) 个头并行,Concat + \(W^O\) GQA/MQA 减少 KV Cache
前馈网络 \(\text{ReLU}(xW_1)W_2\) SwiGLU(门控 + SiLU 激活)
层归一化 LayerNorm / RMSNorm Pre-LN + RMSNorm 成为标配
残差连接 \(x + \text{Sublayer}(x)\) 确保梯度稳定传播
解码策略 Greedy / Top-K / Top-P / Temperature 自回归 + KV Cache

关键设计决策总结

Text Only
现代 LLM(如 LLaMA-3)的架构选择:
├── 位置编码: RoPE(旋转位置编码,支持相对位置)
├── 注意力: GQA(分组查询注意力,平衡性能和效率)
├── FFN: SwiGLU(门控线性单元,效果优于 ReLU/GeLU)
├── 归一化: RMSNorm(比 LayerNorm 快 10-15%)
├── LN 位置: Pre-LN(训练更稳定,无需 warm-up)
├── 加速: FlashAttention-2/3(减少 HBM 访问)
└── 架构: Decoder-Only(统一生成范式)

思考题

  1. RoPE 的外推性:RoPE 理论上可以处理任意长度的序列,但实践中长序列性能会下降。为什么?有哪些方法可以改善 RoPE 的长度外推?(提示:NTK-aware Scaling、YaRN、Position Interpolation)

  2. FlashAttention 的精度:FlashAttention 使用 online softmax 近似,与标准注意力在数值上可能有微小差异。在实际训练中,这种差异是否会影响模型质量?为什么?

  3. KV Cache 压缩:除了 GQA/MQA,还有哪些方法可以减少 KV Cache 的内存占用?(提示:MQA with Quantization、Token Eviction、Cross-Layer Sharing)

  4. Pre-LN 的局限:虽然 Pre-LN 训练更稳定,但有研究发现它可能导致"表征塌缩"(Representation Collapse),即深层的输出趋于相似。你如何看待这个问题?


延伸阅读


最后更新日期: 2026-03-26 适用版本: LLM 学习教程 v2026