跳转至

02 - 注意力机制详解(全面版)

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:深入理解注意力机制的各种变体、数学原理、优化技术和可视化分析方法。

📌 定位说明:本章侧重大模型中的注意力优化与前沿变体(FlashAttention/GQA/MQA/稀疏注意力等)。注意力机制的基础教学(从Seq2Seq出发的动机、加性/点积/多头注意力推导)请参考 深度学习/04-Transformer/01-注意力机制详解


目录

  1. 注意力机制的统一框架
  2. 自注意力深度剖析
  3. 注意力变体详解
  4. 稀疏注意力机制
  5. 线性注意力与高效注意力
  6. 现代大模型的注意力优化
  7. 注意力可视化与分析
  8. 注意力机制的扩展与应用

注意力机制的统一框架

1.1 注意力机制的通用形式

所有注意力机制都可以统一表示为以下形式:

Text Only
Attention(Q, K, V) = f(Q, K) · V

其中:
- Q (Query): 查询向量,表示"我在寻找什么"
- K (Key): 键向量,表示"我有什么"
- V (Value): 值向量,表示"实际内容是什么"
- f(Q, K): 相似度函数,计算查询与键的匹配程度

1.2 相似度函数对比

机制 相似度函数 时间复杂度 空间复杂度 特点
加性注意力 score = v^T tanh(W_Q·Q + W_K·K) O(n²d) O(n²) 灵活,适合不同维度
点积注意力 score = Q·K^T O(n²d) O(n²) 简单快速,需维度匹配
缩放点积 score = (Q·K^T) / √d_k O(n²d) O(n²) 最常用,数值稳定
双线性注意力 score = Q·W·K^T O(n²d²) O(n²) 可学习相似度

1.3 注意力作为软寻址

Text Only
注意力机制可以看作是一种"软寻址":

硬寻址(如数据库查询):
- 精确匹配某个key
- 返回对应的value

软寻址(注意力):
- 计算query与所有key的相似度
- 返回所有value的加权平均
- 权重由相似度决定

类比:
- 硬寻址:在图书馆精确找到某本书
- 软寻址:根据主题相关性,从多本书中提取信息

自注意力机制

上图展示了自注意力机制的核心流程。输入token通过权重矩阵W转换为查询(Query)、键(Key)和值(Value),然后通过计算查询和键的相似度得到注意力权重,最后用这些权重对值进行加权聚合,得到最终的输出。


自注意力深度剖析

2.1 自注意力的信息流动

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttentionAnalyzer:
    """
    自注意力分析工具
    """

    @staticmethod  # @staticmethod无需实例即可调用
    def analyze_information_flow(attention_weights, tokens):
        """
        分析自注意力中的信息流动

        Args:
            attention_weights: [seq_len, seq_len]
            tokens: 词列表
        """
        seq_len = len(tokens)

        print("=" * 60)
        print("自注意力信息流动分析")
        print("=" * 60)

        for i, token in enumerate(tokens):  # enumerate同时获取索引和元素
            # 获取当前token的注意力分布
            attn_dist = attention_weights[i]

            # 找到最关注的token
            top_k = 3
            top_indices = torch.topk(attn_dist, top_k).indices

            print(f"\n位置 {i} ('{token}') 最关注:")
            for idx in top_indices:
                print(f"  - 位置 {idx} ('{tokens[idx]}'): {attn_dist[idx]:.3f}")

            # 计算熵(注意力分布的集中度)
            entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum()
            print(f"  注意力熵: {entropy:.3f} (越低越集中)")

    @staticmethod
    def compute_attention_entropy(attention_weights):
        """
        计算注意力分布的熵

        熵高 = 注意力分散(关注多个位置)
        熵低 = 注意力集中(关注特定位置)
        """
        entropy = -(attention_weights * torch.log(attention_weights + 1e-10)).sum(dim=-1)
        return entropy.mean()

    @staticmethod
    def compute_attention_sparsity(attention_weights, threshold=0.1):
        """
        计算注意力稀疏度

        返回注意力权重大于threshold的比例
        """
        sparsity = (attention_weights > threshold).float().mean()
        return sparsity

