跳转至

08 - 推理优化技术

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

大模型推理阶段的搜索策略、提示词优化与缓存机制

📌 交叉引用:通用大模型推理优化技术(含KV-Cache、PagedAttention、vLLM、量化部署等)的系统性讲解请参考 LLM应用/12-推理优化.md,本节侧重搜索策略、提示词优化与缓存机制。

📖 章节概述

本章将深入探讨大模型推理优化技术,包括搜索策略(贪心、束搜索、采样)、提示词优化和缓存机制等内容。这些技术可以显著提升推理速度和质量,适用于DeepSeek R1等推理模型。

🎯 学习目标

完成本章后,你将能够:

  • 掌握各种搜索策略的实现
  • 了解提示词优化的方法
  • 实现高效的缓存机制
  • 能够优化DeepSeek R1的推理性能

1. 搜索策略

1.1 贪心搜索

Python
import torch
import torch.nn.functional as F

def greedy_search(model, input_ids, max_length=50):
    """
    贪心搜索

    Args:
        model: 语言模型
        input_ids: 输入token IDs
        max_length: 最大生成长度
    """
    model.eval()  # eval()评估模式

    with torch.no_grad():  # 禁用梯度计算,节省内存
        current_ids = input_ids.clone()

        for _ in range(max_length):
            # 前向传播
            outputs = model(current_ids)
            logits = outputs.logits[:, -1, :]  # 最后一个token的logits

            # 选择概率最高的token
            next_token = torch.argmax(logits, dim=-1, keepdim=True)

            # 拼接到序列
            current_ids = torch.cat([current_ids, next_token], dim=-1)  # torch.cat沿已有维度拼接张量

            # 检查是否生成结束token
            if next_token.item() == model.config.eos_token_id:  # 将单元素张量转为Python数值
                break

    return current_ids

# 使用示例
# output_ids = greedy_search(model, input_ids)
Python
import torch
import torch.nn.functional as F
import heapq

class BeamSearchNode:
    """
    束搜索节点
    """
    def __init__(self, token_ids, score, log_prob):
        self.token_ids = token_ids
        self.score = score
        self.log_prob = log_prob

    def __lt__(self, other):
        return self.score < other.score

def beam_search(model, input_ids, beam_width=5, max_length=50):
    """
    束搜索

    Args:
        model: 语言模型
        input_ids: 输入token IDs
        beam_width: 束宽
        max_length: 最大生成长度
    """
    model.eval()

    with torch.no_grad():
        # 初始化束
        beams = [BeamSearchNode(input_ids, 0.0, 0.0)]

        for step in range(max_length):
            new_beams = []

            # 扩展每个束
            for beam in beams:
                # 前向传播
                outputs = model(beam.token_ids)
                logits = outputs.logits[:, -1, :]

                # 计算log概率
                log_probs = F.log_softmax(logits, dim=-1)  # F.xxx PyTorch函数式API

                # 获取top-k候选
                topk_log_probs, topk_tokens = torch.topk(log_probs, beam_width, dim=-1)

                # 创建新束
                for i in range(beam_width):
                    new_token_ids = torch.cat([
                        beam.token_ids,
                        topk_tokens[:, i:i+1]
                    ], dim=-1)

                    new_score = beam.score + topk_log_probs[0, i].item()
                    new_log_prob = beam.log_prob + topk_log_probs[0, i].item()

                    new_beams.append(BeamSearchNode(
                        new_token_ids,
                        new_score,
                        new_log_prob
                    ))

            # 选择top-k束
            beams = heapq.nlargest(beam_width, new_beams)

            # 检查是否所有束都结束
            if all(beam.token_ids[0, -1].item() == model.config.eos_token_id  # all()全部为True才返回True
                   for beam in beams):
                break

        # 返回最佳束
        best_beam = max(beams, key=lambda x: x.score)  # lambda匿名函数
        return best_beam.token_ids

# 使用示例
# output_ids = beam_search(model, input_ids, beam_width=5)

