手写 Transformer 完整实现¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
📌 本章定位:这是一份从零实现完整 Transformer 的全代码教程。每个组件都有数学推导 → 代码实现 → 形状验证三步走,确保你真正理解每一行代码。对标 happy-LLM 的"从零训练模型",我们的实现更深入、更贴近工程实际。
🔗 配套文件:理论详解请参见 01-Transformer 深入理解,注意力机制数学推导请参见 02-注意力机制详解。
目录¶
| 节 | 内容 | 关键代码 |
|---|---|---|
| 1 | 缩放点积注意力 | scaled_dot_product_attention() |
| 2 | 多头注意力 | MultiHeadAttention |
| 3 | 位置编码 | PositionalEncoding |
| 4 | 前馈网络 | FeedForward |
| 5 | Encoder | EncoderLayer / Encoder |
| 6 | Decoder | DecoderLayer / Decoder |
| 7 | 完整 Transformer | Transformer |
| 8 | 训练:字符级语言模型 | 完整训练循环 |
| 9 | 训练: Copy Task | 验证模型正确性 |
| 10 | 自回归生成 | generate() |
1. 缩放点积注意力¶
1.1 数学回顾¶
- \(Q \in \mathbb{R}^{n \times d_k}\)(查询)
- \(K \in \mathbb{R}^{m \times d_k}\)(键)
- \(V \in \mathbb{R}^{m \times d_v}\)(值)
- 缩放因子 \(\sqrt{d_k}\) 防止点积过大导致 softmax 饱和
1.2 完整实现¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力
Args:
Q: [batch, num_heads, seq_len_q, d_k]
K: [batch, num_heads, seq_len_k, d_k]
V: [batch, num_heads, seq_len_k, d_v]
mask: [batch, 1, 1, seq_len_k] 或 [batch, 1, seq_len_q, seq_len_k]
Returns:
output: [batch, num_heads, seq_len_q, d_v]
attn_weights: [batch, num_heads, seq_len_q, seq_len_k]
"""
d_k = Q.size(-1)
# Step 1: Q @ K^T -> [batch, heads, seq_q, seq_k]
scores = torch.matmul(Q, K.transpose(-2, -1))
# Step 2: 缩放
scores = scores / math.sqrt(d_k)
# Step 3: 掩码(将pad位置或未来位置设为-inf)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: softmax归一化
attn_weights = F.softmax(scores, dim=-1)
# Step 5: 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights
1.3 形状验证¶
# 验证
batch, heads, seq_q, seq_k, d_k, d_v = 2, 8, 10, 10, 64, 64
Q = torch.randn(batch, heads, seq_q, d_k)
K = torch.randn(batch, heads, seq_k, d_k)
V = torch.randn(batch, heads, seq_k, d_v)
output, weights = scaled_dot_product_attention(Q, K, V)
assert output.shape == (2, 8, 10, 64), f"输出形状错误: {output.shape}" # assert断言:条件False时抛出AssertionError
assert weights.shape == (2, 8, 10, 10), f"权重形状错误: {weights.shape}"
# 验证softmax归一化
assert torch.allclose(weights.sum(dim=-1), torch.ones(2, 8, 10), atol=1e-6)
print("✅ scaled_dot_product_attention 验证通过")
2. 多头注意力¶
2.1 核心思想¶
将 \(d_{model}\) 维空间拆分为 \(h\) 个子空间,每个子空间独立计算注意力:
2.2 完整实现¶
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__() # super()调用父类方法
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 四个线性投影层
self.W_Q = nn.Linear(d_model, d_model) # [d_model, d_model]
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model) # 输出投影
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
Args:
query: [batch, seq_len_q, d_model]
key: [batch, seq_len_k, d_model]
value: [batch, seq_len_k, d_model]
mask: [batch, 1, 1, seq_len_k] 或 [batch, 1, seq_q, seq_k]
Returns:
output: [batch, seq_len_q, d_model]
"""
batch_size = query.size(0)
# Step 1: 线性投影
Q = self.W_Q(query) # [batch, seq_q, d_model]
K = self.W_K(key) # [batch, seq_k, d_model]
V = self.W_V(value) # [batch, seq_k, d_model]
# Step 2: 分头 (reshape + transpose)
# [batch, seq, d_model] -> [batch, seq, heads, d_k] -> [batch, heads, seq, d_k]
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # view重塑张量形状
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Step 3: 缩放点积注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Step 4: 拼接 (transpose + reshape)
# [batch, heads, seq, d_k] -> [batch, seq, heads, d_k] -> [batch, seq, d_model]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# Step 5: 输出投影
output = self.W_O(attn_output)
return output
2.3 形状验证¶
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)
out = mha(x, x, x)
assert out.shape == (2, 10, 512)
print(f"✅ MultiHeadAttention 验证通过 | 参数量: {sum(p.numel() for p in mha.parameters()):,}")
# 预期参数量: 4 * (512*512 + 512) = 1,050,624
3. 位置编码¶
3.1 数学公式¶
3.2 完整实现¶
class PositionalEncoding(nn.Module):
"""正弦位置编码"""
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵 [max_seq_len, d_model]
pe = torch.zeros(max_seq_len, d_model)
# position: [max_seq_len, 1]
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) # unsqueeze增加一个维度
# div_term: [d_model/2]
# 10000^(2i/d_model) = exp(2i * (-ln(10000)/d_model))
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# 偶数位用sin,奇数位用cos
pe[:, 0::2] = torch.sin(position * div_term) # [max_seq_len, d_model/2]
pe[:, 1::2] = torch.cos(position * div_term)
# 增加batch维度并注册为buffer(不参与梯度计算)
pe = pe.unsqueeze(0) # [1, max_seq_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
3.3 形状验证¶
pe = PositionalEncoding(d_model=512)
x = torch.randn(2, 30, 512)
out = pe(x)
assert out.shape == (2, 30, 512)
# 验证不同位置的编码不同
assert not torch.allclose(pe.pe[0, 0], pe.pe[0, 1])
print("✅ PositionalEncoding 验证通过")
4. 前馈网络¶
4.1 完整实现¶
class FeedForward(nn.Module):
"""位置前馈网络: d_model -> d_ff -> d_model"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
[batch, seq_len, d_model]
"""
return self.linear2(self.dropout(F.relu(self.linear1(x))))
5. Encoder¶
5.1 单层 Encoder¶
class EncoderLayer(nn.Module):
"""
Transformer Encoder层
结构(Pre-LN变体,更稳定):
x -> LayerNorm -> MultiHeadAttention -> + residual
-> LayerNorm -> FeedForward -> + residual
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, src_mask=None):
"""
Args:
x: [batch, seq_len, d_model]
src_mask: [batch, 1, 1, seq_len]
"""
# 子层1: 多头自注意力 + 残差
residual = x
x = self.norm1(x)
x = self.self_attn(x, x, x, mask=src_mask)
x = self.dropout1(x) + residual
# 子层2: 前馈网络 + 残差
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
x = self.dropout2(x) + residual
return x
5.2 完整 Encoder¶
class Encoder(nn.Module):
"""Transformer Encoder: N个EncoderLayer堆叠"""
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model) # 最终LayerNorm
def forward(self, x, src_mask=None):
for layer in self.layers:
x = layer(x, src_mask)
return self.norm(x)
6. Decoder¶
6.1 单层 Decoder¶
class DecoderLayer(nn.Module):
"""
Transformer Decoder层
结构:
x -> LayerNorm -> MaskedMultiHeadAttention -> + residual
-> LayerNorm -> CrossAttention(Q=x, KV=encoder_output) -> + residual
-> LayerNorm -> FeedForward -> + residual
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
"""
Args:
x: [batch, tgt_seq_len, d_model]
enc_output: [batch, src_seq_len, d_model]
src_mask: [batch, 1, 1, src_seq_len]
tgt_mask: [batch, 1, tgt_seq_len, tgt_seq_len]
"""
# 子层1: 掩码多头自注意力
residual = x
x = self.norm1(x)
x = self.self_attn(x, x, x, mask=tgt_mask)
x = self.dropout1(x) + residual
# 子层2: 交叉注意力(Q来自decoder,KV来自encoder)
residual = x
x = self.norm2(x)
x = self.cross_attn(x, enc_output, enc_output, mask=src_mask)
x = self.dropout2(x) + residual
# 子层3: 前馈网络
residual = x
x = self.norm3(x)
x = self.feed_forward(x)
x = self.dropout3(x) + residual
return x
6.2 完整 Decoder¶
class Decoder(nn.Module):
"""Transformer Decoder: N个DecoderLayer堆叠"""
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return self.norm(x)
7. 完整 Transformer¶
7.1 Mask 生成¶
def make_pad_mask(seq, pad_idx=0):
"""
生成padding掩码
Args:
seq: [batch, seq_len] token索引
pad_idx: padding的token索引
Returns:
mask: [batch, 1, 1, seq_len] (True=保留,False=屏蔽)
"""
return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
def make_causal_mask(seq_len, device):
"""
生成因果掩码(下三角矩阵),防止看到未来信息
Args:
seq_len: 序列长度
device: 设备
Returns:
mask: [1, 1, seq_len, seq_len]
"""
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
def make_tgt_mask(tgt, pad_idx=0):
"""
生成目标序列掩码: padding掩码 AND 因果掩码
Args:
tgt: [batch, tgt_seq_len]
Returns:
mask: [batch, 1, tgt_seq_len, tgt_seq_len]
"""
batch_size, tgt_len = tgt.size()
# padding掩码: [batch, 1, 1, tgt_len]
pad_mask = make_pad_mask(tgt, pad_idx)
# 因果掩码: [1, 1, tgt_len, tgt_len]
causal_mask = make_causal_mask(tgt_len, tgt.device)
# 合并: broadcast -> [batch, 1, tgt_len, tgt_len]
return pad_mask & causal_mask.bool()
7.2 完整模型¶
class Transformer(nn.Module):
"""
完整的Transformer (Encoder-Decoder架构)
用于序列到序列任务(如翻译、Copy Task)
"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
num_heads=8, num_encoder_layers=6, num_decoder_layers=6,
d_ff=2048, max_seq_len=5000, dropout=0.1, pad_idx=0):
super().__init__()
self.pad_idx = pad_idx
self.d_model = d_model
# Embedding层
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 位置编码
self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Encoder & Decoder
self.encoder = Encoder(num_encoder_layers, d_model, num_heads, d_ff, dropout)
self.decoder = Decoder(num_decoder_layers, d_model, num_heads, d_ff, dropout)
# 输出投影层
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
# 初始化参数
self._init_parameters()
def _init_parameters(self):
"""Xavier初始化"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# 权重共享(Weight Tying):输出投影层与目标嵌入层共享权重
# 这是 GPT-2/LLaMA 等模型的标准做法,可以减少参数量并提升性能
# 原理:输出层需要将 d_model 映射到 vocab_size,
# 而嵌入层正好是 vocab_size × d_model 的矩阵,二者互为转置
if self.src_vocab_size == self.tgt_vocab_size:
self.output_projection.weight = self.tgt_embedding.weight
def encode(self, src):
"""
编码源序列
Args:
src: [batch, src_len]
Returns:
enc_output: [batch, src_len, d_model]
src_mask: [batch, 1, 1, src_len]
"""
src_mask = make_pad_mask(src, self.pad_idx)
# Embedding * sqrt(d_model) + PE
src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
src_emb = self.positional_encoding(src_emb)
enc_output = self.encoder(src_emb, src_mask)
return enc_output, src_mask
def decode(self, tgt, enc_output, src_mask):
"""
解码目标序列
Args:
tgt: [batch, tgt_len]
enc_output: [batch, src_len, d_model]
src_mask: [batch, 1, 1, src_len]
Returns:
logits: [batch, tgt_len, tgt_vocab_size]
"""
tgt_mask = make_tgt_mask(tgt, self.pad_idx)
tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
tgt_emb = self.positional_encoding(tgt_emb)
dec_output = self.decoder(tgt_emb, enc_output, src_mask, tgt_mask)
logits = self.output_projection(dec_output)
return logits
def forward(self, src, tgt):
"""
前向传播
Args:
src: [batch, src_len] 源序列token索引
tgt: [batch, tgt_len] 目标序列token索引
Returns:
logits: [batch, tgt_len, tgt_vocab_size]
"""
enc_output, src_mask = self.encode(src)
logits = self.decode(tgt, enc_output, src_mask)
return logits
7.3 形状验证¶
# 构建一个小型Transformer
model = Transformer(
src_vocab_size=100,
tgt_vocab_size=100,
d_model=64,
num_heads=4,
num_encoder_layers=2,
num_decoder_layers=2,
d_ff=128,
dropout=0.1
)
# 模拟输入
src = torch.randint(1, 100, (2, 10)) # batch=2, src_len=10
tgt = torch.randint(1, 100, (2, 8)) # batch=2, tgt_len=8
logits = model(src, tgt)
assert logits.shape == (2, 8, 100), f"输出形状错误: {logits.shape}"
total_params = sum(p.numel() for p in model.parameters())
print(f"✅ Transformer 验证通过 | 参数量: {total_params:,}")
8. 训练: Copy Task¶
Copy Task 是验证 Transformer 正确性的经典测试——模型只需学会复制输入序列。如果模型连 Copy 都学不会,说明实现有 bug 。
8.1 数据集¶
from torch.utils.data import Dataset, DataLoader
class CopyDataset(Dataset):
"""
Copy Task: 输入 [1,3,5,7] -> 输出 [1,3,5,7]
src: [BOS, x1, x2, ..., xn]
tgt: [BOS, x1, x2, ..., xn, EOS]
"""
def __init__(self, num_samples=10000, seq_len=10, vocab_size=10):
self.num_samples = num_samples
self.seq_len = seq_len
self.vocab_size = vocab_size
# 特殊token: 0=PAD, 1=BOS, 2=EOS, 3~vocab_size+2=普通token
self.pad_idx = 0
self.bos_idx = 1
self.eos_idx = 2
self.total_vocab = vocab_size + 3
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 随机生成序列(使用3~vocab_size+2的token)
seq = torch.randint(3, self.total_vocab, (self.seq_len,))
# src: [BOS, x1, x2, ..., xn]
src = torch.cat([torch.tensor([self.bos_idx]), seq])
# tgt_input: [BOS, x1, x2, ..., xn] (decoder输入)
# tgt_output: [x1, x2, ..., xn, EOS] (标签)
tgt_input = torch.cat([torch.tensor([self.bos_idx]), seq])
tgt_output = torch.cat([seq, torch.tensor([self.eos_idx])])
return src, tgt_input, tgt_output
8.2 完整训练循环¶
def train_copy_task():
"""训练Copy Task来验证Transformer实现的正确性"""
# 超参数
VOCAB_SIZE = 10
SEQ_LEN = 10
D_MODEL = 64
NUM_HEADS = 4
NUM_LAYERS = 2
D_FF = 128
BATCH_SIZE = 64
NUM_EPOCHS = 20
LR = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据
dataset = CopyDataset(num_samples=10000, seq_len=SEQ_LEN, vocab_size=VOCAB_SIZE)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# 模型
model = Transformer(
src_vocab_size=dataset.total_vocab,
tgt_vocab_size=dataset.total_vocab,
d_model=D_MODEL,
num_heads=NUM_HEADS,
num_encoder_layers=NUM_LAYERS,
num_decoder_layers=NUM_LAYERS,
d_ff=D_FF,
pad_idx=0
).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略PAD
total_params = sum(p.numel() for p in model.parameters())
print(f"设备: {DEVICE} | 参数量: {total_params:,}")
# 训练
for epoch in range(NUM_EPOCHS):
model.train()
total_loss = 0
correct = 0
total = 0
for src, tgt_input, tgt_output in dataloader:
src = src.to(DEVICE)
tgt_input = tgt_input.to(DEVICE)
tgt_output = tgt_output.to(DEVICE)
# 前向传播
logits = model(src, tgt_input) # [batch, tgt_len, vocab]
# 计算损失
loss = criterion(
logits.reshape(-1, logits.size(-1)), # [batch*tgt_len, vocab]
tgt_output.reshape(-1) # [batch*tgt_len]
)
# 反向传播
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
# 计算准确率
preds = logits.argmax(dim=-1) # [batch, tgt_len]
mask = tgt_output != 0
correct += (preds[mask] == tgt_output[mask]).sum().item()
total += mask.sum().item()
avg_loss = total_loss / len(dataloader)
accuracy = correct / total * 100
print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | Loss: {avg_loss:.4f} | Acc: {accuracy:.1f}%")
# 提前达到高准确率就停止
if accuracy > 99.5:
print(f"🎉 在第{epoch+1}轮达到 {accuracy:.1f}% 准确率!")
break
# 测试生成
print("\n--- 生成测试 ---")
model.eval()
test_src = torch.randint(3, dataset.total_vocab, (1, SEQ_LEN + 1))
test_src[0, 0] = dataset.bos_idx
test_src = test_src.to(DEVICE)
# 自回归生成
generated = [dataset.bos_idx]
for _ in range(SEQ_LEN + 1):
tgt_tensor = torch.tensor([generated], device=DEVICE)
with torch.no_grad(): # 禁用梯度计算,节省内存(推理时使用)
logits = model(test_src, tgt_tensor)
next_token = logits[0, -1].argmax().item()
generated.append(next_token)
if next_token == dataset.eos_idx:
break
print(f"输入: {test_src[0].tolist()}")
print(f"生成: {generated}")
print(f"期望: {test_src[0, 1:].tolist()} + [EOS={dataset.eos_idx}]")
return model
# 运行训练
if __name__ == "__main__":
model = train_copy_task()
8.3 预期输出¶
设备: cpu | 参数量: 217,805
Epoch 1/20 | Loss: 2.4312 | Acc: 12.5%
Epoch 2/20 | Loss: 1.8023 | Acc: 35.2%
Epoch 3/20 | Loss: 0.9847 | Acc: 68.7%
...
Epoch 8/20 | Loss: 0.0234 | Acc: 99.2%
Epoch 9/20 | Loss: 0.0089 | Acc: 99.8%
🎉 在第9轮达到 99.8% 准确率!
--- 生成测试 ---
输入: [1, 5, 8, 3, 12, 7, 4, 9, 6, 11, 10]
生成: [1, 5, 8, 3, 12, 7, 4, 9, 6, 11, 10, 2]
期望: [5, 8, 3, 12, 7, 4, 9, 6, 11, 10] + [EOS=2]
如果你的模型能在 10-15 个 epoch 内达到 99%+准确率,说明实现是正确的。如果学不会,检查以下常见错误: 1. mask 生成是否正确(因果掩码是否正确遮挡了未来位置) 2. 是否忘记了
* math.sqrt(d_model)缩放 embedding 3. 残差连接的位置是否正确
9. 训练:字符级语言模型¶
Copy Task 验证了模型正确性后,我们来做一个更有趣的任务——字符级语言模型。这里我们使用 Decoder-Only 架构(类似 GPT )。
9.1 Decoder-Only 模型¶
class DecoderOnlyTransformer(nn.Module):
"""
GPT风格的Decoder-Only模型
与Encoder-Decoder不同,这里只有Decoder,
且没有交叉注意力(因为没有Encoder提供上下文)。
"""
def __init__(self, vocab_size, d_model=256, num_heads=8,
num_layers=4, d_ff=512, max_seq_len=512, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Decoder层(只有自注意力,没有交叉注意力)
self.layers = nn.ModuleList([
DecoderOnlyLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.output_proj = nn.Linear(d_model, vocab_size)
# 初始化
self._init_params()
def _init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# 权重共享(Weight Tying):GPT-2 的关键技巧
# 输出投影层与嵌入层共享权重,减少参数量并提升泛化
self.output_proj.weight = self.embedding.weight
def forward(self, x):
"""
Args:
x: [batch, seq_len] token索引
Returns:
logits: [batch, seq_len, vocab_size]
"""
seq_len = x.size(1)
# 因果掩码
causal_mask = make_causal_mask(seq_len, x.device)
# Embedding + PE
h = self.embedding(x) * math.sqrt(self.d_model)
h = self.pos_encoding(h)
# N层Decoder
for layer in self.layers:
h = layer(h, causal_mask)
h = self.norm(h)
logits = self.output_proj(h)
return logits
@torch.no_grad() # 禁用梯度计算,节省内存(推理时使用)
def generate(self, start_tokens, max_new_tokens=100, temperature=1.0):
"""
自回归文本生成
Args:
start_tokens: [1, start_len] 起始token
max_new_tokens: 最大新生成token数
temperature: 采样温度
"""
self.eval()
tokens = start_tokens.clone()
for _ in range(max_new_tokens):
logits = self(tokens) # [1, current_len, vocab]
next_logits = logits[0, -1] / temperature # [vocab]
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1).unsqueeze(0) # [1, 1]
tokens = torch.cat([tokens, next_token], dim=1)
return tokens
class DecoderOnlyLayer(nn.Module):
"""Decoder-Only层:只有因果自注意力 + FFN"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, causal_mask):
# 自注意力 + 残差
residual = x
x = self.norm1(x)
x = self.self_attn(x, x, x, mask=causal_mask)
x = self.dropout1(x) + residual
# FFN + 残差
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
x = self.dropout2(x) + residual
return x
9.2 字符级数据集与训练¶
class CharDataset(Dataset):
"""字符级数据集"""
def __init__(self, text, seq_len):
self.seq_len = seq_len
self.chars = sorted(list(set(text)))
self.vocab_size = len(self.chars)
self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
self.data = [self.char_to_idx[ch] for ch in text]
def __len__(self):
return len(self.data) - self.seq_len
def __getitem__(self, idx):
x = self.data[idx:idx + self.seq_len]
y = self.data[idx + 1:idx + self.seq_len + 1]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
def encode(self, text):
return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx]
def decode(self, indices):
return ''.join(self.idx_to_char.get(i, '?') for i in indices)
def train_char_lm():
"""训练字符级语言模型"""
# --- 数据 ---
text = """
Twinkle, twinkle, little star,
How I wonder what you are!
Up above the world so high,
Like a diamond in the sky.
Twinkle, twinkle, little star,
How I wonder what you are!
When the blazing sun is gone,
When he nothing shines upon,
Then you show your little light,
Twinkle, twinkle, all the night.
Twinkle, twinkle, little star,
How I wonder what you are!
"""
SEQ_LEN = 64
BATCH_SIZE = 32
NUM_EPOCHS = 100
LR = 3e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = CharDataset(text, SEQ_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"词表大小: {dataset.vocab_size} | 字符: {''.join(dataset.chars)}")
print(f"训练样本数: {len(dataset)}")
# --- 模型 ---
model = DecoderOnlyTransformer(
vocab_size=dataset.vocab_size,
d_model=64,
num_heads=4,
num_layers=2,
d_ff=128,
max_seq_len=SEQ_LEN + 1,
dropout=0.1
).to(DEVICE)
total_params = sum(p.numel() for p in model.parameters())
print(f"参数量: {total_params:,}")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
# --- 训练 ---
for epoch in range(NUM_EPOCHS):
model.train()
total_loss = 0
for x, y in dataloader:
x, y = x.to(DEVICE), y.to(DEVICE)
logits = model(x)
loss = criterion(logits.view(-1, dataset.vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS} | Loss: {avg_loss:.4f}")
# 生成样例
start = "Twinkle"
start_ids = torch.tensor([dataset.encode(start)], device=DEVICE)
generated_ids = model.generate(start_ids, max_new_tokens=80, temperature=0.8)
generated_text = dataset.decode(generated_ids[0].tolist())
print(f" 生成: {generated_text[:100]}")
return model, dataset
if __name__ == "__main__":
model, dataset = train_char_lm()
9.3 预期输出¶
词表大小: 36 | 字符: !\n ,.HILTUWabdeghiklnorstuwy
训练样本数: 389
参数量: 115,236
Epoch 20/100 | Loss: 2.1432
生成: Twinkle, twinkle, lithen yhe sond so
Epoch 40/100 | Loss: 1.5234
生成: Twinkle, twinkle, little star, how I wonder wh
Epoch 60/100 | Loss: 1.0891
生成: Twinkle, twinkle, little star,\nHow I wonder what you are!
Epoch 100/100 | Loss: 0.7234
生成: Twinkle, twinkle, little star,\nHow I wonder what you are!\nUp above the world so high,
10. 完整代码汇总与调试指南¶
10.1 如何使用本教程的代码¶
本教程的所有代码段可以按顺序组合成一个完整的 Python 文件。组合顺序如下:
代码文件结构(按本文顺序组合):
┌─────────────────────────────────────────────────────────────┐
│ 1. imports(torch, math, F, nn, Dataset, DataLoader) │
│ 2. scaled_dot_product_attention() ← 第1节 │
│ 3. MultiHeadAttention ← 第2节 │
│ 4. PositionalEncoding ← 第3节 │
│ 5. FeedForward ← 第4节 │
│ 6. EncoderLayer + Encoder ← 第5节 │
│ 7. DecoderLayer + Decoder ← 第6节 │
│ 8. make_pad_mask / make_causal_mask / make_tgt_mask │
│ ← 第7节 │
│ 9. Transformer ← 第7节 │
│ 10. DecoderOnlyLayer + DecoderOnlyTransformer │
│ ← 第9节 │
│ 11. CopyDataset + train_copy_task ← 第8节 │
│ 12. CharDataset + train_char_lm ← 第9节 │
│ 13. if __name__ == "__main__" 入口 ← 运行入口 │
└─────────────────────────────────────────────────────────────┘
将所有代码段按顺序复制到一个 .py 文件中,即可直接运行:
10.2 常见 Bug 与调试清单¶
实现 Transformer 时最容易出错的地方:
常见 Bug 清单:
┌───┬──────────────────────────────────────────────────────────┐
│ # │ Bug 描述 │ 症状 │
├───┼──────────────────────────────────────────────────────────┤
│ 1 │ 因果掩码方向错误 │ 模型能"偷看"未来 │
│ │ (用了上三角而非下三角) │ 训练 loss 不降 │
├───┼──────────────────────────────────────────────────────────┤
│ 2 │ 忘记 * sqrt(d_model) 缩放 │ 训练不稳定 │
│ │ │ 收敛极慢 │
├───┼──────────────────────────────────────────────────────────┤
│ 3 │ view/reshape 维度顺序错误 │ 输出形状不对 │
│ │ (分头时 batch 和 heads 维度搞反) │ 或运行时报错 │
├───┼──────────────────────────────────────────────────────────┤
│ 4 │ 残差连接位置错误 │ 深层网络无法训练 │
│ │ (Post-LN vs Pre-LN 搞混) │ 梯度消失 │
├───┼──────────────────────────────────────────────────────────┤
│ 5 │ contiguous() 缺失 │ view() 报错 │
│ │ (transpose 后必须 contiguous 才能 view)│ RuntimeError │
├───┼──────────────────────────────────────────────────────────┤
│ 6 │ mask 值类型不匹配 │ 注意力权重异常 │
│ │ (float mask vs bool mask) │ 或 NaN 出现 │
└───┴──────────────────────────────────────────────────────────┘
10.3 验证代码正确性的方法¶
def verify_implementation():
"""快速验证 Transformer 实现的正确性"""
# 测试1: 形状正确性
model = Transformer(src_vocab_size=50, tgt_vocab_size=50,
d_model=64, num_heads=4, num_encoder_layers=2,
num_decoder_layers=2, d_ff=128)
src = torch.randint(1, 50, (2, 8))
tgt = torch.randint(1, 50, (2, 6))
logits = model(src, tgt)
assert logits.shape == (2, 6, 50), f"形状错误: {logits.shape}"
print("✅ 测试1通过: 输出形状正确")
# 测试2: 因果掩码有效性
# 验证位置 i 的输出不依赖于位置 i+1 的输入
tgt_test = torch.randint(1, 50, (1, 5))
with torch.no_grad():
out1 = model(src[:1], tgt_test)
# 修改最后一个 token,前面的输出不应改变
tgt_modified = tgt_test.clone()
tgt_modified[0, -1] = 0
out2 = model(src[:1], tgt_modified)
# 前面位置的输出应该相同(因果掩码保证)
assert torch.allclose(out1[0, :-1], out2[0, :-1], atol=1e-5), "因果掩码可能有问题!"
print("✅ 测试2通过: 因果掩码有效")
# 测试3: 梯度流通性
model.train()
logits = model(src, tgt)
loss = logits.sum()
loss.backward()
# 检查所有参数都有梯度
for name, param in model.named_parameters():
assert param.grad is not None, f"参数 {name} 没有梯度!"
print("✅ 测试3通过: 梯度流通正常")
# 测试4: 权重共享验证
# 检查输出投影层和嵌入层是否共享权重
assert model.output_projection.weight.data_ptr() == model.tgt_embedding.weight.data_ptr()
print("✅ 测试4通过: 权重共享已启用")
print("\n🎉 所有验证测试通过!实现正确。")
verify_implementation()
📝 自测检查清单¶
- 能解释缩放因子 \(\sqrt{d_k}\) 的作用(防止 softmax 饱和)
- 能手绘多头注意力的分头-计算-拼接流程
- Copy Task 在 10 个 epoch 内达到 99%+准确率
- 字符级 LM 在 100 个 epoch 后能生成有意义的文本
- 能独立调试 mask 相关的 bug (最常见的错误来源)
- 理解 Pre-LN vs Post-LN 的区别和原因
思考题¶
-
权重共享的影响:本实现中使用了 Weight Tying(输出投影层与嵌入层共享权重)。如果去掉权重共享,模型的参数量会如何变化?对训练效果有什么影响?尝试修改代码去掉权重共享,对比训练曲线。
-
字符级 vs 子词级:本教程使用字符级语言模型。如果改为 BPE 子词分词,词表大小从 ~40 增加到 ~32000,模型需要哪些调整?(提示:embedding 层、输出层、序列长度、学习率)
-
Decoder-Only vs Encoder-Decoder:在 Copy Task 中,Encoder-Decoder 和 Decoder-Only 哪个更容易学习?为什么?尝试用 Decoder-Only 模型做 Copy Task,对比收敛速度。
-
训练稳定性:如果将模型深度从 2 层增加到 12 层,不做任何其他修改,训练还能稳定吗?需要哪些额外措施?(提示:学习率 warm-up、梯度裁剪、残差连接的缩放)
📚 参考¶
- Vaswani et al., "Attention Is All You Need" (2017)
- Andrej Karpathy, "Let's build GPT: from scratch, in code" (2023)
- Harvard NLP, "The Annotated Transformer" (2018)
最后更新日期: 2026-03-26 适用版本: LLM 学习教程 v2026