第12章 FlashAttention原理与实现¶
前置知识:建议先阅读深度学习教程的 Transformer 章节(理解 Self-Attention 机制)和本教程 第2章 低精度推理(理解 FP16/BF16/FP8 数值格式)。
本章定位:FlashAttention 是 2023-2026 年大模型推理优化领域最重要的算法创新之一,几乎所有主流推理引擎(vLLM、TensorRT-LLM、SGLang)都以其为核心。它也是深度学习系统/推理优化方向面试的 必考内容。
目录¶
- 12.1 标准 Attention 的瓶颈
- 12.2 FlashAttention 核心原理
- 12.3 FlashAttention-2 改进
- 12.4 FlashAttention-3(Hopper 架构)
- 12.5 代码实战
- 12.6 FlashAttention 在推理中的应用
- 12.7 面试高频题
12.1 标准 Attention 的瓶颈¶
12.1.1 Self-Attention 回顾¶
标准 Scaled Dot-Product Attention 的计算公式:
其中 \(Q, K, V \in \mathbb{R}^{N \times d}\),\(N\) 为序列长度,\(d\) 为头维度(head dimension)。
展开计算步骤:
步骤1: S = Q @ K^T # (N, d) × (d, N) → (N, N) — 计算注意力分数
步骤2: P = softmax(S / √d) # (N, N) → (N, N) — 归一化为概率
步骤3: O = P @ V # (N, N) × (N, d) → (N, d) — 加权求和
12.1.2 内存复杂度分析¶
关键瓶颈在于中间矩阵 \(S = QK^T\):
| 对象 | 形状 | 内存(FP16) | N=4096, d=128 时 |
|---|---|---|---|
| Q, K, V | (N, d) | \(O(Nd)\) | 各 1 MB |
| \(S = QK^T\) | (N, N) | \(O(N^2)\) | 32 MB |
| \(P = \text{softmax}(S)\) | (N, N) | \(O(N^2)\) | 32 MB |
| O | (N, d) | \(O(Nd)\) | 1 MB |
当 \(N = 128K\)(长上下文模型常见长度)时:
重要说明:这是标准 Attention 的理论内存需求。FlashAttention 通过分块计算(Tiling)技术,从不创建完整的 \(N \times N\) 注意力矩阵,而是逐块在 SRAM 中计算,从而将内存复杂度从 \(O(N^2)\) 降低到 \(O(N)\)。
这已经超出单块 A100 (80GB) 显存的很大比例,而模型其他参数也要占用显存。\(O(N^2)\) 的内存占用是标准 Attention 无法扩展到长序列的根本原因。
12.1.3 GPU 内存层次:HBM vs SRAM¶
理解 FlashAttention 的前提是理解 GPU 的内存层次结构:
┌─────────────────────────────────────────────┐
│ GPU 芯片 │
│ │
│ ┌───────────────────────────────────────┐ │
│ │ Streaming Multiprocessor │ │
│ │ ┌─────────────────────────────┐ │ │
│ │ │ SRAM (Shared Memory) │ │ │
│ │ │ 容量: ~192 KB per SM │ │ │
│ │ │ 带宽: ~19 TB/s │ │ │
│ │ │ 延迟: ~几个时钟周期 │ │ │
│ │ └─────────────────────────────┘ │ │
│ └───────────────────────────────────────┘ │
│ ↕ 数据搬运是瓶颈 │
│ ┌───────────────────────────────────────┐ │
│ │ HBM (High Bandwidth Memory) │ │
│ │ 容量: 40-80 GB (A100) │ │
│ │ 带宽: ~2 TB/s (A100) │ │
│ │ 延迟: ~数百个时钟周期 │ │
│ └───────────────────────────────────────┘ │
└─────────────────────────────────────────────┘
关键数据(NVIDIA A100):
| 指标 | HBM2e | SRAM (Shared Memory) | 倍数差异 |
|---|---|---|---|
| 容量 | 40/80 GB | ~20 MB (总计) | HBM 大 4000x |
| 带宽 | ~2 TB/s | ~19 TB/s | SRAM 快 ~10x |
| 延迟 | ~数百周期 | ~几十周期 | SRAM 快 ~10x |
12.1.4 标准 Attention 的内存访问模式¶
标准 Attention 内存访问流程:
HBM SRAM 计算单元
┌──────────┐ ┌─────────┐ ┌────────┐
│ Q (Nd) │ ──读取 Q,K──→ │ 计算 │ │ │
│ K (Nd) │ │ S=QK^T │ ──写回 S──→ │ HBM │
│ V (Nd) │ └─────────┘ │ 存储 S │
│ │ │ (N²) │
│ S (N²) ☆│ ←──写回────────── └────────┘
│ │
│ S (N²) │ ──再读取 S──→ ┌─────────┐
│ │ │ 计算 │
│ │ │ softmax │
│ P (N²) ☆│ ←──写回 P────── │ → P │
│ │ └─────────┘
│ │
│ P (N²) │ ──再读取 P,V──→ ┌─────────┐
│ V (Nd) │ │ 计算 │
│ │ │ O = PV │
│ O (Nd) │ ←──写回 O────── └─────────┘
└──────────┘
☆ 标记: 巨大的 N² 中间结果需要反复在 HBM 中读写!
总 HBM 访问量:\(O(Nd + N^2)\),由于 \(N \gg d\),近似为 \(O(N^2)\)。
12.1.5 Memory-Bound vs Compute-Bound¶
面试考点:为什么 Attention 的瓶颈是 memory-bound 而非 compute-bound?
算术强度(Arithmetic Intensity) 分析:
对于标准 Attention: - 计算量:\(O(N^2 d)\) FLOPs(两次矩阵乘法各 \(O(N^2 d)\)) - 内存访问量:\(O(N^2 + Nd)\) bytes(读写中间矩阵 \(S\) 和 \(P\)) - 算术强度 \(\approx O(d)\)
A100 的计算峰值与带宽之比(分界线): $\(\frac{312 \text{ TFLOPS (FP16)}}{2 \text{ TB/s}} = 156 \text{ FLOPs/Byte}\)$
当 \(d = 64\) 或 \(128\) 时,算术强度远低于 156,因此 Attention 是典型的 memory-bound 操作——GPU 计算单元在等待数据从 HBM 搬运过来,大部分时间处于空闲状态。
这就是 FlashAttention 的优化空间:减少 HBM 访问次数,而非减少计算量。FlashAttention 甚至增加了总计算量(recomputation),但通过减少内存访问获得了巨大加速。
12.2 FlashAttention 核心原理¶
FlashAttention(Dao et al., 2022)的核心思想:IO-Aware Algorithm Design——在算法设计时显式考虑 GPU 内存层次结构。
12.2.1 核心策略:Tiling(分块计算)¶
将 \(Q\)、\(K\)、\(V\) 分成小块(tiles),每次只将一小块加载到 SRAM 中计算,避免在 HBM 中创建完整的 \(N \times N\) 矩阵。
分块策略示意图
K₁ K₂ K₃ K₄ V₁ V₂ V₃ V₄
┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐
│ │ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │ │
└─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘
Q₁ ┌──┐ ┌─────┐ 遍历 K/V 的每个块
│ │ │Sᵢⱼ= │ ← 只在 SRAM 中 不写回 HBM
│ │ │QᵢKⱼᵀ│ 存在的小块
Q₂ │ │ └─────┘
│ │
Q₃ │ │ 每次计算一个 Qᵢ 与所有 Kⱼ 的乘积
│ │ 累加得到输出 Oᵢ
Q₄ └──┘
块大小选择:\(B_r \times B_c\),需要满足 \(B_r \times d + B_c \times d + B_r \times B_c\) 可以放入 SRAM。
12.2.2 Online Softmax 技巧¶
Tiling 的最大挑战是 softmax 需要全局信息:
传统 softmax 需要三次遍历:
Pass 1: m = max(x₁, x₂, ..., xₙ) — 求全局最大值(数值稳定)
Pass 2: l = Σ exp(xᵢ - m) — 求指数和
Pass 3: softmax(xᵢ) = exp(xᵢ - m) / l — 归一化
Online Softmax(Milakov & Gimelshein, 2018)将其压缩为一次遍历。
核心数学推导¶
假设我们已经处理了前 \(j-1\) 个元素,维护: - \(m^{(j-1)}\):前 \(j-1\) 个元素的最大值 - \(l^{(j-1)}\):\(\sum_{i=1}^{j-1} e^{x_i - m^{(j-1)}}\)
当新元素 \(x_j\) 到来时:
关键修正因子:\(e^{m^{(j-1)} - m^{(j)}}\),当新最大值更新时,之前的累积和需要乘以此修正因子来"追溯调整"。
扩展到分块计算(FlashAttention 中的用法)¶
在 FlashAttention 中,我们按块处理 \(K\) 和 \(V\)。对于第 \(i\) 个 Q 块,依次遍历 \(K\) 的第 \(j\) 个块:
初始化: m⁽⁰⁾ = -∞, l⁽⁰⁾ = 0, O⁽⁰⁾ = 0
对于第 j 个 K/V 块:
1. 计算局部注意力分数: Sᵢⱼ = Qᵢ @ Kⱼᵀ / √d
2. 计算局部 max/sum:
mᵢⱼ = rowmax(Sᵢⱼ)
m_new = max(m⁽ʲ⁻¹⁾, mᵢⱼ)
Pᵢⱼ = exp(Sᵢⱼ - m_new)
lᵢⱼ = rowsum(Pᵢⱼ)
l_new = l⁽ʲ⁻¹⁾ · exp(m⁽ʲ⁻¹⁾ - m_new) + lᵢⱼ
3. 更新输出(关键!):
O⁽ʲ⁾ = O⁽ʲ⁻¹⁾ · [exp(m⁽ʲ⁻¹⁾ - m_new) · l⁽ʲ⁻¹⁾ / l_new]
+ Pᵢⱼ @ Vⱼ / l_new
4. 更新统计量: m⁽ʲ⁾ = m_new, l⁽ʲ⁾ = l_new
这就是 FlashAttention 的精髓:通过在线更新 \(m\) 和 \(l\),可以在不存储完整 \(N \times N\) 矩阵的情况下,正确计算 softmax 加权和。
12.2.3 前向传播算法(伪代码)¶
def flash_attention_forward(Q, K, V, B_r, B_c):
"""
Q, K, V: (N, d) 存储在 HBM 中
B_r: Q 的分块大小(行方向)
B_c: K/V 的分块大小(列方向)
"""
N, d = Q.shape
O = zeros(N, d) # 输出,存在 HBM
l = zeros(N) # softmax 分母(每行一个标量)
m = full(N, -inf) # 当前行最大值(每行一个标量)
T_r = ceil(N / B_r) # Q 的块数
T_c = ceil(N / B_c) # K/V 的块数
# --- 外层循环: 遍历 K/V 的块(加载到 SRAM)---
for j in range(T_c):
# 从 HBM 加载 Kⱼ, Vⱼ 到 SRAM
K_j = K[j*B_c : (j+1)*B_c, :] # (B_c, d)
V_j = V[j*B_c : (j+1)*B_c, :] # (B_c, d)
# --- 内层循环: 遍历 Q 的块 ---
for i in range(T_r):
# 从 HBM 加载 Qᵢ, Oᵢ, lᵢ, mᵢ 到 SRAM
Q_i = Q[i*B_r : (i+1)*B_r, :]
O_i = O[i*B_r : (i+1)*B_r, :]
l_i = l[i*B_r : (i+1)*B_r]
m_i = m[i*B_r : (i+1)*B_r]
# Step 1: 计算局部注意力分数(在 SRAM 中)
S_ij = Q_i @ K_j.T / sqrt(d) # (B_r, B_c) — 小矩阵!
# Step 2: Online Softmax 更新
m_ij = rowmax(S_ij) # (B_r,)
m_new = max(m_i, m_ij) # 新的全局最大值
P_ij = exp(S_ij - m_new[:, None]) # (B_r, B_c)
l_new = l_i * exp(m_i - m_new) + rowsum(P_ij)
# Step 3: 更新输出(带修正因子)
O_i = O_i * (l_i * exp(m_i - m_new))[:, None] / l_new[:, None] \
+ P_ij @ V_j / l_new[:, None]
# Step 4: 写回 HBM
m_i = m_new
l_i = l_new
O[i*B_r : (i+1)*B_r, :] = O_i
l[i*B_r : (i+1)*B_r] = l_i
m[i*B_r : (i+1)*B_r] = m_i
return O
12.2.4 反向传播:Recomputation 策略¶
标准 Attention 反向传播需要存储 \(S\) 和 \(P\)(均为 \(N \times N\))。FlashAttention 的 recomputation 策略:
前向传播时只保存: - \(Q, K, V\)(\(O(Nd)\)) - \(O\)(\(O(Nd)\)) - \(l, m\)(各 \(O(N)\))——softmax 统计量
反向传播时: - 重新计算 \(S_{ij}\) 和 \(P_{ij}\)(利用保存的 \(l, m\)) - 按相同的分块模式进行梯度计算 - 增加了约 50% 的 FLOPs,但节省了 \(O(N^2)\) 的内存 → 净效果是端到端加速
内存对比:
标准 Attention:
存储: Q, K, V, S, P, O → O(Nd + N²) — 被 N² 主导
FlashAttention:
存储: Q, K, V, O, l, m → O(Nd) — 线性内存!
12.2.5 复杂度分析¶
| 指标 | 标准 Attention | FlashAttention |
|---|---|---|
| FLOPs | \(O(N^2 d)\) | \(O(N^2 d)\)(前向相同,反向多 ~50%) |
| HBM 内存 | \(O(N^2)\) | \(O(N)\) ✓ |
| HBM 访问次数 | \(O(Nd + N^2)\) | \(O(N^2 d^2 / M)\) ✓ |
其中 \(M\) 为 SRAM 大小。当 \(M = O(Nd)\) 时,HBM 访问量为 \(O(N^2 d^2 / M) = O(Nd)\) 量级,远小于标准 Attention 的 \(O(N^2)\)。
直觉理解:FlashAttention 用额外的计算量(便宜的 SRAM 内计算)换取更少的内存访问(昂贵的 HBM 读写) → 整体更快。
12.3 FlashAttention-2 改进¶
FlashAttention-2(Dao, 2023)在 v1 基础上进一步优化,在 A100 上达到理论峰值 FLOPs 的 72%(v1 约 50%)。
12.3.1 减少非 MatMul 计算¶
FlashAttention-1 中存在大量非矩阵乘法的操作(rescaling、softmax 统计量更新等),这些操作无法利用 Tensor Core,成为性能瓶颈。
FA-2 的优化: - 将 online softmax 的 rescaling 延迟到最后一步 - 前向过程中累加未归一化的 \(O\),最后一次性除以 \(l\) - 减少中间的除法和乘法操作
# FlashAttention-1: 每处理一个 K/V 块都要 rescale
O_i = O_i * (l_old / l_new) + P_ij @ V_j / l_new # 多次除法
# FlashAttention-2: 延迟归一化
O_i = diag(exp(m_old - m_new)) @ O_i + P_ij @ V_j # 只维护未归一化的累加
# ... 最后一步:
O_i = diag(1/l_final) @ O_i # 一次归一化
12.3.2 更好的并行策略¶
FA-1:外层循环遍历 K/V 块,内层循环遍历 Q 块 - 一个 thread block 处理一个 Q 块 - K/V 块之间无法并行(存在依赖)
FA-2:反转循环顺序——外层遍历 Q 块,内层遍历 K/V 块 - 不同 Q 块之间完全独立,可以跨 SM 并行 - 对于 batch_size × num_heads 不够大的情况,可以在序列维度上获得额外并行度
FA-1 并行策略: FA-2 并行策略:
for j in K/V blocks: for i in Q blocks: ← 可并行!
for i in Q blocks: ← 并行 for j in K/V blocks: ← 顺序
compute(Q_i, K_j, V_j) compute(Q_i, K_j, V_j)
12.3.3 Warp 级别优化¶
在一个 thread block 内部,FA-2 对 warp 的分工也做了改进:
FA-1 (split-K):
4 个 warp 各自计算 QK^T 的不同列 → 需要通信来共享结果 → 跨 warp 通信开销
FA-2 (split-Q):
4 个 warp 各自处理 Q 的不同行 → 各自独立累加 → 无需 warp 间同步
最后各 warp 将结果写入不同位置(无冲突)
12.3.4 Causal Masking 优化¶
对于因果注意力(causal attention,decoder 常用),\(S\) 的上三角部分被 mask 为 \(-\infty\):
标准方法:
计算完整 S → 应用 mask → 浪费了上三角的计算
FA-2 的优化:
对于完全在 mask 区域内的块,直接跳过不计算
K₁ K₂ K₃ K₄
Q₁ [计算] [跳过] [跳过] [跳过]
Q₂ [计算] [计算] [跳过] [跳过]
Q₃ [计算] [计算] [计算] [跳过] "跳过" = 完全被 mask 的块
Q₄ [计算] [计算] [计算] [计算]
→ 节省约 50% 的计算(对于因果 attention)
12.3.5 性能对比¶
| 指标 | FA-1 | FA-2 | 提升 |
|---|---|---|---|
| A100 FLOPs 利用率 | ~50% | ~72% | 1.44x |
| 前向速度(seq=2k) | ~124 TFLOPs | ~180 TFLOPs | 1.45x |
| 前向+反向速度 | ~100 TFLOPs | ~155 TFLOPs | 1.55x |
| Causal mask 额外加速 | — | +1.7-1.9x | 显著 |
12.4 FlashAttention-3(Hopper 架构)¶
FlashAttention-3(Shah et al., 2024)专为 NVIDIA Hopper 架构(H100/H200)设计,利用了 Hopper 的新硬件特性。
12.4.1 Hopper 架构新特性¶
| 特性 | Ampere (A100) | Hopper (H100) |
|---|---|---|
| Tensor Core 代数 | 同步 | 异步(WGMMA) |
| 内存搬运 | 同步共享内存加载 | 异步 TMA(Tensor Memory Accelerator) |
| FP8 Tensor Core | 不支持 | 支持(1978 TFLOPS) |
| SRAM 容量 | 192 KB/SM | 228 KB/SM |
12.4.2 三大核心优化¶
1. 异步 Warp Specialization¶
将一个 thread block 内的 warp 分为两类:
┌─────────────────────────────────┐
│ Thread Block │
│ │
│ Producer Warps Consumer Warps│
│ ┌────────────┐ ┌────────────┐│
│ │ 负责数据搬运 │ │ 负责矩阵计算 ││
│ │ GMEM→SMEM │ │ WGMMA指令 ││
│ │ (TMA异步) │ │ (Tensor Core)││
│ └────────────┘ └────────────┘│
│ ↓ 流水线 ↓ │
│ 当 consumer 计算第 j 块时, │
│ producer 已在加载第 j+1 块 │
└─────────────────────────────────┘
数据搬运和计算 重叠执行,隐藏内存延迟。
2. FP8 支持与混合精度¶
GEMM (矩阵乘): FP8 × FP8 → FP32 累加器
→ H100 FP8: 1978 TFLOPS (是 FP16 的 2x)
Softmax / 统计量: 保持 FP32
→ 确保数值精度
策略: Incoherent Processing(块级量化 + 随机舍入)
→ 进一步提升 FP8 精度
3. 低精度 GEMM + 高精度 Softmax 流水线¶
利用 Hopper 的异步 WGMMA 指令,实现 GEMM 和 softmax 的指令级重叠:
时间 →
WGMMA (FP8): |──GEMM block j──|──GEMM block j+1──|──GEMM block j+2──|
Softmax: |──soft j──| |──soft j+1──|
TMA load: |──load j+1──| |──load j+2──|
12.4.3 性能数据¶
| 配置 | FA-2 (A100) | FA-3 (H100 FP16) | FA-3 (H100 FP8) |
|---|---|---|---|
| 前向速度 | ~180 TF/s | ~620 TF/s | ~1200 TF/s |
| 相比 FA-2 加速 | 1x | ~1.5-2x | ~3-4x |
FlashAttention-3 在 H100 上接近硬件理论峰值的 75%。
12.5 代码实战¶
12.5.1 使用 flash-attn 库(最简单方式)¶
"""
使用 flash-attn 库的 API
安装: pip install flash-attn --no-build-isolation
要求: NVIDIA GPU (Ampere/Hopper), CUDA 11.6+, PyTorch 1.12+
> **编译环境要求**:flash-attn 需要从源码编译,需要:
> - CUDA 编译器 (nvcc) 在 PATH 中
> - 与 PyTorch 匹配的 CUDA 版本
> - C++ 编译器 (gcc/g++ 或 MSVC)
> - 充足的编译内存(建议 16GB+ RAM)
>
> 如遇编译问题,可尝试使用预编译 wheel:`pip install flash-attn --no-build-isolation` 或从 GitHub Releases 下载对应版本。
"""
import torch
from flash_attn import flash_attn_func
# 创建输入 (batch, seqlen, nheads, headdim) — 注意维度顺序!
batch_size, seq_len, n_heads, head_dim = 2, 4096, 32, 128
dtype = torch.float16
device = "cuda"
q = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype, device=device)
# 前向传播
output = flash_attn_func(
q, k, v,
dropout_p=0.0, # 推理时设为 0
softmax_scale=None, # 默认 1/sqrt(head_dim)
causal=True, # 因果 mask(decoder 模型使用)
)
# output shape: (batch, seqlen, nheads, headdim)
print(f"Output shape: {output.shape}") # (2, 4096, 32, 128)
# 支持变长序列(packed/varlen)
from flash_attn import flash_attn_varlen_func
# 适用于 batch 内不同样本长度不同的情况,避免 padding 浪费
12.5.2 使用 PyTorch 原生 API(torch 2.0+)¶
"""
PyTorch 2.0+ 内置 FlashAttention 后端
通过 torch.nn.functional.scaled_dot_product_attention 自动调用
"""
import torch
import torch.nn.functional as F
batch_size, n_heads, seq_len, head_dim = 2, 32, 4096, 128
dtype = torch.float16
device = "cuda"
# 注意: PyTorch API 的维度顺序是 (batch, nheads, seqlen, headdim)
q = torch.randn(batch_size, n_heads, seq_len, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, n_heads, seq_len, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, n_heads, seq_len, head_dim, dtype=dtype, device=device)
# 方式1: 自动选择最优后端(推荐)
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # 启用 FlashAttention 后端
enable_math=False, # 禁用数学后端(标准实现)
enable_mem_efficient=False # 禁用 memory-efficient 后端
):
output = F.scaled_dot_product_attention( # F.xxx PyTorch函数式API
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=True, # 因果 mask
scale=None, # 默认 1/sqrt(head_dim)
)
print(f"Output shape: {output.shape}") # (2, 32, 4096, 128)
# 方式2: 让 PyTorch 自动选择后端(最简单)
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# 查看当前使用了哪个后端
# 可通过 torch.backends.cuda.flash_sdp_enabled() 检查
print(f"FlashSDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
12.5.3 Triton 实现简化版 FlashAttention¶
这是面试和深入理解的核心内容——用 Triton 实现 FlashAttention 的分块计算逻辑:
"""
简化版 FlashAttention Forward — Triton 实现
核心展示 Tiling + Online Softmax 的逻辑
参考: Triton 官方教程 + FlashAttention 论文算法1
注意: 这是教学用简化版本,省略了 causal mask、dropout 等特性
"""
import torch
import triton
import triton.language as tl
@triton.jit
def _flash_attn_fwd_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
stride_qb, stride_qh, stride_qm, stride_qd, # Q 的 strides
stride_kb, stride_kh, stride_kn, stride_kd, # K 的 strides
stride_vb, stride_vh, stride_vn, stride_vd, # V 的 strides
stride_ob, stride_oh, stride_om, stride_od, # O 的 strides
N: tl.constexpr, # 序列长度
D: tl.constexpr, # 头维度
BLOCK_M: tl.constexpr, # Q 的分块大小
BLOCK_N: tl.constexpr, # K/V 的分块大小
sm_scale, # softmax 缩放因子 = 1/sqrt(D)
):
# ---- 确定当前 block 负责的区域 ----
block_m_idx = tl.program_id(0) # Q 的块索引
batch_head_idx = tl.program_id(1) # batch * num_heads
# 从 batch_head_idx 解析 batch 和 head
# (此处简化,假设 batch_head_idx 直接用于偏移)
# ---- 计算基址偏移 ----
q_offset = batch_head_idx * stride_qh
k_offset = batch_head_idx * stride_kh
v_offset = batch_head_idx * stride_vh
o_offset = batch_head_idx * stride_oh
# ---- 加载 Q 块到 SRAM ----
offs_m = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) # 行偏移
offs_d = tl.arange(0, D) # 列偏移
# Q_i: (BLOCK_M, D)
q_ptrs = Q_ptr + q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
q_mask = offs_m[:, None] < N
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
# ---- 初始化 Online Softmax 状态 ----
m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32) # 行最大值
l_i = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) # 行指数和
o_i = tl.zeros([BLOCK_M, D], dtype=tl.float32) # 累加输出
# ---- 内层循环: 遍历 K/V 的所有块 ----
for block_n_start in range(0, N, BLOCK_N):
offs_n = block_n_start + tl.arange(0, BLOCK_N)
# 加载 K_j: (BLOCK_N, D)
k_ptrs = K_ptr + k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
k_mask = offs_n[:, None] < N
k = tl.load(k_ptrs, mask=k_mask, other=0.0)
# 加载 V_j: (BLOCK_N, D)
v_ptrs = V_ptr + v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
v = tl.load(v_ptrs, mask=k_mask, other=0.0)
# Step 1: 计算 S_ij = Q_i @ K_j^T — (BLOCK_M, BLOCK_N)
s_ij = tl.dot(q, tl.trans(k)) * sm_scale
# 对超出序列长度的位置 mask
s_ij = tl.where(offs_n[None, :] < N, s_ij, float("-inf"))
# Step 2: Online Softmax 更新
m_ij = tl.max(s_ij, axis=1) # 当前块的行最大值 (BLOCK_M,)
m_new = tl.maximum(m_i, m_ij) # 全局行最大值更新
# 修正因子
alpha = tl.exp(m_i - m_new) # 旧最大值的修正
p_ij = tl.exp(s_ij - m_new[:, None]) # 当前块的 exp 值 (BLOCK_M, BLOCK_N)
l_new = l_i * alpha + tl.sum(p_ij, axis=1)
# Step 3: 更新输出累加器
# 先对旧的 O 做修正(旧 max 变化带来的 rescale)
o_i = o_i * alpha[:, None]
# 加上当前块的贡献
o_i += tl.dot(p_ij.to(v.dtype), v)
# Step 4: 更新统计量
m_i = m_new
l_i = l_new
# ---- 最终归一化 ----
o_i = o_i / l_i[:, None]
# ---- 写回 HBM ----
o_ptrs = O_ptr + o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
o_mask = offs_m[:, None] < N
tl.store(o_ptrs, o_i.to(O_ptr.dtype.element_ty), mask=o_mask)
def flash_attention_triton(q, k, v):
"""
简化版 FlashAttention 包装函数
q, k, v: (batch, nheads, seqlen, headdim) — float16/bfloat16
"""
B, H, N, D = q.shape
assert D in {16, 32, 64, 128}, "headdim 必须是 16/32/64/128" # assert断言
o = torch.empty_like(q)
sm_scale = 1.0 / (D ** 0.5)
BLOCK_M = 128
BLOCK_N = 64
grid = (triton.cdiv(N, BLOCK_M), B * H)
_flash_attn_fwd_kernel[grid](
q, k, v, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
N=N, D=D,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
sm_scale=sm_scale,
)
return o
# ---- 测试 ----
if __name__ == "__main__":
torch.manual_seed(42)
B, H, N, D = 2, 4, 1024, 64
q = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
k = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
v = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
# FlashAttention (Triton 实现)
o_flash = flash_attention_triton(q, k, v)
# 标准 Attention (参考实现)
scale = 1.0 / (D ** 0.5)
attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)
o_ref = attn @ v
# 验证正确性
max_diff = (o_flash.float() - o_ref.float()).abs().max().item() # 将单元素张量转为Python数值
print(f"Max absolute difference: {max_diff:.6f}")
assert max_diff < 1e-2, f"差异过大: {max_diff}"
print("✓ Triton FlashAttention 结果正确!")
12.5.4 性能对比基准测试¶
"""
标准 Attention vs FlashAttention 性能对比
测量不同序列长度下的速度和显存使用
"""
import torch
import torch.nn.functional as F
import time
def benchmark_attention(func, q, k, v, name, warmup=10, repeat=50):
"""通用 benchmark 函数"""
# Warmup
for _ in range(warmup):
_ = func(q, k, v)
torch.cuda.synchronize()
# 记录显存
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
# 计时
start = time.perf_counter()
for _ in range(repeat):
_ = func(q, k, v)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / repeat * 1000 # ms
mem_peak = torch.cuda.max_memory_allocated()
mem_used = (mem_peak - mem_before) / 1024**2 # MB
print(f" {name:30s} | Time: {elapsed:8.2f} ms | Peak Mem: {mem_used:8.1f} MB")
return elapsed, mem_used
def standard_attention(q, k, v):
"""标准 Attention(显式计算 N×N 矩阵)"""
scale = 1.0 / (q.shape[-1] ** 0.5) # [-1]负索引取最后元素
attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)
return attn @ v
def flash_attention_pytorch(q, k, v):
"""PyTorch SDPA(自动选择 FlashAttention 后端)"""
return F.scaled_dot_product_attention(q, k, v, is_causal=False)
print("=" * 80)
print("Attention Performance Benchmark")
print("=" * 80)
B, H, D = 2, 32, 128
for N in [512, 1024, 2048, 4096, 8192]:
q = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
k = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
v = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
print(f"\n--- Seq Length = {N} ---")
try: # try/except捕获异常
t1, m1 = benchmark_attention(standard_attention, q, k, v, "Standard Attention")
except torch.cuda.OutOfMemoryError:
print(f" {'Standard Attention':30s} | OOM!")
t1, m1 = float("inf"), float("inf")
t2, m2 = benchmark_attention(flash_attention_pytorch, q, k, v, "FlashAttention (PyTorch SDPA)")
if t1 != float("inf"):
print(f" Speedup: {t1/t2:.2f}x | Memory Saving: {m1/m2:.2f}x")
print("\n" + "=" * 80)
预期输出趋势:
--- Seq Length = 512 ---
Standard Attention | Time: 0.35 ms | Peak Mem: 48.0 MB
FlashAttention (PyTorch SDPA) | Time: 0.25 ms | Peak Mem: 16.0 MB
Speedup: 1.40x | Memory Saving: 3.00x
--- Seq Length = 4096 ---
Standard Attention | Time: 12.50 ms | Peak Mem: 2048.0 MB
FlashAttention (PyTorch SDPA) | Time: 2.80 ms | Peak Mem: 64.0 MB
Speedup: 4.46x | Memory Saving: 32.00x
--- Seq Length = 8192 ---
Standard Attention | OOM!
FlashAttention (PyTorch SDPA) | Time: 10.50 ms | Peak Mem: 128.0 MB
关键观察: 1. 序列越长,FlashAttention 的优势越明显(\(O(N^2)\) vs \(O(N)\) 内存) 2. 在短序列上,FlashAttention 也有优势(减少 HBM 访问) 3. 标准 Attention 在长序列时可能 OOM,FlashAttention 不会
12.6 FlashAttention 在推理中的应用¶
12.6.1 与 KV-Cache 的结合¶
LLM 自回归推理分为两个阶段:
Prefill 阶段(处理 prompt):
Q: (1, prompt_len, d) ← 所有 token 一起处理
K: (1, prompt_len, d)
V: (1, prompt_len, d)
→ 标准 FlashAttention(compute-bound,批量矩阵乘)
Decode 阶段(逐 token 生成):
Q: (1, 1, d) ← 只有 1 个新 token
K: (1, seq_len, d) ← 包含所有历史 token(从 KV-Cache 读取)
V: (1, seq_len, d)
→ 特殊优化需求(memory-bound,Q 太小导致计算强度低)
KV-Cache 的作用:避免 decode 阶段重复计算历史 token 的 K、V。
KV-Cache 工作流:
Step 1: Prefill
输入: "The cat sat on the"
计算: K₁..₅, V₁..₅ → 存入 KV-Cache
Step 2: Decode token "mat"
Q₆ = Embed("mat") @ W_Q
K-Cache: [K₁, K₂, K₃, K₄, K₅] → append K₆
V-Cache: [V₁, V₂, V₃, V₄, V₅] → append V₆
Attention: Q₆ @ [K₁..₆]^T → softmax → @ [V₁..₆]
Step 3: Decode token "is"
Q₇ = Embed("is") @ W_Q
K-Cache: [K₁..₆] → append K₇ (不重新计算K₁..₆)
...
12.6.2 PagedAttention(vLLM)¶
问题:KV-Cache 需要预分配连续内存,导致大量内存碎片和浪费。
PagedAttention 的解决方案:借鉴操作系统的虚拟内存/分页机制:
传统 KV-Cache:
┌──────────────────────────────────────────────┐
│ Request 1 KV-Cache │ 碎片 │ Request 2 KV │ 碎片 │ ...
└──────────────────────────────────────────────┘
预分配最大长度 → 实际使用 30% → 70% 浪费
PagedAttention:
Page Table: Physical Pages:
Req 1: [P3, P7, P1, ...] ┌────┐ ┌────┐ ┌────┐
Req 2: [P2, P5, P8, ...] │ P1 │ │ P2 │ │ P3 │ ...
└────┘ └────┘ └────┘
每个 page 存固定数量 token 的 KV
按需分配 → 接近 0% 浪费
支持 copy-on-write → beam search 共享 KV
核心优势: - 内存利用率从 ~30% 提升到 ~95%+ - 支持更大的 batch size → 更高的吞吐量 - 结合 FlashAttention kernel 实现高效的分页 attention 计算
12.6.3 FlashDecoding¶
问题:Decode 阶段 \(Q\) 只有 1 行,标准 FlashAttention 无法充分利用 GPU 并行度。
Prefill: Q=(1, 2048, d) → 2048 / BLOCK_M = 16 个并行块 ✓
Decode: Q=(1, 1, d) → 1 / BLOCK_M = 1 个并行块 ✗ (GPU 利用率极低)
FlashDecoding(Dao et al., 2023)的策略:在 KV 序列维度上并行:
标准 FlashAttention (Decode):
1 个 block 顺序遍历所有 K/V 块 → 慢
FlashDecoding:
Step 1: Split-K — 多个 block 各处理 K/V 的一部分
Block 0: K[0:1024], V[0:1024] → (partial_O₀, partial_l₀, partial_m₀)
Block 1: K[1024:2048], V[1024:2048] → (partial_O₁, partial_l₁, partial_m₁)
Block 2: K[2048:3072], V[2048:3072] → (partial_O₂, partial_l₂, partial_m₂)
...
→ 所有 block 并行执行!
Step 2: Reduce — 一个小 kernel 合并所有部分结果
使用 Online Softmax 的合并公式:
m_global = max(m₀, m₁, m₂, ...)
l_global = Σ lᵢ · exp(mᵢ - m_global)
O = Σ Oᵢ · exp(mᵢ - m_global) / l_global
12.6.4 MHA / MQA / GQA 与 FlashAttention¶
多头注意力的变体对推理效率影响巨大:
Multi-Head Attention (MHA):
Q: H 个头, K: H 个头, V: H 个头
KV-Cache 大小: 2 × L × H × d × sizeof(dtype)
示例 (LLaMA-70B): 2 × 80 × 64 × 128 × 2B = 2.5 GB / token
Multi-Query Attention (MQA):
Q: H 个头, K: 1 个头, V: 1 个头 ← K/V 所有头共享
KV-Cache 缩小到 1/H
缺点: 质量可能下降
Grouped-Query Attention (GQA):
Q: H 个头, K: G 个头, V: G 个头 ← K/V 分 G 组共享
GQA-8: 8 个 KV 头, 每个服务 H/8 个 Q 头
平衡质量与效率(LLaMA-2-70B, Gemma 等采用)
MHA GQA (G=2) MQA
Q heads: Q₁ Q₂ Q₃ Q₄ Q₁ Q₂ Q₃ Q₄ Q₁ Q₂ Q₃ Q₄
↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕ ↕
KV heads: K₁ K₂ K₃ K₄ K₁ K₁ K₂ K₂ K₁ K₁ K₁ K₁
V₁ V₂ V₃ V₄ V₁ V₁ V₂ V₂ V₁ V₁ V₁ V₁
KV-Cache: 4 份 KV 2 份 KV 1 份 KV
FlashAttention 对 GQA/MQA 的支持: - flash-attn 库原生支持 GQA/MQA(flash_attn_func 自动处理 Q 和 KV 头数不同的情况) - KV-Cache 减小 → 更少的 HBM 访问 → decode 更快 - GQA/MQA + FlashAttention + KV-Cache 是 2024-2026 年 LLM 推理的标准配置
12.7 面试高频题¶
题目1:为什么 FlashAttention 快?从 IO 角度解释¶
参考答案:
FlashAttention 快的原因不是减少了计算量(FLOPs 实际上略有增加),而是大幅减少了 GPU HBM 的读写次数。
标准 Attention 需要将 \(O(N^2)\) 大小的中间矩阵 \(S = QK^T\) 和 \(P = \text{softmax}(S)\) 写入 HBM 再读出来。由于 Attention 是 memory-bound 操作(算术强度低于 GPU 的计算带宽比),HBM 访问是性能瓶颈。
FlashAttention 通过分块(Tiling)策略,将 Q/K/V 的小块加载到 SRAM 中完成所有计算(包括 softmax),中间结果 \(S_{ij}\) 从不写回 HBM。这将 HBM 访问量从 \(O(Nd + N^2)\) 降低到 \(O(N^2 d^2 / M)\)(\(M\) 为 SRAM 大小),在实际参数下减少了数倍到数十倍的 HBM 访问。
SRAM 带宽约为 HBM 的 10 倍,因此将计算放在 SRAM 中完成可以大幅提高有效带宽利用率。
题目2:Online Softmax 是如何工作的?为什么它对 FlashAttention 至关重要?¶
参考答案:
传统的数值稳定 softmax 需要三次遍历数据: 1. 找到全局最大值 \(m = \max(x_i)\) 2. 计算指数和 \(l = \sum e^{x_i - m}\) 3. 归一化 \(\text{softmax}(x_i) = e^{x_i - m} / l\)
这要求能看到整行的所有数据,与 FlashAttention 的分块策略矛盾(每次只看一个 K 块)。
Online Softmax 解决了这个问题:维护两个滑动统计量 \(m\)(当前最大值)和 \(l\)(当前指数和),每处理一个新块时:
关键修正因子 \(e^{m_{\text{old}} - m_{\text{new}}}\) 确保了当最大值更新时,之前的累积和被正确调整。同样的修正也应用于输出累加器 \(O\)。
这使得 FlashAttention 可以一块一块地处理 K/V,无需看到全部数据就能正确计算 softmax。
题目3:FlashAttention 如何处理反向传播?什么是 Recomputation?¶
参考答案:
标准 Attention 的反向传播需要用到前向传播中的中间矩阵 \(S\)(注意力分数)和 \(P\)(softmax 概率),这两个矩阵都是 \(O(N^2)\) 大小。
FlashAttention 的 Recomputation 策略:前向传播时不存储 \(S\) 和 \(P\),只保存 \(Q, K, V, O\)(均为 \(O(Nd)\))以及 softmax 统计量 \(l, m\)(均为 \(O(N)\))。
反向传播时,使用保存的 \(Q, K, V, l, m\),按照相同的分块模式重新计算 \(S_{ij}\) 和 \(P_{ij}\),然后计算梯度。
这增加了约 50% 的 FLOPs(重新计算的代价),但节省了 \(O(N^2)\) → \(O(N)\) 的内存。由于 Attention 是 memory-bound,减少内存使用带来的好处(更大的 batch size、更少的 HBM 访问)远超额外计算的开销,端到端训练速度反而更快。
题目4:FlashAttention-1 和 FlashAttention-2 的主要区别是什么?¶
参考答案:
FlashAttention-2 的三大改进:
-
减少非 MatMul 计算:FA-1 在每个 K/V 块处理后都要做 rescaling(除以 \(l\)),这些标量运算无法利用 Tensor Core。FA-2 将归一化延迟到最后一步,累加未归一化的 \(O\),最后一次性归一化。
-
反转循环顺序:FA-1 外层遍历 K/V 块、内层遍历 Q 块;FA-2 反过来,外层遍历 Q 块。这使得不同 Q 块可以完全独立地并行执行,增加了序列维度的并行度。
-
Warp 级别优化:FA-1 使用 split-K(warp 间共享 K),需要跨 warp 通信;FA-2 使用 split-Q(各 warp 处理不同的 Q 行),消除了 warp 间同步开销。
额外地,FA-2 对 causal masking 做了特殊优化:完全位于 mask 区域内的块直接跳过,因果 Attention 节省约 50% 计算。
结果:A100 上 FLOPs 利用率从约 50% 提升到约 72%,速度提升 1.5-1.7 倍。
题目5:FlashAttention 的局限性有哪些?¶
参考答案:
-
硬件限制:需要 NVIDIA GPU(Ampere 或更新架构),不支持 CPU 或其他加速器(社区有部分移植如 AMD ROCm 的支持)。
-
头维度限制:最初版本仅支持 \(d \leq 128\)(FA-2 扩展到 \(d = 256\)),对于非标准头维度的模型可能不适用。
-
自定义 Attention Pattern 困难:对于非标准的 attention mask(如 sliding window、dilated attention、自定义 sparse pattern),需要专门的 kernel 实现,通用性有限。
-
反向传播开销:Recomputation 增加了约 50% 的反向 FLOPs。对于极度 compute-bound 的场景,这可能不划算(但实际中 Attention 几乎总是 memory-bound)。
-
实现复杂性:底层 CUDA/Triton kernel 编写和调试难度高,不容易自定义扩展。
-
短序列收益有限:当序列很短时(\(N < 256\)),标准 Attention 的 \(N^2\) 矩阵可以放入 SRAM,FlashAttention 的 tiling 开销反而可能增加 overhead。
题目6:MQA 和 GQA 的作用是什么?与 FlashAttention 的关系?¶
参考答案:
Multi-Query Attention (MQA) 让所有 Q 头共享一组 K/V,将 KV-Cache 缩小到 \(1/H\)。Grouped-Query Attention (GQA) 是折中方案,\(G\) 组 KV 头各服务 \(H/G\) 个 Q 头。
核心作用: 1. 减少 KV-Cache 大小:对于 70B 级别模型,MHA 的 KV-Cache 可达 2.5GB/token,GQA-8 缩减到约 300MB/token 2. 减少 decode 阶段的内存带宽需求:KV-Cache 是 decode 阶段的主要 IO 瓶颈 3. 增大可服务的 batch size:显存省出来可以放更多请求
与 FlashAttention 的关系是互补的: - FlashAttention 优化了 Attention 计算本身的 IO 效率 - GQA/MQA 优化了 KV-Cache 的存储和 IO - 两者结合是当前 LLM 推理的标准实践
题目7:PagedAttention 解决什么问题?它是如何工作的?¶
参考答案:
问题:传统 KV-Cache 管理需要为每个请求预分配定长连续内存(通常按最大可能长度分配)。由于实际生成长度远小于最大值,导致 60-80% 的显存浪费(内部碎片)。此外,不同请求长度不同,显存无法灵活复用(外部碎片)。
PagedAttention 的解决方案(借鉴 OS 虚拟内存机制):
- 将 KV-Cache 分成固定大小的 页(page),每个 page 存储固定数量 token 的 KV 向量
- 维护一个 页表(page table) 记录每个请求的逻辑页到物理页的映射
- 物理页在 GPU 显存中非连续存储,按需动态分配
关键好处: - 内存利用率从 30% 提升到 95%+ - 支持 copy-on-write:beam search 中多个候选可以共享前缀的 KV-Cache 页 - 更大的 batch size → 更高的吞吐量
vLLM 是 PagedAttention 的代表性实现,它通过定制化的 CUDA kernel 支持在非连续页上执行高效的 Attention 计算。
题目8:FlashDecoding 是什么?解决了什么问题?¶
参考答案:
问题:在自回归推理的 decode 阶段,每次只生成一个新 token,因此 \(Q\) 只有一行。标准 FlashAttention 在 Q 维度上并行(每个 block 处理若干 Q 行),单行 Q 意味着只有 \(\text{batch} \times \text{num\_heads}\) 个并行块,在 batch 较小或高端 GPU(有大量 SM)上无法充分利用硬件。
FlashDecoding 的方案——在 KV 序列维度上额外并行:
- Split-K 阶段:将 KV 序列分成若干段,每段由一个独立的 block 并行处理,计算出局部结果 \((O_{\text{partial}}, l_{\text{partial}}, m_{\text{partial}})\)
- Reduce 阶段:一个轻量 kernel 使用 Online Softmax 合并公式将所有局部结果合并为最终输出
这样并行块数变为 \(\text{batch} \times \text{num\_heads} \times \text{num\_splits}\),显著提高了 GPU 利用率。在 batch=1 的场景下可以获得 3-5 倍的加速。
总结与知识图谱¶
FlashAttention 知识体系:
标准 Attention 瓶颈
├── O(N²) 内存 → 长序列 OOM
├── Memory-bound(算术强度 < GPU 分界线)
└── 频繁 HBM 读写中间矩阵
FlashAttention 核心
├── IO-Aware Algorithm Design
├── Tiling: Q/K/V 分块加载到 SRAM
├── Online Softmax: 增量计算 softmax
├── Recomputation: 不存 N² 矩阵,反向重算
└── 复杂度: O(N²) 内存 → O(N)
版本演进
├── FA-1: 基础 Tiling + Online Softmax
├── FA-2: 减少非 MatMul 操作 + 循环反转 + warp 优化
└── FA-3: FP8 + 异步 Tensor Core + warp specialization (H100)
推理应用
├── KV-Cache: 避免重复计算
├── PagedAttention: 分页管理 KV-Cache
├── FlashDecoding: decode 阶段 KV 维度并行
└── GQA/MQA: 减少 KV-Cache 大小
参考文献¶
- Dao, T., et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022.
- Dao, T. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024.
- Shah, J., et al. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." 2024.
- Milakov, M. & Gimelshein, N. "Online normalizer calculation for softmax." arXiv 2018.
- Kwon, W., et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023.
- Dao, T., et al. "FlashDecoding." Blog post, 2023.
- Ainslie, J., et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023.
下一章预告:第13章 推测解码与推理加速 将介绍投机解码技术——通过小模型"猜测"大模型输出来加速推理,是与 FlashAttention 互补的另一项重要推理加速技术。