跳转至

项目2: 代码生成系统

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

难度: ⭐⭐⭐ 中等偏难 时间: 12-18小时 涉及知识: Code LLM、Fine-tuning、代码理解、AST解析


📖 项目概述

项目背景

代码生成是AI辅助编程的核心能力之一。基于大语言模型的代码生成系统能够根据自然语言描述或部分代码,自动生成完整的代码片段、函数甚至整个模块。这不仅能提高开发效率,还能帮助新手学习编程,减少重复劳动。

项目目标

构建一个完整的代码生成系统,能够: - 根据自然语言描述生成代码 - 根据函数签名和文档字符串生成实现 - 支持多种编程语言 - 生成可运行的、符合规范的代码 - 提供代码解释和注释 - 支持代码补全和优化建议

技术栈

  • 深度学习框架: PyTorch
  • 预训练模型: CodeLlama / StarCoder / GPT-NeoX
  • 代码处理: tree-sitter, AST
  • 评估工具: CodeBLEU, Pass@k
  • Web框架: Streamlit
  • IDE集成: VS Code Extension API

🏗️ 项目结构

Text Only
code-generation-system/
├── app/                      # 应用主目录
│   ├── __init__.py
│   ├── config.py            # 配置文件
│   ├── models/              # 模型定义
│   │   ├── __init__.py
│   │   ├── code_llm.py      # 代码生成模型
│   │   └── preprocess.py    # 预处理模块
│   ├── data/                # 数据处理
│   │   ├── __init__.py
│   │   ├── dataset.py       # 数据集类
│   │   └── loader.py        # 数据加载器
│   ├── train.py             # 训练脚本
│   ├── evaluate.py          # 评估脚本
│   └── api/                 # API接口
│       ├── __init__.py
│       └── generate.py      # 生成API
├── frontend/                 # 前端目录
│   ├── app.py              # Streamlit应用
│   └── components/         # UI组件
├── data/                    # 数据目录
│   ├── raw/                # 原始数据
│   ├── processed/          # 处理后数据
│   └── cache/              # 缓存数据
├── models/                  # 模型目录
│   ├── checkpoints/        # 模型检查点
│   └── pretrained/         # 预训练模型
├── tests/                   # 测试目录
│   ├── test_models.py
│   ├── test_data.py
│   └── test_evaluation.py
├── utils/                   # 工具函数
│   ├── __init__.py
│   ├── ast_utils.py        # AST工具
│   ├── metrics.py          # 评估指标
│   └── code_utils.py       # 代码工具
├── requirements.txt         # Python依赖
└── README.md               # 项目说明

🎯 核心功能

1. 代码生成

  • 自然语言到代码: 根据描述生成代码
  • 函数实现: 根据签名和文档生成实现
  • 代码补全: 根据上下文补全代码
  • 代码转换: 在不同语言间转换代码

2. 代码理解

  • 代码解释: 生成代码的自然语言解释
  • 代码注释: 自动生成代码注释
  • 代码摘要: 生成代码功能摘要
  • Bug检测: 识别代码中的潜在问题

3. 代码优化

  • 性能优化: 优化代码性能
  • 代码重构: 重构代码结构
  • 代码审查: 提供代码审查建议
  • 最佳实践: 应用编程最佳实践

4. 多语言支持

  • Python: Python代码生成
  • JavaScript: JavaScript/TypeScript代码生成
  • Java: Java代码生成
  • C++: C++代码生成

5. 质量评估

  • 语法检查: 检查代码语法正确性
  • 可执行性: 验证代码可执行性
  • 代码质量: 评估代码质量
  • Pass@k: 评估生成成功率

💻 代码实现

1. 配置文件 (app/config.py)

