跳转至

13. 注意力机制前沿创新

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

核心问题:Kimi新发布的注意力残差(Attention Residuals)是什么?当前注意力机制有哪些前沿创新?


目录

  1. 注意力机制演进全景
  2. Kimi MoBA (Block Attention)
  3. Kimi Attention Residuals(注意力残差)
  4. DeepSeek NSA稀疏注意力
  5. MLA多头潜在注意力
  6. 线性注意力与RNN化
  7. 注意力机制选型指南
  8. 面试高频问答

1. 注意力机制演进全景

1.1 演进时间线

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    注意力机制演进时间线                           │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  2017: 标准多头注意力 (MHA)                                      │
│  ├── O(n²) 复杂度                                               │
│  ├── 所有token两两交互                                          │
│  └── 问题:长序列计算/显存爆炸                                  │
│                                                                 │
│  2020: 稀疏注意力 (Longformer/BigBird)                          │
│  ├── 局部窗口 + 全局token                                       │
│  ├── O(n) 复杂度                                                │
│  └── 问题:固定稀疏模式,灵活性差                               │
│                                                                 │
│  2021: FlashAttention                                           │
│  ├── IO感知的分块计算                                           │
│  ├── 计算复杂度不变,但大幅减少显存访问                          │
│  └── 成为2022-2024标配                                          │
│                                                                 │
│  2022: GQA/MQA (分组查询注意力)                                 │
│  ├── 多个查询头共享KV                                           │
│  ├── 推理时KV Cache大幅减少                                     │
│  └── LLaMA 2/3采用                                              │
│                                                                 │
│  2024: MLA (多头潜在注意力)                                      │
│  ├── KV压缩到低维潜在空间                                       │
│  ├── KV Cache减少90%+                                           │
│  └── DeepSeek-V2/V3采用                                         │
│                                                                 │
│  2024-2025: 动态稀疏注意力                                      │
│  ├── DSA (DeepSeek Sparse Attention)                            │
│  ├── NSA (Native Sparse Attention)                              │
│  └── MoBA (Mixture of Block Attention)                          │
│                                                                 │
│  2026: Attention Residuals(注意力残差)                         │
│  ├── 用注意力机制替代固定残差连接                                │
│  ├── 选择性聚合早期表示                                          │
│  └── Kimi 2026年3月发布,被马斯克/Karpathy点赞                                │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1.2 各方案对比

Text Only
┌────────────────────────────────────────────────────────────────────────┐
│                        注意力机制对比                                   │├────────────────────────────────────────────────────────────────────────┤
│                                                                        │
│  方案          复杂度        KV Cache    长上下文    质量    适用场景    │
│  ─────────────────────────────────────────────────────────────────────│
│  MHA           O(n²)        100%        差          最好    短序列     │
│  GQA           O(n²)        25-50%      一般        好      通用       │
│  MLA           O(n²)        5-10%       好          好      长上下文   │
│  FlashAttn     O(n²)        100%        一般        最好    训练加速   │
│  稀疏注意力    O(n·k)       变化        好          较好    超长序列   │
│  线性注意力    O(n)         O(n)        最好        较差    极长序列   │
│                                                                        │
│  注意:复杂度是理论值,实际性能还受kernel实现、硬件等因素影响          │
│                                                                        │
└────────────────────────────────────────────────────────────────────────┘

2. Kimi MoBA (Block Attention)

2.1 背景与动机

Text Only
Kimi(月之暗面)在2025年发布的MoBA(Mixture of Block Attention)

核心问题:
├── 长上下文(1M+ tokens)的注意力计算成本极高
├── 传统稀疏注意力的固定模式不够灵活
└── 需要在保持质量的同时实现线性复杂度

MoBA的核心创新:
├── 将序列分块,每块独立计算注意力
├── 引入"注意力残差"连接,保留全局信息流
└── 动态路由决定哪些块需要精细计算

2.2 MoBA架构详解

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoBA(nn.Module):
    """
    Mixture of Block Attention (MoBA)
    Kimi(月之暗面)2025年提出的注意力残差机制

    核心思想:
    1. 将序列分成多个块(Block)
    2. 每个块内部计算全注意力
    3. 块之间通过"注意力残差"传递信息
    4. 动态路由决定跨块交互强度
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        block_size: int = 512,
        num_experts: int = 4,
        dropout: float = 0.1
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.block_size = block_size
        self.num_experts = num_experts

        # 标准注意力投影
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

        # 块级路由器
        self.block_router = nn.Linear(d_model, num_experts)

        # 注意力残差门控
        self.residual_gate = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.Sigmoid()
        )

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x: torch.Tensor, mask=None):
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            output: [batch, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.shape

        # 1. 标准投影
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # 重塑为多头形式
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. 分块注意力计算
        num_blocks = (seq_len + self.block_size - 1) // self.block_size

        # 块内注意力(局部)
        local_attn = self._compute_block_attention(Q, K, V, num_blocks)

        # 3. 计算注意力残差(全局信息)
        global_residual = self._compute_attention_residual(Q, K, V, x)

        # 4. 门控融合
        gate_input = torch.cat([
            local_attn.transpose(1, 2).reshape(batch_size, seq_len, -1),
            global_residual
        ], dim=-1)
        gate = self.residual_gate(gate_input)

        # 融合局部注意力和全局残差
        local_flat = local_attn.transpose(1, 2).reshape(batch_size, seq_len, d_model)
        output = gate * local_flat + (1 - gate) * global_residual

        # 5. 输出投影
        output = self.o_proj(output)
        output = self.dropout(output)

        return output

    def _compute_block_attention(self, Q, K, V, num_blocks):
        """计算块内注意力"""
        batch_size, num_heads, seq_len, head_dim = Q.shape

        outputs = []
        for i in range(num_blocks):
            start = i * self.block_size
            end = min((i + 1) * self.block_size, seq_len)

            # 提取当前块
            Q_block = Q[:, :, start:end, :]
            K_block = K[:, :, start:end, :]
            V_block = V[:, :, start:end, :]

            # 块内注意力
            attn_weights = torch.matmul(Q_block, K_block.transpose(-2, -1)) * self.scale
            attn_weights = F.softmax(attn_weights, dim=-1)
            attn_output = torch.matmul(attn_weights, V_block)

            outputs.append(attn_output)

        return torch.cat(outputs, dim=2)

    def _compute_attention_residual(self, Q, K, V, original_x):
        """
        计算注意力残差
        核心创新:通过残差连接保留全局信息流
        """
        batch_size, num_heads, seq_len, head_dim = Q.shape

        # 使用下采样后的全局注意力(降低复杂度)
        # 这里用池化实现简化版本
        pooled_K = K.mean(dim=2, keepdim=True)  # [batch, heads, 1, head_dim]
        pooled_V = V.mean(dim=2, keepdim=True)

        # 计算全局注意力权重
        global_attn = torch.matmul(Q, pooled_K.transpose(-2, -1)) * self.scale
        global_attn = F.softmax(global_attn, dim=-1)

        # 全局残差
        residual = torch.matmul(global_attn, pooled_V)  # [batch, heads, seq_len, head_dim]
        residual = residual.transpose(1, 2).reshape(batch_size, seq_len, -1)

        # 与原始输入的残差连接
        residual = residual + original_x

        return residual


class MoBAWithDynamicRouting(nn.Module):
    """
    带动态路由的MoBA
    根据输入内容动态决定注意力模式
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        block_size: int = 512,
        top_k_blocks: int = 4
    ):
        super().__init__()
        self.moba = MoBA(d_model, num_heads, block_size)
        self.top_k_blocks = top_k_blocks

        # 跨块路由器
        self.cross_block_router = nn.Linear(d_model, 1)

    def forward(self, x: torch.Tensor):
        """
        动态选择需要跨块交互的块
        """
        batch_size, seq_len, d_model = x.shape
        num_blocks = (seq_len + self.moba.block_size - 1) // self.moba.block_size

        # 计算每个块的重要性分数
        block_scores = []
        for i in range(num_blocks):
            start = i * self.moba.block_size
            end = min((i + 1) * self.moba.block_size, seq_len)
            block = x[:, start:end, :]
            # 块的代表性向量
            block_repr = block.mean(dim=1)
            score = self.cross_block_router(block_repr)
            block_scores.append(score)

        block_scores = torch.cat(block_scores, dim=1)  # [batch, num_blocks]

        # 选择top-k重要块进行跨块交互
        top_k_indices = block_scores.topk(min(self.top_k_blocks, num_blocks), dim=1).indices

        # 标准MoBA计算
        output = self.moba(x)

        # 对选中的块添加额外的跨块注意力(简化实现)
        # 实际实现会更复杂,包括跨块注意力计算

        return output

