跳转至

第12章 FlashAttention原理与实现

前置知识:建议先阅读深度学习教程的 Transformer 章节(理解 Self-Attention 机制)和本教程 第2章 低精度推理(理解 FP16/BF16/FP8 数值格式)。

本章定位:FlashAttention 是 2023-2026 年大模型推理优化领域最重要的算法创新之一,几乎所有主流推理引擎(vLLM、TensorRT-LLM、SGLang)都以其为核心。它也是深度学习系统/推理优化方向面试的 必考内容


目录


12.1 标准 Attention 的瓶颈

12.1.1 Self-Attention 回顾

标准 Scaled Dot-Product Attention 的计算公式:

\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

其中 \(Q, K, V \in \mathbb{R}^{N \times d}\)\(N\) 为序列长度,\(d\) 为头维度(head dimension)。

展开计算步骤:

Text Only
步骤1: S = Q @ K^T          # (N, d) × (d, N) → (N, N)   — 计算注意力分数
步骤2: P = softmax(S / √d)  # (N, N) → (N, N)            — 归一化为概率
步骤3: O = P @ V            # (N, N) × (N, d) → (N, d)   — 加权求和

12.1.2 内存复杂度分析

关键瓶颈在于中间矩阵 \(S = QK^T\)

对象 形状 内存(FP16) N=4096, d=128 时
Q, K, V (N, d) \(O(Nd)\) 各 1 MB
\(S = QK^T\) (N, N) \(O(N^2)\) 32 MB
\(P = \text{softmax}(S)\) (N, N) \(O(N^2)\) 32 MB
O (N, d) \(O(Nd)\) 1 MB

\(N = 128K\)(长上下文模型常见长度)时:

\[S \text{ 的内存} = 128K \times 128K \times 2\text{B} = 32\text{ GB}\]

重要说明:这是标准 Attention 的理论内存需求。FlashAttention 通过分块计算(Tiling)技术,从不创建完整的 \(N \times N\) 注意力矩阵,而是逐块在 SRAM 中计算,从而将内存复杂度从 \(O(N^2)\) 降低到 \(O(N)\)

这已经超出单块 A100 (80GB) 显存的很大比例,而模型其他参数也要占用显存。\(O(N^2)\) 的内存占用是标准 Attention 无法扩展到长序列的根本原因。

12.1.3 GPU 内存层次:HBM vs SRAM

理解 FlashAttention 的前提是理解 GPU 的内存层次结构:

Text Only
┌─────────────────────────────────────────────┐
│                  GPU 芯片                     │
│                                               │
│   ┌───────────────────────────────────────┐   │
│   │          Streaming Multiprocessor      │   │
│   │   ┌─────────────────────────────┐     │   │
│   │   │   SRAM (Shared Memory)      │     │   │
│   │   │   容量: ~192 KB per SM      │     │   │
│   │   │   带宽: ~19 TB/s            │     │   │
│   │   │   延迟: ~几个时钟周期        │     │   │
│   │   └─────────────────────────────┘     │   │
│   └───────────────────────────────────────┘   │
│                    ↕ 数据搬运是瓶颈              │
│   ┌───────────────────────────────────────┐   │
│   │         HBM (High Bandwidth Memory)    │   │
│   │   容量: 40-80 GB (A100)               │   │
│   │   带宽: ~2 TB/s (A100)                │   │
│   │   延迟: ~数百个时钟周期                │   │
│   └───────────────────────────────────────┘   │
└─────────────────────────────────────────────┘

关键数据(NVIDIA A100)

指标 HBM2e SRAM (Shared Memory) 倍数差异
容量 40/80 GB ~20 MB (总计) HBM 大 4000x
带宽 ~2 TB/s ~19 TB/s SRAM 快 ~10x
延迟 ~数百周期 ~几十周期 SRAM 快 ~10x

12.1.4 标准 Attention 的内存访问模式

Text Only
标准 Attention 内存访问流程:

HBM                              SRAM                    计算单元
┌──────────┐                   ┌─────────┐              ┌────────┐
│ Q (Nd)   │ ──读取 Q,K──→    │ 计算     │              │        │
│ K (Nd)   │                   │ S=QK^T   │ ──写回 S──→  │ HBM    │
│ V (Nd)   │                   └─────────┘              │ 存储 S │
│          │                                             │ (N²)   │
│ S (N²) ☆│ ←──写回──────────                           └────────┘
│          │
│ S (N²)   │ ──再读取 S──→    ┌─────────┐
│          │                   │ 计算     │
│          │                   │ softmax  │
│ P (N²) ☆│ ←──写回 P──────  │ → P     │
│          │                   └─────────┘
│          │
│ P (N²)   │ ──再读取 P,V──→  ┌─────────┐
│ V (Nd)   │                   │ 计算     │
│          │                   │ O = PV   │
│ O (Nd)   │ ←──写回 O──────  └─────────┘
└──────────┘

☆ 标记: 巨大的 N² 中间结果需要反复在 HBM 中读写!

总 HBM 访问量\(O(Nd + N^2)\),由于 \(N \gg d\),近似为 \(O(N^2)\)

12.1.5 Memory-Bound vs Compute-Bound

面试考点:为什么 Attention 的瓶颈是 memory-bound 而非 compute-bound?

算术强度(Arithmetic Intensity) 分析:

\[\text{算术强度} = \frac{\text{计算量 (FLOPs)}}{\text{内存访问量 (Bytes)}}\]

对于标准 Attention: - 计算量:\(O(N^2 d)\) FLOPs(两次矩阵乘法各 \(O(N^2 d)\)) - 内存访问量:\(O(N^2 + Nd)\) bytes(读写中间矩阵 \(S\)\(P\)) - 算术强度 \(\approx O(d)\)

A100 的计算峰值与带宽之比(分界线): $\(\frac{312 \text{ TFLOPS (FP16)}}{2 \text{ TB/s}} = 156 \text{ FLOPs/Byte}\)$

\(d = 64\)\(128\) 时,算术强度远低于 156,因此 Attention 是典型的 memory-bound 操作——GPU 计算单元在等待数据从 HBM 搬运过来,大部分时间处于空闲状态。