# 示例分析
"""
输入: "The cat sat on the mat"

自注意力后的信息流动:

位置0 'The':
  - 可能关注: 'cat' (确定名词)
  - 熵: 中等(需要看上下文)

位置1 'cat':
  - 可能关注: 'The' (主语), 'sat' (谓语)
  - 熵: 较低(关系明确)

位置2 'sat':
  - 可能关注: 'cat' (主语), 'on' (介词)
  - 熵: 中等

位置3 'on':
  - 可能关注: 'sat' (动词), 'mat' (宾语)
  - 熵: 较低

位置4 'the':
  - 可能关注: 'mat' (确定名词)
  - 熵: 中等

位置5 'mat':
  - 可能关注: 'on' (介词), 'the' (冠词)
  - 熵: 较低
"""

2.2 注意力的梯度流动

Python
class AttentionGradientAnalyzer:
    """
    分析注意力的梯度流动
    """

    @staticmethod
    def analyze_gradient_flow(model, input_ids, target_ids):
        """
        分析注意力层的梯度
        """
        model.train()

        # 前向传播
        outputs = model(input_ids, labels=target_ids)
        loss = outputs.loss

        # 反向传播
        loss.backward()

        # 分析各层的梯度
        grad_stats = {}
        for name, param in model.named_parameters():
            if 'attention' in name and param.grad is not None:
                grad_stats[name] = {
                    'mean': param.grad.mean().item(),
                    'std': param.grad.std().item(),
                    'max': param.grad.max().item(),
                    'min': param.grad.min().item()
                }

        return grad_stats

注意力变体详解

3.1 加性注意力(Additive Attention)

