跳转至

第13章 Speculative Decoding与推理加速

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

前置章节第12章 FlashAttention原理与实现 中我们学习了如何通过优化注意力计算本身来降低内存占用与提升吞吐。本章聚焦另一个推理侧的核心瓶颈——自回归解码的串行性,介绍 Speculative Decoding(投机解码)及其衍生方法,这是 2025-2026 年大模型推理优化面试中的绝对热点。


13.1 自回归解码的瓶颈

13.1.1 自回归生成的本质

Transformer 语言模型的文本生成过程是严格串行的:

\[P(y_1, y_2, \ldots, y_T) = \prod_{t=1}^{T} P(y_t \mid y_1, \ldots, y_{t-1})\]

每一步只能生成 1 个 token,后一个 token 的生成依赖前一个 token 的输出。对于一个 \(T\) 个 token 的回复,模型需要执行 \(T\) 次前向传播。

Text Only
Step 1: [prompt]         → token_1
Step 2: [prompt, t1]     → token_2
Step 3: [prompt, t1, t2] → token_3
...(必须串行,无法跳步)

13.1.2 Memory-Bound 问题

这是理解 Speculative Decoding 为何有效的关键。

Prefill 阶段 vs Decode 阶段的计算特性完全不同:

特性 Prefill(预填充) Decode(逐token解码)
输入规模 整段 prompt(数百~数千 token 并行) 单个 token
计算密度 Compute-bound(矩阵-矩阵乘) Memory-bound(矩阵-向量乘)
GPU 利用率 高(40-80%) 极低(<5%)
瓶颈 计算算力(FLOPs) 显存带宽(GB/s)

逐 token 解码时,每一步需要:

  1. 读取全部模型参数——对 70B 模型约 140GB(FP16)
  2. 读写 KV-Cache——随序列长度线性增长
  3. 仅做一次矩阵-向量乘——计算量极少
Python
# 直观理解 Memory-bound
# 70B模型 Decode阶段每步:
model_params = 70e9 * 2  # 140GB (FP16)
kv_cache_per_step = 0.5  # ~0.5GB (假设,实际取决于序列长度和hidden_size)
total_read = model_params + kv_cache_per_step  # ~140.5GB

# A100 80GB GPU 带宽 = 2TB/s
bandwidth = 2e12  # bytes/s
min_time_per_token = total_read / bandwidth  # ~70ms
# 注:这是简化假设,实际读取时间受KV Cache命中、模型分片等因素影响
# 即便计算瞬间完成,仅读取权重就需要 ~70ms → ~14 tokens/s

# 但实际计算量极小:
flops_per_token = 2 * 70e9  # ~140 GFLOPs
a100_flops = 312e12  # 312 TFLOPS (FP16)
compute_time = flops_per_token / a100_flops  # ~0.45ms

# 计算时间 0.45ms vs 读取时间 70ms → 计算资源严重浪费!

13.1.3 批处理解码 vs 单 Token 解码

如果我们能同时验证多个 token(而非一个个生成),就能更好利用 GPU:

Python
import torch
import time

def benchmark_matmul():
    """对比单向量 vs 批量矩阵乘的 GPU利用率"""
    device = torch.device('cuda')
    weight = torch.randn(4096, 4096, device=device, dtype=torch.float16)

    # 单 token 解码:矩阵-向量乘 (4096, 4096) x (4096, 1)
    x_single = torch.randn(1, 4096, device=device, dtype=torch.float16)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(1000):
        _ = x_single @ weight.T
    torch.cuda.synchronize()
    t_single = (time.perf_counter() - t0) / 1000

    # 批量验证:矩阵-矩阵乘 (4096, 4096) x (4096, 8)
    x_batch = torch.randn(8, 4096, device=device, dtype=torch.float16)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(1000):
        _ = x_batch @ weight.T
    torch.cuda.synchronize()
    t_batch = (time.perf_counter() - t0) / 1000

    print(f"单token: {t_single*1e6:.1f} μs (1 token)")
    print(f"批量8:   {t_batch*1e6:.1f} μs (8 tokens)")
    print(f"批量8每token: {t_batch/8*1e6:.1f} μs")
    print(f"吞吐提升: {t_single / (t_batch/8):.1f}x")

# 典型输出(A100):
# 单token: 42.3 μs (1 token)
# 批量8:   48.7 μs (8 tokens)
# 批量8每token: 6.1 μs
# 吞吐提升: 6.9x

关键洞察:处理 8 个 token 的时间几乎等于处理 1 个 token 的时间——这正是 Speculative Decoding 加速的根本原因。

13.1.4 面试考点

Q:为什么 LLM 推理阶段 GPU 利用率很低?

自回归 Decode 阶段每步仅处理 1 个 token,是矩阵-向量乘(Memory-bound),GPU 大量时间花在读取模型权重和 KV-Cache 上,计算单元严重闲置。70B 模型在 A100 上每步计算时间不到 1ms,但读取权重就要约 70ms。


13.2 Speculative Decoding 核心原理

13.2.1 核心思想

Speculative Decoding(投机解码)的核心思想:

用一个小而快的 Draft 模型猜测未来 K 个 token,然后用大的 Target 模型一次性并行验证这 K 个 token。

  • Draft 模型\(M_q\)):小模型,速度快,准确度有限
  • Target 模型\(M_p\)):大模型,速度慢,结果准确
Text Only
传统自回归(K=5 个 token):
  Target: [fwd] → t1 → [fwd] → t2 → [fwd] → t3 → [fwd] → t4 → [fwd] → t5
  总计: 5 次 Target 前向传播

Speculative Decoding(K=5 个猜测):
  Draft:  [fwd×5 快速] → guess: t1, t2, t3, t4, t5
  Target: [1次并行fwd] → verify: ✓t1, ✓t2, ✓t3, ✗t4 → 修正 t4'
  总计: 5 次 Draft(快) + 1 次 Target 前向传播
  若 Draft 速度 = Target/10,且 5 个中接受 3 个 → 加速约 2-3x

13.2.2 Draft-then-Verify 范式

完整算法流程:

Text Only
输入: Target模型 M_p, Draft模型 M_q, 投机长度 K, 已有前缀 prefix

1. DRAFT 阶段(串行,但很快):
   for i = 1 to K:
       q_i = M_q(prefix + guesses[:i-1])  # Draft 模型预测概率分布
       guess_i ~ q_i                       # 从 Draft 分布采样

2. VERIFY 阶段(并行,一次前向传播):
   p_1, p_2, ..., p_K, p_{K+1} = M_p(prefix + [guess_1, ..., guess_K])
   # Target 模型一次处理 K 个 token,得到 K+1 个位置的概率分布

3. ACCEPT/REJECT(逐个验证):
   for i = 1 to K:
       r ~ Uniform(0, 1)
       if r < min(1, p_i(guess_i) / q_i(guess_i)):
           ACCEPT guess_i, 继续
       else:
           REJECT guess_i
           从修正分布 norm(max(0, p_i - q_i)) 采样替代 token
           丢弃 guess_{i+1}, ..., guess_K
           BREAK

   if 全部接受:
       从 p_{K+1} 额外采样 1 个 bonus token

4. 将已接受的 tokens 追加到 prefix,回到步骤 1

13.2.3 数学保证:输出分布不变

定理:Speculative Decoding 的输出分布与直接使用 Target 模型的分布完全一致

设 Draft 分布为 \(q(x)\),Target 分布为 \(p(x)\),对于候选 token \(x\)

  • 接受概率\(\alpha = \min\left(1, \frac{p(x)}{q(x)}\right)\)
  • 拒绝后修正采样:从 \(p'(x) = \frac{\max(0, p(x) - q(x))}{\sum_{x'} \max(0, p(x') - q(x'))}\) 中采样

证明最终分布为 \(p(x)\)

对任意 token \(x\),被选中的概率为:

