跳转至

手写 Transformer 完整实现

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

📌 本章定位:这是一份从零实现完整 Transformer 的全代码教程。每个组件都有数学推导 → 代码实现 → 形状验证三步走,确保你真正理解每一行代码。对标 happy-LLM 的"从零训练模型",我们的实现更深入、更贴近工程实际。

🔗 配套文件:理论详解请参见 01-Transformer 深入理解,注意力机制数学推导请参见 02-注意力机制详解


目录

内容 关键代码
1 缩放点积注意力 scaled_dot_product_attention()
2 多头注意力 MultiHeadAttention
3 位置编码 PositionalEncoding
4 前馈网络 FeedForward
5 Encoder EncoderLayer / Encoder
6 Decoder DecoderLayer / Decoder
7 完整 Transformer Transformer
8 训练:字符级语言模型 完整训练循环
9 训练: Copy Task 验证模型正确性
10 自回归生成 generate()

1. 缩放点积注意力

1.1 数学回顾

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V \]
  • \(Q \in \mathbb{R}^{n \times d_k}\)(查询)
  • \(K \in \mathbb{R}^{m \times d_k}\)(键)
  • \(V \in \mathbb{R}^{m \times d_v}\)(值)
  • 缩放因子 \(\sqrt{d_k}\) 防止点积过大导致 softmax 饱和

1.2 完整实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    缩放点积注意力

    Args:
        Q: [batch, num_heads, seq_len_q, d_k]
        K: [batch, num_heads, seq_len_k, d_k]
        V: [batch, num_heads, seq_len_k, d_v]
        mask: [batch, 1, 1, seq_len_k] 或 [batch, 1, seq_len_q, seq_len_k]

    Returns:
        output: [batch, num_heads, seq_len_q, d_v]
        attn_weights: [batch, num_heads, seq_len_q, seq_len_k]
    """
    d_k = Q.size(-1)

    # Step 1: Q @ K^T -> [batch, heads, seq_q, seq_k]
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # Step 2: 缩放
    scores = scores / math.sqrt(d_k)

    # Step 3: 掩码(将pad位置或未来位置设为-inf)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: softmax归一化
    attn_weights = F.softmax(scores, dim=-1)

    # Step 5: 加权求和
    output = torch.matmul(attn_weights, V)

    return output, attn_weights

1.3 形状验证

Python
# 验证
batch, heads, seq_q, seq_k, d_k, d_v = 2, 8, 10, 10, 64, 64
Q = torch.randn(batch, heads, seq_q, d_k)
K = torch.randn(batch, heads, seq_k, d_k)
V = torch.randn(batch, heads, seq_k, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)
assert output.shape == (2, 8, 10, 64), f"输出形状错误: {output.shape}"  # assert断言:条件False时抛出AssertionError
assert weights.shape == (2, 8, 10, 10), f"权重形状错误: {weights.shape}"

# 验证softmax归一化
assert torch.allclose(weights.sum(dim=-1), torch.ones(2, 8, 10), atol=1e-6)
print("✅ scaled_dot_product_attention 验证通过")

2. 多头注意力

2.1 核心思想

\(d_{model}\) 维空间拆分为 \(h\) 个子空间,每个子空间独立计算注意力:

\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O \]
\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]

2.2 完整实现

Python
class MultiHeadAttention(nn.Module):
    """多头注意力机制"""

    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()  # super()调用父类方法
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        # 四个线性投影层
        self.W_Q = nn.Linear(d_model, d_model)  # [d_model, d_model]
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)  # 输出投影

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch, seq_len_q, d_model]
            key:   [batch, seq_len_k, d_model]
            value: [batch, seq_len_k, d_model]
            mask:  [batch, 1, 1, seq_len_k] 或 [batch, 1, seq_q, seq_k]

        Returns:
            output: [batch, seq_len_q, d_model]
        """
        batch_size = query.size(0)

        # Step 1: 线性投影
        Q = self.W_Q(query)  # [batch, seq_q, d_model]
        K = self.W_K(key)    # [batch, seq_k, d_model]
        V = self.W_V(value)  # [batch, seq_k, d_model]

        # Step 2: 分头 (reshape + transpose)
        # [batch, seq, d_model] -> [batch, seq, heads, d_k] -> [batch, heads, seq, d_k]
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # view重塑张量形状
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Step 3: 缩放点积注意力
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Step 4: 拼接 (transpose + reshape)
        # [batch, heads, seq, d_k] -> [batch, seq, heads, d_k] -> [batch, seq, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)

        # Step 5: 输出投影
        output = self.W_O(attn_output)

        return output