Python
"""
代码生成系统配置文件
"""
import os
from typing import Optional
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    """应用配置"""

    # 模型配置
    MODEL_NAME: str = "bigcode/starcoder2-15b"
    MAX_LENGTH: int = 2048
    TEMPERATURE: float = 0.2
    TOP_P: float = 0.95
    TOP_K: int = 50

    # 训练配置
    BATCH_SIZE: int = 2
    LEARNING_RATE: float = 5e-5
    NUM_EPOCHS: int = 3
    WARMUP_STEPS: int = 100
    GRADIENT_ACCUMULATION_STEPS: int = 4

    # 数据配置
    TRAIN_DATA_PATH: str = "./data/raw/train.json"
    VAL_DATA_PATH: str = "./data/raw/val.json"
    TEST_DATA_PATH: str = "./data/raw/test.json"

    # 生成配置
    MAX_NEW_TOKENS: int = 512
    NUM_RETURN_SEQUENCES: int = 1
    DO_SAMPLE: bool = True

    # 评估配置
    EVAL_STEPS: int = 500
    SAVE_STEPS: int = 1000
    K_VALUES: list = [1, 10, 100]  # Pass@k的k值

    # API配置
    API_HOST: str = "0.0.0.0"
    API_PORT: int = 8000

    # 缓存配置
    CACHE_DIR: str = "./data/cache"

    class Config:
        env_file = ".env"

settings = Settings()

2. 代码生成模型 (app/models/code_llm.py)

Python
"""
代码生成模型
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Optional, List

class CodeGenerationModel:
    """代码生成模型"""

    def __init__(
        self,
        model_name: str = "bigcode/starcoder2-15b",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        """
        初始化模型

        Args:
            model_name: 预训练模型名称
            device: 设备
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map="auto" if device == "cuda" else None,
        )

        if device != "cuda":
            self.model.to(device)
        self.model.eval()

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.2,
        top_p: float = 0.95,
        top_k: int = 50,
        num_return_sequences: int = 1,
        do_sample: bool = True,
        stop_tokens: Optional[List[str]] = None,
    ) -> str:
        """
        生成代码

        Args:
            prompt: 输入提示
            max_new_tokens: 最大生成token数
            temperature: 温度参数
            top_p: nucleus sampling参数
            top_k: top-k sampling参数
            num_return_sequences: 返回序列数量
            do_sample: 是否采样
            stop_tokens: 停止token列表

        Returns:
            生成的代码
        """
        # 编码输入
        inputs = self.tokenizer(
            prompt,
            return_tensors='pt',
            truncation=True,
            max_length=2048 - max_new_tokens,
        ).to(self.device)

        # 生成代码
        with torch.no_grad():
            outputs = self.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                num_return_sequences=num_return_sequences,
                do_sample=do_sample,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        # 解码输出
        generated_code = self.tokenizer.decode(
            outputs[0],
            skip_special_tokens=True
        )

        # 提取生成的部分
        generated_code = generated_code[len(prompt):].strip()

        # 处理停止token
        if stop_tokens:
            for stop_token in stop_tokens:
                if stop_token in generated_code:
                    generated_code = generated_code.split(stop_token)[0]

        return generated_code

    def generate_function(
        self,
        function_signature: str,
        docstring: str = "",
        language: str = "python",
    ) -> str:
        """
        生成函数实现

        Args:
            function_signature: 函数签名
            docstring: 函数文档字符串
            language: 编程语言

        Returns:
            生成的函数实现
        """
        # 构建提示
        if language == "python":
            prompt = f"""```python
{function_signature}
    \"\"\"
    {docstring}
    \"\"\"
```"""
        else:
            prompt = f"```{language}\n{function_signature}\n```\n"

        # 生成代码
        generated_code = self.generate(prompt)

        return generated_code

    def generate_from_description(
        self,
        description: str,
        language: str = "python",
    ) -> str:
        """
        根据描述生成代码

        Args:
            description: 自然语言描述
            language: 编程语言

        Returns:
            生成的代码
        """
        # 构建提示
        prompt = f"""Write {language} code to:
{description}

```{language}
"""

        # 生成代码
        generated_code = self.generate(prompt)

        return generated_code

    def code_completion(
        self,
        prefix: str,
        max_new_tokens: int = 256,
    ) -> str:
        """
        代码补全

        Args:
            prefix: 代码前缀
            max_new_tokens: 最大生成token数

        Returns:
            补全的代码
        """
        return self.generate(
            prefix,
            max_new_tokens=max_new_tokens,
            temperature=0.1,  # 较低的温度以获得更确定的补全
        )

    def explain_code(
        self,
        code: str,
        language: str = "python",
    ) -> str:
        """
        解释代码

        Args:
            code: 代码
            language: 编程语言

        Returns:
            代码解释
        """
        prompt = f"""Explain the following {language} code:

```{language}
{code}
```text

Explanation:"""

        explanation = self.generate(
            prompt,
            max_new_tokens=512,
            temperature=0.3,
        )

        return explanation

    def add_comments(
        self,
        code: str,
        language: str = "python",
    ) -> str:
        """
        添加代码注释

        Args:
            code: 代码
            language: 编程语言

        Returns:
            带注释的代码
        """
        prompt = f"""Add comments to the following {language} code:

```{language}
{code}
```"""

        commented_code = self.generate(
            prompt,
            max_new_tokens=1024,
            temperature=0.1,
        )

        return commented_code

