跳转至

08 - 长上下文技术

⚠️ 时效性说明(2026-03-27):本章涉及前沿技术,可能随研究进展快速变化;请以论文原文和官方实现为准。本轮已将长上下文模型示例收紧到更稳妥的公开窗口口径。

学习目标:深入理解大模型长上下文处理的核心技术,包括 Ring Attention 、 KV Cache 优化、 FlashAttention 演进、位置编码外推等。


1. 长上下文挑战

1.1 核心问题概述

长上下文处理是大模型面临的关键技术挑战之一。随着应用场景对更长上下文的需求(如长文档理解、多轮对话、代码分析等),传统 Transformer 架构面临三大瓶颈:

Text Only
┌─────────────────────────────────────────────────────────────┐
│                    长上下文三大挑战│
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 计算复杂度O(n²)                    │
│     - 注意力矩阵: n×n                                        │
│     - 100K tokens → 10^10次计算                              │
│                                                             │
│  2. 显存占用爆炸                                             │
│     - KV Cache线性增长                                       │
│     - 100K context → 数十GB显存                              │
│                                                             │
│  3. 位置编码外推                                             │
│     - 训练长度有限(如4K/8K)                                │
│     - 推理时需要处理更长序列                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 注意力复杂度 O(n²)问题

标准自注意力计算

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

复杂度分析

Text Only
输入序列长度: n
隐藏维度: d

计算步骤:
1. Q = XW_Q:  O(n × d × d) = O(nd²)
2. K = XW_K:  O(n × d × d) = O(nd²)
3. V = XW_V:  O(n × d × d) = O(nd²)
4. QK^T:      O(n × d × n) = O(n²d)    ← 瓶颈!
5. Softmax:   O(n²)
6. Attention × V: O(n² × d) = O(n²d)    ← 瓶颈!

总复杂度: O(n²d)
空间复杂度: O(n²)(存储注意力矩阵)

实际影响

Python
# 不同序列长度的注意力矩阵大小
def attention_memory_mb(seq_len, dtype_bytes=2):  # FP16
    """计算注意力矩阵显存占用"""
    elements = seq_len * seq_len
    mb = elements * dtype_bytes / (1024 * 1024)
    return mb

# 示例
lengths = [1024, 4096, 16384, 32768, 100000]
for n in lengths:
    mem = attention_memory_mb(n)
    print(f"序列长度 {n:>6}: 注意力矩阵 {mem:>10.2f} MB")

# 输出:
# 序列长度1024: 注意力矩阵2.00 MB
# 序列长度4096: 注意力矩阵32.00 MB
# 序列长度  16384: 注意力矩阵512.00 MB
# 序列长度  32768: 注意力矩阵2048.00 MB
# 序列长度 100000: 注意力矩阵19073.49 MB (~19GB!)

1.3 显存占用随序列长度增长

KV Cache 显存计算

Text Only
单层KV Cache大小:
  K: batch_size × num_heads × seq_len × head_dim × dtype_size
  V: batch_size × num_heads × seq_len × head_dim × dtype_size

总计 (L层):
  2 × batch_size × num_heads × seq_len × head_dim × L × dtype_size

简化公式:
  KV_Cache = 2 × b × h × s × d × L × bytes

实际案例( LLaMA-65B )

Python
def kv_cache_memory(batch_size, num_layers, num_heads, head_dim, seq_len, dtype_bytes=2):
    """计算KV Cache显存占用"""
    return 2 * batch_size * num_layers * num_heads * head_dim * seq_len * dtype_bytes

# LLaMA-65B 参数
config = {
    'batch_size': 1,
    'num_layers': 80,
    'num_heads': 64,
    'head_dim': 128,
    'seq_len': 100000  # 100K context
}

mem_bytes = kv_cache_memory(**config)
mem_gb = mem_bytes / (1024**3)
print(f"100K context KV Cache: {mem_gb:.2f} GB")
# 输出: 100K context KV Cache: 122.07 GB

1.4 位置编码外推问题

问题本质

Text Only
训练时: 最大序列长度 L_train (如 2048/4096)
推理时: 需要处理 L_infer > L_train 的序列

位置编码困境:
- 训练时未见过的位置 → 位置编码无意义或效果差
- 外推导致注意力分布异常 → 生成质量下降

不同位置编码的外推能力

位置编码类型 外推能力 问题
绝对位置编码 超出训练范围的位置无定义
RoPE 中等 远距离位置衰减过快
ALiBi 线性衰减,可外推
相对位置编码 较好 但计算开销大

2. Ring Attention :分布式长上下文

2.1 核心原理

Ring Attention 是一种分布式注意力计算技术,通过分块计算环状通信实现任意长度序列的处理。

核心思想

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    Ring Attention原理                         │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  传统方法:单GPU需要存储完整注意力矩阵           │
│                                                              │
│  Ring Attention:                                             │
│  1. 将序列分块: Q, K, V → [Block_1, Block_2, ..., Block_n]   │
│  2. 每个GPU持有部分K、V块                                    │
│  3. 环状传递K、V块,计算局部注意力                            │
│  4. 累积得到完整注意力结果                                    │
│                                                              │
│  GPU0: [Q_0] ← K_1 ← K_2 ← K_3 ← K_0 (环形)                  │
│  GPU1: [Q_1] ← K_2 ← K_3 ← K_0 ← K_1                         │
│  GPU2: [Q_2] ← K_3 ← K_0 ← K_1 ← K_2                         │
│  GPU3: [Q_3] ← K_0 ← K_1 ← K_2 ← K_3                         │
│                                                              │
└──────────────────────────────────────────────────────────────┘

2.2 数学形式化

分块注意力计算

将序列分成\(P\)个块,每个 GPU 处理一个块:

\[Q = [Q_1, Q_2, ..., Q_P], \quad K = [K_1, K_2, ..., K_P], \quad V = [V_1, V_2, ..., V_P]\]

对固定的 query 块 \(Q_i\),第 \(j\) 个 key/value 块对应的打分矩阵为:

\[ S_{ij} = \frac{Q_i K_j^T}{\sqrt{d}} \]

