跳转至

推理优化

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

推理优化图

📌 定位说明:本章侧重推理优化的工程部署实践。 - 📖 推理优化的算法原理与技术研究请参考 LLM学习/02-大模型核心技术/02-推理优化技术

📖 章节导读

推理优化是通过各种技术和方法提高大模型推理速度、降低资源消耗的关键技术。本章将深入探讨推理优化的原理、方法和实践。

🎯 学习目标

  • 理解推理优化的核心原理
  • 掌握KV Cache等关键技术
  • 学会使用批处理和流水线优化
  • 了解编译优化技术
  • 掌握 vLLM 与 SGLang 推理框架的使用与选型
  • 理解 Speculative Decoding 的原理与加速效果
  • 掌握大厂面试中的相关问题

12.1 推理优化概述

12.1.1 为什么需要推理优化

挑战:

  1. 计算成本高:
  2. 大模型参数量大
  3. 计算复杂度高
  4. 需要大量GPU资源

  5. 延迟要求高:

  6. 实时应用需要低延迟
  7. 用户体验要求快速响应
  8. 竞争需要高性能

  9. 并发需求大:

  10. 多用户同时访问
  11. 需要处理高并发
  12. 需要弹性扩展

  13. 成本压力大:

  14. GPU成本高
  15. 能耗成本高
  16. 需要成本优化

优化目标:

  1. 降低延迟:减少推理时间
  2. 提高吞吐量:增加并发处理能力
  3. 降低成本:减少资源消耗
  4. 保持精度:优化不损失精度

12.1.2 优化层次

优化层次:

Text Only
算法层 → 模型层 → 框架层 → 硬件层

各层优化:

  1. 算法层:
  2. KV Cache
  3. Flash Attention
  4. 算法优化

  5. 模型层:

  6. 模型量化
  7. 模型剪枝
  8. 知识蒸馏

  9. 框架层:

  10. 框架优化
  11. 算子融合
  12. 内存优化

  13. 硬件层:

  14. GPU优化
  15. 专用芯片
  16. 分布式计算

12.1.3 优化指标

关键指标:

  1. 延迟(Latency):
  2. 单次推理时间
  3. 首字延迟(TTFT)
  4. 总生成时间

  5. 吞吐量(Throughput):

  6. 每秒请求数(RPS)
  7. 每秒Token数
  8. 并发处理能力

  9. 资源利用率:

  10. GPU利用率
  11. 显存占用
  12. 能耗

  13. 成本:

  14. 每次推理成本
  15. 每小时成本
  16. 总成本

12.2 KV Cache

12.2.1 KV Cache原理

定义:KV Cache是一种缓存机制,用于缓存Transformer模型中自注意力层的Key和Value矩阵,避免重复计算。

原理:

在自回归生成过程中,每个token的生成都需要计算之前所有token的注意力。KV Cache将之前计算的K和V缓存起来,后续生成时直接使用,避免重复计算。

数学表示:

对于第t个token,注意力计算为:

Text Only
Attention(Q_t, K_{1:t}, V_{1:t})

使用KV Cache后:

Text Only
Attention(Q_t, Cache_K, Cache_V)

其中: - Cache_K: 缓存的K矩阵 - Cache_V: 缓存的V矩阵

优化效果:

  • 计算量:减少约50%
  • 显存占用:增加约30%
  • 速度提升:2-3倍

12.2.2 KV Cache实现

基础实现:

Python
import torch
import torch.nn as nn