3. 数据集类 (app/data/dataset.py)

Python
"""
代码生成数据集类
"""
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from typing import List, Dict, Any

class CodeGenerationDataset(Dataset):
    """代码生成数据集"""

    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_length: int = 2048,
    ):
        """
        初始化数据集

        Args:
            data_path: 数据文件路径
            tokenizer: 分词器
            max_length: 最大长度
        """
        self.data = self._load_data(data_path)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def _load_data(self, data_path: str) -> List[Dict[str, Any]]:
        """加载数据"""
        with open(data_path, 'r', encoding='utf-8') as f:  # with自动管理文件关闭
            data = [json.loads(line) for line in f]  # json.loads将JSON字符串→Python对象
        return data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """获取数据项"""
        item = self.data[idx]

        # 构建提示
        prompt = self._build_prompt(item)

        # 编码
        encoding = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': encoding['input_ids'].flatten(),
        }

    def _build_prompt(self, item: Dict[str, Any]) -> str:
        """构建提示"""
        if 'description' in item:
            # 自然语言到代码
            return f"""Write code to: {item['description']}
{item['code']} """ elif 'signature' in item: # 函数签名到实现 return f""" {item['signature']} \"\"\" {item.get('docstring', '')} \"\"\" {item['code']}
Python
        else:
            # 代码补全
            return item['prefix'] + item['completion']

4. 评估指标 (utils/metrics.py)

Python
"""
代码生成评估指标
"""
import ast
import subprocess
import tempfile
import os
from typing import List, Dict, Tuple
import numpy as np

class CodeEvaluator:
    """代码评估器"""

    def __init__(self, language: str = "python"):
        """
        初始化评估器

        Args:
            language: 编程语言
        """
        self.language = language

    def check_syntax(self, code: str) -> bool:
        """
        检查代码语法

        Args:
            code: 代码

        Returns:
            语法是否正确
        """
        try:  # try/except捕获异常,防止程序崩溃
            if self.language == "python":
                ast.parse(code)
            return True
        except:
            return False

    def check_executable(
        self,
        code: str,
        timeout: int = 5,
    ) -> bool:
        """
        检查代码可执行性

        Args:
            code: 代码
            timeout: 超时时间

        Returns:
            代码是否可执行
        """
        if self.language != "python":
            return False

        try:
            # 创建临时文件
            with tempfile.NamedTemporaryFile(
                mode='w',
                suffix='.py',
                delete=False
            ) as f:
                f.write(code)
                temp_file = f.name

            # 执行代码
            result = subprocess.run(
                ['python', temp_file],
                timeout=timeout,
                capture_output=True,
            )

            # 删除临时文件
            os.unlink(temp_file)

            return result.returncode == 0
        except:
            return False

    def compute_pass_at_k(
        self,
        predictions: List[List[str]],
        references: List[str],
        k: int = 1,
    ) -> float:
        """
        计算Pass@k指标

        Args:
            predictions: 预测列表,每个元素是k个候选
            references: 参考代码列表
            k: k值

        Returns:
            Pass@k分数
        """
        n = len(predictions)
        c = 0

        for preds, ref in zip(predictions, references):  # zip按位置配对多个可迭代对象
            # 检查k个候选中是否有正确的
            correct = False
            for pred in preds[:k]:
                if self._is_correct(pred, ref):
                    correct = True
                    break
            if correct:
                c += 1

        return c / n

    def _is_correct(
        self,
        prediction: str,
        reference: str,
    ) -> bool:
        """
        检查预测是否正确

        Args:
            prediction: 预测代码
            reference: 参考代码

        Returns:
            是否正确
        """
        # 检查语法
        if not self.check_syntax(prediction):
            return False

        # 检查可执行性
        if not self.check_executable(prediction):
            return False

        # 这里可以添加更多的检查逻辑
        # 例如:运行测试用例、比较输出等

        return True

    def evaluate(
        self,
        predictions: List[List[str]],
        references: List[str],
        k_values: List[int] = [1, 10, 100],
    ) -> Dict[str, float]:
        """
        综合评估

        Args:
            predictions: 预测列表
            references: 参考列表
            k_values: k值列表

        Returns:
            评估结果
        """
        results = {}

        for k in k_values:
            pass_at_k = self.compute_pass_at_k(
                predictions,
                references,
                k=k,
            )
            results[f'pass_at_{k}'] = pass_at_k

        # 计算语法正确率
        syntax_correct = 0
        for preds in predictions:
            if any(self.check_syntax(pred) for pred in preds):  # any()任一为True则返回True
                syntax_correct += 1
        results['syntax_correct_rate'] = syntax_correct / len(predictions)

        return results