这里要特别注意:完整注意力结果不能写成“各块 softmax 输出的简单求和”
原因是 softmax 的分母耦合了所有 key 块:

\[ \text{softmax}\left([S_{i1}, S_{i2}, \dots, S_{iP}]\right) \]

所以 Ring Attention 的正确做法是像 FlashAttention 一样,对每个 query 行维护在线 softmax 累积量。

设第 \(j\) 轮处理完后,每个 query 行维护:

  • \(m_i^{(j)}\):到目前为止的行最大值
  • \(\ell_i^{(j)}\):到目前为止的 softmax 分母累积
  • \(n_i^{(j)}\):到目前为止的 softmax 分子累积

则对块 \(S_{ij}\) 有:

\[ m_i^{(j)} = \max \left(m_i^{(j-1)}, \operatorname{rowmax}(S_{ij})\right) \]
\[ \ell_i^{(j)} = \ell_i^{(j-1)} e^{m_i^{(j-1)} - m_i^{(j)}} + \sum \exp \left(S_{ij} - m_i^{(j)}\right) \]
\[ n_i^{(j)} = n_i^{(j-1)} e^{m_i^{(j-1)} - m_i^{(j)}} + \exp \left(S_{ij} - m_i^{(j)}\right)V_j \]

最终输出为:

\[ O_i = \frac{n_i^{(P)}}{\ell_i^{(P)}} \]

上面所有“\(\sum\)”都表示沿着 key 维度做逐行求和;多头维度和 batch 维度这里为简洁起见省略。

环状通信模式

Text Only
时间步 t=0: GPU_i 持有 K_i, V_i
时间步 t=1: GPU_i 接收 K_{i-1}, V_{i-1} (从左边GPU)
时间步 t=2: GPU_i 接收 K_{i-2}, V_{i-2}
...
时间步 t=P-1: 完成所有块的注意力计算

复杂度分析

Text Only
设总序列长度为 n,每个 GPU 持有约 n/P 个 token。

每个GPU的计算复杂度:
  O(n²d / P)

每个GPU的通信量:
  O(nd)
  (本地 KV 块大小约为 n/P × d,需要在环上走完一圈)

全系统总通信量:
  O(ndP)

当 P 个GPU时:
- 每GPU显存: O(n/P × d) + 局部tile缓冲区
- 可以把超长序列拆到多卡上处理,但效率仍受通信开销影响

2.3 代码实现

简化的 Ring Attention 实现

Python
import torch
import torch.distributed as dist

class RingAttention:
    """Ring Attention 简化实现"""

    def __init__(self, rank, world_size, head_dim, block_size=1024):
        self.rank = rank
        self.world_size = world_size
        self.head_dim = head_dim
        self.block_size = block_size

        # 环状通信的邻居
        self.send_rank = (rank + 1) % world_size
        self.recv_rank = (rank - 1) % world_size

    def compute_attention_block(self, Q_block, K_block, V_block):
        """计算单个块的注意力"""
        # Scaled Dot-Product Attention
        scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V_block)
        return output

    def ring_attention(self, Q, K, V):
        """
        Ring Attention 主函数

        Args:
            Q: [batch, heads, seq_len/P, head_dim] - 本GPU的Q块
            K: [batch, heads, seq_len/P, head_dim] - 本GPU的K块
            V: [batch, heads, seq_len/P, head_dim] - 本GPU的V块
        Returns:
            output: [batch, heads, seq_len/P, head_dim]
        """
        # 在线softmax的分子/分母累积
        numerator = torch.zeros_like(Q)
        row_max = torch.full(Q.shape[:3], float('-inf'), device=Q.device, dtype=Q.dtype)
        row_sum = torch.zeros(Q.shape[:3], device=Q.device, dtype=Q.dtype)

        # 当前持有的K, V块
        K_current = K.clone()
        V_current = V.clone()

        for step in range(self.world_size):
            # 计算当前块的注意力
            block_numerator, block_max, block_sum = self._flash_attention_block(
                Q, K_current, V_current
            )

            # 累积结果(使用数值稳定的softmax)
            new_max = torch.maximum(row_max, block_max)
            scale_old = torch.exp(row_max - new_max)
            scale_new = torch.exp(block_max - new_max)

            numerator = numerator * scale_old.unsqueeze(-1) + block_numerator * scale_new.unsqueeze(-1)
            row_sum = row_sum * scale_old + block_sum * scale_new
            row_max = new_max

            # 环状传递K, V块
            if step < self.world_size - 1:
                K_current, V_current = self._ring_send_recv(K_current, V_current)

        # 归一化得到最终attention输出
        eps = torch.finfo(Q.dtype).eps
        return numerator / row_sum.clamp_min(eps).unsqueeze(-1)

    def _flash_attention_block(self, Q, K, V):
        """Flash Attention风格的分块计算"""
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # 数值稳定的softmax
        max_score = scores.max(dim=-1, keepdim=True).values
        exp_scores = torch.exp(scores - max_score)
        sum_exp = exp_scores.sum(dim=-1, keepdim=True)

        numerator = torch.matmul(exp_scores, V)

        return numerator, max_score.squeeze(-1), sum_exp.squeeze(-1)

    def _ring_send_recv(self, K, V):
        """环状发送和接收K, V块"""
        K_recv = torch.empty_like(K)
        V_recv = torch.empty_like(V)

        # 异步发送和接收
        send_K = dist.isend(tensor=K.contiguous(), dst=self.send_rank)
        send_V = dist.isend(tensor=V.contiguous(), dst=self.send_rank)
        recv_K = dist.irecv(tensor=K_recv, src=self.recv_rank)
        recv_V = dist.irecv(tensor=V_recv, src=self.recv_rank)

        # 等待完成
        send_K.wait()
        send_V.wait()
        recv_K.wait()
        recv_V.wait()

        return K_recv, V_recv

与 FlashAttention 结合

Python
def ring_flash_attention(Q, K, V, rank, world_size):
    """
    Ring Attention + Flash Attention 组合

    结合两者优势:
    - Ring: 分布式处理超长序列
    - Flash: 单GPU内高效计算
    """
    attention = RingAttention(rank=rank, world_size=world_size, head_dim=Q.size(-1))
    return attention.ring_attention(Q, K, V)