2.3 形状验证

Python
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out = mha(x, x, x)
assert out.shape == (2, 10, 512)
print(f"✅ MultiHeadAttention 验证通过 | 参数量: {sum(p.numel() for p in mha.parameters()):,}")
# 预期参数量: 4 * (512*512 + 512) = 1,050,624

3. 位置编码

3.1 数学公式

\[ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \]
\[ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \]

3.2 完整实现

Python
class PositionalEncoding(nn.Module):
    """正弦位置编码"""

    def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建位置编码矩阵 [max_seq_len, d_model]
        pe = torch.zeros(max_seq_len, d_model)

        # position: [max_seq_len, 1]
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)  # unsqueeze增加一个维度

        # div_term: [d_model/2]
        # 10000^(2i/d_model) = exp(2i * (-ln(10000)/d_model))
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # 偶数位用sin,奇数位用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # [max_seq_len, d_model/2]
        pe[:, 1::2] = torch.cos(position * div_term)

        # 增加batch维度并注册为buffer(不参与梯度计算)
        pe = pe.unsqueeze(0)  # [1, max_seq_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            [batch, seq_len, d_model]
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

3.3 形状验证

Python
pe = PositionalEncoding(d_model=512)
x = torch.randn(2, 30, 512)
out = pe(x)
assert out.shape == (2, 30, 512)
# 验证不同位置的编码不同
assert not torch.allclose(pe.pe[0, 0], pe.pe[0, 1])
print("✅ PositionalEncoding 验证通过")

4. 前馈网络

4.1 完整实现

\[ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 \]
Python
class FeedForward(nn.Module):
    """位置前馈网络: d_model -> d_ff -> d_model"""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            [batch, seq_len, d_model]
        """
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

5. Encoder

5.1 单层 Encoder

Python
class EncoderLayer(nn.Module):
    """
    Transformer Encoder层

    结构(Pre-LN变体,更稳定):
        x -> LayerNorm -> MultiHeadAttention -> + residual
          -> LayerNorm -> FeedForward          -> + residual
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        """
        Args:
            x: [batch, seq_len, d_model]
            src_mask: [batch, 1, 1, seq_len]
        """
        # 子层1: 多头自注意力 + 残差
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, x, x, mask=src_mask)
        x = self.dropout1(x) + residual

        # 子层2: 前馈网络 + 残差
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = self.dropout2(x) + residual

        return x

5.2 完整 Encoder

Python
class Encoder(nn.Module):
    """Transformer Encoder: N个EncoderLayer堆叠"""

    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)  # 最终LayerNorm

    def forward(self, x, src_mask=None):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.norm(x)

6. Decoder

6.1 单层 Decoder

Python
class DecoderLayer(nn.Module):
    """
    Transformer Decoder层

    结构:
        x -> LayerNorm -> MaskedMultiHeadAttention -> + residual
          -> LayerNorm -> CrossAttention(Q=x, KV=encoder_output) -> + residual
          -> LayerNorm -> FeedForward -> + residual
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: [batch, tgt_seq_len, d_model]
            enc_output: [batch, src_seq_len, d_model]
            src_mask: [batch, 1, 1, src_seq_len]
            tgt_mask: [batch, 1, tgt_seq_len, tgt_seq_len]
        """
        # 子层1: 掩码多头自注意力
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, x, x, mask=tgt_mask)
        x = self.dropout1(x) + residual

        # 子层2: 交叉注意力(Q来自decoder,KV来自encoder)
        residual = x
        x = self.norm2(x)
        x = self.cross_attn(x, enc_output, enc_output, mask=src_mask)
        x = self.dropout2(x) + residual

        # 子层3: 前馈网络
        residual = x
        x = self.norm3(x)
        x = self.feed_forward(x)
        x = self.dropout3(x) + residual

        return x

