跳转至

项目1: 文本摘要系统

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

难度: ⭐⭐ 中等 时间: 8-12小时 涉及知识: Transformer、Seq2Seq、文本预处理、模型微调


📖 项目概述

项目背景

随着信息爆炸时代的到来,每天产生海量的文本数据。人们需要快速从长篇文章中提取关键信息,文本摘要技术应运而生。传统的摘要方法主要基于规则和统计,难以理解文本语义。基于深度学习的文本摘要系统,特别是使用Transformer架构的模型,能够生成更加准确和流畅的摘要。

项目目标

构建一个完整的文本摘要系统,能够: - 支持多种文本格式的输入(新闻、论文、博客等) - 生成高质量的摘要(抽取式和生成式) - 支持摘要长度控制 - 提供摘要质量评估 - 支持批量处理 - 提供友好的Web界面

技术栈

  • 深度学习框架: PyTorch
  • 模型架构: Transformer / BART / T5
  • 数据处理: NLTK, spaCy
  • 评估指标: ROUGE, BLEU
  • Web框架: Streamlit
  • 预训练模型: HuggingFace Transformers

🏗️ 项目结构

Text Only
text-summarization-system/
├── app/                      # 应用主目录
│   ├── __init__.py
│   ├── config.py            # 配置文件
│   ├── models/              # 模型定义
│   │   ├── __init__.py
│   │   ├── transformer.py  # Transformer模型
│   │   └── preprocess.py    # 预处理模块
│   ├── data/                # 数据处理
│   │   ├── __init__.py
│   │   ├── dataset.py       # 数据集类
│   │   └── loader.py        # 数据加载器
│   ├── train.py             # 训练脚本
│   ├── evaluate.py          # 评估脚本
│   └── api/                 # API接口
│       ├── __init__.py
│       └── summarize.py     # 摘要API
├── frontend/                 # 前端目录
│   ├── app.py              # Streamlit应用
│   └── components/         # UI组件
├── data/                    # 数据目录
│   ├── raw/                # 原始数据
│   ├── processed/          # 处理后数据
│   └── cache/              # 缓存数据
├── models/                  # 模型目录
│   ├── checkpoints/        # 模型检查点
│   └── pretrained/         # 预训练模型
├── tests/                   # 测试目录
│   ├── test_models.py
│   ├── test_data.py
│   └── test_evaluation.py
├── utils/                   # 工具函数
│   ├── __init__.py
│   ├── metrics.py          # 评估指标
│   └── text_utils.py       # 文本工具
├── requirements.txt         # Python依赖
└── README.md               # 项目说明

🎯 核心功能

1. 文本预处理

  • 文本清洗: 去除HTML标签、特殊字符
  • 分词处理: 使用分词器进行文本切分
  • 长度控制: 截断或填充到指定长度
  • 数据增强: 同义词替换、回译等

2. 摘要生成

  • 抽取式摘要: 提取原文中的重要句子
  • 生成式摘要: 使用模型生成新的摘要
  • 长度控制: 支持指定摘要长度
  • 多样性控制: 控制摘要的多样性

3. 模型训练

  • 预训练模型加载: 加载BART/T5等预训练模型
  • 微调训练: 在特定数据集上微调
  • 学习率调度: 使用学习率调度器
  • 早停机制: 防止过拟合

4. 质量评估

  • ROUGE指标: ROUGE-1, ROUGE-2, ROUGE-L
  • BLEU指标: 评估摘要质量
  • 人工评估: 提供人工评估接口
  • 对比分析: 与其他方法对比

5. Web界面

  • 文本输入: 支持文本框和文件上传
  • 摘要展示: 清晰展示生成的摘要
  • 参数调节: 实时调节摘要参数
  • 批量处理: 支持批量文本处理

💻 代码实现

1. 配置文件 (app/config.py)

Python
"""
文本摘要系统配置文件
"""
import os
from typing import Optional
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    """应用配置"""

    # 模型配置
    MODEL_NAME: str = "facebook/bart-large-cnn"
    MAX_SOURCE_LENGTH: int = 1024
    MAX_TARGET_LENGTH: int = 256
    MIN_TARGET_LENGTH: int = 56

    # 训练配置
    BATCH_SIZE: int = 4
    LEARNING_RATE: float = 3e-5
    NUM_EPOCHS: int = 3
    WARMUP_STEPS: int = 500
    GRADIENT_ACCUMULATION_STEPS: int = 2

    # 数据配置
    TRAIN_DATA_PATH: str = "./data/raw/train.json"
    VAL_DATA_PATH: str = "./data/raw/val.json"
    TEST_DATA_PATH: str = "./data/raw/test.json"

    # 评估配置
    EVAL_STEPS: int = 500
    SAVE_STEPS: int = 1000

    # API配置
    API_HOST: str = "0.0.0.0"
    API_PORT: int = 8000

    # 缓存配置
    CACHE_DIR: str = "./data/cache"

    class Config:
        env_file = ".env"