2.4 应用案例

公开生态里的典型长上下文层级

⚠️ 商业模型的上下文窗口、SKU 与默认限制会随产品版本、区域、配额和 API 套餐变化。比起记某个具体数字,更重要的是理解“能力层级”和其背后的工程手段。

类别 常见公开窗口级别 备注
商用前沿闭源模型 128K - 1M+ 具体上限常因接口和套餐不同而变化
开源主流长上下文模型 32K - 128K 更多依赖位置编码扩展和 KV Cache 优化
研究型超长上下文方案 128K - 近无限上下文 往往需要分布式注意力或特定训练策略

3. KV Cache 优化技术

3.1 PagedAttention (vLLM)

核心思想:将 KV Cache 按页管理,类似操作系统的虚拟内存。

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    PagedAttention原理                         │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  传统KV Cache:                                               │
│  ┌────────────────────────────────────────┐                  │
│  │  连续预分配,大量浪费│                  │
│  └────────────────────────────────────────┘                  │
│                                                              │
│  PagedAttention:                                             │
│  ┌────┬────┬────┬────┬────┬────┬────┬────┐                  │
│  │Page│Page│Page│Page│Page│Page│Page│Page│                  │
│  │ 0  │ 1  │ 2  │ 3  │ 4  │ 5  │ 6  │ 7  │                  │
│  └────┴────┴────┴────┴────┴────┴────┴────┘                  │
│                                                              │
│  Sequence A: Page 0 → Page 2 → Page 5 (逻辑连续)             │
│  Sequence B: Page 1 → Page 3 → Page 4 (逻辑连续)             │
│  Sequence C: Page 6 → Page 7 (共享beam search)               │
│                                                              │
└──────────────────────────────────────────────────────────────┘

内存效率对比

Python
# 传统预分配 vs PagedAttention
def memory_comparison():
    # 假设: batch_size=32, max_seq_len=2048, 实际平均长度=512
    num_layers = 32
    num_heads = 32
    head_dim = 128
    dtype_bytes = 2  # FP16
    kv_per_token = 2 * num_layers * num_heads * head_dim * dtype_bytes

    # 传统方法: 按最大长度预分配
    traditional_memory = 32 * 2048 * kv_per_token

    # PagedAttention: 按实际需要分配
    paged_memory = 32 * 512 * kv_per_token

    # 加上少量页表开销 (约5%)
    paged_memory *= 1.05

    savings = (traditional_memory - paged_memory) / traditional_memory
    print(f"内存节省: {savings:.1%}")  # 约75%节省

vLLM 核心实现

Python
class BlockManager:
    """PagedAttention 块管理器"""

    def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int):
        self.num_blocks = num_blocks
        self.block_size = block_size  # 每页token数
        self.num_heads = num_heads
        self.head_dim = head_dim

        # 物理块池
        self.k_cache = torch.zeros(num_blocks, block_size, num_heads, head_dim)
        self.v_cache = torch.zeros(num_blocks, block_size, num_heads, head_dim)

        # 空闲块列表
        self.free_blocks = list(range(num_blocks))
        self.block_ref_count = [0 for _ in range(num_blocks)]

        # 序列到块的映射
        self.seq_to_blocks = {}  # seq_id -> [block_ids]
        self.seq_lengths = {}    # seq_id -> current_length

    def allocate(self, seq_id: int) -> int:
        """为序列分配新块"""
        if not self.free_blocks:
            raise RuntimeError("Out of memory: no free blocks")

        block_id = self.free_blocks.pop(0)
        if seq_id not in self.seq_to_blocks:
            self.seq_to_blocks[seq_id] = []
            self.seq_lengths[seq_id] = 0

        self.seq_to_blocks[seq_id].append(block_id)
        self.block_ref_count[block_id] = 1
        return block_id

    def append_token(self, seq_id: int, k: torch.Tensor, v: torch.Tensor):
        """追加一个token的KV"""
        if seq_id not in self.seq_to_blocks:
            self.allocate(seq_id)

        current_len = self.seq_lengths[seq_id]
        block_idx = current_len // self.block_size
        token_idx = current_len % self.block_size

        # 如果当前块满了,分配新块
        if token_idx == 0 and current_len > 0:
            self.allocate(seq_id)
            block_idx += 1
        elif current_len > 0:
            # 教学版的最小可用 copy-on-write:
            # 如果当前尾块被多个序列共享,而本序列还要继续往里面写,
            # 就复制一份当前尾块,避免 beam / fork 后相互污染。
            tail_block_id = self.seq_to_blocks[seq_id][block_idx]
            if self.block_ref_count[tail_block_id] > 1:
                new_block_id = self.free_blocks.pop(0)
                self.block_ref_count[new_block_id] = 1
                self.block_ref_count[tail_block_id] -= 1
                self.k_cache[new_block_id].copy_(self.k_cache[tail_block_id])
                self.v_cache[new_block_id].copy_(self.v_cache[tail_block_id])
                self.seq_to_blocks[seq_id][block_idx] = new_block_id

        block_id = self.seq_to_blocks[seq_id][block_idx]

        # 写入KV Cache
        self.k_cache[block_id, token_idx] = k
        self.v_cache[block_id, token_idx] = v

        self.seq_lengths[seq_id] += 1

    def get_kv_cache(self, seq_id: int):
        """获取序列的完整KV Cache"""
        block_ids = self.seq_to_blocks[seq_id]
        length = self.seq_lengths[seq_id]

        # 收集所有块的KV
        k_blocks = self.k_cache[block_ids]  # [num_blocks, block_size, heads, dim]
        v_blocks = self.v_cache[block_ids]

        # 展平并截取实际长度
        k = k_blocks.view(-1, self.num_heads, self.head_dim)[:length]
        v = v_blocks.view(-1, self.num_heads, self.head_dim)[:length]

        return k, v

    def fork(self, parent_seq: int, child_seq: int):
        """复制序列(用于beam search),共享已有块并增加引用计数"""
        parent_blocks = self.seq_to_blocks[parent_seq]
        self.seq_to_blocks[child_seq] = parent_blocks.copy()
        self.seq_lengths[child_seq] = self.seq_lengths[parent_seq]
        for block_id in parent_blocks:
            self.block_ref_count[block_id] += 1

    def free(self, seq_id: int):
        """释放序列的所有块"""
        if seq_id in self.seq_to_blocks:
            for block_id in self.seq_to_blocks[seq_id]:
                self.block_ref_count[block_id] -= 1
                if self.block_ref_count[block_id] == 0:
                    self.free_blocks.append(block_id)
            del self.seq_to_blocks[seq_id]
            del self.seq_lengths[seq_id]

