08 - 推理优化技术¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
大模型推理阶段的搜索策略、提示词优化与缓存机制
📌 交叉引用:通用大模型推理优化技术(含KV-Cache、PagedAttention、vLLM、量化部署等)的系统性讲解请参考 LLM应用/12-推理优化.md,本节侧重搜索策略、提示词优化与缓存机制。
📖 章节概述¶
本章将深入探讨大模型推理优化技术,包括搜索策略(贪心、束搜索、采样)、提示词优化和缓存机制等内容。这些技术可以显著提升推理速度和质量,适用于DeepSeek R1等推理模型。
🎯 学习目标¶
完成本章后,你将能够:
- 掌握各种搜索策略的实现
- 了解提示词优化的方法
- 实现高效的缓存机制
- 能够优化DeepSeek R1的推理性能
1. 搜索策略¶
1.1 贪心搜索¶
Python
import torch
import torch.nn.functional as F
def greedy_search(model, input_ids, max_length=50):
"""
贪心搜索
Args:
model: 语言模型
input_ids: 输入token IDs
max_length: 最大生成长度
"""
model.eval() # eval()评估模式
with torch.no_grad(): # 禁用梯度计算,节省内存
current_ids = input_ids.clone()
for _ in range(max_length):
# 前向传播
outputs = model(current_ids)
logits = outputs.logits[:, -1, :] # 最后一个token的logits
# 选择概率最高的token
next_token = torch.argmax(logits, dim=-1, keepdim=True)
# 拼接到序列
current_ids = torch.cat([current_ids, next_token], dim=-1) # torch.cat沿已有维度拼接张量
# 检查是否生成结束token
if next_token.item() == model.config.eos_token_id: # 将单元素张量转为Python数值
break
return current_ids
# 使用示例
# output_ids = greedy_search(model, input_ids)
1.2 束搜索(Beam Search)¶
Python
import torch
import torch.nn.functional as F
import heapq
class BeamSearchNode:
"""
束搜索节点
"""
def __init__(self, token_ids, score, log_prob):
self.token_ids = token_ids
self.score = score
self.log_prob = log_prob
def __lt__(self, other):
return self.score < other.score
def beam_search(model, input_ids, beam_width=5, max_length=50):
"""
束搜索
Args:
model: 语言模型
input_ids: 输入token IDs
beam_width: 束宽
max_length: 最大生成长度
"""
model.eval()
with torch.no_grad():
# 初始化束
beams = [BeamSearchNode(input_ids, 0.0, 0.0)]
for step in range(max_length):
new_beams = []
# 扩展每个束
for beam in beams:
# 前向传播
outputs = model(beam.token_ids)
logits = outputs.logits[:, -1, :]
# 计算log概率
log_probs = F.log_softmax(logits, dim=-1) # F.xxx PyTorch函数式API
# 获取top-k候选
topk_log_probs, topk_tokens = torch.topk(log_probs, beam_width, dim=-1)
# 创建新束
for i in range(beam_width):
new_token_ids = torch.cat([
beam.token_ids,
topk_tokens[:, i:i+1]
], dim=-1)
new_score = beam.score + topk_log_probs[0, i].item()
new_log_prob = beam.log_prob + topk_log_probs[0, i].item()
new_beams.append(BeamSearchNode(
new_token_ids,
new_score,
new_log_prob
))
# 选择top-k束
beams = heapq.nlargest(beam_width, new_beams)
# 检查是否所有束都结束
if all(beam.token_ids[0, -1].item() == model.config.eos_token_id # all()全部为True才返回True
for beam in beams):
break
# 返回最佳束
best_beam = max(beams, key=lambda x: x.score) # lambda匿名函数
return best_beam.token_ids
# 使用示例
# output_ids = beam_search(model, input_ids, beam_width=5)
1.3 采样策略¶
Python
import torch
import torch.nn.functional as F
def sample_with_temperature(model, input_ids, temperature=1.0, max_length=50):
"""
温度采样
Args:
model: 语言模型
input_ids: 输入token IDs
temperature: 温度参数
max_length: 最大生成长度
"""
model.eval()
with torch.no_grad():
current_ids = input_ids.clone()
for _ in range(max_length):
# 前向传播
outputs = model(current_ids)
logits = outputs.logits[:, -1, :]
# 应用温度
logits = logits / temperature
# 计算概率
probs = F.softmax(logits, dim=-1)
# 采样
next_token = torch.multinomial(probs, num_samples=1)
# 拼接到序列
current_ids = torch.cat([current_ids, next_token], dim=-1)
# 检查是否生成结束token
if next_token.item() == model.config.eos_token_id:
break
return current_ids
def top_k_sampling(model, input_ids, top_k=50, max_length=50):
"""
Top-K采样
Args:
model: 语言模型
input_ids: 输入token IDs
top_k: 保留的top-k数量
max_length: 最大生成长度
"""
model.eval()
with torch.no_grad():
current_ids = input_ids.clone()
for _ in range(max_length):
# 前向传播
outputs = model(current_ids)
logits = outputs.logits[:, -1, :]
# Top-K过滤
topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
# 计算概率
probs = F.softmax(topk_logits, dim=-1)
# 采样
next_token_idx = torch.multinomial(probs, num_samples=1)
next_token = topk_indices.gather(-1, next_token_idx)
# 拼接到序列
current_ids = torch.cat([current_ids, next_token], dim=-1)
# 检查是否生成结束token
if next_token.item() == model.config.eos_token_id:
break
return current_ids
def top_p_sampling(model, input_ids, top_p=0.9, max_length=50):
"""
Top-P(核采样)
Args:
model: 语言模型
input_ids: 输入token IDs
top_p: 累积概率阈值
max_length: 最大生成长度
"""
model.eval()
with torch.no_grad():
current_ids = input_ids.clone()
for _ in range(max_length):
# 前向传播
outputs = model(current_ids)
logits = outputs.logits[:, -1, :]
# 计算概率
probs = F.softmax(logits, dim=-1)
# 按概率排序
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 计算累积概率
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 找到累积概率超过top_p的位置
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# 移除低概率token
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
probs[indices_to_remove] = 0.0
# 重新归一化
probs = probs / probs.sum(dim=-1, keepdim=True)
# 采样
next_token = torch.multinomial(probs, num_samples=1)
# 拼接到序列
current_ids = torch.cat([current_ids, next_token], dim=-1)
# 检查是否生成结束token
if next_token.item() == model.config.eos_token_id:
break
return current_ids
# 使用示例
# output_ids = sample_with_temperature(model, input_ids, temperature=0.7)
# output_ids = top_k_sampling(model, input_ids, top_k=50)
# output_ids = top_p_sampling(model, input_ids, top_p=0.9)
2. 提示词优化¶
2.1 提示词模板¶
Python
class PromptTemplate:
"""
提示词模板
"""
def __init__(self, template):
self.template = template
def format(self, **kwargs): # *args接收任意位置参数,**kwargs接收任意关键字参数
"""
格式化提示词
Args:
**kwargs: 模板变量
"""
return self.template.format(**kwargs)
# 预定义模板
TEMPLATES = {
"qa": PromptTemplate(
"问题:{question}\n答案:"
),
"cot": PromptTemplate(
"问题:{question}\n让我们一步步思考:\n"
),
"few_shot": PromptTemplate(
"例子1:{example1}\n例子2:{example2}\n问题:{question}\n答案:"
),
"instruction": PromptTemplate(
"指令:{instruction}\n输入:{input}\n输出:"
)
}
# 使用示例
# template = TEMPLATES["cot"]
# prompt = template.format(question="什么是机器学习?")
2.2 提示词工程¶
Python
class PromptEngineer:
"""
提示词工程师
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def optimize_prompt(self, base_prompt, target_output,
iterations=10, learning_rate=0.01):
"""
优化提示词
Args:
base_prompt: 基础提示词
target_output: 目标输出
iterations: 迭代次数
learning_rate: 学习率
"""
# 编码提示词
prompt_ids = self.tokenizer.encode(base_prompt, return_tensors="pt")
# 转换为可学习参数
prompt_embeddings = self.model.get_input_embeddings()(prompt_ids)
prompt_embeddings = prompt_embeddings.requires_grad_(True)
# 优化器
optimizer = torch.optim.Adam([prompt_embeddings], lr=learning_rate)
# 编码目标输出
target_ids = self.tokenizer.encode(target_output, return_tensors="pt")
for iteration in range(iterations):
# 前向传播
outputs = self.model(inputs_embeds=prompt_embeddings)
logits = outputs.logits
# 计算损失
loss = F.cross_entropy( # F.cross_entropy PyTorch函数式交叉熵损失
logits[:, -target_ids.shape[1]-1:-1, :],
target_ids
)
# 反向传播
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
if iteration % 5 == 0:
print(f"Iteration {iteration}, Loss: {loss.item():.4f}")
# 解码优化后的提示词
optimized_prompt_ids = torch.argmax(
self.model.lm_head(prompt_embeddings),
dim=-1
)
optimized_prompt = self.tokenizer.decode(optimized_prompt_ids[0])
return optimized_prompt
# 使用示例
# engineer = PromptEngineer(model, tokenizer)
# optimized_prompt = engineer.optimize_prompt(
# "解释什么是机器学习",
# "机器学习是人工智能的一个分支...",
# iterations=20
# )
2.3 提示词缓存¶
Python
from functools import lru_cache
import hashlib
import json
class PromptCache:
"""
提示词缓存
"""
def __init__(self, max_size=1000):
self.cache = {}
self.max_size = max_size
self.access_count = {}
def _generate_key(self, prompt: str, **kwargs) -> str:
"""
生成缓存键
"""
data = {"prompt": prompt, **kwargs}
data_str = json.dumps(data, sort_keys=True) # json.dumps将Python对象序列化为JSON字符串
return hashlib.md5(data_str.encode()).hexdigest()
def get(self, prompt: str, **kwargs):
"""
获取缓存
"""
key = self._generate_key(prompt, **kwargs)
if key in self.cache:
self.access_count[key] = self.access_count.get(key, 0) + 1
return self.cache[key]
return None
def set(self, prompt: str, result: str, **kwargs):
"""
设置缓存
"""
key = self._generate_key(prompt, **kwargs)
# 如果缓存已满,删除最少使用的条目
if len(self.cache) >= self.max_size:
# lambda作为min()的key函数:按访问计数找出最少使用的缓存条目进行LRU淘汰
lru_key = min(self.access_count.keys(),
key=lambda k: self.access_count[k])
del self.cache[lru_key]
del self.access_count[lru_key]
self.cache[key] = result
self.access_count[key] = 1
def clear(self):
"""
清空缓存
"""
self.cache.clear()
self.access_count.clear()
# 使用示例
# cache = PromptCache(max_size=1000)
# result = cache.get("什么是机器学习?")
# if result is None:
# result = generate_response("什么是机器学习?")
# cache.set("什么是机器学习?", result)
3. 缓存机制¶
3.1 KV Cache¶
Python
import torch
class KVCache:
"""
KV缓存
"""
def __init__(self, max_batch_size, max_seq_len, num_heads, head_dim):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
# 初始化缓存
# 注意:维度顺序为 (batch, num_heads, seq_len, head_dim)
# 不同框架可能使用不同顺序:
# - PyTorch nn.MultiheadAttention: (seq_len, batch, embed_dim)
# - HuggingFace Transformers: (batch, num_heads, seq_len, head_dim)
# - xformers: (batch, num_heads, seq_len, head_dim)
self.k_cache = torch.zeros(
max_batch_size, num_heads, max_seq_len, head_dim
)
self.v_cache = torch.zeros(
max_batch_size, num_heads, max_seq_len, head_dim
)
# 当前序列长度
self.current_seq_len = 0
def update(self, k, v, batch_size, seq_len):
"""
更新缓存
Args:
k: 键张量 [batch_size, num_heads, seq_len, head_dim]
v: 值张量 [batch_size, num_heads, seq_len, head_dim]
batch_size: 批次大小
seq_len: 序列长度
"""
# 更新缓存
self.k_cache[:batch_size, :, self.current_seq_len:self.current_seq_len+seq_len, :] = k
self.v_cache[:batch_size, :, self.current_seq_len:self.current_seq_len+seq_len, :] = v
# 更新当前序列长度
self.current_seq_len += seq_len
def get(self, batch_size, seq_len):
"""
获取缓存
Returns:
k_cache: 键缓存
v_cache: 值缓存
"""
return (
self.k_cache[:batch_size, :, :self.current_seq_len, :],
self.v_cache[:batch_size, :, :self.current_seq_len, :]
)
def reset(self):
"""
重置缓存
"""
self.k_cache.zero_()
self.v_cache.zero_()
self.current_seq_len = 0
# 使用示例
# kv_cache = KVCache(max_batch_size=32, max_seq_len=2048, num_heads=12, head_dim=64)
# kv_cache.update(k, v, batch_size=32, seq_len=128)
# k_cached, v_cached = kv_cache.get(batch_size=32, seq_len=128)
3.2 推理缓存¶
Python
import torch
class InferenceCache:
"""
推理缓存
"""
def __init__(self, max_size=1000):
self.cache: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
self.max_size = max_size
self.access_order = []
def _generate_key(self, input_ids: torch.Tensor) -> str:
"""
生成缓存键
"""
return input_ids.cpu().numpy().tobytes().hex()
def get(self, input_ids: torch.Tensor):
"""
获取缓存
Args:
input_ids: 输入token IDs
"""
key = self._generate_key(input_ids)
if key in self.cache:
# 更新访问顺序
self.access_order.remove(key)
self.access_order.append(key)
return self.cache[key]
return None
def set(self, input_ids: torch.Tensor,
k_cache: torch.Tensor, v_cache: torch.Tensor):
"""
设置缓存
Args:
input_ids: 输入token IDs
k_cache: 键缓存
v_cache: 值缓存
"""
key = self._generate_key(input_ids)
# 如果缓存已满,删除最旧的条目
if len(self.cache) >= self.max_size:
oldest_key = self.access_order.pop(0)
del self.cache[oldest_key]
self.cache[key] = (k_cache, v_cache)
self.access_order.append(key)
def clear(self):
"""
清空缓存
"""
self.cache.clear()
self.access_order.clear()
# 使用示例
# inference_cache = InferenceCache(max_size=1000)
# cached = inference_cache.get(input_ids)
# if cached is None:
# k_cache, v_cache = compute_kv_cache(model, input_ids)
# inference_cache.set(input_ids, k_cache, v_cache)
3.3 结果缓存¶
Python
import redis
import json
import hashlib
from datetime import datetime, timedelta
class ResultCache:
"""
结果缓存(使用Redis)
"""
def __init__(self, redis_host='localhost', redis_port=6379):
self.redis = redis.Redis(host=redis_host, port=redis_port, db=0)
def _generate_key(self, prompt: str, **kwargs) -> str:
"""
生成缓存键
"""
data = {"prompt": prompt, **kwargs}
data_str = json.dumps(data, sort_keys=True)
return f"result:{hashlib.md5(data_str.encode()).hexdigest()}"
def get(self, prompt: str, **kwargs):
"""
获取缓存结果
Args:
prompt: 提示词
**kwargs: 其他参数
"""
key = self._generate_key(prompt, **kwargs)
cached = self.redis.get(key)
if cached:
return json.loads(cached.decode('utf-8')) # json.loads将JSON字符串解析为Python对象
return None
def set(self, prompt: str, result: str, ttl=3600, **kwargs):
"""
设置缓存结果
Args:
prompt: 提示词
result: 结果
ttl: 生存时间(秒)
**kwargs: 其他参数
"""
key = self._generate_key(prompt, **kwargs)
data = {
"result": result,
"timestamp": datetime.now().isoformat()
}
self.redis.setex(key, ttl, json.dumps(data))
def clear(self):
"""
清空缓存
"""
for key in self.redis.scan_iter("result:*"):
self.redis.delete(key)
# 使用示例
# result_cache = ResultCache()
# cached_result = result_cache.get("什么是机器学习?", max_tokens=200)
# if cached_result is None:
# result = generate_response("什么是机器学习?", max_tokens=200)
# result_cache.set("什么是机器学习?", result, ttl=3600, max_tokens=200)
4. 练习题¶
基础练习¶
-
实现束搜索
-
实现Top-K采样
进阶练习¶
-
实现提示词优化
-
实现KV缓存
项目练习¶
- 创建推理优化框架
- 支持多种搜索策略
- 实现提示词优化
- 集成缓存机制
5. 最佳实践¶
✅ 推荐做法¶
- 选择合适的搜索策略
- 贪心搜索:快速但可能次优
- 束搜索:平衡质量和速度
-
采样:增加多样性
-
优化提示词
- 使用清晰的指令
- 提供示例
-
迭代优化
-
利用缓存
- 缓存KV值
- 缓存结果
- 设置合理的TTL
❌ 避免做法¶
- 过度依赖搜索
- 不要使用过大的束宽
- 考虑计算成本
-
平衡质量和速度
-
忽略提示词质量
- 花时间优化提示词
- 测试不同变体
-
收集反馈
-
缓存失效
- 定期清理缓存
- 设置合理的TTL
- 监控缓存命中率
6. 总结¶
本章介绍了推理优化的核心技术:
- 搜索策略: 贪心、束搜索、采样
- 提示词优化: 模板、工程、缓存
- 缓存机制: KV缓存、推理缓存、结果缓存
这些技术可以显著提升DeepSeek R1的推理性能和质量。
7. 下一步¶
继续学习09-提示词工程与调优,深入了解提示词工程的实践方法。