这就是 FlashAttention 的优化空间:减少 HBM 访问次数,而非减少计算量。FlashAttention 甚至增加了总计算量(recomputation),但通过减少内存访问获得了巨大加速。


12.2 FlashAttention 核心原理

FlashAttention(Dao et al., 2022)的核心思想:IO-Aware Algorithm Design——在算法设计时显式考虑 GPU 内存层次结构。

12.2.1 核心策略:Tiling(分块计算)

\(Q\)\(K\)\(V\) 分成小块(tiles),每次只将一小块加载到 SRAM 中计算,避免在 HBM 中创建完整的 \(N \times N\) 矩阵。

Text Only
                     分块策略示意图

        K₁    K₂    K₃    K₄         V₁    V₂    V₃    V₄
      ┌─────┬─────┬─────┬─────┐    ┌─────┬─────┬─────┬─────┐
      │     │     │     │     │    │     │     │     │     │
      │     │     │     │     │    │     │     │     │     │
      └─────┴─────┴─────┴─────┘    └─────┴─────┴─────┴─────┘

Q₁ ┌──┐   ┌─────┐                    遍历 K/V 的每个块
   │  │   │Sᵢⱼ= │ ← 只在 SRAM 中     不写回 HBM
   │  │   │QᵢKⱼᵀ│   存在的小块
Q₂ │  │   └─────┘
   │  │
Q₃ │  │   每次计算一个 Qᵢ 与所有 Kⱼ 的乘积
   │  │   累加得到输出 Oᵢ
Q₄ └──┘

块大小选择\(B_r \times B_c\),需要满足 \(B_r \times d + B_c \times d + B_r \times B_c\) 可以放入 SRAM。

12.2.2 Online Softmax 技巧

Tiling 的最大挑战是 softmax 需要全局信息

\[\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}\]

传统 softmax 需要三次遍历:

Text Only
Pass 1: m = max(x₁, x₂, ..., xₙ)           — 求全局最大值(数值稳定)
Pass 2: l = Σ exp(xᵢ - m)                    — 求指数和
Pass 3: softmax(xᵢ) = exp(xᵢ - m) / l       — 归一化

Online Softmax(Milakov & Gimelshein, 2018)将其压缩为一次遍历

核心数学推导

假设我们已经处理了前 \(j-1\) 个元素,维护: - \(m^{(j-1)}\):前 \(j-1\) 个元素的最大值 - \(l^{(j-1)}\)\(\sum_{i=1}^{j-1} e^{x_i - m^{(j-1)}}\)

当新元素 \(x_j\) 到来时:

\[m^{(j)} = \max(m^{(j-1)}, x_j)\]
\[l^{(j)} = l^{(j-1)} \cdot e^{m^{(j-1)} - m^{(j)}} + e^{x_j - m^{(j)}}\]

关键修正因子\(e^{m^{(j-1)} - m^{(j)}}\),当新最大值更新时,之前的累积和需要乘以此修正因子来"追溯调整"。

扩展到分块计算(FlashAttention 中的用法)

在 FlashAttention 中,我们按块处理 \(K\)\(V\)。对于第 \(i\) 个 Q 块,依次遍历 \(K\) 的第 \(j\) 个块:

Text Only
初始化: m⁽⁰⁾ = -∞, l⁽⁰⁾ = 0, O⁽⁰⁾ = 0

对于第 j 个 K/V 块:
  1. 计算局部注意力分数: Sᵢⱼ = Qᵢ @ Kⱼᵀ / √d
  2. 计算局部 max/sum:
     mᵢⱼ = rowmax(Sᵢⱼ)
     m_new = max(m⁽ʲ⁻¹⁾, mᵢⱼ)
     Pᵢⱼ = exp(Sᵢⱼ - m_new)
     lᵢⱼ = rowsum(Pᵢⱼ)
     l_new = l⁽ʲ⁻¹⁾ · exp(m⁽ʲ⁻¹⁾ - m_new) + lᵢⱼ
  3. 更新输出(关键!):
     O⁽ʲ⁾ = O⁽ʲ⁻¹⁾ · [exp(m⁽ʲ⁻¹⁾ - m_new) · l⁽ʲ⁻¹⁾ / l_new]
          + Pᵢⱼ @ Vⱼ / l_new
  4. 更新统计量: m⁽ʲ⁾ = m_new, l⁽ʲ⁾ = l_new

这就是 FlashAttention 的精髓:通过在线更新 \(m\)\(l\),可以在不存储完整 \(N \times N\) 矩阵的情况下,正确计算 softmax 加权和。

12.2.3 前向传播算法(伪代码)

Python
def flash_attention_forward(Q, K, V, B_r, B_c):
    """
    Q, K, V: (N, d) 存储在 HBM 中
    B_r: Q 的分块大小(行方向)
    B_c: K/V 的分块大小(列方向)
    """
    N, d = Q.shape
    O = zeros(N, d)          # 输出,存在 HBM
    l = zeros(N)             # softmax 分母(每行一个标量)
    m = full(N, -inf)        # 当前行最大值(每行一个标量)

    T_r = ceil(N / B_r)      # Q 的块数
    T_c = ceil(N / B_c)      # K/V 的块数

    # --- 外层循环: 遍历 K/V 的块(加载到 SRAM)---
    for j in range(T_c):
        # 从 HBM 加载 Kⱼ, Vⱼ 到 SRAM
        K_j = K[j*B_c : (j+1)*B_c, :]   # (B_c, d)
        V_j = V[j*B_c : (j+1)*B_c, :]   # (B_c, d)

        # --- 内层循环: 遍历 Q 的块 ---
        for i in range(T_r):
            # 从 HBM 加载 Qᵢ, Oᵢ, lᵢ, mᵢ 到 SRAM
            Q_i = Q[i*B_r : (i+1)*B_r, :]
            O_i = O[i*B_r : (i+1)*B_r, :]
            l_i = l[i*B_r : (i+1)*B_r]
            m_i = m[i*B_r : (i+1)*B_r]

            # Step 1: 计算局部注意力分数(在 SRAM 中)
            S_ij = Q_i @ K_j.T / sqrt(d)    # (B_r, B_c) — 小矩阵!

            # Step 2: Online Softmax 更新
            m_ij = rowmax(S_ij)              # (B_r,)
            m_new = max(m_i, m_ij)           # 新的全局最大值

            P_ij = exp(S_ij - m_new[:, None])  # (B_r, B_c)

            l_new = l_i * exp(m_i - m_new) + rowsum(P_ij)

            # Step 3: 更新输出(带修正因子)
            O_i = O_i * (l_i * exp(m_i - m_new))[:, None] / l_new[:, None] \
                + P_ij @ V_j / l_new[:, None]

            # Step 4: 写回 HBM
            m_i = m_new
            l_i = l_new
            O[i*B_r : (i+1)*B_r, :] = O_i
            l[i*B_r : (i+1)*B_r] = l_i
            m[i*B_r : (i+1)*B_r] = m_i

    return O

