跳转至

08 - 长上下文技术

⚠️ 时效性说明:本章涉及前沿技术,可能随研究进展快速变化;请以论文原文和官方实现为准。

学习目标:深入理解大模型长上下文处理的核心技术,包括Ring Attention、KV Cache优化、FlashAttention演进、位置编码外推等。


1. 长上下文挑战

1.1 核心问题概述

长上下文处理是大模型面临的关键技术挑战之一。随着应用场景对更长上下文的需求(如长文档理解、多轮对话、代码分析等),传统Transformer架构面临三大瓶颈:

Text Only
┌─────────────────────────────────────────────────────────────┐
│                    长上下文三大挑战│
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 计算复杂度O(n²)                    │
│     - 注意力矩阵: n×n                                        │
│     - 100K tokens → 10^10次计算                              │
│                                                             │
│  2. 显存占用爆炸                                             │
│     - KV Cache线性增长                                       │
│     - 100K context → 数十GB显存                              │
│                                                             │
│  3. 位置编码外推                                             │
│     - 训练长度有限(如4K/8K)                                │
│     - 推理时需要处理更长序列                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 注意力复杂度O(n²)问题

标准自注意力计算

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

复杂度分析

Text Only
输入序列长度: n
隐藏维度: d

计算步骤:
1. Q = XW_Q:  O(n × d × d) = O(nd²)
2. K = XW_K:  O(n × d × d) = O(nd²)
3. V = XW_V:  O(n × d × d) = O(nd²)
4. QK^T:      O(n × d × n) = O(n²d)    ← 瓶颈!
5. Softmax:   O(n²)
6. Attention × V: O(n² × d) = O(n²d)    ← 瓶颈!

总复杂度: O(n²d)
空间复杂度: O(n²)(存储注意力矩阵)

实际影响

Python
# 不同序列长度的注意力矩阵大小
def attention_memory_mb(seq_len, dtype_bytes=2):  # FP16
    """计算注意力矩阵显存占用"""
    elements = seq_len * seq_len
    mb = elements * dtype_bytes / (1024 * 1024)
    return mb

# 示例
lengths = [1024, 4096, 16384, 32768, 100000]
for n in lengths:
    mem = attention_memory_mb(n)
    print(f"序列长度 {n:>6}: 注意力矩阵 {mem:>10.2f} MB")

# 输出:
# 序列长度1024: 注意力矩阵2.00 MB
# 序列长度4096: 注意力矩阵32.00 MB
# 序列长度  16384: 注意力矩阵512.00 MB
# 序列长度  32768: 注意力矩阵2048.00 MB
# 序列长度 100000: 注意力矩阵19073.49 MB (~19GB!)

1.3 显存占用随序列长度增长

KV Cache显存计算

Text Only
单层KV Cache大小:
  K: batch_size × num_heads × seq_len × head_dim × dtype_size
  V: batch_size × num_heads × seq_len × head_dim × dtype_size

总计 (L层):
  2 × batch_size × num_heads × seq_len × head_dim × L × dtype_size

简化公式:
  KV_Cache = 2 × b × h × s × d × L × bytes

实际案例(LLaMA-65B)

Python
def kv_cache_memory(batch_size, num_layers, num_heads, head_dim, seq_len, dtype_bytes=2):
    """计算KV Cache显存占用"""
    return 2 * batch_size * num_layers * num_heads * head_dim * seq_len * dtype_bytes

# LLaMA-65B 参数
config = {
    'batch_size': 1,
    'num_layers': 80,
    'num_heads': 64,
    'head_dim': 128,
    'seq_len': 100000  # 100K context
}

mem_bytes = kv_cache_memory(**config)
mem_gb = mem_bytes / (1024**3)
print(f"100K context KV Cache: {mem_gb:.2f} GB")
# 输出: 100K context KV Cache: 122.07 GB

1.4 位置编码外推问题

问题本质

Text Only
训练时: 最大序列长度 L_train (如 2048/4096)
推理时: 需要处理 L_infer > L_train 的序列

位置编码困境:
- 训练时未见过的位置 → 位置编码无意义或效果差
- 外推导致注意力分布异常 → 生成质量下降

不同位置编码的外推能力

位置编码类型 外推能力 问题
绝对位置编码 超出训练范围的位置无定义
RoPE 中等 远距离位置衰减过快
ALiBi 线性衰减,可外推
相对位置编码 较好 但计算开销大

2. Ring Attention:分布式长上下文

2.1 核心原理

Ring Attention是一种分布式注意力计算技术,通过分块计算环状通信实现任意长度序列的处理。