6.2 完整 Decoder

Python
class Decoder(nn.Module):
    """Transformer Decoder: N个DecoderLayer堆叠"""

    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.norm(x)

7. 完整 Transformer

7.1 Mask 生成

Python
def make_pad_mask(seq, pad_idx=0):
    """
    生成padding掩码

    Args:
        seq: [batch, seq_len] token索引
        pad_idx: padding的token索引

    Returns:
        mask: [batch, 1, 1, seq_len]  (True=保留,False=屏蔽)
    """
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

def make_causal_mask(seq_len, device):
    """
    生成因果掩码(下三角矩阵),防止看到未来信息

    Args:
        seq_len: 序列长度
        device: 设备

    Returns:
        mask: [1, 1, seq_len, seq_len]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]

def make_tgt_mask(tgt, pad_idx=0):
    """
    生成目标序列掩码: padding掩码 AND 因果掩码

    Args:
        tgt: [batch, tgt_seq_len]

    Returns:
        mask: [batch, 1, tgt_seq_len, tgt_seq_len]
    """
    batch_size, tgt_len = tgt.size()

    # padding掩码: [batch, 1, 1, tgt_len]
    pad_mask = make_pad_mask(tgt, pad_idx)

    # 因果掩码: [1, 1, tgt_len, tgt_len]
    causal_mask = make_causal_mask(tgt_len, tgt.device)

    # 合并: broadcast -> [batch, 1, tgt_len, tgt_len]
    return pad_mask & causal_mask.bool()

7.2 完整模型

Python
class Transformer(nn.Module):
    """
    完整的Transformer (Encoder-Decoder架构)

    用于序列到序列任务(如翻译、Copy Task)
    """

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
                 num_heads=8, num_encoder_layers=6, num_decoder_layers=6,
                 d_ff=2048, max_seq_len=5000, dropout=0.1, pad_idx=0):
        super().__init__()

        self.pad_idx = pad_idx
        self.d_model = d_model

        # Embedding层
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # 位置编码
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)

        # Encoder & Decoder
        self.encoder = Encoder(num_encoder_layers, d_model, num_heads, d_ff, dropout)
        self.decoder = Decoder(num_decoder_layers, d_model, num_heads, d_ff, dropout)

        # 输出投影层
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)

        # 初始化参数
        self._init_parameters()

    def _init_parameters(self):
        """Xavier初始化"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        # 权重共享(Weight Tying):输出投影层与目标嵌入层共享权重
        # 这是 GPT-2/LLaMA 等模型的标准做法,可以减少参数量并提升性能
        # 原理:输出层需要将 d_model 映射到 vocab_size,
        #        而嵌入层正好是 vocab_size × d_model 的矩阵,二者互为转置
        if self.src_vocab_size == self.tgt_vocab_size:
            self.output_projection.weight = self.tgt_embedding.weight

    def encode(self, src):
        """
        编码源序列

        Args:
            src: [batch, src_len]
        Returns:
            enc_output: [batch, src_len, d_model]
            src_mask: [batch, 1, 1, src_len]
        """
        src_mask = make_pad_mask(src, self.pad_idx)

        # Embedding * sqrt(d_model) + PE
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        src_emb = self.positional_encoding(src_emb)

        enc_output = self.encoder(src_emb, src_mask)
        return enc_output, src_mask

    def decode(self, tgt, enc_output, src_mask):
        """
        解码目标序列

        Args:
            tgt: [batch, tgt_len]
            enc_output: [batch, src_len, d_model]
            src_mask: [batch, 1, 1, src_len]
        Returns:
            logits: [batch, tgt_len, tgt_vocab_size]
        """
        tgt_mask = make_tgt_mask(tgt, self.pad_idx)

        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.positional_encoding(tgt_emb)

        dec_output = self.decoder(tgt_emb, enc_output, src_mask, tgt_mask)
        logits = self.output_projection(dec_output)
        return logits

    def forward(self, src, tgt):
        """
        前向传播

        Args:
            src: [batch, src_len]   源序列token索引
            tgt: [batch, tgt_len]   目标序列token索引
        Returns:
            logits: [batch, tgt_len, tgt_vocab_size]
        """
        enc_output, src_mask = self.encode(src)
        logits = self.decode(tgt, enc_output, src_mask)
        return logits