1.3 采样策略

Python
import torch
import torch.nn.functional as F

def sample_with_temperature(model, input_ids, temperature=1.0, max_length=50):
    """
    温度采样

    Args:
        model: 语言模型
        input_ids: 输入token IDs
        temperature: 温度参数
        max_length: 最大生成长度
    """
    model.eval()

    with torch.no_grad():
        current_ids = input_ids.clone()

        for _ in range(max_length):
            # 前向传播
            outputs = model(current_ids)
            logits = outputs.logits[:, -1, :]

            # 应用温度
            logits = logits / temperature

            # 计算概率
            probs = F.softmax(logits, dim=-1)

            # 采样
            next_token = torch.multinomial(probs, num_samples=1)

            # 拼接到序列
            current_ids = torch.cat([current_ids, next_token], dim=-1)

            # 检查是否生成结束token
            if next_token.item() == model.config.eos_token_id:
                break

    return current_ids

def top_k_sampling(model, input_ids, top_k=50, max_length=50):
    """
    Top-K采样

    Args:
        model: 语言模型
        input_ids: 输入token IDs
        top_k: 保留的top-k数量
        max_length: 最大生成长度
    """
    model.eval()

    with torch.no_grad():
        current_ids = input_ids.clone()

        for _ in range(max_length):
            # 前向传播
            outputs = model(current_ids)
            logits = outputs.logits[:, -1, :]

            # Top-K过滤
            topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)

            # 计算概率
            probs = F.softmax(topk_logits, dim=-1)

            # 采样
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = topk_indices.gather(-1, next_token_idx)

            # 拼接到序列
            current_ids = torch.cat([current_ids, next_token], dim=-1)

            # 检查是否生成结束token
            if next_token.item() == model.config.eos_token_id:
                break

    return current_ids

def top_p_sampling(model, input_ids, top_p=0.9, max_length=50):
    """
    Top-P(核采样)

    Args:
        model: 语言模型
        input_ids: 输入token IDs
        top_p: 累积概率阈值
        max_length: 最大生成长度
    """
    model.eval()

    with torch.no_grad():
        current_ids = input_ids.clone()

        for _ in range(max_length):
            # 前向传播
            outputs = model(current_ids)
            logits = outputs.logits[:, -1, :]

            # 计算概率
            probs = F.softmax(logits, dim=-1)

            # 按概率排序
            sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)

            # 计算累积概率
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            # 找到累积概率超过top_p的位置
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            # 移除低概率token
            indices_to_remove = sorted_indices_to_remove.scatter(
                1, sorted_indices, sorted_indices_to_remove
            )
            probs[indices_to_remove] = 0.0

            # 重新归一化
            probs = probs / probs.sum(dim=-1, keepdim=True)

            # 采样
            next_token = torch.multinomial(probs, num_samples=1)

            # 拼接到序列
            current_ids = torch.cat([current_ids, next_token], dim=-1)

            # 检查是否生成结束token
            if next_token.item() == model.config.eos_token_id:
                break

    return current_ids

# 使用示例
# output_ids = sample_with_temperature(model, input_ids, temperature=0.7)
# output_ids = top_k_sampling(model, input_ids, top_k=50)
# output_ids = top_p_sampling(model, input_ids, top_p=0.9)

2. 提示词优化

2.1 提示词模板

Python
class PromptTemplate:
    """
    提示词模板
    """
    def __init__(self, template):
        self.template = template

    def format(self, **kwargs):  # *args接收任意位置参数,**kwargs接收任意关键字参数
        """
        格式化提示词

        Args:
            **kwargs: 模板变量
        """
        return self.template.format(**kwargs)

# 预定义模板
TEMPLATES = {
    "qa": PromptTemplate(
        "问题:{question}\n答案:"
    ),
    "cot": PromptTemplate(
        "问题:{question}\n让我们一步步思考:\n"
    ),
    "few_shot": PromptTemplate(
        "例子1:{example1}\n例子2:{example2}\n问题:{question}\n答案:"
    ),
    "instruction": PromptTemplate(
        "指令:{instruction}\n输入:{input}\n输出:"
    )
}