核心思想

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    Ring Attention原理                         │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  传统方法:单GPU需要存储完整注意力矩阵           │
│                                                              │
│  Ring Attention:                                             │
│  1. 将序列分块: Q, K, V → [Block_1, Block_2, ..., Block_n]   │
│  2. 每个GPU持有部分K、V块                                    │
│  3. 环状传递K、V块,计算局部注意力                            │
│  4. 累积得到完整注意力结果                                    │
│                                                              │
│  GPU0: [Q_0] ← K_1 ← K_2 ← K_3 ← K_0 (环形)                  │
│  GPU1: [Q_1] ← K_2 ← K_3 ← K_0 ← K_1                         │
│  GPU2: [Q_2] ← K_3 ← K_0 ← K_1 ← K_2                         │
│  GPU3: [Q_3] ← K_0 ← K_1 ← K_2 ← K_3                         │
│                                                              │
└──────────────────────────────────────────────────────────────┘

2.2 数学形式化

分块注意力计算

将序列分成\(P\)个块,每个GPU处理一个块:

\[Q = [Q_1, Q_2, ..., Q_P], \quad K = [K_1, K_2, ..., K_P], \quad V = [V_1, V_2, ..., V_P]\]

\(i\)个GPU的注意力计算

\[O_i = \sum_{j=1}^{P} \text{FlashAttention}(Q_i, K_j, V_j)\]

环状通信模式

Text Only
时间步 t=0: GPU_i 持有 K_i, V_i
时间步 t=1: GPU_i 接收 K_{i-1}, V_{i-1} (从左边GPU)
时间步 t=2: GPU_i 接收 K_{i-2}, V_{i-2}
...
时间步 t=P-1: 完成所有块的注意力计算

复杂度分析

Text Only
计算复杂度: O(n²/P)  (每个GPU)
通信复杂度: O(n × d × P)  (环状传递)
总复杂度: O(n²/P + n×d×P)

当 P 个GPU时:
- 显存: O(n/P × d)  (每个GPU只存部分)
- 支持: 任意长度 (理论上)

2.3 代码实现

简化的Ring Attention实现

Python
import torch
import torch.distributed as dist

class RingAttention:
    """Ring Attention 简化实现"""

    def __init__(self, rank, world_size, head_dim, block_size=1024):
        self.rank = rank
        self.world_size = world_size
        self.head_dim = head_dim
        self.block_size = block_size

        # 环状通信的邻居
        self.send_rank = (rank + 1) % world_size
        self.recv_rank = (rank - 1) % world_size

    def compute_attention_block(self, Q_block, K_block, V_block):
        """计算单个块的注意力"""
        # Scaled Dot-Product Attention
        scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V_block)
        return output

    def ring_attention(self, Q, K, V):
        """
        Ring Attention 主函数

        Args:
            Q: [batch, heads, seq_len/P, head_dim] - 本GPU的Q块
            K: [batch, heads, seq_len/P, head_dim] - 本GPU的K块
            V: [batch, heads, seq_len/P, head_dim] - 本GPU的V块
        Returns:
            output: [batch, heads, seq_len/P, head_dim]
        """
        # 初始化输出和累积变量
        output = torch.zeros_like(Q)
        max_score = torch.full(Q.shape[:3], float('-inf'), device=Q.device)
        sum_exp = torch.zeros(Q.shape[:3], device=Q.device)

        # 当前持有的K, V块
        K_current = K.clone()
        V_current = V.clone()

        for step in range(self.world_size):
            # 计算当前块的注意力
            block_output, block_max, block_sum = self._flash_attention_block(
                Q, K_current, V_current
            )

            # 累积结果(使用数值稳定的softmax)
            new_max = torch.maximum(max_score, block_max)
            scale_old = torch.exp(max_score - new_max)
            scale_new = torch.exp(block_max - new_max)

            output = output * scale_old.unsqueeze(-1) + block_output * scale_new.unsqueeze(-1)
            sum_exp = sum_exp * scale_old + block_sum * scale_new
            max_score = new_max

            # 环状传递K, V块
            if step < self.world_size - 1:
                K_current, V_current = self._ring_send_recv(K_current, V_current)

        # 归一化
        output = output / sum_exp.unsqueeze(-1)

        return output

    def _flash_attention_block(self, Q, K, V):
        """Flash Attention风格的分块计算"""
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # 数值稳定的softmax
        max_score = scores.max(dim=-1, keepdim=True).values
        exp_scores = torch.exp(scores - max_score)
        sum_exp = exp_scores.sum(dim=-1, keepdim=True)

        output = torch.matmul(exp_scores, V)

        return output, max_score.squeeze(-1), sum_exp.squeeze(-1)

    def _ring_send_recv(self, K, V):
        """环状发送和接收K, V块"""
        K_recv = torch.empty_like(K)
        V_recv = torch.empty_like(V)

        # 异步发送和接收
        send_K = dist.isend(K.contiguous(), self.send_rank)
        send_V = dist.isend(V.contiguous(), self.send_rank)
        recv_K = dist.irecv(K_recv, self.recv_rank)
        recv_V = dist.irecv(V_recv, self.recv_rank)

        # 等待完成
        send_K.wait()
        send_V.wait()
        recv_K.wait()
        recv_V.wait()

        return K_recv, V_recv