3.2 Multi-Query Attention (MQA) / Grouped-Query Attention (GQA)

架构对比

Text Only
┌──────────────────────────────────────────────────────────────┐
│              MHA vs MQA vs GQA对比                            │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  Multi-Head Attention (MHA):                                 │
│  每个头有独立的K, V                                           │
│  KV Cache: [batch, num_heads, seq_len, head_dim]             │
│  内存: O(h × s × d)                                          │
│                                                              │
│  Multi-Query Attention (MQA):                                │
│  所有头共享一组K, V                                           │
│  KV Cache: [batch, 1, seq_len, head_dim]                     │
│  内存: O(s × d),节省h倍!                                    │
│                                                              │
│  Grouped-Query Attention (GQA):                              │
│  头分成g组,每组共享K, V                                      │
│  KV Cache: [batch, g, seq_len, head_dim]                     │
│  内存: O(g × s × d),平衡质量和效率                           │
│                                                              │
└──────────────────────────────────────────────────────────────┘

GQA 实现

Python
import torch
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention 实现"""

    def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int):
        """
        Args:
            hidden_size: 隐藏维度
            num_heads: 查询头数 (Q)
            num_kv_heads: KV头数 (通常 < num_heads)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads

        self.head_dim = hidden_size // num_heads
        self.num_heads_per_group = num_heads // num_kv_heads

        # Q有num_heads个头
        self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False)
        # K和V只有num_kv_heads个头
        self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)

        self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False)

    def forward(self, hidden_states, attention_mask=None, kv_cache=None):
        batch_size, seq_len, _ = hidden_states.shape

        # 计算Q, K, V
        Q = self.q_proj(hidden_states)
        K = self.k_proj(hidden_states)
        V = self.v_proj(hidden_states)

        # 重塑形状
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # 处理KV Cache
        if kv_cache is not None:
            K = torch.cat([kv_cache['k'], K], dim=2)
            V = torch.cat([kv_cache['v'], V], dim=2)

        # 更新缓存
        new_kv_cache = {'k': K, 'v': V}

        # 扩展K, V以匹配Q的头数
        # [batch, num_kv_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim]
        K = self._repeat_kv(K)
        V = self._repeat_kv(V)

        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if attention_mask is not None:
            scores = scores + attention_mask

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        # 重塑并投影
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output, new_kv_cache

    def _repeat_kv(self, x):
        """将KV头复制以匹配Q头数"""
        if self.num_heads_per_group == 1:
            return x

        batch, num_kv_heads, seq_len, head_dim = x.shape
        x = x[:, :, None, :, :].expand(batch, num_kv_heads, self.num_heads_per_group, seq_len, head_dim)
        return x.reshape(batch, num_kv_heads * self.num_heads_per_group, seq_len, head_dim)

KV Cache 节省计算

Python
def kv_cache_savings():
    """计算GQA的KV Cache节省"""
    # 以LLaMA-2-70B为例
    num_heads = 64
    num_kv_heads = 8  # GQA with 8 groups
    seq_len = 100000  # 100K context
    head_dim = 128
    dtype_bytes = 2   # FP16

    # MHA
    mha_kv = 2 * num_heads * seq_len * head_dim * dtype_bytes
    mha_gb = mha_kv / (1024**3)

    # GQA
    gqa_kv = 2 * num_kv_heads * seq_len * head_dim * dtype_bytes
    gqa_gb = gqa_kv / (1024**3)

    print(f"MHA KV Cache: {mha_gb:.2f} GB")
    print(f"GQA KV Cache: {gqa_gb:.2f} GB")
    print(f"节省: {(1 - gqa_gb/mha_gb):.1%}")

    # 输出:
    # MHA KV Cache: 3.05 GB
    # GQA KV Cache: 0.38 GB
    # 节省: 87.5%

3.3 KV Cache 量化

量化策略

Python
class KVCacheQuantizer:
    """KV Cache 量化器"""

    def __init__(self, bits: int = 8, group_size: int = 128):
        """
        Args:
            bits: 量化位数 (4 或 8)
            group_size: 分组量化的大小
        """
        self.bits = bits
        self.group_size = group_size
        self.scale = None
        self.zero_point = None

    def quantize(self, x: torch.Tensor) -> tuple:
        """
        将FP16 KV Cache量化到低精度

        Args:
            x: [batch, heads, seq_len, head_dim]
        Returns:
            quantized: 量化后的数据
            scale: 缩放因子
            zero_point: 零点
        """
        # 分组量化
        x_grouped = x.reshape(-1, self.group_size)

        # 计算每组的min/max
        x_min = x_grouped.min(dim=-1, keepdim=True).values
        x_max = x_grouped.max(dim=-1, keepdim=True).values

        # 计算scale和zero_point
        qmin = 0
        qmax = 2 ** self.bits - 1

        scale = (x_max - x_min) / (qmax - qmin)
        scale = scale.clamp(min=1e-8)
        zero_point = qmin - x_min / scale
        zero_point = zero_point.round().clamp(qmin, qmax).to(torch.int32)

        # 量化
        quantized = ((x_grouped / scale) + zero_point).round().clamp(qmin, qmax)
        quantized = quantized.to(torch.uint8 if self.bits == 8 else torch.int8)

        return quantized, scale, zero_point

    def dequantize(self, quantized: torch.Tensor, scale: torch.Tensor, 
                   zero_point: torch.Tensor) -> torch.Tensor:
        """反量化"""
        return (quantized.float() - zero_point) * scale

    def memory_savings(self):
        """计算内存节省"""
        original_bits = 16  # FP16
        savings = 1 - self.bits / original_bits
        return savings