\[P(\text{选中 } x) = q(x) \cdot \min\left(1, \frac{p(x)}{q(x)}\right) + \left(1 - \sum_{x'} q(x') \cdot \min\left(1, \frac{p(x')}{q(x')}\right)\right) \cdot p'(x)\]

分两种情况讨论:

情况 1:当 \(p(x) \geq q(x)\) 时: $\(q(x) \cdot 1 + \beta \cdot \frac{p(x) - q(x)}{\sum_{x'} \max(0, p(x') - q(x'))} = q(x) + (p(x) - q(x)) = p(x)\)$

(其中 \(\beta = \sum_{x'} \max(0, p(x') - q(x'))\)

情况 2:当 \(p(x) < q(x)\) 时: $\(q(x) \cdot \frac{p(x)}{q(x)} + \beta \cdot 0 = p(x)\)$

两种情况下最终分布都等于 \(p(x)\)。✓

13.2.4 验证算法详解

Python
import torch
import torch.nn.functional as F

def speculative_sampling(
    target_logits: torch.Tensor,  # (K+1, vocab_size) - Target模型输出
    draft_probs: torch.Tensor,    # (K, vocab_size)   - Draft模型采样时的概率
    draft_tokens: torch.Tensor,   # (K,)              - Draft模型采样的token
    temperature: float = 1.0
) -> tuple[torch.Tensor, int]:
    """
    执行 Speculative Decoding 的验证步骤。

    Returns:
        accepted_tokens: 最终接受的 token 序列
        n_accepted: 接受的 Draft token 数量
    """
    K = draft_tokens.shape[0]
    vocab_size = target_logits.shape[-1]  # [-1]负索引取最后元素

    # Target 模型的概率分布
    target_probs = F.softmax(target_logits[:K] / temperature, dim=-1)  # (K, V)  # F.xxx PyTorch函数式API

    accepted_tokens = []
    n_accepted = 0

    for i in range(K):
        token = draft_tokens[i]
        p_i = target_probs[i, token].item()   # Target 概率
        q_i = draft_probs[i, token].item()    # Draft 概率

        # 接受概率
        accept_prob = min(1.0, p_i / (q_i + 1e-10))

        r = torch.rand(1).item()
        if r < accept_prob:
            # 接受 Draft 的猜测
            accepted_tokens.append(token.item())
            n_accepted += 1
        else:
            # 拒绝:从修正分布采样
            # p'(x) = norm(max(0, p(x) - q(x)))
            corrected = torch.clamp(target_probs[i] - draft_probs[i], min=0)
            corrected = corrected / (corrected.sum() + 1e-10)
            replacement = torch.multinomial(corrected, 1).item()
            accepted_tokens.append(replacement)
            break

    # 如果全部接受,从 p_{K+1} 采样 bonus token
    if n_accepted == K:
        bonus_probs = F.softmax(target_logits[K] / temperature, dim=-1)
        bonus_token = torch.multinomial(bonus_probs, 1).item()
        accepted_tokens.append(bonus_token)

    return torch.tensor(accepted_tokens), n_accepted

13.2.5 加速原理分析

设: - \(c\):Draft 模型与 Target 模型的速度比(\(c = t_{\text{draft}} / t_{\text{target}}\),通常 \(c \ll 1\)) - \(K\):投机长度(Draft 一次猜测的 token 数) - \(\alpha\):平均接受率

每轮墙钟时间\(T_{\text{round}} = K \cdot c \cdot t_{\text{target}} + t_{\text{target}} = (Kc + 1) \cdot t_{\text{target}}\)

每轮平均产出 token 数: $\(E[\text{tokens}] = \sum_{i=0}^{K-1} \alpha^i (1-\alpha) \cdot (i+1) + \alpha^K \cdot (K+1) = \frac{1 - \alpha^{K+1}}{1 - \alpha}\)$

加速比: $\(\text{Speedup} = \frac{E[\text{tokens}]}{Kc + 1} = \frac{1 - \alpha^{K+1}}{(1 - \alpha)(Kc + 1)}\)$

Python
import numpy as np

def compute_speedup(alpha: float, K: int, c: float) -> float:
    """计算 Speculative Decoding 的理论加速比"""
    expected_tokens = (1 - alpha**(K+1)) / (1 - alpha)
    wall_time_ratio = K * c + 1
    return expected_tokens / wall_time_ratio

# 不同参数下的加速比
print("=" * 60)
print(f"{'接受率α':<10} {'K':<5} {'速度比c':<10} {'加速比':<10}")
print("=" * 60)
for alpha in [0.6, 0.7, 0.8, 0.9]:
    for K in [3, 5, 8]:
        for c in [0.05, 0.1]:
            s = compute_speedup(alpha, K, c)
            print(f"{alpha:<10} {K:<5} {c:<10} {s:<10.2f}")
    print("-" * 60)

# 典型输出:
# α=0.8, K=5, c=0.05 → 加速比 ≈ 2.95x
# α=0.9, K=5, c=0.05 → 加速比 ≈ 3.69x
# α=0.7, K=5, c=0.10 → 加速比 ≈ 1.94x

13.2.6 完整 PyTorch 实现

Python
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional

class SpeculativeDecoder:
    """
    完整的 Speculative Decoding 实现。
    使用小模型 (draft) 猜测 + 大模型 (target) 验证。
    """
    def __init__(
        self,
        target_model_name: str,
        draft_model_name: str,
        device: str = "cuda",
        dtype: torch.dtype = torch.float16,
    ):
        print(f"Loading target model: {target_model_name}")
        self.target_model = AutoModelForCausalLM.from_pretrained(
            target_model_name, torch_dtype=dtype, device_map=device
        )
        self.target_model.eval()  # eval()评估模式

        print(f"Loading draft model: {draft_model_name}")
        self.draft_model = AutoModelForCausalLM.from_pretrained(
            draft_model_name, torch_dtype=dtype, device_map=device
        )
        self.draft_model.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
        self.device = device

    @torch.no_grad()  # 禁用梯度计算,节省内存
    def draft_generate(
        self,
        input_ids: torch.Tensor,
        K: int,
        temperature: float = 1.0,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Draft 模型串行生成 K 个候选 token。
        Returns:
            draft_tokens: (K,) 候选 token ids
            draft_probs:  (K, vocab_size) 候选 token 的概率分布
        """
        draft_tokens = []
        draft_probs = []
        current_ids = input_ids.clone()

        for _ in range(K):
            outputs = self.draft_model(current_ids)
            logits = outputs.logits[:, -1, :]  # (1, vocab_size)
            probs = F.softmax(logits / temperature, dim=-1)
            token = torch.multinomial(probs, 1)  # (1, 1)

            draft_tokens.append(token.squeeze())  # squeeze压缩维度
            draft_probs.append(probs.squeeze())
            current_ids = torch.cat([current_ids, token], dim=-1)  # torch.cat沿已有维度拼接张量

        return torch.stack(draft_tokens), torch.stack(draft_probs)  # torch.stack沿新维度拼接张量

    @torch.no_grad()
    def target_verify(
        self,
        input_ids: torch.Tensor,
        draft_tokens: torch.Tensor,
    ) -> torch.Tensor:
        """
        Target 模型一次前向传播验证 K 个候选 token。
        Returns:
            logits: (K+1, vocab_size) — K个验证位置 + 1个bonus位置
        """
        # 将 draft tokens 拼接到输入
        candidate_ids = torch.cat([
            input_ids,
            draft_tokens.unsqueeze(0)  # unsqueeze增加一个维度
        ], dim=-1)  # (1, seq_len + K)

        outputs = self.target_model(candidate_ids)
        # 取最后 K+1 个位置的 logits
        # 位置 -(K+1): 验证第1个draft token
        # 位置 -1: bonus token 位置
        logits = outputs.logits[0, -(len(draft_tokens)+1):, :]
        return logits

    @torch.no_grad()
    def speculative_sample(
        self,
        target_logits: torch.Tensor,
        draft_probs: torch.Tensor,
        draft_tokens: torch.Tensor,
        temperature: float = 1.0,
    ) -> tuple[torch.Tensor, int]:
        """接受-拒绝采样"""
        K = draft_tokens.shape[0]
        target_probs = F.softmax(target_logits[:K] / temperature, dim=-1)

        accepted = []
        n_accepted = 0

        for i in range(K):
            tok = draft_tokens[i]
            p = target_probs[i, tok].item()
            q = draft_probs[i, tok].item()

            if torch.rand(1).item() < min(1.0, p / (q + 1e-10)):
                accepted.append(tok.item())
                n_accepted += 1
            else:
                # 修正采样
                corrected = torch.clamp(target_probs[i] - draft_probs[i], min=0)
                corrected /= corrected.sum() + 1e-10
                replacement = torch.multinomial(corrected, 1).item()
                accepted.append(replacement)
                break

        if n_accepted == K:
            bonus_probs = F.softmax(target_logits[K] / temperature, dim=-1)
            bonus = torch.multinomial(bonus_probs, 1).item()
            accepted.append(bonus)

        return torch.tensor(accepted, device=self.device), n_accepted

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 128,
        K: int = 5,
        temperature: float = 1.0,
    ) -> str:
        """
        使用 Speculative Decoding 生成文本。

        Args:
            prompt: 输入提示
            max_new_tokens: 最大生成 token 数
            K: 投机长度(每轮猜测的 token 数)
            temperature: 采样温度
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated_tokens = 0
        total_draft = 0
        total_accepted = 0

        while generated_tokens < max_new_tokens:
            # Step 1: Draft 模型猜测 K 个 token
            draft_tokens, draft_probs = self.draft_generate(
                input_ids, K, temperature
            )

            # Step 2: Target 模型一次性验证
            target_logits = self.target_verify(input_ids, draft_tokens)

            # Step 3: 接受-拒绝采样
            accepted_tokens, n_accepted = self.speculative_sample(
                target_logits, draft_probs, draft_tokens, temperature
            )

            # Step 4: 更新输入序列
            input_ids = torch.cat([
                input_ids,
                accepted_tokens.unsqueeze(0)
            ], dim=-1)

            generated_tokens += len(accepted_tokens)
            total_draft += K
            total_accepted += n_accepted

            # 检查是否生成了 EOS
            if self.tokenizer.eos_token_id in accepted_tokens.tolist():
                break

        # 统计信息
        accept_rate = total_accepted / total_draft if total_draft > 0 else 0
        print(f"\n[Stats] 生成 {generated_tokens} tokens | "
              f"接受率 {accept_rate:.1%} | "
              f"Draft {total_draft} tokens, 接受 {total_accepted}")

        return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)

# 使用示例
if __name__ == "__main__":
    # 注意:meta-llama 模型需要在 HuggingFace 上申请访问权限
    # 替代方案:使用公开模型如 "openlm-research/open_llama_7b"
    decoder = SpeculativeDecoder(
        target_model_name="meta-llama/Llama-3.1-8B-Instruct",  # 需要HuggingFace访问权限
        draft_model_name="meta-llama/Llama-3.2-1B-Instruct",   # 需要HuggingFace访问权限
    )

    result = decoder.generate(
        prompt="Explain quantum computing in simple terms:",
        max_new_tokens=200,
        K=5,
        temperature=0.7,
    )
    print(result)

13.3 Draft 模型选择策略

Draft 模型的质量直接决定接受率和最终加速效果。理想 Draft 模型需要满足:

  1. :推理速度远快于 Target(通常 5-20x)
  2. :与 Target 分布尽可能接近(接受率高)
  3. 对齐:共享相同的 tokenizer 和词表

13.3.1 同架构小模型

最朴素的策略:使用同系列的小规模模型作为 Draft。

Target 模型 Draft 模型 参数比 典型接受率
Llama-3.1-70B Llama-3.2-1B 70:1 60-75%
Llama-3.1-70B Llama-3.2-3B 23:1 70-82%
Qwen2.5-72B Qwen2.5-0.5B 144:1 55-70%
Qwen2.5-72B Qwen2.5-1.5B 48:1 65-78%
DeepSeek-V3 (671B) DeepSeek-V3-Lite ~50:1 65-80%

优点:实现简单,无需额外训练,词表天然一致。

缺点:小模型与大模型能力差距大时接受率低,尤其在复杂推理任务上。

13.3.2 自蒸馏 Draft 模型

从 Target 模型蒸馏出专门用于 Draft 的小模型:

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DraftDistillationTrainer:
    """
    从 Target 模型知识蒸馏训练 Draft 模型。
    使用 KL 散度让 Draft 尽可能模仿 Target 的输出分布。
    """
    def __init__(self, target_model, draft_model, temperature=2.0):
        self.target = target_model
        self.draft = draft_model
        self.temperature = temperature

    def distillation_loss(self, input_ids, attention_mask=None):
        """KL(target || draft) — 让 Draft 模拟 Target 的分布"""
        with torch.no_grad():
            target_logits = self.target(
                input_ids, attention_mask=attention_mask
            ).logits
            target_probs = F.softmax(
                target_logits / self.temperature, dim=-1
            )

        draft_logits = self.draft(
            input_ids, attention_mask=attention_mask
        ).logits
        draft_log_probs = F.log_softmax(
            draft_logits / self.temperature, dim=-1
        )

        # KL 散度
        kl_loss = F.kl_div(
            draft_log_probs, target_probs, reduction='batchmean'
        ) * (self.temperature ** 2)

        return kl_loss

优点:接受率比通用小模型高 5-15%。

缺点:需要额外的训练流程和计算资源。

13.3.3 N-gram 模型作为 Draft

一种极端轻量级策略——使用统计 N-gram 模型或基于上下文的词表匹配:

Python
from collections import defaultdict

class NGramDraft:
    """
    基于 N-gram 统计的超轻量 Draft。
    几乎零延迟,但接受率较低。
    适合重复性高的场景(代码生成、模板文本)。
    """
    def __init__(self, n: int = 3):
        self.n = n
        # 两层嵌套默认字典:外层key不存在时自动创建新的defaultdict(int),内层key不存在时自动初始化为0,用于N-gram→next_token→计数
        self.counts = defaultdict(lambda: defaultdict(int))  # defaultdict访问不存在的键时返回默认值

    def fit(self, token_sequences: list[list[int]]):
        """从训练语料构建 N-gram 统计"""
        for seq in token_sequences:
            for i in range(len(seq) - self.n):
                context = tuple(seq[i:i+self.n])
                next_token = seq[i+self.n]
                self.counts[context][next_token] += 1

    def draft(self, context_tokens: list[int], K: int) -> list[int]:
        """基于 N-gram 猜测接下来 K 个 token"""
        guesses = []
        ctx = list(context_tokens)
        for _ in range(K):
            key = tuple(ctx[-self.n:])
            candidates = self.counts.get(key, {})
            if candidates:
                # 选最高频的 token
                best = max(candidates, key=candidates.get)
                guesses.append(best)
                ctx.append(best)
            else:
                break
        return guesses

13.3.4 Prompt Lookup Decoding

核心思想:直接从 prompt 中寻找可能的续写内容,无需任何 Draft 模型。

Python
def prompt_lookup_draft(
    input_ids: list[int],
    prompt_length: int,
    K: int = 5,
    n_match: int = 3,
) -> list[int]:
    """
    Prompt Lookup Decoding: 在 prompt 中寻找匹配当前后缀的片段,
    将其后续 token 作为 draft 候选。

    特别适合:摘要、翻译、重写等与输入高度相关的任务。
    """
    if len(input_ids) < n_match:
        return []

    # 当前生成序列的最后 n_match 个 token
    suffix = input_ids[-n_match:]

    # 在 prompt 中搜索相同的 n-gram
    candidates = []
    for i in range(prompt_length - n_match):
        if input_ids[i:i+n_match] == suffix:
            # 找到匹配!取后续 K 个 token 作为猜测
            end = min(i + n_match + K, prompt_length)
            draft = input_ids[i+n_match:end]
            if draft:
                candidates.append(draft)

    # 返回最长的匹配候选
    if candidates:
        return max(candidates, key=len)[:K]
    return []

优点:零额外模型开销,无需 GPU 显存。

缺点:仅在 prompt 与输出有大量重叠时有效。

13.3.5 各策略对比

策略 延迟开销 典型接受率 额外显存 适用场景
同架构小模型 60-80% 需加载小模型 通用场景
自蒸馏 Draft 70-85% 需加载小模型 追求极致加速
N-gram 极低 30-50% 几乎为零 高重复场景(代码)
Prompt Lookup 0-70%(波动大) 摘要/翻译/重写
Medusa 多头 微增 60-80% 微增(几层MLP) 不想额外加载模型

13.4 Medusa:多头投机解码

13.4.1 Medusa 核心思想

Medusa(ICML 2024)提出了一种无需独立 Draft 模型的投机解码方案:

在 Target 模型的最后一层隐藏状态上,添加多个独立的预测头(Medusa Heads),每个头预测未来第 \(i\) 个位置的 token。

Text Only
传统 LM Head:       hidden_state → head_0 → token_{t+1}

Medusa:             hidden_state → head_0 → token_{t+1}  (原始预测)
                                 → head_1 → token_{t+2}  (提前1步)
                                 → head_2 → token_{t+3}  (提前2步)
                                 → head_3 → token_{t+4}  (提前3步)
                                 → head_4 → token_{t+5}  (提前4步)

13.4.2 Medusa Head 实现

Python
import torch
import torch.nn as nn

class MedusaHead(nn.Module):  # 继承nn.Module定义网络层
    """
    单个 Medusa 预测头。
    结构:1层残差MLP + 共享的LM Head权重。
    """
    def __init__(self, hidden_size: int, vocab_size: int):
        super().__init__()  # super()调用父类方法
        self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
        self.act = nn.SiLU()
        self.ln = nn.LayerNorm(hidden_size)
        # 通常复用原始模型的 lm_head 权重
        self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden_states: (batch, seq_len, hidden_size)
        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        x = self.act(self.fc(hidden_states))
        x = self.ln(x + hidden_states)  # 残差连接
        return self.lm_head(x)

class MedusaModel(nn.Module):
    """
    在基础模型上添加 K 个 Medusa 预测头。
    """
    def __init__(self, base_model, num_heads: int = 4, hidden_size: int = 4096, vocab_size: int = 32000):
        super().__init__()
        self.base_model = base_model
        self.medusa_heads = nn.ModuleList([
            MedusaHead(hidden_size, vocab_size)
            for _ in range(num_heads)
        ])
        # 共享原始模型的 lm_head 权重
        for head in self.medusa_heads:
            head.lm_head.weight = base_model.lm_head.weight

    @torch.no_grad()
    def forward(self, input_ids, attention_mask=None):
        outputs = self.base_model.model(
            input_ids, attention_mask=attention_mask
        )
        hidden = outputs.last_hidden_state  # (B, L, D)

        # 原始 LM Head 预测 t+1
        base_logits = self.base_model.lm_head(hidden)

        # Medusa Heads 预测 t+2, t+3, ...
        medusa_logits = [head(hidden) for head in self.medusa_heads]

        return base_logits, medusa_logits

13.4.3 Tree-Structured Attention

Medusa 的关键创新:使用树状结构组织候选 token,而非线性序列。

Text Only
线性猜测: t1 → t2 → t3 → t4 → t5
  (一个被拒绝,后续全部丢弃)

树状猜测:           t1
                  / | \
                t2a t2b t2c
               /|    |   \
            t3a t3b  t3c  t3d
  (t2a被拒绝时,t2b/t2c 分支仍可能被接受)
Python
def build_medusa_tree(
    medusa_logits: list[torch.Tensor],
    top_k: int = 3,
) -> list[list[int]]:
    """
    构建 Medusa 候选树。

    Args:
        medusa_logits: K 个 Medusa Head 的输出 logits
        top_k: 每个 Head 取 top-k 候选

    Returns:
        candidates: 所有候选路径 [[t1, t2, t3], [t1, t2', t3'], ...]
    """
    K = len(medusa_logits)
    # 每个 Head 取 top-k
    topk_per_head = []
    for logits in medusa_logits:
        probs = torch.softmax(logits[:, -1, :], dim=-1)
        topk = torch.topk(probs, top_k, dim=-1)
        topk_per_head.append(topk.indices[0].tolist())  # (top_k,)

    # 构建所有路径组合(可用剪枝策略减少路径数)
    # 简化示例:仅取联合概率最高的若干路径
    candidates = []
    for t1 in topk_per_head[0]:
        for t2 in topk_per_head[1]:
            for t3 in topk_per_head[2] if K > 2 else [None]:
                path = [t1, t2]
                if t3 is not None:
                    path.append(t3)
                candidates.append(path)

    return candidates

def tree_attention_mask(num_candidates: int, seq_len: int) -> torch.Tensor:
    """
    构建 Tree Attention 的因果掩码。
    每条候选路径只能看到自己的祖先节点和共享前缀。
    """
    # 简化示例:具体实现需根据树的拓扑结构
    mask = torch.zeros(num_candidates, seq_len + num_candidates)
    # 所有候选都能看到完整 prefix
    mask[:, :seq_len] = 1
    # 每个候选只能看到自己路径上的祖先
    for i in range(num_candidates):
        mask[i, seq_len + i] = 1
    return mask

13.4.4 Medusa-1 vs Medusa-2

特性 Medusa-1 Medusa-2
训练方式 冻结基础模型,只训练 Medusa Heads 联合微调基础模型 + Medusa Heads
训练成本 低(仅几层 MLP) 高(全模型微调)
预测质量 中等(接受率 ~60%) 高(接受率 ~75-80%)
是否保持原模型能力 可能略有变化
适用场景 快速部署、不想改基础模型 追求最大加速比

13.4.5 Medusa 训练代码

Python
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

def train_medusa_heads(
    model: MedusaModel,
    dataloader: DataLoader,
    num_epochs: int = 3,
    lr: float = 1e-4,
):
    """
    Medusa-1 训练:冻结基础模型,仅训练 Medusa Heads。
    训练目标:让 Head_i 预测位置 t+i+2 的 token。
    """
    # 冻结基础模型
    for param in model.base_model.parameters():
        param.requires_grad = False

    # 只优化 Medusa Heads(不含共享的 lm_head.weight)
    medusa_params = []
    for head in model.medusa_heads:
        medusa_params.extend([head.fc.weight, head.ln.weight, head.ln.bias])

    optimizer = torch.optim.AdamW(medusa_params, lr=lr)

    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            input_ids = batch["input_ids"].cuda()
            labels = batch["labels"].cuda()  # 与 input_ids 对齐的标签

            base_logits, medusa_logits = model(input_ids)

            loss = 0
            for i, m_logits in enumerate(medusa_logits):  # enumerate同时获取索引和元素
                # Head_i 预测 t+i+2 的 token
                shift = i + 2  # head_0 预测 t+2, head_1 预测 t+3, ...
                if shift < labels.shape[1]:
                    pred = m_logits[:, :-shift, :]  # 预测部分
                    target = labels[:, shift:]       # 目标部分
                    loss += F.cross_entropy(  # F.cross_entropy PyTorch函数式交叉熵损失
                        pred.reshape(-1, pred.shape[-1]),  # 重塑张量形状
                        target.reshape(-1),
                        ignore_index=-100,
                    )

            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新参数
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}")

13.5 EAGLE 与 EAGLE-2

13.5.1 EAGLE:特征级别 Draft

EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)的核心创新:

不在 token 级别做 Draft(预测 token),而在 Target 模型的隐藏特征级别做 Draft(预测特征向量),大幅提高 Draft 准确性。

Text Only
传统 Speculative Decoding:
  Draft 模型: [tokens] → Draft LM → 预测下一个 token ID
  准确率受限于 Draft 模型能力

EAGLE:
  Target 模型第1步: [tokens] → Target LM → hidden features (f1)
  EAGLE Draft:      f1 → 轻量网络 → 预测 f2 (Target的下一步特征)
                    f2 → 轻量网络 → 预测 f3
                    ...
                    f_i → Target LM Head → token_i
  关键优势: 在特征空间做外推,比在 token 空间做预测容易得多

为什么特征级别更好?

  1. Token 预测是离散的高维分类问题(32K-128K 个类别),容易出错
  2. 特征向量是连续的低维空间(4096 维),Auto-regressive 外推更准确
  3. EAGLE Draft 网络直接利用 Target 模型的中间特征,信息量更丰富
Python
import torch
import torch.nn as nn

class EAGLEDraftHead(nn.Module):
    """
    EAGLE 的特征级别 Draft 网络。
    输入: Target 模型的隐藏特征 + token embedding
    输出: 预测下一步的隐藏特征
    """
    def __init__(self, hidden_size: int = 4096):
        super().__init__()
        # 融合 token embedding 和 hidden state
        self.fc_in = nn.Linear(hidden_size * 2, hidden_size, bias=False)
        # 轻量级 Transformer 层(通常1层)
        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=32,
            dim_feedforward=hidden_size * 4,
            batch_first=True,
            norm_first=True,
        )
        self.ln = nn.LayerNorm(hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        token_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: Target 模型的隐藏特征 (B, L, D)
            token_embeddings: 对应 token 的 embedding (B, L, D)
        Returns:
            预测的下一步隐藏特征 (B, L, D)
        """
        # 拼接特征和 embedding
        fused = torch.cat([hidden_states, token_embeddings], dim=-1)
        x = self.fc_in(fused)
        x = self.transformer_layer(x)
        return self.ln(x)

class EAGLEDecoder:
    """EAGLE Speculative Decoding 推理流程"""

    def __init__(self, target_model, eagle_head, lm_head, embed_tokens):
        self.target = target_model
        self.eagle_head = eagle_head
        self.lm_head = lm_head        # 复用 Target 的 LM Head
        self.embed = embed_tokens      # 复用 Target 的 Embedding

    @torch.no_grad()
    def draft_with_features(
        self,
        last_hidden: torch.Tensor,  # Target 模型最后一层输出
        last_token_id: int,
        K: int = 5,
    ) -> tuple[list[int], list[torch.Tensor]]:
        """
        使用 EAGLE Head 在特征空间做 K 步外推。
        """
        draft_tokens = []
        draft_features = []

        current_hidden = last_hidden[:, -1:, :]  # (1, 1, D)
        current_embed = self.embed(
            torch.tensor([[last_token_id]], device=last_hidden.device)
        )  # (1, 1, D)

        for _ in range(K):
            # 预测下一步特征
            predicted_hidden = self.eagle_head(current_hidden, current_embed)
            draft_features.append(predicted_hidden)

            # 用 Target 的 LM Head 将特征映射到 token
            logits = self.lm_head(predicted_hidden)  # (1, 1, V)
            token = logits.argmax(dim=-1).item()
            draft_tokens.append(token)

            # 更新 Draft 状态
            current_hidden = predicted_hidden
            current_embed = self.embed(
                torch.tensor([[token]], device=last_hidden.device)
            )

        return draft_tokens, draft_features

13.5.2 EAGLE-2:Context-Aware 动态 Draft 树

EAGLE-2 在 EAGLE 基础上引入两个关键改进:

  1. 动态 Draft 树:根据当前上下文动态调整投机树的形状和深度
  2. 置信度感知采样:利用 Draft Head 的置信度信号决定是否继续展开分支
Text Only
EAGLE-1(静态树):
  所有输入使用相同的树结构(如固定 K=5, top-k=3)

EAGLE-2(动态树):
  简单token("the", "is")→ 深树, 少分支 (K=8, top-k=1)
  复杂token(推理、代码)→ 浅树, 多分支 (K=3, top-k=5)
  置信度低时 → 提前终止 Draft,避免浪费
Python
class DynamicDraftTree:
    """EAGLE-2 的动态 Draft 树策略"""

    def __init__(
        self,
        max_depth: int = 8,
        max_width: int = 5,
        confidence_threshold: float = 0.4,
        max_candidates: int = 60,
    ):
        self.max_depth = max_depth
        self.max_width = max_width
        self.conf_threshold = confidence_threshold
        self.max_candidates = max_candidates

    def build_tree(self, eagle_head, hidden, embed, lm_head):
        """
        动态构建 Draft 树。
        高置信度分支: 继续深入
        低置信度分支: 剪枝停止
        """
        tree_nodes = []  # [(token, parent_idx, depth, confidence)]
        active_leaves = [(hidden[:, -1:, :], embed[:, -1:, :], -1, 0)]

        while active_leaves and len(tree_nodes) < self.max_candidates:
            new_leaves = []
            for h, e, parent_idx, depth in active_leaves:
                if depth >= self.max_depth:
                    continue

                # EAGLE Head 预测
                pred_hidden = eagle_head(h, e)
                logits = lm_head(pred_hidden).squeeze()
                probs = torch.softmax(logits, dim=-1)

                # 取 top-k
                topk_probs, topk_ids = probs.topk(self.max_width)

                for prob, tid in zip(topk_probs, topk_ids):  # zip按位置配对
                    conf = prob.item()
                    if conf < self.conf_threshold:
                        continue  # 置信度低,剪枝

                    node_idx = len(tree_nodes)
                    tree_nodes.append((tid.item(), parent_idx, depth + 1, conf))

                    # 准备该节点的特征用于下一层
                    new_embed = embed.new_zeros(1, 1, embed.shape[-1])
                    new_leaves.append((pred_hidden, new_embed, node_idx, depth + 1))

            active_leaves = new_leaves

        return tree_nodes

13.5.3 性能对比

在多个基准上的典型加速比(相比原始自回归解码,在 A100 GPU、batch_size=1 条件下):

方法 MT-Bench 加速 HumanEval 加速 额外显存 是否需要训练
Speculative Decoding (Llama-3.2-1B → 70B) 2.0-2.5x 1.8-2.2x +1B 模型
Medusa-1 2.0-2.3x 1.7-2.0x +0.5% 是(仅 Head)
Medusa-2 2.3-2.8x 2.0-2.5x +0.5% 是(全模型)
EAGLE 2.5-3.2x 2.8-3.5x +2-5% 是(Head)
EAGLE-2 2.8-3.6x 3.0-4.0x +2-5% 是(Head)

面试要点:EAGLE 系列在代码生成(HumanEval)上加速更明显,因为代码的自回归可预测性更高(语法约束强),Draft 接受率更高。


13.6 其他推理加速方法

13.6.1 Lookahead Decoding:Jacobi 迭代并行解码

核心思想:将自回归生成看作一个不动点方程,使用 Jacobi 迭代并行求解。

\[Y^{(k+1)}_t = \arg\max P(y_t \mid y_1^{(k)}, y_2^{(k)}, \ldots, y_{t-1}^{(k)})\]
Python
import torch

def lookahead_decoding_step(model, input_ids, n_lookahead=5, n_iterations=3):
    """
    Lookahead Decoding 的核心步骤。
    维护一个 N-token 的猜测窗口,通过 Jacobi 迭代并行求解。

    优点: 不需要 Draft 模型
    缺点: 需要多次迭代,加速比通常低于 Speculative Decoding

    Args:
        model: 语言模型
        input_ids: 当前序列 (1, seq_len)
        n_lookahead: 前瞻窗口大小
        n_iterations: Jacobi 迭代次数
    """
    device = input_ids.device
    seq_len = input_ids.shape[1]

    # 初始化猜测窗口(随机或启发式)
    guess = torch.randint(0, model.config.vocab_size, (1, n_lookahead), device=device)

    for iteration in range(n_iterations):
        # 将猜测拼接到序列
        full_ids = torch.cat([input_ids, guess], dim=-1)

        # 构造因果掩码(标准自回归)
        outputs = model(full_ids)
        logits = outputs.logits  # (1, seq_len + n_lookahead, vocab_size)

        # 并行更新所有猜测位置
        new_guess = logits[0, seq_len-1:seq_len+n_lookahead-1, :].argmax(dim=-1)
        new_guess = new_guess.unsqueeze(0)

        # 检查是否收敛
        if torch.equal(new_guess, guess):
            break
        guess = new_guess

    # 验证:找到第一个不一致的位置
    final_ids = torch.cat([input_ids, guess], dim=-1)
    verify_outputs = model(final_ids)
    verify_logits = verify_outputs.logits

    n_accepted = 0
    for i in range(n_lookahead):
        predicted = verify_logits[0, seq_len + i - 1, :].argmax()
        if predicted == guess[0, i]:
            n_accepted += 1
        else:
            break

    return guess[0, :n_accepted+1], n_accepted

13.6.2 Continuous Batching(动态批处理)

传统 Static Batching 的问题:批内所有请求必须等最长序列完成后才能释放资源。

Text Only
Static Batching:
  Req1: [====生成中====]         [等待]
  Req2: [======生成中======]     [等待]
  Req3: [===========生成中===========]
  ↑ Req1完成后仍需等待 Req3

Continuous Batching (Iteration-level scheduling):
  Req1: [====完成====][Req4开始===]
  Req2: [======完成======][Req5=]
  Req3: [===========完成===========]
  ↑ Req1完成后立即插入 Req4
Python
class ContinuousBatchingScheduler:
    """
    连续批处理调度器(简化示意)。
    每次 Decode 迭代后检查是否有请求完成或新请求到达。
    """
    def __init__(self, model, max_batch_size=64):
        self.model = model
        self.max_batch = max_batch_size
        self.active_requests = []     # 正在生成的请求
        self.waiting_queue = []       # 等待队列

    def iteration_step(self):
        """执行一次 Decode 迭代"""
        # 1. 移除已完成的请求
        finished = [r for r in self.active_requests if r.is_done()]
        for r in finished:
            self.active_requests.remove(r)
            r.callback(r.generated_text)

        # 2. 从等待队列填充空位
        while (len(self.active_requests) < self.max_batch
               and self.waiting_queue):
            new_req = self.waiting_queue.pop(0)
            self.active_requests.append(new_req)

        if not self.active_requests:
            return

        # 3. 组装 batch,执行一次前向传播
        batch_ids = self._pad_batch()
        with torch.no_grad():
            outputs = self.model(batch_ids)

        # 4. 每个请求采样下一个 token
        for i, req in enumerate(self.active_requests):
            next_token = outputs.logits[i, -1, :].argmax().item()
            req.append_token(next_token)

    def _pad_batch(self):
        """将不同长度的请求 padding 到相同长度"""
        max_len = max(len(r.tokens) for r in self.active_requests)
        batch = torch.zeros(len(self.active_requests), max_len, dtype=torch.long)
        for i, r in enumerate(self.active_requests):
            batch[i, :len(r.tokens)] = torch.tensor(r.tokens)
        return batch.cuda()

13.6.3 Prefix Caching

核心思想:当多个请求共享相同前缀(如 system prompt)时,缓存该前缀的 KV-Cache,避免重复计算。

Python
class PrefixCacheManager:
    """
    前缀 KV-Cache 管理器。
    适用场景:所有请求共享相同 system prompt。
    """
    def __init__(self, model, max_cache_entries=100):
        self.model = model
        self.cache = {}  # hash(prefix_tokens) → KV-Cache
        self.max_entries = max_cache_entries

    def get_or_compute_prefix_cache(self, prefix_tokens: tuple):
        """查找或计算前缀的 KV-Cache"""
        cache_key = hash(prefix_tokens)

        if cache_key in self.cache:
            return self.cache[cache_key]  # 缓存命中!

        # 缓存未命中,计算 KV-Cache
        input_ids = torch.tensor([list(prefix_tokens)]).cuda()
        with torch.no_grad():
            outputs = self.model(input_ids, use_cache=True)

        kv_cache = outputs.past_key_values

        # LRU 淘汰
        if len(self.cache) >= self.max_entries:
            oldest = next(iter(self.cache))
            del self.cache[oldest]

        self.cache[cache_key] = kv_cache
        return kv_cache

    def generate_with_cached_prefix(self, prefix_tokens, continuation_tokens):
        """使用缓存的前缀 KV-Cache 直接从续写部分开始计算"""
        kv_cache = self.get_or_compute_prefix_cache(tuple(prefix_tokens))

        # 仅需处理续写部分的 token
        cont_ids = torch.tensor([continuation_tokens]).cuda()
        with torch.no_grad():
            outputs = self.model(
                cont_ids,
                past_key_values=kv_cache,
                use_cache=True,
            )
        return outputs

13.6.4 Chunked Prefill(分块预填充)

解决长 prompt 的 Prefill 阶段占用过多时间、阻塞 Decode 请求的问题:

Text Only
传统 Prefill(prompt=4096 tokens):
  [===========Prefill 4096 tokens===========]
  ↑ 这段时间其他 Decode 请求全部被阻塞

Chunked Prefill:
  [Prefill chunk1=1024][Decode batch][Prefill chunk2=1024][Decode batch]...
  ↑ 将 Prefill 分块,与 Decode 请求交替执行
Python
class ChunkedPrefillScheduler:
    """分块预填充调度器"""

    def __init__(self, model, chunk_size=1024):
        self.model = model
        self.chunk_size = chunk_size  # 每块处理的 token 数

    def chunked_prefill(self, prompt_tokens, decode_requests=None):
        """
        将长 prompt 的 Prefill 拆分为多个 chunk 交替执行。
        减少 Prefill 对 Decode 延迟的影响 (TTFT ↓, ITL 稳定)。
        """
        total_len = len(prompt_tokens)
        num_chunks = (total_len + self.chunk_size - 1) // self.chunk_size
        kv_cache = None

        for i in range(num_chunks):
            start = i * self.chunk_size
            end = min(start + self.chunk_size, total_len)
            chunk = prompt_tokens[start:end]

            # 处理一块 Prefill
            chunk_ids = torch.tensor([chunk]).cuda()
            with torch.no_grad():
                outputs = self.model(
                    chunk_ids,
                    past_key_values=kv_cache,
                    use_cache=True,
                )
            kv_cache = outputs.past_key_values

            # 在 chunk 之间插入 Decode 步骤
            if decode_requests:
                self._serve_decode_batch(decode_requests)

        return kv_cache

    def _serve_decode_batch(self, requests):
        """服务当前等待中的 Decode 请求"""
        # 对活跃请求执行一步 Decode
        pass

13.6.5 SGLang 的 RadixAttention

SGLang(Structured Generation Language)的 RadixAttention 是 Prefix Caching 的高级形式:

Text Only
传统 Prefix Cache: 仅匹配完全相同的前缀

RadixAttention(基数树缓存):
  请求1: [System] [User: 什么是AI?]
  请求2: [System] [User: 什么是ML?]
  请求3: [System] [User: 什么是AI?] [Assistant: AI是...]  [User: 详细说说]

  Radix Tree:
    [System] ──┬── [User: 什么是AI?] ──── [Assistant: AI是...] ── [User: 详细说说]
               └── [User: 什么是ML?]

  最长前缀匹配: 请求3 完全命中请求1的缓存,直接从"详细说说"开始计算

核心优势: - 自动管理任意前缀的 KV-Cache - 支持多轮对话的增量缓存 - 与 Continuous Batching 和 Chunked Prefill 正交,可组合使用


13.7 工程实践

13.7.1 vLLM 中的 Speculative Decoding

vLLM(v0.6+)原生支持多种 Speculative Decoding 策略:

Python
from vllm import LLM, SamplingParams

# ===== 方法1:使用独立 Draft 模型 =====
llm = LLM(
    model="meta-llama/Llama-3.1-70B-Instruct",
    speculative_model="meta-llama/Llama-3.2-1B-Instruct",
    num_speculative_tokens=5,      # 每轮猜测 5 个 token
    max_model_len=4096,
    tensor_parallel_size=4,         # Target 模型 4 卡并行
    # Draft 模型自动放在单卡上
)

# ===== 方法2:Prompt Lookup Decoding(不需要 Draft 模型) =====
llm = LLM(
    model="meta-llama/Llama-3.1-70B-Instruct",
    speculative_model="[ngram]",    # 使用 N-gram 作为 Draft
    num_speculative_tokens=5,
    ngram_prompt_lookup_max=4,      # N-gram 匹配的最大长度
    ngram_prompt_lookup_min=2,
)

# ===== 方法3:EAGLE 投机解码 =====
llm = LLM(
    model="meta-llama/Llama-3.1-70B-Instruct",
    speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-70B",
    num_speculative_tokens=5,
    max_model_len=4096,
)

# 使用方式与普通 vLLM 完全一致
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
outputs = llm.generate(["Explain quantum computing:"], sampling_params)
print(outputs[0].outputs[0].text)

vLLM Speculative Decoding 配置要点

参数 说明 建议值
num_speculative_tokens 每轮投机长度 K 3-8(取决于接受率)
speculative_disable_by_batch_size 当 batch size 超过阈值时禁用投机 8-16
speculative_max_model_len Draft 模型最大序列长度 与 Target 一致或更小
speculative_draft_tensor_parallel_size Draft 模型张量并行度 1(通常单卡)

13.7.2 TensorRT-LLM 中的投机采样

Python
# TensorRT-LLM Speculative Decoding 配置示例
import tensorrt_llm
from tensorrt_llm import BuildConfig
from tensorrt_llm.executor import ExecutorConfig

# 构建 Target 引擎
target_build_config = BuildConfig(
    max_batch_size=64,
    max_input_len=2048,
    max_seq_len=4096,
    speculative_decoding_mode="DRAFT_TOKENS_EXTERNAL",
)

# 构建 Draft 引擎
draft_build_config = BuildConfig(
    max_batch_size=64,
    max_input_len=2048,
    max_seq_len=4096,
)

# 运行时配置
executor_config = ExecutorConfig(
    speculative_config={
        "num_draft_tokens": 5,
        "acceptance_threshold": 0.0,  # 使用标准 Speculative Sampling
    }
)

# 在 trtllm-build 命令行中:
# trtllm-build --checkpoint_dir ./target_ckpt \
#              --output_dir ./target_engine \
#              --speculative_decoding_mode draft_tokens_external \
#              --max_draft_len 5

13.7.3 不同场景下的实测加速比

以下为典型实测数据(A100 80GB,batch_size=1,FP16):

Target 模型 方法 代码生成 对话 摘要 数学推理
Llama-3.1-70B 标准解码(基线) 1.0x 1.0x 1.0x 1.0x
Llama-3.1-70B SD (Draft=1B, K=5) 2.3x 2.0x 2.5x 1.6x
Llama-3.1-70B Prompt Lookup (K=5) 1.2x 1.1x 2.8x 1.0x
Llama-3.1-70B EAGLE-2 3.2x 2.8x 2.5x 2.1x
Llama-3.1-70B Medusa-2 (4 heads) 2.5x 2.3x 2.0x 1.7x

关键观察

  1. 代码生成加速效果最好——代码语法约束强,Draft 准确率高
  2. 数学推理加速效果最差——推理链不可预测,Draft 频繁被拒绝
  3. 摘要任务用 Prompt Lookup 效果最好——输出大量复用输入内容
  4. EAGLE-2 在多数场景综合表现最佳

13.7.4 何时适合/不适合使用 Speculative Decoding

适合的场景

条件 原因
Batch size = 1 或很小 GPU 利用率低,有空间做投机计算
低延迟(TTFT↓)需求 减少 Decode 步数,加速首 token 后的生成
输出可预测性高 代码、翻译、模板化文本→高接受率
有现成的好 Draft 模型 同系列小模型或已训练的 EAGLE Head

不适合的场景

条件 原因
Batch size 很大(>16) GPU 已被充分利用,投机计算反而增加开销
显存紧张 Draft 模型或额外 Head 占用显存
创意性/开放式生成 接受率低,投机频繁失败,可能反而变慢
高温度采样(\(T > 1.5\) 概率分布分散,Draft 越难命中
Python
def should_use_speculative_decoding(
    batch_size: int,
    task_type: str,
    temperature: float,
    gpu_memory_free_gb: float,
    draft_model_size_gb: float,
) -> tuple[bool, str]:
    """
    决策函数:是否应该使用 Speculative Decoding
    """
    reasons = []

    if batch_size > 16:
        reasons.append(f"batch_size={batch_size} 过大,GPU已充分利用")
        return False, "; ".join(reasons)

    if temperature > 1.5:
        reasons.append(f"temperature={temperature} 过高,接受率将很低")
        return False, "; ".join(reasons)

    if gpu_memory_free_gb < draft_model_size_gb * 1.2:
        reasons.append("显存不足以加载 Draft 模型")
        return False, "; ".join(reasons)

    # 按任务类型预估加速比
    expected_speedup = {
        "code": 2.5, "translation": 2.2, "summary": 2.0,
        "chat": 1.8, "creative": 1.2, "math": 1.5,
    }.get(task_type, 1.5)

    if expected_speedup < 1.3:
        reasons.append(f"预估加速比 {expected_speedup}x 不值得额外复杂度")
        return False, "; ".join(reasons)

    return True, f"建议使用,预估加速 {expected_speedup}x"

13.8 面试高频题

题目 1:Speculative Decoding 为什么能加速,但不改变输出分布?

加速原因: - 自回归 Decode 是 Memory-bound 的,处理 1 个 token 和处理 K 个 token 的 wall-clock time 几乎相同 - Speculative Decoding 用快速 Draft 模型猜 K 个 token,然后 Target 模型一次前向传播并行验证 K 个 token - 相当于用"1 次 Draft ×K + 1 次 Target"替代"K 次 Target"

分布不变的保证: - 使用接受-拒绝采样:接受概率为 \(\min(1, p(x)/q(x))\) - 拒绝时从修正分布 \(\text{norm}(\max(0, p-q))\) 采样 - 可以数学证明最终每个 token 的边际分布恰好等于 Target 分布 \(p(x)\) - 这不是近似,而是精确保证


题目 2:Speculative Decoding 的加速效果受哪些因素影响?

影响加速比的三个关键因素:

  1. 接受率 \(\alpha\)(最重要)
  2. 取决于 Draft 模型与 Target 模型的分布差异
  3. \(\alpha\) 越高加速越大;\(\alpha < 0.5\) 时可能不如不用
  4. 受任务类型影响:代码/翻译 > 对话 > 创意写作

  5. 速度比 \(c = t_{\text{draft}} / t_{\text{target}}\)

  6. Draft 越快越好,通常 \(c < 0.1\) 才有意义
  7. 包括:模型加载延迟、KV-Cache 管理开销

  8. 投机长度 \(K\)

  9. \(K\) 过小→每轮产出少,频繁切换 Draft/Target 开销大
  10. \(K\) 过大→后面的猜测越来越不准,接受率指数下降
  11. 最优 \(K\) 通常 3-8,取决于 \(\alpha\)\(c\)

理论加速比公式为:\(\text{Speedup} = \frac{1 - \alpha^{K+1}}{(1-\alpha)(Kc+1)}\)


题目 3:比较 Medusa 和 EAGLE 的思路差异

维度 Medusa EAGLE
Draft 方式 在 token 级别:多个独立 Head 各自预测未来第 i 个 token 在特征级别:预测 Target 模型的隐藏特征,再转 token
独立性假设 各 Head 独立预测(不考虑 token 间依赖) 逐步外推特征(保留自回归依赖)
验证结构 树结构组合所有 Head 的 top-k 候选 动态树(EAGLE-2),根据置信度调整
额外参数 极少(每个 Head 仅一层 MLP) 稍多(一层 Transformer + 投影层)
训练数据 需要训练数据对齐 Head 需要训练数据 + Target 模型的中间特征
加速效果 中等(受独立性假设限制) 更高(特征外推准确率高)

核心差异:Medusa 的各 Head 独立预测违反了语言的自回归本质(第 3 个 token 的预测应该依赖第 2 个),EAGLE 通过在特征空间做序列外推保留了这种依赖,因此 Draft 质量更高。


题目 4:Speculative Decoding 在大 batch size 下为什么效果不好?

  1. GPU 利用率已经高:大 batch 时 Decode 的矩阵乘变成了矩阵-矩阵乘(而非矩阵-向量乘),从 Memory-bound 变为 Compute-bound,GPU 计算单元已被充分利用

  2. 验证开销增大:Target 模型需要同时处理 batch 中每个请求的 K 个候选 token,显存和计算开销线性增长

  3. Draft 开销无法忽略:batch=1 时 Draft 成本可忽略,但 batch=32 时 Draft 模型也需要处理 32 个请求,成本不可忽略

  4. 边际收益递减:Continuous Batching 本身已提供了很好的吞吐,Speculative Decoding 的额外收益被稀释

经验法则:batch_size > 8-16 时,关闭 Speculative Decoding 通常更优。vLLM 提供 speculative_disable_by_batch_size 参数来自动处理。


题目 5:解释 Prompt Lookup Decoding 的原理和适用场景

原理: - 不使用任何 Draft 模型 - 直接在输入 prompt 中寻找与当前生成后缀匹配的 N-gram - 将匹配位置之后的 token 作为 Draft 候选交给 Target 验证 - 时间复杂度 \(O(L \cdot n)\)\(L\) 为 prompt 长度,\(n\) 为匹配长度),几乎零延迟

适用场景(接受率高): - 摘要任务:输出大量复用 prompt 中的原文 - 翻译:源语言和目标语言有共享 token(数字、专有名词) - 代码重构:输出与输入代码高度相似 - 多轮对话:重复引用上文内容

不适用场景: - 创意写作(输出与输入无关) - 知识问答(答案不在 prompt 中) - 数学推理(推导过程与 prompt 无重叠)


题目 6:Continuous Batching 和 Speculative Decoding 能否结合使用?

可以结合,但有工程挑战

  1. 兼容性:两者正交——Continuous Batching 管理请求调度,Speculative Decoding 管理单个请求的加速。理论上完全兼容。

  2. 工程挑战

  3. 批内不同请求的 Draft 长度可能不同(因为接受率不同),需要 padding 对齐
  4. Draft 阶段的 batching 和 Target 验证的 batching 需要分别管理
  5. KV-Cache 管理更复杂:被拒绝的 token 需要回退 Cache

  6. 动态策略

  7. batch 小时开启 Speculative Decoding
  8. batch 大时自动关闭
  9. vLLM 的 speculative_disable_by_batch_size 就是这个策略

  10. 实践结论:在请求稀疏时(batch 小),两者结合效果最好——Continuous Batching 保证吞吐,Speculative Decoding 降低延迟。


题目 7:如何选择最优投机长度 K?

理论分析

加速比 \(S(K) = \frac{1 - \alpha^{K+1}}{(1-\alpha)(Kc+1)}\)

\(K\) 求导并令其为零,可得理论最优 \(K^*\)

\[K^* \approx \frac{1}{c} \cdot \frac{\alpha}{1-\alpha} \quad (\text{近似})\]

实践指导

接受率 \(\alpha\) 速度比 \(c\) 建议 \(K\) 预估加速
0.9 0.05 6-8 3.5-4.0x
0.8 0.05 4-6 2.5-3.0x
0.7 0.10 3-4 1.7-2.0x
0.6 0.10 2-3 1.3-1.5x

动态调整策略: - 监控运行时接受率 - 接受率高时增大 \(K\)(多猜一些),低时减小 \(K\)(少猜避免浪费) - EAGLE-2 的动态树本质上就是在自动做这件事

Python
class AdaptiveSpeculativeLength:
    """运行时自适应调整投机长度"""
    def __init__(self, K_min=2, K_max=10, target_accept_rate=0.75):
        self.K = 5  # 初始值
        self.K_min = K_min
        self.K_max = K_max
        self.target_rate = target_accept_rate
        self.accept_history = []

    def update(self, n_accepted, K_used):
        rate = n_accepted / K_used
        self.accept_history.append(rate)

        # 用最近 10 轮的平均接受率调整
        if len(self.accept_history) >= 10:
            avg_rate = sum(self.accept_history[-10:]) / 10
            if avg_rate > self.target_rate + 0.1:
                self.K = min(self.K + 1, self.K_max)
            elif avg_rate < self.target_rate - 0.1:
                self.K = max(self.K - 1, self.K_min)

        return self.K

题目 8:未来推理加速的趋势是什么?

  1. 硬件-算法协同
  2. Speculative Decoding 与专用推理芯片(Groq、Cerebras)结合
  3. 硬件原生支持 Draft-Verify 流水线
  4. 充分利用 HBM3e 带宽(>4TB/s)缩小 Memory-bound 瓶颈

  5. 自适应投机策略

  6. 根据输入动态选择 Draft 方法(N-gram vs 小模型 vs EAGLE)
  7. 多级 Draft 级联:先 N-gram,失败再用小模型
  8. 在线学习接受率,实时调整参数

  9. 与 MoE 结合

  10. MoE 模型(如 DeepSeek-V3)的非活跃专家天然可作为 Draft
  11. 路由预测实现推测执行

  12. 与量化互补

  13. Draft 模型使用 INT4/INT2 极致量化
  14. Target 模型使用 FP8
  15. 量化带来的额外速度比 \(c\) 进一步降低

  16. 端侧推理

  17. 手机/边缘设备上内存带宽更受限,Speculative Decoding 价值更大
  18. 超轻量 Draft(N-gram + 小型 MLP)适合端侧

  19. 多模态投机解码

  20. 将 Speculative Decoding 扩展到视觉-语言模型
  21. 图像理解任务中文本部分仍可投机加速

本章总结

Text Only
自回归解码瓶颈 ──→ Memory-bound, GPU利用率低
Speculative Decoding ──→ Draft猜测 + Target验证
   │       │       │
   ▼       ▼       ▼
独立Draft  Medusa   EAGLE/EAGLE-2
(小模型)  (多头)   (特征级别)
   │       │       │
   ▼       ▼       ▼
工程实践: vLLM / TensorRT-LLM / SGLang
互补方法: Continuous Batching + Prefix Caching + Chunked Prefill

核心知识点: 1. Decode 阶段是 Memory-bound → 处理 K 个 token 和 1 个 token 耗时相近 2. Speculative Decoding 通过接受-拒绝采样保证输出分布精确不变 3. EAGLE 在特征空间做 Draft 优于 token 空间(Medusa) 4. 大 batch 时 Speculative Decoding 效果递减,需要动态开关 5. 与 Continuous Batching、Prefix Caching 等方法正交互补

延伸阅读:关于 Speculative Decoding 在完整推理服务中的集成(包括 API 层、负载均衡、多模型部署),请参阅 LLM应用教程第12章-推理优化


参考文献: 1. Leviathan et al., "Fast Inference from Transformers via Speculative Decoding", ICML 2023 2. Chen et al., "Accelerating Large Language Model Decoding with Speculative Sampling", 2023 3. Cai et al., "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads", ICML 2024 4. Li et al., "EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty", ICML 2024 5. Li et al., "EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees", 2024 6. Fu et al., "Lookahead Decoding: Breaking the Sequential Dependency of LLM Decoding", 2024 7. Saxena, "Prompt Lookup Decoding", 2023 8. Yu et al., "vLLM: Efficient Memory Management for Large Language Model Serving", SOSP 2023 9. Zheng et al., "SGLang: Efficient Execution of Structured Language Model Programs", 2024