settings = Settings()

2. 数据集类 (app/data/dataset.py)

Python
"""
文本摘要数据集类
"""
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from typing import List, Dict, Any

class SummarizationDataset(Dataset):
    """文本摘要数据集"""

    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_source_length: int = 1024,
        max_target_length: int = 256,
    ):
        """
        初始化数据集

        Args:
            data_path: 数据文件路径
            tokenizer: 分词器
            max_source_length: 源文本最大长度
            max_target_length: 目标文本最大长度
        """
        self.data = self._load_data(data_path)
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def _load_data(self, data_path: str) -> List[Dict[str, Any]]:
        """加载数据"""
        with open(data_path, 'r', encoding='utf-8') as f:  # with自动管理文件关闭
            data = [json.loads(line) for line in f]  # json.loads将JSON字符串→Python对象
        return data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """获取数据项"""
        item = self.data[idx]

        # 编码源文本
        source_encoding = self.tokenizer(
            item['article'],
            max_length=self.max_source_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # 编码目标文本
        target_encoding = self.tokenizer(
            item['summary'],
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': source_encoding['input_ids'].flatten(),
            'attention_mask': source_encoding['attention_mask'].flatten(),
            'labels': target_encoding['input_ids'].flatten(),
        }

3. Transformer模型 (app/models/transformer.py)

Python
"""
Transformer文本摘要模型
"""
import torch
import torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import Dict, Optional

class SummarizationModel:
    """文本摘要模型"""

    def __init__(
        self,
        model_name: str = "facebook/bart-large-cnn",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        """
        初始化模型

        Args:
            model_name: 预训练模型名称
            device: 设备
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.model.to(device)  # .to(device)将数据移至GPU/CPU
        self.model.eval()

    def summarize(
        self,
        text: str,
        max_length: int = 256,
        min_length: int = 56,
        num_beams: int = 4,
        no_repeat_ngram_size: int = 3,
        length_penalty: float = 1.0,
        early_stopping: bool = True,
    ) -> str:
        """
        生成摘要

        Args:
            text: 输入文本
            max_length: 最大长度
            min_length: 最小长度
            num_beams: beam search数量
            no_repeat_ngram_size: 不重复n-gram大小
            length_penalty: 长度惩罚
            early_stopping: 是否早停

        Returns:
            摘要文本
        """
        # 编码输入
        inputs = self.tokenizer(
            text,
            max_length=1024,
            truncation=True,
            return_tensors='pt'
        ).to(self.device)

        # 生成摘要
        with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
            summary_ids = self.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                min_length=min_length,
                num_beams=num_beams,
                no_repeat_ngram_size=no_repeat_ngram_size,
                length_penalty=length_penalty,
                early_stopping=early_stopping,
            )

        # 解码摘要
        summary = self.tokenizer.decode(
            summary_ids[0],
            skip_special_tokens=True
        )

        return summary

    def batch_summarize(
        self,
        texts: list,
        batch_size: int = 8,
        **kwargs
    ) -> list:
        """
        批量生成摘要

        Args:
            texts: 输入文本列表
            batch_size: 批次大小
            **kwargs: 其他参数

        Returns:
            摘要列表
        """
        summaries = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_summaries = [
                self.summarize(text, **kwargs)
                for text in batch_texts
            ]
            summaries.extend(batch_summaries)
        return summaries

4. 训练脚本 (app/train.py)

Python
"""
模型训练脚本
"""
import os
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    get_linear_schedule_with_warmup,
)
from tqdm import tqdm
from typing import Dict

from app.data.dataset import SummarizationDataset
from app.config import settings

def train_epoch(
    model: AutoModelForSeq2SeqLM,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: str,
) -> float:
    """
    训练一个epoch

    Args:
        model: 模型
        dataloader: 数据加载器
        optimizer: 优化器
        scheduler: 学习率调度器
        device: 设备

    Returns:
        平均损失
    """
    model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        # 移动数据到设备
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # 前向传播
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

        loss = outputs.loss
        total_loss += loss.item()

        # 反向传播
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            max_norm=1.0
        )

        # 更新参数
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        # 更新进度条
        progress_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)

def train():
    """训练模型"""
    # 加载分词器
    tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)

    # 加载数据集
    train_dataset = SummarizationDataset(
        settings.TRAIN_DATA_PATH,
        tokenizer,
        settings.MAX_SOURCE_LENGTH,
        settings.MAX_TARGET_LENGTH,
    )

    val_dataset = SummarizationDataset(
        settings.VAL_DATA_PATH,
        tokenizer,
        settings.MAX_SOURCE_LENGTH,
        settings.MAX_TARGET_LENGTH,
    )

    # 创建数据加载器
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=settings.BATCH_SIZE,
        shuffle=True,
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=settings.BATCH_SIZE,
        shuffle=False,
    )

    # 加载模型
    model = AutoModelForSeq2SeqLM.from_pretrained(settings.MODEL_NAME)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    # 创建优化器
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=settings.LEARNING_RATE,
    )

    # 创建学习率调度器
    total_steps = len(train_dataloader) * settings.NUM_EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=settings.WARMUP_STEPS,
        num_training_steps=total_steps,
    )

    # 训练循环
    best_val_loss = float('inf')
    for epoch in range(settings.NUM_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{settings.NUM_EPOCHS}")

        # 训练
        train_loss = train_epoch(
            model,
            train_dataloader,
            optimizer,
            scheduler,
            device,
        )
        print(f"Train Loss: {train_loss:.4f}")

        # 验证
        val_loss = evaluate(model, val_dataloader, device)
        print(f"Val Loss: {val_loss:.4f}")

        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_path = f"./models/checkpoints/best_model.pt"
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model to {save_path}")

def evaluate(
    model: AutoModelForSeq2SeqLM,
    dataloader: DataLoader,
    device: str,
) -> float:
    """评估模型"""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            total_loss += outputs.loss.item()

    return total_loss / len(dataloader)

if __name__ == "__main__":
    train()

5. 评估指标 (utils/metrics.py)

Python
"""
摘要评估指标
"""
from rouge_score import rouge_scorer
from typing import List, Dict
import numpy as np

class SummarizationEvaluator:
    """摘要评估器"""

    def __init__(self):
        """初始化评估器"""
        self.scorer = rouge_scorer.RougeScorer(
            ['rouge1', 'rouge2', 'rougeL'],
            use_stemmer=True
        )

    def compute_rouge(
        self,
        predictions: List[str],
        references: List[str],
    ) -> Dict[str, float]:
        """
        计算ROUGE分数

        Args:
            predictions: 预测摘要列表
            references: 参考摘要列表

        Returns:
            ROUGE分数字典
        """
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []

        for pred, ref in zip(predictions, references):  # zip按位置配对多个可迭代对象
            scores = self.scorer.score(ref, pred)
            rouge1_scores.append(scores['rouge1'].fmeasure)
            rouge2_scores.append(scores['rouge2'].fmeasure)
            rougeL_scores.append(scores['rougeL'].fmeasure)

        return {
            'rouge1': np.mean(rouge1_scores),
            'rouge2': np.mean(rouge2_scores),
            'rougeL': np.mean(rougeL_scores),
        }

    def evaluate(
        self,
        predictions: List[str],
        references: List[str],
    ) -> Dict[str, float]:
        """
        综合评估

        Args:
            predictions: 预测摘要列表
            references: 参考摘要列表

        Returns:
            评估结果字典
        """
        rouge_scores = self.compute_rouge(predictions, references)

        # 计算平均长度
        pred_lengths = [len(pred.split()) for pred in predictions]
        ref_lengths = [len(ref.split()) for ref in references]

        return {
            **rouge_scores,
            'avg_pred_length': np.mean(pred_lengths),
            'avg_ref_length': np.mean(ref_lengths),
        }

6. Streamlit前端 (frontend/app.py)

Python
"""
文本摘要系统Web界面
"""
import streamlit as st
from app.models.transformer import SummarizationModel
from app.config import settings

# 页面配置
st.set_page_config(
    page_title="文本摘要系统",
    page_icon="📝",
    layout="wide"
)

# 标题
st.title("📝 文本摘要系统")
st.markdown("---")

# 侧边栏
st.sidebar.header("参数设置")

# 摘要长度
max_length = st.sidebar.slider(
    "最大长度",
    min_value=50,
    max_value=500,
    value=256,
    step=10
)

min_length = st.sidebar.slider(
    "最小长度",
    min_value=10,
    max_value=100,
    value=56,
    step=5
)

# Beam数量
num_beams = st.sidebar.slider(
    "Beam Search数量",
    min_value=1,
    max_value=10,
    value=4,
    step=1
)

# 长度惩罚
length_penalty = st.sidebar.slider(
    "长度惩罚",
    min_value=0.5,
    max_value=2.0,
    value=1.0,
    step=0.1
)

# 加载模型
@st.cache_resource
def load_model():
    """加载模型"""
    return SummarizationModel(
        model_name=settings.MODEL_NAME,
    )

model = load_model()

# 输入方式
input_method = st.radio(
    "选择输入方式",
    ["文本输入", "文件上传"],
    horizontal=True
)

# 文本输入
if input_method == "文本输入":
    text = st.text_area(
        "输入文本",
        height=300,
        placeholder="请输入需要摘要的文本..."
    )

    # 生成摘要按钮
    if st.button("生成摘要", type="primary"):
        if text.strip():  # 链式调用:strip去除空白
            with st.spinner("正在生成摘要..."):
                summary = model.summarize(
                    text,
                    max_length=max_length,
                    min_length=min_length,
                    num_beams=num_beams,
                    length_penalty=length_penalty,
                )

                # 显示摘要
                st.success("摘要生成成功!")
                st.subheader("生成的摘要")
                st.write(summary)

                # 统计信息
                col1, col2 = st.columns(2)
                with col1:
                    st.metric("原文长度", len(text.split()))
                with col2:
                    st.metric("摘要长度", len(summary.split()))
        else:
            st.warning("请输入文本!")

# 文件上传
else:
    uploaded_file = st.file_uploader(
        "上传文件",
        type=['txt', 'md'],
    )

    if uploaded_file is not None:
        text = uploaded_file.read().decode('utf-8')
        st.text_area("文件内容", text, height=200)

        if st.button("生成摘要", type="primary"):
            with st.spinner("正在生成摘要..."):
                summary = model.summarize(
                    text,
                    max_length=max_length,
                    min_length=min_length,
                    num_beams=num_beams,
                    length_penalty=length_penalty,
                )

                st.success("摘要生成成功!")
                st.subheader("生成的摘要")
                st.write(summary)

🧪 测试方法

1. 单元测试

Python
"""
单元测试示例
"""
import pytest
from app.models.transformer import SummarizationModel

def test_summarization():
    """测试摘要生成"""
    model = SummarizationModel()

    text = """
    人工智能是计算机科学的一个分支,它企图了解智能的实质,
    并生产出一种新的能以人类智能相似的方式做出反应的智能机器。
    该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。
    """

    summary = model.summarize(text)

    assert len(summary) > 0  # assert断言:条件False时抛出AssertionError
    assert len(summary) < len(text)
    print(f"Summary: {summary}")

2. 集成测试

Python
"""
集成测试示例
"""
def test_end_to_end():
    """端到端测试"""
    # 测试完整流程
    text = load_test_text()
    summary = model.summarize(text)
    scores = evaluator.evaluate([summary], [reference])

    assert scores['rouge1'] > 0.3

3. 性能测试

Python
"""
性能测试示例
"""
import time

def test_performance():
    """测试性能"""
    texts = [load_test_text() for _ in range(100)]

    start_time = time.time()
    summaries = model.batch_summarize(texts)
    end_time = time.time()

    avg_time = (end_time - start_time) / len(texts)
    print(f"Average time per text: {avg_time:.2f}s")

📊 扩展建议

1. 功能扩展

  • 多语言支持: 支持多语言文本摘要
  • 多文档摘要: 合并多个文档生成摘要
  • 关键词提取: 提取文本关键词
  • 摘要风格控制: 控制摘要的风格(正式/非正式)
  • 实时摘要: 支持流式文本实时摘要

2. 模型优化

  • 模型压缩: 使用知识蒸馏、量化等技术压缩模型
  • 加速推理: 使用ONNX、TensorRT等加速推理
  • 模型集成: 集成多个模型提升性能
  • 自适应摘要: 根据文本类型自动调整摘要策略

3. 数据增强

  • 数据扩充: 使用回译、同义词替换等方法扩充数据
  • 领域适应: 在特定领域数据上微调
  • 主动学习: 主动选择最有价值的样本进行标注

4. 系统优化

  • 缓存机制: 缓存常见文本的摘要
  • 异步处理: 使用异步处理提升性能
  • 分布式部署: 支持分布式部署
  • API优化: 优化API接口性能

📚 学习收获

完成本项目后,你将掌握:

  • ✅ Transformer架构的理解和应用
  • ✅ Seq2Seq模型的实现和训练
  • ✅ 文本预处理和数据增强技术
  • ✅ 模型微调和评估方法
  • ✅ ROUGE等评估指标的使用
  • ✅ Streamlit Web应用开发
  • ✅ 完整的NLP项目开发流程

🔗 参考资源


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026


项目完成时间: 8-12小时 难度等级: ⭐⭐ 中等 推荐指数: ⭐⭐⭐⭐⭐