# 使用示例
quantizer = KVCacheQuantizer(bits=4)  # 4-bit量化
print(f"内存节省: {quantizer.memory_savings():.1%}")  # 75%

4. FlashAttention 演进

4.1 FlashAttention-1

核心创新:分块计算 + 在线 Softmax

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    FlashAttention-1原理                       │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  标准Attention:                                              │
│  1. 计算 S = QK^T           [n×n矩阵,存HBM]                 │
│  2. 计算 P = softmax(S)     [n×n矩阵,存HBM]                 │
│  3. 计算 O = PV             [n×n矩阵操作]                    │
│                                                              │
│  HBM访问: O(n²) - 瓶颈!                                     │
│                                                              │
│  FlashAttention:                                             │
│  1. 分块: Q, K, V分成小块 [B_r, B_c]                         │
│  2. 在SRAM中计算:                                            │
│     - 加载Q块、K块到SRAM                                     │
│     - 计算局部S_ij = Q_i @ K_j^T                             │
│     - 在线softmax更新                                        │
│     - 累积O_i                                                │
│  3. 只写回最终O                                              │
│                                                              │
│  HBM访问: O(n) - 线性!                                      │
│                                                              │
└──────────────────────────────────────────────────────────────┘

在线 Softmax 算法

Python
def online_softmax(a: torch.Tensor, new_block: torch.Tensor):
    """
    在线Softmax: 增量更新softmax结果

    已知: 之前块的max值m_old, sum(exp)值d_old
    新块: 需要融合new_block

    算法:
    1. m_new = max(m_old, max(new_block))
    2. d_new = d_old * exp(m_old - m_new) + sum(exp(new_block - m_new))
    3. O_new = O_old * (d_old * exp(m_old - m_new) / d_new) + 
               new_output * (sum(exp(new_block - m_new)) / d_new)
    """
    m_old, d_old, O_old = a
    m_new_block = new_block.max(dim=-1, keepdim=True).values

    # 更新全局max
    m_new = torch.maximum(m_old, m_new_block)

    # 更新归一化因子
    d_new_block = torch.exp(new_block - m_new).sum(dim=-1, keepdim=True)
    d_new = d_old * torch.exp(m_old - m_new) + d_new_block

    # 更新输出
    scale_old = d_old * torch.exp(m_old - m_new) / d_new
    scale_new = d_new_block / d_new
    O_new = O_old * scale_old + new_block * scale_new

    return m_new, d_new, O_new

性能提升

指标 标准 Attention FlashAttention-1
HBM 访问 O(n²) O(n)
显存占用 O(n²) O(n)
速度 1x 2-4x
支持序列长度 ~2K ~16K (同显存)

4.2 FlashAttention-2

主要改进

Text Only
FlashAttention-2 相比 v1 的优化:

1. 减少非矩阵乘法操作
   - 优化softmax计算
   - 减少原子操作

2. 并行化改进
   - 序列长度维度并行
   - 更好的GPU利用率

3. 支持更多数据类型
   - FP16, BF16
   - FP8 (Hopper架构)

并行策略

Python
# FlashAttention-2 的并行策略
"""
FlashAttention-1: 
  - 在batch和head维度并行
  - 长序列时GPU利用率低

FlashAttention-2:
  - 额外在seq_len维度并行
  - 即使batch=1也能充分利用GPU

示例: seq_len=16K, batch=1
  - v1: 只有32个线程块 (32 heads)
  - v2: 128个线程块 (额外在seq_len上切分)
"""

性能对比

配置 FlashAttention-1 FlashAttention-2 提升
A100, seq=2K 185 TFLOPS 215 TFLOPS 16%
A100, seq=8K 150 TFLOPS 190 TFLOPS 27%
H100, seq=16K 210 TFLOPS 280 TFLOPS 33%

4.3 FlashAttention-3

针对 Hopper 架构的优化

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    FlashAttention-3特性                       │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  1. 异步计算                                                 │
│     - Tensor Core: GEMM操作                                  │
│     - CUDA Core: Softmax操作                                 │
│     - 两者并行执行!                                          │
│                                                              │
│  2. FP8支持                                                  │
│     - H100原生FP8 Tensor Core                                │
│     - 比FP16快2倍                                            │
│                                                              │
│  3. 低精度优化                                               │
│     - 块量化技术                                             │
│     - 动态缩放因子                                           │
│                                                              │
└──────────────────────────────────────────────────────────────┘

异步流水线

Python
# FlashAttention-3 异步执行伪代码
def flash_attention_3_async(Q, K, V):
    """
    异步流水线执行:
    - GEMM (Tensor Core) 和 Softmax (CUDA Core) 并行
    """
    # 流水线阶段
    for block_i in range(num_blocks):
        # 阶段1: Tensor Core计算QK^T (异步)
        s_ij = async_gemm(Q[block_i], K[block_j].T)  # Tensor Core

        # 阶段2: 同时,CUDA Core处理上一个块的softmax
        if block_i > 0:
            p_prev = softmax(s_prev)  # CUDA Core (与上面并行)
            o_prev = async_gemm(p_prev, V[block_j-1])

        # 保存当前块用于下一轮
        s_prev = s_ij

    return output

4.4 版本对比总结

特性 FlashAttention-1 FlashAttention-2 FlashAttention-3
发布时间 2022 2023 2024
HBM 访问 O(n) O(n) O(n)
并行维度 batch, head batch, head, seq batch, head, seq
数据类型 FP16, BF16 FP16, BF16 FP16, BF16, FP8
硬件优化 早期主流GPU 对新架构更友好 明显偏向 Hopper
实践收益 显著降低显存 更好的并行效率 进一步吃满新硬件
适用结论 作为基础版本理解 当前工程主线之一 依赖硬件与框架支持

5. 位置编码外推技术

5.1 RoPE 外推改进

RoPE 基础回顾

\[f(q, m) = q \cdot e^{i m \theta}\]

其中\(m\)是位置,\(\theta\)是频率。

外推问题