2.3 MoBA vs 传统注意力

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    MoBA vs 传统注意力                            │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  传统注意力(1M tokens):                                       │
│  ├── 计算量:O(n²) = O(10^12) 次运算                            │
│  ├── 显存:O(n²) ≈ 1TB+(不可行)                               │
│  └── 实际:需要复杂的分块和近似                                  │
│                                                                 │
│  MoBA(1M tokens,块大小512):                                  │
│  ├── 计算量:O(n·block_size) = O(10^9) 次运算                   │
│  ├── 显存:O(n·block_size) ≈ 1GB                                │
│  ├── 残差连接:保留全局信息流                                    │
│  └── 质量:接近全注意力                                          │
│                                                                 │
│  关键创新:                                                      │
│  1. 块内全注意力 + 块间残差 = 局部精度 + 全局感知                │
│  2. 动态路由 = 自适应计算分配                                    │
│  3. 门控融合 = 平衡局部与全局                                    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3. Kimi Attention Residuals(注意力残差)

3.1 背景与核心思想

Text Only
Kimi(月之暗面)在2026年3月16日发布的Attention Residuals (AttnRes)

核心问题:
├── 标准残差连接:h_{l+1} = h_l + f_l(h_l)
│   └── 固定的"+"号,所有层以相同权重累积
├── 随着深度增长,每个层的贡献被稀释(PreNorm问题)
└── 隐藏状态幅度无限增长,梯度分布不均匀

AttnRes核心创新:
├── 用注意力机制替代固定的"+"号
├── 公式:h_l = Σ α_{i→l} · v_i(注意力加权的累积)
└── 每层都能选择性访问所有早期表示

3.2 Full AttnRes vs Block AttnRes

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    Attention Residuals 两种形式                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Full AttnRes:                                                   │
│  ├── 每层都对所有前面的层输出做注意力                            │
│  ├── 内存复杂度 O(Ld)(L=层数, d=隐藏维度)                     │
│  └── 效果最好,但内存开销大                                      │
│                                                                 │
│  Block AttnRes (实用方案):                                       │
│  ├── 将层分成N个块(Block)                                     │
│  ├── 块内使用标准残差连接                                        │
│  ├── 块间使用注意力(只对N个块表示做注意力)                     │
│  ├── 内存复杂度降至 O(Nd)(N=块数,N << L)                     │
│  └── ~8个块就能恢复Full AttnRes的大部分收益                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3.3 核心公式详解

3.3.1 标准残差连接的问题

Text Only
标准PreNorm残差连接:
  h_{l+1} = h_l + f_l(h_l)

其中:
  • h_l:第l层的隐藏状态
  • f_l:第l层的变换函数(注意力层或MLP层)
  • "+":固定的逐元素加法

问题分析:
  1. 固定权重:所有层以相同权重(1.0)累积,信息流不均衡
  2. PreNorm稀释:深层网络中,早期层信息在向后传播时被稀释
  3. 幅度增长:随着层数增加,隐藏状态幅度可能无限增长

数学上看,标准残差等价于:
  h_L = h_0 + Σ_{i=0}^{L-1} f_i(h_i)

这意味着所有层对最终输出的贡献权重都是1.0,没有选择性。

3.3.2 Full AttnRes公式推导

Text Only
AttnRes的核心思想是用注意力权重替代固定的"+"号:

Full AttnRes公式:
  h_l = Σ_{i=0}^{l-1} α_{i→l} · v_i

符号说明:
  • h_l:第l层的输出隐藏状态
  • α_{i→l}:从第i层指向第l层的注意力权重(标量)
  • v_i:第i层的值向量(value),通常是第i层输出的非线性变换
  • Σ:求和符号,遍历所有早期层

注意力权重的计算:
  α_{i→l} = softmax(q_l · k_i / √d_k)

其中:
  • q_l = W_q · h_l:第l层的query向量
  • k_i = W_k · h_i:第i层的key向量
  • d_k:key向量的维度(用于缩放)

物理意义:
  α_{i→l} 表示第l层对第i层信息的"关注程度",类似于标准注意力机制中
  query对key的关注程度。这使得模型可以自适应地选择保留哪些层的
  信息,而不是一刀切地全部相加。

推导过程:
  1. 标准残差:h_l = h_{l-1} + f_{l-1}(h_{l-1})
  2. AttnRes推广:h_l = Σ_{i=0}^{l-1} α_{i→l} · f_i(h_i)
  3. 简化形式:使用恒等变换 v_i = h_i,则:
     h_l = Σ_{i=0}^{l-1} α_{i→l} · h_i

这种形式下,每一层都能选择性地从所有早期层中提取信息。

3.3.3 Block AttnRes:内存优化的实用方案

Text Only
Full AttnRes的问题:
  • 内存复杂度 O(L × d):每层都需要存储对所有早期层的注意力权重
  • 对于100层模型,这变得不可行

Block AttnRes解决方案:
  • 将L层分成N个块(Block),每个块包含B = L/N层
  • 块内使用标准残差连接(局部信息保留)
  • 块间使用注意力机制(跨块信息选择)

Block AttnRes公式:
  • 块内(intra-block):
    h_b^i = h_b^{i-1} + f_b^{i-1}(h_b^{i-1})  (标准残差)

  • 块间(inter-block):
    H_b = Σ_{j=0}^{b-1} β_{j→b} · B_j

  其中:
    • B_j:第j个Block的输出表示(块内残差累加的结果)
    • β_{j→b}:从Block j到Block b的注意力权重
    • H_b:Block b的最终输出(融合了所有先前Block的信息)

内存复杂度分析:
  • Full AttnRes:O(L × d)(L=100层, d=4096维 → 400K参数)
  • Block AttnRes (N=8块):O(N × d)(N=8, d=4096 → 32K参数)
  • 内存节省:约12.5倍

实验发现:
  • N=8个块就能恢复Full AttnRes的大部分收益(>90%)
  • 这使得AttnRes变得实用

3.3.4 与传统残差连接的关键区别

