03-文本生成实战¶
学习时间: 约4-6小时 难度级别: ⭐⭐⭐ 中级 前置知识: RNN/LSTM基础、PyTorch 项目目标: 基于字符级LSTM实现文本生成,体验完整的生成模型开发流程
目录¶
1. 项目概述¶
1.1 任务描述¶
训练一个字符级语言模型,学习文本中的模式和规律,然后自动生成新的文本。
1.2 技术路线¶
2. 数据准备¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
class TextDataset(Dataset):
"""字符级文本数据集"""
def __init__(self, text, seq_length=100): # __init__构造方法,创建对象时自动调用
self.seq_length = seq_length
# 构建字符表
self.chars = sorted(set(text))
self.vocab_size = len(self.chars)
self.char2idx = {ch: i for i, ch in enumerate(self.chars)} # enumerate同时获取索引和元素
self.idx2char = {i: ch for i, ch in enumerate(self.chars)}
# 编码文本
self.encoded = [self.char2idx[ch] for ch in text]
self.data = torch.LongTensor(self.encoded)
print(f"文本长度: {len(text):,} 字符")
print(f"词汇表大小: {self.vocab_size}")
print(f"样本数: {len(self)}")
def __len__(self): # __len__定义len()的行为
return len(self.data) - self.seq_length
def __getitem__(self, idx): # __getitem__定义索引访问行为
x = self.data[idx:idx + self.seq_length]
y = self.data[idx + 1:idx + self.seq_length + 1]
return x, y
# 示例文本(实际项目中用更大的语料)
sample_text = """
Deep learning is a subset of machine learning that uses neural networks with many layers.
These deep neural networks can learn complex patterns in data, making them powerful tools
for tasks like image recognition, natural language processing, and speech recognition.
The field has seen remarkable progress in recent years, driven by advances in computing
power, data availability, and algorithmic innovations.
""" * 100 # 重复以增加数据量
dataset = TextDataset(sample_text, seq_length=100)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True) # DataLoader批量加载数据,支持shuffle和多进程
3. 模型设计¶
Python
class CharLSTM(nn.Module): # 继承nn.Module定义神经网络层
"""字符级 LSTM 语言模型"""
def __init__(self, vocab_size, embed_dim=128, hidden_dim=512,
num_layers=2, dropout=0.3):
super().__init__() # super()调用父类方法
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim, vocab_size)
# 权重绑定(可选)
if embed_dim == hidden_dim:
self.fc.weight = self.embedding.weight
def forward(self, x, hidden=None):
"""
x: (batch, seq_len)
hidden: (h_0, c_0) 或 None
"""
# 也可以使用Transformer架构替代LSTM
# 
embed = self.dropout(self.embedding(x))
output, hidden = self.lstm(embed, hidden)
output = self.dropout(output)
logits = self.fc(output) # (batch, seq_len, vocab_size)
return logits, hidden
def init_hidden(self, batch_size, device):
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
return (h0, c0)
4. 训练流程¶
Python
def train_model(model, dataloader, epochs=50, lr=2e-3, device='cuda', clip=1.0):
"""训练字符级语言模型"""
model = model.to(device) # .to(device)将数据移至GPU/CPU
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train() # train()开启训练模式
total_loss = 0
num_batches = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
hidden = model.init_hidden(x.size(0), device)
hidden = tuple(h.detach() for h in hidden) # detach()从计算图分离,不参与梯度计算
logits, hidden = model(x, hidden)
loss = criterion(logits.view(-1, model.vocab_size), y.view(-1)) # view重塑张量形状(要求内存连续)
optimizer.zero_grad() # 清零梯度,防止梯度累积
loss.backward() # 反向传播计算梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step() # 根据梯度更新模型参数
total_loss += loss.item() # .item()将单元素张量转为Python数值
num_batches += 1
scheduler.step()
avg_loss = total_loss / num_batches
perplexity = np.exp(avg_loss)
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Perplexity={perplexity:.2f}")
# 生成样本
sample = generate_text(model, dataset, seed_text="Deep ",
length=200, device=device)
print(f" 生成文本: {sample[:100]}...")
return model
5. 文本生成与采样策略¶
Python
def generate_text(model, dataset, seed_text="The ", length=500,
temperature=0.8, device='cuda', strategy='temperature'):
"""生成文本"""
model.eval() # eval()开启评估模式(关闭Dropout等)
# 编码种子文本
chars = [dataset.char2idx.get(ch, 0) for ch in seed_text]
input_seq = torch.LongTensor([chars]).to(device) # 链式调用,连续执行多个方法
hidden = model.init_hidden(1, device)
generated = list(seed_text)
# 先处理种子文本
with torch.no_grad():
logits, hidden = model(input_seq, hidden)
# 逐字符生成
last_char = input_seq[:, -1:]
with torch.no_grad():
for _ in range(length):
logits, hidden = model(last_char, hidden)
logits = logits[:, -1, :] # 最后一个位置的输出
if strategy == 'temperature':
next_char = temperature_sampling(logits, temperature)
elif strategy == 'top_k':
next_char = top_k_sampling(logits, k=10, temperature=temperature)
elif strategy == 'top_p':
next_char = top_p_sampling(logits, p=0.9, temperature=temperature)
elif strategy == 'greedy':
next_char = logits.argmax(dim=-1, keepdim=True)
else:
next_char = temperature_sampling(logits, temperature)
char = dataset.idx2char[next_char.item()]
generated.append(char)
last_char = next_char.unsqueeze(0) # unsqueeze增加一个维度
return ''.join(generated)
def temperature_sampling(logits, temperature=1.0):
"""温度采样"""
probs = F.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs, 1)
def top_k_sampling(logits, k=10, temperature=1.0):
"""Top-K 采样"""
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
probs = F.softmax(top_k_logits, dim=-1)
idx = torch.multinomial(probs, 1)
return top_k_indices.gather(-1, idx)
def top_p_sampling(logits, p=0.9, temperature=1.0):
"""Top-P (Nucleus) 采样"""
logits = logits / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 找到累积概率超过 p 的位置
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
sorted_logits[sorted_indices_to_remove] = -float('inf')
probs = F.softmax(sorted_logits, dim=-1)
idx = torch.multinomial(probs, 1)
return sorted_indices.gather(-1, idx)
6. 完整项目代码¶
Python
def main():
"""完整的文本生成项目"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 1. 加载文本数据
# 实际项目中可以使用:
# - 莎士比亚全集
# - 唐诗/宋词
# - Python代码
# - 维基百科
text = open('input.txt', 'r', encoding='utf-8').read() if False else sample_text
# 2. 创建数据集
dataset = TextDataset(text, seq_length=128)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
# 3. 创建模型
model = CharLSTM(
vocab_size=dataset.vocab_size,
embed_dim=128,
hidden_dim=512,
num_layers=2,
dropout=0.3
)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数量: {total_params:,}")
# 4. 训练
model = train_model(model, dataloader, epochs=50, lr=2e-3, device=device)
# 5. 生成文本(不同采样策略对比)
print("\n" + "=" * 60)
print("文本生成对比")
print("=" * 60)
strategies = {
'greedy': {'strategy': 'greedy'},
'temperature=0.5': {'strategy': 'temperature', 'temperature': 0.5},
'temperature=1.0': {'strategy': 'temperature', 'temperature': 1.0},
'top_k=10': {'strategy': 'top_k', 'temperature': 0.8},
'top_p=0.9': {'strategy': 'top_p', 'temperature': 0.8},
}
for name, kwargs in strategies.items():
text = generate_text(model, dataset, seed_text="Deep learning ",
length=200, device=device, **kwargs) # *args接收任意位置参数,**kwargs接收任意关键字参数
print(f"\n[{name}]:")
print(text[:200])
# 6. 保存模型
torch.save({
'model_state': model.state_dict(),
'char2idx': dataset.char2idx,
'idx2char': dataset.idx2char,
'vocab_size': dataset.vocab_size,
}, 'char_lstm_model.pt')
print("\n模型已保存!")
# main()
7. 扩展思考¶
- 词级别模型:将字符级改为词级别,使用预训练词嵌入。
- 注意力增强:在 LSTM 上添加注意力机制。
- Transformer 替代:用 Transformer Decoder 替代 LSTM。
- 条件生成:加入控制条件(如风格、主题)。
- 中文生成:处理中文文本,学习唐诗或歌词生成。
采样策略总结¶
| 策略 | 多样性 | 质量 | 适用场景 |
|---|---|---|---|
| Greedy | 最低 | 重复严重 | 不推荐 |
| Temperature < 1 | 低 | 高 | 需要确定性文本 |
| Temperature > 1 | 高 | 可能不连贯 | 需要创意文本 |
| Top-K | 中等 | 中等 | 通用 |
| Top-P | 自适应 | 较高 | 推荐默认策略 |
下一个实战项目: 04-GAN图像生成实战