Text Only
训练时: 位置 m ∈ [0, L_train-1]
推理时: 位置 m ∈ [0, L_infer-1], L_infer > L_train

问题:
1. 高频分量在远距离衰减过快
2. 位置插值导致信息损失

5.2 Position Interpolation (PI)

核心思想:将长序列的位置压缩到训练范围内。

设扩展倍数为:

\[ s = \frac{L_{infer}}{L_{train}} \]

则 Position Interpolation 把推理位置 \(m\) 压缩为:

\[ m' = \frac{m}{s} = \frac{m \times L_{train}}{L_{infer}} \]

它的好处是简单直接,所有位置都会被映射回训练时见过的范围;代价是原本相距很远的位置在相位空间里会被“挤近”,因此分辨率下降。

例如两个位置 \(m_1, m_2\) 的相对间隔会从 \((m_1 - m_2)\) 变成 \((m_1 - m_2) / s\),这也是 PI 在大倍率扩窗时容易损失细粒度位置信息的根源。

Python
def position_interpolation(pos: int, max_train_len: int, max_infer_len: int) -> float:
    """位置插值"""
    return pos * max_train_len / max_infer_len

# 示例: 训练4K,推理16K
# 原始位置 16000 -> 插值后 4000

问题:压缩导致分辨率下降,性能损失。

5.3 NTK-aware Interpolation

核心洞察:RoPE 的不同频率分量对长上下文外推的敏感度不同。
与“所有位置统一压缩”的 PI 不同,NTK-aware scaling 的常见做法是增大 RoPE 的 base,让相位旋转整体变慢,尤其对低频分量影响更明显。

RoPE 的逆频率通常写成:

\[ \omega_i = \text{base}^{-\frac{2i}{d}} \]

如果把 base 调大,那么所有 \(\omega_i\) 都会变小,也就是旋转速度变慢;并且 \(i\) 越大,指数越大,低频维度受到的影响越明显。

说明:不同实现里的 NTK-aware / dynamic-NTK 公式并不完全一样。下面给出的是一种常见教学写法,不应当把它理解成所有框架都严格使用同一公式。

Python
def ntk_aware_rope_base(
    base: float,
    dim: int,
    max_train_len: int,
    max_infer_len: int,
) -> float:
    """
    一种常见的 NTK-aware base 调整写法。
    """
    if dim <= 2:
        raise ValueError("dim must be greater than 2")

    scale = max_infer_len / max_train_len
    if scale <= 1:
        return base

    return base * (scale ** (dim / (dim - 2)))

# 示例
original_base = 10000
dim = 128
train_len = 4096
infer_len = 32768

new_base = ntk_aware_rope_base(original_base, dim, train_len, infer_len)
print(f"new_base = {new_base:.0f}")

数学解释

Text Only
RoPE逆频率: ω_i = base^(-2i/d)

NTK-aware不是显式地“给每个频率单独乘不同系数”,
而是通过增大 base,让所有维度旋转变慢。

由于指数项 2i/d 随 i 增大而增大:
- 高频分量(i小)变化较小
- 低频分量(i大)变化更明显

因此它通常比直接 PI 更能保住长距离依赖。

5.4 YaRN (Yet another RoPE extension)

更准确地说:YaRN = 渐变频率缩放 + 注意力重标定

YaRN 不是简单的“NTK-aware + 温度缩放”。
它的核心思想是:

  1. 不同 rotary 维度使用不同强度的缩放
  2. 在“保持原始频率”和“按扩窗比例缩放频率”之间做平滑过渡
  3. 配合 attention rescaling,减轻长上下文下 logits 过尖或过平的问题

教学上可以把它理解为:

\[ \omega_i^{(\text{scaled})} = \frac{\omega_i}{s} \]
\[ \omega_i^{(\text{YaRN})} = r_i \omega_i + (1-r_i)\omega_i^{(\text{scaled})} \]

其中 \(r_i \in [0, 1]\) 是一个随维度变化的 ramp 系数:

  • 某些维度更接近原始频率
  • 某些维度更接近缩放后的频率
  • 中间维度平滑过渡
Python
import math
import torch

def yarn_rope_scaling(
    dim: int,
    base: float,
    scale: float,
    low_cut: float = 0.25,
    high_cut: float = 0.75,
):
    """
    YaRN教学版:
    - 用 ramp 在原始 inv_freq 和缩放后的 inv_freq 之间平滑过渡
    - 真实实现还会配合 attention_factor 等额外细节
    """
    half_dim = dim // 2
    idx = torch.arange(half_dim, dtype=torch.float32)
    inv_freq = base ** (-2 * idx / dim)
    scaled_inv_freq = inv_freq / scale

    dim_pos = idx / max(half_dim - 1, 1)
    ramp = ((dim_pos - low_cut) / max(high_cut - low_cut, 1e-6)).clamp(0.0, 1.0)

    return ramp * inv_freq + (1.0 - ramp) * scaled_inv_freq

def yarn_attention_factor(scale: float) -> float:
    """教学版 attention rescaling;真实实现的系数会因模型/仓库而异。"""
    if scale <= 1:
        return 1.0
    return 1.0 + 0.1 * math.log(scale)

典型趋势(定性)

方法 小倍率扩窗 中倍率扩窗 大倍率扩窗
直接外推 容易退化 退化明显 通常不稳定
Position Interpolation 简单稳妥 可能损失分辨率 大倍率下往往不够
NTK-aware 通常优于直接 PI 常作为实用基线 效果依赖实现
YaRN 往往更稳 更适合大倍率扩窗 需要结合具体模型验证

5.5 ALiBi (Attention with Linear Biases)

核心思想:不用位置编码,而是在注意力分数上添加线性偏置。

\[\text{Attention}(q, k, m) = \text{softmax}(qk^T - m \cdot |i-j|)\]

其中\(m\)是每个头的斜率,\(|i-j|\)是相对距离。

Python
class ALiBiAttention(nn.Module):
    """ALiBi Attention 实现"""

    def __init__(self, num_heads: int):
        super().__init__()
        self.num_heads = num_heads

        # 每个头的斜率: m_i = 1 / 2^(i * 8 / num_heads)
        # 或者使用几何序列
        slopes = self._get_slopes(num_heads)
        self.register_buffer('slopes', slopes)

    def _get_slopes(self, num_heads: int) -> torch.Tensor:
        """计算ALiBi斜率"""
        # 方法1: 几何序列
        n = 2 ** torch.floor(torch.log2(torch.tensor(num_heads)))
        m0 = 2 ** (-8 / n)
        slopes = m0 ** torch.arange(1, num_heads + 1)

        # 如果num_heads不是2的幂,需要额外处理
        if num_heads > n:
            extra_slopes = 2 ** (-4 / n)
            extra = extra_slopes ** torch.arange(1, 2 * (num_heads - n) + 1, 2)
            slopes = torch.cat([slopes, extra])

        return slopes

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
        """
        Args:
            Q: [batch, num_heads, seq_len, head_dim]
            K: [batch, num_heads, seq_len, head_dim]
            V: [batch, num_heads, seq_len, head_dim]
        """
        batch_size, num_heads, seq_len, head_dim = Q.shape

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)

        # 添加ALiBi偏置
        # 相对位置矩阵: [seq_len, seq_len]
        positions = torch.arange(seq_len, device=Q.device)
        relative_pos = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))

        # ALiBi偏置: [num_heads, seq_len, seq_len]
        # slopes: [num_heads] -> [num_heads, 1, 1]
        alibi_bias = -self.slopes.view(num_heads, 1, 1) * relative_pos.unsqueeze(0)

        # 添加到分数
        scores = scores + alibi_bias

        # Softmax和输出
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        return output