与FlashAttention结合

Python
def ring_flash_attention(Q, K, V, rank, world_size):
    """
    Ring Attention + Flash Attention 组合

    结合两者优势:
    - Ring: 分布式处理超长序列
    - Flash: 单GPU内高效计算
    """
    from flash_attn import flash_attn_func

    output = torch.zeros_like(Q)

    K_current = K.clone()
    V_current = V.clone()

    for step in range(world_size):
        # 使用Flash Attention计算当前块
        block_output = flash_attn_func(Q, K_current, V_current, causal=True)

        # 累积结果...
        output += block_output  # 简化版,实际需要正确的softmax累积

        # 环状传递
        if step < world_size - 1:
            K_current, V_current = ring_send_recv(K_current, V_current, rank, world_size)

    return output

2.4 应用案例

支持的超长上下文模型

模型 上下文长度 使用技术
Claude 3 200K Ring Attention
Gemini 1.5 Pro 1M-2M Ring Attention + 优化
GPT-4-Turbo 128K 类似技术
LLaMA 3.1 128K Ring Attention

3. KV Cache优化技术

3.1 PagedAttention (vLLM)

核心思想:将KV Cache按页管理,类似操作系统的虚拟内存。

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    PagedAttention原理                         │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  传统KV Cache:                                               │
│  ┌────────────────────────────────────────┐                  │
│  │  连续预分配,大量浪费│                  │
│  └────────────────────────────────────────┘                  │
│                                                              │
│  PagedAttention:                                             │
│  ┌────┬────┬────┬────┬────┬────┬────┬────┐                  │
│  │Page│Page│Page│Page│Page│Page│Page│Page│                  │
│  │ 0  │ 1  │ 2  │ 3  │ 4  │ 5  │ 6  │ 7  │                  │
│  └────┴────┴────┴────┴────┴────┴────┴────┘                  │
│                                                              │
│  Sequence A: Page 0 → Page 2 → Page 5 (逻辑连续)             │
│  Sequence B: Page 1 → Page 3 → Page 4 (逻辑连续)             │
│  Sequence C: Page 6 → Page 7 (共享beam search)               │
│                                                              │
└──────────────────────────────────────────────────────────────┘

内存效率对比

Python
# 传统预分配 vs PagedAttention
def memory_comparison():
    # 假设: batch_size=32, max_seq_len=2048, 实际平均长度=512

    # 传统方法: 按最大长度预分配
    traditional_memory = 32 * 2048 * kv_per_token

    # PagedAttention: 按实际需要分配
    paged_memory = 32 * 512 * kv_per_token

    # 加上少量页表开销 (约5%)
    paged_memory *= 1.05

    savings = (traditional_memory - paged_memory) / traditional_memory
    print(f"内存节省: {savings:.1%}")  # 约75%节省

vLLM核心实现

Python
class BlockManager:
    """PagedAttention 块管理器"""

    def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int):
        self.num_blocks = num_blocks
        self.block_size = block_size  # 每页token数
        self.num_heads = num_heads
        self.head_dim = head_dim

        # 物理块池
        self.k_cache = torch.zeros(num_blocks, block_size, num_heads, head_dim)
        self.v_cache = torch.zeros(num_blocks, block_size, num_heads, head_dim)

        # 空闲块列表
        self.free_blocks = list(range(num_blocks))

        # 序列到块的映射
        self.seq_to_blocks = {}  # seq_id -> [block_ids]
        self.seq_lengths = {}    # seq_id -> current_length

    def allocate(self, seq_id: int) -> int:
        """为序列分配新块"""
        if not self.free_blocks:
            raise RuntimeError("Out of memory: no free blocks")

        block_id = self.free_blocks.pop(0)
        if seq_id not in self.seq_to_blocks:
            self.seq_to_blocks[seq_id] = []
            self.seq_lengths[seq_id] = 0

        self.seq_to_blocks[seq_id].append(block_id)
        return block_id

    def append_token(self, seq_id: int, k: torch.Tensor, v: torch.Tensor):
        """追加一个token的KV"""
        if seq_id not in self.seq_to_blocks:
            self.allocate(seq_id)

        current_len = self.seq_lengths[seq_id]
        block_idx = current_len // self.block_size
        token_idx = current_len % self.block_size

        # 如果当前块满了,分配新块
        if token_idx == 0 and current_len > 0:
            self.allocate(seq_id)
            block_idx += 1

        block_id = self.seq_to_blocks[seq_id][block_idx]

        # 写入KV Cache
        self.k_cache[block_id, token_idx] = k
        self.v_cache[block_id, token_idx] = v

        self.seq_lengths[seq_id] += 1

    def get_kv_cache(self, seq_id: int):
        """获取序列的完整KV Cache"""
        block_ids = self.seq_to_blocks[seq_id]
        length = self.seq_lengths[seq_id]

        # 收集所有块的KV
        k_blocks = self.k_cache[block_ids]  # [num_blocks, block_size, heads, dim]
        v_blocks = self.v_cache[block_ids]

        # 展平并截取实际长度
        k = k_blocks.view(-1, self.num_heads, self.head_dim)[:length]
        v = v_blocks.view(-1, self.num_heads, self.head_dim)[:length]

        return k, v

    def fork(self, parent_seq: int, child_seq: int):
        """复制序列(用于beam search),共享块"""
        parent_blocks = self.seq_to_blocks[parent_seq]
        self.seq_to_blocks[child_seq] = parent_blocks.copy()  # Copy-on-write
        self.seq_lengths[child_seq] = self.seq_lengths[parent_seq]

    def free(self, seq_id: int):
        """释放序列的所有块"""
        if seq_id in self.seq_to_blocks:
            self.free_blocks.extend(self.seq_to_blocks[seq_id])
            del self.seq_to_blocks[seq_id]
            del self.seq_lengths[seq_id]