5. Streamlit前端 (frontend/app.py)

Python
"""
代码生成系统Web界面
"""
import streamlit as st
from app.models.code_llm import CodeGenerationModel
from app.config import settings

# 页面配置
st.set_page_config(
    page_title="代码生成系统",
    page_icon="💻",
    layout="wide"
)

# 标题
st.title("💻 代码生成系统")
st.markdown("---")

# 侧边栏
st.sidebar.header("参数设置")

# 模型选择
model_name = st.sidebar.selectbox(
    "选择模型",
    ["bigcode/starcoder2-15b", "bigcode/starcoder2-7b"],
)

# 生成参数
max_new_tokens = st.sidebar.slider(
    "最大生成token数",
    min_value=128,
    max_value=1024,
    value=512,
    step=64
)

temperature = st.sidebar.slider(
    "温度",
    min_value=0.0,
    max_value=1.0,
    value=0.2,
    step=0.1
)

top_p = st.sidebar.slider(
    "Top P",
    min_value=0.5,
    max_value=1.0,
    value=0.95,
    step=0.05
)

# 编程语言
language = st.sidebar.selectbox(
    "编程语言",
    ["python", "javascript", "java", "cpp"],
)

# 加载模型
@st.cache_resource
def load_model(model_name: str):
    """加载模型"""
    return CodeGenerationModel(
        model_name=model_name,
    )

model = load_model(model_name)

# 选择功能
function = st.radio(
    "选择功能",
    ["自然语言生成代码", "函数实现", "代码补全", "代码解释", "添加注释"],
    horizontal=True
)

# 自然语言生成代码
if function == "自然语言生成代码":
    description = st.text_area(
        "输入代码描述",
        height=150,
        placeholder="例如:实现一个快速排序算法..."
    )

    if st.button("生成代码", type="primary"):
        if description.strip():  # 链式调用:strip去除空白
            with st.spinner("正在生成代码..."):
                code = model.generate_from_description(
                    description,
                    language=language,
                )

                st.success("代码生成成功!")
                st.code(code, language=language)
        else:
            st.warning("请输入代码描述!")

# 函数实现
elif function == "函数实现":
    col1, col2 = st.columns(2)

    with col1:
        function_signature = st.text_area(
            "函数签名",
            height=100,
            placeholder="def quick_sort(arr):"
        )

    with col2:
        docstring = st.text_area(
            "文档字符串",
            height=100,
            placeholder="使用快速排序算法对数组进行排序..."
        )

    if st.button("生成实现", type="primary"):
        if function_signature.strip():
            with st.spinner("正在生成实现..."):
                code = model.generate_function(
                    function_signature,
                    docstring,
                    language,
                )

                st.success("实现生成成功!")
                st.code(code, language=language)
        else:
            st.warning("请输入函数签名!")

# 代码补全
elif function == "代码补全":
    prefix = st.text_area(
        "输入代码前缀",
        height=200,
        placeholder="输入部分代码..."
    )

    if st.button("补全代码", type="primary"):
        if prefix.strip():
            with st.spinner("正在补全代码..."):
                completion = model.code_completion(prefix)

                st.success("代码补全成功!")
                st.code(prefix + completion, language=language)
        else:
            st.warning("请输入代码前缀!")