Text Only
┌─────────────────────────────────────────────────────────────────┐
│              标准残差 vs Attention Residuals                     │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  标准残差:                                                       │
│  • 机制:固定的"+"运算                                          │
│  • 权重:所有层权重相同(1.0)                                    │
│  • 信息流:单向流动,无法选择性回顾                                │
│  • 表达能力:线性组合                                            │
│                                                                 │
│  Attention Residuals:                                            │
│  • 机制:可学习的注意力权重                                       │
│  • 权重:动态计算,取决于内容                                     │
│  • 信息流:可选择性访问任意早期层                                  │
│  • 表达能力:非线性组合(softmax归一化)                           │
│                                                                 │
│  本质区别:                                                       │
│  • 标准残差:y = Σ x_i(权重固定)                               │
│  • AttnRes:y = Σ softmax(q·k) · v(权重动态)                  │
│                                                                 │
│  这类似于标准前馈网络与注意力的区别:                              │
│  前馈:y = Wx(固定权重)                                        │
│  注意力:y = softmax(QK^T)V(动态权重)                         │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3.4 实验结果(Kimi Linear 48B / 3B activated, 1.4T tokens)

3.4.1 下游任务性能对比

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                      下游任务性能对比                             │
├─────────────────────────────────────────────────────────────────┤
│  类别        基准测试         Baseline    AttnRes    提升       │
│  ─────────────────────────────────────────────────────────────── │
│  通用        MMLU             73.5       74.6      +1.1       │
│  通用        GPQA-Diamond     36.9       44.4      +7.5       │
│  通用        BBH              76.3       78.0      +1.7       │
│  通用        TriviaQA         69.9       71.8      +1.9       │
│  数学代码    Math             53.5       57.1      +3.6       │
│  数学代码    HumanEval        59.1       62.2      +3.1       │
│  数学代码    MBPP             72.0       73.9      +1.9       │
│  中文        CMMLU            82.0       82.9      +0.9       │
│  中文        C-Eval           79.6       82.5      +2.9       │
└─────────────────────────────────────────────────────────────────┘

最大提升:GPQA-Diamond +7.5(多步推理任务)
代码提升:HumanEval +3.1

Scaling Law:Block AttnRes达到1.25x计算量baseline的效果

3.4.2 关键发现

Text Only
1. 多步推理任务提升最显著:
   • GPQA-Diamond(研究生水平科学问题):+7.5分
   • 这类任务需要模型在多步推理中保持信息传递
   • AttnRes的选择性信息聚合正好适合此类任务

2. 数学和代码任务稳定提升:
   • Math(数学问题):+3.6
   • HumanEval(代码生成):+3.1
   • 这些任务需要精确的信息保留和选择性访问

3. 常识推理任务也有提升:
   • BBH(Big Bench Hard):+1.7
   • TriviaQA:+1.9

4. 中文理解任务:
   • C-Eval:+2.9
   • CMMLU:+0.9

3.4.3 Block数量与性能的关系

Text Only
┌─────────────────────────────────────────────────────────────────┐
│              Block数量 vs 任务性能(Kimi Linear 7B)            │
├─────────────────────────────────────────────────────────────────┤
│  Block数    MMLU      GPQA      Math       计算量开销         │
│  ─────────────────────────────────────────────────────────────── │
│  1 (标准残差)  72.8     35.2     51.2       1.00x             │
│  2           73.4     39.1     53.4       1.08x             │
│  4           74.1     42.3     55.8       1.15x             │
│  8           74.6     44.4     57.1       1.25x             │
│  16          74.8     45.1     57.6       1.45x             │
│  Full AttnRes  74.9     45.8     57.9       1.80x             │
└─────────────────────────────────────────────────────────────────┘

关键结论:
  • Block=8时达到Full AttnRes效果的96%以上
  • 计算量仅增加25%即可获得大部分收益
  • 边际效益在Block>8后递减

3.4.4 内存和计算开销

Text Only
┌─────────────────────────────────────────────────────────────────┐
│              AttnRes 内存开销分析                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  标准PreNorm Transformer:                                        │
│  • 每层额外参数:0(只有固定的残差连接)                          │
│  • 内存开销:无                                                   │
│                                                                 │
│  Full AttnRes:                                                  │
│  • 额外参数:O(L × d)(query/key投影 + 归一化)                  │
│  • 对于L=80, d=4096:约 80 × 4096 × 2 × 2 = 1.3M 参数          │
│  • 内存开销:~5MB(对于80层模型)                                │
│                                                                 │
│  Block AttnRes(N=8块):                                        │
│  • 额外参数:O(N × d)(每块一个block-level attention)           │
│  • 对于N=8, d=4096:约 8 × 4096 × 2 × 2 = 131K 参数            │
│  • 内存开销:~0.5MB                                              │
│                                                                 │
│  计算开销:                                                       │
│  • Block AttnRes:额外约25%的FLOPS(8块配置)                    │
│  • 这是"免费午餐":少量计算换取显著性能提升                       │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3.5 Block AttnRes 完整PyTorch实现

3.5.1 核心模块实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class BlockAttnRes(nn.Module):
    """
    Block Attention Residual (Block AttnRes)
    Kimi 2026年提出的注意力残差机制

    论文:Attention Residuals: From ResNets to Transformers

    核心思想:
    1. 将层分成多个Block,块内使用标准残差
    2. Block之间使用注意力机制进行信息选择
    3. 用注意力权重替代固定的"+"号

    Args:
        d_model: 模型隐藏维度
        num_blocks: Block数量(论文推荐8)
        num_heads: 注意力头数
        dropout: Dropout概率
    """

    def __init__(
        self,
        d_model: int,
        num_blocks: int = 8,
        num_heads: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Block级注意力投影(用于计算跨Block的注意力)
        # 将隐藏状态投影到query/key/value空间
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

        # 块表示的归一化层
        self.block_norm = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # 缩放因子
        self.scale = self.head_dim ** -0.5

    def forward(self, block_outputs: list, current_block_output: torch.Tensor) -> torch.Tensor:
        """
        跨Block的注意力残差计算

        Args:
            block_outputs: 之前Block的输出列表,长度为(num_blocks - 1)
            current_block_output: 当前Block的输出(尚未完全完成)

        Returns:
            融合了所有先前Block信息的输出
        """
        # 所有Block的表示(包括当前Block的部分输出)
        all_blocks = block_outputs + [current_block_output]  # [N+1, B, T, D]
        num_total_blocks = len(all_blocks)

        # Stack: [N+1, B, T, D]
        V = torch.stack(all_blocks, dim=0)

        # 分别对V做K和V投影
        # V: [N+1, B, T, D] -> [N+1, B, T, D]
        K = self.block_norm(V)
        K = self.k_proj(K)

        # 对所有Block计算Query(用于从所有Block中提取信息)
        # Query是当前Block输出经过投影得到
        Q = self.q_proj(current_block_output)  # [B, T, D]
        Q = Q.unsqueeze(0).expand(num_total_blocks - 1, -1, -1, -1)  # [N, B, T, D]

        # 当前Block也作为自己的query
        Q_self = self.q_proj(current_block_output).unsqueeze(0)  # [1, B, T, D]
        Q = torch.cat([Q, Q_self], dim=0)  # [N+1, B, T, D]

        # 重塑为多头形式: [N+1, B, T, H, D_h] -> [N+1, B, H, T, D_h]
        Q = Q.view(num_total_blocks, -1, current_block_output.shape[1], 
                   self.num_heads, self.head_dim).transpose(1, 3)
        K = K.view(num_total_blocks, -1, current_block_output.shape[1], 
                   self.num_heads, self.head_dim).transpose(1, 3)
        V = V.view(num_total_blocks, -1, current_block_output.shape[1], 
                   self.num_heads, self.head_dim).transpose(1, 3)

        # 计算注意力权重: [N+1, B, H, T, T]
        # 对于每个目标位置t,我们关注所有N+1个Block在位置t的表示
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_weights, dim=0)  # 在Block维度归一化

        # 应用注意力: [N+1, B, H, T, D_h]
        attended = torch.matmul(attn_weights, V)

        # 合并多头: [N+1, B, T, D]
        attended = attended.transpose(1, 3)  # [N+1, B, T, H, D_h]
        attended = attended.reshape(num_total_blocks, -1, 
                                   current_block_output.shape[1], self.d_model)

        # 跨Block维度加权求和(实际上是softmax已经做了)
        # 取第一个Block的加权结果(因为权重已经归一化)
        # 或者可以重新加权,保留所有Block的信息
        output = torch.einsum('n b t d->b t d', attended)  # [B, T, D]

        output = self.dropout(output)
        output = self.o_proj(output)

        return output