3.2 Multi-Query Attention (MQA) / Grouped-Query Attention (GQA)

架构对比

Text Only
┌──────────────────────────────────────────────────────────────┐
│              MHA vs MQA vs GQA对比                            │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  Multi-Head Attention (MHA):                                 │
│  每个头有独立的K, V                                           │
│  KV Cache: [batch, num_heads, seq_len, head_dim]             │
│  内存: O(h × s × d)                                          │
│                                                              │
│  Multi-Query Attention (MQA):                                │
│  所有头共享一组K, V                                           │
│  KV Cache: [batch, 1, seq_len, head_dim]                     │
│  内存: O(s × d),节省h倍!                                    │
│                                                              │
│  Grouped-Query Attention (GQA):                              │
│  头分成g组,每组共享K, V                                      │
│  KV Cache: [batch, g, seq_len, head_dim]                     │
│  内存: O(g × s × d),平衡质量和效率                           │
│                                                              │
└──────────────────────────────────────────────────────────────┘

GQA实现

Python
import torch
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention 实现"""

    def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int):
        """
        Args:
            hidden_size: 隐藏维度
            num_heads: 查询头数 (Q)
            num_kv_heads: KV头数 (通常 < num_heads)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads

        self.head_dim = hidden_size // num_heads
        self.num_heads_per_group = num_heads // num_kv_heads

        # Q有num_heads个头
        self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False)
        # K和V只有num_kv_heads个头
        self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)

        self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False)

    def forward(self, hidden_states, attention_mask=None, kv_cache=None):
        batch_size, seq_len, _ = hidden_states.shape

        # 计算Q, K, V
        Q = self.q_proj(hidden_states)
        K = self.k_proj(hidden_states)
        V = self.v_proj(hidden_states)

        # 重塑形状
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # 处理KV Cache
        if kv_cache is not None:
            K = torch.cat([kv_cache['k'], K], dim=2)
            V = torch.cat([kv_cache['v'], V], dim=2)

        # 更新缓存
        new_kv_cache = {'k': K, 'v': V}

        # 扩展K, V以匹配Q的头数
        # [batch, num_kv_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim]
        K = self._repeat_kv(K)
        V = self._repeat_kv(V)

        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if attention_mask is not None:
            scores = scores + attention_mask

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        # 重塑并投影
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output, new_kv_cache

    def _repeat_kv(self, x):
        """将KV头复制以匹配Q头数"""
        if self.num_heads_per_group == 1:
            return x

        batch, num_kv_heads, seq_len, head_dim = x.shape
        x = x[:, :, None, :, :].expand(batch, num_kv_heads, self.num_heads_per_group, seq_len, head_dim)
        return x.reshape(batch, num_kv_heads * self.num_heads_per_group, seq_len, head_dim)

KV Cache节省计算

Python
def kv_cache_savings():
    """计算GQA的KV Cache节省"""
    # 以LLaMA-2-70B为例
    num_heads = 64
    num_kv_heads = 8  # GQA with 8 groups
    seq_len = 100000  # 100K context
    head_dim = 128
    dtype_bytes = 2   # FP16

    # MHA
    mha_kv = 2 * num_heads * seq_len * head_dim * dtype_bytes
    mha_gb = mha_kv / (1024**3)

    # GQA
    gqa_kv = 2 * num_kv_heads * seq_len * head_dim * dtype_bytes
    gqa_gb = gqa_kv / (1024**3)

    print(f"MHA KV Cache: {mha_gb:.2f} GB")
    print(f"GQA KV Cache: {gqa_gb:.2f} GB")
    print(f"节省: {(1 - gqa_gb/mha_gb):.1%}")

    # 输出:
    # MHA KV Cache: 3.05 GB
    # GQA KV Cache: 0.38 GB
    # 节省: 87.5%

3.3 KV Cache量化

量化策略