7.3 形状验证

Python
# 构建一个小型Transformer
model = Transformer(
    src_vocab_size=100,
    tgt_vocab_size=100,
    d_model=64,
    num_heads=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    d_ff=128,
    dropout=0.1
)

# 模拟输入
src = torch.randint(1, 100, (2, 10))   # batch=2, src_len=10
tgt = torch.randint(1, 100, (2, 8))    # batch=2, tgt_len=8

logits = model(src, tgt)
assert logits.shape == (2, 8, 100), f"输出形状错误: {logits.shape}"

total_params = sum(p.numel() for p in model.parameters())
print(f"✅ Transformer 验证通过 | 参数量: {total_params:,}")

8. 训练: Copy Task

Copy Task 是验证 Transformer 正确性的经典测试——模型只需学会复制输入序列。如果模型连 Copy 都学不会,说明实现有 bug 。

8.1 数据集

Python
from torch.utils.data import Dataset, DataLoader

class CopyDataset(Dataset):
    """
    Copy Task: 输入 [1,3,5,7] -> 输出 [1,3,5,7]

    src: [BOS, x1, x2, ..., xn]
    tgt: [BOS, x1, x2, ..., xn, EOS]
    """

    def __init__(self, num_samples=10000, seq_len=10, vocab_size=10):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        # 特殊token: 0=PAD, 1=BOS, 2=EOS, 3~vocab_size+2=普通token
        self.pad_idx = 0
        self.bos_idx = 1
        self.eos_idx = 2
        self.total_vocab = vocab_size + 3

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 随机生成序列(使用3~vocab_size+2的token)
        seq = torch.randint(3, self.total_vocab, (self.seq_len,))

        # src: [BOS, x1, x2, ..., xn]
        src = torch.cat([torch.tensor([self.bos_idx]), seq])

        # tgt_input:  [BOS, x1, x2, ..., xn]     (decoder输入)
        # tgt_output: [x1, x2, ..., xn, EOS]      (标签)
        tgt_input = torch.cat([torch.tensor([self.bos_idx]), seq])
        tgt_output = torch.cat([seq, torch.tensor([self.eos_idx])])

        return src, tgt_input, tgt_output

8.2 完整训练循环

