08 - 长上下文技术¶
⚠️ 时效性说明:本章涉及前沿技术,可能随研究进展快速变化;请以论文原文和官方实现为准。
学习目标:深入理解大模型长上下文处理的核心技术,包括Ring Attention、KV Cache优化、FlashAttention演进、位置编码外推等。
1. 长上下文挑战¶
1.1 核心问题概述¶
长上下文处理是大模型面临的关键技术挑战之一。随着应用场景对更长上下文的需求(如长文档理解、多轮对话、代码分析等),传统Transformer架构面临三大瓶颈:
┌─────────────────────────────────────────────────────────────┐
│ 长上下文三大挑战│
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 计算复杂度O(n²) │
│ - 注意力矩阵: n×n │
│ - 100K tokens → 10^10次计算 │
│ │
│ 2. 显存占用爆炸 │
│ - KV Cache线性增长 │
│ - 100K context → 数十GB显存 │
│ │
│ 3. 位置编码外推 │
│ - 训练长度有限(如4K/8K) │
│ - 推理时需要处理更长序列 │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 注意力复杂度O(n²)问题¶
标准自注意力计算:
复杂度分析:
输入序列长度: n
隐藏维度: d
计算步骤:
1. Q = XW_Q: O(n × d × d) = O(nd²)
2. K = XW_K: O(n × d × d) = O(nd²)
3. V = XW_V: O(n × d × d) = O(nd²)
4. QK^T: O(n × d × n) = O(n²d) ← 瓶颈!
5. Softmax: O(n²)
6. Attention × V: O(n² × d) = O(n²d) ← 瓶颈!
总复杂度: O(n²d)
空间复杂度: O(n²)(存储注意力矩阵)
实际影响:
# 不同序列长度的注意力矩阵大小
def attention_memory_mb(seq_len, dtype_bytes=2): # FP16
"""计算注意力矩阵显存占用"""
elements = seq_len * seq_len
mb = elements * dtype_bytes / (1024 * 1024)
return mb
# 示例
lengths = [1024, 4096, 16384, 32768, 100000]
for n in lengths:
mem = attention_memory_mb(n)
print(f"序列长度 {n:>6}: 注意力矩阵 {mem:>10.2f} MB")
# 输出:
# 序列长度1024: 注意力矩阵2.00 MB
# 序列长度4096: 注意力矩阵32.00 MB
# 序列长度 16384: 注意力矩阵512.00 MB
# 序列长度 32768: 注意力矩阵2048.00 MB
# 序列长度 100000: 注意力矩阵19073.49 MB (~19GB!)
1.3 显存占用随序列长度增长¶
KV Cache显存计算:
单层KV Cache大小:
K: batch_size × num_heads × seq_len × head_dim × dtype_size
V: batch_size × num_heads × seq_len × head_dim × dtype_size
总计 (L层):
2 × batch_size × num_heads × seq_len × head_dim × L × dtype_size
简化公式:
KV_Cache = 2 × b × h × s × d × L × bytes
实际案例(LLaMA-65B):
def kv_cache_memory(batch_size, num_layers, num_heads, head_dim, seq_len, dtype_bytes=2):
"""计算KV Cache显存占用"""
return 2 * batch_size * num_layers * num_heads * head_dim * seq_len * dtype_bytes
# LLaMA-65B 参数
config = {
'batch_size': 1,
'num_layers': 80,
'num_heads': 64,
'head_dim': 128,
'seq_len': 100000 # 100K context
}
mem_bytes = kv_cache_memory(**config)
mem_gb = mem_bytes / (1024**3)
print(f"100K context KV Cache: {mem_gb:.2f} GB")
# 输出: 100K context KV Cache: 122.07 GB
1.4 位置编码外推问题¶
问题本质:
训练时: 最大序列长度 L_train (如 2048/4096)
推理时: 需要处理 L_infer > L_train 的序列
位置编码困境:
- 训练时未见过的位置 → 位置编码无意义或效果差
- 外推导致注意力分布异常 → 生成质量下降
不同位置编码的外推能力:
| 位置编码类型 | 外推能力 | 问题 |
|---|---|---|
| 绝对位置编码 | 差 | 超出训练范围的位置无定义 |
| RoPE | 中等 | 远距离位置衰减过快 |
| ALiBi | 好 | 线性衰减,可外推 |
| 相对位置编码 | 较好 | 但计算开销大 |
2. Ring Attention:分布式长上下文¶
2.1 核心原理¶
Ring Attention是一种分布式注意力计算技术,通过分块计算和环状通信实现任意长度序列的处理。
核心思想:
┌──────────────────────────────────────────────────────────────┐
│ Ring Attention原理 │
├──────────────────────────────────────────────────────────────┤
│ │
│ 传统方法:单GPU需要存储完整注意力矩阵 │
│ │
│ Ring Attention: │
│ 1. 将序列分块: Q, K, V → [Block_1, Block_2, ..., Block_n] │
│ 2. 每个GPU持有部分K、V块 │
│ 3. 环状传递K、V块,计算局部注意力 │
│ 4. 累积得到完整注意力结果 │
│ │
│ GPU0: [Q_0] ← K_1 ← K_2 ← K_3 ← K_0 (环形) │
│ GPU1: [Q_1] ← K_2 ← K_3 ← K_0 ← K_1 │
│ GPU2: [Q_2] ← K_3 ← K_0 ← K_1 ← K_2 │
│ GPU3: [Q_3] ← K_0 ← K_1 ← K_2 ← K_3 │
│ │
└──────────────────────────────────────────────────────────────┘
2.2 数学形式化¶
分块注意力计算:
将序列分成\(P\)个块,每个GPU处理一个块:
第\(i\)个GPU的注意力计算:
环状通信模式:
时间步 t=0: GPU_i 持有 K_i, V_i
时间步 t=1: GPU_i 接收 K_{i-1}, V_{i-1} (从左边GPU)
时间步 t=2: GPU_i 接收 K_{i-2}, V_{i-2}
...
时间步 t=P-1: 完成所有块的注意力计算
复杂度分析:
计算复杂度: O(n²/P) (每个GPU)
通信复杂度: O(n × d × P) (环状传递)
总复杂度: O(n²/P + n×d×P)
当 P 个GPU时:
- 显存: O(n/P × d) (每个GPU只存部分)
- 支持: 任意长度 (理论上)
2.3 代码实现¶
简化的Ring Attention实现:
import torch
import torch.distributed as dist
class RingAttention:
"""Ring Attention 简化实现"""
def __init__(self, rank, world_size, head_dim, block_size=1024):
self.rank = rank
self.world_size = world_size
self.head_dim = head_dim
self.block_size = block_size
# 环状通信的邻居
self.send_rank = (rank + 1) % world_size
self.recv_rank = (rank - 1) % world_size
def compute_attention_block(self, Q_block, K_block, V_block):
"""计算单个块的注意力"""
# Scaled Dot-Product Attention
scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V_block)
return output
def ring_attention(self, Q, K, V):
"""
Ring Attention 主函数
Args:
Q: [batch, heads, seq_len/P, head_dim] - 本GPU的Q块
K: [batch, heads, seq_len/P, head_dim] - 本GPU的K块
V: [batch, heads, seq_len/P, head_dim] - 本GPU的V块
Returns:
output: [batch, heads, seq_len/P, head_dim]
"""
# 初始化输出和累积变量
output = torch.zeros_like(Q)
max_score = torch.full(Q.shape[:3], float('-inf'), device=Q.device)
sum_exp = torch.zeros(Q.shape[:3], device=Q.device)
# 当前持有的K, V块
K_current = K.clone()
V_current = V.clone()
for step in range(self.world_size):
# 计算当前块的注意力
block_output, block_max, block_sum = self._flash_attention_block(
Q, K_current, V_current
)
# 累积结果(使用数值稳定的softmax)
new_max = torch.maximum(max_score, block_max)
scale_old = torch.exp(max_score - new_max)
scale_new = torch.exp(block_max - new_max)
output = output * scale_old.unsqueeze(-1) + block_output * scale_new.unsqueeze(-1)
sum_exp = sum_exp * scale_old + block_sum * scale_new
max_score = new_max
# 环状传递K, V块
if step < self.world_size - 1:
K_current, V_current = self._ring_send_recv(K_current, V_current)
# 归一化
output = output / sum_exp.unsqueeze(-1)
return output
def _flash_attention_block(self, Q, K, V):
"""Flash Attention风格的分块计算"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
# 数值稳定的softmax
max_score = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - max_score)
sum_exp = exp_scores.sum(dim=-1, keepdim=True)
output = torch.matmul(exp_scores, V)
return output, max_score.squeeze(-1), sum_exp.squeeze(-1)
def _ring_send_recv(self, K, V):
"""环状发送和接收K, V块"""
K_recv = torch.empty_like(K)
V_recv = torch.empty_like(V)
# 异步发送和接收
send_K = dist.isend(K.contiguous(), self.send_rank)
send_V = dist.isend(V.contiguous(), self.send_rank)
recv_K = dist.irecv(K_recv, self.recv_rank)
recv_V = dist.irecv(V_recv, self.recv_rank)
# 等待完成
send_K.wait()
send_V.wait()
recv_K.wait()
recv_V.wait()
return K_recv, V_recv
与FlashAttention结合:
def ring_flash_attention(Q, K, V, rank, world_size):
"""
Ring Attention + Flash Attention 组合
结合两者优势:
- Ring: 分布式处理超长序列
- Flash: 单GPU内高效计算
"""
from flash_attn import flash_attn_func
output = torch.zeros_like(Q)
K_current = K.clone()
V_current = V.clone()
for step in range(world_size):
# 使用Flash Attention计算当前块
block_output = flash_attn_func(Q, K_current, V_current, causal=True)
# 累积结果...
output += block_output # 简化版,实际需要正确的softmax累积
# 环状传递
if step < world_size - 1:
K_current, V_current = ring_send_recv(K_current, V_current, rank, world_size)
return output
2.4 应用案例¶
支持的超长上下文模型:
| 模型 | 上下文长度 | 使用技术 |
|---|---|---|
| Claude 3 | 200K | Ring Attention |
| Gemini 1.5 Pro | 1M-2M | Ring Attention + 优化 |
| GPT-4-Turbo | 128K | 类似技术 |
| LLaMA 3.1 | 128K | Ring Attention |
3. KV Cache优化技术¶
3.1 PagedAttention (vLLM)¶
核心思想:将KV Cache按页管理,类似操作系统的虚拟内存。
┌──────────────────────────────────────────────────────────────┐
│ PagedAttention原理 │
├──────────────────────────────────────────────────────────────┤
│ │
│ 传统KV Cache: │
│ ┌────────────────────────────────────────┐ │
│ │ 连续预分配,大量浪费│ │
│ └────────────────────────────────────────┘ │
│ │
│ PagedAttention: │
│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │
│ │Page│Page│Page│Page│Page│Page│Page│Page│ │
│ │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ │
│ └────┴────┴────┴────┴────┴────┴────┴────┘ │
│ │
│ Sequence A: Page 0 → Page 2 → Page 5 (逻辑连续) │
│ Sequence B: Page 1 → Page 3 → Page 4 (逻辑连续) │
│ Sequence C: Page 6 → Page 7 (共享beam search) │
│ │
└──────────────────────────────────────────────────────────────┘
内存效率对比:
# 传统预分配 vs PagedAttention
def memory_comparison():
# 假设: batch_size=32, max_seq_len=2048, 实际平均长度=512
# 传统方法: 按最大长度预分配
traditional_memory = 32 * 2048 * kv_per_token
# PagedAttention: 按实际需要分配
paged_memory = 32 * 512 * kv_per_token
# 加上少量页表开销 (约5%)
paged_memory *= 1.05
savings = (traditional_memory - paged_memory) / traditional_memory
print(f"内存节省: {savings:.1%}") # 约75%节省
vLLM核心实现:
class BlockManager:
"""PagedAttention 块管理器"""
def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int):
self.num_blocks = num_blocks
self.block_size = block_size # 每页token数
self.num_heads = num_heads
self.head_dim = head_dim
# 物理块池
self.k_cache = torch.zeros(num_blocks, block_size, num_heads, head_dim)
self.v_cache = torch.zeros(num_blocks, block_size, num_heads, head_dim)
# 空闲块列表
self.free_blocks = list(range(num_blocks))
# 序列到块的映射
self.seq_to_blocks = {} # seq_id -> [block_ids]
self.seq_lengths = {} # seq_id -> current_length
def allocate(self, seq_id: int) -> int:
"""为序列分配新块"""
if not self.free_blocks:
raise RuntimeError("Out of memory: no free blocks")
block_id = self.free_blocks.pop(0)
if seq_id not in self.seq_to_blocks:
self.seq_to_blocks[seq_id] = []
self.seq_lengths[seq_id] = 0
self.seq_to_blocks[seq_id].append(block_id)
return block_id
def append_token(self, seq_id: int, k: torch.Tensor, v: torch.Tensor):
"""追加一个token的KV"""
if seq_id not in self.seq_to_blocks:
self.allocate(seq_id)
current_len = self.seq_lengths[seq_id]
block_idx = current_len // self.block_size
token_idx = current_len % self.block_size
# 如果当前块满了,分配新块
if token_idx == 0 and current_len > 0:
self.allocate(seq_id)
block_idx += 1
block_id = self.seq_to_blocks[seq_id][block_idx]
# 写入KV Cache
self.k_cache[block_id, token_idx] = k
self.v_cache[block_id, token_idx] = v
self.seq_lengths[seq_id] += 1
def get_kv_cache(self, seq_id: int):
"""获取序列的完整KV Cache"""
block_ids = self.seq_to_blocks[seq_id]
length = self.seq_lengths[seq_id]
# 收集所有块的KV
k_blocks = self.k_cache[block_ids] # [num_blocks, block_size, heads, dim]
v_blocks = self.v_cache[block_ids]
# 展平并截取实际长度
k = k_blocks.view(-1, self.num_heads, self.head_dim)[:length]
v = v_blocks.view(-1, self.num_heads, self.head_dim)[:length]
return k, v
def fork(self, parent_seq: int, child_seq: int):
"""复制序列(用于beam search),共享块"""
parent_blocks = self.seq_to_blocks[parent_seq]
self.seq_to_blocks[child_seq] = parent_blocks.copy() # Copy-on-write
self.seq_lengths[child_seq] = self.seq_lengths[parent_seq]
def free(self, seq_id: int):
"""释放序列的所有块"""
if seq_id in self.seq_to_blocks:
self.free_blocks.extend(self.seq_to_blocks[seq_id])
del self.seq_to_blocks[seq_id]
del self.seq_lengths[seq_id]
3.2 Multi-Query Attention (MQA) / Grouped-Query Attention (GQA)¶
架构对比:
┌──────────────────────────────────────────────────────────────┐
│ MHA vs MQA vs GQA对比 │
├──────────────────────────────────────────────────────────────┤
│ │
│ Multi-Head Attention (MHA): │
│ 每个头有独立的K, V │
│ KV Cache: [batch, num_heads, seq_len, head_dim] │
│ 内存: O(h × s × d) │
│ │
│ Multi-Query Attention (MQA): │
│ 所有头共享一组K, V │
│ KV Cache: [batch, 1, seq_len, head_dim] │
│ 内存: O(s × d),节省h倍! │
│ │
│ Grouped-Query Attention (GQA): │
│ 头分成g组,每组共享K, V │
│ KV Cache: [batch, g, seq_len, head_dim] │
│ 内存: O(g × s × d),平衡质量和效率 │
│ │
└──────────────────────────────────────────────────────────────┘
GQA实现:
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention 实现"""
def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int):
"""
Args:
hidden_size: 隐藏维度
num_heads: 查询头数 (Q)
num_kv_heads: KV头数 (通常 < num_heads)
"""
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.num_heads_per_group = num_heads // num_kv_heads
# Q有num_heads个头
self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False)
# K和V只有num_kv_heads个头
self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False)
def forward(self, hidden_states, attention_mask=None, kv_cache=None):
batch_size, seq_len, _ = hidden_states.shape
# 计算Q, K, V
Q = self.q_proj(hidden_states)
K = self.k_proj(hidden_states)
V = self.v_proj(hidden_states)
# 重塑形状
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# 处理KV Cache
if kv_cache is not None:
K = torch.cat([kv_cache['k'], K], dim=2)
V = torch.cat([kv_cache['v'], V], dim=2)
# 更新缓存
new_kv_cache = {'k': K, 'v': V}
# 扩展K, V以匹配Q的头数
# [batch, num_kv_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim]
K = self._repeat_kv(K)
V = self._repeat_kv(V)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if attention_mask is not None:
scores = scores + attention_mask
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
# 重塑并投影
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = self.o_proj(output)
return output, new_kv_cache
def _repeat_kv(self, x):
"""将KV头复制以匹配Q头数"""
if self.num_heads_per_group == 1:
return x
batch, num_kv_heads, seq_len, head_dim = x.shape
x = x[:, :, None, :, :].expand(batch, num_kv_heads, self.num_heads_per_group, seq_len, head_dim)
return x.reshape(batch, num_kv_heads * self.num_heads_per_group, seq_len, head_dim)
KV Cache节省计算:
def kv_cache_savings():
"""计算GQA的KV Cache节省"""
# 以LLaMA-2-70B为例
num_heads = 64
num_kv_heads = 8 # GQA with 8 groups
seq_len = 100000 # 100K context
head_dim = 128
dtype_bytes = 2 # FP16
# MHA
mha_kv = 2 * num_heads * seq_len * head_dim * dtype_bytes
mha_gb = mha_kv / (1024**3)
# GQA
gqa_kv = 2 * num_kv_heads * seq_len * head_dim * dtype_bytes
gqa_gb = gqa_kv / (1024**3)
print(f"MHA KV Cache: {mha_gb:.2f} GB")
print(f"GQA KV Cache: {gqa_gb:.2f} GB")
print(f"节省: {(1 - gqa_gb/mha_gb):.1%}")
# 输出:
# MHA KV Cache: 3.05 GB
# GQA KV Cache: 0.38 GB
# 节省: 87.5%
3.3 KV Cache量化¶
量化策略:
class KVCacheQuantizer:
"""KV Cache 量化器"""
def __init__(self, bits: int = 8, group_size: int = 128):
"""
Args:
bits: 量化位数 (4 或 8)
group_size: 分组量化的大小
"""
self.bits = bits
self.group_size = group_size
self.scale = None
self.zero_point = None
def quantize(self, x: torch.Tensor) -> tuple:
"""
将FP16 KV Cache量化到低精度
Args:
x: [batch, heads, seq_len, head_dim]
Returns:
quantized: 量化后的数据
scale: 缩放因子
zero_point: 零点
"""
# 分组量化
x_grouped = x.reshape(-1, self.group_size)
# 计算每组的min/max
x_min = x_grouped.min(dim=-1, keepdim=True).values
x_max = x_grouped.max(dim=-1, keepdim=True).values
# 计算scale和zero_point
qmin = 0
qmax = 2 ** self.bits - 1
scale = (x_max - x_min) / (qmax - qmin)
scale = scale.clamp(min=1e-8)
zero_point = qmin - x_min / scale
zero_point = zero_point.round().clamp(qmin, qmax).to(torch.int32)
# 量化
quantized = ((x_grouped / scale) + zero_point).round().clamp(qmin, qmax)
quantized = quantized.to(torch.uint8 if self.bits == 8 else torch.int8)
return quantized, scale, zero_point
def dequantize(self, quantized: torch.Tensor, scale: torch.Tensor,
zero_point: torch.Tensor) -> torch.Tensor:
"""反量化"""
return (quantized.float() - zero_point) * scale
def memory_savings(self):
"""计算内存节省"""
original_bits = 16 # FP16
savings = 1 - self.bits / original_bits
return savings
# 使用示例
quantizer = KVCacheQuantizer(bits=4) # 4-bit量化
print(f"内存节省: {quantizer.memory_savings():.1%}") # 75%
4. FlashAttention演进¶
4.1 FlashAttention-1¶
核心创新:分块计算 + 在线Softmax
┌──────────────────────────────────────────────────────────────┐
│ FlashAttention-1原理 │
├──────────────────────────────────────────────────────────────┤
│ │
│ 标准Attention: │
│ 1. 计算 S = QK^T [n×n矩阵,存HBM] │
│ 2. 计算 P = softmax(S) [n×n矩阵,存HBM] │
│ 3. 计算 O = PV [n×n矩阵操作] │
│ │
│ HBM访问: O(n²) - 瓶颈! │
│ │
│ FlashAttention: │
│ 1. 分块: Q, K, V分成小块 [B_r, B_c] │
│ 2. 在SRAM中计算: │
│ - 加载Q块、K块到SRAM │
│ - 计算局部S_ij = Q_i @ K_j^T │
│ - 在线softmax更新 │
│ - 累积O_i │
│ 3. 只写回最终O │
│ │
│ HBM访问: O(n) - 线性! │
│ │
└──────────────────────────────────────────────────────────────┘
在线Softmax算法:
def online_softmax(a: torch.Tensor, new_block: torch.Tensor):
"""
在线Softmax: 增量更新softmax结果
已知: 之前块的max值m_old, sum(exp)值d_old
新块: 需要融合new_block
算法:
1. m_new = max(m_old, max(new_block))
2. d_new = d_old * exp(m_old - m_new) + sum(exp(new_block - m_new))
3. O_new = O_old * (d_old * exp(m_old - m_new) / d_new) +
new_output * (sum(exp(new_block - m_new)) / d_new)
"""
m_old, d_old, O_old = a
m_new_block = new_block.max(dim=-1, keepdim=True).values
# 更新全局max
m_new = torch.maximum(m_old, m_new_block)
# 更新归一化因子
d_new_block = torch.exp(new_block - m_new).sum(dim=-1, keepdim=True)
d_new = d_old * torch.exp(m_old - m_new) + d_new_block
# 更新输出
scale_old = d_old * torch.exp(m_old - m_new) / d_new
scale_new = d_new_block / d_new
O_new = O_old * scale_old + new_block * scale_new
return m_new, d_new, O_new
性能提升:
| 指标 | 标准Attention | FlashAttention-1 |
|---|---|---|
| HBM访问 | O(n²) | O(n) |
| 显存占用 | O(n²) | O(n) |
| 速度 | 1x | 2-4x |
| 支持序列长度 | ~2K | ~16K (同显存) |
4.2 FlashAttention-2¶
主要改进:
FlashAttention-2 相比 v1 的优化:
1. 减少非矩阵乘法操作
- 优化softmax计算
- 减少原子操作
2. 并行化改进
- 序列长度维度并行
- 更好的GPU利用率
3. 支持更多数据类型
- FP16, BF16
- FP8 (Hopper架构)
并行策略:
# FlashAttention-2 的并行策略
"""
FlashAttention-1:
- 在batch和head维度并行
- 长序列时GPU利用率低
FlashAttention-2:
- 额外在seq_len维度并行
- 即使batch=1也能充分利用GPU
示例: seq_len=16K, batch=1
- v1: 只有32个线程块 (32 heads)
- v2: 128个线程块 (额外在seq_len上切分)
"""
性能对比:
| 配置 | FlashAttention-1 | FlashAttention-2 | 提升 |
|---|---|---|---|
| A100, seq=2K | 185 TFLOPS | 215 TFLOPS | 16% |
| A100, seq=8K | 150 TFLOPS | 190 TFLOPS | 27% |
| H100, seq=16K | 210 TFLOPS | 280 TFLOPS | 33% |
4.3 FlashAttention-3¶
针对Hopper架构的优化:
┌──────────────────────────────────────────────────────────────┐
│ FlashAttention-3特性 │
├──────────────────────────────────────────────────────────────┤
│ │
│ 1. 异步计算 │
│ - Tensor Core: GEMM操作 │
│ - CUDA Core: Softmax操作 │
│ - 两者并行执行! │
│ │
│ 2. FP8支持 │
│ - H100原生FP8 Tensor Core │
│ - 比FP16快2倍 │
│ │
│ 3. 低精度优化 │
│ - 块量化技术 │
│ - 动态缩放因子 │
│ │
└──────────────────────────────────────────────────────────────┘
异步流水线:
# FlashAttention-3 异步执行伪代码
def flash_attention_3_async(Q, K, V):
"""
异步流水线执行:
- GEMM (Tensor Core) 和 Softmax (CUDA Core) 并行
"""
# 流水线阶段
for block_i in range(num_blocks):
# 阶段1: Tensor Core计算QK^T (异步)
s_ij = async_gemm(Q[block_i], K[block_j].T) # Tensor Core
# 阶段2: 同时,CUDA Core处理上一个块的softmax
if block_i > 0:
p_prev = softmax(s_prev) # CUDA Core (与上面并行)
o_prev = async_gemm(p_prev, V[block_j-1])
# 保存当前块用于下一轮
s_prev = s_ij
return output
4.4 版本对比总结¶
| 特性 | FlashAttention-1 | FlashAttention-2 | FlashAttention-3 |
|---|---|---|---|
| 发布时间 | 2022 | 2023 | 2024 |
| HBM访问 | O(n) | O(n) | O(n) |
| 并行维度 | batch, head | batch, head, seq | batch, head, seq |
| 数据类型 | FP16, BF16 | FP16, BF16 | FP16, BF16, FP8 |
| 硬件优化 | Ampere | Ampere, Hopper | Hopper专用 |
| 相对速度 | 2-4x | 3-5x | 5-8x |
| 最长序列 | ~16K | ~64K | ~128K+ |
5. 位置编码外推技术¶
5.1 RoPE外推改进¶
RoPE基础回顾:
其中\(m\)是位置,\(\theta\)是频率。
外推问题:
训练时: 位置 m ∈ [0, L_train-1]
推理时: 位置 m ∈ [0, L_infer-1], L_infer > L_train
问题:
1. 高频分量在远距离衰减过快
2. 位置插值导致信息损失
5.2 Position Interpolation (PI)¶
核心思想:将长序列的位置压缩到训练范围内。
def position_interpolation(pos: int, max_train_len: int, max_infer_len: int) -> float:
"""位置插值"""
return pos * max_train_len / max_infer_len
# 示例: 训练4K,推理16K
# 原始位置 16000 -> 插值后 4000
问题:压缩导致分辨率下降,性能损失。
5.3 NTK-aware Interpolation¶
核心洞察:不同频率分量需要不同的缩放。
def ntk_aware_rope(base: float, max_train_len: int, max_infer_len: int) -> float:
"""
NTK-aware位置插值
原理: 高频分量不插值,低频分量插值
通过调整base来实现
"""
scale = max_infer_len / max_train_len
new_base = base * (scale ** (dim / (dim - 2)))
return new_base
# 示例
original_base = 10000
train_len = 4096
infer_len = 32768
new_base = ntk_aware_rope(original_base, train_len, infer_len)
# new_base ≈ 1250000 (大幅增加)
数学解释:
RoPE频率: θ_i = base^(-2i/d)
NTK-aware调整:
- 高频 (i小): 保持原频率,不插值
- 低频 (i大): 频率降低,允许外推
新base计算:
base_new = base_old × scale^(d/(d-2))
5.4 YaRN (Yet another RoPE extension)¶
YaRN = NTK-aware + 温度缩放
def yarn_rope_scaling(
pos: int,
dim: int,
base: float,
max_train_len: int,
max_infer_len: int,
temperature: float = 1.0
):
"""
YaRN: 结合NTK-aware和温度缩放
温度缩放: 调整softmax前的注意力分数
"""
scale = max_infer_len / max_train_len
# NTK-aware base调整
ntk_base = base * (scale ** (dim / (dim - 2)))
# 计算频率
freq = 1.0 / (ntk_base ** (torch.arange(0, dim, 2).float() / dim))
# 应用位置
angles = pos * freq
# 温度缩放
angles = angles / temperature
return angles
# YaRN推荐温度
def get_yarn_temperature(scale: float, dim: int) -> float:
"""计算YaRN温度参数"""
if scale <= 1:
return 1.0
return 1.0 + 0.1 * torch.log(scale) # 经验公式
性能对比:
| 方法 | 4K→8K | 4K→16K | 4K→32K |
|---|---|---|---|
| 直接外推 | 65% | 45% | 25% |
| Position Interpolation | 80% | 70% | 55% |
| NTK-aware | 85% | 78% | 65% |
| YaRN | 90% | 85% | 75% |
5.5 ALiBi (Attention with Linear Biases)¶
核心思想:不用位置编码,而是在注意力分数上添加线性偏置。
其中\(m\)是每个头的斜率,\(|i-j|\)是相对距离。
class ALiBiAttention(nn.Module):
"""ALiBi Attention 实现"""
def __init__(self, num_heads: int):
super().__init__()
self.num_heads = num_heads
# 每个头的斜率: m_i = 1 / 2^(i * 8 / num_heads)
# 或者使用几何序列
slopes = self._get_slopes(num_heads)
self.register_buffer('slopes', slopes)
def _get_slopes(self, num_heads: int) -> torch.Tensor:
"""计算ALiBi斜率"""
# 方法1: 几何序列
n = 2 ** torch.floor(torch.log2(torch.tensor(num_heads)))
m0 = 2 ** (-8 / n)
slopes = m0 ** torch.arange(1, num_heads + 1)
# 如果num_heads不是2的幂,需要额外处理
if num_heads > n:
extra_slopes = 2 ** (-4 / n)
extra = extra_slopes ** torch.arange(1, 2 * (num_heads - n) + 1, 2)
slopes = torch.cat([slopes, extra])
return slopes
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
"""
Args:
Q: [batch, num_heads, seq_len, head_dim]
K: [batch, num_heads, seq_len, head_dim]
V: [batch, num_heads, seq_len, head_dim]
"""
batch_size, num_heads, seq_len, head_dim = Q.shape
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)
# 添加ALiBi偏置
# 相对位置矩阵: [seq_len, seq_len]
positions = torch.arange(seq_len, device=Q.device)
relative_pos = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
# ALiBi偏置: [num_heads, seq_len, seq_len]
# slopes: [num_heads] -> [num_heads, 1, 1]
alibi_bias = -self.slopes.view(num_heads, 1, 1) * relative_pos.unsqueeze(0)
# 添加到分数
scores = scores + alibi_bias
# Softmax和输出
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output
ALiBi优势:
1. 无需训练位置编码
- 位置信息通过偏置隐式编码
- 减少参数量
2. 天然支持外推
- 线性衰减不依赖于训练长度
- 可以处理任意长度序列
3. 计算高效
- 偏置可以预计算
- 无额外参数
4. 性能稳定
- 长序列性能下降缓慢
- 比RoPE更适合超长上下文
外推能力对比:
| 方法 | 训练长度 | 推理2x | 推理4x | 推理8x |
|---|---|---|---|---|
| RoPE | 2K | 95% | 75% | 50% |
| RoPE + PI | 2K | 92% | 85% | 70% |
| RoPE + YaRN | 2K | 96% | 90% | 80% |
| ALiBi | 2K | 98% | 95% | 90% |
6. 实践指南¶
6.1 技术选型决策树¶
┌──────────────────────────────────────────────────────────────┐
│ 长上下文技术选型 │
├──────────────────────────────────────────────────────────────┤
│ │
│ Q: 需要多长的上下文? │
│ │
│ ├─ < 8K: │
│ │ └─ 标准Transformer + FlashAttention-2 │
│ │ │
│ ├─ 8K - 32K: │
│ │ ├─ 单GPU: │
│ │ │ └─ FlashAttention-2 + GQA + KV Cache量化 │
│ │ └─ 多GPU: │
│ │ └─ Ring Attention │
│ │ │
│ ├─ 32K - 128K: │
│ │ ├─ 训练: Ring Attention + YaRN │
│ │ └─ 推理: vLLM (PagedAttention) + GQA │
│ │ │
│ └─ > 128K (100K - 1M+): │
│ └─ Ring Attention + KV Cache量化 + 分布式推理 │
│ │
└──────────────────────────────────────────────────────────────┘
6.2 不同场景的推荐配置¶
场景1:长文档问答(32K-128K)
# 推荐配置
config = {
'attention': 'FlashAttention-2',
'position_encoding': 'YaRN',
'kv_cache': 'PagedAttention (vLLM)',
'attention_variant': 'GQA (8 groups)',
'kv_quantization': 'FP8',
'max_context': 128000,
}
# vLLM启动命令
# python -m vllm.entrypoints.api_server \
# --model meta-llama/Llama-3.1-70B \
# --max-context-length 128000 \
# --kv-cache-dtype fp8 \
# --tensor-parallel-size 4
场景2:代码分析(100K+)
# 推荐配置
config = {
'attention': 'Ring Attention',
'position_encoding': 'YaRN',
'kv_cache': 'PagedAttention + 量化',
'attention_variant': 'GQA',
'distributed': '张量并行',
'max_context': 200000,
}
# 需要多GPU支持
# ring-attention --model path/to/model \
# --sequence-length 200000 \
# --ring-size 8
场景3:多轮对话(8K-32K)
# 推荐配置
config = {
'attention': 'FlashAttention-2',
'position_encoding': 'RoPE + NTK-aware',
'kv_cache': 'PagedAttention',
'attention_variant': 'GQA',
'continuous_batching': True,
'max_context': 32000,
}
# vLLM配置
# vllm serve model_name \
# --max-model-len 32768 \
# --enable-prefix-caching \
# --gpu-memory-utilization 0.9
6.3 性能对比表¶
| 配置 | 显存 (70B) | 吞吐量 | TTFT (1K) | 最大长度 |
|---|---|---|---|---|
| 基线 (MHA + FP16) | 140GB | 100 tok/s | 2.5s | 4K |
| + FlashAttention-2 | 80GB | 250 tok/s | 1.2s | 16K |
| + GQA | 45GB | 280 tok/s | 1.0s | 32K |
| + KV量化 (INT8) | 30GB | 260 tok/s | 1.1s | 64K |
| + PagedAttention | 25GB | 350 tok/s | 0.8s | 128K |
| Ring (4x A100) | 20GB/GPU | 400 tok/s | 0.6s | 512K |
6.4 常见问题与解决方案¶
Q1: 显存不足怎么办?
# 按优先级尝试:
solutions = [
"1. 启用KV Cache量化 (FP8/INT8)",
"2. 使用GQA替代MHA",
"3. 减小batch_size",
"4. 使用PagedAttention (vLLM)",
"5. 多卡分布式推理",
]
Q2: 长序列推理速度慢?
# 优化步骤:
optimizations = [
"1. 确保使用FlashAttention-2/3",
"2. 启用连续批处理",
"3. 考虑投机采样 (Speculative Decoding)",
"4. 使用张量并行",
]
Q3: 如何评估外推能力?
def evaluate_extrapolation(model, tokenizer, test_lengths):
"""评估模型的外推能力"""
results = {}
for length in test_lengths:
# 构造测试样本
text = "..." * length # 长文本
inputs = tokenizer(text, return_tensors="pt")
# 测试困惑度
with torch.no_grad():
outputs = model(**inputs, labels=inputs.input_ids)
ppl = torch.exp(outputs.loss).item()
results[length] = {
'perplexity': ppl,
'success': ppl < threshold
}
return results
7. 总结¶
7.1 技术关系图¶
┌──────────────────────────────────────────────────────────────┐
│ 长上下文技术栈 │
├──────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 应用层 (100K - 1M tokens) │ │
│ │ 长文档理解 │ 代码分析 │ 多轮对话 │ RAG │ │
│ └─────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 系统层 (分布式 + 内存管理) │ │
│ │ Ring Attention │ PagedAttention │ Continuous Batching│ │
│ └─────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 算法层 (注意力优化) │ │
│ │ FlashAttention-1/2/3 │ MQA/GQA │ KV量化 │ │
│ └─────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 基础层 (位置编码) │ │
│ │ RoPE │ ALiBi │ YaRN │ NTK-aware │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────┘
7.2 关键要点¶
- Ring Attention:分布式处理超长序列的核心技术
- KV Cache优化:GQA + PagedAttention + 量化是标配组合
- FlashAttention:从算法层面解决O(n²)问题,v3是当前最优
- 位置编码外推:YaRN是RoPE的最佳扩展方案,ALiBi天然支持外推
- 实践选型:根据序列长度和资源选择合适的技术组合
8. 参考资料¶
8.1 核心论文¶
- Ring Attention: "Ring Attention with Blockwise Transformers for Near-Infinite Context" (Liu et al., 2023)
- FlashAttention-1: "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
- FlashAttention-2: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (Dao, 2023)
- FlashAttention-3: "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (Dao et al., 2024)
- PagedAttention: "Efficient Memory Management for Large Language Model Serving with PagedAttention" (vLLM, 2023)
- MQA: "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019)
- GQA: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023)
- YaRN: "YaRN: Efficient Context Window Extension of Large Language Models" (Peng et al., 2023)
- ALiBi: "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (Press et al., 2022)
8.2 开源实现¶
- vLLM - PagedAttention实现
- FlashAttention - 官方实现
- Ring Attention - 分布式实现
最后更新日期:2026-02-21 适用版本:LLM学习教程 v2026