Python
class KVCacheQuantizer:
    """KV Cache 量化器"""

    def __init__(self, bits: int = 8, group_size: int = 128):
        """
        Args:
            bits: 量化位数 (4 或 8)
            group_size: 分组量化的大小
        """
        self.bits = bits
        self.group_size = group_size
        self.scale = None
        self.zero_point = None

    def quantize(self, x: torch.Tensor) -> tuple:
        """
        将FP16 KV Cache量化到低精度

        Args:
            x: [batch, heads, seq_len, head_dim]
        Returns:
            quantized: 量化后的数据
            scale: 缩放因子
            zero_point: 零点
        """
        # 分组量化
        x_grouped = x.reshape(-1, self.group_size)

        # 计算每组的min/max
        x_min = x_grouped.min(dim=-1, keepdim=True).values
        x_max = x_grouped.max(dim=-1, keepdim=True).values

        # 计算scale和zero_point
        qmin = 0
        qmax = 2 ** self.bits - 1

        scale = (x_max - x_min) / (qmax - qmin)
        scale = scale.clamp(min=1e-8)
        zero_point = qmin - x_min / scale
        zero_point = zero_point.round().clamp(qmin, qmax).to(torch.int32)

        # 量化
        quantized = ((x_grouped / scale) + zero_point).round().clamp(qmin, qmax)
        quantized = quantized.to(torch.uint8 if self.bits == 8 else torch.int8)

        return quantized, scale, zero_point

    def dequantize(self, quantized: torch.Tensor, scale: torch.Tensor, 
                   zero_point: torch.Tensor) -> torch.Tensor:
        """反量化"""
        return (quantized.float() - zero_point) * scale

    def memory_savings(self):
        """计算内存节省"""
        original_bits = 16  # FP16
        savings = 1 - self.bits / original_bits
        return savings

# 使用示例
quantizer = KVCacheQuantizer(bits=4)  # 4-bit量化
print(f"内存节省: {quantizer.memory_savings():.1%}")  # 75%

4. FlashAttention演进

4.1 FlashAttention-1

核心创新:分块计算 + 在线Softmax

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    FlashAttention-1原理                       │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  标准Attention:                                              │
│  1. 计算 S = QK^T           [n×n矩阵,存HBM]                 │
│  2. 计算 P = softmax(S)     [n×n矩阵,存HBM]                 │
│  3. 计算 O = PV             [n×n矩阵操作]                    │
│                                                              │
│  HBM访问: O(n²) - 瓶颈!                                     │
│                                                              │
│  FlashAttention:                                             │
│  1. 分块: Q, K, V分成小块 [B_r, B_c]                         │
│  2. 在SRAM中计算:                                            │
│     - 加载Q块、K块到SRAM                                     │
│     - 计算局部S_ij = Q_i @ K_j^T                             │
│     - 在线softmax更新                                        │
│     - 累积O_i                                                │
│  3. 只写回最终O                                              │
│                                                              │
│  HBM访问: O(n) - 线性!                                      │
│                                                              │
└──────────────────────────────────────────────────────────────┘

在线Softmax算法

Python
def online_softmax(a: torch.Tensor, new_block: torch.Tensor):
    """
    在线Softmax: 增量更新softmax结果

    已知: 之前块的max值m_old, sum(exp)值d_old
    新块: 需要融合new_block

    算法:
    1. m_new = max(m_old, max(new_block))
    2. d_new = d_old * exp(m_old - m_new) + sum(exp(new_block - m_new))
    3. O_new = O_old * (d_old * exp(m_old - m_new) / d_new) + 
               new_output * (sum(exp(new_block - m_new)) / d_new)
    """
    m_old, d_old, O_old = a
    m_new_block = new_block.max(dim=-1, keepdim=True).values

    # 更新全局max
    m_new = torch.maximum(m_old, m_new_block)

    # 更新归一化因子
    d_new_block = torch.exp(new_block - m_new).sum(dim=-1, keepdim=True)
    d_new = d_old * torch.exp(m_old - m_new) + d_new_block

    # 更新输出
    scale_old = d_old * torch.exp(m_old - m_new) / d_new
    scale_new = d_new_block / d_new
    O_new = O_old * scale_old + new_block * scale_new

    return m_new, d_new, O_new

性能提升

指标 标准Attention FlashAttention-1
HBM访问 O(n²) O(n)
显存占用 O(n²) O(n)
速度 1x 2-4x
支持序列长度 ~2K ~16K (同显存)

4.2 FlashAttention-2

主要改进

Text Only
FlashAttention-2 相比 v1 的优化:

1. 减少非矩阵乘法操作
   - 优化softmax计算
   - 减少原子操作

2. 并行化改进
   - 序列长度维度并行
   - 更好的GPU利用率

3. 支持更多数据类型
   - FP16, BF16
   - FP8 (Hopper架构)

并行策略

Python
# FlashAttention-2 的并行策略
"""
FlashAttention-1: 
  - 在batch和head维度并行
  - 长序列时GPU利用率低

FlashAttention-2:
  - 额外在seq_len维度并行
  - 即使batch=1也能充分利用GPU

示例: seq_len=16K, batch=1
  - v1: 只有32个线程块 (32 heads)
  - v2: 128个线程块 (额外在seq_len上切分)
"""