class KVCache:
    """KV Cache实现"""

    def __init__(self, max_batch_size: int, max_seq_len: int, num_heads: int, head_dim: int):
        """
        初始化KV Cache

        Args:
            max_batch_size: 最大批次大小
            max_seq_len: 最大序列长度
            num_heads: 注意力头数
            head_dim: 每个头的维度
        """
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim

        # 初始化缓存
        self.cache_k = torch.zeros(
            max_batch_size,
            num_heads,
            max_seq_len,
            head_dim
        )
        self.cache_v = torch.zeros(
            max_batch_size,
            num_heads,
            max_seq_len,
            head_dim
        )

        # 当前位置
        self.current_pos = 0

    def update(self, k: torch.Tensor, v: torch.Tensor):
        """
        更新缓存

        Args:
            k: Key张量 [batch_size, num_heads, seq_len, head_dim]
            v: Value张量 [batch_size, num_heads, seq_len, head_dim]
        """
        batch_size = k.size(0)
        seq_len = k.size(2)

        # 更新缓存
        self.cache_k[:batch_size, :, self.current_pos:self.current_pos+seq_len, :] = k
        self.cache_v[:batch_size, :, self.current_pos:self.current_pos+seq_len, :] = v

        # 更新位置
        self.current_pos += seq_len

    def get(self) -> tuple:
        """
        获取缓存

        Returns:
            (k, v) 缓存的Key和Value
        """
        k = self.cache_k[:, :, :self.current_pos, :]
        v = self.cache_v[:, :, :self.current_pos, :]
        return k, v

    def clear(self):
        """清空缓存"""
        self.current_pos = 0
        self.cache_k.zero_()
        self.cache_v.zero_()

# 使用示例
kv_cache = KVCache(
    max_batch_size=1,
    max_seq_len=1024,
    num_heads=12,
    head_dim=64
)

# 生成过程
for i in range(10):
    # 模拟计算K和V
    k = torch.randn(1, 12, 1, 64)
    v = torch.randn(1, 12, 1, 64)

    # 更新缓存
    kv_cache.update(k, v)

    # 获取缓存用于注意力计算
    cached_k, cached_v = kv_cache.get()

    # 使用缓存的K和V进行注意力计算
    # ...

12.2.3 KV Cache优化

优化策略:

  1. PagedAttention:
  2. 将KV Cache分页管理
  3. 动态分配显存
  4. 减少显存浪费

  5. KV Cache压缩:

  6. 压缩KV Cache
  7. 减少显存占用
  8. 平衡精度和效率

  9. 多级缓存:

  10. 使用多级缓存
  11. 提高缓存命中率
  12. 减少计算量

代码实现:

Python
class PagedKVCache:
    """分页KV Cache"""

    def __init__(self, num_heads: int, head_dim: int, page_size: int = 128):
        """
        初始化分页KV Cache

        Args:
            num_heads: 注意力头数
            head_dim: 每个头的维度
            page_size: 页大小
        """
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.page_size = page_size

        # 页表
        self.page_table = {}

        # 页池
        self.page_pool_k = []
        self.page_pool_v = []

    def allocate_page(self) -> int:
        """分配页"""
        if not self.page_pool_k:
            # 创建新页
            page_k = torch.zeros(1, self.num_heads, self.page_size, self.head_dim)
            page_v = torch.zeros(1, self.num_heads, self.page_size, self.head_dim)
            self.page_pool_k.append(page_k)
            self.page_pool_v.append(page_v)
            return len(self.page_pool_k) - 1
        else:
            # 复用空闲页
            return self._find_free_page()

    def update(self, k: torch.Tensor, v: torch.Tensor, seq_idx: int):
        """
        更新缓存

        Args:
            k: Key张量
            v: Value张量
            seq_idx: 序列索引
        """
        # 计算页索引
        page_idx = seq_idx // self.page_size
        offset = seq_idx % self.page_size

        # 分配页
        if page_idx not in self.page_table:
            self.page_table[page_idx] = self.allocate_page()

        # 更新页
        actual_page_idx = self.page_table[page_idx]
        self.page_pool_k[actual_page_idx][:, :, offset:offset+1, :] = k
        self.page_pool_v[actual_page_idx][:, :, offset:offset+1, :] = v

    def get(self, seq_len: int) -> tuple:
        """
        获取缓存

        Args:
            seq_len: 序列长度

        Returns:
            (k, v) 缓存的Key和Value
        """
        # 计算需要的页数
        num_pages = (seq_len + self.page_size - 1) // self.page_size

        # 收集页
        k_pages = []
        v_pages = []

        for page_idx in range(num_pages):
            if page_idx in self.page_table:
                actual_page_idx = self.page_table[page_idx]
                k_pages.append(self.page_pool_k[actual_page_idx])
                v_pages.append(self.page_pool_v[actual_page_idx])

        # 拼接页
        k = torch.cat(k_pages, dim=2)[:, :, :seq_len, :]
        v = torch.cat(v_pages, dim=2)[:, :, :seq_len, :]

        return k, v

12.3 批处理优化

12.3.1 批处理原理