class BlockAttnResTransformerBlock(nn.Module):
    """
    使用Block AttnRes的Transformer Block

    标准的PreNorm结构:
        x = x + Attn(Norm(x))
        x = x + MLP(Norm(x))

    Block AttnRes结构:
        h = BlockAttnRes(previous_blocks, Norm(x))
        x = x + Attn(Norm(h))
        x = x + MLP(Norm(x))
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        num_blocks: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()
        self.d_model = d_model
        self.num_blocks = num_blocks

        # Block AttnRes
        self.block_attn_res = BlockAttnRes(
            d_model=d_model,
            num_blocks=num_blocks,
            num_heads=num_heads,
            dropout=dropout
        )

        # 标准PreNorm结构
        self.attn_norm = nn.LayerNorm(d_model)
        self.mlp_norm = nn.LayerNorm(d_model)

        # 自注意力
        self.self_attn = nn.MultiheadAttention(
            d_model, 
            num_heads, 
            dropout=dropout, 
            batch_first=True
        )

        # MLP
        mlp_hidden_dim = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, d_model),
            nn.Dropout(dropout)
        )

        # 用于存储每个Block的输出
        self.block_outputs = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, T, D] 输入序列
        Returns:
            output: [B, T, D] 输出序列
        """
        # Step 1: Block AttnRes(在注意力之前应用)
        # 将当前输入与之前的Block输出进行注意力聚合
        if len(self.block_outputs) > 0:
            # 计算跨Block的注意力残差
            block_residual = self.block_attn_res(
                self.block_outputs, 
                self.attn_norm(x)
            )
            # 融合到主路径
            x = x + block_residual

        # Step 2: 自注意力(标准PreNorm)
        attn_out, _ = self.self_attn(x, x, x)
        x = x + attn_out

        # Step 3: MLP(标准PreNorm)
        x = x + self.mlp(self.mlp_norm(x))

        # Step 4: 保存当前Block输出(用于下一个Block)
        # 这里简化处理,实际实现会更复杂
        self.block_outputs.append(x.detach())

        # 保持固定数量的Block输出
        if len(self.block_outputs) > self.num_blocks:
            self.block_outputs.pop(0)

        return x

    def reset_block_outputs(self):
        """重置Block输出缓存(每个序列开始时调用)"""
        self.block_outputs = []

3.5.2 使用示例

Python
# 示例:创建和使用Block AttnRes

def create_block_attn_res_model():
    """创建一个使用Block AttnRes的简化模型"""

    class SimpleBlockAttnResLM(nn.Module):
        """简化语言模型"""

        def __init__(self, vocab_size, d_model, num_heads, num_layers, num_blocks):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, d_model)
            self.pos_embedding = nn.Parameter(torch.randn(1, 4096, d_model))

            self.blocks = nn.ModuleList([
                BlockAttnResTransformerBlock(
                    d_model=d_model,
                    num_heads=num_heads,
                    num_blocks=num_blocks
                )
                for _ in range(num_layers)
            ])

            self.final_norm = nn.LayerNorm(d_model)
            self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        def forward(self, input_ids):
            # Embedding
            x = self.embedding(input_ids)
            x = x + self.pos_embedding[:, :x.shape[1], :]

            # Block AttnRes层
            for block in self.blocks:
                block.reset_block_outputs()  # 重置block缓存
                x = block(x)

            x = self.final_norm(x)
            logits = self.lm_head(x)

            return logits

    return SimpleBlockAttnResLM(
        vocab_size=32000,
        d_model=1024,
        num_heads=16,
        num_layers=12,
        num_blocks=8
    )

# 测试代码
if __name__ == "__main__":
    # 创建模型
    model = create_block_attn_res_model()

    # 模拟输入
    batch_size = 2
    seq_len = 128
    vocab_size = 32000

    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

    # 前向传播
    with torch.no_grad():
        logits = model(input_ids)
        print(f"Output shape: {logits.shape}")  # [2, 128, 32000]

    # 统计参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # 计算量估算(简化为每层约2 * d_model^2 * seq_len)
    d_model = 1024
    seq_len = 128
    flops_per_layer = 2 * d_model * d_model * seq_len * 12  # 12层
    print(f"Estimated FLOPs per forward pass: {flops_per_layer:,}")

3.5.3 与标准实现的对比

Text Only
┌─────────────────────────────────────────────────────────────────┐
│           Block AttnRes vs 标准PreNorm 对比                      │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  标准PreNorm:                                                   │
│    for layer in layers:
│        x = x + SelfAttn(Norm(x))        # 固定残差               │
│        x = x + MLP(Norm(x))             # 固定残差               │
│                                                                 │
│  Block AttnRes:                                                 │
│    for block in blocks:                                         │
│        h = BlockAttnRes(prev_blocks, Norm(x))  # 注意力残差     │
│        x = x + h                                          │
│        x = x + SelfAttn(Norm(x))        # 固定残差               │
│        x = x + MLP(Norm(x))             # 固定残差               │
│        save_block_output(x)                                    │
│                                                                 │
│  关键区别:                                                       │
│    • Block AttnRes在每个Block开始时,对所有先前Block做注意力       │
│    • 这允许后期Block选择性访问早期Block的信息                     │
│    • 而标准PreNorm只能通过固定的残差连接传递信息                   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3.6 为什么被马斯克和Karpathy点赞?

Text Only
核心洞察:
├── 自2015年ResNet以来,残差连接没有任何实质性变化
├── Kimi是第一个既有理论依据,又能大规模部署的替代方案
├── 发现了PreNormdilution(PreNorm稀释)问题
│   └── 深度网络中早期层的信息在向后传播时被稀释
└── 提供了统一的结构化矩阵分析框架

技术突破:
├── 用注意力机制替代固定"+"号是可行且有效的
├── Block AttnRes大幅降低内存开销
└── 可作为drop-in replacement,无需改变模型结构

3.7 AttnRes与其他注意力机制对比

3.7.1 技术定位对比