性能对比

配置 FlashAttention-1 FlashAttention-2 提升
A100, seq=2K 185 TFLOPS 215 TFLOPS 16%
A100, seq=8K 150 TFLOPS 190 TFLOPS 27%
H100, seq=16K 210 TFLOPS 280 TFLOPS 33%

4.3 FlashAttention-3

针对Hopper架构的优化

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    FlashAttention-3特性                       │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  1. 异步计算                                                 │
│     - Tensor Core: GEMM操作                                  │
│     - CUDA Core: Softmax操作                                 │
│     - 两者并行执行!                                          │
│                                                              │
│  2. FP8支持                                                  │
│     - H100原生FP8 Tensor Core                                │
│     - 比FP16快2倍                                            │
│                                                              │
│  3. 低精度优化                                               │
│     - 块量化技术                                             │
│     - 动态缩放因子                                           │
│                                                              │
└──────────────────────────────────────────────────────────────┘

异步流水线

Python
# FlashAttention-3 异步执行伪代码
def flash_attention_3_async(Q, K, V):
    """
    异步流水线执行:
    - GEMM (Tensor Core) 和 Softmax (CUDA Core) 并行
    """
    # 流水线阶段
    for block_i in range(num_blocks):
        # 阶段1: Tensor Core计算QK^T (异步)
        s_ij = async_gemm(Q[block_i], K[block_j].T)  # Tensor Core

        # 阶段2: 同时,CUDA Core处理上一个块的softmax
        if block_i > 0:
            p_prev = softmax(s_prev)  # CUDA Core (与上面并行)
            o_prev = async_gemm(p_prev, V[block_j-1])

        # 保存当前块用于下一轮
        s_prev = s_ij

    return output

4.4 版本对比总结

特性 FlashAttention-1 FlashAttention-2 FlashAttention-3
发布时间 2022 2023 2024
HBM访问 O(n) O(n) O(n)
并行维度 batch, head batch, head, seq batch, head, seq
数据类型 FP16, BF16 FP16, BF16 FP16, BF16, FP8
硬件优化 Ampere Ampere, Hopper Hopper专用
相对速度 2-4x 3-5x 5-8x
最长序列 ~16K ~64K ~128K+

5. 位置编码外推技术

5.1 RoPE外推改进

RoPE基础回顾

\[f(q, m) = q \cdot e^{i m \theta}\]

其中\(m\)是位置,\(\theta\)是频率。

外推问题

Text Only
训练时: 位置 m ∈ [0, L_train-1]
推理时: 位置 m ∈ [0, L_infer-1], L_infer > L_train

问题:
1. 高频分量在远距离衰减过快
2. 位置插值导致信息损失

5.2 Position Interpolation (PI)

核心思想:将长序列的位置压缩到训练范围内。