定义:批处理是将多个请求组合成一个批次一起处理,以提高GPU利用率和吞吐量。

优势:

  1. 提高GPU利用率:
  2. GPU擅长并行计算
  3. 批处理充分利用GPU
  4. 提高计算效率

  5. 降低延迟:

  6. 分摊固定开销
  7. 减少启动时间
  8. 提高吞吐量

  9. 降低成本:

  10. 提高资源利用率
  11. 降低每次推理成本
  12. 优化整体成本

挑战:

  1. 序列长度不同:
  2. 不同请求长度不同
  3. 需要Padding
  4. 浪费计算资源

  5. 动态批次:

  6. 请求到达时间不同
  7. 需要动态组批
  8. 增加复杂度

12.3.2 连续批处理

原理:连续批处理(Continuous Batching)允许不同长度的序列在同一个批次中处理,避免Padding浪费。

实现:

Python
import torch
from collections import deque

class ContinuousBatchProcessor:
    """连续批处理器"""

    def __init__(self, max_batch_size: int, max_seq_len: int):
        """
        初始化连续批处理器

        Args:
            max_batch_size: 最大批次大小
            max_seq_len: 最大序列长度
        """
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        # 请求队列
        self.request_queue = deque()  # deque双端队列,两端操作O(1)

        # 活跃请求
        self.active_requests = {}

    def add_request(self, request_id: str, input_ids: torch.Tensor):
        """
        添加请求

        Args:
            request_id: 请求ID
            input_ids: 输入token IDs
        """
        self.request_queue.append({
            'id': request_id,
            'input_ids': input_ids,
            'seq_len': input_ids.size(0)
        })

    def get_batch(self) -> tuple[torch.Tensor, list[str]]:
        """
        获取批次

        Returns:
            (batch_input_ids, request_ids)
        """
        if not self.request_queue:
            return None, []

        # 选择请求组成批次
        batch_requests = []
        total_tokens = 0

        while self.request_queue and len(batch_requests) < self.max_batch_size:
            request = self.request_queue[0]

            # 检查是否超过最大序列长度
            if total_tokens + request['seq_len'] > self.max_seq_len:
                break

            # 添加到批次
            batch_requests.append(self.request_queue.popleft())
            total_tokens += request['seq_len']

        if not batch_requests:
            return None, []

        # 准备批次数据
        max_len = max(req['seq_len'] for req in batch_requests)
        batch_input_ids = torch.zeros(
            len(batch_requests),
            max_len,
            dtype=torch.long
        )

        request_ids = []
        for i, req in enumerate(batch_requests):  # enumerate同时获取索引和元素
            batch_input_ids[i, :req['seq_len']] = req['input_ids']
            request_ids.append(req['id'])

        return batch_input_ids, request_ids

    def complete_request(self, request_id: str):
        """
        完成请求

        Args:
            request_id: 请求ID
        """
        if request_id in self.active_requests:
            del self.active_requests[request_id]

# 使用示例
processor = ContinuousBatchProcessor(
    max_batch_size=4,
    max_seq_len=1024
)

# 添加请求
processor.add_request("req1", torch.randint(0, 1000, (50,)))
processor.add_request("req2", torch.randint(0, 1000, (100,)))
processor.add_request("req3", torch.randint(0, 1000, (75,)))

# 获取批次
batch_input_ids, request_ids = processor.get_batch()
print(f"批次大小: {len(request_ids)}")
print(f"批次形状: {batch_input_ids.shape}")

12.3.3 动态批处理

原理:动态批处理根据请求到达时间动态组批,平衡延迟和吞吐量。

实现:

Python
import time
import asyncio  # Python标准异步库