12.2.4 反向传播:Recomputation 策略

标准 Attention 反向传播需要存储 \(S\)\(P\)(均为 \(N \times N\))。FlashAttention 的 recomputation 策略:

前向传播时只保存: - \(Q, K, V\)\(O(Nd)\)) - \(O\)\(O(Nd)\)) - \(l, m\)(各 \(O(N)\))——softmax 统计量

反向传播时: - 重新计算 \(S_{ij}\)\(P_{ij}\)(利用保存的 \(l, m\)) - 按相同的分块模式进行梯度计算 - 增加了约 50% 的 FLOPs,但节省了 \(O(N^2)\) 的内存 → 净效果是端到端加速

Text Only
内存对比:

标准 Attention:
  存储: Q, K, V, S, P, O → O(Nd + N²) — 被 N² 主导

FlashAttention:
  存储: Q, K, V, O, l, m  → O(Nd)     — 线性内存!

12.2.5 复杂度分析

指标 标准 Attention FlashAttention
FLOPs \(O(N^2 d)\) \(O(N^2 d)\)(前向相同,反向多 ~50%)
HBM 内存 \(O(N^2)\) \(O(N)\)
HBM 访问次数 \(O(Nd + N^2)\) \(O(N^2 d^2 / M)\)

其中 \(M\) 为 SRAM 大小。当 \(M = O(Nd)\) 时,HBM 访问量为 \(O(N^2 d^2 / M) = O(Nd)\) 量级,远小于标准 Attention 的 \(O(N^2)\)

直觉理解:FlashAttention 用额外的计算量(便宜的 SRAM 内计算)换取更少的内存访问(昂贵的 HBM 读写) → 整体更快。


12.3 FlashAttention-2 改进

FlashAttention-2(Dao, 2023)在 v1 基础上进一步优化,在 A100 上达到理论峰值 FLOPs 的 72%(v1 约 50%)。

12.3.1 减少非 MatMul 计算

FlashAttention-1 中存在大量非矩阵乘法的操作(rescaling、softmax 统计量更新等),这些操作无法利用 Tensor Core,成为性能瓶颈。

FA-2 的优化: - 将 online softmax 的 rescaling 延迟到最后一步 - 前向过程中累加未归一化的 \(O\),最后一次性除以 \(l\) - 减少中间的除法和乘法操作

Python
# FlashAttention-1: 每处理一个 K/V 块都要 rescale
O_i = O_i * (l_old / l_new) + P_ij @ V_j / l_new    # 多次除法

# FlashAttention-2: 延迟归一化
O_i = diag(exp(m_old - m_new)) @ O_i + P_ij @ V_j    # 只维护未归一化的累加
# ... 最后一步:
O_i = diag(1/l_final) @ O_i                           # 一次归一化

12.3.2 更好的并行策略

FA-1:外层循环遍历 K/V 块,内层循环遍历 Q 块 - 一个 thread block 处理一个 Q 块 - K/V 块之间无法并行(存在依赖)

FA-2反转循环顺序——外层遍历 Q 块,内层遍历 K/V 块 - 不同 Q 块之间完全独立,可以跨 SM 并行 - 对于 batch_size × num_heads 不够大的情况,可以在序列维度上获得额外并行度

Text Only
FA-1 并行策略:                    FA-2 并行策略:
for j in K/V blocks:              for i in Q blocks:        ← 可并行!
  for i in Q blocks: ← 并行        for j in K/V blocks:    ← 顺序
    compute(Q_i, K_j, V_j)           compute(Q_i, K_j, V_j)

12.3.3 Warp 级别优化

在一个 thread block 内部,FA-2 对 warp 的分工也做了改进:

Text Only
FA-1 (split-K):
  4 个 warp 各自计算 QK^T 的不同列 → 需要通信来共享结果 → 跨 warp 通信开销

FA-2 (split-Q):
  4 个 warp 各自处理 Q 的不同行 → 各自独立累加 → 无需 warp 间同步
  最后各 warp 将结果写入不同位置(无冲突)

12.3.4 Causal Masking 优化

对于因果注意力(causal attention,decoder 常用),\(S\) 的上三角部分被 mask 为 \(-\infty\)

Text Only
标准方法:
  计算完整 S → 应用 mask → 浪费了上三角的计算

FA-2 的优化:
  对于完全在 mask 区域内的块,直接跳过不计算

         K₁   K₂   K₃   K₄
  Q₁   [计算] [跳过] [跳过] [跳过]
  Q₂   [计算] [计算] [跳过] [跳过]
  Q₃   [计算] [计算] [计算] [跳过]     "跳过" = 完全被 mask 的块
  Q₄   [计算] [计算] [计算] [计算]

  → 节省约 50% 的计算(对于因果 attention)

12.3.5 性能对比

指标 FA-1 FA-2 提升
A100 FLOPs 利用率 ~50% ~72% 1.44x
前向速度(seq=2k) ~124 TFLOPs ~180 TFLOPs 1.45x
前向+反向速度 ~100 TFLOPs ~155 TFLOPs 1.55x
Causal mask 额外加速 +1.7-1.9x 显著

12.4 FlashAttention-3(Hopper 架构)

FlashAttention-3(Shah et al., 2024)专为 NVIDIA Hopper 架构(H100/H200)设计,利用了 Hopper 的新硬件特性。

12.4.1 Hopper 架构新特性