\[m' = \frac{m \times L_{train}}{L_{infer}}\]
Python
def position_interpolation(pos: int, max_train_len: int, max_infer_len: int) -> float:
    """位置插值"""
    return pos * max_train_len / max_infer_len

# 示例: 训练4K,推理16K
# 原始位置 16000 -> 插值后 4000

问题:压缩导致分辨率下降,性能损失。

5.3 NTK-aware Interpolation

核心洞察:不同频率分量需要不同的缩放。

Python
def ntk_aware_rope(base: float, max_train_len: int, max_infer_len: int) -> float:
    """
    NTK-aware位置插值

    原理: 高频分量不插值,低频分量插值
    通过调整base来实现
    """
    scale = max_infer_len / max_train_len
    new_base = base * (scale ** (dim / (dim - 2)))
    return new_base

# 示例
original_base = 10000
train_len = 4096
infer_len = 32768

new_base = ntk_aware_rope(original_base, train_len, infer_len)
# new_base ≈ 1250000 (大幅增加)

数学解释

Text Only
RoPE频率: θ_i = base^(-2i/d)

NTK-aware调整:
- 高频 (i小): 保持原频率,不插值
- 低频 (i大): 频率降低,允许外推

新base计算:
  base_new = base_old × scale^(d/(d-2))

5.4 YaRN (Yet another RoPE extension)

YaRN = NTK-aware + 温度缩放

Python
def yarn_rope_scaling(
    pos: int,
    dim: int,
    base: float,
    max_train_len: int,
    max_infer_len: int,
    temperature: float = 1.0
):
    """
    YaRN: 结合NTK-aware和温度缩放

    温度缩放: 调整softmax前的注意力分数
    """
    scale = max_infer_len / max_train_len

    # NTK-aware base调整
    ntk_base = base * (scale ** (dim / (dim - 2)))

    # 计算频率
    freq = 1.0 / (ntk_base ** (torch.arange(0, dim, 2).float() / dim))

    # 应用位置
    angles = pos * freq

    # 温度缩放
    angles = angles / temperature

    return angles

# YaRN推荐温度
def get_yarn_temperature(scale: float, dim: int) -> float:
    """计算YaRN温度参数"""
    if scale <= 1:
        return 1.0
    return 1.0 + 0.1 * torch.log(scale)  # 经验公式

性能对比

方法 4K→8K 4K→16K 4K→32K
直接外推 65% 45% 25%
Position Interpolation 80% 70% 55%
NTK-aware 85% 78% 65%
YaRN 90% 85% 75%

5.5 ALiBi (Attention with Linear Biases)

核心思想:不用位置编码,而是在注意力分数上添加线性偏置。

\[\text{Attention}(q, k, m) = \text{softmax}(qk^T - m \cdot |i-j|)\]

其中\(m\)是每个头的斜率,\(|i-j|\)是相对距离。

Python
class ALiBiAttention(nn.Module):
    """ALiBi Attention 实现"""

    def __init__(self, num_heads: int):
        super().__init__()
        self.num_heads = num_heads

        # 每个头的斜率: m_i = 1 / 2^(i * 8 / num_heads)
        # 或者使用几何序列
        slopes = self._get_slopes(num_heads)
        self.register_buffer('slopes', slopes)

    def _get_slopes(self, num_heads: int) -> torch.Tensor:
        """计算ALiBi斜率"""
        # 方法1: 几何序列
        n = 2 ** torch.floor(torch.log2(torch.tensor(num_heads)))
        m0 = 2 ** (-8 / n)
        slopes = m0 ** torch.arange(1, num_heads + 1)

        # 如果num_heads不是2的幂,需要额外处理
        if num_heads > n:
            extra_slopes = 2 ** (-4 / n)
            extra = extra_slopes ** torch.arange(1, 2 * (num_heads - n) + 1, 2)
            slopes = torch.cat([slopes, extra])

        return slopes

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
        """
        Args:
            Q: [batch, num_heads, seq_len, head_dim]
            K: [batch, num_heads, seq_len, head_dim]
            V: [batch, num_heads, seq_len, head_dim]
        """
        batch_size, num_heads, seq_len, head_dim = Q.shape

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)

        # 添加ALiBi偏置
        # 相对位置矩阵: [seq_len, seq_len]
        positions = torch.arange(seq_len, device=Q.device)
        relative_pos = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))

        # ALiBi偏置: [num_heads, seq_len, seq_len]
        # slopes: [num_heads] -> [num_heads, 1, 1]
        alibi_bias = -self.slopes.view(num_heads, 1, 1) * relative_pos.unsqueeze(0)

        # 添加到分数
        scores = scores + alibi_bias

        # Softmax和输出
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        return output

ALiBi优势

Text Only
1. 无需训练位置编码
   - 位置信息通过偏置隐式编码
   - 减少参数量

2. 天然支持外推
   - 线性衰减不依赖于训练长度
   - 可以处理任意长度序列

3. 计算高效
   - 偏置可以预计算
   - 无额外参数

4. 性能稳定
   - 长序列性能下降缓慢
   - 比RoPE更适合超长上下文

外推能力对比

方法 训练长度 推理2x 推理4x 推理8x
RoPE 2K 95% 75% 50%
RoPE + PI 2K 92% 85% 70%
RoPE + YaRN 2K 96% 90% 80%
ALiBi 2K 98% 95% 90%

6. 实践指南

6.1 技术选型决策树

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    长上下文技术选型                           │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  Q: 需要多长的上下文?                                        │
│                                                              │
│  ├─ < 8K:                                                    │
│  │   └─ 标准Transformer + FlashAttention-2                   │
│  │                                                          │
│  ├─ 8K - 32K:                                               │
│  │   ├─ 单GPU:                                              │
│  │   │   └─ FlashAttention-2 + GQA + KV Cache量化           │
│  │   └─ 多GPU:                                              │
│  │       └─ Ring Attention                                  │
│  │                                                          │
│  ├─ 32K - 128K:                                             │
│  │   ├─ 训练: Ring Attention + YaRN                         │
│  │   └─ 推理: vLLM (PagedAttention) + GQA                   │
│  │                                                          │
│  └─ > 128K (100K - 1M+):                                    │
│      └─ Ring Attention + KV Cache量化 + 分布式推理           │
│                                                              │
└──────────────────────────────────────────────────────────────┘

6.2 不同场景的推荐配置

场景1:长文档问答(32K-128K)

Python
# 推荐配置
config = {
    'attention': 'FlashAttention-2',
    'position_encoding': 'YaRN',
    'kv_cache': 'PagedAttention (vLLM)',
    'attention_variant': 'GQA (8 groups)',
    'kv_quantization': 'FP8',
    'max_context': 128000,
}

# vLLM启动命令
# python -m vllm.entrypoints.api_server \
#     --model meta-llama/Llama-3.1-70B \
#     --max-context-length 128000 \
#     --kv-cache-dtype fp8 \
#     --tensor-parallel-size 4