class DynamicBatchProcessor:
    """动态批处理器"""

    def __init__(self, max_batch_size: int, max_wait_time: float = 0.01):
        """
        初始化动态批处理器

        Args:
            max_batch_size: 最大批次大小
            max_wait_time: 最大等待时间(秒)
        """
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time

        # 请求队列
        self.request_queue = []
        self.last_batch_time = time.time()

    async def add_request(self, request_id: str, input_ids: torch.Tensor):  # async def定义协程函数
        """
        添加请求

        Args:
            request_id: 请求ID
            input_ids: 输入token IDs
        """
        self.request_queue.append({
            'id': request_id,
            'input_ids': input_ids,
            'seq_len': input_ids.size(0),
            'arrival_time': time.time()
        })

    async def get_batch(self) -> tuple[torch.Tensor, list[str]]:
        """
        获取批次

        Returns:
            (batch_input_ids, request_ids)
        """
        if not self.request_queue:
            return None, []

        current_time = time.time()
        time_since_last_batch = current_time - self.last_batch_time

        # 检查是否应该组批
        should_batch = (
            len(self.request_queue) >= self.max_batch_size or
            time_since_last_batch >= self.max_wait_time
        )

        if not should_batch:
            return None, []

        # 组批
        batch_requests = self.request_queue[:self.max_batch_size]
        self.request_queue = self.request_queue[len(batch_requests):]
        self.last_batch_time = current_time

        # 准备批次数据
        max_len = max(req['seq_len'] for req in batch_requests)
        batch_input_ids = torch.zeros(
            len(batch_requests),
            max_len,
            dtype=torch.long
        )

        request_ids = []
        for i, req in enumerate(batch_requests):
            batch_input_ids[i, :req['seq_len']] = req['input_ids']
            request_ids.append(req['id'])

        return batch_input_ids, request_ids

# 使用示例
async def main():
    processor = DynamicBatchProcessor(
        max_batch_size=4,
        max_wait_time=0.01
    )

    # 添加请求
    await processor.add_request("req1", torch.randint(0, 1000, (50,)))  # await等待异步操作完成
    await asyncio.sleep(0.005)
    await processor.add_request("req2", torch.randint(0, 1000, (100,)))

    # 获取批次
    batch_input_ids, request_ids = await processor.get_batch()
    print(f"批次大小: {len(request_ids)}")

# 运行
asyncio.run(main())  # 创建事件循环运行顶层协程

12.4 编译优化

12.4.1 TorchScript

原理:TorchScript将PyTorch模型转换为静态图,优化推理性能。

实现:

Python
import torch
import torch.nn as nn

# 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()  # super()调用父类方法
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

# 创建模型
model = SimpleModel()

# 转换为TorchScript
scripted_model = torch.jit.script(model)

# 保存模型
scripted_model.save("model_scripted.pt")

# 加载并推理
loaded_model = torch.jit.load("model_scripted.pt")
input_data = torch.randn(1, 10)
output = loaded_model(input_data)

12.4.2 TensorRT

原理:TensorRT是NVIDIA的高性能推理优化器,支持多种优化技术。

实现:

Python
import torch
from torch2trt import torch2trt
import tensorrt as trt

# 创建模型
model = SimpleModel()
model.eval()

# 示例输入
example_input = torch.randn(1, 10).cuda()

# 转换为TensorRT
trt_model = torch2trt(
    model,
    [example_input],
    fp16_mode=True,
    max_workspace_size=1 << 25
)

# 保存模型
torch.save(trt_model.state_dict(), "model_trt.pth")

# 推理
output = trt_model(example_input)

12.4.3 ONNX Runtime

原理:ONNX Runtime是一个跨平台的高性能推理引擎。

实现:

Python
import torch
import onnx
import onnxruntime as ort

# 导出模型为ONNX
model = SimpleModel()
model.eval()

dummy_input = torch.randn(1, 10)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}}
)

# 使用ONNX Runtime推理
ort_session = ort.InferenceSession("model.onnx")

input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

# 推理
input_data = torch.randn(1, 10).numpy()
outputs = ort_session.run([output_name], {input_name: input_data})

12.5 推理框架:vLLM 与 SGLang

推理框架是连接模型与线上服务的关键基础设施。vLLM 和 SGLang 是当前最主流的两个开源 LLM 推理框架。

12.5.1 vLLM 简介

vLLM 由 UC Berkeley Sky Lab 团队开发,以 PagedAttention 为核心,是目前社区使用最广泛的推理引擎。

核心特性: - PagedAttention:将 KV Cache 按页管理(类似操作系统虚拟内存),消除显存碎片 - Continuous Batching:动态插入/移除请求,最大化 GPU 利用率 - Tensor Parallelism / Pipeline Parallelism:多卡推理支持 - OpenAI 兼容 API/v1/chat/completions 直接对接现有代码

Python
# vLLM 快速启动服务
from vllm import LLM, SamplingParams