ALiBi 优势

Text Only
1. 无需训练位置编码
   - 位置信息通过偏置隐式编码
   - 减少参数量

2. 天然支持外推
   - 线性衰减不依赖于训练长度
   - 可以处理任意长度序列

3. 计算高效
   - 偏置可以预计算
   - 无额外参数

4. 性能稳定
   - 长序列性能下降缓慢
   - 比RoPE更适合超长上下文

外推能力对比

方法 小倍率外推 中倍率外推 大倍率外推
RoPE 直接外推 可能可用 容易退化 往往明显掉点
RoPE + PI 较稳 常见基线 大倍率下可能吃力
RoPE + YaRN 往往更稳 常见强基线 需结合模型验证
ALiBi 通常外推性好 保持较稳 但建模偏置与质量要权衡

6. 实践指南

6.1 技术选型决策树

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    长上下文技术选型                           │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  Q: 需要多长的上下文?                                        │
│                                                              │
│  ├─ < 8K:                                                    │
│  │   └─ 标准Transformer + FlashAttention-2                   │
│  │                                                          │
│  ├─ 8K - 32K:                                               │
│  │   ├─ 单GPU:                                              │
│  │   │   └─ FlashAttention-2 + GQA + KV Cache量化           │
│  │   └─ 多GPU:                                              │
│  │       └─ Ring Attention                                  │
│  │                                                          │
│  ├─ 32K - 128K:                                             │
│  │   ├─ 训练: Ring Attention + YaRN                         │
│  │   └─ 推理: vLLM (PagedAttention) + GQA                   │
│  │                                                          │
│  └─ > 128K (100K - 1M+):                                    │
│      └─ Ring Attention + KV Cache量化 + 分布式推理           │
│                                                              │
└──────────────────────────────────────────────────────────────┘

6.2 不同场景的推荐配置

场景 1 :长文档问答( 32K-128K )

Python
# 推荐配置
config = {
    'attention': 'FlashAttention-2',
    'position_encoding': 'YaRN',
    'kv_cache': 'PagedAttention (vLLM)',
    'attention_variant': 'GQA (8 groups)',
    'kv_quantization': 'FP8',
    'max_context': 128000,
}

# vLLM启动命令
# vllm serve meta-llama/Llama-3.1-70B \
#     --max-model-len 128000 \
#     --kv-cache-dtype fp8 \
#     --tensor-parallel-size 4

场景 2 :代码分析( 100K+)

Python
# 推荐配置
config = {
    'attention': 'Ring Attention',
    'position_encoding': 'YaRN',
    'kv_cache': 'PagedAttention + 量化',
    'attention_variant': 'GQA',
    'distributed': '张量并行',
    'max_context': 200000,
}

# 需要多GPU支持
# 具体启动命令取决于所用训练/推理框架;
# 常见组合是 sequence parallel / ring attention / tensor parallel 一起配置。

场景 3 :多轮对话( 8K-32K )

Python
# 推荐配置
config = {
    'attention': 'FlashAttention-2',
    'position_encoding': 'RoPE + NTK-aware',
    'kv_cache': 'PagedAttention',
    'attention_variant': 'GQA',
    'continuous_batching': True,
    'max_context': 32000,
}

# vLLM配置
# vllm serve model_name \
#     --max-model-len 32768 \
#     --enable-prefix-caching \
#     --gpu-memory-utilization 0.9

6.3 典型收益对比(定性)

⚠️ 不同模型、硬件、batch size、服务框架差异很大,下面只保留“趋势级”结论,不给出脱离上下文的固定数值。

配置 主要改善点 显存趋势 吞吐趋势 更适合的场景
基线 (MHA + FP16) 作为对照组 基线 短上下文 / 小规模实验
+ FlashAttention-⅔ 降低 IO 与中间激活开销 中到高 单卡更长上下文
+ GQA / MQA 显著降低 KV Cache 更低 长解码 / 服务化
+ KV 量化 进一步压缩 KV / 权重显存 更低 视硬件而定 显存紧张场景
+ PagedAttention 降低碎片,提高利用率 更稳 多请求在线推理
Ring / 序列并行 把超长序列拆到多卡 单卡占用下降 受通信影响 超长上下文训练 / 推理

6.4 常见问题与解决方案

Q1: 显存不足怎么办?

Python
# 按优先级尝试:
solutions = [
    "1. 启用KV Cache量化 (FP8/INT8)",
    "2. 使用GQA替代MHA",
    "3. 减小batch_size",
    "4. 使用PagedAttention (vLLM)",
    "5. 多卡分布式推理",
]

