跳转至

01-注意力机制详解

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 线性代数、神经网络基础、RNN/LSTM基础 学习目标: 深入理解注意力机制的各种形式,掌握自注意力和多头注意力的完整推导与实现

📌 定位说明:本章从深度学习基础视角讲解注意力机制(从Seq2Seq出发的动机、加性/点积/多头注意力推导)。面向大模型的注意力优化(FlashAttention/GQA/MQA/稀疏注意力等)请参考 LLM学习/01-基础巩固/02-注意力机制详解


目录


1. 注意力机制动机

1.1 信息瓶颈问题

在经典 Seq2Seq 模型中,编码器将整个输入序列压缩为一个固定维度的向量 \(c\)。当输入序列很长时,这个向量无法承载所有信息 — 这就是信息瓶颈

注意力机制让解码器在每一步都能"回看"编码器的所有输出,根据当前需要动态选择关注哪些输入位置。

1.2 人类注意力的类比

当你阅读一个英文句子来翻译时,翻译每个词时会"注意"源句子中不同的部分。注意力机制就是让模型学会这种"选择性关注"。


2. 注意力机制的一般框架

2.1 Query-Key-Value 框架

Query-Key-Value框架

注意力机制的核心是三元组 (Query, Key, Value)

  1. Query (Q): 当前要关注什么("我在找什么?")
  2. Key (K): 各个位置的索引("我有什么可以提供的?")
  3. Value (V): 各个位置的实际内容("我的实际信息是什么?")
\[\text{Attention}(Q, K, V) = \text{Aggregate}(\text{Score}(Q, K), V)\]

步骤: 1. 计算 Q 与每个 K 之间的相似度分数 2. 通过 Softmax 将分数转换为注意力权重(和为 1) 3. 用权重对 V 做加权求和,得到上下文向量

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def general_attention(query, keys, values, score_fn, mask=None):
    """
    通用注意力框架
    query: (batch, 1, d_q) 或 (batch, seq_q, d_q)
    keys: (batch, seq_k, d_k)
    values: (batch, seq_k, d_v)
    """
    # 1. 计算注意力分数
    scores = score_fn(query, keys)  # (batch, seq_q, seq_k)

    # 2. 掩码(可选)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # 3. Softmax 归一化
    weights = F.softmax(scores, dim=-1)  # (batch, seq_q, seq_k)

    # 4. 加权求和
    context = torch.bmm(weights, values)  # (batch, seq_q, d_v)

    return context, weights

3. 加性注意力

3.1 Bahdanau 注意力

Bahdanau et al.(2015)提出的加性注意力:

\[\text{score}(q, k) = v^T \tanh(W_q q + W_k k)\]

其中 \(W_q \in \mathbb{R}^{d_a \times d_q}\), \(W_k \in \mathbb{R}^{d_a \times d_k}\), \(v \in \mathbb{R}^{d_a}\) 是可学习参数,\(d_a\) 是注意力维度。

Python
class AdditiveAttention(nn.Module):  # 继承nn.Module定义神经网络层
    """加性(Bahdanau)注意力"""
    def __init__(self, query_dim, key_dim, attn_dim):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.W_q = nn.Linear(query_dim, attn_dim, bias=False)
        self.W_k = nn.Linear(key_dim, attn_dim, bias=False)
        self.v = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, query, keys, values, mask=None):
        """
        query: (batch, 1, query_dim)  — 单个查询位置
        keys: (batch, seq_len, key_dim)
        values: (batch, seq_len, value_dim)
        """
        # (batch, 1, attn_dim) + (batch, seq_len, attn_dim) → 广播相加
        scores = self.v(torch.tanh(self.W_q(query) + self.W_k(keys)))
        scores = scores.squeeze(-1)  # (batch, seq_len)  # squeeze去除大小为1的维度

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights.unsqueeze(1), values)  # (batch, 1, value_dim)

        return context, weights

# 测试
attn = AdditiveAttention(query_dim=256, key_dim=512, attn_dim=128)
q = torch.randn(4, 1, 256)
k = torch.randn(4, 20, 512)
v = torch.randn(4, 20, 512)
context, weights = attn(q, k, v)
print(f"Context: {context.shape}, Weights: {weights.shape}")