llm = LLM(model="Qwen/Qwen2.5-7B-Instruct", tensor_parallel_size=1)
params = SamplingParams(temperature=0.7, max_tokens=512)

prompts = ["解释什么是Transformer", "写一首关于AI的诗"]
outputs = llm.generate(prompts, params)

for output in outputs:
    print(output.outputs[0].text)

12.5.2 SGLang 详解

SGLang(Structured Generation Language)由 UC Berkeley LMSYS 团队开发,专注于结构化生成高效前缀共享,在多轮对话和 Agent 场景中性能突出。

核心特性

特性 说明
RadixAttention 基于 Radix Tree 自动管理前缀 KV Cache,相同前缀的请求共享缓存,无需手动管理
Compressed FSM 使用压缩有限状态机约束结构化输出(JSON Schema),比逐 token 采样快 数倍
Torch.compile 集成 自动编译计算图,Llama 70B 推理速度提升 ~20%
Chunked Prefill 将长前缀分块预填充,避免长 prompt 阻塞短请求
FP8 量化推理 原生支持 FP8 W8A8 量化,兼顾精度与速度

RadixAttention 原理

Text Only
传统方式:每个请求独立计算全部 KV Cache
  请求A: [System Prompt] + [User A] → 独立 KV Cache
  请求B: [System Prompt] + [User B] → 独立 KV Cache(重复计算 System Prompt)

RadixAttention:自动识别共享前缀,复用 KV Cache
  Radix Tree:
    [System Prompt] ── KV Cache(共享)
        ├── [User A] → 增量 KV Cache
        └── [User B] → 增量 KV Cache
  → 节省 ~30-60% 显存,吞吐提升 2-5x(前缀越长效果越明显)

SGLang Engine 代码示例

Python
import sglang as sgl

# 方式1:离线批量推理
llm = sgl.Engine(model_path="Qwen/Qwen2.5-7B-Instruct")

prompts = ["什么是大模型推理优化?", "解释KV Cache的作用"]
outputs = llm.generate(prompts, sampling_params={"temperature": 0.7, "max_new_tokens": 256})
for out in outputs:
    print(out["text"])

llm.shutdown()
Python
# 方式2:启动 OpenAI 兼容服务器
# 终端执行:
# python -m sglang.launch_server --model Qwen/Qwen2.5-7B-Instruct --port 8000

# 客户端调用(与 vLLM/OpenAI 完全兼容)
import openai
client = openai.OpenAI(base_url="http://localhost:8000/v1", api_key="none")

response = client.chat.completions.create(
    model="Qwen/Qwen2.5-7B-Instruct",
    messages=[{"role": "user", "content": "用JSON格式输出三个中国城市的信息"}],
    response_format={"type": "json_object"},  # 结构化输出
)
print(response.choices[0].message.content)

JSON Schema 约束输出

Python
# SGLang 原生支持 JSON Schema 约束(Compressed FSM 加速)
from sglang import Engine

llm = Engine(model_path="Qwen/Qwen2.5-7B-Instruct")

json_schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "integer", "minimum": 0},
        "skills": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["name", "age", "skills"]
}

output = llm.generate(
    "生成一个AI工程师的简介",
    sampling_params={"max_new_tokens": 200, "json_schema": json_schema}
)
print(output[0]["text"])  # 保证输出符合 JSON Schema

12.5.3 vLLM vs SGLang 对比

对比维度 vLLM SGLang
开发团队 UC Berkeley Sky Lab UC Berkeley LMSYS
核心创新 PagedAttention RadixAttention + Compressed FSM
KV Cache 管理 分页内存管理 Radix Tree 前缀共享
结构化输出 基础支持(Outlines 集成) 原生 Compressed FSM,速度快数倍
多轮对话 普通性能 前缀共享显著加速
编译优化 有限 Torch.compile 深度集成
社区生态 更成熟,文档丰富 快速增长中
模型支持 更广泛 主流模型均支持
吞吐量(共享前缀场景) 基准 1.5-5x 提升
首 Token 延迟 基准 通常更低(Chunked Prefill)
API 兼容性 OpenAI 兼容 OpenAI 兼容

选择建议

Text Only
选择 vLLM 的场景:
  ✅ 需要最广泛的模型兼容性
  ✅ 团队已有 vLLM 经验和部署
  ✅ 简单的单轮问答服务

