第13章 Speculative Decoding与推理加速¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
前置章节:第12章 FlashAttention原理与实现 中我们学习了如何通过优化注意力计算本身来降低内存占用与提升吞吐。本章聚焦另一个推理侧的核心瓶颈——自回归解码的串行性,介绍 Speculative Decoding(投机解码)及其衍生方法,这是 2025-2026 年大模型推理优化面试中的绝对热点。
13.1 自回归解码的瓶颈¶
13.1.1 自回归生成的本质¶
Transformer 语言模型的文本生成过程是严格串行的:
每一步只能生成 1 个 token,后一个 token 的生成依赖前一个 token 的输出。对于一个 \(T\) 个 token 的回复,模型需要执行 \(T\) 次前向传播。
Step 1: [prompt] → token_1
Step 2: [prompt, t1] → token_2
Step 3: [prompt, t1, t2] → token_3
...(必须串行,无法跳步)
13.1.2 Memory-Bound 问题¶
这是理解 Speculative Decoding 为何有效的关键。
Prefill 阶段 vs Decode 阶段的计算特性完全不同:
| 特性 | Prefill(预填充) | Decode(逐token解码) |
|---|---|---|
| 输入规模 | 整段 prompt(数百~数千 token 并行) | 单个 token |
| 计算密度 | Compute-bound(矩阵-矩阵乘) | Memory-bound(矩阵-向量乘) |
| GPU 利用率 | 高(40-80%) | 极低(<5%) |
| 瓶颈 | 计算算力(FLOPs) | 显存带宽(GB/s) |
逐 token 解码时,每一步需要:
- 读取全部模型参数——对 70B 模型约 140GB(FP16)
- 读写 KV-Cache——随序列长度线性增长
- 仅做一次矩阵-向量乘——计算量极少
# 直观理解 Memory-bound
# 70B模型 Decode阶段每步:
model_params = 70e9 * 2 # 140GB (FP16)
kv_cache_per_step = 0.5 # ~0.5GB (假设,实际取决于序列长度和hidden_size)
total_read = model_params + kv_cache_per_step # ~140.5GB
# A100 80GB GPU 带宽 = 2TB/s
bandwidth = 2e12 # bytes/s
min_time_per_token = total_read / bandwidth # ~70ms
# 注:这是简化假设,实际读取时间受KV Cache命中、模型分片等因素影响
# 即便计算瞬间完成,仅读取权重就需要 ~70ms → ~14 tokens/s
# 但实际计算量极小:
flops_per_token = 2 * 70e9 # ~140 GFLOPs
a100_flops = 312e12 # 312 TFLOPS (FP16)
compute_time = flops_per_token / a100_flops # ~0.45ms
# 计算时间 0.45ms vs 读取时间 70ms → 计算资源严重浪费!
13.1.3 批处理解码 vs 单 Token 解码¶
如果我们能同时验证多个 token(而非一个个生成),就能更好利用 GPU:
import torch
import time
def benchmark_matmul():
"""对比单向量 vs 批量矩阵乘的 GPU利用率"""
device = torch.device('cuda')
weight = torch.randn(4096, 4096, device=device, dtype=torch.float16)
# 单 token 解码:矩阵-向量乘 (4096, 4096) x (4096, 1)
x_single = torch.randn(1, 4096, device=device, dtype=torch.float16)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(1000):
_ = x_single @ weight.T
torch.cuda.synchronize()
t_single = (time.perf_counter() - t0) / 1000
# 批量验证:矩阵-矩阵乘 (4096, 4096) x (4096, 8)
x_batch = torch.randn(8, 4096, device=device, dtype=torch.float16)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(1000):
_ = x_batch @ weight.T
torch.cuda.synchronize()
t_batch = (time.perf_counter() - t0) / 1000
print(f"单token: {t_single*1e6:.1f} μs (1 token)")
print(f"批量8: {t_batch*1e6:.1f} μs (8 tokens)")
print(f"批量8每token: {t_batch/8*1e6:.1f} μs")
print(f"吞吐提升: {t_single / (t_batch/8):.1f}x")
# 典型输出(A100):
# 单token: 42.3 μs (1 token)
# 批量8: 48.7 μs (8 tokens)
# 批量8每token: 6.1 μs
# 吞吐提升: 6.9x
关键洞察:处理 8 个 token 的时间几乎等于处理 1 个 token 的时间——这正是 Speculative Decoding 加速的根本原因。
13.1.4 面试考点¶
Q:为什么 LLM 推理阶段 GPU 利用率很低?
自回归 Decode 阶段每步仅处理 1 个 token,是矩阵-向量乘(Memory-bound),GPU 大量时间花在读取模型权重和 KV-Cache 上,计算单元严重闲置。70B 模型在 A100 上每步计算时间不到 1ms,但读取权重就要约 70ms。
13.2 Speculative Decoding 核心原理¶
13.2.1 核心思想¶
Speculative Decoding(投机解码)的核心思想:
用一个小而快的 Draft 模型猜测未来 K 个 token,然后用大的 Target 模型一次性并行验证这 K 个 token。
- Draft 模型(\(M_q\)):小模型,速度快,准确度有限
- Target 模型(\(M_p\)):大模型,速度慢,结果准确
传统自回归(K=5 个 token):
Target: [fwd] → t1 → [fwd] → t2 → [fwd] → t3 → [fwd] → t4 → [fwd] → t5
总计: 5 次 Target 前向传播
Speculative Decoding(K=5 个猜测):
Draft: [fwd×5 快速] → guess: t1, t2, t3, t4, t5
Target: [1次并行fwd] → verify: ✓t1, ✓t2, ✓t3, ✗t4 → 修正 t4'
总计: 5 次 Draft(快) + 1 次 Target 前向传播
若 Draft 速度 = Target/10,且 5 个中接受 3 个 → 加速约 2-3x
13.2.2 Draft-then-Verify 范式¶
完整算法流程:
输入: Target模型 M_p, Draft模型 M_q, 投机长度 K, 已有前缀 prefix
1. DRAFT 阶段(串行,但很快):
for i = 1 to K:
q_i = M_q(prefix + guesses[:i-1]) # Draft 模型预测概率分布
guess_i ~ q_i # 从 Draft 分布采样
2. VERIFY 阶段(并行,一次前向传播):
p_1, p_2, ..., p_K, p_{K+1} = M_p(prefix + [guess_1, ..., guess_K])
# Target 模型一次处理 K 个 token,得到 K+1 个位置的概率分布
3. ACCEPT/REJECT(逐个验证):
for i = 1 to K:
r ~ Uniform(0, 1)
if r < min(1, p_i(guess_i) / q_i(guess_i)):
ACCEPT guess_i, 继续
else:
REJECT guess_i
从修正分布 norm(max(0, p_i - q_i)) 采样替代 token
丢弃 guess_{i+1}, ..., guess_K
BREAK
if 全部接受:
从 p_{K+1} 额外采样 1 个 bonus token
4. 将已接受的 tokens 追加到 prefix,回到步骤 1
13.2.3 数学保证:输出分布不变¶
定理:Speculative Decoding 的输出分布与直接使用 Target 模型的分布完全一致。
设 Draft 分布为 \(q(x)\),Target 分布为 \(p(x)\),对于候选 token \(x\):
- 接受概率:\(\alpha = \min\left(1, \frac{p(x)}{q(x)}\right)\)
- 拒绝后修正采样:从 \(p'(x) = \frac{\max(0, p(x) - q(x))}{\sum_{x'} \max(0, p(x') - q(x'))}\) 中采样
证明最终分布为 \(p(x)\):
对任意 token \(x\),被选中的概率为:
分两种情况讨论:
情况 1:当 \(p(x) \geq q(x)\) 时: $\(q(x) \cdot 1 + \beta \cdot \frac{p(x) - q(x)}{\sum_{x'} \max(0, p(x') - q(x'))} = q(x) + (p(x) - q(x)) = p(x)\)$
(其中 \(\beta = \sum_{x'} \max(0, p(x') - q(x'))\))
情况 2:当 \(p(x) < q(x)\) 时: $\(q(x) \cdot \frac{p(x)}{q(x)} + \beta \cdot 0 = p(x)\)$
两种情况下最终分布都等于 \(p(x)\)。✓
13.2.4 验证算法详解¶
import torch
import torch.nn.functional as F
def speculative_sampling(
target_logits: torch.Tensor, # (K+1, vocab_size) - Target模型输出
draft_probs: torch.Tensor, # (K, vocab_size) - Draft模型采样时的概率
draft_tokens: torch.Tensor, # (K,) - Draft模型采样的token
temperature: float = 1.0
) -> tuple[torch.Tensor, int]:
"""
执行 Speculative Decoding 的验证步骤。
Returns:
accepted_tokens: 最终接受的 token 序列
n_accepted: 接受的 Draft token 数量
"""
K = draft_tokens.shape[0]
vocab_size = target_logits.shape[-1] # [-1]负索引取最后元素
# Target 模型的概率分布
target_probs = F.softmax(target_logits[:K] / temperature, dim=-1) # (K, V) # F.xxx PyTorch函数式API
accepted_tokens = []
n_accepted = 0
for i in range(K):
token = draft_tokens[i]
p_i = target_probs[i, token].item() # Target 概率
q_i = draft_probs[i, token].item() # Draft 概率
# 接受概率
accept_prob = min(1.0, p_i / (q_i + 1e-10))
r = torch.rand(1).item()
if r < accept_prob:
# 接受 Draft 的猜测
accepted_tokens.append(token.item())
n_accepted += 1
else:
# 拒绝:从修正分布采样
# p'(x) = norm(max(0, p(x) - q(x)))
corrected = torch.clamp(target_probs[i] - draft_probs[i], min=0)
corrected = corrected / (corrected.sum() + 1e-10)
replacement = torch.multinomial(corrected, 1).item()
accepted_tokens.append(replacement)
break
# 如果全部接受,从 p_{K+1} 采样 bonus token
if n_accepted == K:
bonus_probs = F.softmax(target_logits[K] / temperature, dim=-1)
bonus_token = torch.multinomial(bonus_probs, 1).item()
accepted_tokens.append(bonus_token)
return torch.tensor(accepted_tokens), n_accepted
13.2.5 加速原理分析¶
设: - \(c\):Draft 模型与 Target 模型的速度比(\(c = t_{\text{draft}} / t_{\text{target}}\),通常 \(c \ll 1\)) - \(K\):投机长度(Draft 一次猜测的 token 数) - \(\alpha\):平均接受率
每轮墙钟时间:\(T_{\text{round}} = K \cdot c \cdot t_{\text{target}} + t_{\text{target}} = (Kc + 1) \cdot t_{\text{target}}\)
每轮平均产出 token 数: $\(E[\text{tokens}] = \sum_{i=0}^{K-1} \alpha^i (1-\alpha) \cdot (i+1) + \alpha^K \cdot (K+1) = \frac{1 - \alpha^{K+1}}{1 - \alpha}\)$
加速比: $\(\text{Speedup} = \frac{E[\text{tokens}]}{Kc + 1} = \frac{1 - \alpha^{K+1}}{(1 - \alpha)(Kc + 1)}\)$
import numpy as np
def compute_speedup(alpha: float, K: int, c: float) -> float:
"""计算 Speculative Decoding 的理论加速比"""
expected_tokens = (1 - alpha**(K+1)) / (1 - alpha)
wall_time_ratio = K * c + 1
return expected_tokens / wall_time_ratio
# 不同参数下的加速比
print("=" * 60)
print(f"{'接受率α':<10} {'K':<5} {'速度比c':<10} {'加速比':<10}")
print("=" * 60)
for alpha in [0.6, 0.7, 0.8, 0.9]:
for K in [3, 5, 8]:
for c in [0.05, 0.1]:
s = compute_speedup(alpha, K, c)
print(f"{alpha:<10} {K:<5} {c:<10} {s:<10.2f}")
print("-" * 60)
# 典型输出:
# α=0.8, K=5, c=0.05 → 加速比 ≈ 2.95x
# α=0.9, K=5, c=0.05 → 加速比 ≈ 3.69x
# α=0.7, K=5, c=0.10 → 加速比 ≈ 1.94x
13.2.6 完整 PyTorch 实现¶
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
class SpeculativeDecoder:
"""
完整的 Speculative Decoding 实现。
使用小模型 (draft) 猜测 + 大模型 (target) 验证。
"""
def __init__(
self,
target_model_name: str,
draft_model_name: str,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
):
print(f"Loading target model: {target_model_name}")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=dtype, device_map=device
)
self.target_model.eval() # eval()评估模式
print(f"Loading draft model: {draft_model_name}")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=dtype, device_map=device
)
self.draft_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
self.device = device
@torch.no_grad() # 禁用梯度计算,节省内存
def draft_generate(
self,
input_ids: torch.Tensor,
K: int,
temperature: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Draft 模型串行生成 K 个候选 token。
Returns:
draft_tokens: (K,) 候选 token ids
draft_probs: (K, vocab_size) 候选 token 的概率分布
"""
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(K):
outputs = self.draft_model(current_ids)
logits = outputs.logits[:, -1, :] # (1, vocab_size)
probs = F.softmax(logits / temperature, dim=-1)
token = torch.multinomial(probs, 1) # (1, 1)
draft_tokens.append(token.squeeze()) # squeeze压缩维度
draft_probs.append(probs.squeeze())
current_ids = torch.cat([current_ids, token], dim=-1) # torch.cat沿已有维度拼接张量
return torch.stack(draft_tokens), torch.stack(draft_probs) # torch.stack沿新维度拼接张量
@torch.no_grad()
def target_verify(
self,
input_ids: torch.Tensor,
draft_tokens: torch.Tensor,
) -> torch.Tensor:
"""
Target 模型一次前向传播验证 K 个候选 token。
Returns:
logits: (K+1, vocab_size) — K个验证位置 + 1个bonus位置
"""
# 将 draft tokens 拼接到输入
candidate_ids = torch.cat([
input_ids,
draft_tokens.unsqueeze(0) # unsqueeze增加一个维度
], dim=-1) # (1, seq_len + K)
outputs = self.target_model(candidate_ids)
# 取最后 K+1 个位置的 logits
# 位置 -(K+1): 验证第1个draft token
# 位置 -1: bonus token 位置
logits = outputs.logits[0, -(len(draft_tokens)+1):, :]
return logits
@torch.no_grad()
def speculative_sample(
self,
target_logits: torch.Tensor,
draft_probs: torch.Tensor,
draft_tokens: torch.Tensor,
temperature: float = 1.0,
) -> tuple[torch.Tensor, int]:
"""接受-拒绝采样"""
K = draft_tokens.shape[0]
target_probs = F.softmax(target_logits[:K] / temperature, dim=-1)
accepted = []
n_accepted = 0
for i in range(K):
tok = draft_tokens[i]
p = target_probs[i, tok].item()
q = draft_probs[i, tok].item()
if torch.rand(1).item() < min(1.0, p / (q + 1e-10)):
accepted.append(tok.item())
n_accepted += 1
else:
# 修正采样
corrected = torch.clamp(target_probs[i] - draft_probs[i], min=0)
corrected /= corrected.sum() + 1e-10
replacement = torch.multinomial(corrected, 1).item()
accepted.append(replacement)
break
if n_accepted == K:
bonus_probs = F.softmax(target_logits[K] / temperature, dim=-1)
bonus = torch.multinomial(bonus_probs, 1).item()
accepted.append(bonus)
return torch.tensor(accepted, device=self.device), n_accepted
@torch.no_grad()
def generate(
self,
prompt: str,
max_new_tokens: int = 128,
K: int = 5,
temperature: float = 1.0,
) -> str:
"""
使用 Speculative Decoding 生成文本。
Args:
prompt: 输入提示
max_new_tokens: 最大生成 token 数
K: 投机长度(每轮猜测的 token 数)
temperature: 采样温度
"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated_tokens = 0
total_draft = 0
total_accepted = 0
while generated_tokens < max_new_tokens:
# Step 1: Draft 模型猜测 K 个 token
draft_tokens, draft_probs = self.draft_generate(
input_ids, K, temperature
)
# Step 2: Target 模型一次性验证
target_logits = self.target_verify(input_ids, draft_tokens)
# Step 3: 接受-拒绝采样
accepted_tokens, n_accepted = self.speculative_sample(
target_logits, draft_probs, draft_tokens, temperature
)
# Step 4: 更新输入序列
input_ids = torch.cat([
input_ids,
accepted_tokens.unsqueeze(0)
], dim=-1)
generated_tokens += len(accepted_tokens)
total_draft += K
total_accepted += n_accepted
# 检查是否生成了 EOS
if self.tokenizer.eos_token_id in accepted_tokens.tolist():
break
# 统计信息
accept_rate = total_accepted / total_draft if total_draft > 0 else 0
print(f"\n[Stats] 生成 {generated_tokens} tokens | "
f"接受率 {accept_rate:.1%} | "
f"Draft {total_draft} tokens, 接受 {total_accepted}")
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
# 使用示例
if __name__ == "__main__":
# 注意:meta-llama 模型需要在 HuggingFace 上申请访问权限
# 替代方案:使用公开模型如 "openlm-research/open_llama_7b"
decoder = SpeculativeDecoder(
target_model_name="meta-llama/Llama-3.1-8B-Instruct", # 需要HuggingFace访问权限
draft_model_name="meta-llama/Llama-3.2-1B-Instruct", # 需要HuggingFace访问权限
)
result = decoder.generate(
prompt="Explain quantum computing in simple terms:",
max_new_tokens=200,
K=5,
temperature=0.7,
)
print(result)
13.3 Draft 模型选择策略¶
Draft 模型的质量直接决定接受率和最终加速效果。理想 Draft 模型需要满足:
- 快:推理速度远快于 Target(通常 5-20x)
- 准:与 Target 分布尽可能接近(接受率高)
- 对齐:共享相同的 tokenizer 和词表
13.3.1 同架构小模型¶
最朴素的策略:使用同系列的小规模模型作为 Draft。
| Target 模型 | Draft 模型 | 参数比 | 典型接受率 |
|---|---|---|---|
| Llama-3.1-70B | Llama-3.2-1B | 70:1 | 60-75% |
| Llama-3.1-70B | Llama-3.2-3B | 23:1 | 70-82% |
| Qwen2.5-72B | Qwen2.5-0.5B | 144:1 | 55-70% |
| Qwen2.5-72B | Qwen2.5-1.5B | 48:1 | 65-78% |
| DeepSeek-V3 (671B) | DeepSeek-V3-Lite | ~50:1 | 65-80% |
优点:实现简单,无需额外训练,词表天然一致。
缺点:小模型与大模型能力差距大时接受率低,尤其在复杂推理任务上。
13.3.2 自蒸馏 Draft 模型¶
从 Target 模型蒸馏出专门用于 Draft 的小模型:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DraftDistillationTrainer:
"""
从 Target 模型知识蒸馏训练 Draft 模型。
使用 KL 散度让 Draft 尽可能模仿 Target 的输出分布。
"""
def __init__(self, target_model, draft_model, temperature=2.0):
self.target = target_model
self.draft = draft_model
self.temperature = temperature
def distillation_loss(self, input_ids, attention_mask=None):
"""KL(target || draft) — 让 Draft 模拟 Target 的分布"""
with torch.no_grad():
target_logits = self.target(
input_ids, attention_mask=attention_mask
).logits
target_probs = F.softmax(
target_logits / self.temperature, dim=-1
)
draft_logits = self.draft(
input_ids, attention_mask=attention_mask
).logits
draft_log_probs = F.log_softmax(
draft_logits / self.temperature, dim=-1
)
# KL 散度
kl_loss = F.kl_div(
draft_log_probs, target_probs, reduction='batchmean'
) * (self.temperature ** 2)
return kl_loss
优点:接受率比通用小模型高 5-15%。
缺点:需要额外的训练流程和计算资源。
13.3.3 N-gram 模型作为 Draft¶
一种极端轻量级策略——使用统计 N-gram 模型或基于上下文的词表匹配:
from collections import defaultdict
class NGramDraft:
"""
基于 N-gram 统计的超轻量 Draft。
几乎零延迟,但接受率较低。
适合重复性高的场景(代码生成、模板文本)。
"""
def __init__(self, n: int = 3):
self.n = n
# 两层嵌套默认字典:外层key不存在时自动创建新的defaultdict(int),内层key不存在时自动初始化为0,用于N-gram→next_token→计数
self.counts = defaultdict(lambda: defaultdict(int)) # defaultdict访问不存在的键时返回默认值
def fit(self, token_sequences: list[list[int]]):
"""从训练语料构建 N-gram 统计"""
for seq in token_sequences:
for i in range(len(seq) - self.n):
context = tuple(seq[i:i+self.n])
next_token = seq[i+self.n]
self.counts[context][next_token] += 1
def draft(self, context_tokens: list[int], K: int) -> list[int]:
"""基于 N-gram 猜测接下来 K 个 token"""
guesses = []
ctx = list(context_tokens)
for _ in range(K):
key = tuple(ctx[-self.n:])
candidates = self.counts.get(key, {})
if candidates:
# 选最高频的 token
best = max(candidates, key=candidates.get)
guesses.append(best)
ctx.append(best)
else:
break
return guesses
13.3.4 Prompt Lookup Decoding¶
核心思想:直接从 prompt 中寻找可能的续写内容,无需任何 Draft 模型。
def prompt_lookup_draft(
input_ids: list[int],
prompt_length: int,
K: int = 5,
n_match: int = 3,
) -> list[int]:
"""
Prompt Lookup Decoding: 在 prompt 中寻找匹配当前后缀的片段,
将其后续 token 作为 draft 候选。
特别适合:摘要、翻译、重写等与输入高度相关的任务。
"""
if len(input_ids) < n_match:
return []
# 当前生成序列的最后 n_match 个 token
suffix = input_ids[-n_match:]
# 在 prompt 中搜索相同的 n-gram
candidates = []
for i in range(prompt_length - n_match):
if input_ids[i:i+n_match] == suffix:
# 找到匹配!取后续 K 个 token 作为猜测
end = min(i + n_match + K, prompt_length)
draft = input_ids[i+n_match:end]
if draft:
candidates.append(draft)
# 返回最长的匹配候选
if candidates:
return max(candidates, key=len)[:K]
return []
优点:零额外模型开销,无需 GPU 显存。
缺点:仅在 prompt 与输出有大量重叠时有效。
13.3.5 各策略对比¶
| 策略 | 延迟开销 | 典型接受率 | 额外显存 | 适用场景 |
|---|---|---|---|---|
| 同架构小模型 | 低 | 60-80% | 需加载小模型 | 通用场景 |
| 自蒸馏 Draft | 低 | 70-85% | 需加载小模型 | 追求极致加速 |
| N-gram | 极低 | 30-50% | 几乎为零 | 高重复场景(代码) |
| Prompt Lookup | 零 | 0-70%(波动大) | 零 | 摘要/翻译/重写 |
| Medusa 多头 | 微增 | 60-80% | 微增(几层MLP) | 不想额外加载模型 |
13.4 Medusa:多头投机解码¶
13.4.1 Medusa 核心思想¶
Medusa(ICML 2024)提出了一种无需独立 Draft 模型的投机解码方案:
在 Target 模型的最后一层隐藏状态上,添加多个独立的预测头(Medusa Heads),每个头预测未来第 \(i\) 个位置的 token。
传统 LM Head: hidden_state → head_0 → token_{t+1}
Medusa: hidden_state → head_0 → token_{t+1} (原始预测)
→ head_1 → token_{t+2} (提前1步)
→ head_2 → token_{t+3} (提前2步)
→ head_3 → token_{t+4} (提前3步)
→ head_4 → token_{t+5} (提前4步)
13.4.2 Medusa Head 实现¶
import torch
import torch.nn as nn
class MedusaHead(nn.Module): # 继承nn.Module定义网络层
"""
单个 Medusa 预测头。
结构:1层残差MLP + 共享的LM Head权重。
"""
def __init__(self, hidden_size: int, vocab_size: int):
super().__init__() # super()调用父类方法
self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
self.act = nn.SiLU()
self.ln = nn.LayerNorm(hidden_size)
# 通常复用原始模型的 lm_head 权重
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: (batch, seq_len, hidden_size)
Returns:
logits: (batch, seq_len, vocab_size)
"""
x = self.act(self.fc(hidden_states))
x = self.ln(x + hidden_states) # 残差连接
return self.lm_head(x)
class MedusaModel(nn.Module):
"""
在基础模型上添加 K 个 Medusa 预测头。
"""
def __init__(self, base_model, num_heads: int = 4, hidden_size: int = 4096, vocab_size: int = 32000):
super().__init__()
self.base_model = base_model
self.medusa_heads = nn.ModuleList([
MedusaHead(hidden_size, vocab_size)
for _ in range(num_heads)
])
# 共享原始模型的 lm_head 权重
for head in self.medusa_heads:
head.lm_head.weight = base_model.lm_head.weight
@torch.no_grad()
def forward(self, input_ids, attention_mask=None):
outputs = self.base_model.model(
input_ids, attention_mask=attention_mask
)
hidden = outputs.last_hidden_state # (B, L, D)
# 原始 LM Head 预测 t+1
base_logits = self.base_model.lm_head(hidden)
# Medusa Heads 预测 t+2, t+3, ...
medusa_logits = [head(hidden) for head in self.medusa_heads]
return base_logits, medusa_logits
13.4.3 Tree-Structured Attention¶
Medusa 的关键创新:使用树状结构组织候选 token,而非线性序列。
线性猜测: t1 → t2 → t3 → t4 → t5
(一个被拒绝,后续全部丢弃)
树状猜测: t1
/ | \
t2a t2b t2c
/| | \
t3a t3b t3c t3d
(t2a被拒绝时,t2b/t2c 分支仍可能被接受)
def build_medusa_tree(
medusa_logits: list[torch.Tensor],
top_k: int = 3,
) -> list[list[int]]:
"""
构建 Medusa 候选树。
Args:
medusa_logits: K 个 Medusa Head 的输出 logits
top_k: 每个 Head 取 top-k 候选
Returns:
candidates: 所有候选路径 [[t1, t2, t3], [t1, t2', t3'], ...]
"""
K = len(medusa_logits)
# 每个 Head 取 top-k
topk_per_head = []
for logits in medusa_logits:
probs = torch.softmax(logits[:, -1, :], dim=-1)
topk = torch.topk(probs, top_k, dim=-1)
topk_per_head.append(topk.indices[0].tolist()) # (top_k,)
# 构建所有路径组合(可用剪枝策略减少路径数)
# 简化示例:仅取联合概率最高的若干路径
candidates = []
for t1 in topk_per_head[0]:
for t2 in topk_per_head[1]:
for t3 in topk_per_head[2] if K > 2 else [None]:
path = [t1, t2]
if t3 is not None:
path.append(t3)
candidates.append(path)
return candidates
def tree_attention_mask(num_candidates: int, seq_len: int) -> torch.Tensor:
"""
构建 Tree Attention 的因果掩码。
每条候选路径只能看到自己的祖先节点和共享前缀。
"""
# 简化示例:具体实现需根据树的拓扑结构
mask = torch.zeros(num_candidates, seq_len + num_candidates)
# 所有候选都能看到完整 prefix
mask[:, :seq_len] = 1
# 每个候选只能看到自己路径上的祖先
for i in range(num_candidates):
mask[i, seq_len + i] = 1
return mask
13.4.4 Medusa-1 vs Medusa-2¶
| 特性 | Medusa-1 | Medusa-2 |
|---|---|---|
| 训练方式 | 冻结基础模型,只训练 Medusa Heads | 联合微调基础模型 + Medusa Heads |
| 训练成本 | 低(仅几层 MLP) | 高(全模型微调) |
| 预测质量 | 中等(接受率 ~60%) | 高(接受率 ~75-80%) |
| 是否保持原模型能力 | 是 | 可能略有变化 |
| 适用场景 | 快速部署、不想改基础模型 | 追求最大加速比 |
13.4.5 Medusa 训练代码¶
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
def train_medusa_heads(
model: MedusaModel,
dataloader: DataLoader,
num_epochs: int = 3,
lr: float = 1e-4,
):
"""
Medusa-1 训练:冻结基础模型,仅训练 Medusa Heads。
训练目标:让 Head_i 预测位置 t+i+2 的 token。
"""
# 冻结基础模型
for param in model.base_model.parameters():
param.requires_grad = False
# 只优化 Medusa Heads(不含共享的 lm_head.weight)
medusa_params = []
for head in model.medusa_heads:
medusa_params.extend([head.fc.weight, head.ln.weight, head.ln.bias])
optimizer = torch.optim.AdamW(medusa_params, lr=lr)
for epoch in range(num_epochs):
total_loss = 0
for batch in dataloader:
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda() # 与 input_ids 对齐的标签
base_logits, medusa_logits = model(input_ids)
loss = 0
for i, m_logits in enumerate(medusa_logits): # enumerate同时获取索引和元素
# Head_i 预测 t+i+2 的 token
shift = i + 2 # head_0 预测 t+2, head_1 预测 t+3, ...
if shift < labels.shape[1]:
pred = m_logits[:, :-shift, :] # 预测部分
target = labels[:, shift:] # 目标部分
loss += F.cross_entropy( # F.cross_entropy PyTorch函数式交叉熵损失
pred.reshape(-1, pred.shape[-1]), # 重塑张量形状
target.reshape(-1),
ignore_index=-100,
)
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}")
13.5 EAGLE 与 EAGLE-2¶
13.5.1 EAGLE:特征级别 Draft¶
EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)的核心创新:
不在 token 级别做 Draft(预测 token),而在 Target 模型的隐藏特征级别做 Draft(预测特征向量),大幅提高 Draft 准确性。
传统 Speculative Decoding:
Draft 模型: [tokens] → Draft LM → 预测下一个 token ID
准确率受限于 Draft 模型能力
EAGLE:
Target 模型第1步: [tokens] → Target LM → hidden features (f1)
EAGLE Draft: f1 → 轻量网络 → 预测 f2 (Target的下一步特征)
f2 → 轻量网络 → 预测 f3
...
f_i → Target LM Head → token_i
关键优势: 在特征空间做外推,比在 token 空间做预测容易得多
为什么特征级别更好?
- Token 预测是离散的高维分类问题(32K-128K 个类别),容易出错
- 特征向量是连续的低维空间(4096 维),Auto-regressive 外推更准确
- EAGLE Draft 网络直接利用 Target 模型的中间特征,信息量更丰富
import torch
import torch.nn as nn
class EAGLEDraftHead(nn.Module):
"""
EAGLE 的特征级别 Draft 网络。
输入: Target 模型的隐藏特征 + token embedding
输出: 预测下一步的隐藏特征
"""
def __init__(self, hidden_size: int = 4096):
super().__init__()
# 融合 token embedding 和 hidden state
self.fc_in = nn.Linear(hidden_size * 2, hidden_size, bias=False)
# 轻量级 Transformer 层(通常1层)
self.transformer_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=32,
dim_feedforward=hidden_size * 4,
batch_first=True,
norm_first=True,
)
self.ln = nn.LayerNorm(hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
token_embeddings: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: Target 模型的隐藏特征 (B, L, D)
token_embeddings: 对应 token 的 embedding (B, L, D)
Returns:
预测的下一步隐藏特征 (B, L, D)
"""
# 拼接特征和 embedding
fused = torch.cat([hidden_states, token_embeddings], dim=-1)
x = self.fc_in(fused)
x = self.transformer_layer(x)
return self.ln(x)
class EAGLEDecoder:
"""EAGLE Speculative Decoding 推理流程"""
def __init__(self, target_model, eagle_head, lm_head, embed_tokens):
self.target = target_model
self.eagle_head = eagle_head
self.lm_head = lm_head # 复用 Target 的 LM Head
self.embed = embed_tokens # 复用 Target 的 Embedding
@torch.no_grad()
def draft_with_features(
self,
last_hidden: torch.Tensor, # Target 模型最后一层输出
last_token_id: int,
K: int = 5,
) -> tuple[list[int], list[torch.Tensor]]:
"""
使用 EAGLE Head 在特征空间做 K 步外推。
"""
draft_tokens = []
draft_features = []
current_hidden = last_hidden[:, -1:, :] # (1, 1, D)
current_embed = self.embed(
torch.tensor([[last_token_id]], device=last_hidden.device)
) # (1, 1, D)
for _ in range(K):
# 预测下一步特征
predicted_hidden = self.eagle_head(current_hidden, current_embed)
draft_features.append(predicted_hidden)
# 用 Target 的 LM Head 将特征映射到 token
logits = self.lm_head(predicted_hidden) # (1, 1, V)
token = logits.argmax(dim=-1).item()
draft_tokens.append(token)
# 更新 Draft 状态
current_hidden = predicted_hidden
current_embed = self.embed(
torch.tensor([[token]], device=last_hidden.device)
)
return draft_tokens, draft_features
13.5.2 EAGLE-2:Context-Aware 动态 Draft 树¶
EAGLE-2 在 EAGLE 基础上引入两个关键改进:
- 动态 Draft 树:根据当前上下文动态调整投机树的形状和深度
- 置信度感知采样:利用 Draft Head 的置信度信号决定是否继续展开分支
EAGLE-1(静态树):
所有输入使用相同的树结构(如固定 K=5, top-k=3)
EAGLE-2(动态树):
简单token("the", "is")→ 深树, 少分支 (K=8, top-k=1)
复杂token(推理、代码)→ 浅树, 多分支 (K=3, top-k=5)
置信度低时 → 提前终止 Draft,避免浪费
class DynamicDraftTree:
"""EAGLE-2 的动态 Draft 树策略"""
def __init__(
self,
max_depth: int = 8,
max_width: int = 5,
confidence_threshold: float = 0.4,
max_candidates: int = 60,
):
self.max_depth = max_depth
self.max_width = max_width
self.conf_threshold = confidence_threshold
self.max_candidates = max_candidates
def build_tree(self, eagle_head, hidden, embed, lm_head):
"""
动态构建 Draft 树。
高置信度分支: 继续深入
低置信度分支: 剪枝停止
"""
tree_nodes = [] # [(token, parent_idx, depth, confidence)]
active_leaves = [(hidden[:, -1:, :], embed[:, -1:, :], -1, 0)]
while active_leaves and len(tree_nodes) < self.max_candidates:
new_leaves = []
for h, e, parent_idx, depth in active_leaves:
if depth >= self.max_depth:
continue
# EAGLE Head 预测
pred_hidden = eagle_head(h, e)
logits = lm_head(pred_hidden).squeeze()
probs = torch.softmax(logits, dim=-1)
# 取 top-k
topk_probs, topk_ids = probs.topk(self.max_width)
for prob, tid in zip(topk_probs, topk_ids): # zip按位置配对
conf = prob.item()
if conf < self.conf_threshold:
continue # 置信度低,剪枝
node_idx = len(tree_nodes)
tree_nodes.append((tid.item(), parent_idx, depth + 1, conf))
# 准备该节点的特征用于下一层
new_embed = embed.new_zeros(1, 1, embed.shape[-1])
new_leaves.append((pred_hidden, new_embed, node_idx, depth + 1))
active_leaves = new_leaves
return tree_nodes
13.5.3 性能对比¶
在多个基准上的典型加速比(相比原始自回归解码,在 A100 GPU、batch_size=1 条件下):
| 方法 | MT-Bench 加速 | HumanEval 加速 | 额外显存 | 是否需要训练 |
|---|---|---|---|---|
| Speculative Decoding (Llama-3.2-1B → 70B) | 2.0-2.5x | 1.8-2.2x | +1B 模型 | 否 |
| Medusa-1 | 2.0-2.3x | 1.7-2.0x | +0.5% | 是(仅 Head) |
| Medusa-2 | 2.3-2.8x | 2.0-2.5x | +0.5% | 是(全模型) |
| EAGLE | 2.5-3.2x | 2.8-3.5x | +2-5% | 是(Head) |
| EAGLE-2 | 2.8-3.6x | 3.0-4.0x | +2-5% | 是(Head) |
面试要点:EAGLE 系列在代码生成(HumanEval)上加速更明显,因为代码的自回归可预测性更高(语法约束强),Draft 接受率更高。
13.6 其他推理加速方法¶
13.6.1 Lookahead Decoding:Jacobi 迭代并行解码¶
核心思想:将自回归生成看作一个不动点方程,使用 Jacobi 迭代并行求解。
import torch
def lookahead_decoding_step(model, input_ids, n_lookahead=5, n_iterations=3):
"""
Lookahead Decoding 的核心步骤。
维护一个 N-token 的猜测窗口,通过 Jacobi 迭代并行求解。
优点: 不需要 Draft 模型
缺点: 需要多次迭代,加速比通常低于 Speculative Decoding
Args:
model: 语言模型
input_ids: 当前序列 (1, seq_len)
n_lookahead: 前瞻窗口大小
n_iterations: Jacobi 迭代次数
"""
device = input_ids.device
seq_len = input_ids.shape[1]
# 初始化猜测窗口(随机或启发式)
guess = torch.randint(0, model.config.vocab_size, (1, n_lookahead), device=device)
for iteration in range(n_iterations):
# 将猜测拼接到序列
full_ids = torch.cat([input_ids, guess], dim=-1)
# 构造因果掩码(标准自回归)
outputs = model(full_ids)
logits = outputs.logits # (1, seq_len + n_lookahead, vocab_size)
# 并行更新所有猜测位置
new_guess = logits[0, seq_len-1:seq_len+n_lookahead-1, :].argmax(dim=-1)
new_guess = new_guess.unsqueeze(0)
# 检查是否收敛
if torch.equal(new_guess, guess):
break
guess = new_guess
# 验证:找到第一个不一致的位置
final_ids = torch.cat([input_ids, guess], dim=-1)
verify_outputs = model(final_ids)
verify_logits = verify_outputs.logits
n_accepted = 0
for i in range(n_lookahead):
predicted = verify_logits[0, seq_len + i - 1, :].argmax()
if predicted == guess[0, i]:
n_accepted += 1
else:
break
return guess[0, :n_accepted+1], n_accepted
13.6.2 Continuous Batching(动态批处理)¶
传统 Static Batching 的问题:批内所有请求必须等最长序列完成后才能释放资源。
Static Batching:
Req1: [====生成中====] [等待]
Req2: [======生成中======] [等待]
Req3: [===========生成中===========]
↑ Req1完成后仍需等待 Req3
Continuous Batching (Iteration-level scheduling):
Req1: [====完成====][Req4开始===]
Req2: [======完成======][Req5=]
Req3: [===========完成===========]
↑ Req1完成后立即插入 Req4
class ContinuousBatchingScheduler:
"""
连续批处理调度器(简化示意)。
每次 Decode 迭代后检查是否有请求完成或新请求到达。
"""
def __init__(self, model, max_batch_size=64):
self.model = model
self.max_batch = max_batch_size
self.active_requests = [] # 正在生成的请求
self.waiting_queue = [] # 等待队列
def iteration_step(self):
"""执行一次 Decode 迭代"""
# 1. 移除已完成的请求
finished = [r for r in self.active_requests if r.is_done()]
for r in finished:
self.active_requests.remove(r)
r.callback(r.generated_text)
# 2. 从等待队列填充空位
while (len(self.active_requests) < self.max_batch
and self.waiting_queue):
new_req = self.waiting_queue.pop(0)
self.active_requests.append(new_req)
if not self.active_requests:
return
# 3. 组装 batch,执行一次前向传播
batch_ids = self._pad_batch()
with torch.no_grad():
outputs = self.model(batch_ids)
# 4. 每个请求采样下一个 token
for i, req in enumerate(self.active_requests):
next_token = outputs.logits[i, -1, :].argmax().item()
req.append_token(next_token)
def _pad_batch(self):
"""将不同长度的请求 padding 到相同长度"""
max_len = max(len(r.tokens) for r in self.active_requests)
batch = torch.zeros(len(self.active_requests), max_len, dtype=torch.long)
for i, r in enumerate(self.active_requests):
batch[i, :len(r.tokens)] = torch.tensor(r.tokens)
return batch.cuda()
13.6.3 Prefix Caching¶
核心思想:当多个请求共享相同前缀(如 system prompt)时,缓存该前缀的 KV-Cache,避免重复计算。
class PrefixCacheManager:
"""
前缀 KV-Cache 管理器。
适用场景:所有请求共享相同 system prompt。
"""
def __init__(self, model, max_cache_entries=100):
self.model = model
self.cache = {} # hash(prefix_tokens) → KV-Cache
self.max_entries = max_cache_entries
def get_or_compute_prefix_cache(self, prefix_tokens: tuple):
"""查找或计算前缀的 KV-Cache"""
cache_key = hash(prefix_tokens)
if cache_key in self.cache:
return self.cache[cache_key] # 缓存命中!
# 缓存未命中,计算 KV-Cache
input_ids = torch.tensor([list(prefix_tokens)]).cuda()
with torch.no_grad():
outputs = self.model(input_ids, use_cache=True)
kv_cache = outputs.past_key_values
# LRU 淘汰
if len(self.cache) >= self.max_entries:
oldest = next(iter(self.cache))
del self.cache[oldest]
self.cache[cache_key] = kv_cache
return kv_cache
def generate_with_cached_prefix(self, prefix_tokens, continuation_tokens):
"""使用缓存的前缀 KV-Cache 直接从续写部分开始计算"""
kv_cache = self.get_or_compute_prefix_cache(tuple(prefix_tokens))
# 仅需处理续写部分的 token
cont_ids = torch.tensor([continuation_tokens]).cuda()
with torch.no_grad():
outputs = self.model(
cont_ids,
past_key_values=kv_cache,
use_cache=True,
)
return outputs
13.6.4 Chunked Prefill(分块预填充)¶
解决长 prompt 的 Prefill 阶段占用过多时间、阻塞 Decode 请求的问题:
传统 Prefill(prompt=4096 tokens):
[===========Prefill 4096 tokens===========]
↑ 这段时间其他 Decode 请求全部被阻塞
Chunked Prefill:
[Prefill chunk1=1024][Decode batch][Prefill chunk2=1024][Decode batch]...
↑ 将 Prefill 分块,与 Decode 请求交替执行
class ChunkedPrefillScheduler:
"""分块预填充调度器"""
def __init__(self, model, chunk_size=1024):
self.model = model
self.chunk_size = chunk_size # 每块处理的 token 数
def chunked_prefill(self, prompt_tokens, decode_requests=None):
"""
将长 prompt 的 Prefill 拆分为多个 chunk 交替执行。
减少 Prefill 对 Decode 延迟的影响 (TTFT ↓, ITL 稳定)。
"""
total_len = len(prompt_tokens)
num_chunks = (total_len + self.chunk_size - 1) // self.chunk_size
kv_cache = None
for i in range(num_chunks):
start = i * self.chunk_size
end = min(start + self.chunk_size, total_len)
chunk = prompt_tokens[start:end]
# 处理一块 Prefill
chunk_ids = torch.tensor([chunk]).cuda()
with torch.no_grad():
outputs = self.model(
chunk_ids,
past_key_values=kv_cache,
use_cache=True,
)
kv_cache = outputs.past_key_values
# 在 chunk 之间插入 Decode 步骤
if decode_requests:
self._serve_decode_batch(decode_requests)
return kv_cache
def _serve_decode_batch(self, requests):
"""服务当前等待中的 Decode 请求"""
# 对活跃请求执行一步 Decode
pass
13.6.5 SGLang 的 RadixAttention¶
SGLang(Structured Generation Language)的 RadixAttention 是 Prefix Caching 的高级形式:
传统 Prefix Cache: 仅匹配完全相同的前缀
RadixAttention(基数树缓存):
请求1: [System] [User: 什么是AI?]
请求2: [System] [User: 什么是ML?]
请求3: [System] [User: 什么是AI?] [Assistant: AI是...] [User: 详细说说]
Radix Tree:
[System] ──┬── [User: 什么是AI?] ──── [Assistant: AI是...] ── [User: 详细说说]
└── [User: 什么是ML?]
最长前缀匹配: 请求3 完全命中请求1的缓存,直接从"详细说说"开始计算
核心优势: - 自动管理任意前缀的 KV-Cache - 支持多轮对话的增量缓存 - 与 Continuous Batching 和 Chunked Prefill 正交,可组合使用
13.7 工程实践¶
13.7.1 vLLM 中的 Speculative Decoding¶
vLLM(v0.6+)原生支持多种 Speculative Decoding 策略:
from vllm import LLM, SamplingParams
# ===== 方法1:使用独立 Draft 模型 =====
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
speculative_model="meta-llama/Llama-3.2-1B-Instruct",
num_speculative_tokens=5, # 每轮猜测 5 个 token
max_model_len=4096,
tensor_parallel_size=4, # Target 模型 4 卡并行
# Draft 模型自动放在单卡上
)
# ===== 方法2:Prompt Lookup Decoding(不需要 Draft 模型) =====
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
speculative_model="[ngram]", # 使用 N-gram 作为 Draft
num_speculative_tokens=5,
ngram_prompt_lookup_max=4, # N-gram 匹配的最大长度
ngram_prompt_lookup_min=2,
)
# ===== 方法3:EAGLE 投机解码 =====
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-70B",
num_speculative_tokens=5,
max_model_len=4096,
)
# 使用方式与普通 vLLM 完全一致
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate(["Explain quantum computing:"], sampling_params)
print(outputs[0].outputs[0].text)
vLLM Speculative Decoding 配置要点:
| 参数 | 说明 | 建议值 |
|---|---|---|
num_speculative_tokens | 每轮投机长度 K | 3-8(取决于接受率) |
speculative_disable_by_batch_size | 当 batch size 超过阈值时禁用投机 | 8-16 |
speculative_max_model_len | Draft 模型最大序列长度 | 与 Target 一致或更小 |
speculative_draft_tensor_parallel_size | Draft 模型张量并行度 | 1(通常单卡) |
13.7.2 TensorRT-LLM 中的投机采样¶
# TensorRT-LLM Speculative Decoding 配置示例
import tensorrt_llm
from tensorrt_llm import BuildConfig
from tensorrt_llm.executor import ExecutorConfig
# 构建 Target 引擎
target_build_config = BuildConfig(
max_batch_size=64,
max_input_len=2048,
max_seq_len=4096,
speculative_decoding_mode="DRAFT_TOKENS_EXTERNAL",
)
# 构建 Draft 引擎
draft_build_config = BuildConfig(
max_batch_size=64,
max_input_len=2048,
max_seq_len=4096,
)
# 运行时配置
executor_config = ExecutorConfig(
speculative_config={
"num_draft_tokens": 5,
"acceptance_threshold": 0.0, # 使用标准 Speculative Sampling
}
)
# 在 trtllm-build 命令行中:
# trtllm-build --checkpoint_dir ./target_ckpt \
# --output_dir ./target_engine \
# --speculative_decoding_mode draft_tokens_external \
# --max_draft_len 5
13.7.3 不同场景下的实测加速比¶
以下为典型实测数据(A100 80GB,batch_size=1,FP16):
| Target 模型 | 方法 | 代码生成 | 对话 | 摘要 | 数学推理 |
|---|---|---|---|---|---|
| Llama-3.1-70B | 标准解码(基线) | 1.0x | 1.0x | 1.0x | 1.0x |
| Llama-3.1-70B | SD (Draft=1B, K=5) | 2.3x | 2.0x | 2.5x | 1.6x |
| Llama-3.1-70B | Prompt Lookup (K=5) | 1.2x | 1.1x | 2.8x | 1.0x |
| Llama-3.1-70B | EAGLE-2 | 3.2x | 2.8x | 2.5x | 2.1x |
| Llama-3.1-70B | Medusa-2 (4 heads) | 2.5x | 2.3x | 2.0x | 1.7x |
关键观察:
- 代码生成加速效果最好——代码语法约束强,Draft 准确率高
- 数学推理加速效果最差——推理链不可预测,Draft 频繁被拒绝
- 摘要任务用 Prompt Lookup 效果最好——输出大量复用输入内容
- EAGLE-2 在多数场景综合表现最佳
13.7.4 何时适合/不适合使用 Speculative Decoding¶
适合的场景:
| 条件 | 原因 |
|---|---|
| Batch size = 1 或很小 | GPU 利用率低,有空间做投机计算 |
| 低延迟(TTFT↓)需求 | 减少 Decode 步数,加速首 token 后的生成 |
| 输出可预测性高 | 代码、翻译、模板化文本→高接受率 |
| 有现成的好 Draft 模型 | 同系列小模型或已训练的 EAGLE Head |
不适合的场景:
| 条件 | 原因 |
|---|---|
| Batch size 很大(>16) | GPU 已被充分利用,投机计算反而增加开销 |
| 显存紧张 | Draft 模型或额外 Head 占用显存 |
| 创意性/开放式生成 | 接受率低,投机频繁失败,可能反而变慢 |
| 高温度采样(\(T > 1.5\)) | 概率分布分散,Draft 越难命中 |
def should_use_speculative_decoding(
batch_size: int,
task_type: str,
temperature: float,
gpu_memory_free_gb: float,
draft_model_size_gb: float,
) -> tuple[bool, str]:
"""
决策函数:是否应该使用 Speculative Decoding
"""
reasons = []
if batch_size > 16:
reasons.append(f"batch_size={batch_size} 过大,GPU已充分利用")
return False, "; ".join(reasons)
if temperature > 1.5:
reasons.append(f"temperature={temperature} 过高,接受率将很低")
return False, "; ".join(reasons)
if gpu_memory_free_gb < draft_model_size_gb * 1.2:
reasons.append("显存不足以加载 Draft 模型")
return False, "; ".join(reasons)
# 按任务类型预估加速比
expected_speedup = {
"code": 2.5, "translation": 2.2, "summary": 2.0,
"chat": 1.8, "creative": 1.2, "math": 1.5,
}.get(task_type, 1.5)
if expected_speedup < 1.3:
reasons.append(f"预估加速比 {expected_speedup}x 不值得额外复杂度")
return False, "; ".join(reasons)
return True, f"建议使用,预估加速 {expected_speedup}x"
13.8 面试高频题¶
题目 1:Speculative Decoding 为什么能加速,但不改变输出分布?¶
答:
加速原因: - 自回归 Decode 是 Memory-bound 的,处理 1 个 token 和处理 K 个 token 的 wall-clock time 几乎相同 - Speculative Decoding 用快速 Draft 模型猜 K 个 token,然后 Target 模型一次前向传播并行验证 K 个 token - 相当于用"1 次 Draft ×K + 1 次 Target"替代"K 次 Target"
分布不变的保证: - 使用接受-拒绝采样:接受概率为 \(\min(1, p(x)/q(x))\) - 拒绝时从修正分布 \(\text{norm}(\max(0, p-q))\) 采样 - 可以数学证明最终每个 token 的边际分布恰好等于 Target 分布 \(p(x)\) - 这不是近似,而是精确保证
题目 2:Speculative Decoding 的加速效果受哪些因素影响?¶
答:
影响加速比的三个关键因素:
- 接受率 \(\alpha\)(最重要)
- 取决于 Draft 模型与 Target 模型的分布差异
- \(\alpha\) 越高加速越大;\(\alpha < 0.5\) 时可能不如不用
-
受任务类型影响:代码/翻译 > 对话 > 创意写作
-
速度比 \(c = t_{\text{draft}} / t_{\text{target}}\)
- Draft 越快越好,通常 \(c < 0.1\) 才有意义
-
包括:模型加载延迟、KV-Cache 管理开销
-
投机长度 \(K\)
- \(K\) 过小→每轮产出少,频繁切换 Draft/Target 开销大
- \(K\) 过大→后面的猜测越来越不准,接受率指数下降
- 最优 \(K\) 通常 3-8,取决于 \(\alpha\) 和 \(c\)
理论加速比公式为:\(\text{Speedup} = \frac{1 - \alpha^{K+1}}{(1-\alpha)(Kc+1)}\)
题目 3:比较 Medusa 和 EAGLE 的思路差异¶
答:
| 维度 | Medusa | EAGLE |
|---|---|---|
| Draft 方式 | 在 token 级别:多个独立 Head 各自预测未来第 i 个 token | 在特征级别:预测 Target 模型的隐藏特征,再转 token |
| 独立性假设 | 各 Head 独立预测(不考虑 token 间依赖) | 逐步外推特征(保留自回归依赖) |
| 验证结构 | 树结构组合所有 Head 的 top-k 候选 | 动态树(EAGLE-2),根据置信度调整 |
| 额外参数 | 极少(每个 Head 仅一层 MLP) | 稍多(一层 Transformer + 投影层) |
| 训练数据 | 需要训练数据对齐 Head | 需要训练数据 + Target 模型的中间特征 |
| 加速效果 | 中等(受独立性假设限制) | 更高(特征外推准确率高) |
核心差异:Medusa 的各 Head 独立预测违反了语言的自回归本质(第 3 个 token 的预测应该依赖第 2 个),EAGLE 通过在特征空间做序列外推保留了这种依赖,因此 Draft 质量更高。
题目 4:Speculative Decoding 在大 batch size 下为什么效果不好?¶
答:
-
GPU 利用率已经高:大 batch 时 Decode 的矩阵乘变成了矩阵-矩阵乘(而非矩阵-向量乘),从 Memory-bound 变为 Compute-bound,GPU 计算单元已被充分利用
-
验证开销增大:Target 模型需要同时处理 batch 中每个请求的 K 个候选 token,显存和计算开销线性增长
-
Draft 开销无法忽略:batch=1 时 Draft 成本可忽略,但 batch=32 时 Draft 模型也需要处理 32 个请求,成本不可忽略
-
边际收益递减:Continuous Batching 本身已提供了很好的吞吐,Speculative Decoding 的额外收益被稀释
经验法则:batch_size > 8-16 时,关闭 Speculative Decoding 通常更优。vLLM 提供 speculative_disable_by_batch_size 参数来自动处理。
题目 5:解释 Prompt Lookup Decoding 的原理和适用场景¶
答:
原理: - 不使用任何 Draft 模型 - 直接在输入 prompt 中寻找与当前生成后缀匹配的 N-gram - 将匹配位置之后的 token 作为 Draft 候选交给 Target 验证 - 时间复杂度 \(O(L \cdot n)\)(\(L\) 为 prompt 长度,\(n\) 为匹配长度),几乎零延迟
适用场景(接受率高): - 摘要任务:输出大量复用 prompt 中的原文 - 翻译:源语言和目标语言有共享 token(数字、专有名词) - 代码重构:输出与输入代码高度相似 - 多轮对话:重复引用上文内容
不适用场景: - 创意写作(输出与输入无关) - 知识问答(答案不在 prompt 中) - 数学推理(推导过程与 prompt 无重叠)
题目 6:Continuous Batching 和 Speculative Decoding 能否结合使用?¶
答:
可以结合,但有工程挑战:
-
兼容性:两者正交——Continuous Batching 管理请求调度,Speculative Decoding 管理单个请求的加速。理论上完全兼容。
-
工程挑战:
- 批内不同请求的 Draft 长度可能不同(因为接受率不同),需要 padding 对齐
- Draft 阶段的 batching 和 Target 验证的 batching 需要分别管理
-
KV-Cache 管理更复杂:被拒绝的 token 需要回退 Cache
-
动态策略:
- batch 小时开启 Speculative Decoding
- batch 大时自动关闭
-
vLLM 的
speculative_disable_by_batch_size就是这个策略 -
实践结论:在请求稀疏时(batch 小),两者结合效果最好——Continuous Batching 保证吞吐,Speculative Decoding 降低延迟。
题目 7:如何选择最优投机长度 K?¶
答:
理论分析:
加速比 \(S(K) = \frac{1 - \alpha^{K+1}}{(1-\alpha)(Kc+1)}\)
对 \(K\) 求导并令其为零,可得理论最优 \(K^*\):
实践指导:
| 接受率 \(\alpha\) | 速度比 \(c\) | 建议 \(K\) | 预估加速 |
|---|---|---|---|
| 0.9 | 0.05 | 6-8 | 3.5-4.0x |
| 0.8 | 0.05 | 4-6 | 2.5-3.0x |
| 0.7 | 0.10 | 3-4 | 1.7-2.0x |
| 0.6 | 0.10 | 2-3 | 1.3-1.5x |
动态调整策略: - 监控运行时接受率 - 接受率高时增大 \(K\)(多猜一些),低时减小 \(K\)(少猜避免浪费) - EAGLE-2 的动态树本质上就是在自动做这件事
class AdaptiveSpeculativeLength:
"""运行时自适应调整投机长度"""
def __init__(self, K_min=2, K_max=10, target_accept_rate=0.75):
self.K = 5 # 初始值
self.K_min = K_min
self.K_max = K_max
self.target_rate = target_accept_rate
self.accept_history = []
def update(self, n_accepted, K_used):
rate = n_accepted / K_used
self.accept_history.append(rate)
# 用最近 10 轮的平均接受率调整
if len(self.accept_history) >= 10:
avg_rate = sum(self.accept_history[-10:]) / 10
if avg_rate > self.target_rate + 0.1:
self.K = min(self.K + 1, self.K_max)
elif avg_rate < self.target_rate - 0.1:
self.K = max(self.K - 1, self.K_min)
return self.K
题目 8:未来推理加速的趋势是什么?¶
答:
- 硬件-算法协同
- Speculative Decoding 与专用推理芯片(Groq、Cerebras)结合
- 硬件原生支持 Draft-Verify 流水线
-
充分利用 HBM3e 带宽(>4TB/s)缩小 Memory-bound 瓶颈
-
自适应投机策略
- 根据输入动态选择 Draft 方法(N-gram vs 小模型 vs EAGLE)
- 多级 Draft 级联:先 N-gram,失败再用小模型
-
在线学习接受率,实时调整参数
-
与 MoE 结合
- MoE 模型(如 DeepSeek-V3)的非活跃专家天然可作为 Draft
-
路由预测实现推测执行
-
与量化互补
- Draft 模型使用 INT4/INT2 极致量化
- Target 模型使用 FP8
-
量化带来的额外速度比 \(c\) 进一步降低
-
端侧推理
- 手机/边缘设备上内存带宽更受限,Speculative Decoding 价值更大
-
超轻量 Draft(N-gram + 小型 MLP)适合端侧
-
多模态投机解码
- 将 Speculative Decoding 扩展到视觉-语言模型
- 图像理解任务中文本部分仍可投机加速
本章总结¶
自回归解码瓶颈 ──→ Memory-bound, GPU利用率低
│
▼
Speculative Decoding ──→ Draft猜测 + Target验证
│ │ │
▼ ▼ ▼
独立Draft Medusa EAGLE/EAGLE-2
(小模型) (多头) (特征级别)
│ │ │
▼ ▼ ▼
工程实践: vLLM / TensorRT-LLM / SGLang
│
▼
互补方法: Continuous Batching + Prefix Caching + Chunked Prefill
核心知识点: 1. Decode 阶段是 Memory-bound → 处理 K 个 token 和 1 个 token 耗时相近 2. Speculative Decoding 通过接受-拒绝采样保证输出分布精确不变 3. EAGLE 在特征空间做 Draft 优于 token 空间(Medusa) 4. 大 batch 时 Speculative Decoding 效果递减,需要动态开关 5. 与 Continuous Batching、Prefix Caching 等方法正交互补
延伸阅读:关于 Speculative Decoding 在完整推理服务中的集成(包括 API 层、负载均衡、多模型部署),请参阅 LLM应用教程第12章-推理优化。
参考文献: 1. Leviathan et al., "Fast Inference from Transformers via Speculative Decoding", ICML 2023 2. Chen et al., "Accelerating Large Language Model Decoding with Speculative Sampling", 2023 3. Cai et al., "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads", ICML 2024 4. Li et al., "EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty", ICML 2024 5. Li et al., "EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees", 2024 6. Fu et al., "Lookahead Decoding: Breaking the Sequential Dependency of LLM Decoding", 2024 7. Saxena, "Prompt Lookup Decoding", 2023 8. Yu et al., "vLLM: Efficient Memory Management for Large Language Model Serving", SOSP 2023 9. Zheng et al., "SGLang: Efficient Execution of Structured Language Model Programs", 2024