Q2: 长序列推理速度慢?

Python
# 优化步骤:
optimizations = [
    "1. 确保使用FlashAttention-2/3",
    "2. 启用连续批处理",
    "3. 考虑投机采样 (Speculative Decoding)",
    "4. 使用张量并行",
]

Q3: 如何评估外推能力?

Python
def evaluate_extrapolation(model, tokenizer, test_lengths):
    """评估模型的外推能力"""
    results = {}

    for length in test_lengths:
        # 构造测试样本
        text = "..." * length  # 长文本
        inputs = tokenizer(text, return_tensors="pt")

        # 测试困惑度
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs.input_ids)
            ppl = torch.exp(outputs.loss).item()

        results[length] = {
            'perplexity': ppl,
            'success': ppl < threshold
        }

    return results

7. 总结

7.1 技术关系图

Text Only
┌──────────────────────────────────────────────────────────────┐
│                    长上下文技术栈                             │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              应用层 (100K - 1M tokens)               │    │
│  │  长文档理解 │ 代码分析 │ 多轮对话 │ RAG              │    │
│  └─────────────────────────────────────────────────────┘    │
│                          ↓                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              系统层 (分布式 + 内存管理)               │    │
│  │  Ring Attention │ PagedAttention │ Continuous Batching│   │
│  └─────────────────────────────────────────────────────┘    │
│                          ↓                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              算法层 (注意力优化)                      │    │
│  │  FlashAttention-1/2/3 │ MQA/GQA │ KV量化            │    │
│  └─────────────────────────────────────────────────────┘    │
│                          ↓                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              基础层 (位置编码)                        │    │
│  │  RoPE │ ALiBi │ YaRN │ NTK-aware                    │    │
│  └─────────────────────────────────────────────────────┘    │
│                                                              │
└──────────────────────────────────────────────────────────────┘

7.2 关键要点

  1. Ring Attention:核心价值是把超长序列的注意力拆到多卡上,并用在线 softmax 保持结果正确。
  2. KV Cache 优化:GQA、PagedAttention、KV 量化是常见主线组合,但不是所有模型都必须三者同时使用。
  3. FlashAttention:核心是降低显存和 IO 开销,不改变自注意力的理论算术复杂度 \(O(n^2 d)\)
  4. 位置编码外推:没有"放之四海皆准"的最佳方案,PI、NTK-aware、YaRN 都要结合基础模型和扩窗倍率验证。
  5. 实践选型:先看目标上下文长度,再看单卡显存、通信带宽和服务形态,最后决定组合方式。

🤔 思考题

💡 思考题参考解答 **Q1: Ring Attention 的在线 softmax 为什么能保证结果与标准 softmax 完全一致?** 标准 softmax:$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$ Ring Attention 使用**在线 softmax(Online Softmax / FlashAttention 分块技巧)**: - 维护两个累加器:`m`(当前最大值)和 `l`(指数和) - 每处理一个新的 KV 块时,更新最大值并调整之前的累加结果 - 数学上等价于先收集所有分数再做全局 softmax 关键公式:当最大值从 $m_{old}$ 更新为 $m_{new}$ 时,之前的累加和需要乘以 $e^{m_{old} - m_{new}}$ 进行修正。 **Q2: YaRN 为什么比简单的 Position Interpolation (PI) 效果更好?** PI 直接将位置线性压缩到原始窗口内,导致: - 所有位置的相对距离被等比缩小 - 高频位置信息丢失,模型难以区分相邻 token YaRN 的改进: 1. **NTK-aware 缩放**:通过调整 RoPE 的频率基数,保留高频信息 2. **动态缩放因子**:对不同频率分量使用不同的缩放策略 3. **注意力温度调整**:补偿因位置缩放导致的注意力分布偏移 结果:YaRN 在 128K 外推时仍保持接近原始窗口的性能,而 PI 在 8-16K 就开始明显退化。 **Q3: 如果要在单张 A100 (80GB) 上服务 128K 上下文的 70B 模型,你会如何组合本章的技术?** 推荐组合方案: 1. **GQA**:减少 KV Cache 参数量(如 LLaMA-2 70B 已使用 8 组 GQA) 2. **KV Cache 量化(FP8/INT8)**:将 KV Cache 压缩 2 倍 3. **PagedAttention**:管理 KV Cache 内存碎片 4. **FlashAttention-2/3**:减少注意力计算的显存峰值 5. **模型量化(AWQ INT4)**:70B 模型从 140GB 压缩到 ~35GB 估算:35GB(模型权重)+ 20GB(128K KV Cache with GQA+FP8)+ 10GB(激活值)≈ 65GB < 80GB,可行。 **Q4: FlashAttention-3 相比 FlashAttention-2 的主要改进是什么?** FlashAttention-3 针对 Hopper 架构(H100)的三大优化: 1. **异步执行**:利用 H100 的 Tensor Memory Accelerator (TMA) 异步加载数据,计算与 IO 重叠 2. **FP8 支持**:利用 FP8 Tensor Core,吞吐量翻倍(FP16 的 2 倍) 3. ** warp 专业化**:将 warp 分为生产者(加载数据)和消费者(计算),减少等待 结果:在 H100 上达到接近理论峰值的 75% 利用率(FA2 约为 50%)。

8. 参考资料

8.1 核心论文

  1. Ring Attention: "Ring Attention with Blockwise Transformers for Near-Infinite Context" (Liu et al., 2023)
  2. FlashAttention-1: "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
  3. FlashAttention-2: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (Dao, 2023)
  4. FlashAttention-3: "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (Dao et al., 2024)
  5. PagedAttention: "Efficient Memory Management for Large Language Model Serving with PagedAttention" (vLLM, 2023)
  6. MQA: "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019)
  7. GQA: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023)
  8. YaRN: "YaRN: Efficient Context Window Extension of Large Language Models" (Peng et al., 2023)
  9. ALiBi: "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (Press et al., 2022)

8.2 开源实现


最后更新日期: 2026-03-27 适用版本: LLM 学习教程 v2026

相关章节: - 02-推理优化技术 - 推理优化基础 - 03-大模型预训练 - 预训练中的长上下文处理