特性 Ampere (A100) Hopper (H100)
Tensor Core 代数 同步 异步(WGMMA)
内存搬运 同步共享内存加载 异步 TMA(Tensor Memory Accelerator)
FP8 Tensor Core 不支持 支持(1978 TFLOPS)
SRAM 容量 192 KB/SM 228 KB/SM

12.4.2 三大核心优化

1. 异步 Warp Specialization

将一个 thread block 内的 warp 分为两类:

Text Only
┌─────────────────────────────────┐
│         Thread Block            │
│                                 │
│  Producer Warps    Consumer Warps│
│  ┌────────────┐  ┌────────────┐│
│  │ 负责数据搬运 │  │ 负责矩阵计算 ││
│  │ GMEM→SMEM  │  │ WGMMA指令  ││
│  │ (TMA异步)   │  │ (Tensor Core)││
│  └────────────┘  └────────────┘│
│        ↓ 流水线 ↓                │
│  当 consumer 计算第 j 块时,       │
│  producer 已在加载第 j+1 块       │
└─────────────────────────────────┘

数据搬运和计算 重叠执行,隐藏内存延迟。

2. FP8 支持与混合精度

Text Only
GEMM (矩阵乘): FP8 × FP8 → FP32 累加器
  → H100 FP8: 1978 TFLOPS (是 FP16 的 2x)

Softmax / 统计量: 保持 FP32
  → 确保数值精度

策略: Incoherent Processing(块级量化 + 随机舍入)
  → 进一步提升 FP8 精度

3. 低精度 GEMM + 高精度 Softmax 流水线

利用 Hopper 的异步 WGMMA 指令,实现 GEMM 和 softmax 的指令级重叠

Text Only
时间 →
WGMMA (FP8): |──GEMM block j──|──GEMM block j+1──|──GEMM block j+2──|
Softmax:                |──soft j──|                |──soft j+1──|
TMA load:     |──load j+1──|          |──load j+2──|

12.4.3 性能数据

配置 FA-2 (A100) FA-3 (H100 FP16) FA-3 (H100 FP8)
前向速度 ~180 TF/s ~620 TF/s ~1200 TF/s
相比 FA-2 加速 1x ~1.5-2x ~3-4x

FlashAttention-3 在 H100 上接近硬件理论峰值的 75%


12.5 代码实战

12.5.1 使用 flash-attn 库(最简单方式)

Python
"""
使用 flash-attn 库的 API
安装: pip install flash-attn --no-build-isolation
要求: NVIDIA GPU (Ampere/Hopper), CUDA 11.6+, PyTorch 1.12+

> **编译环境要求**:flash-attn 需要从源码编译,需要:
> - CUDA 编译器 (nvcc) 在 PATH 中
> - 与 PyTorch 匹配的 CUDA 版本
> - C++ 编译器 (gcc/g++ 或 MSVC)
> - 充足的编译内存(建议 16GB+ RAM)
>
> 如遇编译问题,可尝试使用预编译 wheel:`pip install flash-attn --no-build-isolation` 或从 GitHub Releases 下载对应版本。
"""
import torch
from flash_attn import flash_attn_func

# 创建输入 (batch, seqlen, nheads, headdim) — 注意维度顺序!
batch_size, seq_len, n_heads, head_dim = 2, 4096, 32, 128
dtype = torch.float16
device = "cuda"

q = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype, device=device)

# 前向传播
output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,           # 推理时设为 0
    softmax_scale=None,      # 默认 1/sqrt(head_dim)
    causal=True,             # 因果 mask(decoder 模型使用)
)
# output shape: (batch, seqlen, nheads, headdim)
print(f"Output shape: {output.shape}")  # (2, 4096, 32, 128)

# 支持变长序列(packed/varlen)
from flash_attn import flash_attn_varlen_func
# 适用于 batch 内不同样本长度不同的情况,避免 padding 浪费

12.5.2 使用 PyTorch 原生 API(torch 2.0+)

Python
"""
PyTorch 2.0+ 内置 FlashAttention 后端
通过 torch.nn.functional.scaled_dot_product_attention 自动调用
"""
import torch
import torch.nn.functional as F

batch_size, n_heads, seq_len, head_dim = 2, 32, 4096, 128
dtype = torch.float16
device = "cuda"

# 注意: PyTorch API 的维度顺序是 (batch, nheads, seqlen, headdim)
q = torch.randn(batch_size, n_heads, seq_len, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, n_heads, seq_len, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, n_heads, seq_len, head_dim, dtype=dtype, device=device)

# 方式1: 自动选择最优后端(推荐)
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,       # 启用 FlashAttention 后端
    enable_math=False,       # 禁用数学后端(标准实现)
    enable_mem_efficient=False  # 禁用 memory-efficient 后端
):
    output = F.scaled_dot_product_attention(  # F.xxx PyTorch函数式API
        q, k, v,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=True,      # 因果 mask
        scale=None,          # 默认 1/sqrt(head_dim)
    )

print(f"Output shape: {output.shape}")  # (2, 32, 4096, 128)

# 方式2: 让 PyTorch 自动选择后端(最简单)
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# 查看当前使用了哪个后端
# 可通过 torch.backends.cuda.flash_sdp_enabled() 检查
print(f"FlashSDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")

12.5.3 Triton 实现简化版 FlashAttention

这是面试和深入理解的核心内容——用 Triton 实现 FlashAttention 的分块计算逻辑:

Python
"""
简化版 FlashAttention Forward — Triton 实现
核心展示 Tiling + Online Softmax 的逻辑
参考: Triton 官方教程 + FlashAttention 论文算法1

注意: 这是教学用简化版本,省略了 causal mask、dropout 等特性
"""
import torch
import triton
import triton.language as tl