选择 SGLang 的场景:
  ✅ Agent 多轮调用(大量共享 System Prompt)
  ✅ 需要 JSON/结构化输出约束
  ✅ 共享前缀的批量请求(如同一文档的多个问题)
  ✅ 追求极致吞吐和低延迟

12.6 Speculative Decoding(投机解码)

Speculative Decoding 是一种利用小模型加速大模型推理的技术,在不损失输出质量的前提下显著降低延迟。

12.6.1 核心原理

思路:用一个小而快的 Draft Model 先「猜」多个 token,再用大模型一次性验证,接受正确的 token、拒绝错误的 token。

Text Only
传统自回归解码(每步 1 token):
  大模型 → t1 → 大模型 → t2 → 大模型 → t3 → ... (N 步)

Speculative Decoding(每步可能接受 γ 个 token):
  小模型 → [t1, t2, t3, t4, t5](猜 γ=5 个)
  大模型 → 一次前向传播验证 → 接受 [t1, t2, t3],拒绝 t4
  → 平均每步 >1 token,加速 2-3x

12.6.2 数学保证

Speculative Decoding 的关键优势是无损——输出分布与直接使用大模型完全一致

验证策略基于拒绝采样: - 对 Draft Model 生成的 token \(t_i\),计算接受概率:

\[P_{\text{accept}} = \min\left(1, \frac{p_{\text{target}}(t_i)}{p_{\text{draft}}(t_i)}\right)\]
  • 若拒绝,从修正分布中重新采样:
\[p_{\text{resample}}(t) = \text{norm}\left(\max\left(0, p_{\text{target}}(t) - p_{\text{draft}}(t)\right)\right)\]

12.6.3 实现示例

Python
import torch

def speculative_decode(target_model, draft_model, input_ids, gamma=5, temperature=1.0):
    """
    Speculative Decoding 简化实现
    target_model: 大模型(如 70B)
    draft_model: 小模型(如 7B,同系列)
    gamma: 每轮猜测的 token 数
    """
    generated = input_ids.clone()

    while len(generated[0]) < input_ids.shape[1] + 100:  # 简化:生成100个token
        # Step 1: Draft Model 自回归生成 γ 个候选 token
        draft_tokens = []
        draft_probs = []
        draft_input = generated.clone()

        for _ in range(gamma):
            with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
                logits = draft_model(draft_input).logits[:, -1, :] / temperature
                probs = torch.softmax(logits, dim=-1)
                token = torch.multinomial(probs, 1)
                draft_tokens.append(token)
                draft_probs.append(probs)
                draft_input = torch.cat([draft_input, token], dim=-1)

        # Step 2: Target Model 一次前向传播验证所有候选
        with torch.no_grad():
            all_input = torch.cat([generated] + draft_tokens, dim=-1)
            target_logits = target_model(all_input).logits / temperature

        # Step 3: 逐个验证,基于拒绝采样
        accepted = 0
        for i in range(gamma):
            target_prob = torch.softmax(target_logits[:, -(gamma - i + 1), :], dim=-1)
            draft_prob = draft_probs[i]
            token = draft_tokens[i]

            # 接受概率
            accept_prob = torch.min(
                torch.ones_like(target_prob),
                target_prob / (draft_prob + 1e-10)
            )

            if torch.rand(1).item() < accept_prob[0, token[0, 0]].item():
                generated = torch.cat([generated, token], dim=-1)
                accepted += 1
            else:
                # 从修正分布采样
                residual = torch.clamp(target_prob - draft_prob, min=0)
                residual = residual / residual.sum(dim=-1, keepdim=True)
                new_token = torch.multinomial(residual, 1)
                generated = torch.cat([generated, new_token], dim=-1)
                break

        if accepted == gamma:
            # 所有猜测都被接受,额外从 target 采样一个
            bonus_prob = torch.softmax(target_logits[:, -1, :], dim=-1)
            bonus_token = torch.multinomial(bonus_prob, 1)
            generated = torch.cat([generated, bonus_token], dim=-1)

    return generated

12.6.4 加速效果与适用场景