场景2:代码分析(100K+)

Python
# 推荐配置
config = {
    'attention': 'Ring Attention',
    'position_encoding': 'YaRN',
    'kv_cache': 'PagedAttention + 量化',
    'attention_variant': 'GQA',
    'distributed': '张量并行',
    'max_context': 200000,
}

# 需要多GPU支持
# ring-attention --model path/to/model \
#     --sequence-length 200000 \
#     --ring-size 8

场景3:多轮对话(8K-32K)

Python
# 推荐配置
config = {
    'attention': 'FlashAttention-2',
    'position_encoding': 'RoPE + NTK-aware',
    'kv_cache': 'PagedAttention',
    'attention_variant': 'GQA',
    'continuous_batching': True,
    'max_context': 32000,
}

# vLLM配置
# vllm serve model_name \
#     --max-model-len 32768 \
#     --enable-prefix-caching \
#     --gpu-memory-utilization 0.9

6.3 性能对比表

配置 显存 (70B) 吞吐量 TTFT (1K) 最大长度
基线 (MHA + FP16) 140GB 100 tok/s 2.5s 4K
+ FlashAttention-2 80GB 250 tok/s 1.2s 16K
+ GQA 45GB 280 tok/s 1.0s 32K
+ KV量化 (INT8) 30GB 260 tok/s 1.1s 64K
+ PagedAttention 25GB 350 tok/s 0.8s 128K
Ring (4x A100) 20GB/GPU 400 tok/s 0.6s 512K

6.4 常见问题与解决方案

Q1: 显存不足怎么办?

Python
# 按优先级尝试:
solutions = [
    "1. 启用KV Cache量化 (FP8/INT8)",
    "2. 使用GQA替代MHA",
    "3. 减小batch_size",
    "4. 使用PagedAttention (vLLM)",
    "5. 多卡分布式推理",
]

Q2: 长序列推理速度慢?

Python
# 优化步骤:
optimizations = [
    "1. 确保使用FlashAttention-2/3",
    "2. 启用连续批处理",
    "3. 考虑投机采样 (Speculative Decoding)",
    "4. 使用张量并行",
]

Q3: 如何评估外推能力?

Python
def evaluate_extrapolation(model, tokenizer, test_lengths):
    """评估模型的外推能力"""
    results = {}

    for length in test_lengths:
        # 构造测试样本
        text = "..." * length  # 长文本
        inputs = tokenizer(text, return_tensors="pt")

        # 测试困惑度
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs.input_ids)
            ppl = torch.exp(outputs.loss).item()

        results[length] = {
            'perplexity': ppl,
            'success': ppl < threshold
        }

    return results

7. 总结

7.1 技术关系图

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    长上下文技术栈                             │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              应用层 (100K - 1M tokens)               │    │
│  │  长文档理解 │ 代码分析 │ 多轮对话 │ RAG              │    │
│  └─────────────────────────────────────────────────────┘    │
│                          ↓                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              系统层 (分布式 + 内存管理)               │    │
│  │  Ring Attention │ PagedAttention │ Continuous Batching│   │
│  └─────────────────────────────────────────────────────┘    │
│                          ↓                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              算法层 (注意力优化)                      │    │
│  │  FlashAttention-1/2/3 │ MQA/GQA │ KV量化            │    │
│  └─────────────────────────────────────────────────────┘    │
│                          ↓                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              基础层 (位置编码)                        │    │
│  │  RoPE │ ALiBi │ YaRN │ NTK-aware                    │    │
│  └─────────────────────────────────────────────────────┘    │
│                                                              │
└──────────────────────────────────────────────────────────────┘

7.2 关键要点

  1. Ring Attention:分布式处理超长序列的核心技术
  2. KV Cache优化:GQA + PagedAttention + 量化是标配组合
  3. FlashAttention:从算法层面解决O(n²)问题,v3是当前最优
  4. 位置编码外推:YaRN是RoPE的最佳扩展方案,ALiBi天然支持外推
  5. 实践选型:根据序列长度和资源选择合适的技术组合

8. 参考资料

8.1 核心论文

  1. Ring Attention: "Ring Attention with Blockwise Transformers for Near-Infinite Context" (Liu et al., 2023)
  2. FlashAttention-1: "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
  3. FlashAttention-2: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (Dao, 2023)
  4. FlashAttention-3: "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (Dao et al., 2024)
  5. PagedAttention: "Efficient Memory Management for Large Language Model Serving with PagedAttention" (vLLM, 2023)
  6. MQA: "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019)
  7. GQA: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023)
  8. YaRN: "YaRN: Efficient Context Window Extension of Large Language Models" (Peng et al., 2023)
  9. ALiBi: "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (Press et al., 2022)

8.2 开源实现


最后更新日期:2026-02-21 适用版本:LLM学习教程 v2026

相关章节: - 02-推理优化技术 - 推理优化基础 - 03-大模型预训练 - 预训练中的长上下文处理