# 代码解释
elif function == "代码解释":
    code = st.text_area(
        "输入代码",
        height=200,
        placeholder="输入需要解释的代码..."
    )

    if st.button("解释代码", type="primary"):
        if code.strip():
            with st.spinner("正在解释代码..."):
                explanation = model.explain_code(code, language)

                st.success("代码解释成功!")
                st.write(explanation)
        else:
            st.warning("请输入代码!")

# 添加注释
elif function == "添加注释":
    code = st.text_area(
        "输入代码",
        height=200,
        placeholder="输入需要添加注释的代码..."
    )

    if st.button("添加注释", type="primary"):
        if code.strip():
            with st.spinner("正在添加注释..."):
                commented_code = model.add_comments(code, language)

                st.success("注释添加成功!")
                st.code(commented_code, language=language)
        else:
            st.warning("请输入代码!")

🧪 测试方法

1. 单元测试

Python
"""
单元测试示例
"""
import pytest
from app.models.code_llm import CodeGenerationModel

def test_code_generation():
    """测试代码生成"""
    model = CodeGenerationModel()

    description = "实现一个快速排序算法"
    code = model.generate_from_description(description)

    assert len(code) > 0  # assert断言:条件False时抛出AssertionError
    print(f"Generated code:\n{code}")

def test_function_generation():
    """测试函数生成"""
    model = CodeGenerationModel()

    signature = "def quick_sort(arr):"
    docstring = "使用快速排序算法对数组进行排序"
    code = model.generate_function(signature, docstring)

    assert "def quick_sort" in code
    print(f"Generated function:\n{code}")

2. 语法检查测试

Python
"""
语法检查测试
"""
from utils.metrics import CodeEvaluator

def test_syntax_check():
    """测试语法检查"""
    evaluator = CodeEvaluator(language="python")

    # 正确的代码
    valid_code = "def hello():\n    print('Hello, World!')"
    assert evaluator.check_syntax(valid_code) == True

    # 错误的代码
    invalid_code = "def hello(\n    print('Hello, World!')"
    assert evaluator.check_syntax(invalid_code) == False

3. Pass@k测试

Python
"""
Pass@k测试
"""
def test_pass_at_k():
    """测试Pass@k指标"""
    evaluator = CodeEvaluator()

    predictions = [
        ["code1", "code2", "code3"],
        ["code4", "code5", "code6"],
    ]
    references = ["ref1", "ref2"]

    pass_at_1 = evaluator.compute_pass_at_k(predictions, references, k=1)
    pass_at_10 = evaluator.compute_pass_at_k(predictions, references, k=10)

    print(f"Pass@1: {pass_at_1}")
    print(f"Pass@10: {pass_at_10}")

📊 扩展建议

1. 功能扩展

  • 代码审查: 提供代码审查和改进建议
  • Bug修复: 自动检测和修复代码bug
  • 代码重构: 自动重构代码结构
  • 测试生成: 自动生成单元测试
  • 文档生成: 自动生成API文档

2. 模型优化

  • 领域适配: 在特定领域数据上微调
  • 模型压缩: 使用量化、蒸馏等技术
  • 推理加速: 使用ONNX、TensorRT加速
  • 多任务学习: 同时训练多个任务

3. IDE集成

  • VS Code插件: 开发VS Code扩展
  • JetBrains插件: 支持JetBrains IDE
  • 命令行工具: 提供CLI工具
  • API服务: 提供REST API

4. 数据增强

  • 代码混淆: 通过混淆增强数据
  • 风格迁移: 改变代码风格
  • 跨语言转换: 在不同语言间转换
  • 合成数据: 生成合成训练数据

📚 学习收获

完成本项目后,你将掌握:

  • ✅ Code LLM的理解和应用
  • ✅ 代码生成和补全技术
  • ✅ AST解析和代码理解
  • ✅ Pass@k等评估指标
  • ✅ 模型微调和优化
  • ✅ Streamlit Web应用开发
  • ✅ 完整的AI辅助编程系统开发

🔗 参考资源


项目完成时间: 12-18小时 难度等级: ⭐⭐⭐ 中等偏难 推荐指数: ⭐⭐⭐⭐⭐


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026