场景 加速比 说明
Draft = 同系列小模型 2-3x 如 Llama-70B + Llama-7B
Draft = 自身浅层层 1.5-2x Self-Speculative Decoding
高温度采样(创意写作) 1.3-1.5x 接受率低,加速有限
低温度/贪心解码(代码/翻译) 2.5-3.5x 接受率高,加速显著

适用条件: - ✅ Draft Model 与 Target Model 分布相近(同系列效果最好) - ✅ 解码为主的任务(非长 prefill) - ✅ 模型越大加速越明显(大模型前向传播开销大) - ❌ 不适合 batch 很大的高吞吐场景(额外的 draft 计算抵消收益)

12.7 练习题

练习题1:KV Cache

题目:实现一个简单的KV Cache机制。

参考答案:

Python
class SimpleKVCache:
    def __init__(self, num_heads, head_dim, max_seq_len):
        self.cache_k = torch.zeros(num_heads, max_seq_len, head_dim)
        self.cache_v = torch.zeros(num_heads, max_seq_len, head_dim)
        self.current_pos = 0

    def update(self, k, v):
        self.cache_k[:, self.current_pos, :] = k
        self.cache_v[:, self.current_pos, :] = v
        self.current_pos += 1

    def get(self):
        return self.cache_k[:, :self.current_pos, :], self.cache_v[:, :self.current_pos, :]

练习题2:批处理

题目:实现一个简单的批处理器。

参考答案:

Python
class SimpleBatchProcessor:
    def __init__(self, max_batch_size):
        self.max_batch_size = max_batch_size
        self.request_queue = []

    def add_request(self, request_id, input_ids):
        self.request_queue.append((request_id, input_ids))

    def get_batch(self):
        if not self.request_queue:
            return None, []

        batch_requests = self.request_queue[:self.max_batch_size]
        self.request_queue = self.request_queue[len(batch_requests):]

        request_ids = [req[0] for req in batch_requests]
        batch_data = torch.stack([req[1] for req in batch_requests])

        return batch_data, request_ids

12.8 面试准备

12.8.1 大厂面试题

字节跳动面试题:

  1. 问题:什么是KV Cache?它有什么优势?

参考答案: - KV Cache缓存自注意力的Key和Value - 优势: - 避免重复计算 - 减少计算量约50% - 提升推理速度2-3倍 - 增加显存占用约30%

  1. 问题:连续批处理相比传统批处理有什么优势?

参考答案: - 避免Padding浪费 - 提高GPU利用率 - 降低计算成本 - 提高吞吐量

腾讯面试题:

  1. 问题:如何优化大模型的推理性能?

参考答案: - KV Cache:缓存注意力计算 - 批处理:提高GPU利用率 - 模型量化:减少计算量 - 编译优化:优化执行图 - 硬件优化:使用专用硬件

  1. 问题:PagedAttention有什么优势?

参考答案: - 动态分配显存 - 减少显存浪费 - 支持更长的序列 - 提高显存利用率

阿里巴巴面试题:

  1. 问题:推理优化有哪些层次?

参考答案: - 算法层:KV Cache、Flash Attention - 模型层:量化、剪枝、蒸馏 - 框架层:算子融合、内存优化 - 硬件层:GPU优化、专用芯片

  1. 问题:在实际项目中如何优化推理性能?

参考答案: - 性能分析:识别瓶颈 - 选择优化:选择合适的优化技术 - 实现优化:实现优化方案 - 测试验证:测试优化效果 - 持续监控:监控性能指标 - 持续优化:根据反馈持续优化

12.8.2 面试技巧

技巧1:理论联系实际

结合实际项目经验,说明如何应用推理优化。

技巧2:性能分析

展示性能分析的方法和工具。

技巧3:优化选择

说明如何选择合适的优化技术。

技巧4:效果评估

说明如何评估优化效果。

📝 本章小结

本章系统介绍了推理优化的核心内容:

  1. ✅ 推理优化概述:为什么需要、优化层次、优化指标
  2. ✅ KV Cache:原理、实现、优化
  3. ✅ 批处理优化:批处理原理、连续批处理、动态批处理
  4. ✅ 编译优化:TorchScript、TensorRT、ONNX Runtime
  5. ✅ 推理框架:vLLM 与 SGLang 对比及选型
  6. ✅ Speculative Decoding:投机解码原理与实现
  7. ✅ 练习题:KV Cache、批处理
  8. ✅ 面试准备:大厂面试题和解答技巧