Python
def train_copy_task():
    """训练Copy Task来验证Transformer实现的正确性"""

    # 超参数
    VOCAB_SIZE = 10
    SEQ_LEN = 10
    D_MODEL = 64
    NUM_HEADS = 4
    NUM_LAYERS = 2
    D_FF = 128
    BATCH_SIZE = 64
    NUM_EPOCHS = 20
    LR = 1e-3
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 数据
    dataset = CopyDataset(num_samples=10000, seq_len=SEQ_LEN, vocab_size=VOCAB_SIZE)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # 模型
    model = Transformer(
        src_vocab_size=dataset.total_vocab,
        tgt_vocab_size=dataset.total_vocab,
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        num_encoder_layers=NUM_LAYERS,
        num_decoder_layers=NUM_LAYERS,
        d_ff=D_FF,
        pad_idx=0
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略PAD

    total_params = sum(p.numel() for p in model.parameters())
    print(f"设备: {DEVICE} | 参数量: {total_params:,}")

    # 训练
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for src, tgt_input, tgt_output in dataloader:
            src = src.to(DEVICE)
            tgt_input = tgt_input.to(DEVICE)
            tgt_output = tgt_output.to(DEVICE)

            # 前向传播
            logits = model(src, tgt_input)  # [batch, tgt_len, vocab]

            # 计算损失
            loss = criterion(
                logits.reshape(-1, logits.size(-1)),  # [batch*tgt_len, vocab]
                tgt_output.reshape(-1)                 # [batch*tgt_len]
            )

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

            # 计算准确率
            preds = logits.argmax(dim=-1)  # [batch, tgt_len]
            mask = tgt_output != 0
            correct += (preds[mask] == tgt_output[mask]).sum().item()
            total += mask.sum().item()

        avg_loss = total_loss / len(dataloader)
        accuracy = correct / total * 100
        print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | Loss: {avg_loss:.4f} | Acc: {accuracy:.1f}%")

        # 提前达到高准确率就停止
        if accuracy > 99.5:
            print(f"🎉 在第{epoch+1}轮达到 {accuracy:.1f}% 准确率!")
            break

    # 测试生成
    print("\n--- 生成测试 ---")
    model.eval()
    test_src = torch.randint(3, dataset.total_vocab, (1, SEQ_LEN + 1))
    test_src[0, 0] = dataset.bos_idx
    test_src = test_src.to(DEVICE)

    # 自回归生成
    generated = [dataset.bos_idx]
    for _ in range(SEQ_LEN + 1):
        tgt_tensor = torch.tensor([generated], device=DEVICE)
        with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
            logits = model(test_src, tgt_tensor)
        next_token = logits[0, -1].argmax().item()
        generated.append(next_token)
        if next_token == dataset.eos_idx:
            break

    print(f"输入:  {test_src[0].tolist()}")
    print(f"生成:  {generated}")
    print(f"期望:  {test_src[0, 1:].tolist()} + [EOS={dataset.eos_idx}]")

    return model

# 运行训练
if __name__ == "__main__":
    model = train_copy_task()

8.3 预期输出

Text Only
设备: cpu | 参数量: 217,805
Epoch  1/20 | Loss: 2.4312 | Acc: 12.5%
Epoch  2/20 | Loss: 1.8023 | Acc: 35.2%
Epoch  3/20 | Loss: 0.9847 | Acc: 68.7%
...
Epoch  8/20 | Loss: 0.0234 | Acc: 99.2%
Epoch  9/20 | Loss: 0.0089 | Acc: 99.8%
🎉 在第9轮达到 99.8% 准确率!

--- 生成测试 ---
输入:  [1, 5, 8, 3, 12, 7, 4, 9, 6, 11, 10]
生成:  [1, 5, 8, 3, 12, 7, 4, 9, 6, 11, 10, 2]
期望:  [5, 8, 3, 12, 7, 4, 9, 6, 11, 10] + [EOS=2]

如果你的模型能在 10-15 个 epoch 内达到 99%+准确率,说明实现是正确的。如果学不会,检查以下常见错误: 1. mask 生成是否正确(因果掩码是否正确遮挡了未来位置) 2. 是否忘记了 * math.sqrt(d_model) 缩放 embedding 3. 残差连接的位置是否正确


9. 训练:字符级语言模型

Copy Task 验证了模型正确性后,我们来做一个更有趣的任务——字符级语言模型。这里我们使用 Decoder-Only 架构(类似 GPT )。

9.1 Decoder-Only 模型

Python
class DecoderOnlyTransformer(nn.Module):
    """
    GPT风格的Decoder-Only模型

    与Encoder-Decoder不同,这里只有Decoder,
    且没有交叉注意力(因为没有Encoder提供上下文)。
    """

    def __init__(self, vocab_size, d_model=256, num_heads=8,
                 num_layers=4, d_ff=512, max_seq_len=512, dropout=0.1):
        super().__init__()

        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)

        # Decoder层(只有自注意力,没有交叉注意力)
        self.layers = nn.ModuleList([
            DecoderOnlyLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size)

        # 初始化
        self._init_params()

    def _init_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        # 权重共享(Weight Tying):GPT-2 的关键技巧
        # 输出投影层与嵌入层共享权重,减少参数量并提升泛化
        self.output_proj.weight = self.embedding.weight

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len] token索引
        Returns:
            logits: [batch, seq_len, vocab_size]
        """
        seq_len = x.size(1)

        # 因果掩码
        causal_mask = make_causal_mask(seq_len, x.device)

        # Embedding + PE
        h = self.embedding(x) * math.sqrt(self.d_model)
        h = self.pos_encoding(h)

        # N层Decoder
        for layer in self.layers:
            h = layer(h, causal_mask)

        h = self.norm(h)
        logits = self.output_proj(h)
        return logits

    @torch.no_grad()  # 禁用梯度计算,节省内存(推理时使用)
    def generate(self, start_tokens, max_new_tokens=100, temperature=1.0):
        """
        自回归文本生成

        Args:
            start_tokens: [1, start_len] 起始token
            max_new_tokens: 最大新生成token数
            temperature: 采样温度
        """
        self.eval()
        tokens = start_tokens.clone()

        for _ in range(max_new_tokens):
            logits = self(tokens)             # [1, current_len, vocab]
            next_logits = logits[0, -1] / temperature  # [vocab]
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, 1).unsqueeze(0)  # [1, 1]
            tokens = torch.cat([tokens, next_token], dim=1)

        return tokens

class DecoderOnlyLayer(nn.Module):
    """Decoder-Only层:只有因果自注意力 + FFN"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, causal_mask):
        # 自注意力 + 残差
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, x, x, mask=causal_mask)
        x = self.dropout1(x) + residual

        # FFN + 残差
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = self.dropout2(x) + residual

        return x