# 使用示例
# template = TEMPLATES["cot"]
# prompt = template.format(question="什么是机器学习?")

2.2 提示词工程

Python
class PromptEngineer:
    """
    提示词工程师
    """
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def optimize_prompt(self, base_prompt, target_output,
                      iterations=10, learning_rate=0.01):
        """
        优化提示词

        Args:
            base_prompt: 基础提示词
            target_output: 目标输出
            iterations: 迭代次数
            learning_rate: 学习率
        """
        # 编码提示词
        prompt_ids = self.tokenizer.encode(base_prompt, return_tensors="pt")

        # 转换为可学习参数
        prompt_embeddings = self.model.get_input_embeddings()(prompt_ids)
        prompt_embeddings = prompt_embeddings.requires_grad_(True)

        # 优化器
        optimizer = torch.optim.Adam([prompt_embeddings], lr=learning_rate)

        # 编码目标输出
        target_ids = self.tokenizer.encode(target_output, return_tensors="pt")

        for iteration in range(iterations):
            # 前向传播
            outputs = self.model(inputs_embeds=prompt_embeddings)
            logits = outputs.logits

            # 计算损失
            loss = F.cross_entropy(  # F.cross_entropy PyTorch函数式交叉熵损失
                logits[:, -target_ids.shape[1]-1:-1, :],
                target_ids
            )

            # 反向传播
            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新参数

            if iteration % 5 == 0:
                print(f"Iteration {iteration}, Loss: {loss.item():.4f}")

        # 解码优化后的提示词
        optimized_prompt_ids = torch.argmax(
            self.model.lm_head(prompt_embeddings),
            dim=-1
        )
        optimized_prompt = self.tokenizer.decode(optimized_prompt_ids[0])

        return optimized_prompt

# 使用示例
# engineer = PromptEngineer(model, tokenizer)
# optimized_prompt = engineer.optimize_prompt(
#     "解释什么是机器学习",
#     "机器学习是人工智能的一个分支...",
#     iterations=20
# )

2.3 提示词缓存

Python
from functools import lru_cache
import hashlib
import json

class PromptCache:
    """
    提示词缓存
    """
    def __init__(self, max_size=1000):
        self.cache = {}
        self.max_size = max_size
        self.access_count = {}

    def _generate_key(self, prompt: str, **kwargs) -> str:
        """
        生成缓存键
        """
        data = {"prompt": prompt, **kwargs}
        data_str = json.dumps(data, sort_keys=True)  # json.dumps将Python对象序列化为JSON字符串
        return hashlib.md5(data_str.encode()).hexdigest()

    def get(self, prompt: str, **kwargs):
        """
        获取缓存
        """
        key = self._generate_key(prompt, **kwargs)

        if key in self.cache:
            self.access_count[key] = self.access_count.get(key, 0) + 1
            return self.cache[key]

        return None

    def set(self, prompt: str, result: str, **kwargs):
        """
        设置缓存
        """
        key = self._generate_key(prompt, **kwargs)

        # 如果缓存已满,删除最少使用的条目
        if len(self.cache) >= self.max_size:
            # lambda作为min()的key函数:按访问计数找出最少使用的缓存条目进行LRU淘汰
            lru_key = min(self.access_count.keys(),
                         key=lambda k: self.access_count[k])
            del self.cache[lru_key]
            del self.access_count[lru_key]

        self.cache[key] = result
        self.access_count[key] = 1

    def clear(self):
        """
        清空缓存
        """
        self.cache.clear()
        self.access_count.clear()

# 使用示例
# cache = PromptCache(max_size=1000)
# result = cache.get("什么是机器学习?")
# if result is None:
#     result = generate_response("什么是机器学习?")
#     cache.set("什么是机器学习?", result)

3. 缓存机制

3.1 KV Cache

Python
import torch