通过本章学习,你应该能够: - 理解推理优化的核心原理 - 掌握KV Cache等关键技术 - 学会使用批处理和流水线优化 - 了解编译优化技术 - 能够根据场景选择 vLLM 或 SGLang - 理解 Speculative Decoding 的数学保证与工程实现 - 准备好应对大厂面试

🔗 下一步

下一章我们将深入学习多模态应用,掌握如何处理和生成多模态数据。

继续学习: 13-多模态应用.md

💡 思考题

  1. 什么是KV Cache?它有什么优势?

    缓存已计算的Key/Value向量,避免每次生成新Token时重复计算前文的Attention。优势:生成速度提升数十倍(从O(n²)降为O(n))。挑战:KV Cache显存随序列长度线性增长(128K context可占数十GB),PagedAttention/GQA可缓解。

  2. 连续批处理相比传统批处理有什么优势?

    传统(Static Batching):等所有请求生成完才返回,最慢请求拖慢全扑0。连续(Continuous Batching):请求完成即时释放资源并插入新请求,GPU利用率提升10-20x。vLLM/TGI均默认启用。核心思想:让GPU永远不闲置。

  3. 如何优化大模型的推理性能?

    分层优化:①模型层(量化、蒸馏、剪枝) ②注意力层(FlashAttention、GQA、MQA、Sliding Window) ③解码层(KV Cache、Speculative Decoding、Early Exit) ④服务层(连续批处理、PagedAttention、请求调度) ⑤硬件层(TensorRT编译、Tensor并行)。开箱即用:vLLM已集成多数优化。

  4. PagedAttention有什么优势?

    借Linux虚拟内存分页思想管理KV Cache:将连续的KV Cache分为固定大小的Block(页),按需分配。优势:①显存浪费从60-80%降低刼<4% ②尾部填充减少 ③支持共享Prefix(多请求复用系统提示词的KV Cache) ④同等显存下并发请求数提升5x+。这是vLLM的核心创新。

  5. 在实际项目中如何优化推理性能?

    路线:①先用vLLM部署(开箱即用大部分优化) ②测量throughput和P99 latency基线 ③尝试量化(AWQ/GPTQ)观察性能/质量权衡 ④开启Speculative Decoding(若时延敏感) ⑤调整batch_size/max_num_seqs找最优点 ⑥TensorRT编译(NVIDIA平台)。监控:token/s、首Token延迟(TTFT)、GPU利用率。

  6. SGLang的RadixAttention相比vLLM的PagedAttention有什么优势?适用于什么场景?

    RadixAttention用Radix Tree(基数树)管理KV Cache,实现前缀级别的自动复用。优势:多轮对话/多路并行场景下KV Cache命中率更高,LRU自动淘汰无需手动管理。适用场景:大量Shared Prefix(同一系统提示词)、Tree of Thoughts推理、多轮Agent对话。PagedAttention更通用,RadixAttention在特定模式下更高效。

  7. Speculative Decoding为什么能保证输出分布无损?它的加速比受哪些因素影响?

    采用拒绝采样(rejection sampling)策略:草稿模型生成候选token,目标模型并行验证,按min(1, p_target/p_draft)概率接受。数学证明接受的token服从目标模型分布。加速比受:①草稿模型接受率(与目标模型越接近越好) ②草稿模型速度(越小越快) ③推测token数(通常4-8) ④任务确定性(确定性任务加速比更高)。典型加速2-3x。

📚 参考资料

  1. "Efficient Attention: Attention with Linear Complexities" - Katharopoulos et al.
  2. "FlashAttention: Fast and Memory-Efficient Exact Attention" - Dao et al.
  3. "PagedAttention: Efficient Attention for Long Sequences" - Kwon et al.
  4. vLLM Documentation
  5. SGLang Documentation - https://sgl-project.github.io/
  6. "Efficient Memory Management for Large Language Model Serving with PagedAttention" - Kwon et al.
  7. "Fast Inference of Mixture-of-Experts Language Models with Offloading" - Eliseev & Mazur
  8. "Accelerating Large Language Model Decoding with Speculative Sampling" - Leviathan et al.
  9. TensorRT Documentation

最后更新日期:2026-02-12 适用版本:LLM应用指南 v2026