@triton.jit
def _flash_attn_fwd_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    stride_qb, stride_qh, stride_qm, stride_qd,  # Q 的 strides
    stride_kb, stride_kh, stride_kn, stride_kd,  # K 的 strides
    stride_vb, stride_vh, stride_vn, stride_vd,  # V 的 strides
    stride_ob, stride_oh, stride_om, stride_od,  # O 的 strides
    N: tl.constexpr,           # 序列长度
    D: tl.constexpr,           # 头维度
    BLOCK_M: tl.constexpr,     # Q 的分块大小
    BLOCK_N: tl.constexpr,     # K/V 的分块大小
    sm_scale,                  # softmax 缩放因子 = 1/sqrt(D)
):
    # ---- 确定当前 block 负责的区域 ----
    block_m_idx = tl.program_id(0)   # Q 的块索引
    batch_head_idx = tl.program_id(1)  # batch * num_heads
    # 从 batch_head_idx 解析 batch 和 head
    # (此处简化,假设 batch_head_idx 直接用于偏移)

    # ---- 计算基址偏移 ----
    q_offset = batch_head_idx * stride_qh
    k_offset = batch_head_idx * stride_kh
    v_offset = batch_head_idx * stride_vh
    o_offset = batch_head_idx * stride_oh

    # ---- 加载 Q 块到 SRAM ----
    offs_m = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)   # 行偏移
    offs_d = tl.arange(0, D)                                   # 列偏移

    # Q_i: (BLOCK_M, D)
    q_ptrs = Q_ptr + q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
    q_mask = offs_m[:, None] < N
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)

    # ---- 初始化 Online Softmax 状态 ----
    m_i = tl.full([BLOCK_M], value=float("-inf"), dtype=tl.float32)   # 行最大值
    l_i = tl.full([BLOCK_M], value=0.0, dtype=tl.float32)            # 行指数和
    o_i = tl.zeros([BLOCK_M, D], dtype=tl.float32)                   # 累加输出

    # ---- 内层循环: 遍历 K/V 的所有块 ----
    for block_n_start in range(0, N, BLOCK_N):
        offs_n = block_n_start + tl.arange(0, BLOCK_N)

        # 加载 K_j: (BLOCK_N, D)
        k_ptrs = K_ptr + k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
        k_mask = offs_n[:, None] < N
        k = tl.load(k_ptrs, mask=k_mask, other=0.0)

        # 加载 V_j: (BLOCK_N, D)
        v_ptrs = V_ptr + v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
        v = tl.load(v_ptrs, mask=k_mask, other=0.0)

        # Step 1: 计算 S_ij = Q_i @ K_j^T  — (BLOCK_M, BLOCK_N)
        s_ij = tl.dot(q, tl.trans(k)) * sm_scale

        # 对超出序列长度的位置 mask
        s_ij = tl.where(offs_n[None, :] < N, s_ij, float("-inf"))

        # Step 2: Online Softmax 更新
        m_ij = tl.max(s_ij, axis=1)              # 当前块的行最大值 (BLOCK_M,)
        m_new = tl.maximum(m_i, m_ij)             # 全局行最大值更新

        # 修正因子
        alpha = tl.exp(m_i - m_new)               # 旧最大值的修正
        p_ij = tl.exp(s_ij - m_new[:, None])      # 当前块的 exp 值 (BLOCK_M, BLOCK_N)

        l_new = l_i * alpha + tl.sum(p_ij, axis=1)

        # Step 3: 更新输出累加器
        # 先对旧的 O 做修正(旧 max 变化带来的 rescale)
        o_i = o_i * alpha[:, None]
        # 加上当前块的贡献
        o_i += tl.dot(p_ij.to(v.dtype), v)

        # Step 4: 更新统计量
        m_i = m_new
        l_i = l_new

    # ---- 最终归一化 ----
    o_i = o_i / l_i[:, None]

    # ---- 写回 HBM ----
    o_ptrs = O_ptr + o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
    o_mask = offs_m[:, None] < N
    tl.store(o_ptrs, o_i.to(O_ptr.dtype.element_ty), mask=o_mask)

def flash_attention_triton(q, k, v):
    """
    简化版 FlashAttention 包装函数
    q, k, v: (batch, nheads, seqlen, headdim) — float16/bfloat16
    """
    B, H, N, D = q.shape
    assert D in {16, 32, 64, 128}, "headdim 必须是 16/32/64/128"  # assert断言

    o = torch.empty_like(q)
    sm_scale = 1.0 / (D ** 0.5)

    BLOCK_M = 128
    BLOCK_N = 64

    grid = (triton.cdiv(N, BLOCK_M), B * H)

    _flash_attn_fwd_kernel[grid](
        q, k, v, o,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        N=N, D=D,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
        sm_scale=sm_scale,
    )
    return o

# ---- 测试 ----
if __name__ == "__main__":
    torch.manual_seed(42)
    B, H, N, D = 2, 4, 1024, 64
    q = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
    k = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
    v = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)

    # FlashAttention (Triton 实现)
    o_flash = flash_attention_triton(q, k, v)

    # 标准 Attention (参考实现)
    scale = 1.0 / (D ** 0.5)
    attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)
    o_ref = attn @ v

    # 验证正确性
    max_diff = (o_flash.float() - o_ref.float()).abs().max().item()  # 将单元素张量转为Python数值
    print(f"Max absolute difference: {max_diff:.6f}")
    assert max_diff < 1e-2, f"差异过大: {max_diff}"
    print("✓ Triton FlashAttention 结果正确!")

12.5.4 性能对比基准测试

Python
"""
标准 Attention vs FlashAttention 性能对比
测量不同序列长度下的速度和显存使用
"""
import torch
import torch.nn.functional as F
import time

def benchmark_attention(func, q, k, v, name, warmup=10, repeat=50):
    """通用 benchmark 函数"""
    # Warmup
    for _ in range(warmup):
        _ = func(q, k, v)
    torch.cuda.synchronize()

    # 记录显存
    torch.cuda.reset_peak_memory_stats()
    mem_before = torch.cuda.memory_allocated()

    # 计时
    start = time.perf_counter()
    for _ in range(repeat):
        _ = func(q, k, v)
    torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / repeat * 1000  # ms

    mem_peak = torch.cuda.max_memory_allocated()
    mem_used = (mem_peak - mem_before) / 1024**2  # MB

    print(f"  {name:30s} | Time: {elapsed:8.2f} ms | Peak Mem: {mem_used:8.1f} MB")
    return elapsed, mem_used

def standard_attention(q, k, v):
    """标准 Attention(显式计算 N×N 矩阵)"""
    scale = 1.0 / (q.shape[-1] ** 0.5)  # [-1]负索引取最后元素
    attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)
    return attn @ v