class KVCache:
    """
    KV缓存
    """
    def __init__(self, max_batch_size, max_seq_len, num_heads, head_dim):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim

        # 初始化缓存
        # 注意:维度顺序为 (batch, num_heads, seq_len, head_dim)
        # 不同框架可能使用不同顺序:
        # - PyTorch nn.MultiheadAttention: (seq_len, batch, embed_dim)
        # - HuggingFace Transformers: (batch, num_heads, seq_len, head_dim)
        # - xformers: (batch, num_heads, seq_len, head_dim)
        self.k_cache = torch.zeros(
            max_batch_size, num_heads, max_seq_len, head_dim
        )
        self.v_cache = torch.zeros(
            max_batch_size, num_heads, max_seq_len, head_dim
        )

        # 当前序列长度
        self.current_seq_len = 0

    def update(self, k, v, batch_size, seq_len):
        """
        更新缓存

        Args:
            k: 键张量 [batch_size, num_heads, seq_len, head_dim]
            v: 值张量 [batch_size, num_heads, seq_len, head_dim]
            batch_size: 批次大小
            seq_len: 序列长度
        """
        # 更新缓存
        self.k_cache[:batch_size, :, self.current_seq_len:self.current_seq_len+seq_len, :] = k
        self.v_cache[:batch_size, :, self.current_seq_len:self.current_seq_len+seq_len, :] = v

        # 更新当前序列长度
        self.current_seq_len += seq_len

    def get(self, batch_size, seq_len):
        """
        获取缓存

        Returns:
            k_cache: 键缓存
            v_cache: 值缓存
        """
        return (
            self.k_cache[:batch_size, :, :self.current_seq_len, :],
            self.v_cache[:batch_size, :, :self.current_seq_len, :]
        )

    def reset(self):
        """
        重置缓存
        """
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.current_seq_len = 0

# 使用示例
# kv_cache = KVCache(max_batch_size=32, max_seq_len=2048, num_heads=12, head_dim=64)
# kv_cache.update(k, v, batch_size=32, seq_len=128)
# k_cached, v_cached = kv_cache.get(batch_size=32, seq_len=128)

3.2 推理缓存

Python
import torch

class InferenceCache:
    """
    推理缓存
    """
    def __init__(self, max_size=1000):
        self.cache: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
        self.max_size = max_size
        self.access_order = []

    def _generate_key(self, input_ids: torch.Tensor) -> str:
        """
        生成缓存键
        """
        return input_ids.cpu().numpy().tobytes().hex()

    def get(self, input_ids: torch.Tensor):
        """
        获取缓存

        Args:
            input_ids: 输入token IDs
        """
        key = self._generate_key(input_ids)

        if key in self.cache:
            # 更新访问顺序
            self.access_order.remove(key)
            self.access_order.append(key)

            return self.cache[key]

        return None

    def set(self, input_ids: torch.Tensor,
            k_cache: torch.Tensor, v_cache: torch.Tensor):
        """
        设置缓存

        Args:
            input_ids: 输入token IDs
            k_cache: 键缓存
            v_cache: 值缓存
        """
        key = self._generate_key(input_ids)

        # 如果缓存已满,删除最旧的条目
        if len(self.cache) >= self.max_size:
            oldest_key = self.access_order.pop(0)
            del self.cache[oldest_key]

        self.cache[key] = (k_cache, v_cache)
        self.access_order.append(key)

    def clear(self):
        """
        清空缓存
        """
        self.cache.clear()
        self.access_order.clear()

# 使用示例
# inference_cache = InferenceCache(max_size=1000)
# cached = inference_cache.get(input_ids)
# if cached is None:
#     k_cache, v_cache = compute_kv_cache(model, input_ids)
#     inference_cache.set(input_ids, k_cache, v_cache)

3.3 结果缓存

Python
import redis
import json
import hashlib
from datetime import datetime, timedelta

