13. 注意力机制前沿创新¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
核心问题:Kimi新发布的注意力残差(Attention Residuals)是什么?当前注意力机制有哪些前沿创新?
目录¶
- 注意力机制演进全景
- Kimi MoBA (Block Attention)
- Kimi Attention Residuals(注意力残差)
- DeepSeek NSA稀疏注意力
- MLA多头潜在注意力
- 线性注意力与RNN化
- 注意力机制选型指南
- 面试高频问答
1. 注意力机制演进全景¶
1.1 演进时间线¶
┌─────────────────────────────────────────────────────────────────┐
│ 注意力机制演进时间线 │├─────────────────────────────────────────────────────────────────┤
│ │
│ 2017: 标准多头注意力 (MHA) │
│ ├── O(n²) 复杂度 │
│ ├── 所有token两两交互 │
│ └── 问题:长序列计算/显存爆炸 │
│ │
│ 2020: 稀疏注意力 (Longformer/BigBird) │
│ ├── 局部窗口 + 全局token │
│ ├── O(n) 复杂度 │
│ └── 问题:固定稀疏模式,灵活性差 │
│ │
│ 2021: FlashAttention │
│ ├── IO感知的分块计算 │
│ ├── 计算复杂度不变,但大幅减少显存访问 │
│ └── 成为2022-2024标配 │
│ │
│ 2022: GQA/MQA (分组查询注意力) │
│ ├── 多个查询头共享KV │
│ ├── 推理时KV Cache大幅减少 │
│ └── LLaMA 2/3采用 │
│ │
│ 2024: MLA (多头潜在注意力) │
│ ├── KV压缩到低维潜在空间 │
│ ├── KV Cache减少90%+ │
│ └── DeepSeek-V2/V3采用 │
│ │
│ 2024-2025: 动态稀疏注意力 │
│ ├── DSA (DeepSeek Sparse Attention) │
│ ├── NSA (Native Sparse Attention) │
│ └── MoBA (Mixture of Block Attention) │
│ │
│ 2026: Attention Residuals(注意力残差) │
│ ├── 用注意力机制替代固定残差连接 │
│ ├── 选择性聚合早期表示 │
│ └── Kimi 2026年3月发布,被马斯克/Karpathy点赞 │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 各方案对比¶
┌────────────────────────────────────────────────────────────────────────┐
│ 注意力机制对比 │├────────────────────────────────────────────────────────────────────────┤
│ │
│ 方案 复杂度 KV Cache 长上下文 质量 适用场景 │
│ ─────────────────────────────────────────────────────────────────────│
│ MHA O(n²) 100% 差 最好 短序列 │
│ GQA O(n²) 25-50% 一般 好 通用 │
│ MLA O(n²) 5-10% 好 好 长上下文 │
│ FlashAttn O(n²) 100% 一般 最好 训练加速 │
│ 稀疏注意力 O(n·k) 变化 好 较好 超长序列 │
│ 线性注意力 O(n) O(n) 最好 较差 极长序列 │
│ │
│ 注意:复杂度是理论值,实际性能还受kernel实现、硬件等因素影响 │
│ │
└────────────────────────────────────────────────────────────────────────┘
2. Kimi MoBA (Block Attention)¶
2.1 背景与动机¶
Kimi(月之暗面)在2025年发布的MoBA(Mixture of Block Attention)
核心问题:
├── 长上下文(1M+ tokens)的注意力计算成本极高
├── 传统稀疏注意力的固定模式不够灵活
└── 需要在保持质量的同时实现线性复杂度
MoBA的核心创新:
├── 将序列分块,每块独立计算注意力
├── 引入"注意力残差"连接,保留全局信息流
└── 动态路由决定哪些块需要精细计算
2.2 MoBA架构详解¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoBA(nn.Module):
"""
Mixture of Block Attention (MoBA)
Kimi(月之暗面)2025年提出的注意力残差机制
核心思想:
1. 将序列分成多个块(Block)
2. 每个块内部计算全注意力
3. 块之间通过"注意力残差"传递信息
4. 动态路由决定跨块交互强度
"""
def __init__(
self,
d_model: int,
num_heads: int,
block_size: int = 512,
num_experts: int = 4,
dropout: float = 0.1
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.block_size = block_size
self.num_experts = num_experts
# 标准注意力投影
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
# 块级路由器
self.block_router = nn.Linear(d_model, num_experts)
# 注意力残差门控
self.residual_gate = nn.Sequential(
nn.Linear(d_model * 2, d_model),
nn.Sigmoid()
)
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
output: [batch, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.shape
# 1. 标准投影
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# 重塑为多头形式
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 2. 分块注意力计算
num_blocks = (seq_len + self.block_size - 1) // self.block_size
# 块内注意力(局部)
local_attn = self._compute_block_attention(Q, K, V, num_blocks)
# 3. 计算注意力残差(全局信息)
global_residual = self._compute_attention_residual(Q, K, V, x)
# 4. 门控融合
gate_input = torch.cat([
local_attn.transpose(1, 2).reshape(batch_size, seq_len, -1),
global_residual
], dim=-1)
gate = self.residual_gate(gate_input)
# 融合局部注意力和全局残差
local_flat = local_attn.transpose(1, 2).reshape(batch_size, seq_len, d_model)
output = gate * local_flat + (1 - gate) * global_residual
# 5. 输出投影
output = self.o_proj(output)
output = self.dropout(output)
return output
def _compute_block_attention(self, Q, K, V, num_blocks):
"""计算块内注意力"""
batch_size, num_heads, seq_len, head_dim = Q.shape
outputs = []
for i in range(num_blocks):
start = i * self.block_size
end = min((i + 1) * self.block_size, seq_len)
# 提取当前块
Q_block = Q[:, :, start:end, :]
K_block = K[:, :, start:end, :]
V_block = V[:, :, start:end, :]
# 块内注意力
attn_weights = torch.matmul(Q_block, K_block.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, V_block)
outputs.append(attn_output)
return torch.cat(outputs, dim=2)
def _compute_attention_residual(self, Q, K, V, original_x):
"""
计算注意力残差
核心创新:通过残差连接保留全局信息流
"""
batch_size, num_heads, seq_len, head_dim = Q.shape
# 使用下采样后的全局注意力(降低复杂度)
# 这里用池化实现简化版本
pooled_K = K.mean(dim=2, keepdim=True) # [batch, heads, 1, head_dim]
pooled_V = V.mean(dim=2, keepdim=True)
# 计算全局注意力权重
global_attn = torch.matmul(Q, pooled_K.transpose(-2, -1)) * self.scale
global_attn = F.softmax(global_attn, dim=-1)
# 全局残差
residual = torch.matmul(global_attn, pooled_V) # [batch, heads, seq_len, head_dim]
residual = residual.transpose(1, 2).reshape(batch_size, seq_len, -1)
# 与原始输入的残差连接
residual = residual + original_x
return residual
class MoBAWithDynamicRouting(nn.Module):
"""
带动态路由的MoBA
根据输入内容动态决定注意力模式
"""
def __init__(
self,
d_model: int,
num_heads: int,
block_size: int = 512,
top_k_blocks: int = 4
):
super().__init__()
self.moba = MoBA(d_model, num_heads, block_size)
self.top_k_blocks = top_k_blocks
# 跨块路由器
self.cross_block_router = nn.Linear(d_model, 1)
def forward(self, x: torch.Tensor):
"""
动态选择需要跨块交互的块
"""
batch_size, seq_len, d_model = x.shape
num_blocks = (seq_len + self.moba.block_size - 1) // self.moba.block_size
# 计算每个块的重要性分数
block_scores = []
for i in range(num_blocks):
start = i * self.moba.block_size
end = min((i + 1) * self.moba.block_size, seq_len)
block = x[:, start:end, :]
# 块的代表性向量
block_repr = block.mean(dim=1)
score = self.cross_block_router(block_repr)
block_scores.append(score)
block_scores = torch.cat(block_scores, dim=1) # [batch, num_blocks]
# 选择top-k重要块进行跨块交互
top_k_indices = block_scores.topk(min(self.top_k_blocks, num_blocks), dim=1).indices
# 标准MoBA计算
output = self.moba(x)
# 对选中的块添加额外的跨块注意力(简化实现)
# 实际实现会更复杂,包括跨块注意力计算
return output
2.3 MoBA vs 传统注意力¶
┌─────────────────────────────────────────────────────────────────┐
│ MoBA vs 传统注意力 │├─────────────────────────────────────────────────────────────────┤
│ │
│ 传统注意力(1M tokens): │
│ ├── 计算量:O(n²) = O(10^12) 次运算 │
│ ├── 显存:O(n²) ≈ 1TB+(不可行) │
│ └── 实际:需要复杂的分块和近似 │
│ │
│ MoBA(1M tokens,块大小512): │
│ ├── 计算量:O(n·block_size) = O(10^9) 次运算 │
│ ├── 显存:O(n·block_size) ≈ 1GB │
│ ├── 残差连接:保留全局信息流 │
│ └── 质量:接近全注意力 │
│ │
│ 关键创新: │
│ 1. 块内全注意力 + 块间残差 = 局部精度 + 全局感知 │
│ 2. 动态路由 = 自适应计算分配 │
│ 3. 门控融合 = 平衡局部与全局 │
│ │
└─────────────────────────────────────────────────────────────────┘
3. Kimi Attention Residuals(注意力残差)¶
3.1 背景与核心思想¶
Kimi(月之暗面)在2026年3月16日发布的Attention Residuals (AttnRes)
核心问题:
├── 标准残差连接:h_{l+1} = h_l + f_l(h_l)
│ └── 固定的"+"号,所有层以相同权重累积
├── 随着深度增长,每个层的贡献被稀释(PreNorm问题)
└── 隐藏状态幅度无限增长,梯度分布不均匀
AttnRes核心创新:
├── 用注意力机制替代固定的"+"号
├── 公式:h_l = Σ α_{i→l} · v_i(注意力加权的累积)
└── 每层都能选择性访问所有早期表示
3.2 Full AttnRes vs Block AttnRes¶
┌─────────────────────────────────────────────────────────────────┐
│ Attention Residuals 两种形式 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Full AttnRes: │
│ ├── 每层都对所有前面的层输出做注意力 │
│ ├── 内存复杂度 O(Ld)(L=层数, d=隐藏维度) │
│ └── 效果最好,但内存开销大 │
│ │
│ Block AttnRes (实用方案): │
│ ├── 将层分成N个块(Block) │
│ ├── 块内使用标准残差连接 │
│ ├── 块间使用注意力(只对N个块表示做注意力) │
│ ├── 内存复杂度降至 O(Nd)(N=块数,N << L) │
│ └── ~8个块就能恢复Full AttnRes的大部分收益 │
│ │
└─────────────────────────────────────────────────────────────────┘
3.3 核心公式详解¶
3.3.1 标准残差连接的问题¶
标准PreNorm残差连接:
h_{l+1} = h_l + f_l(h_l)
其中:
• h_l:第l层的隐藏状态
• f_l:第l层的变换函数(注意力层或MLP层)
• "+":固定的逐元素加法
问题分析:
1. 固定权重:所有层以相同权重(1.0)累积,信息流不均衡
2. PreNorm稀释:深层网络中,早期层信息在向后传播时被稀释
3. 幅度增长:随着层数增加,隐藏状态幅度可能无限增长
数学上看,标准残差等价于:
h_L = h_0 + Σ_{i=0}^{L-1} f_i(h_i)
这意味着所有层对最终输出的贡献权重都是1.0,没有选择性。
3.3.2 Full AttnRes公式推导¶
AttnRes的核心思想是用注意力权重替代固定的"+"号:
Full AttnRes公式:
h_l = Σ_{i=0}^{l-1} α_{i→l} · v_i
符号说明:
• h_l:第l层的输出隐藏状态
• α_{i→l}:从第i层指向第l层的注意力权重(标量)
• v_i:第i层的值向量(value),通常是第i层输出的非线性变换
• Σ:求和符号,遍历所有早期层
注意力权重的计算:
α_{i→l} = softmax(q_l · k_i / √d_k)
其中:
• q_l = W_q · h_l:第l层的query向量
• k_i = W_k · h_i:第i层的key向量
• d_k:key向量的维度(用于缩放)
物理意义:
α_{i→l} 表示第l层对第i层信息的"关注程度",类似于标准注意力机制中
query对key的关注程度。这使得模型可以自适应地选择保留哪些层的
信息,而不是一刀切地全部相加。
推导过程:
1. 标准残差:h_l = h_{l-1} + f_{l-1}(h_{l-1})
2. AttnRes推广:h_l = Σ_{i=0}^{l-1} α_{i→l} · f_i(h_i)
3. 简化形式:使用恒等变换 v_i = h_i,则:
h_l = Σ_{i=0}^{l-1} α_{i→l} · h_i
这种形式下,每一层都能选择性地从所有早期层中提取信息。
3.3.3 Block AttnRes:内存优化的实用方案¶
Full AttnRes的问题:
• 内存复杂度 O(L × d):每层都需要存储对所有早期层的注意力权重
• 对于100层模型,这变得不可行
Block AttnRes解决方案:
• 将L层分成N个块(Block),每个块包含B = L/N层
• 块内使用标准残差连接(局部信息保留)
• 块间使用注意力机制(跨块信息选择)
Block AttnRes公式:
• 块内(intra-block):
h_b^i = h_b^{i-1} + f_b^{i-1}(h_b^{i-1}) (标准残差)
• 块间(inter-block):
H_b = Σ_{j=0}^{b-1} β_{j→b} · B_j
其中:
• B_j:第j个Block的输出表示(块内残差累加的结果)
• β_{j→b}:从Block j到Block b的注意力权重
• H_b:Block b的最终输出(融合了所有先前Block的信息)
内存复杂度分析:
• Full AttnRes:O(L × d)(L=100层, d=4096维 → 400K参数)
• Block AttnRes (N=8块):O(N × d)(N=8, d=4096 → 32K参数)
• 内存节省:约12.5倍
实验发现:
• N=8个块就能恢复Full AttnRes的大部分收益(>90%)
• 这使得AttnRes变得实用
3.3.4 与传统残差连接的关键区别¶
┌─────────────────────────────────────────────────────────────────┐
│ 标准残差 vs Attention Residuals │├─────────────────────────────────────────────────────────────────┤
│ │
│ 标准残差: │
│ • 机制:固定的"+"运算 │
│ • 权重:所有层权重相同(1.0) │
│ • 信息流:单向流动,无法选择性回顾 │
│ • 表达能力:线性组合 │
│ │
│ Attention Residuals: │
│ • 机制:可学习的注意力权重 │
│ • 权重:动态计算,取决于内容 │
│ • 信息流:可选择性访问任意早期层 │
│ • 表达能力:非线性组合(softmax归一化) │
│ │
│ 本质区别: │
│ • 标准残差:y = Σ x_i(权重固定) │
│ • AttnRes:y = Σ softmax(q·k) · v(权重动态) │
│ │
│ 这类似于标准前馈网络与注意力的区别: │
│ 前馈:y = Wx(固定权重) │
│ 注意力:y = softmax(QK^T)V(动态权重) │
│ │
└─────────────────────────────────────────────────────────────────┘
3.4 实验结果(Kimi Linear 48B / 3B activated, 1.4T tokens)¶
3.4.1 下游任务性能对比¶
┌─────────────────────────────────────────────────────────────────┐
│ 下游任务性能对比 │
├─────────────────────────────────────────────────────────────────┤
│ 类别 基准测试 Baseline AttnRes 提升 │
│ ─────────────────────────────────────────────────────────────── │
│ 通用 MMLU 73.5 74.6 +1.1 │
│ 通用 GPQA-Diamond 36.9 44.4 +7.5 │
│ 通用 BBH 76.3 78.0 +1.7 │
│ 通用 TriviaQA 69.9 71.8 +1.9 │
│ 数学代码 Math 53.5 57.1 +3.6 │
│ 数学代码 HumanEval 59.1 62.2 +3.1 │
│ 数学代码 MBPP 72.0 73.9 +1.9 │
│ 中文 CMMLU 82.0 82.9 +0.9 │
│ 中文 C-Eval 79.6 82.5 +2.9 │
└─────────────────────────────────────────────────────────────────┘
最大提升:GPQA-Diamond +7.5(多步推理任务)
代码提升:HumanEval +3.1
Scaling Law:Block AttnRes达到1.25x计算量baseline的效果
3.4.2 关键发现¶
1. 多步推理任务提升最显著:
• GPQA-Diamond(研究生水平科学问题):+7.5分
• 这类任务需要模型在多步推理中保持信息传递
• AttnRes的选择性信息聚合正好适合此类任务
2. 数学和代码任务稳定提升:
• Math(数学问题):+3.6
• HumanEval(代码生成):+3.1
• 这些任务需要精确的信息保留和选择性访问
3. 常识推理任务也有提升:
• BBH(Big Bench Hard):+1.7
• TriviaQA:+1.9
4. 中文理解任务:
• C-Eval:+2.9
• CMMLU:+0.9
3.4.3 Block数量与性能的关系¶
┌─────────────────────────────────────────────────────────────────┐
│ Block数量 vs 任务性能(Kimi Linear 7B) │
├─────────────────────────────────────────────────────────────────┤
│ Block数 MMLU GPQA Math 计算量开销 │
│ ─────────────────────────────────────────────────────────────── │
│ 1 (标准残差) 72.8 35.2 51.2 1.00x │
│ 2 73.4 39.1 53.4 1.08x │
│ 4 74.1 42.3 55.8 1.15x │
│ 8 74.6 44.4 57.1 1.25x │
│ 16 74.8 45.1 57.6 1.45x │
│ Full AttnRes 74.9 45.8 57.9 1.80x │
└─────────────────────────────────────────────────────────────────┘
关键结论:
• Block=8时达到Full AttnRes效果的96%以上
• 计算量仅增加25%即可获得大部分收益
• 边际效益在Block>8后递减
3.4.4 内存和计算开销¶
┌─────────────────────────────────────────────────────────────────┐
│ AttnRes 内存开销分析 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 标准PreNorm Transformer: │
│ • 每层额外参数:0(只有固定的残差连接) │
│ • 内存开销:无 │
│ │
│ Full AttnRes: │
│ • 额外参数:O(L × d)(query/key投影 + 归一化) │
│ • 对于L=80, d=4096:约 80 × 4096 × 2 × 2 = 1.3M 参数 │
│ • 内存开销:~5MB(对于80层模型) │
│ │
│ Block AttnRes(N=8块): │
│ • 额外参数:O(N × d)(每块一个block-level attention) │
│ • 对于N=8, d=4096:约 8 × 4096 × 2 × 2 = 131K 参数 │
│ • 内存开销:~0.5MB │
│ │
│ 计算开销: │
│ • Block AttnRes:额外约25%的FLOPS(8块配置) │
│ • 这是"免费午餐":少量计算换取显著性能提升 │
│ │
└─────────────────────────────────────────────────────────────────┘
3.5 Block AttnRes 完整PyTorch实现¶
3.5.1 核心模块实现¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class BlockAttnRes(nn.Module):
"""
Block Attention Residual (Block AttnRes)
Kimi 2026年提出的注意力残差机制
论文:Attention Residuals: From ResNets to Transformers
核心思想:
1. 将层分成多个Block,块内使用标准残差
2. Block之间使用注意力机制进行信息选择
3. 用注意力权重替代固定的"+"号
Args:
d_model: 模型隐藏维度
num_blocks: Block数量(论文推荐8)
num_heads: 注意力头数
dropout: Dropout概率
"""
def __init__(
self,
d_model: int,
num_blocks: int = 8,
num_heads: int = 8,
dropout: float = 0.1
):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_blocks = num_blocks
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Block级注意力投影(用于计算跨Block的注意力)
# 将隐藏状态投影到query/key/value空间
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
# 块表示的归一化层
self.block_norm = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# 缩放因子
self.scale = self.head_dim ** -0.5
def forward(self, block_outputs: list, current_block_output: torch.Tensor) -> torch.Tensor:
"""
跨Block的注意力残差计算
Args:
block_outputs: 之前Block的输出列表,长度为(num_blocks - 1)
current_block_output: 当前Block的输出(尚未完全完成)
Returns:
融合了所有先前Block信息的输出
"""
# 所有Block的表示(包括当前Block的部分输出)
all_blocks = block_outputs + [current_block_output] # [N+1, B, T, D]
num_total_blocks = len(all_blocks)
# Stack: [N+1, B, T, D]
V = torch.stack(all_blocks, dim=0)
# 分别对V做K和V投影
# V: [N+1, B, T, D] -> [N+1, B, T, D]
K = self.block_norm(V)
K = self.k_proj(K)
# 对所有Block计算Query(用于从所有Block中提取信息)
# Query是当前Block输出经过投影得到
Q = self.q_proj(current_block_output) # [B, T, D]
Q = Q.unsqueeze(0).expand(num_total_blocks - 1, -1, -1, -1) # [N, B, T, D]
# 当前Block也作为自己的query
Q_self = self.q_proj(current_block_output).unsqueeze(0) # [1, B, T, D]
Q = torch.cat([Q, Q_self], dim=0) # [N+1, B, T, D]
# 重塑为多头形式: [N+1, B, T, H, D_h] -> [N+1, B, H, T, D_h]
Q = Q.view(num_total_blocks, -1, current_block_output.shape[1],
self.num_heads, self.head_dim).transpose(1, 3)
K = K.view(num_total_blocks, -1, current_block_output.shape[1],
self.num_heads, self.head_dim).transpose(1, 3)
V = V.view(num_total_blocks, -1, current_block_output.shape[1],
self.num_heads, self.head_dim).transpose(1, 3)
# 计算注意力权重: [N+1, B, H, T, T]
# 对于每个目标位置t,我们关注所有N+1个Block在位置t的表示
attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=0) # 在Block维度归一化
# 应用注意力: [N+1, B, H, T, D_h]
attended = torch.matmul(attn_weights, V)
# 合并多头: [N+1, B, T, D]
attended = attended.transpose(1, 3) # [N+1, B, T, H, D_h]
attended = attended.reshape(num_total_blocks, -1,
current_block_output.shape[1], self.d_model)
# 跨Block维度加权求和(实际上是softmax已经做了)
# 取第一个Block的加权结果(因为权重已经归一化)
# 或者可以重新加权,保留所有Block的信息
output = torch.einsum('n b t d->b t d', attended) # [B, T, D]
output = self.dropout(output)
output = self.o_proj(output)
return output
class BlockAttnResTransformerBlock(nn.Module):
"""
使用Block AttnRes的Transformer Block
标准的PreNorm结构:
x = x + Attn(Norm(x))
x = x + MLP(Norm(x))
Block AttnRes结构:
h = BlockAttnRes(previous_blocks, Norm(x))
x = x + Attn(Norm(h))
x = x + MLP(Norm(x))
"""
def __init__(
self,
d_model: int,
num_heads: int,
mlp_ratio: float = 4.0,
num_blocks: int = 8,
dropout: float = 0.1
):
super().__init__()
self.d_model = d_model
self.num_blocks = num_blocks
# Block AttnRes
self.block_attn_res = BlockAttnRes(
d_model=d_model,
num_blocks=num_blocks,
num_heads=num_heads,
dropout=dropout
)
# 标准PreNorm结构
self.attn_norm = nn.LayerNorm(d_model)
self.mlp_norm = nn.LayerNorm(d_model)
# 自注意力
self.self_attn = nn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
batch_first=True
)
# MLP
mlp_hidden_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, d_model),
nn.Dropout(dropout)
)
# 用于存储每个Block的输出
self.block_outputs = []
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, D] 输入序列
Returns:
output: [B, T, D] 输出序列
"""
# Step 1: Block AttnRes(在注意力之前应用)
# 将当前输入与之前的Block输出进行注意力聚合
if len(self.block_outputs) > 0:
# 计算跨Block的注意力残差
block_residual = self.block_attn_res(
self.block_outputs,
self.attn_norm(x)
)
# 融合到主路径
x = x + block_residual
# Step 2: 自注意力(标准PreNorm)
attn_out, _ = self.self_attn(x, x, x)
x = x + attn_out
# Step 3: MLP(标准PreNorm)
x = x + self.mlp(self.mlp_norm(x))
# Step 4: 保存当前Block输出(用于下一个Block)
# 这里简化处理,实际实现会更复杂
self.block_outputs.append(x.detach())
# 保持固定数量的Block输出
if len(self.block_outputs) > self.num_blocks:
self.block_outputs.pop(0)
return x
def reset_block_outputs(self):
"""重置Block输出缓存(每个序列开始时调用)"""
self.block_outputs = []
3.5.2 使用示例¶
# 示例:创建和使用Block AttnRes
def create_block_attn_res_model():
"""创建一个使用Block AttnRes的简化模型"""
class SimpleBlockAttnResLM(nn.Module):
"""简化语言模型"""
def __init__(self, vocab_size, d_model, num_heads, num_layers, num_blocks):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, 4096, d_model))
self.blocks = nn.ModuleList([
BlockAttnResTransformerBlock(
d_model=d_model,
num_heads=num_heads,
num_blocks=num_blocks
)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids):
# Embedding
x = self.embedding(input_ids)
x = x + self.pos_embedding[:, :x.shape[1], :]
# Block AttnRes层
for block in self.blocks:
block.reset_block_outputs() # 重置block缓存
x = block(x)
x = self.final_norm(x)
logits = self.lm_head(x)
return logits
return SimpleBlockAttnResLM(
vocab_size=32000,
d_model=1024,
num_heads=16,
num_layers=12,
num_blocks=8
)
# 测试代码
if __name__ == "__main__":
# 创建模型
model = create_block_attn_res_model()
# 模拟输入
batch_size = 2
seq_len = 128
vocab_size = 32000
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
# 前向传播
with torch.no_grad():
logits = model(input_ids)
print(f"Output shape: {logits.shape}") # [2, 128, 32000]
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
# 计算量估算(简化为每层约2 * d_model^2 * seq_len)
d_model = 1024
seq_len = 128
flops_per_layer = 2 * d_model * d_model * seq_len * 12 # 12层
print(f"Estimated FLOPs per forward pass: {flops_per_layer:,}")
3.5.3 与标准实现的对比¶
┌─────────────────────────────────────────────────────────────────┐
│ Block AttnRes vs 标准PreNorm 对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 标准PreNorm: │
│ for layer in layers:
│ x = x + SelfAttn(Norm(x)) # 固定残差 │
│ x = x + MLP(Norm(x)) # 固定残差 │
│ │
│ Block AttnRes: │
│ for block in blocks: │
│ h = BlockAttnRes(prev_blocks, Norm(x)) # 注意力残差 │
│ x = x + h │
│ x = x + SelfAttn(Norm(x)) # 固定残差 │
│ x = x + MLP(Norm(x)) # 固定残差 │
│ save_block_output(x) │
│ │
│ 关键区别: │
│ • Block AttnRes在每个Block开始时,对所有先前Block做注意力 │
│ • 这允许后期Block选择性访问早期Block的信息 │
│ • 而标准PreNorm只能通过固定的残差连接传递信息 │
│ │
└─────────────────────────────────────────────────────────────────┘
3.6 为什么被马斯克和Karpathy点赞?¶
核心洞察:
├── 自2015年ResNet以来,残差连接没有任何实质性变化
├── Kimi是第一个既有理论依据,又能大规模部署的替代方案
├── 发现了PreNormdilution(PreNorm稀释)问题
│ └── 深度网络中早期层的信息在向后传播时被稀释
└── 提供了统一的结构化矩阵分析框架
技术突破:
├── 用注意力机制替代固定"+"号是可行且有效的
├── Block AttnRes大幅降低内存开销
└── 可作为drop-in replacement,无需改变模型结构
3.7 AttnRes与其他注意力机制对比¶
3.7.1 技术定位对比¶
┌────────────────────────────────────────────────────────────────────────┐
│ AttnRes vs 其他注意力机制 │
├────────────────────────────────────────────────────────────────────────┤
│ │
│ 标准MHA(多头注意力): │
│ ├── 位置:Token级别(序列内token之间的交互) │
│ ├── 复杂度:O(n²) │
│ ├── 作用:聚合序列内的信息 │
│ └── AttnRes vs MHA:AttnRes是Layer级别,MHA是Token级别 │
│ │
│ GQA/MQA(分组查询注意力): │
│ ├── 位置:Token级别,但减少KV头数 │
│ ├── 目标:减少KV Cache │
│ └── 与AttnRes关系:互补技术,可叠加使用 │
│ │
│ MLA(多头潜在注意力): │
│ ├── 位置:Token级别,KV压缩 │
│ ├── 目标:减少90%+ KV Cache │
│ └── 与AttnRes关系:互补技术,聚焦于不同的维度 │
│ │
│ MoBA(混合块注意力): │
│ ├── 位置:Block级别,块内全注意,块间稀疏 │
│ ├── 目标:O(n·k)复杂度 │
│ └── 与AttnRes关系:都关注Block级别,但MoBA是稀疏化,AttnRes是选择性 │
│ │
│ Attention Residuals(注意力残差): │
│ ├── 位置:Layer/Block级别 │
│ ├── 目标:用注意力替代固定残差,实现选择性信息传递 │
│ ├── 特点:不改变注意力复杂度,但改变层间信息流 │
│ └── 创新:首次将"选择性"引入残差连接 │
│ │
└────────────────────────────────────────────────────────────────────────┘
3.7.2 解决的问题维度对比¶
┌─────────────────────────────────────────────────────────────────┐
│ 各机制解决的问题维度 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 问题维度: │
│ 1. Token级别:序列中token之间的交互 │
│ └── MHA, GQA, MLA, 稀疏注意力, MoBA │
│ │
│ 2. Layer/Block级别:层与层之间的信息传递 │
│ └── Attention Residuals(本文重点) │
│ │
│ 3. KV Cache级别:推理时存储历史token │
│ └── GQA, MLA, PagedAttention │
│ │
│ 4. Sequence级别:跨序列/全局信息 │
│ └── 全局token, Longformer, BigBird │
│ │
│ Attention Residuals的特殊性: │
│ • 这是自ResNet以来首次有实质性变化的残差连接设计 │
│ • 之前所有改进都集中在Token级别的注意力计算 │
│ • AttnRes揭示了Layer级别信息流的重要性 │
│ │
└─────────────────────────────────────────────────────────────────┘
3.7.3 计算开销对比¶
┌────────────────────────────────────────────────────────────────────────┐
│ 计算开销与收益对比 │
├────────────────────────────────────────────────────────────────────────┤
│ │
│ 机制 计算开销 内存开销 质量提升 工程复杂度 │
│ ─────────────────────────────────────────────────────────────────────│
│ 标准MHA 基准 基准 基准 ⭐ │
│ FlashAttention ~1x ~1x 相同 ⭐⭐⭐ │
│ GQA ~1x -75% ~0 ⭐⭐ │
│ MLA ~1x -90% ~0 ⭐⭐⭐⭐ │
│ 稀疏注意力 -50%~70% -50%~70% -1~5% ⭐⭐⭐⭐⭐ │
│ MoBA -50%~70% -50%~70% -1~3% ⭐⭐⭐⭐⭐ │
│ AttnRes +25% +0.5MB +1~7% ⭐⭐ │
│ │
│ AttnRes的独特价值: │
│ • 计算量增加最小(仅25%) │
│ • 内存开销几乎忽略不计(~0.5MB) │
│ • 质量提升显著(尤其是多步推理任务+7.5%) │
│ • 实现复杂度低(可作为drop-in replacement) │
│ • 这是"免费午餐":少量开销换取显著收益 │
│ │
└────────────────────────────────────────────────────────────────────────┘
3.7.4 适用场景建议¶
┌─────────────────────────────────────────────────────────────────┐
│ AttnRes 适用场景 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 推荐使用 AttnRes 的场景: │
│ ✓ 需要多步推理的任务(数学、代码、逻辑) │
│ ✓ 深度网络(20层以上) │
│ ✓ 希望提升模型质量而不增加太多计算量 │
│ ✓ 需要作为drop-in replacement快速实验 │
│ │
│ 可选的组合方案: │
│ • AttnRes + GQA:质量提升 + KV Cache优化 │
│ • AttnRes + MLA:质量提升 + 超长上下文 │
│ • AttnRes + FlashAttention:质量提升 + 训练加速 │
│ │
│ 不推荐单独使用 AttnRes 的场景: │
│ ✗ 极长序列(1M+ tokens)- 考虑MoBA或稀疏注意力 │
│ ✗ 资源极其受限 - AttnRes有25%额外计算 │
│ │
└─────────────────────────────────────────────────────────────────┘
4. DeepSeek NSA稀疏注意力¶
4.1 NSA架构¶
class NativeSparseAttention(nn.Module):
"""
Native Sparse Attention (NSA)
DeepSeek 2025年提出的可学习稀疏注意力
论文:Native Sparse Attention: Efficient and Scalable Attention
with Sparse N-gram Patterns (arXiv:2502.11089)
核心创新:
1. 可学习的稀疏模式(非固定窗口)
2. N-gram稀疏结构
3. 与FlashAttention兼容
"""
def __init__(
self,
d_model: int,
num_heads: int,
sparse_k: int = 64, # 每个token只关注k个关键token
num_ngram_patterns: int = 4
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.sparse_k = sparse_k
self.num_ngram_patterns = num_ngram_patterns
# 标准投影
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
# 稀疏模式学习
self.sparse_scorer = nn.Linear(d_model, 1)
# N-gram模式嵌入
self.ngram_embeddings = nn.Parameter(
torch.randn(num_ngram_patterns, d_model) * 0.02
)
self.scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: [batch, seq_len, d_model]
"""
batch_size, seq_len, _ = x.shape
# 投影
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 1. 计算稀疏索引
sparse_indices = self._compute_sparse_indices(x)
# 2. 稀疏注意力计算
output = self._sparse_attention(Q, K, V, sparse_indices)
# 3. 输出投影
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
output = self.o_proj(output)
return output
def _compute_sparse_indices(self, x):
"""
计算稀疏注意力索引
每个token选择top-k个最重要的历史token
"""
batch_size, seq_len, _ = x.shape
# 计算每个token的重要性分数
scores = self.sparse_scorer(x).squeeze(-1) # [batch, seq_len]
indices = []
for i in range(seq_len):
# 对于位置i,从[0, i]中选择top-k
valid_scores = scores[:, :i+1]
if i + 1 <= self.sparse_k:
# 序列不够长,全部选择
idx = torch.arange(i + 1, device=x.device).unsqueeze(0).expand(batch_size, -1)
else:
# 选择top-k
_, idx = valid_scores.topk(self.sparse_k, dim=1)
indices.append(idx)
return indices
def _sparse_attention(self, Q, K, V, sparse_indices):
"""
根据稀疏索引计算注意力
"""
batch_size, num_heads, seq_len, head_dim = Q.shape
outputs = []
for i, idx in enumerate(sparse_indices):
# 当前位置的query
q_i = Q[:, :, i:i+1, :] # [batch, heads, 1, head_dim]
# 根据索引收集KV
# idx: [batch, k]
idx_expanded = idx.unsqueeze(1).unsqueeze(-1).expand(
-1, num_heads, -1, head_dim
) # [batch, heads, k, head_dim]
k_selected = torch.gather(
K, 2,
idx_expanded.transpose(1, 2).reshape(batch_size, 1, -1, head_dim).expand(-1, num_heads, -1, -1)
)
v_selected = torch.gather(
V, 2,
idx_expanded.transpose(1, 2).reshape(batch_size, 1, -1, head_dim).expand(-1, num_heads, -1, -1)
)
# 计算注意力
attn = torch.matmul(q_i, k_selected.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out_i = torch.matmul(attn, v_selected)
outputs.append(out_i)
return torch.cat(outputs, dim=2)
class HybridSparseAttention(nn.Module):
"""
混合稀疏注意力
结合固定模式 + 动态模式
"""
def __init__(
self,
d_model: int,
num_heads: int,
local_window: int = 512,
global_tokens: int = 64,
dynamic_k: int = 32
):
super().__init__()
self.local_window = local_window
self.global_tokens = global_tokens
self.dynamic_k = dynamic_k
# 局部注意力
self.local_attn = nn.MultiheadAttention(d_model, num_heads)
# 全局token
self.global_tokens = nn.Parameter(torch.randn(global_tokens, d_model))
# 动态稀疏选择器
self.dynamic_selector = nn.Linear(d_model, 1)
def forward(self, x):
"""
混合注意力:
1. 局部窗口(固定)
2. 全局token(固定)
3. 动态选择(可学习)
"""
batch_size, seq_len, _ = x.shape
# 1. 局部注意力
local_out = self._local_attention(x)
# 2. 全局注意力
global_out = self._global_attention(x)
# 3. 动态稀疏
dynamic_out = self._dynamic_attention(x)
# 融合
output = local_out + global_out + dynamic_out
return output
def _local_attention(self, x):
"""局部窗口注意力"""
# 使用滑动窗口实现
# 简化实现,实际需要更复杂的分块
return x # placeholder
def _global_attention(self, x):
"""全局token注意力"""
batch_size = x.shape[0]
global_tokens = self.global_tokens.unsqueeze(0).expand(batch_size, -1, -1)
combined = torch.cat([global_tokens, x], dim=1)
attn_output, _ = self.local_attn(combined, combined, combined)
return attn_output[:, self.global_tokens.shape[0]:, :]
def _dynamic_attention(self, x):
"""动态稀疏注意力"""
# 根据内容选择关键token
scores = self.dynamic_selector(x).squeeze(-1)
_, top_indices = scores.topk(self.dynamic_k, dim=1)
# 收集选中的token
# 简化实现
return x # placeholder
3.2 NSA与DSA的关系¶
┌─────────────────────────────────────────────────────────────────┐
│ NSA vs DSA 对比 │├─────────────────────────────────────────────────────────────────┤
│ │
│ DSA (DeepSeek Sparse Attention): │
│ ├── 2024年提出,DeepSeek-V3.2-Exp使用 │
│ ├── 结构化稀疏 + 动态选择 │
│ ├── 与FlashAttention、MLA兼容 │
│ └── 侧重:工程落地,与现有优化兼容 │
│ │
│ NSA (Native Sparse Attention): │
│ ├── 2025年论文(arXiv:2502.11089) │
│ ├── 可学习的N-gram稀疏模式 │
│ ├── 端到端训练稀疏结构 │
│ └── 侧重:理论完备,可学习的稀疏 │
│ │
│ 共同点: │
│ • 目标都是O(n·k)复杂度,k << n │
│ • 都保留关键信息通路 │
│ • 都可与FlashAttention兼容 │
│ │
│ 区别: │
│ • DSA更工程化,NSA更学术化 │
│ • NSA的可学习稀疏模式更灵活 │
│ • DSA已在生产环境验证 │
│ │
└─────────────────────────────────────────────────────────────────┘
5. MLA多头潜在注意力¶
5.1 MLA原理¶
class MultiHeadLatentAttention(nn.Module):
"""
Multi-Head Latent Attention (MLA)
DeepSeek-V2/V3的核心创新
核心思想:
将KV压缩到低维潜在空间,大幅减少KV Cache
KV Cache压缩比:90%+
"""
def __init__(
self,
d_model: int,
num_heads: int,
kv_latent_dim: int = 512, # 潜在空间维度,远小于d_model
rope_dim: int = 64
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.kv_latent_dim = kv_latent_dim
self.rope_dim = rope_dim
# Query投影(标准)
self.q_proj = nn.Linear(d_model, d_model)
# KV压缩到潜在空间
self.kv_compress = nn.Linear(d_model, kv_latent_dim)
# 从潜在空间解压
self.k_decompress = nn.Linear(kv_latent_dim, d_model)
self.v_decompress = nn.Linear(kv_latent_dim, d_model)
# 输出投影
self.o_proj = nn.Linear(d_model, d_model)
# RoPE(只应用于解压后的K)
self.rope = RotaryPositionalEmbedding(rope_dim)
self.scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, past_kv_cache=None):
"""
Args:
x: [batch, seq_len, d_model]
past_kv_cache: 之前的KV潜在表示
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: 新的KV潜在表示(用于增量生成)
"""
batch_size, seq_len, _ = x.shape
# 1. Query投影
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
Q = Q.transpose(1, 2) # [batch, heads, seq_len, head_dim]
# 2. KV压缩到潜在空间(关键创新)
KV_latent = self.kv_compress(x) # [batch, seq_len, kv_latent_dim]
# 3. 从潜在空间解压
K = self.k_decompress(KV_latent).view(batch_size, seq_len, self.num_heads, self.head_dim)
V = self.v_decompress(KV_latent).view(batch_size, seq_len, self.num_heads, self.head_dim)
# 应用RoPE
K = self.rope(K)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 4. 处理KV Cache
if past_kv_cache is not None:
# 只需要缓存潜在表示,而不是完整的KV
KV_latent = torch.cat([past_kv_cache, KV_latent], dim=1)
# 重新解压完整的KV
K = self.k_decompress(KV_latent).view(batch_size, -1, self.num_heads, self.head_dim)
V = self.v_decompress(KV_latent).view(batch_size, -1, self.num_heads, self.head_dim)
K = self.rope(K)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 5. 标准注意力计算
attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
output = torch.matmul(attn_weights, V)
# 6. 输出投影
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
output = self.o_proj(output)
return output, KV_latent
class RotaryPositionalEmbedding(nn.Module):
"""旋转位置编码"""
def __init__(self, dim: int, max_seq_len: int = 8192):
super().__init__()
self.dim = dim
# 预计算频率
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# 预计算位置编码
pos = torch.arange(max_seq_len).float()
freqs = torch.einsum("i,j->ij", pos, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, x):
"""应用RoPE"""
seq_len = x.shape[2]
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
return self._apply_rotary(x, cos, sin)
def _apply_rotary(self, x, cos, sin):
"""旋转操作"""
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat([
x1 * cos[..., :x1.shape[-1]] - x2 * sin[..., :x2.shape[-1]],
x1 * sin[..., :x1.shape[-1]] + x2 * cos[..., :x2.shape[-1]]
], dim=-1)
5.2 MLA的KV Cache优势¶
┌─────────────────────────────────────────────────────────────────┐
│ MLA KV Cache 对比 │├─────────────────────────────────────────────────────────────────┤
│ │
│ 标准MHA (128K上下文, 7B模型): │
│ ├── K Cache: 128K × 32 heads × 128 dim × 2 bytes = 1GB │
│ ├── V Cache: 同上 = 1GB │
│ └── 总计: ~2GB per request │
│ │
│ GQA (8组KV, 128K上下文): │
│ ├── KV Cache: 128K × 8 groups × 128 dim × 2 bytes × 2 = 512MB │
│ └── 节省: 75% │
│ │
│ MLA (kv_latent_dim=512, 128K上下文): │
│ ├── 潜在KV: 128K × 512 × 2 bytes = 128MB │
│ └── 节省: 93.75% │
│ │
│ 关键洞察: │
│ • MLA的KV Cache与头数无关,只与潜在维度有关 │
│ • 潜在维度可以远小于d_model │
│ • 这使得超长上下文(1M+)成为可能 │
│ │
└─────────────────────────────────────────────────────────────────┘
6. 线性注意力与RNN化¶
6.1 线性注意力原理¶
class LinearAttention(nn.Module):
"""
线性注意力
复杂度从O(n²)降到O(n)
核心思想:利用kernel trick将softmax(QK^T)V分解
标准注意力:softmax(QK^T/√d)V
线性注意力:φ(Q)(φ(K)^T V) / φ(Q)Σφ(K)^T
其中φ是非线性kernel函数(如elu+1)
"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
batch_size, seq_len, _ = x.shape
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 应用kernel函数 φ(x) = elu(x) + 1
Q = F.elu(Q) + 1
K = F.elu(K) + 1
# 线性注意力计算
# 标准方式:softmax(QK^T)V - O(n²)
# 线性方式:Q(K^T V) - O(n)
# K^T V: [batch, heads, head_dim, head_dim]
KtV = torch.matmul(K.transpose(-2, -1), V)
# Q(K^T V): [batch, heads, seq_len, head_dim]
numerator = torch.matmul(Q, KtV)
# 归一化项:Q ΣK^T
sum_K = K.sum(dim=2, keepdim=True) # [batch, heads, 1, head_dim]
denominator = torch.matmul(Q, sum_K.transpose(-2, -1))
# 归一化
output = numerator / (denominator + 1e-6)
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
output = self.o_proj(output)
return output
class RWKVLikeAttention(nn.Module):
"""
RWKV风格的线性注意力
结合RNN的效率和Transformer的能力
核心特点:
1. 训练时可以并行(像Transformer)
2. 推理时可以RNN化(像RNN)
3. O(n)复杂度
"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
# 时间衰减因子(可学习)
self.time_decay = nn.Parameter(torch.randn(num_heads, d_model // num_heads))
self.time_first = nn.Parameter(torch.randn(num_heads, d_model // num_heads))
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
"""
RWKV风格的线性注意力
"""
batch_size, seq_len, _ = x.shape
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
# RWKV核心:wkv (weighted key-value)
# 使用指数衰减的累积和
output = self._wkv(Q, K, V)
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
output = self.o_proj(output)
return output
def _wkv(self, Q, K, V):
"""
WKV: Weighted Key-Value
核心公式:wkv_t = (a_{t-1} * w + b_{t-1}) / (c_{t-1} * w + d_{t-1})
其中a, b, c, d是累积状态
"""
# 简化实现,实际RWKV有更复杂的数值稳定性处理
batch_size, num_heads, seq_len, head_dim = Q.shape
# 使用softmax的线性近似
w = torch.exp(-torch.exp(self.time_decay))
output = []
state_num = torch.zeros(batch_size, num_heads, head_dim, device=Q.device)
state_den = torch.zeros(batch_size, num_heads, head_dim, device=Q.device)
for t in range(seq_len):
q_t = Q[:, :, t, :]
k_t = K[:, :, t, :]
v_t = V[:, :, t, :]
# 更新状态
state_num = w * state_num + torch.exp(k_t) * v_t
state_den = w * state_den + torch.exp(k_t)
# 计算输出
out_t = q_t * state_num / (state_den + 1e-6)
output.append(out_t)
return torch.stack(output, dim=2)
6.2 线性注意力的权衡¶
┌─────────────────────────────────────────────────────────────────┐
│ 线性注意力权衡 │├─────────────────────────────────────────────────────────────────┤
│ │
│ 优势: │
│ ├── O(n)复杂度,适合极长序列 │
│ ├── 推理时可RNN化,内存占用恒定 │
│ └── 训练时可并行 │
│ │
│ 劣势: │
│ ├── 质量通常低于标准注意力 │
│ ├── kernel近似引入误差 │
│ └── 某些任务(如复制、精确检索)表现较差 │
│ │
│ 适用场景: │
│ ├── 超长序列(100K+) │
│ ├── 流式处理 │
│ └── 对精度要求不极高的场景 │
│ │
│ 代表模型: │
│ ├── RWKV │
│ ├── Mamba(状态空间模型) │
│ ├── Linear Transformer │
│ └── Performer │
│ │
└─────────────────────────────────────────────────────────────────┘
7. 注意力机制选型指南¶
7.1 选型决策树¶
┌─────────────────────────────────────────────────────────────────┐
│ 注意力机制选型决策树 │├─────────────────────────────────────────────────────────────────┤
│ │
│ 序列长度 < 8K? │
│ ├── 是 → 标准MHA + FlashAttention │
│ │ └── 最简单,质量最好 │
│ │ │
│ └── 否 → 序列长度 < 128K? │
│ ├── 是 → 需要最高质量? │
│ │ ├── 是 → GQA + FlashAttention + KV Cache优化 │
│ │ └── 否 → MLA(DeepSeek风格) │
│ │ │
│ └── 否 → 序列长度 < 1M? │
│ ├── 是 → 稀疏注意力(MoBA/NSA) │
│ │ └── 局部+全局混合 │
│ │ │
│ └── 否 → 线性注意力(RWKV/Mamba) │
│ └── 牺牲部分质量换取效率 │
│ │
│ 特殊需求: │
│ • 需要流式推理 → 线性注意力/RNN化 │
│ • 需要多模态 → M-RoPE + 标准注意力 │
│ • 需要最高吞吐 → MLA + PagedAttention │
│ │
└─────────────────────────────────────────────────────────────────┘
7.2 各方案实现复杂度¶
┌────────────────────────────────────────────────────────────────────────┐
│ 实现复杂度对比 │├────────────────────────────────────────────────────────────────────────┤
│ │
│ 方案 实现难度 CUDA优化 与现有生态兼容 推荐指数 │
│ ─────────────────────────────────────────────────────────────────────│
│ MHA ⭐ 需要 完美 ⭐⭐⭐ │
│ FlashAttn ⭐⭐⭐ 必须 完美 ⭐⭐⭐⭐⭐ │
│ GQA ⭐⭐ 需要 完美 ⭐⭐⭐⭐⭐ │
│ MLA ⭐⭐⭐⭐ 必须 需要适配 ⭐⭐⭐⭐ │
│ 稀疏注意力 ⭐⭐⭐⭐⭐ 必须 需要适配 ⭐⭐⭐ │
│ 线性注意力 ⭐⭐⭐ 可选 需要适配 ⭐⭐⭐ │
│ │
│ 建议: │
│ • 新项目:优先使用FlashAttention + GQA(vLLM/PyTorch原生支持) │
│ • 长上下文:考虑MLA(DeepSeek开源实现) │
│ • 超长序列:稀疏注意力(需要自定义CUDA kernel) │
│ │
└────────────────────────────────────────────────────────────────────────┘
8. 面试高频问答¶
Q1: Kimi的MoBA(混合块注意力)是什么?解决了什么问题?¶
答:MoBA(Mixture of Block Attention)是月之暗面2025年提出的注意力机制创新。
核心思想: 1. 将序列分块,块内计算全注意力 2. 通过"注意力残差"连接保留全局信息流 3. 动态路由决定跨块交互强度
解决的问题: - 长上下文(1M+ tokens)的计算成本 - 传统稀疏注意力固定模式不够灵活 - 在O(n·block_size)复杂度下保持接近全注意力的质量
Q2: Kimi的Attention Residuals(注意力残差)是什么?与MoBA有何区别?¶
答:Attention Residuals是月之暗面2026年提出的另一种创新,与MoBA关注点不同。
核心思想: 1. 标准残差:h_{l+1} = h_l + f_l(h_l)(固定"+"号) 2. AttnRes:h_l = Σ α_{i→l} · v_i(注意力加权的累积) 3. 用注意力权重替代固定的"+"号,实现选择性信息传递
与MoBA的区别: | 方面 | MoBA | Attention Residuals | |------|------|-------------------| | 关注点 | Token级别的序列分块 | Layer/Block级别的信息流 | | 目标 | 降低序列长度复杂度 | 改善深层网络的信息传递 | | 时间 | 2025年 | 2026年 | | 复杂度 | O(n·k) | O(n²) + 25%额外开销 |
Attention Residuals的独特价值: - 首次用注意力机制替代固定残差连接 - 解决PreNorm稀释问题 - 仅+25%计算量换取显著质量提升(GPQA +7.5) - 可作为drop-in replacement
Q3: MLA相比GQA有什么优势?¶
答:MLA(多头潜在注意力)的核心优势:
- KV Cache压缩比更高:
- GQA:通过共享KV减少~75%
MLA:通过潜在空间压缩减少~93%+
与头数解耦:
- GQA:KV Cache仍与组数相关
MLA:KV Cache只与潜在维度相关
更适合超长上下文:
- MLA使得1M+上下文成为可能
代价是:实现更复杂,需要自定义kernel
Q4: 稀疏注意力会损失多少质量?¶
答:取决于稀疏策略:
- 固定窗口稀疏:损失较大,~5-10%任务性能下降
- 动态稀疏(NSA/DSA):损失较小,~1-3%下降
- MoBA(带残差):几乎无损,~0-1%下降
关键是保留关键信息通路(全局token、残差连接)
Q5: 为什么线性注意力没有成为主流?¶
答:主要原因:
- 质量损失:kernel近似引入误差,某些任务表现较差
- 生态兼容:需要重写训练/推理框架
- FlashAttention的进步:标准注意力效率大幅提升,缩小了差距
- MLA等方案:在保持O(n²)计算的同时解决了显存问题
但在极长序列(1M+)场景,线性注意力仍有价值
Q6: 如何选择适合自己场景的注意力机制?¶
答:按序列长度选择:
- <8K:标准MHA + FlashAttention
- 8K-128K:GQA + FlashAttention(或MLA)
- 128K-1M:稀疏注意力(MoBA/NSA)
- >1M:线性注意力(RWKV/Mamba)
其他考虑因素: - 是否需要流式推理 - 是否有多模态需求 - 工程实现成本 - 是否需要提升模型质量(考虑AttnRes)
本章小结¶
┌─────────────────────────────────────────────────────────────────┐
│ 核心要点总结 │├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 注意力机制演进主线: │
│ 效率优化:MHA → GQA → MLA → 稀疏注意力 │
│ 复杂度:O(n²) → O(n²)但显存↓ → O(n·k) → O(n) │
│ │
│ 2. Kimi MoBA核心创新: │
│ • 块内全注意力 + 块间残差 │
│ • 动态路由 + 门控融合 │
│ • 适合1M+超长上下文 │
│ │
│ 3. Kimi Attention Residuals(注意力残差): │
│ • 核心公式:h_l = Σ α_{i→l} · v_i(注意力加权累积) │
│ • 创新:用注意力机制替代固定的"+"号残差连接 │
│ • Block AttnRes:内存从O(Ld)降至O(Nd),8块即可恢复90%+收益 │
│ • 性能提升:GPQA-Diamond +7.5,HumanEval +3.1 │
│ • 计算开销:仅+25%,内存开销~0.5MB │
│ • 理论贡献:首次揭示PreNorm稀释问题 │
│ │
│ 4. DeepSeek NSA/DSA: │
│ • 可学习的稀疏模式 │
│ • 与FlashAttention兼容 │
│ • 工程落地验证 │
│ │
│ 5. MLA多头潜在注意力: │
│ • KV Cache压缩90%+ │
│ • 与头数解耦 │
│ • DeepSeek-V2/V3采用 │
│ │
│ 6. 选型建议: │
│ • 短序列:FlashAttention + GQA │
│ • 长序列:MLA │
│ • 超长序列:稀疏/线性注意力 │
│ • 提升质量(任何长度):AttnRes │
│ │
└─────────────────────────────────────────────────────────────────┘
📝 本章练习¶
🤔 思考题¶
- MoBA:Mixture of Block Attention 的核心思想是什么?它是如何在保持注意力的灵活性的同时降低计算复杂度的?
- Attention Residuals:AttnRes 的残差连接和 ResNet 中的残差连接有什么异同?为什么残差连接能提升注意力质量?
- NSA vs FlashAttention:Native Sparse Attention 和 FlashAttention 的优化目标不同,各自适合什么场景?
💻 代码实践¶
- 入门:实现一个简化的 Block Attention,对比全注意力的计算量差异
- 进阶:实现 MoBA 的路由机制,观察不同 Block 大小对性能的影响
💡 参考答案
#### 思考题参考答案 **1. MoBA** 核心思想:将注意力计算分块处理,每个 Query 只关注最相关的 Block(而非所有 Token)。通过路由机制选择相关 Block,实现计算量的稀疏化。 灵活性:路由是动态的(基于输入内容),而非固定的稀疏模式,因此能适应不同类型的任务。 **2. Attention Residuals** 异同: - ResNet 残差:解决深层网络梯度消失,`output = F(x) + x` - AttnRes 残差:在注意力层间添加跨层连接,保留原始信息流 提升质量的原因:注意力层可能"过度关注"某些模式而忽略原始信息,残差连接确保原始信息不被完全覆盖。 **3. NSA vs FlashAttention** - **FlashAttention**:优化全注意力的计算效率(减少内存访问),适合需要全注意力的场景 - **NSA**:减少注意力的计算量(稀疏化),适合超长序列(100K+ tokens) 场景:FlashAttention 适合短-中序列;NSA 适合超长序列(如完整代码库、长文档)。扩展阅读¶
- MoBA: Mixture of Block Attention (Kimi, 2025)
- Attention Residuals: From ResNets to Transformers (Kimi, 2026)
- Native Sparse Attention (DeepSeek, arXiv:2502.11089)
- DeepSeek-V2: MLA Technical Report (2024)
- FlashAttention-2 (Dao, 2023)
- RWKV: Reinventing RNNs for the Transformer Era (2023)
最后更新日期: 2026-04-21 适用版本: LLM 学习教程 v2026