Text Only
┌────────────────────────────────────────────────────────────────────────┐
│                    AttnRes vs 其他注意力机制                            │
├────────────────────────────────────────────────────────────────────────┤
│                                                                        │
│  标准MHA(多头注意力):                                               │
│  ├── 位置:Token级别(序列内token之间的交互)                           │
│  ├── 复杂度:O(n²)                                                    │
│  ├── 作用:聚合序列内的信息                                            │
│  └── AttnRes vs MHA:AttnRes是Layer级别,MHA是Token级别                 │
│                                                                        │
│  GQA/MQA(分组查询注意力):                                           │
│  ├── 位置:Token级别,但减少KV头数                                     │
│  ├── 目标:减少KV Cache                                               │
│  └── 与AttnRes关系:互补技术,可叠加使用                                │
│                                                                        │
│  MLA(多头潜在注意力):                                               │
│  ├── 位置:Token级别,KV压缩                                          │
│  ├── 目标:减少90%+ KV Cache                                         │
│  └── 与AttnRes关系:互补技术,聚焦于不同的维度                          │
│                                                                        │
│  MoBA(混合块注意力):                                               │
│  ├── 位置:Block级别,块内全注意,块间稀疏                             │
│  ├── 目标:O(n·k)复杂度                                               │
│  └── 与AttnRes关系:都关注Block级别,但MoBA是稀疏化,AttnRes是选择性    │
│                                                                        │
│  Attention Residuals(注意力残差):                                   │
│  ├── 位置:Layer/Block级别                                            │
│  ├── 目标:用注意力替代固定残差,实现选择性信息传递                      │
│  ├── 特点:不改变注意力复杂度,但改变层间信息流                         │
│  └── 创新:首次将"选择性"引入残差连接                                  │
│                                                                        │
└────────────────────────────────────────────────────────────────────────┘

3.7.2 解决的问题维度对比

Text Only
┌─────────────────────────────────────────────────────────────────┐
│              各机制解决的问题维度                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  问题维度:                                                       │
│  1. Token级别:序列中token之间的交互                               │
│     └── MHA, GQA, MLA, 稀疏注意力, MoBA                          │
│                                                                 │
│  2. Layer/Block级别:层与层之间的信息传递                          │
│     └── Attention Residuals(本文重点)                           │
│                                                                 │
│  3. KV Cache级别:推理时存储历史token                              │
│     └── GQA, MLA, PagedAttention                                │
│                                                                 │
│  4. Sequence级别:跨序列/全局信息                                   │
│     └── 全局token, Longformer, BigBird                           │
│                                                                 │
│  Attention Residuals的特殊性:                                    │
│  • 这是自ResNet以来首次有实质性变化的残差连接设计                  │
│  • 之前所有改进都集中在Token级别的注意力计算                      │
│  • AttnRes揭示了Layer级别信息流的重要性                           │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3.7.3 计算开销对比

Text Only
┌────────────────────────────────────────────────────────────────────────┐
│                    计算开销与收益对比                                   │
├────────────────────────────────────────────────────────────────────────┤
│                                                                        │
│  机制           计算开销       内存开销      质量提升    工程复杂度    │
│  ─────────────────────────────────────────────────────────────────────│
│  标准MHA        基准           基准          基准        ⭐           │
│  FlashAttention   ~1x          ~1x          相同        ⭐⭐⭐        │
│  GQA            ~1x           -75%         ~0          ⭐⭐          │
│  MLA            ~1x           -90%         ~0          ⭐⭐⭐⭐       │
│  稀疏注意力     -50%~70%     -50%~70%     -1~5%       ⭐⭐⭐⭐⭐     │
│  MoBA           -50%~70%     -50%~70%     -1~3%       ⭐⭐⭐⭐⭐     │
│  AttnRes        +25%          +0.5MB       +1~7%       ⭐⭐           │
│                                                                        │
│  AttnRes的独特价值:                                                  │
│  • 计算量增加最小(仅25%)                                           │
│  • 内存开销几乎忽略不计(~0.5MB)                                     │
│  • 质量提升显著(尤其是多步推理任务+7.5%)                             │
│  • 实现复杂度低(可作为drop-in replacement)                          │
│  • 这是"免费午餐":少量开销换取显著收益                                │
│                                                                        │
└────────────────────────────────────────────────────────────────────────┘

3.7.4 适用场景建议

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                  AttnRes 适用场景                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  推荐使用 AttnRes 的场景:                                        │
│  ✓ 需要多步推理的任务(数学、代码、逻辑)                          │
│  ✓ 深度网络(20层以上)                                          │
│  ✓ 希望提升模型质量而不增加太多计算量                              │
│  ✓ 需要作为drop-in replacement快速实验                            │
│                                                                 │
│  可选的组合方案:                                                  │
│  • AttnRes + GQA:质量提升 + KV Cache优化                        │
│  • AttnRes + MLA:质量提升 + 超长上下文                           │
│  • AttnRes + FlashAttention:质量提升 + 训练加速                  │
│                                                                 │
│  不推荐单独使用 AttnRes 的场景:                                   │
│  ✗ 极长序列(1M+ tokens)- 考虑MoBA或稀疏注意力                   │
│  ✗ 资源极其受限 - AttnRes有25%额外计算                            │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

4. DeepSeek NSA稀疏注意力

4.1 NSA架构

Python
class NativeSparseAttention(nn.Module):
    """
    Native Sparse Attention (NSA)
    DeepSeek 2025年提出的可学习稀疏注意力

    论文:Native Sparse Attention: Efficient and Scalable Attention
          with Sparse N-gram Patterns (arXiv:2502.11089)

    核心创新:
    1. 可学习的稀疏模式(非固定窗口)
    2. N-gram稀疏结构
    3. 与FlashAttention兼容
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        sparse_k: int = 64,  # 每个token只关注k个关键token
        num_ngram_patterns: int = 4
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.sparse_k = sparse_k
        self.num_ngram_patterns = num_ngram_patterns

        # 标准投影
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

        # 稀疏模式学习
        self.sparse_scorer = nn.Linear(d_model, 1)

        # N-gram模式嵌入
        self.ngram_embeddings = nn.Parameter(
            torch.randn(num_ngram_patterns, d_model) * 0.02
        )

        self.scale = self.head_dim ** -0.5

    def forward(self, x: torch.Tensor, mask=None):
        """
        Args:
            x: [batch, seq_len, d_model]
        """
        batch_size, seq_len, _ = x.shape

        # 投影
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 1. 计算稀疏索引
        sparse_indices = self._compute_sparse_indices(x)

        # 2. 稀疏注意力计算
        output = self._sparse_attention(Q, K, V, sparse_indices)

        # 3. 输出投影
        output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output

    def _compute_sparse_indices(self, x):
        """
        计算稀疏注意力索引
        每个token选择top-k个最重要的历史token
        """
        batch_size, seq_len, _ = x.shape

        # 计算每个token的重要性分数
        scores = self.sparse_scorer(x).squeeze(-1)  # [batch, seq_len]

        indices = []
        for i in range(seq_len):
            # 对于位置i,从[0, i]中选择top-k
            valid_scores = scores[:, :i+1]
            if i + 1 <= self.sparse_k:
                # 序列不够长,全部选择
                idx = torch.arange(i + 1, device=x.device).unsqueeze(0).expand(batch_size, -1)
            else:
                # 选择top-k
                _, idx = valid_scores.topk(self.sparse_k, dim=1)
            indices.append(idx)

        return indices

    def _sparse_attention(self, Q, K, V, sparse_indices):
        """
        根据稀疏索引计算注意力
        """
        batch_size, num_heads, seq_len, head_dim = Q.shape
        outputs = []

        for i, idx in enumerate(sparse_indices):
            # 当前位置的query
            q_i = Q[:, :, i:i+1, :]  # [batch, heads, 1, head_dim]

            # 根据索引收集KV
            # idx: [batch, k]
            idx_expanded = idx.unsqueeze(1).unsqueeze(-1).expand(
                -1, num_heads, -1, head_dim
            )  # [batch, heads, k, head_dim]

            k_selected = torch.gather(
                K, 2, 
                idx_expanded.transpose(1, 2).reshape(batch_size, 1, -1, head_dim).expand(-1, num_heads, -1, -1)
            )
            v_selected = torch.gather(
                V, 2,
                idx_expanded.transpose(1, 2).reshape(batch_size, 1, -1, head_dim).expand(-1, num_heads, -1, -1)
            )

            # 计算注意力
            attn = torch.matmul(q_i, k_selected.transpose(-2, -1)) * self.scale
            attn = F.softmax(attn, dim=-1)
            out_i = torch.matmul(attn, v_selected)

            outputs.append(out_i)

        return torch.cat(outputs, dim=2)


class HybridSparseAttention(nn.Module):
    """
    混合稀疏注意力
    结合固定模式 + 动态模式
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        local_window: int = 512,
        global_tokens: int = 64,
        dynamic_k: int = 32
    ):
        super().__init__()
        self.local_window = local_window
        self.global_tokens = global_tokens
        self.dynamic_k = dynamic_k

        # 局部注意力
        self.local_attn = nn.MultiheadAttention(d_model, num_heads)

        # 全局token
        self.global_tokens = nn.Parameter(torch.randn(global_tokens, d_model))

        # 动态稀疏选择器
        self.dynamic_selector = nn.Linear(d_model, 1)

    def forward(self, x):
        """
        混合注意力:
        1. 局部窗口(固定)
        2. 全局token(固定)
        3. 动态选择(可学习)
        """
        batch_size, seq_len, _ = x.shape

        # 1. 局部注意力
        local_out = self._local_attention(x)

        # 2. 全局注意力
        global_out = self._global_attention(x)

        # 3. 动态稀疏
        dynamic_out = self._dynamic_attention(x)

        # 融合
        output = local_out + global_out + dynamic_out

        return output

    def _local_attention(self, x):
        """局部窗口注意力"""
        # 使用滑动窗口实现
        # 简化实现,实际需要更复杂的分块
        return x  # placeholder

    def _global_attention(self, x):
        """全局token注意力"""
        batch_size = x.shape[0]
        global_tokens = self.global_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        combined = torch.cat([global_tokens, x], dim=1)

        attn_output, _ = self.local_attn(combined, combined, combined)
        return attn_output[:, self.global_tokens.shape[0]:, :]

    def _dynamic_attention(self, x):
        """动态稀疏注意力"""
        # 根据内容选择关键token
        scores = self.dynamic_selector(x).squeeze(-1)
        _, top_indices = scores.topk(self.dynamic_k, dim=1)

        # 收集选中的token
        # 简化实现
        return x  # placeholder