9.2 字符级数据集与训练

Python
class CharDataset(Dataset):
    """字符级数据集"""

    def __init__(self, text, seq_len):
        self.seq_len = seq_len
        self.chars = sorted(list(set(text)))
        self.vocab_size = len(self.chars)
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.data = [self.char_to_idx[ch] for ch in text]

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        y = self.data[idx + 1:idx + self.seq_len + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

    def encode(self, text):
        return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx]

    def decode(self, indices):
        return ''.join(self.idx_to_char.get(i, '?') for i in indices)

def train_char_lm():
    """训练字符级语言模型"""

    # --- 数据 ---
    text = """
    Twinkle, twinkle, little star,
    How I wonder what you are!
    Up above the world so high,
    Like a diamond in the sky.
    Twinkle, twinkle, little star,
    How I wonder what you are!
    When the blazing sun is gone,
    When he nothing shines upon,
    Then you show your little light,
    Twinkle, twinkle, all the night.
    Twinkle, twinkle, little star,
    How I wonder what you are!
    """

    SEQ_LEN = 64
    BATCH_SIZE = 32
    NUM_EPOCHS = 100
    LR = 3e-4
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CharDataset(text, SEQ_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    print(f"词表大小: {dataset.vocab_size} | 字符: {''.join(dataset.chars)}")
    print(f"训练样本数: {len(dataset)}")

    # --- 模型 ---
    model = DecoderOnlyTransformer(
        vocab_size=dataset.vocab_size,
        d_model=64,
        num_heads=4,
        num_layers=2,
        d_ff=128,
        max_seq_len=SEQ_LEN + 1,
        dropout=0.1
    ).to(DEVICE)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"参数量: {total_params:,}")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    # --- 训练 ---
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0

        for x, y in dataloader:
            x, y = x.to(DEVICE), y.to(DEVICE)

            logits = model(x)
            loss = criterion(logits.view(-1, dataset.vocab_size), y.view(-1))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS} | Loss: {avg_loss:.4f}")

            # 生成样例
            start = "Twinkle"
            start_ids = torch.tensor([dataset.encode(start)], device=DEVICE)
            generated_ids = model.generate(start_ids, max_new_tokens=80, temperature=0.8)
            generated_text = dataset.decode(generated_ids[0].tolist())
            print(f"  生成: {generated_text[:100]}")

    return model, dataset

