项目2: 代码生成系统¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
难度: ⭐⭐⭐ 中等偏难 时间: 12-18小时 涉及知识: Code LLM、Fine-tuning、代码理解、AST解析
📖 项目概述¶
项目背景¶
代码生成是AI辅助编程的核心能力之一。基于大语言模型的代码生成系统能够根据自然语言描述或部分代码,自动生成完整的代码片段、函数甚至整个模块。这不仅能提高开发效率,还能帮助新手学习编程,减少重复劳动。
项目目标¶
构建一个完整的代码生成系统,能够: - 根据自然语言描述生成代码 - 根据函数签名和文档字符串生成实现 - 支持多种编程语言 - 生成可运行的、符合规范的代码 - 提供代码解释和注释 - 支持代码补全和优化建议
技术栈¶
- 深度学习框架: PyTorch
- 预训练模型: CodeLlama / StarCoder / GPT-NeoX
- 代码处理: tree-sitter, AST
- 评估工具: CodeBLEU, Pass@k
- Web框架: Streamlit
- IDE集成: VS Code Extension API
🏗️ 项目结构¶
Text Only
code-generation-system/
├── app/ # 应用主目录
│ ├── __init__.py
│ ├── config.py # 配置文件
│ ├── models/ # 模型定义
│ │ ├── __init__.py
│ │ ├── code_llm.py # 代码生成模型
│ │ └── preprocess.py # 预处理模块
│ ├── data/ # 数据处理
│ │ ├── __init__.py
│ │ ├── dataset.py # 数据集类
│ │ └── loader.py # 数据加载器
│ ├── train.py # 训练脚本
│ ├── evaluate.py # 评估脚本
│ └── api/ # API接口
│ ├── __init__.py
│ └── generate.py # 生成API
├── frontend/ # 前端目录
│ ├── app.py # Streamlit应用
│ └── components/ # UI组件
├── data/ # 数据目录
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后数据
│ └── cache/ # 缓存数据
├── models/ # 模型目录
│ ├── checkpoints/ # 模型检查点
│ └── pretrained/ # 预训练模型
├── tests/ # 测试目录
│ ├── test_models.py
│ ├── test_data.py
│ └── test_evaluation.py
├── utils/ # 工具函数
│ ├── __init__.py
│ ├── ast_utils.py # AST工具
│ ├── metrics.py # 评估指标
│ └── code_utils.py # 代码工具
├── requirements.txt # Python依赖
└── README.md # 项目说明
🎯 核心功能¶
1. 代码生成¶
- 自然语言到代码: 根据描述生成代码
- 函数实现: 根据签名和文档生成实现
- 代码补全: 根据上下文补全代码
- 代码转换: 在不同语言间转换代码
2. 代码理解¶
- 代码解释: 生成代码的自然语言解释
- 代码注释: 自动生成代码注释
- 代码摘要: 生成代码功能摘要
- Bug检测: 识别代码中的潜在问题
3. 代码优化¶
- 性能优化: 优化代码性能
- 代码重构: 重构代码结构
- 代码审查: 提供代码审查建议
- 最佳实践: 应用编程最佳实践
4. 多语言支持¶
- Python: Python代码生成
- JavaScript: JavaScript/TypeScript代码生成
- Java: Java代码生成
- C++: C++代码生成
5. 质量评估¶
- 语法检查: 检查代码语法正确性
- 可执行性: 验证代码可执行性
- 代码质量: 评估代码质量
- Pass@k: 评估生成成功率
💻 代码实现¶
1. 配置文件 (app/config.py)¶
Python
"""
代码生成系统配置文件
"""
import os
from typing import Optional
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""应用配置"""
# 模型配置
MODEL_NAME: str = "bigcode/starcoder2-15b"
MAX_LENGTH: int = 2048
TEMPERATURE: float = 0.2
TOP_P: float = 0.95
TOP_K: int = 50
# 训练配置
BATCH_SIZE: int = 2
LEARNING_RATE: float = 5e-5
NUM_EPOCHS: int = 3
WARMUP_STEPS: int = 100
GRADIENT_ACCUMULATION_STEPS: int = 4
# 数据配置
TRAIN_DATA_PATH: str = "./data/raw/train.json"
VAL_DATA_PATH: str = "./data/raw/val.json"
TEST_DATA_PATH: str = "./data/raw/test.json"
# 生成配置
MAX_NEW_TOKENS: int = 512
NUM_RETURN_SEQUENCES: int = 1
DO_SAMPLE: bool = True
# 评估配置
EVAL_STEPS: int = 500
SAVE_STEPS: int = 1000
K_VALUES: list = [1, 10, 100] # Pass@k的k值
# API配置
API_HOST: str = "0.0.0.0"
API_PORT: int = 8000
# 缓存配置
CACHE_DIR: str = "./data/cache"
class Config:
env_file = ".env"
settings = Settings()
2. 代码生成模型 (app/models/code_llm.py)¶
Python
"""
代码生成模型
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Optional, List
class CodeGenerationModel:
"""代码生成模型"""
def __init__(
self,
model_name: str = "bigcode/starcoder2-15b",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
初始化模型
Args:
model_name: 预训练模型名称
device: 设备
"""
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
)
if device != "cuda":
self.model.to(device)
self.model.eval()
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 50,
num_return_sequences: int = 1,
do_sample: bool = True,
stop_tokens: Optional[List[str]] = None,
) -> str:
"""
生成代码
Args:
prompt: 输入提示
max_new_tokens: 最大生成token数
temperature: 温度参数
top_p: nucleus sampling参数
top_k: top-k sampling参数
num_return_sequences: 返回序列数量
do_sample: 是否采样
stop_tokens: 停止token列表
Returns:
生成的代码
"""
# 编码输入
inputs = self.tokenizer(
prompt,
return_tensors='pt',
truncation=True,
max_length=2048 - max_new_tokens,
).to(self.device)
# 生成代码
with torch.no_grad():
outputs = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# 解码输出
generated_code = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# 提取生成的部分
generated_code = generated_code[len(prompt):].strip()
# 处理停止token
if stop_tokens:
for stop_token in stop_tokens:
if stop_token in generated_code:
generated_code = generated_code.split(stop_token)[0]
return generated_code
def generate_function(
self,
function_signature: str,
docstring: str = "",
language: str = "python",
) -> str:
"""
生成函数实现
Args:
function_signature: 函数签名
docstring: 函数文档字符串
language: 编程语言
Returns:
生成的函数实现
"""
# 构建提示
if language == "python":
prompt = f"""```python
{function_signature}
\"\"\"
{docstring}
\"\"\"
```"""
else:
prompt = f"```{language}\n{function_signature}\n```\n"
# 生成代码
generated_code = self.generate(prompt)
return generated_code
def generate_from_description(
self,
description: str,
language: str = "python",
) -> str:
"""
根据描述生成代码
Args:
description: 自然语言描述
language: 编程语言
Returns:
生成的代码
"""
# 构建提示
prompt = f"""Write {language} code to:
{description}
```{language}
"""
# 生成代码
generated_code = self.generate(prompt)
return generated_code
def code_completion(
self,
prefix: str,
max_new_tokens: int = 256,
) -> str:
"""
代码补全
Args:
prefix: 代码前缀
max_new_tokens: 最大生成token数
Returns:
补全的代码
"""
return self.generate(
prefix,
max_new_tokens=max_new_tokens,
temperature=0.1, # 较低的温度以获得更确定的补全
)
def explain_code(
self,
code: str,
language: str = "python",
) -> str:
"""
解释代码
Args:
code: 代码
language: 编程语言
Returns:
代码解释
"""
prompt = f"""Explain the following {language} code:
```{language}
{code}
```text
Explanation:"""
explanation = self.generate(
prompt,
max_new_tokens=512,
temperature=0.3,
)
return explanation
def add_comments(
self,
code: str,
language: str = "python",
) -> str:
"""
添加代码注释
Args:
code: 代码
language: 编程语言
Returns:
带注释的代码
"""
prompt = f"""Add comments to the following {language} code:
```{language}
{code}
```"""
commented_code = self.generate(
prompt,
max_new_tokens=1024,
temperature=0.1,
)
return commented_code
3. 数据集类 (app/data/dataset.py)¶
Python
{item['code']} """
代码生成数据集类
"""
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from typing import List, Dict, Any
class CodeGenerationDataset(Dataset):
"""代码生成数据集"""
def __init__(
self,
data_path: str,
tokenizer: AutoTokenizer,
max_length: int = 2048,
):
"""
初始化数据集
Args:
data_path: 数据文件路径
tokenizer: 分词器
max_length: 最大长度
"""
self.data = self._load_data(data_path)
self.tokenizer = tokenizer
self.max_length = max_length
def _load_data(self, data_path: str) -> List[Dict[str, Any]]:
"""加载数据"""
with open(data_path, 'r', encoding='utf-8') as f: # with自动管理文件关闭
data = [json.loads(line) for line in f] # json.loads将JSON字符串→Python对象
return data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""获取数据项"""
item = self.data[idx]
# 构建提示
prompt = self._build_prompt(item)
# 编码
encoding = self.tokenizer(
prompt,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': encoding['input_ids'].flatten(),
}
def _build_prompt(self, item: Dict[str, Any]) -> str:
"""构建提示"""
if 'description' in item:
# 自然语言到代码
return f"""Write code to: {item['description']}
""" elif 'signature' in item: # 函数签名到实现 return f""" {item['signature']} \"\"\" {item.get('docstring', '')} \"\"\" {item['code']} 4. 评估指标 (utils/metrics.py)¶
Python
"""
代码生成评估指标
"""
import ast
import subprocess
import tempfile
import os
from typing import List, Dict, Tuple
import numpy as np
class CodeEvaluator:
"""代码评估器"""
def __init__(self, language: str = "python"):
"""
初始化评估器
Args:
language: 编程语言
"""
self.language = language
def check_syntax(self, code: str) -> bool:
"""
检查代码语法
Args:
code: 代码
Returns:
语法是否正确
"""
try: # try/except捕获异常,防止程序崩溃
if self.language == "python":
ast.parse(code)
return True
except:
return False
def check_executable(
self,
code: str,
timeout: int = 5,
) -> bool:
"""
检查代码可执行性
Args:
code: 代码
timeout: 超时时间
Returns:
代码是否可执行
"""
if self.language != "python":
return False
try:
# 创建临时文件
with tempfile.NamedTemporaryFile(
mode='w',
suffix='.py',
delete=False
) as f:
f.write(code)
temp_file = f.name
# 执行代码
result = subprocess.run(
['python', temp_file],
timeout=timeout,
capture_output=True,
)
# 删除临时文件
os.unlink(temp_file)
return result.returncode == 0
except:
return False
def compute_pass_at_k(
self,
predictions: List[List[str]],
references: List[str],
k: int = 1,
) -> float:
"""
计算Pass@k指标
Args:
predictions: 预测列表,每个元素是k个候选
references: 参考代码列表
k: k值
Returns:
Pass@k分数
"""
n = len(predictions)
c = 0
for preds, ref in zip(predictions, references): # zip按位置配对多个可迭代对象
# 检查k个候选中是否有正确的
correct = False
for pred in preds[:k]:
if self._is_correct(pred, ref):
correct = True
break
if correct:
c += 1
return c / n
def _is_correct(
self,
prediction: str,
reference: str,
) -> bool:
"""
检查预测是否正确
Args:
prediction: 预测代码
reference: 参考代码
Returns:
是否正确
"""
# 检查语法
if not self.check_syntax(prediction):
return False
# 检查可执行性
if not self.check_executable(prediction):
return False
# 这里可以添加更多的检查逻辑
# 例如:运行测试用例、比较输出等
return True
def evaluate(
self,
predictions: List[List[str]],
references: List[str],
k_values: List[int] = [1, 10, 100],
) -> Dict[str, float]:
"""
综合评估
Args:
predictions: 预测列表
references: 参考列表
k_values: k值列表
Returns:
评估结果
"""
results = {}
for k in k_values:
pass_at_k = self.compute_pass_at_k(
predictions,
references,
k=k,
)
results[f'pass_at_{k}'] = pass_at_k
# 计算语法正确率
syntax_correct = 0
for preds in predictions:
if any(self.check_syntax(pred) for pred in preds): # any()任一为True则返回True
syntax_correct += 1
results['syntax_correct_rate'] = syntax_correct / len(predictions)
return results
5. Streamlit前端 (frontend/app.py)¶
Python
"""
代码生成系统Web界面
"""
import streamlit as st
from app.models.code_llm import CodeGenerationModel
from app.config import settings
# 页面配置
st.set_page_config(
page_title="代码生成系统",
page_icon="💻",
layout="wide"
)
# 标题
st.title("💻 代码生成系统")
st.markdown("---")
# 侧边栏
st.sidebar.header("参数设置")
# 模型选择
model_name = st.sidebar.selectbox(
"选择模型",
["bigcode/starcoder2-15b", "bigcode/starcoder2-7b"],
)
# 生成参数
max_new_tokens = st.sidebar.slider(
"最大生成token数",
min_value=128,
max_value=1024,
value=512,
step=64
)
temperature = st.sidebar.slider(
"温度",
min_value=0.0,
max_value=1.0,
value=0.2,
step=0.1
)
top_p = st.sidebar.slider(
"Top P",
min_value=0.5,
max_value=1.0,
value=0.95,
step=0.05
)
# 编程语言
language = st.sidebar.selectbox(
"编程语言",
["python", "javascript", "java", "cpp"],
)
# 加载模型
@st.cache_resource
def load_model(model_name: str):
"""加载模型"""
return CodeGenerationModel(
model_name=model_name,
)
model = load_model(model_name)
# 选择功能
function = st.radio(
"选择功能",
["自然语言生成代码", "函数实现", "代码补全", "代码解释", "添加注释"],
horizontal=True
)
# 自然语言生成代码
if function == "自然语言生成代码":
description = st.text_area(
"输入代码描述",
height=150,
placeholder="例如:实现一个快速排序算法..."
)
if st.button("生成代码", type="primary"):
if description.strip(): # 链式调用:strip去除空白
with st.spinner("正在生成代码..."):
code = model.generate_from_description(
description,
language=language,
)
st.success("代码生成成功!")
st.code(code, language=language)
else:
st.warning("请输入代码描述!")
# 函数实现
elif function == "函数实现":
col1, col2 = st.columns(2)
with col1:
function_signature = st.text_area(
"函数签名",
height=100,
placeholder="def quick_sort(arr):"
)
with col2:
docstring = st.text_area(
"文档字符串",
height=100,
placeholder="使用快速排序算法对数组进行排序..."
)
if st.button("生成实现", type="primary"):
if function_signature.strip():
with st.spinner("正在生成实现..."):
code = model.generate_function(
function_signature,
docstring,
language,
)
st.success("实现生成成功!")
st.code(code, language=language)
else:
st.warning("请输入函数签名!")
# 代码补全
elif function == "代码补全":
prefix = st.text_area(
"输入代码前缀",
height=200,
placeholder="输入部分代码..."
)
if st.button("补全代码", type="primary"):
if prefix.strip():
with st.spinner("正在补全代码..."):
completion = model.code_completion(prefix)
st.success("代码补全成功!")
st.code(prefix + completion, language=language)
else:
st.warning("请输入代码前缀!")
# 代码解释
elif function == "代码解释":
code = st.text_area(
"输入代码",
height=200,
placeholder="输入需要解释的代码..."
)
if st.button("解释代码", type="primary"):
if code.strip():
with st.spinner("正在解释代码..."):
explanation = model.explain_code(code, language)
st.success("代码解释成功!")
st.write(explanation)
else:
st.warning("请输入代码!")
# 添加注释
elif function == "添加注释":
code = st.text_area(
"输入代码",
height=200,
placeholder="输入需要添加注释的代码..."
)
if st.button("添加注释", type="primary"):
if code.strip():
with st.spinner("正在添加注释..."):
commented_code = model.add_comments(code, language)
st.success("注释添加成功!")
st.code(commented_code, language=language)
else:
st.warning("请输入代码!")
🧪 测试方法¶
1. 单元测试¶
Python
"""
单元测试示例
"""
import pytest
from app.models.code_llm import CodeGenerationModel
def test_code_generation():
"""测试代码生成"""
model = CodeGenerationModel()
description = "实现一个快速排序算法"
code = model.generate_from_description(description)
assert len(code) > 0 # assert断言:条件False时抛出AssertionError
print(f"Generated code:\n{code}")
def test_function_generation():
"""测试函数生成"""
model = CodeGenerationModel()
signature = "def quick_sort(arr):"
docstring = "使用快速排序算法对数组进行排序"
code = model.generate_function(signature, docstring)
assert "def quick_sort" in code
print(f"Generated function:\n{code}")
2. 语法检查测试¶
Python
"""
语法检查测试
"""
from utils.metrics import CodeEvaluator
def test_syntax_check():
"""测试语法检查"""
evaluator = CodeEvaluator(language="python")
# 正确的代码
valid_code = "def hello():\n print('Hello, World!')"
assert evaluator.check_syntax(valid_code) == True
# 错误的代码
invalid_code = "def hello(\n print('Hello, World!')"
assert evaluator.check_syntax(invalid_code) == False
3. Pass@k测试¶
Python
"""
Pass@k测试
"""
def test_pass_at_k():
"""测试Pass@k指标"""
evaluator = CodeEvaluator()
predictions = [
["code1", "code2", "code3"],
["code4", "code5", "code6"],
]
references = ["ref1", "ref2"]
pass_at_1 = evaluator.compute_pass_at_k(predictions, references, k=1)
pass_at_10 = evaluator.compute_pass_at_k(predictions, references, k=10)
print(f"Pass@1: {pass_at_1}")
print(f"Pass@10: {pass_at_10}")
📊 扩展建议¶
1. 功能扩展¶
- 代码审查: 提供代码审查和改进建议
- Bug修复: 自动检测和修复代码bug
- 代码重构: 自动重构代码结构
- 测试生成: 自动生成单元测试
- 文档生成: 自动生成API文档
2. 模型优化¶
- 领域适配: 在特定领域数据上微调
- 模型压缩: 使用量化、蒸馏等技术
- 推理加速: 使用ONNX、TensorRT加速
- 多任务学习: 同时训练多个任务
3. IDE集成¶
- VS Code插件: 开发VS Code扩展
- JetBrains插件: 支持JetBrains IDE
- 命令行工具: 提供CLI工具
- API服务: 提供REST API
4. 数据增强¶
- 代码混淆: 通过混淆增强数据
- 风格迁移: 改变代码风格
- 跨语言转换: 在不同语言间转换
- 合成数据: 生成合成训练数据
📚 学习收获¶
完成本项目后,你将掌握:
- ✅ Code LLM的理解和应用
- ✅ 代码生成和补全技术
- ✅ AST解析和代码理解
- ✅ Pass@k等评估指标
- ✅ 模型微调和优化
- ✅ Streamlit Web应用开发
- ✅ 完整的AI辅助编程系统开发
🔗 参考资源¶
项目完成时间: 12-18小时 难度等级: ⭐⭐⭐ 中等偏难 推荐指数: ⭐⭐⭐⭐⭐
最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026