3.2 NSA与DSA的关系

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    NSA vs DSA 对比                               │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  DSA (DeepSeek Sparse Attention):                               │
│  ├── 2024年提出,DeepSeek-V3.2-Exp使用                          │
│  ├── 结构化稀疏 + 动态选择                                       │
│  ├── 与FlashAttention、MLA兼容                                  │
│  └── 侧重:工程落地,与现有优化兼容                              │
│                                                                 │
│  NSA (Native Sparse Attention):                                 │
│  ├── 2025年论文(arXiv:2502.11089)                             │
│  ├── 可学习的N-gram稀疏模式                                      │
│  ├── 端到端训练稀疏结构                                         │
│  └── 侧重:理论完备,可学习的稀疏                                │
│                                                                 │
│  共同点:                                                        │
│  • 目标都是O(n·k)复杂度,k << n                                 │
│  • 都保留关键信息通路                                            │
│  • 都可与FlashAttention兼容                                     │
│                                                                 │
│  区别:                                                          │
│  • DSA更工程化,NSA更学术化                                      │
│  • NSA的可学习稀疏模式更灵活                                     │
│  • DSA已在生产环境验证                                          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

5. MLA多头潜在注意力

5.1 MLA原理

Python
class MultiHeadLatentAttention(nn.Module):
    """
    Multi-Head Latent Attention (MLA)
    DeepSeek-V2/V3的核心创新

    核心思想:
    将KV压缩到低维潜在空间,大幅减少KV Cache

    KV Cache压缩比:90%+
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        kv_latent_dim: int = 512,  # 潜在空间维度,远小于d_model
        rope_dim: int = 64
    ):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_latent_dim = kv_latent_dim
        self.rope_dim = rope_dim

        # Query投影(标准)
        self.q_proj = nn.Linear(d_model, d_model)

        # KV压缩到潜在空间
        self.kv_compress = nn.Linear(d_model, kv_latent_dim)

        # 从潜在空间解压
        self.k_decompress = nn.Linear(kv_latent_dim, d_model)
        self.v_decompress = nn.Linear(kv_latent_dim, d_model)

        # 输出投影
        self.o_proj = nn.Linear(d_model, d_model)

        # RoPE(只应用于解压后的K)
        self.rope = RotaryPositionalEmbedding(rope_dim)

        self.scale = self.head_dim ** -0.5

    def forward(self, x: torch.Tensor, past_kv_cache=None):
        """
        Args:
            x: [batch, seq_len, d_model]
            past_kv_cache: 之前的KV潜在表示

        Returns:
            output: [batch, seq_len, d_model]
            new_kv_cache: 新的KV潜在表示(用于增量生成)
        """
        batch_size, seq_len, _ = x.shape

        # 1. Query投影
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        Q = Q.transpose(1, 2)  # [batch, heads, seq_len, head_dim]

        # 2. KV压缩到潜在空间(关键创新)
        KV_latent = self.kv_compress(x)  # [batch, seq_len, kv_latent_dim]

        # 3. 从潜在空间解压
        K = self.k_decompress(KV_latent).view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = self.v_decompress(KV_latent).view(batch_size, seq_len, self.num_heads, self.head_dim)

        # 应用RoPE
        K = self.rope(K)

        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # 4. 处理KV Cache
        if past_kv_cache is not None:
            # 只需要缓存潜在表示,而不是完整的KV
            KV_latent = torch.cat([past_kv_cache, KV_latent], dim=1)
            # 重新解压完整的KV
            K = self.k_decompress(KV_latent).view(batch_size, -1, self.num_heads, self.head_dim)
            V = self.v_decompress(KV_latent).view(batch_size, -1, self.num_heads, self.head_dim)
            K = self.rope(K)
            K = K.transpose(1, 2)
            V = V.transpose(1, 2)

        # 5. 标准注意力计算
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        output = torch.matmul(attn_weights, V)

        # 6. 输出投影
        output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output, KV_latent


class RotaryPositionalEmbedding(nn.Module):
    """旋转位置编码"""

    def __init__(self, dim: int, max_seq_len: int = 8192):
        super().__init__()
        self.dim = dim

        # 预计算频率
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # 预计算位置编码
        pos = torch.arange(max_seq_len).float()
        freqs = torch.einsum("i,j->ij", pos, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def forward(self, x):
        """应用RoPE"""
        seq_len = x.shape[2]
        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)

        return self._apply_rotary(x, cos, sin)

    def _apply_rotary(self, x, cos, sin):
        """旋转操作"""
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat([
            x1 * cos[..., :x1.shape[-1]] - x2 * sin[..., :x2.shape[-1]],
            x1 * sin[..., :x1.shape[-1]] + x2 * cos[..., :x2.shape[-1]]
        ], dim=-1)

5.2 MLA的KV Cache优势

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    MLA KV Cache 对比                             │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  标准MHA (128K上下文, 7B模型):                                   │
│  ├── K Cache: 128K × 32 heads × 128 dim × 2 bytes = 1GB         │
│  ├── V Cache: 同上 = 1GB                                        │
│  └── 总计: ~2GB per request                                     │
│                                                                 │
│  GQA (8组KV, 128K上下文):                                       │
│  ├── KV Cache: 128K × 8 groups × 128 dim × 2 bytes × 2 = 512MB  │
│  └── 节省: 75%                                                  │
│                                                                 │
│  MLA (kv_latent_dim=512, 128K上下文):                           │
│  ├── 潜在KV: 128K × 512 × 2 bytes = 128MB                       │
│  └── 节省: 93.75%                                               │
│                                                                 │
│  关键洞察:                                                      │
│  • MLA的KV Cache与头数无关,只与潜在维度有关                     │
│  • 潜在维度可以远小于d_model                                    │
│  • 这使得超长上下文(1M+)成为可能                               │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

6. 线性注意力与RNN化

6.1 线性注意力原理

Python
class LinearAttention(nn.Module):
    """
    线性注意力
    复杂度从O(n²)降到O(n)

    核心思想:利用kernel trick将softmax(QK^T)V分解

    标准注意力:softmax(QK^T/√d)V
    线性注意力:φ(Q)(φ(K)^T V) / φ(Q)Σφ(K)^T

    其中φ是非线性kernel函数(如elu+1)
    """

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor):
        batch_size, seq_len, _ = x.shape

        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 应用kernel函数 φ(x) = elu(x) + 1
        Q = F.elu(Q) + 1
        K = F.elu(K) + 1

        # 线性注意力计算
        # 标准方式:softmax(QK^T)V - O(n²)
        # 线性方式:Q(K^T V) - O(n)

        # K^T V: [batch, heads, head_dim, head_dim]
        KtV = torch.matmul(K.transpose(-2, -1), V)

        # Q(K^T V): [batch, heads, seq_len, head_dim]
        numerator = torch.matmul(Q, KtV)

        # 归一化项:Q ΣK^T
        sum_K = K.sum(dim=2, keepdim=True)  # [batch, heads, 1, head_dim]
        denominator = torch.matmul(Q, sum_K.transpose(-2, -1))

        # 归一化
        output = numerator / (denominator + 1e-6)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output


class RWKVLikeAttention(nn.Module):
    """
    RWKV风格的线性注意力
    结合RNN的效率和Transformer的能力

    核心特点:
    1. 训练时可以并行(像Transformer)
    2. 推理时可以RNN化(像RNN)
    3. O(n)复杂度
    """

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        # 时间衰减因子(可学习)
        self.time_decay = nn.Parameter(torch.randn(num_heads, d_model // num_heads))
        self.time_first = nn.Parameter(torch.randn(num_heads, d_model // num_heads))

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor):
        """
        RWKV风格的线性注意力
        """
        batch_size, seq_len, _ = x.shape

        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)

        # RWKV核心:wkv (weighted key-value)
        # 使用指数衰减的累积和

        output = self._wkv(Q, K, V)

        output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
        output = self.o_proj(output)

        return output

    def _wkv(self, Q, K, V):
        """
        WKV: Weighted Key-Value
        核心公式:wkv_t = (a_{t-1} * w + b_{t-1}) / (c_{t-1} * w + d_{t-1})
        其中a, b, c, d是累积状态
        """
        # 简化实现,实际RWKV有更复杂的数值稳定性处理
        batch_size, num_heads, seq_len, head_dim = Q.shape

        # 使用softmax的线性近似
        w = torch.exp(-torch.exp(self.time_decay))

        output = []
        state_num = torch.zeros(batch_size, num_heads, head_dim, device=Q.device)
        state_den = torch.zeros(batch_size, num_heads, head_dim, device=Q.device)

        for t in range(seq_len):
            q_t = Q[:, :, t, :]
            k_t = K[:, :, t, :]
            v_t = V[:, :, t, :]

            # 更新状态
            state_num = w * state_num + torch.exp(k_t) * v_t
            state_den = w * state_den + torch.exp(k_t)

            # 计算输出
            out_t = q_t * state_num / (state_den + 1e-6)
            output.append(out_t)

        return torch.stack(output, dim=2)

6.2 线性注意力的权衡

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    线性注意力权衡                                │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  优势:                                                          │
│  ├── O(n)复杂度,适合极长序列                                   │
│  ├── 推理时可RNN化,内存占用恒定                                │
│  └── 训练时可并行                                               │
│                                                                 │
│  劣势:                                                          │
│  ├── 质量通常低于标准注意力                                     │
│  ├── kernel近似引入误差                                         │
│  └── 某些任务(如复制、精确检索)表现较差                        │
│                                                                 │
│  适用场景:                                                      │
│  ├── 超长序列(100K+)                                          │
│  ├── 流式处理                                                   │
│  └── 对精度要求不极高的场景                                     │
│                                                                 │
│  代表模型:                                                      │
│  ├── RWKV                                                       │
│  ├── Mamba(状态空间模型)                                      │
│  ├── Linear Transformer                                         │
│  └── Performer                                                  │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

7. 注意力机制选型指南

7.1 选型决策树

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    注意力机制选型决策树                          │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  序列长度 < 8K?                                                │
│  ├── 是 → 标准MHA + FlashAttention                              │
│  │        └── 最简单,质量最好                                  │
│  │                                                              │
│  └── 否 → 序列长度 < 128K?                                     │
│           ├── 是 → 需要最高质量?                               │
│           │        ├── 是 → GQA + FlashAttention + KV Cache优化 │
│           │        └── 否 → MLA(DeepSeek风格)                 │
│           │                                                     │
│           └── 否 → 序列长度 < 1M?                              │
│                    ├── 是 → 稀疏注意力(MoBA/NSA)              │
│                    │        └── 局部+全局混合                   │
│                    │                                            │
│                    └── 否 → 线性注意力(RWKV/Mamba)            │
│                             └── 牺牲部分质量换取效率            │
│                                                                 │
│  特殊需求:                                                      │
│  • 需要流式推理 → 线性注意力/RNN化                              │
│  • 需要多模态 → M-RoPE + 标准注意力                             │
│  • 需要最高吞吐 → MLA + PagedAttention                          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

7.2 各方案实现复杂度

Text Only
┌────────────────────────────────────────────────────────────────────────┐
│                        实现复杂度对比                                   │├────────────────────────────────────────────────────────────────────────┤
│                                                                        │
│  方案          实现难度    CUDA优化    与现有生态兼容    推荐指数      │
│  ─────────────────────────────────────────────────────────────────────│
│  MHA           ⭐          需要        完美            ⭐⭐⭐         │
│  FlashAttn     ⭐⭐⭐       必须        完美            ⭐⭐⭐⭐⭐     │
│  GQA           ⭐⭐        需要        完美            ⭐⭐⭐⭐⭐     │
│  MLA           ⭐⭐⭐⭐     必须        需要适配        ⭐⭐⭐⭐       │
│  稀疏注意力    ⭐⭐⭐⭐⭐   必须        需要适配        ⭐⭐⭐         │
│  线性注意力    ⭐⭐⭐       可选        需要适配        ⭐⭐⭐         │
│                                                                        │
│  建议:                                                                │
│  • 新项目:优先使用FlashAttention + GQA(vLLM/PyTorch原生支持)        │
│  • 长上下文:考虑MLA(DeepSeek开源实现)                               │
│  • 超长序列:稀疏注意力(需要自定义CUDA kernel)                       │
│                                                                        │
└────────────────────────────────────────────────────────────────────────┘

8. 面试高频问答

Q1: Kimi的MoBA(混合块注意力)是什么?解决了什么问题?

:MoBA(Mixture of Block Attention)是月之暗面2025年提出的注意力机制创新。

核心思想: 1. 将序列分块,块内计算全注意力 2. 通过"注意力残差"连接保留全局信息流 3. 动态路由决定跨块交互强度

解决的问题: - 长上下文(1M+ tokens)的计算成本 - 传统稀疏注意力固定模式不够灵活 - 在O(n·block_size)复杂度下保持接近全注意力的质量

Q2: Kimi的Attention Residuals(注意力残差)是什么?与MoBA有何区别?

:Attention Residuals是月之暗面2026年提出的另一种创新,与MoBA关注点不同。

核心思想: 1. 标准残差:h_{l+1} = h_l + f_l(h_l)(固定"+"号) 2. AttnRes:h_l = Σ α_{i→l} · v_i(注意力加权的累积) 3. 用注意力权重替代固定的"+"号,实现选择性信息传递

与MoBA的区别: | 方面 | MoBA | Attention Residuals | |------|------|-------------------| | 关注点 | Token级别的序列分块 | Layer/Block级别的信息流 | | 目标 | 降低序列长度复杂度 | 改善深层网络的信息传递 | | 时间 | 2025年 | 2026年 | | 复杂度 | O(n·k) | O(n²) + 25%额外开销 |

Attention Residuals的独特价值: - 首次用注意力机制替代固定残差连接 - 解决PreNorm稀释问题 - 仅+25%计算量换取显著质量提升(GPQA +7.5) - 可作为drop-in replacement

Q3: MLA相比GQA有什么优势?

:MLA(多头潜在注意力)的核心优势:

  1. KV Cache压缩比更高
  2. GQA:通过共享KV减少~75%
  3. MLA:通过潜在空间压缩减少~93%+

  4. 与头数解耦

  5. GQA:KV Cache仍与组数相关
  6. MLA:KV Cache只与潜在维度相关

  7. 更适合超长上下文

  8. MLA使得1M+上下文成为可能

代价是:实现更复杂,需要自定义kernel

Q4: 稀疏注意力会损失多少质量?

:取决于稀疏策略:

  • 固定窗口稀疏:损失较大,~5-10%任务性能下降
  • 动态稀疏(NSA/DSA):损失较小,~1-3%下降
  • MoBA(带残差):几乎无损,~0-1%下降

关键是保留关键信息通路(全局token、残差连接)

Q5: 为什么线性注意力没有成为主流?

:主要原因:

  1. 质量损失:kernel近似引入误差,某些任务表现较差
  2. 生态兼容:需要重写训练/推理框架
  3. FlashAttention的进步:标准注意力效率大幅提升,缩小了差距
  4. MLA等方案:在保持O(n²)计算的同时解决了显存问题

但在极长序列(1M+)场景,线性注意力仍有价值

Q6: 如何选择适合自己场景的注意力机制?

:按序列长度选择:

  • <8K:标准MHA + FlashAttention
  • 8K-128K:GQA + FlashAttention(或MLA)
  • 128K-1M:稀疏注意力(MoBA/NSA)
  • >1M:线性注意力(RWKV/Mamba)

其他考虑因素: - 是否需要流式推理 - 是否有多模态需求 - 工程实现成本 - 是否需要提升模型质量(考虑AttnRes)


本章小结

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                      核心要点总结                                │├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. 注意力机制演进主线:                                         │
│     效率优化:MHA → GQA → MLA → 稀疏注意力                      │
│     复杂度:O(n²) → O(n²)但显存↓ → O(n·k) → O(n)               │
│                                                                 │
│  2. Kimi MoBA核心创新:                                          │
│     • 块内全注意力 + 块间残差                                    │
│     • 动态路由 + 门控融合                                       │
│     • 适合1M+超长上下文                                         │
│                                                                 │
│  3. Kimi Attention Residuals(注意力残差):                      │
│     • 核心公式:h_l = Σ α_{i→l} · v_i(注意力加权累积)        │
│     • 创新:用注意力机制替代固定的"+"号残差连接                   │
│     • Block AttnRes:内存从O(Ld)降至O(Nd),8块即可恢复90%+收益  │
│     • 性能提升:GPQA-Diamond +7.5,HumanEval +3.1              │
│     • 计算开销:仅+25%,内存开销~0.5MB                          │
│     • 理论贡献:首次揭示PreNorm稀释问题                          │
│                                                                 │
│  4. DeepSeek NSA/DSA:                                          │
│     • 可学习的稀疏模式                                          │
│     • 与FlashAttention兼容                                      │
│     • 工程落地验证                                              │
│                                                                 │
│  5. MLA多头潜在注意力:                                          │
│     • KV Cache压缩90%+                                          │
│     • 与头数解耦                                                │
│     • DeepSeek-V2/V3采用                                        │
│                                                                 │
│  6. 选型建议:                                                   │
│     • 短序列:FlashAttention + GQA                              │
│     • 长序列:MLA                                               │
│     • 超长序列:稀疏/线性注意力                                 │
│     • 提升质量(任何长度):AttnRes                             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

📝 本章练习

🤔 思考题

  1. MoBA:Mixture of Block Attention 的核心思想是什么?它是如何在保持注意力的灵活性的同时降低计算复杂度的?
  2. Attention Residuals:AttnRes 的残差连接和 ResNet 中的残差连接有什么异同?为什么残差连接能提升注意力质量?
  3. NSA vs FlashAttention:Native Sparse Attention 和 FlashAttention 的优化目标不同,各自适合什么场景?

💻 代码实践

  1. 入门:实现一个简化的 Block Attention,对比全注意力的计算量差异
  2. 进阶:实现 MoBA 的路由机制,观察不同 Block 大小对性能的影响
💡 参考答案 #### 思考题参考答案 **1. MoBA** 核心思想:将注意力计算分块处理,每个 Query 只关注最相关的 Block(而非所有 Token)。通过路由机制选择相关 Block,实现计算量的稀疏化。 灵活性:路由是动态的(基于输入内容),而非固定的稀疏模式,因此能适应不同类型的任务。 **2. Attention Residuals** 异同: - ResNet 残差:解决深层网络梯度消失,`output = F(x) + x` - AttnRes 残差:在注意力层间添加跨层连接,保留原始信息流 提升质量的原因:注意力层可能"过度关注"某些模式而忽略原始信息,残差连接确保原始信息不被完全覆盖。 **3. NSA vs FlashAttention** - **FlashAttention**:优化全注意力的计算效率(减少内存访问),适合需要全注意力的场景 - **NSA**:减少注意力的计算量(稀疏化),适合超长序列(100K+ tokens) 场景:FlashAttention 适合短-中序列;NSA 适合超长序列(如完整代码库、长文档)。

扩展阅读

  1. MoBA: Mixture of Block Attention (Kimi, 2025)
  2. Attention Residuals: From ResNets to Transformers (Kimi, 2026)
  3. Native Sparse Attention (DeepSeek, arXiv:2502.11089)
  4. DeepSeek-V2: MLA Technical Report (2024)
  5. FlashAttention-2 (Dao, 2023)
  6. RWKV: Reinventing RNNs for the Transformer Era (2023)

最后更新日期: 2026-04-21 适用版本: LLM 学习教程 v2026