模型架构创新¶
目录¶
Mixture of Experts (MoE)¶
1.1 MoE核心思想¶
混合专家模型(Mixture of Experts, MoE) 通过条件计算(Conditional Computation)实现模型容量与计算成本的解耦。核心思想是:对于每个输入,只激活部分参数(专家),而非整个网络。
Text Only
┌─────────────────────────────────────────────────────────────────┐
│ MoE 架构示意图 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Input │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Router / │ │
│ │ Gate Network │──▶ 生成路由权重 [0.1, 0.7, 0.05, 0.15] │
│ │ (可学习) │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────┴──────┬────────┬────────┐ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ E0 │ │ E1 │ │ E2 │ │ E3 │ Expert Networks │
│ │0.1 │ │0.7 │ │0.05 │ │0.15 │ (FFN层作为专家) │
│ └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘ │
│ │ │ │ │ │
│ └──────────┴────┬────┴────────┘ │
│ ▼ │
│ ┌──────────┐ │
│ │ 加权求和 │──▶ 0.1×E0 + 0.7×E1 + 0.05×E2 + 0.15×E3 │
│ └──────────┘ │
│ │ │
│ ▼ │
│ Output │
│ │
│ 关键:只计算Top-K专家(如Top-2),其余专家输出设为0 │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 MoE数学原理¶
Text Only
MoE层的前向传播
═══════════════════════════════════════════════════════════════════
给定输入 x,MoE层的输出为:
N
MoE(x) = Σ G(x)_i · E_i(x)
i=1
其中:
- N = 专家数量(通常 8, 16, 64, 128)
- E_i = 第i个专家网络(通常是FFN)
- G(x)_i = 门控网络对第i个专家的权重
门控函数(Softmax Gate):
exp(W_g · x)_i
G(x)_i = ─────────────────────
Σ_j exp(W_g · x)_j
Top-K门控(实际使用):
G(x)_i = { Softmax(TopK(W_g · x, k))_i if i ∈ TopK
{ 0 otherwise
负载均衡损失(Auxiliary Loss):
L_aux = α · N · Σ_i f_i · P_i
其中:
- f_i = 批次中分配给专家i的token比例
- P_i = 路由器对专家i的平均预测概率
- α = 超参数(通常 0.01)
═══════════════════════════════════════════════════════════════════
1.3 MoE实现代码¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class MoELayer(nn.Module):
"""
Mixture of Experts 层实现
"""
def __init__(
self,
d_model: int,
num_experts: int = 8,
top_k: int = 2,
expert_hidden_dim: int = None,
aux_loss_coef: float = 0.01
):
super().__init__() # super()调用父类方法
self.d_model = d_model
self.num_experts = num_experts
self.top_k = top_k
self.aux_loss_coef = aux_loss_coef
if expert_hidden_dim is None:
expert_hidden_dim = 4 * d_model
# 门控网络(路由器)
self.gate = nn.Linear(d_model, num_experts, bias=False)
# 专家网络(每个专家是一个FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, expert_hidden_dim),
nn.GELU(),
nn.Linear(expert_hidden_dim, d_model)
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
output: [batch_size, seq_len, d_model]
aux_loss: 负载均衡辅助损失
"""
batch_size, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model) # [batch_size * seq_len, d_model] # view重塑张量形状
# 1. 计算门控分数
gate_logits = self.gate(x_flat) # [num_tokens, num_experts]
gate_probs = F.softmax(gate_logits, dim=-1)
# 2. Top-K路由
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) # 重新归一化
# 3. 计算专家输出并加权
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_indices = top_k_indices[:, i]
expert_probs = top_k_probs[:, i:i+1]
# 对每个专家处理分配给它的token
for expert_id in range(self.num_experts):
mask = expert_indices == expert_id
if mask.any(): # any()任一为True则返回True
expert_input = x_flat[mask]
expert_output = self.experts[expert_id](expert_input)
output[mask] += expert_probs[mask] * expert_output
# 4. 计算负载均衡辅助损失
aux_loss = self._compute_aux_loss(gate_probs, top_k_indices)
output = output.view(batch_size, seq_len, d_model)
return output, aux_loss
def _compute_aux_loss(
self,
gate_probs: torch.Tensor,
top_k_indices: torch.Tensor
) -> torch.Tensor:
"""
计算负载均衡辅助损失(来自Switch Transformer论文)
"""
num_tokens = gate_probs.size(0)
# f_i: 每个专家被选择的token比例
expert_mask = F.one_hot(top_k_indices, self.num_experts).sum(dim=1)
f = expert_mask.float().sum(dim=0) / num_tokens
# P_i: 路由器分配给每个专家的平均概率
P = gate_probs.mean(dim=0)
# 辅助损失:鼓励均匀分布
aux_loss = self.num_experts * (f * P).sum()
return self.aux_loss_coef * aux_loss
class MoETransformerBlock(nn.Module):
"""
使用MoE的Transformer块(替代标准FFN)
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.1
):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.moe = MoELayer(d_model, num_experts, top_k)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# Self-Attention
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
# MoE FFN
moe_out, aux_loss = self.moe(x)
x = self.norm2(x + self.dropout(moe_out))
return x, aux_loss
1.4 MoE变体与优化¶
Text Only
MoE架构演进
═══════════════════════════════════════════════════════════════════
> ⚠️ **时效性说明(2026-02 复核)**:下述“性能对比”结论受评测集、训练数据与评测口径影响,请结合原论文与最新公开评测解读。
Switch Transformer (Google, 2021)
├── Top-1 路由(每个token只选一个专家)
├── 简化设计,降低通信开销
└── 1.6T参数,与T5-XXL相当性能但快4倍
GLaM (Google, 2021)
├── 64专家 per MoE层
├── 1.2T总参数,96B激活参数
└── 在部分公开任务上优于同期GPT-3基线
ST-MoE (Google, 2022)
├── 引入Expert Choice Routing
├── 从token选专家 → 专家选token
└── 更好的负载均衡
Mixtral 8x7B (Mistral, 2023)
├── 8专家,Top-2路由
├── 46.7B总参数,12.9B激活参数
├── 开源,在多项公开评测中表现接近或优于Llama 2 70B
└── 稀疏MoE首次在开源社区成功
Qwen1.5-MoE (阿里, 2024)
├── 共享专家 + 路由专家分离
├── 更好的专家特化
└── 中文MoE代表模型之一
═══════════════════════════════════════════════════════════════════
专家选择路由(Expert Choice Routing)¶
Python
class ExpertChoiceMoE(nn.Module):
"""
Expert Choice Routing: 专家选择token而非token选择专家
优势:完美负载均衡,每个专家处理固定数量token
"""
def __init__(self, d_model: int, num_experts: int, expert_capacity: int):
super().__init__()
self.num_experts = num_experts
self.expert_capacity = expert_capacity # 每个专家处理的token数
self.gate = nn.Linear(d_model, num_experts)
self.experts = nn.ModuleList([
self._create_expert(d_model) for _ in range(num_experts)
])
def forward(self, x: torch.Tensor):
batch_size, seq_len, d_model = x.shape
num_tokens = batch_size * seq_len
x_flat = x.view(num_tokens, d_model)
# 计算每个token对每个专家的偏好分数
gate_scores = self.gate(x_flat) # [num_tokens, num_experts]
# 每个专家选择Top-C个token
selected_tokens = []
for expert_id in range(self.num_experts):
expert_scores = gate_scores[:, expert_id]
top_k_values, top_k_indices = torch.topk(
expert_scores,
min(self.expert_capacity, num_tokens)
)
selected_tokens.append((expert_id, top_k_indices, top_k_values))
# 聚合输出(一个token可能被多个专家处理)
output = torch.zeros_like(x_flat)
for expert_id, indices, scores in selected_tokens:
expert_input = x_flat[indices]
expert_output = self.experts[expert_id](expert_input)
output[indices] += scores.unsqueeze(1) * expert_output # unsqueeze增加一个维度
return output.view(batch_size, seq_len, d_model)
1.5 MoE训练挑战与解决方案¶
| 挑战 | 原因 | 解决方案 |
|---|---|---|
| 负载不均衡 | 某些专家被过度使用 | 辅助损失、专家选择路由、容量因子 |
| 通信开销 | 专家分布在不同设备 | 专家并行、All-to-All通信优化 |
| 训练不稳定 | 路由决策离散 | 负载均衡损失、专家dropout、温度退火 |
| 显存占用 | 大量专家参数 | 专家并行、CPU offloading、激活检查点 |
状态空间模型 (Mamba)¶
2.1 从RNN到SSM¶
状态空间模型(State Space Model, SSM) 是一类将连续时间系统离散化的序列建模方法。Mamba通过选择性机制(Selective Mechanism)解决了传统SSM的内容依赖问题。
Text Only
序列模型演进
═══════════════════════════════════════════════════════════════════
RNN
├── 特点:隐状态压缩历史信息
├── 问题:长程依赖困难、训练并行度低
└── 代表:LSTM, GRU
Transformer
├── 特点:全局注意力、训练并行
├── 问题:O(n²)复杂度、长序列内存爆炸
└── 代表:GPT, BERT, LLaMA
SSM (S4, 2021)
├── 特点:线性复杂度、连续时间建模
├── 问题:内容无关(参数固定)
└── 代表:S4, DSS, GSS
Mamba (2023)
├── 特点:选择性状态空间、输入依赖
├── 优势:线性复杂度 + 内容感知
└── 代表:Mamba, Jamba, Falcon Mamba
═══════════════════════════════════════════════════════════════════
2.2 状态空间模型基础¶
Text Only
连续时间状态空间方程
═══════════════════════════════════════════════════════════════════
连续形式:
h'(t) = A · h(t) + B · x(t) (状态演化)
y(t) = C · h(t) (输出)
其中:
- h(t) ∈ R^N: 隐状态(N维)
- x(t) ∈ R: 输入(标量)
- y(t) ∈ R: 输出(标量)
- A ∈ R^(N×N): 状态转移矩阵
- B ∈ R^N: 输入影响矩阵
- C ∈ R^N: 输出投影矩阵
离散化(零阶保持,ZOH):
h_k = Ā · h_{k-1} + B̄ · x_k
y_k = C · h_k
其中离散化参数:
Ā = exp(Δ · A) (矩阵指数)
B̄ = (Ā - I) · A^(-1) · B
═══════════════════════════════════════════════════════════════════
2.3 Mamba的选择性机制¶
Mamba的核心创新是使B、C、Δ成为输入的函数,实现选择性记忆/遗忘。
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class SelectiveSSM(nn.Module):
"""
选择性状态空间模型(Mamba核心)
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(expand * d_model)
# 输入投影(x → x和残差)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 因果卷积(局部建模)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner, # 深度可分离
bias=True
)
# 选择性参数生成(关键创新)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
# 输出: [B, Δ, C] 其中Δ是时间步长参数
# 状态转移矩阵A(可学习或固定)
# 通常使用HiPPO初始化或学习
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1)).repeat(self.d_inner, 1))
self.D = nn.Parameter(torch.ones(self.d_inner)) # 跳跃连接
# 输出投影
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
output: [batch, seq_len, d_model]
"""
batch, seq_len, dim = x.shape
# 输入投影
x_and_res = self.in_proj(x) # [batch, seq_len, 2*d_inner]
x_ssm, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1)
# 因果卷积
x_conv = rearrange(x_ssm, 'b l d -> b d l')
x_conv = self.conv1d(x_conv)[..., :seq_len] # 因果截断
x_conv = rearrange(x_conv, 'b d l -> b l d')
x_conv = F.silu(x_conv)
# 选择性SSM
y = self.selective_scan(x_conv)
# 门控融合
y = y * F.silu(res)
# 输出投影
output = self.out_proj(y)
return output
def selective_scan(self, x):
"""
选择性扫描:核心计算
使用并行关联扫描(Parallel Associative Scan)实现高效计算
"""
batch, seq_len, d_in = x.shape
# 生成选择性参数(输入依赖)
params = self.x_proj(x) # [batch, seq_len, d_state*2 + 1]
B, C, delta = params.split([self.d_state, self.d_state, 1], dim=-1)
# softplus确保Δ为正
delta = F.softplus(delta.squeeze(-1)) # [batch, seq_len, d_in]
# 离散化
A = -torch.exp(self.A_log.float()) # [d_in, d_state]
# 使用并行扫描高效计算
# 这里使用简化的顺序实现,实际使用CUDA优化的并行版本
y = self.parallel_scan(x, delta, A, B, C)
return y
def parallel_scan(self, x, delta, A, B, C):
"""
并行关联扫描实现
将递归计算转化为可并行的关联操作
"""
batch, seq_len, d_in = x.shape
# 离散化参数
delta_A = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
delta_B = torch.einsum('bld,bln->bldn', delta, B)
# 输入加权
BX = delta_B * x.unsqueeze(-1) # [batch, seq_len, d_in, d_state]
# 关联扫描: h_k = delta_A_k * h_{k-1} + BX_k
# 使用Blelloch扫描算法并行化
h = self.associative_scan(delta_A, BX)
# 输出投影
y = torch.einsum('bldn,bln->bld', h, C)
# 跳跃连接
y = y + self.D.unsqueeze(0).unsqueeze(0) * x
return y
def associative_scan(self, A, B):
"""
关联扫描:将线性递归转化为并行计算
核心思想:h_k = A_k * h_{k-1} + B_k 是关联操作
"""
# 实际实现使用CUDA内核
# 这里展示概念性实现
batch, seq_len, d_in, d_state = B.shape
h = torch.zeros(batch, d_in, d_state, device=B.device, dtype=B.dtype)
hs = []
for k in range(seq_len):
h = A[:, k] * h + B[:, k]
hs.append(h)
return torch.stack(hs, dim=1)
2.4 Mamba架构细节¶
Text Only
Mamba Block 详细结构
═══════════════════════════════════════════════════════════════════
输入 x ∈ R^(L×D)
│
├──▶ Linear ──▶ [x_ssm, res] 分割为SSM输入和残差
│
├──▶ x_ssm ──▶ Conv1d (因果) ──▶ SiLU ──▶ Selective SSM ──┐
│ │
├──▶ res ─────────────────────────────────────────────▶ SiLU │
│ │
└──▶ 逐元素相乘 ◀───────────────────────────────────────────┘
│
▼
Linear ──▶ 输出
关键设计选择:
1. 因果卷积:引入局部依赖,弥补SSM缺乏局部性的不足
2. SiLU激活:平滑非线性,与门控机制配合
3. 选择性参数:B, C, Δ都依赖于输入x
4. 并行扫描:O(n)复杂度的高效实现
═══════════════════════════════════════════════════════════════════
2.5 Mamba变体¶
Text Only
Mamba 家族
═══════════════════════════════════════════════════════════════════
Mamba (2023)
├── 纯SSM架构,无注意力
├── 线性复杂度 O(n)
├── 语言建模性能匹敌Transformer
└── 长序列优势显著
Jamba (AI21, 2024)
├── Mamba + Transformer 混合
├── 注意力层与SSM层交替
├── 结合两者优势
└── 256K上下文支持
Falcon Mamba (TII, 2024)
├── 7B参数纯Mamba模型
├── 开源可商用
└── 在长文本任务上表现突出
Vision Mamba (2024)
├── 将Mamba应用于视觉
├── 双向SSM处理图像
└── 高效视觉backbone
Mamba-2 (2024)
├── 理论统一:SSM与结构化矩阵
├── 更高效的实现
└── 8倍训练速度提升
═══════════════════════════════════════════════════════════════════
线性注意力与RNN化Transformer¶
3.1 线性注意力机制¶
标准注意力的复杂度为O(n²),线性注意力通过核技巧将其降至O(n)。
Text Only
线性注意力原理
═══════════════════════════════════════════════════════════════════
标准Softmax Attention:
softmax(QK^T)V
Attn = ───────────────
Z
其中 Z = row_sum(softmax(QK^T))
复杂度:O(n²·d) 因为需要计算n×n的注意力矩阵
线性注意力(核技巧):
使用特征映射 φ,使得:
softmax(QK^T/√d) ≈ φ(Q)φ(K)^T
则:
φ(Q)φ(K)^T V φ(Q)(φ(K)^T V)
Attn = ───────────── = ─────────────────
Z Z
关键观察:φ(K)^T V 可以先计算(与序列长度无关)
复杂度:O(n·d²) 线性于序列长度!
═══════════════════════════════════════════════════════════════════
Python
class LinearAttention(nn.Module):
"""
线性注意力实现(基于Katharopoulos et al.)
"""
def __init__(self, dim, num_heads, feature_dim=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.feature_dim = feature_dim or self.head_dim
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def elu_feature_map(self, x):
"""特征映射:ELU + 1"""
return F.elu(x) + 1
def forward(self, x):
B, N, C = x.shape
# 生成QKV
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
# 特征映射
q = self.elu_feature_map(q)
k = self.elu_feature_map(k)
# 线性注意力计算
# KV = Σ_t k_t^T v_t
KV = torch.einsum('bhsk,bhsv->bhkv', k, v) # [B, heads, head_dim, head_dim]
# Z = Σ_t k_t
Z = k.sum(dim=2) # [B, heads, head_dim]
# 输出 = Q @ KV / (Q @ Z)
numerator = torch.einsum('bhnk,bhkv->bhnv', q, KV) # [B, heads, N, head_dim]
denominator = torch.einsum('bhnk,bhk->bhn', q, Z).unsqueeze(-1) + 1e-6
out = numerator / denominator
out = out.transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
return out
3.2 RNN化Transformer¶
将Transformer解码为RNN形式,实现O(1)推理内存。
Python
class RNNTransformer(nn.Module):
"""
Transformer的RNN形式(基于RWKV和RetNet思想)
"""
def __init__(self, dim, num_heads):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.decay = nn.Parameter(torch.randn(num_heads)) # 每头一个标量衰减率
def forward_training(self, x):
"""训练时使用并行计算"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 因果掩码 + 衰减
positions = torch.arange(N, device=x.device)
rel_pos = positions.unsqueeze(0) - positions.unsqueeze(1) # [N, N]
# decay_matrix[h, n, m] = exp((n-m) * decay_h),对下三角有效
decay_matrix = torch.exp(
rel_pos.unsqueeze(0) * self.decay.view(-1, 1, 1) # [H, N, N]
)
decay_matrix = torch.tril(decay_matrix) # 因果掩码
# 注意力:Retention(X) = (Q K^T ⊙ D) V
attn = torch.einsum('bnhd,bmhd->bhnm', q, k) # [B, H, N, M]
attn = attn * decay_matrix.unsqueeze(0) # [B, H, N, M]
attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-6)
out = torch.einsum('bhnm,bmhd->bnhd', attn, v) # [B, N, H, D]
return out.reshape(B, N, C)
def forward_inference(self, x, state):
"""
推理时使用RNN形式,O(1)内存
state: (KV_cache, K_sum_cache)
"""
B, C = x.shape
qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
KV_cache, K_sum_cache = state
# 更新状态(RNN形式)
decay = torch.exp(self.decay) # [num_heads]
# KV_cache = decay * KV_cache + k^T v
new_KV = KV_cache * decay.view(1, -1, 1, 1) # [B, H, D, D]
new_KV = new_KV + torch.einsum('bhd,bhe->bhde', k, v)
# K_sum_cache = decay * K_sum_cache + k
new_K_sum = K_sum_cache * decay.view(1, -1, 1) # [B, H, D]
new_K_sum = new_K_sum + k
# 计算输出
numerator = torch.einsum('bhd,bhde->bhe', q, new_KV)
denominator = torch.einsum('bhd,bhd->bh', q, new_K_sum).unsqueeze(-1) + 1e-6
out = numerator / denominator
return out.reshape(B, C), (new_KV, new_K_sum)
3.3 代表性模型¶
Text Only
线性/RNN化Transformer模型
═══════════════════════════════════════════════════════════════════
RWKV (2023)
├── 架构:结合RNN和Transformer优点
├── 训练:并行化(类似Transformer)
├── 推理:O(1)内存(类似RNN)
├── 特点:时间衰减机制
└── 开源:完全开源可商用
RetNet (Microsoft, 2023)
├── 架构:保留Transformer结构
├── 核心:Retention机制(对偶形式)
├── 训练:并行
├── 推理:循环形式
└── 声称:性能匹敌Transformer
TransNormerLLM (2023)
├── 分块注意力 + 线性注意力
├── 局部精确 + 全局线性
└── 长序列高效
Linear Transformer (2020)
├── 最早的线性注意力工作
├── 使用核技巧近似softmax
└── 理论基础
═══════════════════════════════════════════════════════════════════
其他架构创新¶
4.1 混合专家系统(非MoE形式)¶
Python
class MultiScaleTransformer(nn.Module):
"""
多尺度Transformer:不同层使用不同注意力模式
"""
def __init__(self, dim, num_layers):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
if i % 3 == 0:
# 局部窗口注意力(捕捉局部模式)
self.layers.append(LocalAttention(dim, window_size=512))
elif i % 3 == 1:
# 稀疏全局注意力(长程依赖)
self.layers.append(SparseGlobalAttention(dim, stride=16))
else:
# 标准注意力(精确建模)
self.layers.append(StandardAttention(dim))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
4.2 分层注意力¶
Python
class HierarchicalAttention(nn.Module):
"""
分层注意力:先粗粒度后细粒度
"""
def __init__(self, dim, num_tokens_per_group=64):
super().__init__()
self.num_tokens_per_group = num_tokens_per_group
# 粗粒度(组级别)
self.coarse_attention = nn.MultiheadAttention(dim, num_heads=8)
# 细粒度(组内)
self.fine_attention = nn.MultiheadAttention(dim, num_heads=8)
def forward(self, x):
B, N, C = x.shape
G = N // self.num_tokens_per_group # 组数
# 分组
x_grouped = x.reshape(B, G, self.num_tokens_per_group, C)
# 组内池化得到组表示
group_repr = x_grouped.mean(dim=2) # [B, G, C]
# 粗粒度:组间注意力
coarse_out, _ = self.coarse_attention(group_repr, group_repr, group_repr)
# 将组表示广播回各token
coarse_expanded = coarse_out.unsqueeze(2).expand(-1, -1, self.num_tokens_per_group, -1)
# 细粒度:组内注意力
x_combined = x_grouped + coarse_expanded # 融合
x_fine = x_combined.reshape(B * G, self.num_tokens_per_group, C)
fine_out, _ = self.fine_attention(x_fine, x_fine, x_fine)
return fine_out.reshape(B, N, C)
4.3 内存高效架构¶
Python
class MemoryEfficientAttention(nn.Module):
"""
FlashAttention风格的内存高效实现
通过分块计算避免存储完整的注意力矩阵
"""
def __init__(self, dim, num_heads, block_size=256):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.block_size = block_size
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
分块计算注意力,O(n)内存(无需存储完整N×N注意力矩阵)
"""
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 分块
num_blocks = (N + self.block_size - 1) // self.block_size
outputs = []
for i in range(num_blocks):
start_i = i * self.block_size
end_i = min(start_i + self.block_size, N)
block_len = end_i - start_i
q_block = q[:, start_i:end_i] # [B, block_len, H, D]
# 在线计算softmax
max_score = torch.full((B, H, block_len, 1), float('-inf'), device=x.device)
numerator = torch.zeros(B, H, block_len, D, device=x.device)
denominator = torch.zeros(B, H, block_len, 1, device=x.device)
for j in range(num_blocks):
start_j = j * self.block_size
end_j = min(start_j + self.block_size, N)
k_block = k[:, start_j:end_j] # [B, block_j, H, D]
v_block = v[:, start_j:end_j] # [B, block_j, H, D]
# 计算块间注意力,保留多头维度
scores = torch.einsum('bnhd,bmhd->bhnm', q_block, k_block) # [B, H, block_i, block_j]
scores = scores / math.sqrt(D)
# 在线softmax更新
max_score_new = torch.maximum(max_score, scores.max(dim=-1, keepdim=True)[0])
exp_scores = torch.exp(scores - max_score_new)
numerator = numerator * torch.exp(max_score - max_score_new) + \
torch.einsum('bhnm,bmhd->bhnd', exp_scores, v_block)
denominator = denominator * torch.exp(max_score - max_score_new) + \
exp_scores.sum(dim=-1, keepdim=True)
max_score = max_score_new
out_block = numerator / denominator # [B, H, block_len, D]
out_block = out_block.permute(0, 2, 1, 3) # [B, block_len, H, D]
outputs.append(out_block)
out = torch.cat(outputs, dim=1)
out = out.reshape(B, N, C)
return self.proj(out)
架构选择指南¶
5.1 不同场景推荐¶
Text Only
架构选择决策树
═══════════════════════════════════════════════════════════════════
你的主要约束是什么?
│
├── 长序列(>32K tokens)
│ │
│ ├── 需要强长程依赖捕捉
│ │ └──▶ Mamba / Jamba / Linear Attention
│ │
│ └── 主要关注局部模式
│ └──▶ 滑动窗口注意力 + 全局token
│
├── 大规模部署(成本敏感)
│ │
│ ├── 高吞吐推理
│ │ └──▶ MoE(激活参数少)/ RWKV(O(1)内存)
│ │
│ └── 训练成本控制
│ └──▶ 线性注意力 / 分层注意力
│
├── 通用性能优先
│ │
│ ├── 有足够计算资源
│ │ └──▶ 标准Transformer + FlashAttention
│ │
│ └── 资源受限
│ └──▶ 混合架构(MoE + 线性注意力)
│
└── 特定领域
│
├── 代码生成
│ └──▶ 标准Transformer(结构化模式重要)
│
├── 长文档理解
│ └──▶ Mamba / Longformer
│
└── 多模态
└──▶ 标准Transformer(跨模态对齐成熟)
═══════════════════════════════════════════════════════════════════
5.2 架构对比表¶
| 架构 | 训练复杂度 | 推理复杂度 | 长序列支持 | 实现难度 | 生态成熟度 |
|---|---|---|---|---|---|
| 标准Transformer | O(n²) | O(n²) | 差 | 低 | 极高 |
| FlashAttention | O(n²) | O(n²) | 中 | 中 | 高 |
| MoE | O(n²)⁺ | O(n²)⁺ | 好 | 高 | 中 |
| Mamba | O(n) | O(n) | 极好 | 中 | 增长中 |
| 线性注意力 | O(n) | O(n) | 好 | 低 | 中 |
| RWKV | O(n) | O(1)⁺⁺ | 好 | 中 | 低 |
⁺ MoE注意力复杂度仍为O(n²),但每个token的FFN计算量显著降低(仅激活少数专家) ⁺⁺ RWKV推理复杂度O(1)指内存不随序列增长
5.3 实践建议¶
- 从标准Transformer开始
- 除非有明确的长序列需求,否则优先使用成熟的Transformer
-
FlashAttention已解决大部分内存问题
-
MoE的适用场景
- 模型参数>10B时考虑
- 推理吞吐量优先于延迟的场景
-
有分布式训练基础设施
-
Mamba/SSM的适用场景
- 序列长度>32K
- 需要流式处理(如语音、视频)
-
内存严格受限的边端设备
-
混合策略
- 底层:Mamba/线性注意力处理长上下文
- 上层:标准Transformer进行精确建模
- 参考:Jamba架构
下一步¶
完成本章节学习后,建议继续阅读: - 05-大模型安全与对齐 - 理解RLHF、安全训练 - 06-大模型应用与产品化 - 工程实践与部署
实践项目: - 尝试用Mamba替换Transformer中的部分层,对比长序列性能 - 实现简化的MoE层,观察负载均衡损失的影响
最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026