Python
class AdditiveAttention(nn.Module):
    """
    加性注意力(Bahdanau Attention)

    使用一个前馈网络计算相似度
    score = v^T * tanh(W_q * Q + W_k * K)

    优点:可以处理不同维度的Q和K
    """
    def __init__(self, query_dim, key_dim, hidden_dim):
        super().__init__()  # super()调用父类方法
        self.W_q = nn.Linear(query_dim, hidden_dim, bias=False)
        self.W_k = nn.Linear(key_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch, query_len, query_dim]
            key: [batch, key_len, key_dim]
            value: [batch, key_len, value_dim]
        """
        # 扩展维度以便广播
        # query: [batch, query_len, 1, hidden]
        # key: [batch, 1, key_len, hidden]
        query_proj = self.W_q(query).unsqueeze(2)  # unsqueeze增加一个维度
        key_proj = self.W_k(key).unsqueeze(1)

        # 计算分数
        scores = self.v(torch.tanh(query_proj + key_proj)).squeeze(-1)
        # scores: [batch, query_len, key_len]

        # 应用mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax
        attn_weights = F.softmax(scores, dim=-1)

        # 加权求和
        output = torch.bmm(attn_weights, value)

        return output, attn_weights

3.2 双线性注意力

Python
class BilinearAttention(nn.Module):
    """
    双线性注意力

    score = Q * W * K^T

    可以学习Q和K之间的复杂交互
    """
    def __init__(self, query_dim, key_dim):
        super().__init__()
        self.W = nn.Parameter(torch.randn(query_dim, key_dim))
        nn.init.xavier_uniform_(self.W)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch, query_len, query_dim]
            key: [batch, key_len, key_dim]
        """
        # scores[b, i, j] = sum_d query[b, i, d] * W[d, e] * key[b, j, e]
        scores = torch.matmul(torch.matmul(query, self.W), key.transpose(-2, -1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.bmm(attn_weights, value)

        return output, attn_weights

3.3 局部注意力(Local Attention)

Python
class LocalAttention(nn.Module):
    """
    局部注意力

    只关注窗口内的位置,降低计算复杂度
    """
    def __init__(self, window_size=128):
        super().__init__()
        self.window_size = window_size

    def forward(self, query, key, value, mask=None):
        """
        每个位置只关注前后window_size/2个位置
        """
        batch_size, seq_len, dim = query.shape
        half_window = self.window_size // 2

        outputs = []

        for i in range(seq_len):
            # 确定窗口范围
            start = max(0, i - half_window)
            end = min(seq_len, i + half_window + 1)

            # 提取局部key和value
            local_key = key[:, start:end, :]
            local_value = value[:, start:end, :]
            local_query = query[:, i:i+1, :]

            # 计算局部注意力
            scores = torch.matmul(local_query, local_key.transpose(-2, -1))
            scores = scores / math.sqrt(dim)

            attn_weights = F.softmax(scores, dim=-1)
            local_output = torch.matmul(attn_weights, local_value)

            outputs.append(local_output)

        return torch.cat(outputs, dim=1)

稀疏注意力机制

4.1 稀疏注意力的动机

Text Only
标准自注意力的复杂度是O(n²),当序列长度n很大时:
- n=1K: 1M 计算量
- n=4K: 16M 计算量
- n=32K: 1B 计算量
- n=128K: 16B 计算量

稀疏注意力的目标:
- 保持长距离依赖能力
- 将复杂度降至O(n log n)或O(n)
- 只计算"重要"的注意力对

4.2 滑动窗口注意力

Python
class SlidingWindowAttention(nn.Module):
    """
    滑动窗口注意力(Longformer)

    每个token只关注窗口内的邻居
    """
    def __init__(self, d_model, num_heads, window_size=512):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.window_size = window_size

        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # 生成QKV
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, -1)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 转置为 [batch, heads, seq_len, head_dim]
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # 创建滑动窗口mask
        window_mask = torch.zeros(seq_len, seq_len, device=x.device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            window_mask[i, start:end] = 1

        # 计算注意力(只计算窗口内)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
        scores = scores.masked_fill(window_mask.unsqueeze(0).unsqueeze(0) == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
        return self.out_proj(output)

4.3 全局-局部混合注意力

Python
class GlobalLocalAttention(nn.Module):
    """
    全局-局部混合注意力(Longformer)

    全局token: 可以关注所有位置,被所有位置关注
    局部token: 只关注邻居
    """
    def __init__(self, d_model, num_heads, window_size=512, num_global_tokens=16):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.window_size = window_size
        self.num_global_tokens = num_global_tokens

        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # 假设前num_global_tokens个是全局token
        global_indices = list(range(self.num_global_tokens))

        # 生成QKV
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, -1)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # 创建注意力mask
        attn_mask = torch.zeros(seq_len, seq_len, device=x.device)

        for i in range(seq_len):
            if i in global_indices:
                # 全局token可以关注所有位置
                attn_mask[i, :] = 1
            else:
                # 局部token只关注邻居和全局token
                start = max(0, i - self.window_size // 2)
                end = min(seq_len, i + self.window_size // 2 + 1)
                attn_mask[i, start:end] = 1
                attn_mask[i, global_indices] = 1

        # 计算注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
        scores = scores.masked_fill(attn_mask.unsqueeze(0).unsqueeze(0) == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.d_model)
        return self.out_proj(output)

4.4 稀疏Transformer模式

Python
class SparseTransformerAttention(nn.Module):
    """
    稀疏Transformer(Sparse Transformer)

    使用稀疏模式:
    - 奇数层:行注意力(每个token关注同行)
    - 偶数层:列注意力(每个token关注同列)

    需要将1D序列reshape为2D
    """
    def __init__(self, d_model, num_heads, block_size=32):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.block_size = block_size

        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, layer_idx=0):
        """
        Args:
            x: [batch, seq_len, d_model]
            layer_idx: 当前层索引,决定使用行还是列注意力
        """
        batch_size, seq_len, _ = x.shape

        # 计算2D网格大小(向上取整,确保 grid_size² ≥ seq_len)
        grid_size = int(math.ceil(math.sqrt(seq_len)))
        padding = grid_size * grid_size - seq_len
        if padding > 0:
            # 填充到完全平方
            x = F.pad(x, (0, 0, 0, padding))

        # Reshape为2D [batch, height, width, dim]
        x_2d = x.reshape(batch_size, grid_size, grid_size, -1)

        # 生成QKV
        qkv = self.qkv(x_2d).reshape(batch_size, grid_size, grid_size, 3, self.num_heads, -1)
        q, k, v = qkv[:, :, :, 0], qkv[:, :, :, 1], qkv[:, :, :, 2]

        if layer_idx % 2 == 0:
            # 行注意力
            q = q.permute(0, 4, 1, 2, 3)  # [batch, heads, height, width, head_dim]
            k = k.permute(0, 4, 1, 2, 3)
            v = v.permute(0, 4, 1, 2, 3)

            # 每行内部做注意力
            scores = torch.matmul(q, k.transpose(-2, -1))
            attn_weights = F.softmax(scores, dim=-1)
            output = torch.matmul(attn_weights, v)

            output = output.permute(0, 2, 3, 1, 4).reshape(batch_size, grid_size, grid_size, self.d_model)
        else:
            # 列注意力
            q = q.permute(0, 4, 2, 1, 3)  # [batch, heads, width, height, head_dim]
            k = k.permute(0, 4, 2, 1, 3)
            v = v.permute(0, 4, 2, 1, 3)

            # 每列内部做注意力
            scores = torch.matmul(q, k.transpose(-2, -1))
            attn_weights = F.softmax(scores, dim=-1)
            output = torch.matmul(attn_weights, v)

            output = output.permute(0, 3, 2, 1, 4).reshape(batch_size, grid_size, grid_size, self.d_model)

        # Reshape回1D
        output = output.reshape(batch_size, -1, self.d_model)[:, :seq_len, :]

        return self.out_proj(output)

线性注意力与高效注意力

5.1 线性注意力原理

Text Only
标准注意力: O(n²) 复杂度
Softmax(QK^T)V

线性注意力: O(n) 复杂度
使用核技巧: φ(Q)φ(K)^T V = φ(Q)(φ(K)^T V)

关键洞察:先计算 φ(K)^T V,复杂度与序列长度无关!
Python
class LinearAttention(nn.Module):
    """
    线性注意力(Katharopoulos et al.)

    使用特征映射φ将softmax注意力转化为线性复杂度
    """
    def __init__(self, dim, num_heads, feature_dim=None):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.feature_dim = feature_dim or self.head_dim

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def elu_feature_map(self, x):
        """特征映射:ELU + 1"""
        return F.elu(x) + 1

    def forward(self, x):
        batch, seq_len, dim = x.shape

        # 生成QKV
        qkv = self.qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 转置 [batch, heads, seq_len, head_dim]
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # 特征映射
        q = self.elu_feature_map(q)
        k = self.elu_feature_map(k)

        # 线性注意力计算
        # KV = Σ_t k_t^T v_t
        KV = torch.einsum('bhsk,bhsv->bhkv', k, v)

        # Z = Σ_t k_t
        Z = k.sum(dim=2)

        # 输出 = Q @ KV / (Q @ Z)
        numerator = torch.einsum('bhnk,bhkv->bhnv', q, KV)
        denominator = torch.einsum('bhnk,bhk->bhn', q, Z).unsqueeze(-1) + 1e-6

        out = numerator / denominator
        out = out.transpose(1, 2).reshape(batch, seq_len, dim)

        return self.proj(out)

5.2 FlashAttention

Python
class FlashAttention(nn.Module):
    """
    FlashAttention原理说明

    FlashAttention不是新的注意力机制,而是IO感知的注意力实现

    核心思想:
    1. 将输入分块(tiling),避免加载完整的N×N注意力矩阵到GPU HBM
    2. 在SRAM(高速缓存)中计算注意力
    3. 使用online softmax算法,避免存储中间结果

    优势:
    - 减少HBM访问次数(从O(N²)到O(N))
    - 内存高效(不需要存储N×N注意力矩阵)
    - 计算更快(虽然FLOPs相同,但内存访问更少)

    实际使用:
    from flash_attn import flash_attn_func
    output = flash_attn_func(q, k, v, causal=True)
    """

    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, causal=True):
        """
        这里只是接口说明,实际应使用flash_attn库
        """
        batch, seq_len, dim = x.shape

        # 生成QKV
        qkv = self.qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 转置为 [batch, heads, seq_len, head_dim]
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # 实际使用时:
        # from flash_attn import flash_attn_func
        # out = flash_attn_func(q, k, v, causal=causal)

        # 这里使用标准注意力作为fallback
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if causal:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            scores = scores.masked_fill(mask.to(x.device), float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).reshape(batch, seq_len, dim)
        return self.proj(out)

现代大模型的注意力优化

6.1 MQA (Multi-Query Attention)

Python
class MultiQueryAttention(nn.Module):
    """
    多查询注意力(Multi-Query Attention)

    来自PaLM,所有头共享同一组K和V

    优势:
    - 大幅减少KV Cache内存占用
    - 解码速度更快

    劣势:
    - 略微降低模型质量
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Q投影到所有头
        self.W_q = nn.Linear(d_model, d_model)
        # K,V只投影到一个头(共享)
        self.W_k = nn.Linear(d_model, self.head_dim)
        self.W_v = nn.Linear(d_model, self.head_dim)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch, seq_len, _ = x.shape

        # Q: [batch, seq_len, d_model]
        q = self.W_q(x)
        q = q.reshape(batch, seq_len, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)  # [batch, heads, seq_len, head_dim]

        # K,V: [batch, seq_len, head_dim]
        k = self.W_k(x).unsqueeze(1)  # [batch, 1, seq_len, head_dim]
        v = self.W_v(x).unsqueeze(1)  # [batch, 1, seq_len, head_dim]

        # 广播K,V到所有头
        k = k.expand(-1, self.num_heads, -1, -1)
        v = v.expand(-1, self.num_heads, -1, -1)

        # 计算注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.permute(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
        return self.W_o(out)

6.2 GQA (Grouped-Query Attention)

Python
class GroupedQueryAttention(nn.Module):
    """
    分组查询注意力(Grouped-Query Attention)

    来自LLaMA-2,MQA和MHA的折中

    多个查询头共享一组K,V
    例如:32个头,8个K,V组,每组4个头共享
    """
    def __init__(self, d_model, num_heads, num_kv_groups=8):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.head_dim = d_model // num_heads

        assert num_heads % num_kv_groups == 0  # assert断言:条件False时抛出AssertionError
        self.heads_per_group = num_heads // num_kv_groups

        # Q投影到所有头
        self.W_q = nn.Linear(d_model, d_model)
        # K,V投影到组数
        kv_dim = self.head_dim * num_kv_groups
        self.W_k = nn.Linear(d_model, kv_dim)
        self.W_v = nn.Linear(d_model, kv_dim)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch, seq_len, _ = x.shape

        # Q: [batch, seq_len, num_heads, head_dim]
        q = self.W_q(x).reshape(batch, seq_len, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)  # [batch, num_heads, seq_len, head_dim]

        # K,V: [batch, seq_len, num_kv_groups, head_dim]
        k = self.W_k(x).reshape(batch, seq_len, self.num_kv_groups, self.head_dim)
        v = self.W_v(x).reshape(batch, seq_len, self.num_kv_groups, self.head_dim)

        # 转置
        k = k.permute(0, 2, 1, 3)  # [batch, num_kv_groups, seq_len, head_dim]
        v = v.permute(0, 2, 1, 3)

        # 重复K,V以匹配Q的头数
        k = k.repeat_interleave(self.heads_per_group, dim=1)
        v = v.repeat_interleave(self.heads_per_group, dim=1)

        # 计算注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.permute(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
        return self.W_o(out)

注意力可视化与分析

7.1 注意力权重可视化

Python
import matplotlib.pyplot as plt
import seaborn as sns

class AttentionVisualizer:
    """
    注意力可视化工具
    """

    @staticmethod
    def plot_attention_heatmap(attention_weights, tokens, title="Attention Heatmap"):
        """
        绘制注意力热力图

        Args:
            attention_weights: [seq_len, seq_len]
            tokens: token列表
        """
        plt.figure(figsize=(10, 8))
        sns.heatmap(
            attention_weights.cpu().numpy(),
            xticklabels=tokens,
            yticklabels=tokens,
            cmap='viridis',
            cbar_kws={'label': 'Attention Weight'}
        )
        plt.title(title)
        plt.xlabel('Key Position')
        plt.ylabel('Query Position')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_multi_head_attention(attention_weights, tokens, num_heads=8):
        """
        绘制多头注意力

        Args:
            attention_weights: [num_heads, seq_len, seq_len]
        """

多头注意力机制

上图展示了多头注意力机制的架构。多头注意力将输入分割到多个子空间,每个头独立计算注意力,最后将所有头的输出拼接并通过线性投影层。这种设计允许模型同时关注不同类型的关系和模式,大大增强了模型的表达能力。

Python
        fig, axes = plt.subplots(2, num_heads // 2, figsize=(20, 8))
        axes = axes.flatten()

        for i in range(num_heads):
            ax = axes[i]
            sns.heatmap(
                attention_weights[i].cpu().numpy(),
                ax=ax,
                cmap='viridis',
                cbar=False,
                xticklabels=False,
                yticklabels=False
            )
            ax.set_title(f'Head {i+1}')

        plt.suptitle('Multi-Head Attention Patterns')
        plt.tight_layout()
        plt.show()

    @staticmethod
    def plot_attention_rollout(attention_weights_list, tokens):
        """
        注意力展开(Attention Rollout)

        累积多层注意力,显示从输入到输出的信息流动
        """
        # 初始化单位矩阵
        rollout = torch.eye(len(tokens))

        for attn in attention_weights_list:
            # 注意力归一化
            attn = attn + torch.eye(len(tokens))  # 添加残差连接
            attn = attn / attn.sum(dim=-1, keepdim=True)

            # 累积
            rollout = torch.matmul(attn, rollout)

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            rollout.cpu().numpy(),
            xticklabels=tokens,
            yticklabels=tokens,
            cmap='viridis'
        )
        plt.title('Attention Rollout')
        plt.show()

7.2 注意力模式分析

Python
class AttentionPatternAnalyzer:
    """
    注意力模式分析
    """

    @staticmethod
    def identify_attention_patterns(attention_weights):
        """
        识别注意力模式

        返回每种模式的占比
        """
        seq_len = attention_weights.shape[0]

        patterns = {
            'diagonal': 0,      # 对角线(关注邻近)
            'vertical': 0,      # 垂直(关注特定位置)
            'block': 0,         # 块(关注连续区域)
            'uniform': 0,       # 均匀分布
            'sparse': 0         # 稀疏
        }

        for i in range(seq_len):
            attn_dist = attention_weights[i]

            # 对角线模式
            if i > 0:
                diagonal_score = attn_dist[i-1] + (attn_dist[i+1] if i < seq_len-1 else 0)
            else:
                diagonal_score = attn_dist[i+1] if i < seq_len-1 else 0

            # 垂直模式(某个位置特别高)
            max_val = attn_dist.max()
            vertical_score = max_val

            # 均匀度
            entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum()
            uniform_score = 1.0 / (entropy + 1)

            # 稀疏度
            sparse_score = (attn_dist > 0.1).float().sum() / seq_len

            # 分类
            scores = {
                'diagonal': diagonal_score,
                'vertical': vertical_score,
                'uniform': uniform_score,
                'sparse': 1 - sparse_score
            }

            best_pattern = max(scores, key=scores.get)
            patterns[best_pattern] += 1

        # 归一化
        for key in patterns:
            patterns[key] /= seq_len

        return patterns

    @staticmethod
    def find_syntactic_patterns(attention_weights, tokens, pos_tags):
        """
        发现句法模式

        例如:代词-名词、动词-宾语关系
        """
        patterns = []

        for i, (token, tag) in enumerate(zip(tokens, pos_tags)):
            attn_dist = attention_weights[i]

            # 找到最关注的token
            top_idx = attn_dist.argmax()

            # 分析关系
            if tag in ['PRP', 'PRP$']:  # 代词
                if pos_tags[top_idx] in ['NN', 'NNS']:  # 名词
                    patterns.append({
                        'type': 'pronoun-noun',
                        'from': (i, token),
                        'to': (top_idx, tokens[top_idx]),
                        'weight': attn_dist[top_idx].item()
                    })

            elif tag in ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']:  # 动词
                if pos_tags[top_idx] in ['NN', 'NNS', 'PRP']:  # 名词或代词
                    patterns.append({
                        'type': 'verb-object',
                        'from': (i, token),
                        'to': (top_idx, tokens[top_idx]),
                        'weight': attn_dist[top_idx].item()
                    })

        return patterns

注意力机制的扩展与应用

8.1 跨模态注意力

Python
class CrossModalAttention(nn.Module):
    """
    跨模态注意力(用于多模态模型)

    例如:视觉-语言模型中,文本查询关注图像特征
    """
    def __init__(self, text_dim, image_dim, hidden_dim):
        super().__init__()
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        self.image_proj_k = nn.Linear(image_dim, hidden_dim)
        self.image_proj_v = nn.Linear(image_dim, hidden_dim)

    def forward(self, text_features, image_features):
        """
        Args:
            text_features: [batch, text_len, text_dim]
            image_features: [batch, num_patches, image_dim]
        """
        # Q来自文本
        q = self.text_proj(text_features)

        # K,V来自图像
        k = self.image_proj_k(image_features)
        v = self.image_proj_v(image_features)

        # 计算跨模态注意力
        scores = torch.matmul(q, k.transpose(-2, -1))
        attn = F.softmax(scores, dim=-1)

        # 加权图像特征
        output = torch.matmul(attn, v)

        return output, attn

8.2 图注意力(Graph Attention)

Python
class GraphAttentionLayer(nn.Module):
    """
    图注意力层(GAT)

    用于图结构数据的注意力
    """
    def __init__(self, in_features, out_features, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.out_features = out_features

        self.W = nn.Linear(in_features, out_features * num_heads)
        self.a = nn.Parameter(torch.randn(num_heads, 2 * out_features))

    def forward(self, x, adj_matrix):
        """
        Args:
            x: [num_nodes, in_features]
            adj_matrix: [num_nodes, num_nodes] 邻接矩阵
        """
        num_nodes = x.size(0)

        # 线性变换
        h = self.W(x)  # [num_nodes, out_features * num_heads]
        h = h.view(num_nodes, self.num_heads, -1)  # [num_nodes, heads, out_features]  # view重塑张量形状

        # 计算注意力系数
        # 对于每条边(i,j),计算e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
        attn_scores = []

        for head in range(self.num_heads):
            h_head = h[:, head, :]  # [num_nodes, out_features]

            # 计算所有节点对的注意力
            # [num_nodes, 1, out_features] + [1, num_nodes, out_features]
            concat = torch.cat([
                h_head.unsqueeze(1).expand(-1, num_nodes, -1),
                h_head.unsqueeze(0).expand(num_nodes, -1, -1)
            ], dim=-1)  # [num_nodes, num_nodes, 2*out_features]

            e = torch.matmul(concat, self.a[head])  # [num_nodes, num_nodes]
            e = F.leaky_relu(e)

            # Mask掉不存在的边
            e = e.masked_fill(adj_matrix == 0, float('-inf'))

            attn_scores.append(e)

        attn_scores = torch.stack(attn_scores, dim=0)  # [heads, num_nodes, num_nodes]
        attn = F.softmax(attn_scores, dim=-1)

        # 聚合邻居特征
        out = torch.matmul(attn, h.transpose(0, 1))  # [heads, num_nodes, out_features]
        out = out.transpose(0, 1).reshape(num_nodes, -1)  # [num_nodes, heads*out_features]

        return out

总结

注意力机制选择指南

Text Only
根据场景选择注意力机制:

短序列(<1K):
├── 标准自注意力 ✅
└── FlashAttention ✅(更快)

中等序列(1K-8K):
├── 滑动窗口注意力 ✅
├── 稀疏注意力 ✅
└── FlashAttention ✅

长序列(8K-32K):
├── 全局-局部混合 ✅
├── 稀疏Transformer ✅
└── 线性注意力 ✅

超长序列(>32K):
├── 线性注意力 ✅
├── Mamba/State Space ✅
└── 分块注意力 ✅

内存受限:
├── MQA ✅(减少KV Cache)
├── GQA ✅(平衡质量和内存)
└── 线性注意力 ✅(O(n)复杂度)

关键要点

  1. 注意力是软寻址:计算相似度,加权聚合
  2. 复杂度是瓶颈:O(n²)限制了长序列应用
  3. 稀疏化是方向:只计算重要的注意力对
  4. 内存访问是关键:FlashAttention优化内存访问模式
  5. MQA/GQA是工程优化:牺牲少量质量换取大幅速度提升

下一步:继续学习实践-手写Transformer,动手实现各种注意力机制!


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026