def flash_attention_pytorch(q, k, v):
    """PyTorch SDPA(自动选择 FlashAttention 后端)"""
    return F.scaled_dot_product_attention(q, k, v, is_causal=False)

print("=" * 80)
print("Attention Performance Benchmark")
print("=" * 80)

B, H, D = 2, 32, 128

for N in [512, 1024, 2048, 4096, 8192]:
    q = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
    k = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)
    v = torch.randn(B, H, N, D, device="cuda", dtype=torch.float16)

    print(f"\n--- Seq Length = {N} ---")
    try:  # try/except捕获异常
        t1, m1 = benchmark_attention(standard_attention, q, k, v, "Standard Attention")
    except torch.cuda.OutOfMemoryError:
        print(f"  {'Standard Attention':30s} | OOM!")
        t1, m1 = float("inf"), float("inf")

    t2, m2 = benchmark_attention(flash_attention_pytorch, q, k, v, "FlashAttention (PyTorch SDPA)")

    if t1 != float("inf"):
        print(f"  Speedup: {t1/t2:.2f}x | Memory Saving: {m1/m2:.2f}x")

print("\n" + "=" * 80)

预期输出趋势:

Text Only
--- Seq Length = 512 ---
  Standard Attention             | Time:     0.35 ms | Peak Mem:     48.0 MB
  FlashAttention (PyTorch SDPA)  | Time:     0.25 ms | Peak Mem:     16.0 MB
  Speedup: 1.40x | Memory Saving: 3.00x

--- Seq Length = 4096 ---
  Standard Attention             | Time:    12.50 ms | Peak Mem:   2048.0 MB
  FlashAttention (PyTorch SDPA)  | Time:     2.80 ms | Peak Mem:     64.0 MB
  Speedup: 4.46x | Memory Saving: 32.00x

--- Seq Length = 8192 ---
  Standard Attention             | OOM!
  FlashAttention (PyTorch SDPA)  | Time:    10.50 ms | Peak Mem:    128.0 MB

关键观察: 1. 序列越长,FlashAttention 的优势越明显(\(O(N^2)\) vs \(O(N)\) 内存) 2. 在短序列上,FlashAttention 也有优势(减少 HBM 访问) 3. 标准 Attention 在长序列时可能 OOM,FlashAttention 不会


12.6 FlashAttention 在推理中的应用

12.6.1 与 KV-Cache 的结合

LLM 自回归推理分为两个阶段:

Text Only
Prefill 阶段(处理 prompt):
  Q: (1, prompt_len, d)     ← 所有 token 一起处理
  K: (1, prompt_len, d)
  V: (1, prompt_len, d)
  → 标准 FlashAttention(compute-bound,批量矩阵乘)

Decode 阶段(逐 token 生成):
  Q: (1, 1, d)              ← 只有 1 个新 token
  K: (1, seq_len, d)        ← 包含所有历史 token(从 KV-Cache 读取)
  V: (1, seq_len, d)
  → 特殊优化需求(memory-bound,Q 太小导致计算强度低)

KV-Cache 的作用:避免 decode 阶段重复计算历史 token 的 K、V。

Text Only
KV-Cache 工作流:

Step 1: Prefill
  输入: "The cat sat on the"
  计算: K₁..₅, V₁..₅ → 存入 KV-Cache

Step 2: Decode token "mat"
  Q₆ = Embed("mat") @ W_Q
  K-Cache: [K₁, K₂, K₃, K₄, K₅] → append K₆
  V-Cache: [V₁, V₂, V₃, V₄, V₅] → append V₆
  Attention: Q₆ @ [K₁..₆]^T → softmax → @ [V₁..₆]

Step 3: Decode token "is"
  Q₇ = Embed("is") @ W_Q
  K-Cache: [K₁..₆] → append K₇ (不重新计算K₁..₆)
  ...

12.6.2 PagedAttention(vLLM)

问题:KV-Cache 需要预分配连续内存,导致大量内存碎片和浪费。

PagedAttention 的解决方案:借鉴操作系统的虚拟内存/分页机制:

Text Only
传统 KV-Cache:
  ┌──────────────────────────────────────────────┐
  │ Request 1 KV-Cache  │  碎片  │ Request 2 KV  │ 碎片 │ ...
  └──────────────────────────────────────────────┘
  预分配最大长度 → 实际使用 30% → 70% 浪费

PagedAttention:
  Page Table:                   Physical Pages:
  Req 1: [P3, P7, P1, ...]     ┌────┐ ┌────┐ ┌────┐
  Req 2: [P2, P5, P8, ...]     │ P1 │ │ P2 │ │ P3 │ ...
                                └────┘ └────┘ └────┘
  每个 page 存固定数量 token 的 KV
  按需分配 → 接近 0% 浪费
  支持 copy-on-write → beam search 共享 KV

核心优势: - 内存利用率从 ~30% 提升到 ~95%+ - 支持更大的 batch size → 更高的吞吐量 - 结合 FlashAttention kernel 实现高效的分页 attention 计算

12.6.3 FlashDecoding

问题:Decode 阶段 \(Q\) 只有 1 行,标准 FlashAttention 无法充分利用 GPU 并行度。

Text Only
Prefill: Q=(1, 2048, d) → 2048 / BLOCK_M = 16 个并行块 ✓
Decode:  Q=(1, 1, d)    → 1 / BLOCK_M = 1 个并行块 ✗ (GPU 利用率极低)

FlashDecoding(Dao et al., 2023)的策略:在 KV 序列维度上并行:

Text Only
标准 FlashAttention (Decode):
  1 个 block 顺序遍历所有 K/V 块 → 慢

FlashDecoding:
  Step 1: Split-K — 多个 block 各处理 K/V 的一部分
    Block 0: K[0:1024],   V[0:1024]   → (partial_O₀, partial_l₀, partial_m₀)
    Block 1: K[1024:2048], V[1024:2048] → (partial_O₁, partial_l₁, partial_m₁)
    Block 2: K[2048:3072], V[2048:3072] → (partial_O₂, partial_l₂, partial_m₂)
    ...
    → 所有 block 并行执行!

  Step 2: Reduce — 一个小 kernel 合并所有部分结果
    使用 Online Softmax 的合并公式:
    m_global = max(m₀, m₁, m₂, ...)
    l_global = Σ lᵢ · exp(mᵢ - m_global)
    O = Σ Oᵢ · exp(mᵢ - m_global) / l_global