class ResultCache:
    """
    结果缓存(使用Redis)
    """
    def __init__(self, redis_host='localhost', redis_port=6379):
        self.redis = redis.Redis(host=redis_host, port=redis_port, db=0)

    def _generate_key(self, prompt: str, **kwargs) -> str:
        """
        生成缓存键
        """
        data = {"prompt": prompt, **kwargs}
        data_str = json.dumps(data, sort_keys=True)
        return f"result:{hashlib.md5(data_str.encode()).hexdigest()}"

    def get(self, prompt: str, **kwargs):
        """
        获取缓存结果

        Args:
            prompt: 提示词
            **kwargs: 其他参数
        """
        key = self._generate_key(prompt, **kwargs)

        cached = self.redis.get(key)
        if cached:
            return json.loads(cached.decode('utf-8'))  # json.loads将JSON字符串解析为Python对象

        return None

    def set(self, prompt: str, result: str, ttl=3600, **kwargs):
        """
        设置缓存结果

        Args:
            prompt: 提示词
            result: 结果
            ttl: 生存时间(秒)
            **kwargs: 其他参数
        """
        key = self._generate_key(prompt, **kwargs)

        data = {
            "result": result,
            "timestamp": datetime.now().isoformat()
        }

        self.redis.setex(key, ttl, json.dumps(data))

    def clear(self):
        """
        清空缓存
        """
        for key in self.redis.scan_iter("result:*"):
            self.redis.delete(key)

# 使用示例
# result_cache = ResultCache()
# cached_result = result_cache.get("什么是机器学习?", max_tokens=200)
# if cached_result is None:
#     result = generate_response("什么是机器学习?", max_tokens=200)
#     result_cache.set("什么是机器学习?", result, ttl=3600, max_tokens=200)

4. 练习题

基础练习

  1. 实现束搜索

    Python
    # TODO: 实现束搜索
    def beam_search(model, input_ids, beam_width=5, max_length=50):
        # 你的代码
        pass
    

  2. 实现Top-K采样

    Python
    # TODO: 实现Top-K采样
    def top_k_sampling(model, input_ids, top_k=50, max_length=50):
        # 你的代码
        pass
    

进阶练习

  1. 实现提示词优化

    Python
    # TODO: 实现提示词优化
    class PromptOptimizer:
        def __init__(self, model, tokenizer):
            # 你的代码
            pass
    
        def optimize(self, prompt, target_output):
            # 你的代码
            pass
    

  2. 实现KV缓存

    Python
    # TODO: 实现KV缓存
    class KVCache:
        def __init__(self, max_seq_len, num_heads, head_dim):
            # 你的代码
            pass
    
        def update(self, k, v):
            # 你的代码
            pass
    
        def get(self):
            # 你的代码
            pass
    

项目练习

  1. 创建推理优化框架
  2. 支持多种搜索策略
  3. 实现提示词优化
  4. 集成缓存机制

5. 最佳实践

✅ 推荐做法

  1. 选择合适的搜索策略
  2. 贪心搜索:快速但可能次优
  3. 束搜索:平衡质量和速度
  4. 采样:增加多样性

  5. 优化提示词

  6. 使用清晰的指令
  7. 提供示例
  8. 迭代优化

  9. 利用缓存

  10. 缓存KV值
  11. 缓存结果
  12. 设置合理的TTL

❌ 避免做法

  1. 过度依赖搜索
  2. 不要使用过大的束宽
  3. 考虑计算成本
  4. 平衡质量和速度

  5. 忽略提示词质量

  6. 花时间优化提示词
  7. 测试不同变体
  8. 收集反馈

  9. 缓存失效

  10. 定期清理缓存
  11. 设置合理的TTL
  12. 监控缓存命中率

6. 总结

本章介绍了推理优化的核心技术:

  • 搜索策略: 贪心、束搜索、采样
  • 提示词优化: 模板、工程、缓存
  • 缓存机制: KV缓存、推理缓存、结果缓存

这些技术可以显著提升DeepSeek R1的推理性能和质量。

7. 下一步

继续学习09-提示词工程与调优,深入了解提示词工程的实践方法。