if __name__ == "__main__":
    model, dataset = train_char_lm()

9.3 预期输出

Text Only
词表大小: 36 | 字符: !\n ,.HILTUWabdeghiklnorstuwy
训练样本数: 389
参数量: 115,236
Epoch  20/100 | Loss: 2.1432
  生成: Twinkle, twinkle, lithen yhe sond so
Epoch  40/100 | Loss: 1.5234
  生成: Twinkle, twinkle, little star, how I wonder wh
Epoch  60/100 | Loss: 1.0891
  生成: Twinkle, twinkle, little star,\nHow I wonder what you are!
Epoch 100/100 | Loss: 0.7234
  生成: Twinkle, twinkle, little star,\nHow I wonder what you are!\nUp above the world so high,

10. 完整代码汇总与调试指南

10.1 如何使用本教程的代码

本教程的所有代码段可以按顺序组合成一个完整的 Python 文件。组合顺序如下:

Text Only
代码文件结构(按本文顺序组合):
┌─────────────────────────────────────────────────────────────┐
│ 1. imports(torch, math, F, nn, Dataset, DataLoader)       │
│ 2. scaled_dot_product_attention()     ← 第1节               │
│ 3. MultiHeadAttention                 ← 第2节               │
│ 4. PositionalEncoding                 ← 第3节               │
│ 5. FeedForward                        ← 第4节               │
│ 6. EncoderLayer + Encoder             ← 第5节               │
│ 7. DecoderLayer + Decoder             ← 第6节               │
│ 8. make_pad_mask / make_causal_mask / make_tgt_mask          │
│                                       ← 第7节               │
│ 9. Transformer                        ← 第7节               │
│ 10. DecoderOnlyLayer + DecoderOnlyTransformer                │
│                                       ← 第9节               │
│ 11. CopyDataset + train_copy_task     ← 第8节               │
│ 12. CharDataset + train_char_lm       ← 第9节               │
│ 13. if __name__ == "__main__" 入口    ← 运行入口            │
└─────────────────────────────────────────────────────────────┘

将所有代码段按顺序复制到一个 .py 文件中,即可直接运行:

Bash
python my_transformer.py copy   # 运行 Copy Task 验证
python my_transformer.py lm     # 运行字符级语言模型

10.2 常见 Bug 与调试清单

实现 Transformer 时最容易出错的地方:

Text Only
常见 Bug 清单:
┌───┬──────────────────────────────────────────────────────────┐
│ # │ Bug 描述                              │ 症状             │
├───┼──────────────────────────────────────────────────────────┤
│ 1 │ 因果掩码方向错误                        │ 模型能"偷看"未来  │
│   │ (用了上三角而非下三角)                  │ 训练 loss 不降   │
├───┼──────────────────────────────────────────────────────────┤
│ 2 │ 忘记 * sqrt(d_model) 缩放              │ 训练不稳定       │
│   │                                        │ 收敛极慢         │
├───┼──────────────────────────────────────────────────────────┤
│ 3 │ view/reshape 维度顺序错误               │ 输出形状不对     │
│   │ (分头时 batch 和 heads 维度搞反)       │ 或运行时报错     │
├───┼──────────────────────────────────────────────────────────┤
│ 4 │ 残差连接位置错误                        │ 深层网络无法训练  │
│   │ (Post-LN vs Pre-LN 搞混)             │ 梯度消失         │
├───┼──────────────────────────────────────────────────────────┤
│ 5 │ contiguous() 缺失                      │ view() 报错      │
│   │ (transpose 后必须 contiguous 才能 view)│ RuntimeError    │
├───┼──────────────────────────────────────────────────────────┤
│ 6 │ mask 值类型不匹配                       │ 注意力权重异常   │
│   │ (float mask vs bool mask)             │ 或 NaN 出现     │
└───┴──────────────────────────────────────────────────────────┘