12.6.4 MHA / MQA / GQA 与 FlashAttention

多头注意力的变体对推理效率影响巨大:

Text Only
Multi-Head Attention (MHA):
  Q: H 个头, K: H 个头, V: H 个头
  KV-Cache 大小: 2 × L × H × d × sizeof(dtype)
  示例 (LLaMA-70B): 2 × 80 × 64 × 128 × 2B = 2.5 GB / token

Multi-Query Attention (MQA):
  Q: H 个头, K: 1 个头, V: 1 个头    ← K/V 所有头共享
  KV-Cache 缩小到 1/H
  缺点: 质量可能下降

Grouped-Query Attention (GQA):
  Q: H 个头, K: G 个头, V: G 个头    ← K/V 分 G 组共享
  GQA-8: 8 个 KV 头, 每个服务 H/8 个 Q 头
  平衡质量与效率(LLaMA-2-70B, Gemma 等采用)
Text Only
                MHA              GQA (G=2)          MQA
Q heads:    Q₁ Q₂ Q₃ Q₄      Q₁ Q₂ Q₃ Q₄      Q₁ Q₂ Q₃ Q₄
            ↕  ↕  ↕  ↕        ↕  ↕  ↕  ↕        ↕  ↕  ↕  ↕
KV heads:   K₁ K₂ K₃ K₄      K₁ K₁ K₂ K₂      K₁ K₁ K₁ K₁
            V₁ V₂ V₃ V₄      V₁ V₁ V₂ V₂      V₁ V₁ V₁ V₁

KV-Cache:   4 份 KV          2 份 KV            1 份 KV

FlashAttention 对 GQA/MQA 的支持: - flash-attn 库原生支持 GQA/MQA(flash_attn_func 自动处理 Q 和 KV 头数不同的情况) - KV-Cache 减小 → 更少的 HBM 访问 → decode 更快 - GQA/MQA + FlashAttention + KV-Cache 是 2024-2026 年 LLM 推理的标准配置


12.7 面试高频题

题目1:为什么 FlashAttention 快?从 IO 角度解释

参考答案

FlashAttention 快的原因不是减少了计算量(FLOPs 实际上略有增加),而是大幅减少了 GPU HBM 的读写次数

标准 Attention 需要将 \(O(N^2)\) 大小的中间矩阵 \(S = QK^T\)\(P = \text{softmax}(S)\) 写入 HBM 再读出来。由于 Attention 是 memory-bound 操作(算术强度低于 GPU 的计算带宽比),HBM 访问是性能瓶颈。

FlashAttention 通过分块(Tiling)策略,将 Q/K/V 的小块加载到 SRAM 中完成所有计算(包括 softmax),中间结果 \(S_{ij}\) 从不写回 HBM。这将 HBM 访问量从 \(O(Nd + N^2)\) 降低到 \(O(N^2 d^2 / M)\)\(M\) 为 SRAM 大小),在实际参数下减少了数倍到数十倍的 HBM 访问。

SRAM 带宽约为 HBM 的 10 倍,因此将计算放在 SRAM 中完成可以大幅提高有效带宽利用率。


题目2:Online Softmax 是如何工作的?为什么它对 FlashAttention 至关重要?

参考答案

传统的数值稳定 softmax 需要三次遍历数据: 1. 找到全局最大值 \(m = \max(x_i)\) 2. 计算指数和 \(l = \sum e^{x_i - m}\) 3. 归一化 \(\text{softmax}(x_i) = e^{x_i - m} / l\)

这要求能看到整行的所有数据,与 FlashAttention 的分块策略矛盾(每次只看一个 K 块)。

Online Softmax 解决了这个问题:维护两个滑动统计量 \(m\)(当前最大值)和 \(l\)(当前指数和),每处理一个新块时:

\[m_{\text{new}} = \max(m_{\text{old}}, m_{\text{block}})$$ $$l_{\text{new}} = l_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \sum e^{x_i - m_{\text{new}}}\]

关键修正因子 \(e^{m_{\text{old}} - m_{\text{new}}}\) 确保了当最大值更新时,之前的累积和被正确调整。同样的修正也应用于输出累加器 \(O\)

这使得 FlashAttention 可以一块一块地处理 K/V,无需看到全部数据就能正确计算 softmax。


题目3:FlashAttention 如何处理反向传播?什么是 Recomputation?

参考答案

标准 Attention 的反向传播需要用到前向传播中的中间矩阵 \(S\)(注意力分数)和 \(P\)(softmax 概率),这两个矩阵都是 \(O(N^2)\) 大小。

FlashAttention 的 Recomputation 策略:前向传播时不存储 \(S\)\(P\),只保存 \(Q, K, V, O\)(均为 \(O(Nd)\))以及 softmax 统计量 \(l, m\)(均为 \(O(N)\))。

反向传播时,使用保存的 \(Q, K, V, l, m\),按照相同的分块模式重新计算 \(S_{ij}\)\(P_{ij}\),然后计算梯度。

这增加了约 50% 的 FLOPs(重新计算的代价),但节省了 \(O(N^2)\)\(O(N)\) 的内存。由于 Attention 是 memory-bound,减少内存使用带来的好处(更大的 batch size、更少的 HBM 访问)远超额外计算的开销,端到端训练速度反而更快


题目4:FlashAttention-1 和 FlashAttention-2 的主要区别是什么?

参考答案

FlashAttention-2 的三大改进:

  1. 减少非 MatMul 计算:FA-1 在每个 K/V 块处理后都要做 rescaling(除以 \(l\)),这些标量运算无法利用 Tensor Core。FA-2 将归一化延迟到最后一步,累加未归一化的 \(O\),最后一次性归一化。

  2. 反转循环顺序:FA-1 外层遍历 K/V 块、内层遍历 Q 块;FA-2 反过来,外层遍历 Q 块。这使得不同 Q 块可以完全独立地并行执行,增加了序列维度的并行度。

  3. Warp 级别优化:FA-1 使用 split-K(warp 间共享 K),需要跨 warp 通信;FA-2 使用 split-Q(各 warp 处理不同的 Q 行),消除了 warp 间同步开销。

