项目1: 文本摘要系统¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
难度: ⭐⭐ 中等 时间: 8-12小时 涉及知识: Transformer、Seq2Seq、文本预处理、模型微调
📖 项目概述¶
项目背景¶
随着信息爆炸时代的到来,每天产生海量的文本数据。人们需要快速从长篇文章中提取关键信息,文本摘要技术应运而生。传统的摘要方法主要基于规则和统计,难以理解文本语义。基于深度学习的文本摘要系统,特别是使用Transformer架构的模型,能够生成更加准确和流畅的摘要。
项目目标¶
构建一个完整的文本摘要系统,能够: - 支持多种文本格式的输入(新闻、论文、博客等) - 生成高质量的摘要(抽取式和生成式) - 支持摘要长度控制 - 提供摘要质量评估 - 支持批量处理 - 提供友好的Web界面
技术栈¶
- 深度学习框架: PyTorch
- 模型架构: Transformer / BART / T5
- 数据处理: NLTK, spaCy
- 评估指标: ROUGE, BLEU
- Web框架: Streamlit
- 预训练模型: HuggingFace Transformers
🏗️ 项目结构¶
Text Only
text-summarization-system/
├── app/ # 应用主目录
│ ├── __init__.py
│ ├── config.py # 配置文件
│ ├── models/ # 模型定义
│ │ ├── __init__.py
│ │ ├── transformer.py # Transformer模型
│ │ └── preprocess.py # 预处理模块
│ ├── data/ # 数据处理
│ │ ├── __init__.py
│ │ ├── dataset.py # 数据集类
│ │ └── loader.py # 数据加载器
│ ├── train.py # 训练脚本
│ ├── evaluate.py # 评估脚本
│ └── api/ # API接口
│ ├── __init__.py
│ └── summarize.py # 摘要API
├── frontend/ # 前端目录
│ ├── app.py # Streamlit应用
│ └── components/ # UI组件
├── data/ # 数据目录
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后数据
│ └── cache/ # 缓存数据
├── models/ # 模型目录
│ ├── checkpoints/ # 模型检查点
│ └── pretrained/ # 预训练模型
├── tests/ # 测试目录
│ ├── test_models.py
│ ├── test_data.py
│ └── test_evaluation.py
├── utils/ # 工具函数
│ ├── __init__.py
│ ├── metrics.py # 评估指标
│ └── text_utils.py # 文本工具
├── requirements.txt # Python依赖
└── README.md # 项目说明
🎯 核心功能¶
1. 文本预处理¶
- 文本清洗: 去除HTML标签、特殊字符
- 分词处理: 使用分词器进行文本切分
- 长度控制: 截断或填充到指定长度
- 数据增强: 同义词替换、回译等
2. 摘要生成¶
- 抽取式摘要: 提取原文中的重要句子
- 生成式摘要: 使用模型生成新的摘要
- 长度控制: 支持指定摘要长度
- 多样性控制: 控制摘要的多样性
3. 模型训练¶
- 预训练模型加载: 加载BART/T5等预训练模型
- 微调训练: 在特定数据集上微调
- 学习率调度: 使用学习率调度器
- 早停机制: 防止过拟合
4. 质量评估¶
- ROUGE指标: ROUGE-1, ROUGE-2, ROUGE-L
- BLEU指标: 评估摘要质量
- 人工评估: 提供人工评估接口
- 对比分析: 与其他方法对比
5. Web界面¶
- 文本输入: 支持文本框和文件上传
- 摘要展示: 清晰展示生成的摘要
- 参数调节: 实时调节摘要参数
- 批量处理: 支持批量文本处理
💻 代码实现¶
1. 配置文件 (app/config.py)¶
Python
"""
文本摘要系统配置文件
"""
import os
from typing import Optional
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""应用配置"""
# 模型配置
MODEL_NAME: str = "facebook/bart-large-cnn"
MAX_SOURCE_LENGTH: int = 1024
MAX_TARGET_LENGTH: int = 256
MIN_TARGET_LENGTH: int = 56
# 训练配置
BATCH_SIZE: int = 4
LEARNING_RATE: float = 3e-5
NUM_EPOCHS: int = 3
WARMUP_STEPS: int = 500
GRADIENT_ACCUMULATION_STEPS: int = 2
# 数据配置
TRAIN_DATA_PATH: str = "./data/raw/train.json"
VAL_DATA_PATH: str = "./data/raw/val.json"
TEST_DATA_PATH: str = "./data/raw/test.json"
# 评估配置
EVAL_STEPS: int = 500
SAVE_STEPS: int = 1000
# API配置
API_HOST: str = "0.0.0.0"
API_PORT: int = 8000
# 缓存配置
CACHE_DIR: str = "./data/cache"
class Config:
env_file = ".env"
settings = Settings()
2. 数据集类 (app/data/dataset.py)¶
Python
"""
文本摘要数据集类
"""
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from typing import List, Dict, Any
class SummarizationDataset(Dataset):
"""文本摘要数据集"""
def __init__(
self,
data_path: str,
tokenizer: AutoTokenizer,
max_source_length: int = 1024,
max_target_length: int = 256,
):
"""
初始化数据集
Args:
data_path: 数据文件路径
tokenizer: 分词器
max_source_length: 源文本最大长度
max_target_length: 目标文本最大长度
"""
self.data = self._load_data(data_path)
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
def _load_data(self, data_path: str) -> List[Dict[str, Any]]:
"""加载数据"""
with open(data_path, 'r', encoding='utf-8') as f: # with自动管理文件关闭
data = [json.loads(line) for line in f] # json.loads将JSON字符串→Python对象
return data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""获取数据项"""
item = self.data[idx]
# 编码源文本
source_encoding = self.tokenizer(
item['article'],
max_length=self.max_source_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 编码目标文本
target_encoding = self.tokenizer(
item['summary'],
max_length=self.max_target_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': source_encoding['input_ids'].flatten(),
'attention_mask': source_encoding['attention_mask'].flatten(),
'labels': target_encoding['input_ids'].flatten(),
}
3. Transformer模型 (app/models/transformer.py)¶
Python
"""
Transformer文本摘要模型
"""
import torch
import torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import Dict, Optional
class SummarizationModel:
"""文本摘要模型"""
def __init__(
self,
model_name: str = "facebook/bart-large-cnn",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
初始化模型
Args:
model_name: 预训练模型名称
device: 设备
"""
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.to(device) # .to(device)将数据移至GPU/CPU
self.model.eval()
def summarize(
self,
text: str,
max_length: int = 256,
min_length: int = 56,
num_beams: int = 4,
no_repeat_ngram_size: int = 3,
length_penalty: float = 1.0,
early_stopping: bool = True,
) -> str:
"""
生成摘要
Args:
text: 输入文本
max_length: 最大长度
min_length: 最小长度
num_beams: beam search数量
no_repeat_ngram_size: 不重复n-gram大小
length_penalty: 长度惩罚
early_stopping: 是否早停
Returns:
摘要文本
"""
# 编码输入
inputs = self.tokenizer(
text,
max_length=1024,
truncation=True,
return_tensors='pt'
).to(self.device)
# 生成摘要
with torch.no_grad(): # 禁用梯度计算,节省内存(推理时使用)
summary_ids = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
early_stopping=early_stopping,
)
# 解码摘要
summary = self.tokenizer.decode(
summary_ids[0],
skip_special_tokens=True
)
return summary
def batch_summarize(
self,
texts: list,
batch_size: int = 8,
**kwargs
) -> list:
"""
批量生成摘要
Args:
texts: 输入文本列表
batch_size: 批次大小
**kwargs: 其他参数
Returns:
摘要列表
"""
summaries = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_summaries = [
self.summarize(text, **kwargs)
for text in batch_texts
]
summaries.extend(batch_summaries)
return summaries
4. 训练脚本 (app/train.py)¶
Python
"""
模型训练脚本
"""
import os
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
get_linear_schedule_with_warmup,
)
from tqdm import tqdm
from typing import Dict
from app.data.dataset import SummarizationDataset
from app.config import settings
def train_epoch(
model: AutoModelForSeq2SeqLM,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler,
device: str,
) -> float:
"""
训练一个epoch
Args:
model: 模型
dataloader: 数据加载器
optimizer: 优化器
scheduler: 学习率调度器
device: 设备
Returns:
平均损失
"""
model.train()
total_loss = 0
progress_bar = tqdm(dataloader, desc="Training")
for batch in progress_bar:
# 移动数据到设备
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# 前向传播
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss
total_loss += loss.item()
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0
)
# 更新参数
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# 更新进度条
progress_bar.set_postfix({'loss': loss.item()})
return total_loss / len(dataloader)
def train():
"""训练模型"""
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
# 加载数据集
train_dataset = SummarizationDataset(
settings.TRAIN_DATA_PATH,
tokenizer,
settings.MAX_SOURCE_LENGTH,
settings.MAX_TARGET_LENGTH,
)
val_dataset = SummarizationDataset(
settings.VAL_DATA_PATH,
tokenizer,
settings.MAX_SOURCE_LENGTH,
settings.MAX_TARGET_LENGTH,
)
# 创建数据加载器
train_dataloader = DataLoader(
train_dataset,
batch_size=settings.BATCH_SIZE,
shuffle=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=settings.BATCH_SIZE,
shuffle=False,
)
# 加载模型
model = AutoModelForSeq2SeqLM.from_pretrained(settings.MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# 创建优化器
optimizer = torch.optim.AdamW(
model.parameters(),
lr=settings.LEARNING_RATE,
)
# 创建学习率调度器
total_steps = len(train_dataloader) * settings.NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=settings.WARMUP_STEPS,
num_training_steps=total_steps,
)
# 训练循环
best_val_loss = float('inf')
for epoch in range(settings.NUM_EPOCHS):
print(f"\nEpoch {epoch + 1}/{settings.NUM_EPOCHS}")
# 训练
train_loss = train_epoch(
model,
train_dataloader,
optimizer,
scheduler,
device,
)
print(f"Train Loss: {train_loss:.4f}")
# 验证
val_loss = evaluate(model, val_dataloader, device)
print(f"Val Loss: {val_loss:.4f}")
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
save_path = f"./models/checkpoints/best_model.pt"
torch.save(model.state_dict(), save_path)
print(f"Saved best model to {save_path}")
def evaluate(
model: AutoModelForSeq2SeqLM,
dataloader: DataLoader,
device: str,
) -> float:
"""评估模型"""
model.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
total_loss += outputs.loss.item()
return total_loss / len(dataloader)
if __name__ == "__main__":
train()
5. 评估指标 (utils/metrics.py)¶
Python
"""
摘要评估指标
"""
from rouge_score import rouge_scorer
from typing import List, Dict
import numpy as np
class SummarizationEvaluator:
"""摘要评估器"""
def __init__(self):
"""初始化评估器"""
self.scorer = rouge_scorer.RougeScorer(
['rouge1', 'rouge2', 'rougeL'],
use_stemmer=True
)
def compute_rouge(
self,
predictions: List[str],
references: List[str],
) -> Dict[str, float]:
"""
计算ROUGE分数
Args:
predictions: 预测摘要列表
references: 参考摘要列表
Returns:
ROUGE分数字典
"""
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []
for pred, ref in zip(predictions, references): # zip按位置配对多个可迭代对象
scores = self.scorer.score(ref, pred)
rouge1_scores.append(scores['rouge1'].fmeasure)
rouge2_scores.append(scores['rouge2'].fmeasure)
rougeL_scores.append(scores['rougeL'].fmeasure)
return {
'rouge1': np.mean(rouge1_scores),
'rouge2': np.mean(rouge2_scores),
'rougeL': np.mean(rougeL_scores),
}
def evaluate(
self,
predictions: List[str],
references: List[str],
) -> Dict[str, float]:
"""
综合评估
Args:
predictions: 预测摘要列表
references: 参考摘要列表
Returns:
评估结果字典
"""
rouge_scores = self.compute_rouge(predictions, references)
# 计算平均长度
pred_lengths = [len(pred.split()) for pred in predictions]
ref_lengths = [len(ref.split()) for ref in references]
return {
**rouge_scores,
'avg_pred_length': np.mean(pred_lengths),
'avg_ref_length': np.mean(ref_lengths),
}
6. Streamlit前端 (frontend/app.py)¶
Python
"""
文本摘要系统Web界面
"""
import streamlit as st
from app.models.transformer import SummarizationModel
from app.config import settings
# 页面配置
st.set_page_config(
page_title="文本摘要系统",
page_icon="📝",
layout="wide"
)
# 标题
st.title("📝 文本摘要系统")
st.markdown("---")
# 侧边栏
st.sidebar.header("参数设置")
# 摘要长度
max_length = st.sidebar.slider(
"最大长度",
min_value=50,
max_value=500,
value=256,
step=10
)
min_length = st.sidebar.slider(
"最小长度",
min_value=10,
max_value=100,
value=56,
step=5
)
# Beam数量
num_beams = st.sidebar.slider(
"Beam Search数量",
min_value=1,
max_value=10,
value=4,
step=1
)
# 长度惩罚
length_penalty = st.sidebar.slider(
"长度惩罚",
min_value=0.5,
max_value=2.0,
value=1.0,
step=0.1
)
# 加载模型
@st.cache_resource
def load_model():
"""加载模型"""
return SummarizationModel(
model_name=settings.MODEL_NAME,
)
model = load_model()
# 输入方式
input_method = st.radio(
"选择输入方式",
["文本输入", "文件上传"],
horizontal=True
)
# 文本输入
if input_method == "文本输入":
text = st.text_area(
"输入文本",
height=300,
placeholder="请输入需要摘要的文本..."
)
# 生成摘要按钮
if st.button("生成摘要", type="primary"):
if text.strip(): # 链式调用:strip去除空白
with st.spinner("正在生成摘要..."):
summary = model.summarize(
text,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
length_penalty=length_penalty,
)
# 显示摘要
st.success("摘要生成成功!")
st.subheader("生成的摘要")
st.write(summary)
# 统计信息
col1, col2 = st.columns(2)
with col1:
st.metric("原文长度", len(text.split()))
with col2:
st.metric("摘要长度", len(summary.split()))
else:
st.warning("请输入文本!")
# 文件上传
else:
uploaded_file = st.file_uploader(
"上传文件",
type=['txt', 'md'],
)
if uploaded_file is not None:
text = uploaded_file.read().decode('utf-8')
st.text_area("文件内容", text, height=200)
if st.button("生成摘要", type="primary"):
with st.spinner("正在生成摘要..."):
summary = model.summarize(
text,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
length_penalty=length_penalty,
)
st.success("摘要生成成功!")
st.subheader("生成的摘要")
st.write(summary)
🧪 测试方法¶
1. 单元测试¶
Python
"""
单元测试示例
"""
import pytest
from app.models.transformer import SummarizationModel
def test_summarization():
"""测试摘要生成"""
model = SummarizationModel()
text = """
人工智能是计算机科学的一个分支,它企图了解智能的实质,
并生产出一种新的能以人类智能相似的方式做出反应的智能机器。
该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。
"""
summary = model.summarize(text)
assert len(summary) > 0 # assert断言:条件False时抛出AssertionError
assert len(summary) < len(text)
print(f"Summary: {summary}")
2. 集成测试¶
Python
"""
集成测试示例
"""
def test_end_to_end():
"""端到端测试"""
# 测试完整流程
text = load_test_text()
summary = model.summarize(text)
scores = evaluator.evaluate([summary], [reference])
assert scores['rouge1'] > 0.3
3. 性能测试¶
Python
"""
性能测试示例
"""
import time
def test_performance():
"""测试性能"""
texts = [load_test_text() for _ in range(100)]
start_time = time.time()
summaries = model.batch_summarize(texts)
end_time = time.time()
avg_time = (end_time - start_time) / len(texts)
print(f"Average time per text: {avg_time:.2f}s")
📊 扩展建议¶
1. 功能扩展¶
- 多语言支持: 支持多语言文本摘要
- 多文档摘要: 合并多个文档生成摘要
- 关键词提取: 提取文本关键词
- 摘要风格控制: 控制摘要的风格(正式/非正式)
- 实时摘要: 支持流式文本实时摘要
2. 模型优化¶
- 模型压缩: 使用知识蒸馏、量化等技术压缩模型
- 加速推理: 使用ONNX、TensorRT等加速推理
- 模型集成: 集成多个模型提升性能
- 自适应摘要: 根据文本类型自动调整摘要策略
3. 数据增强¶
- 数据扩充: 使用回译、同义词替换等方法扩充数据
- 领域适应: 在特定领域数据上微调
- 主动学习: 主动选择最有价值的样本进行标注
4. 系统优化¶
- 缓存机制: 缓存常见文本的摘要
- 异步处理: 使用异步处理提升性能
- 分布式部署: 支持分布式部署
- API优化: 优化API接口性能
📚 学习收获¶
完成本项目后,你将掌握:
- ✅ Transformer架构的理解和应用
- ✅ Seq2Seq模型的实现和训练
- ✅ 文本预处理和数据增强技术
- ✅ 模型微调和评估方法
- ✅ ROUGE等评估指标的使用
- ✅ Streamlit Web应用开发
- ✅ 完整的NLP项目开发流程
🔗 参考资源¶
最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026
项目完成时间: 8-12小时 难度等级: ⭐⭐ 中等 推荐指数: ⭐⭐⭐⭐⭐