10.3 验证代码正确性的方法

Python
def verify_implementation():
    """快速验证 Transformer 实现的正确性"""

    # 测试1: 形状正确性
    model = Transformer(src_vocab_size=50, tgt_vocab_size=50,
                        d_model=64, num_heads=4, num_encoder_layers=2,
                        num_decoder_layers=2, d_ff=128)
    src = torch.randint(1, 50, (2, 8))
    tgt = torch.randint(1, 50, (2, 6))
    logits = model(src, tgt)
    assert logits.shape == (2, 6, 50), f"形状错误: {logits.shape}"
    print("✅ 测试1通过: 输出形状正确")

    # 测试2: 因果掩码有效性
    # 验证位置 i 的输出不依赖于位置 i+1 的输入
    tgt_test = torch.randint(1, 50, (1, 5))
    with torch.no_grad():
        out1 = model(src[:1], tgt_test)
        # 修改最后一个 token,前面的输出不应改变
        tgt_modified = tgt_test.clone()
        tgt_modified[0, -1] = 0
        out2 = model(src[:1], tgt_modified)
    # 前面位置的输出应该相同(因果掩码保证)
    assert torch.allclose(out1[0, :-1], out2[0, :-1], atol=1e-5), "因果掩码可能有问题!"
    print("✅ 测试2通过: 因果掩码有效")

    # 测试3: 梯度流通性
    model.train()
    logits = model(src, tgt)
    loss = logits.sum()
    loss.backward()
    # 检查所有参数都有梯度
    for name, param in model.named_parameters():
        assert param.grad is not None, f"参数 {name} 没有梯度!"
    print("✅ 测试3通过: 梯度流通正常")

    # 测试4: 权重共享验证
    # 检查输出投影层和嵌入层是否共享权重
    assert model.output_projection.weight.data_ptr() == model.tgt_embedding.weight.data_ptr()
    print("✅ 测试4通过: 权重共享已启用")

    print("\n🎉 所有验证测试通过!实现正确。")

verify_implementation()

📝 自测检查清单

  • 能解释缩放因子 \(\sqrt{d_k}\) 的作用(防止 softmax 饱和)
  • 能手绘多头注意力的分头-计算-拼接流程
  • Copy Task 在 10 个 epoch 内达到 99%+准确率
  • 字符级 LM 在 100 个 epoch 后能生成有意义的文本
  • 能独立调试 mask 相关的 bug (最常见的错误来源)
  • 理解 Pre-LN vs Post-LN 的区别和原因

思考题

  1. 权重共享的影响:本实现中使用了 Weight Tying(输出投影层与嵌入层共享权重)。如果去掉权重共享,模型的参数量会如何变化?对训练效果有什么影响?尝试修改代码去掉权重共享,对比训练曲线。

  2. 字符级 vs 子词级:本教程使用字符级语言模型。如果改为 BPE 子词分词,词表大小从 ~40 增加到 ~32000,模型需要哪些调整?(提示:embedding 层、输出层、序列长度、学习率)

  3. Decoder-Only vs Encoder-Decoder:在 Copy Task 中,Encoder-Decoder 和 Decoder-Only 哪个更容易学习?为什么?尝试用 Decoder-Only 模型做 Copy Task,对比收敛速度。

  4. 训练稳定性:如果将模型深度从 2 层增加到 12 层,不做任何其他修改,训练还能稳定吗?需要哪些额外措施?(提示:学习率 warm-up、梯度裁剪、残差连接的缩放)


📚 参考

  1. Vaswani et al., "Attention Is All You Need" (2017)
  2. Andrej Karpathy, "Let's build GPT: from scratch, in code" (2023)
  3. Harvard NLP, "The Annotated Transformer" (2018)

最后更新日期: 2026-03-26 适用版本: LLM 学习教程 v2026