额外地,FA-2 对 causal masking 做了特殊优化:完全位于 mask 区域内的块直接跳过,因果 Attention 节省约 50% 计算。

结果:A100 上 FLOPs 利用率从约 50% 提升到约 72%,速度提升 1.5-1.7 倍。


题目5:FlashAttention 的局限性有哪些?

参考答案

  1. 硬件限制:需要 NVIDIA GPU(Ampere 或更新架构),不支持 CPU 或其他加速器(社区有部分移植如 AMD ROCm 的支持)。

  2. 头维度限制:最初版本仅支持 \(d \leq 128\)(FA-2 扩展到 \(d = 256\)),对于非标准头维度的模型可能不适用。

  3. 自定义 Attention Pattern 困难:对于非标准的 attention mask(如 sliding window、dilated attention、自定义 sparse pattern),需要专门的 kernel 实现,通用性有限。

  4. 反向传播开销:Recomputation 增加了约 50% 的反向 FLOPs。对于极度 compute-bound 的场景,这可能不划算(但实际中 Attention 几乎总是 memory-bound)。

  5. 实现复杂性:底层 CUDA/Triton kernel 编写和调试难度高,不容易自定义扩展。

  6. 短序列收益有限:当序列很短时(\(N < 256\)),标准 Attention 的 \(N^2\) 矩阵可以放入 SRAM,FlashAttention 的 tiling 开销反而可能增加 overhead。


题目6:MQA 和 GQA 的作用是什么?与 FlashAttention 的关系?

参考答案

Multi-Query Attention (MQA) 让所有 Q 头共享一组 K/V,将 KV-Cache 缩小到 \(1/H\)Grouped-Query Attention (GQA) 是折中方案,\(G\) 组 KV 头各服务 \(H/G\) 个 Q 头。

核心作用: 1. 减少 KV-Cache 大小:对于 70B 级别模型,MHA 的 KV-Cache 可达 2.5GB/token,GQA-8 缩减到约 300MB/token 2. 减少 decode 阶段的内存带宽需求:KV-Cache 是 decode 阶段的主要 IO 瓶颈 3. 增大可服务的 batch size:显存省出来可以放更多请求

与 FlashAttention 的关系是互补的: - FlashAttention 优化了 Attention 计算本身的 IO 效率 - GQA/MQA 优化了 KV-Cache 的存储和 IO - 两者结合是当前 LLM 推理的标准实践


题目7:PagedAttention 解决什么问题?它是如何工作的?

参考答案

问题:传统 KV-Cache 管理需要为每个请求预分配定长连续内存(通常按最大可能长度分配)。由于实际生成长度远小于最大值,导致 60-80% 的显存浪费(内部碎片)。此外,不同请求长度不同,显存无法灵活复用(外部碎片)。

PagedAttention 的解决方案(借鉴 OS 虚拟内存机制):

  1. 将 KV-Cache 分成固定大小的 页(page),每个 page 存储固定数量 token 的 KV 向量
  2. 维护一个 页表(page table) 记录每个请求的逻辑页到物理页的映射
  3. 物理页在 GPU 显存中非连续存储,按需动态分配

关键好处: - 内存利用率从 30% 提升到 95%+ - 支持 copy-on-write:beam search 中多个候选可以共享前缀的 KV-Cache 页 - 更大的 batch size → 更高的吞吐量

vLLM 是 PagedAttention 的代表性实现,它通过定制化的 CUDA kernel 支持在非连续页上执行高效的 Attention 计算。


题目8:FlashDecoding 是什么?解决了什么问题?

参考答案

问题:在自回归推理的 decode 阶段,每次只生成一个新 token,因此 \(Q\) 只有一行。标准 FlashAttention 在 Q 维度上并行(每个 block 处理若干 Q 行),单行 Q 意味着只有 \(\text{batch} \times \text{num\_heads}\) 个并行块,在 batch 较小或高端 GPU(有大量 SM)上无法充分利用硬件。

FlashDecoding 的方案——在 KV 序列维度上额外并行

  1. Split-K 阶段:将 KV 序列分成若干段,每段由一个独立的 block 并行处理,计算出局部结果 \((O_{\text{partial}}, l_{\text{partial}}, m_{\text{partial}})\)
  2. Reduce 阶段:一个轻量 kernel 使用 Online Softmax 合并公式将所有局部结果合并为最终输出

这样并行块数变为 \(\text{batch} \times \text{num\_heads} \times \text{num\_splits}\),显著提高了 GPU 利用率。在 batch=1 的场景下可以获得 3-5 倍的加速。


总结与知识图谱

Text Only
FlashAttention 知识体系:

标准 Attention 瓶颈
├── O(N²) 内存 → 长序列 OOM
├── Memory-bound(算术强度 < GPU 分界线)
└── 频繁 HBM 读写中间矩阵

FlashAttention 核心
├── IO-Aware Algorithm Design
├── Tiling: Q/K/V 分块加载到 SRAM
├── Online Softmax: 增量计算 softmax
├── Recomputation: 不存 N² 矩阵,反向重算
└── 复杂度: O(N²) 内存 → O(N)

版本演进
├── FA-1: 基础 Tiling + Online Softmax
├── FA-2: 减少非 MatMul 操作 + 循环反转 + warp 优化
└── FA-3: FP8 + 异步 Tensor Core + warp specialization (H100)

推理应用
├── KV-Cache: 避免重复计算
├── PagedAttention: 分页管理 KV-Cache
├── FlashDecoding: decode 阶段 KV 维度并行
└── GQA/MQA: 减少 KV-Cache 大小

参考文献

  1. Dao, T., et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS 2022.
  2. Dao, T. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024.
  3. Shah, J., et al. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." 2024.
  4. Milakov, M. & Gimelshein, N. "Online normalizer calculation for softmax." arXiv 2018.
  5. Kwon, W., et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP 2023.
  6. Dao, T., et al. "FlashDecoding." Blog post, 2023.
  7. Ainslie, J., et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." EMNLP 2023.

下一章预告第13章 推测解码与推理加速 将介绍投机解码技术——通过小模型"猜测"大模型输出来加速推理,是与 FlashAttention 互补的另一项重要推理加速技术。