4. 点积注意力与缩放点积注意力

4.1 点积注意力

\[\text{score}(q, k) = q^T k\]

简单但要求 \(d_q = d_k\)

4.2 缩放点积注意力(Scaled Dot-Product Attention)

缩放点积注意力

Vaswani et al.(2017)在 Transformer 中使用的注意力形式:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

为什么要除以 \(\sqrt{d_k}\)

\(d_k\) 较大时,\(q^Tk = \sum_{i=1}^{d_k} q_i k_i\) 的方差约为 \(d_k\)(假设 \(q_i, k_i\) 独立同分布,均值 0,方差 1),这使得某些 softmax 输出极端趋近 0 或 1,梯度几乎为零。除以 \(\sqrt{d_k}\) 将方差恢复到 1。

4.3 完整推导与实现

Python
class ScaledDotProductAttention(nn.Module):
    """缩放点积注意力 — Transformer 的核心"""
    def __init__(self, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        """
        Q: (batch, num_heads, seq_q, d_k)
        K: (batch, num_heads, seq_k, d_k)
        V: (batch, num_heads, seq_k, d_v)
        mask: (batch, 1, 1, seq_k) 或 (batch, 1, seq_q, seq_k)
        """
        d_k = Q.size(-1)

        # Step 1: 计算注意力分数
        # QK^T: (batch, num_heads, seq_q, seq_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        # Step 2: 应用掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Step 3: Softmax 得到注意力权重
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Step 4: 加权求和
        context = torch.matmul(attn_weights, V)

        return context, attn_weights

# 测试
sdpa = ScaledDotProductAttention()
Q = torch.randn(2, 8, 10, 64)  # batch=2, 8 heads, seq=10, d_k=64
K = torch.randn(2, 8, 20, 64)  # seq_k=20
V = torch.randn(2, 8, 20, 64)
context, weights = sdpa(Q, K, V)
print(f"Context: {context.shape}")   # (2, 8, 10, 64)
print(f"Weights: {weights.shape}")   # (2, 8, 10, 20)

5. 自注意力(Self-Attention)

5.1 原理

自注意力机制

自注意力中 Q、K、V 都来自同一个序列。每个位置都能关注序列中的所有位置(包括自身),捕获序列内部的依赖关系。

对于输入序列 \(X \in \mathbb{R}^{n \times d}\)

\[Q = XW^Q, \quad K = XW^K, \quad V = XW^V\]
\[\text{SelfAttention}(X) = \text{softmax}\left(\frac{(XW^Q)(XW^K)^T}{\sqrt{d_k}}\right)(XW^V)\]

5.2 自注意力 vs RNN

特性 自注意力 RNN
并行计算 ✅ 完全并行 ❌ 必须顺序
长距离依赖 O(1) 连接路径 O(n) 连接路径
计算复杂度 \(O(n^2 d)\) \(O(n d^2)\)
适合长序列 \(n < d\) 时高效 \(n > d\) 时高效

5.3 实现

Python
class SelfAttention(nn.Module):
    """单头自注意力"""
    def __init__(self, d_model, d_k=None, d_v=None, dropout=0.1):
        super().__init__()
        d_k = d_k or d_model
        d_v = d_v or d_model

        self.W_Q = nn.Linear(d_model, d_k)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_v)
        self.scale = math.sqrt(d_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len, d_model)
        """
        Q = self.W_Q(x)  # (batch, seq_len, d_k)
        K = self.W_K(x)  # (batch, seq_len, d_k)
        V = self.W_V(x)  # (batch, seq_len, d_v)

        # 注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        weights = self.dropout(F.softmax(scores, dim=-1))
        context = torch.matmul(weights, V)

        return context, weights

# 测试
self_attn = SelfAttention(d_model=512)
x = torch.randn(4, 20, 512)
context, weights = self_attn(x)
print(f"自注意力: context={context.shape}, weights={weights.shape}")
# weights[0, i, j] 表示位置 i 对位置 j 的注意力权重

6. 多头注意力(Multi-Head Attention)

6.1 动机

多头注意力机制

单头注意力只能学习一种注意力模式。多头注意力将注意力分成多个"头",每个头可以关注不同的表示子空间和不同的位置关系。

6.2 数学公式

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\]

其中: $\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)$

参数:\(W_i^Q \in \mathbb{R}^{d_{model} \times d_k}\), \(W_i^K \in \mathbb{R}^{d_{model} \times d_k}\), \(W_i^V \in \mathbb{R}^{d_{model} \times d_v}\), \(W^O \in \mathbb{R}^{hd_v \times d_{model}}\)

通常 \(d_k = d_v = d_{model} / h\)

6.3 完整实现

Python
class MultiHeadAttention(nn.Module):
    """多头注意力 — 从零实现"""
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 线性投影(所有头合并在一起计算更高效)
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def split_heads(self, x):
        """将 (batch, seq, d_model) 重塑为 (batch, num_heads, seq, d_k)"""
        batch_size, seq_len, _ = x.shape
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def merge_heads(self, x):
        """将 (batch, num_heads, seq, d_k) 重塑为 (batch, seq, d_model)"""
        batch_size, _, seq_len, _ = x.shape
        x = x.transpose(1, 2).contiguous()  # 链式调用,连续执行多个方法
        return x.view(batch_size, seq_len, self.d_model)

    def forward(self, query, key, value, mask=None):
        """
        query: (batch, seq_q, d_model)
        key:   (batch, seq_k, d_model)
        value: (batch, seq_k, d_model)
        mask:  (batch, 1, 1, seq_k) 或 (batch, 1, seq_q, seq_k)
        """
        # 1. 线性投影
        Q = self.split_heads(self.W_Q(query))  # (batch, heads, seq_q, d_k)
        K = self.split_heads(self.W_K(key))    # (batch, heads, seq_k, d_k)
        V = self.split_heads(self.W_V(value))  # (batch, heads, seq_k, d_k)

        # 2. 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = self.dropout(F.softmax(scores, dim=-1))
        context = torch.matmul(attn_weights, V)

        # 3. 合并多头
        context = self.merge_heads(context)  # (batch, seq_q, d_model)

        # 4. 输出投影
        output = self.W_O(context)

        return output, attn_weights

# 测试
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(4, 20, 512)
output, weights = mha(x, x, x)  # 自注意力:Q=K=V=x
print(f"多头注意力: output={output.shape}, weights={weights.shape}")
# output: (4, 20, 512), weights: (4, 8, 20, 20)

7. 交叉注意力(Cross-Attention)

交叉注意力

交叉注意力中 Q 来自一个序列(通常是解码器),K 和 V 来自另一个序列(通常是编码器)。这是 Transformer 解码器中连接编码器和解码器的关键机制。

Python
# 交叉注意力的使用(在 Transformer 解码器中)
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        # 自注意力(带因果掩码)
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # 交叉注意力 — Q 来自解码器,K/V 来自编码器
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # 1. 自注意力
        attn_out, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_out))

        # 2. 交叉注意力 — Q=x(解码器), K=V=enc_output(编码器)
        cross_out, cross_weights = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(cross_out))

        # 3. 前馈网络
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_out))

        return x, cross_weights

8. 注意力可视化

Python
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def visualize_attention(attention_weights, src_tokens, tgt_tokens=None, head_idx=0):
    """
    可视化注意力权重
    attention_weights: (num_heads, tgt_len, src_len)
    """
    if tgt_tokens is None:
        tgt_tokens = src_tokens

    # 选择某个头
    weights = attention_weights[head_idx].detach().cpu().numpy()

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(weights, xticklabels=src_tokens, yticklabels=tgt_tokens,
                cmap='Blues', annot=True, fmt='.2f', ax=ax)
    ax.set_title(f'Attention Weights (Head {head_idx})')
    ax.set_xlabel('Source')
    ax.set_ylabel('Target')
    plt.tight_layout()
    plt.show()

def visualize_all_heads(attention_weights, src_tokens, tgt_tokens=None, ncols=4):
    """可视化所有头的注意力"""
    num_heads = attention_weights.shape[0]
    nrows = (num_heads + ncols - 1) // ncols

    fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
    if tgt_tokens is None:
        tgt_tokens = src_tokens

    for idx in range(num_heads):
        ax = axes[idx // ncols, idx % ncols] if nrows > 1 else axes[idx % ncols]
        weights = attention_weights[idx].detach().cpu().numpy()
        sns.heatmap(weights, ax=ax, cmap='Blues', xticklabels=src_tokens,
                    yticklabels=tgt_tokens if idx % ncols == 0 else False)
        ax.set_title(f'Head {idx}')

    plt.tight_layout()
    plt.show()

# 使用示例
# tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat']
# weights = torch.randn(8, 6, 6).softmax(dim=-1)  # 8 heads
# visualize_attention(weights, tokens, head_idx=0)
# visualize_all_heads(weights, tokens)

9. 注意力的变体与优化

9.0 缩放点积注意力的梯度推导

\(S = \frac{QK^T}{\sqrt{d_k}}\)\(A = \text{softmax}(S)\)(逐行),\(O = AV\)。已知上游梯度 \(\frac{\partial \mathcal{L}}{\partial O}\)(记为 \(dO\)),反向传播求 \(dQ, dK, dV\)

Step 1: 对 V 的梯度

\[dV = A^T \cdot dO\]

Step 2: 对 A 的梯度

\[dA = dO \cdot V^T\]

Step 3: 通过 softmax 反向传播(逐行,设第 \(i\) 行的 softmax 输出为 \(a\),上游梯度为 \(da\)):

\[dS_i = a \odot \left(da - (a^T da)\mathbf{1}\right)\]

\(ds_{ij} = a_{ij}(da_{ij} - \sum_k a_{ik} \cdot da_{ik})\)。这是 softmax 反向传播的标准公式。

Step 4: 对 Q 和 K 的梯度

\[dQ = \frac{dS \cdot K}{\sqrt{d_k}}, \quad dK = \frac{dS^T \cdot Q}{\sqrt{d_k}}\]

实践要点:标准实现中 softmax 反向传播的中间矩阵 \(dS\) 大小为 \(O(n^2)\),是注意力的显存瓶颈。Flash Attention 通过分块计算和在线 softmax 重计算避免了显式存储 \(S\)\(dS\),将显存从 \(O(n^2)\) 降到 \(O(n)\)

9.1 常见变体

变体 复杂度 思想
标准注意力 \(O(n^2)\) 全注意力
稀疏注意力 \(O(n\sqrt{n})\) 只关注固定模式的位置
线性注意力 \(O(n)\) 用核函数近似 softmax
Flash Attention \(O(n^2)\), 低显存 IO 感知的精确注意力
滑动窗口 \(O(nw)\) 只关注局部窗口
分组查询注意力(GQA) \(O(n^2)\), 低显存 K/V 共享组

9.2 Flash Attention

Python
# PyTorch 2.0+ 内置了 Flash Attention
# 使用 scaled_dot_product_attention(自动选择最优实现)
from torch.nn.functional import scaled_dot_product_attention

Q = torch.randn(2, 8, 1024, 64, device='cuda')
K = torch.randn(2, 8, 1024, 64, device='cuda')
V = torch.randn(2, 8, 1024, 64, device='cuda')

# 自动使用 Flash Attention(如果满足条件)
output = scaled_dot_product_attention(Q, K, V, is_causal=True)
print(f"Flash Attention 输出: {output.shape}")

10. 练习与自我检查

练习题

  1. 实现对比:分别实现加性注意力和缩放点积注意力,在同一任务上对比效率。
  2. 自注意力:从零实现单头自注意力,加入掩码支持,在简单序列分类上测试。
  3. 多头注意力:从零实现多头注意力,验证不同头学到了不同的注意力模式。
  4. 注意力可视化:训练一个带注意力的翻译模型,可视化并解读注意力热力图。
  5. 缩放分析:实验验证不缩放 (\(1/\sqrt{d_k}\)) 时梯度会趋向零的现象。

自我检查清单

  • 理解注意力机制解决了什么问题
  • 能区分 Q、K、V 的各自角色
  • 理解为什么需要 \(1/\sqrt{d_k}\) 缩放
  • 能从零实现缩放点积注意力
  • 理解多头注意力的动机和具体操作
  • 能区分自注意力和交叉注意力
  • 了解 Flash Attention 等效率优化方法

下一篇: 02-Transformer架构 — 完整的Transformer架构详解