04. 模型架构创新¶
目录¶
Mixture of Experts (MoE)¶
1.1 MoE 核心思想¶
混合专家模型( Mixture of Experts, MoE ) 通过条件计算( Conditional Computation )实现模型容量与计算成本的解耦。核心思想是:对于每个输入,只激活部分参数(专家),而非整个网络。
┌─────────────────────────────────────────────────────────────────┐
│ MoE 架构示意图 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Input │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Router / │ │
│ │ Gate Network │──▶ 生成路由权重 [0.1, 0.7, 0.05, 0.15] │
│ │ (可学习) │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────┴──────┬────────┬────────┐ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ E0 │ │ E1 │ │ E2 │ │ E3 │ Expert Networks │
│ │0.1 │ │0.7 │ │0.05 │ │0.15 │ (FFN层作为专家) │
│ └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘ │
│ │ │ │ │ │
│ └──────────┴────┬────┴────────┘ │
│ ▼ │
│ ┌──────────┐ │
│ │ 加权求和 │──▶ 0.1×E0 + 0.7×E1 + 0.05×E2 + 0.15×E3 │
│ └──────────┘ │
│ │ │
│ ▼ │
│ Output │
│ │
│ 关键:只计算Top-K专家(如Top-2),其余专家输出设为0 │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 MoE 数学原理¶
MoE层的前向传播
═══════════════════════════════════════════════════════════════════
给定输入 x,MoE层的输出为:
N
MoE(x) = Σ G(x)_i · E_i(x)
i=1
其中:
- N = 专家数量(通常 8, 16, 64, 128)
- E_i = 第i个专家网络(通常是FFN)
- G(x)_i = 门控网络对第i个专家的权重
门控函数(Softmax Gate):
exp(W_g · x)_i
G(x)_i = ─────────────────────
Σ_j exp(W_g · x)_j
Top-K门控(实际使用):
G(x)_i = { Softmax(TopK(W_g · x, k))_i if i ∈ TopK
{ 0 otherwise
负载均衡损失(Auxiliary Loss):
L_aux = α · N · Σ_i f_i · P_i
其中:
- f_i = 批次中分配给专家i的token比例
- P_i = 路由器对专家i的平均预测概率
- α = 超参数(通常 0.01)
═══════════════════════════════════════════════════════════════════
1.3 MoE 实现代码¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class MoELayer(nn.Module):
"""
Mixture of Experts 层实现
"""
def __init__(
self,
d_model: int,
num_experts: int = 8,
top_k: int = 2,
expert_hidden_dim: int = None,
aux_loss_coef: float = 0.01
):
super().__init__() # super()调用父类方法
self.d_model = d_model
self.num_experts = num_experts
self.top_k = top_k
self.aux_loss_coef = aux_loss_coef
if expert_hidden_dim is None:
expert_hidden_dim = 4 * d_model
# 门控网络(路由器)
self.gate = nn.Linear(d_model, num_experts, bias=False)
# 专家网络(每个专家是一个FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, expert_hidden_dim),
nn.GELU(),
nn.Linear(expert_hidden_dim, d_model)
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
output: [batch_size, seq_len, d_model]
aux_loss: 负载均衡辅助损失
"""
batch_size, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model) # [batch_size * seq_len, d_model] # view重塑张量形状
# 1. 计算门控分数
gate_logits = self.gate(x_flat) # [num_tokens, num_experts]
gate_probs = F.softmax(gate_logits, dim=-1)
# 2. Top-K路由
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) # 重新归一化
# 3. 计算专家输出并加权
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_indices = top_k_indices[:, i]
expert_probs = top_k_probs[:, i:i+1]
# 对每个专家处理分配给它的token
for expert_id in range(self.num_experts):
mask = expert_indices == expert_id
if mask.any(): # any()任一为True则返回True
expert_input = x_flat[mask]
expert_output = self.experts[expert_id](expert_input)
output[mask] += expert_probs[mask] * expert_output
# 4. 计算负载均衡辅助损失
aux_loss = self._compute_aux_loss(gate_probs, top_k_indices)
output = output.view(batch_size, seq_len, d_model)
return output, aux_loss
def _compute_aux_loss(
self,
gate_probs: torch.Tensor,
top_k_indices: torch.Tensor
) -> torch.Tensor:
"""
计算负载均衡辅助损失(来自Switch Transformer论文)
"""
num_tokens = gate_probs.size(0)
# f_i: 每个专家被选择的token比例
expert_mask = F.one_hot(top_k_indices, self.num_experts).sum(dim=1)
f = expert_mask.float().sum(dim=0) / num_tokens
# P_i: 路由器分配给每个专家的平均概率
P = gate_probs.mean(dim=0)
# 辅助损失:鼓励均匀分布
aux_loss = self.num_experts * (f * P).sum()
return self.aux_loss_coef * aux_loss
class MoETransformerBlock(nn.Module):
"""
使用MoE的Transformer块(替代标准FFN)
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.1
):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.moe = MoELayer(d_model, num_experts, top_k)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# Self-Attention
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
# MoE FFN
moe_out, aux_loss = self.moe(x)
x = self.norm2(x + self.dropout(moe_out))
return x, aux_loss
1.4 MoE 变体与优化¶
MoE架构演进
═══════════════════════════════════════════════════════════════════
> ⚠️ **时效性说明(2026-02 复核)**:下述“性能对比”结论受评测集、训练数据与评测口径影响,请结合原论文与最新公开评测解读。
Switch Transformer (Google, 2021)
├── Top-1 路由(每个token只选一个专家)
├── 简化设计,降低通信开销
└── 1.6T参数,与T5-XXL相当性能但快4倍
GLaM (Google, 2021)
├── 64专家 per MoE层
├── 1.2T总参数,96B激活参数
└── 在部分公开任务上优于同期GPT-3基线
ST-MoE (Google, 2022)
├── 引入Expert Choice Routing
├── 从token选专家 → 专家选token
└── 更好的负载均衡
Mixtral 8x7B (Mistral, 2023)
├── 8专家,Top-2路由
├── 46.7B总参数,12.9B激活参数
├── 开源,在多项公开评测中表现接近或优于Llama 2 70B
└── 稀疏MoE首次在开源社区成功
Qwen1.5-MoE (阿里, 2024)
├── 共享专家 + 路由专家分离
├── 更好的专家特化
└── 中文MoE代表模型之一
═══════════════════════════════════════════════════════════════════
专家选择路由( Expert Choice Routing )¶
class ExpertChoiceMoE(nn.Module):
"""
Expert Choice Routing: 专家选择token而非token选择专家
优势:完美负载均衡,每个专家处理固定数量token
"""
def __init__(self, d_model: int, num_experts: int, expert_capacity: int):
super().__init__()
self.num_experts = num_experts
self.expert_capacity = expert_capacity # 每个专家处理的token数
self.gate = nn.Linear(d_model, num_experts)
self.experts = nn.ModuleList([
self._create_expert(d_model) for _ in range(num_experts)
])
def forward(self, x: torch.Tensor):
batch_size, seq_len, d_model = x.shape
num_tokens = batch_size * seq_len
x_flat = x.view(num_tokens, d_model)
# 计算每个token对每个专家的偏好分数
gate_scores = self.gate(x_flat) # [num_tokens, num_experts]
# 每个专家选择Top-C个token
selected_tokens = []
for expert_id in range(self.num_experts):
expert_scores = gate_scores[:, expert_id]
top_k_values, top_k_indices = torch.topk(
expert_scores,
min(self.expert_capacity, num_tokens)
)
selected_tokens.append((expert_id, top_k_indices, top_k_values))
# 聚合输出(一个token可能被多个专家处理)
output = torch.zeros_like(x_flat)
for expert_id, indices, scores in selected_tokens:
expert_input = x_flat[indices]
expert_output = self.experts[expert_id](expert_input)
output[indices] += scores.unsqueeze(1) * expert_output # unsqueeze增加一个维度
return output.view(batch_size, seq_len, d_model)
1.5 MoE 训练挑战与解决方案¶
| 挑战 | 原因 | 解决方案 |
|---|---|---|
| 负载不均衡 | 某些专家被过度使用 | 辅助损失、专家选择路由、容量因子 |
| 通信开销 | 专家分布在不同设备 | 专家并行、 All-to-All 通信优化 |
| 训练不稳定 | 路由决策离散 | 负载均衡损失、专家 dropout 、温度退火 |
| 显存占用 | 大量专家参数 | 专家并行、 CPU offloading 、激活检查点 |
状态空间模型 (Mamba)¶
2.1 从 RNN 到 SSM¶
状态空间模型( State Space Model, SSM ) 是一类将连续时间系统离散化的序列建模方法。 Mamba 通过选择性机制( Selective Mechanism )解决了传统 SSM 的内容依赖问题。
序列模型演进
═══════════════════════════════════════════════════════════════════
RNN
├── 特点:隐状态压缩历史信息
├── 问题:长程依赖困难、训练并行度低
└── 代表:LSTM, GRU
Transformer
├── 特点:全局注意力、训练并行
├── 问题:O(n²)复杂度、长序列内存爆炸
└── 代表:GPT, BERT, LLaMA
SSM (S4, 2021)
├── 特点:线性复杂度、连续时间建模
├── 问题:内容无关(参数固定)
└── 代表:S4, DSS, GSS
Mamba (2023)
├── 特点:选择性状态空间、输入依赖
├── 优势:线性复杂度 + 内容感知
└── 代表:Mamba, Jamba, Falcon Mamba
═══════════════════════════════════════════════════════════════════
2.2 状态空间模型基础¶
连续时间状态空间方程
═══════════════════════════════════════════════════════════════════
连续形式:
h'(t) = A · h(t) + B · x(t) (状态演化)
y(t) = C · h(t) (输出)
其中:
- h(t) ∈ R^N: 隐状态(N维)
- x(t) ∈ R: 输入(标量)
- y(t) ∈ R: 输出(标量)
- A ∈ R^(N×N): 状态转移矩阵
- B ∈ R^N: 输入影响矩阵
- C ∈ R^N: 输出投影矩阵
离散化(零阶保持,ZOH):
h_k = Ā · h_{k-1} + B̄ · x_k
y_k = C · h_k
其中离散化参数:
Ā = exp(Δ · A) (矩阵指数)
B̄ = (Ā - I) · A^(-1) · B
═══════════════════════════════════════════════════════════════════
2.3 Mamba 的选择性机制¶
Mamba 的核心创新是使 B 、 C 、Δ成为输入的函数,实现选择性记忆/遗忘。
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class SelectiveSSM(nn.Module):
"""
选择性状态空间模型(Mamba核心)
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(expand * d_model)
# 输入投影(x → x和残差)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 因果卷积(局部建模)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner, # 深度可分离
bias=True
)
# 选择性参数生成(关键创新)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
# 输出: [B, Δ, C] 其中Δ是时间步长参数
# 状态转移矩阵A(可学习或固定)
# 通常使用HiPPO初始化或学习
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1)).repeat(self.d_inner, 1))
self.D = nn.Parameter(torch.ones(self.d_inner)) # 跳跃连接
# 输出投影
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model]
Returns:
output: [batch, seq_len, d_model]
"""
batch, seq_len, dim = x.shape
# 输入投影
x_and_res = self.in_proj(x) # [batch, seq_len, 2*d_inner]
x_ssm, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1)
# 因果卷积
x_conv = rearrange(x_ssm, 'b l d -> b d l')
x_conv = self.conv1d(x_conv)[..., :seq_len] # 因果截断
x_conv = rearrange(x_conv, 'b d l -> b l d')
x_conv = F.silu(x_conv)
# 选择性SSM
y = self.selective_scan(x_conv)
# 门控融合
y = y * F.silu(res)
# 输出投影
output = self.out_proj(y)
return output
def selective_scan(self, x):
"""
选择性扫描:核心计算
使用并行关联扫描(Parallel Associative Scan)实现高效计算
"""
batch, seq_len, d_in = x.shape
# 生成选择性参数(输入依赖)
params = self.x_proj(x) # [batch, seq_len, d_state*2 + 1]
B, C, delta = params.split([self.d_state, self.d_state, 1], dim=-1)
# softplus确保Δ为正
delta = F.softplus(delta.squeeze(-1)) # [batch, seq_len, d_in]
# 离散化
A = -torch.exp(self.A_log.float()) # [d_in, d_state]
# 使用并行扫描高效计算
# 这里使用简化的顺序实现,实际使用CUDA优化的并行版本
y = self.parallel_scan(x, delta, A, B, C)
return y
def parallel_scan(self, x, delta, A, B, C):
"""
并行关联扫描实现
将递归计算转化为可并行的关联操作
"""
batch, seq_len, d_in = x.shape
# 离散化参数
delta_A = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
delta_B = torch.einsum('bld,bln->bldn', delta, B)
# 输入加权
BX = delta_B * x.unsqueeze(-1) # [batch, seq_len, d_in, d_state]
# 关联扫描: h_k = delta_A_k * h_{k-1} + BX_k
# 使用Blelloch扫描算法并行化
h = self.associative_scan(delta_A, BX)
# 输出投影
y = torch.einsum('bldn,bln->bld', h, C)
# 跳跃连接
y = y + self.D.unsqueeze(0).unsqueeze(0) * x
return y
def associative_scan(self, A, B):
"""
关联扫描:将线性递归转化为并行计算
核心思想:h_k = A_k * h_{k-1} + B_k 是关联操作
"""
# 实际实现使用CUDA内核
# 这里展示概念性实现
batch, seq_len, d_in, d_state = B.shape
h = torch.zeros(batch, d_in, d_state, device=B.device, dtype=B.dtype)
hs = []
for k in range(seq_len):
h = A[:, k] * h + B[:, k]
hs.append(h)
return torch.stack(hs, dim=1)
2.4 Mamba 架构细节¶
Mamba Block 详细结构
═══════════════════════════════════════════════════════════════════
输入 x ∈ R^(L×D)
│
├──▶ Linear ──▶ [x_ssm, res] 分割为SSM输入和残差
│
├──▶ x_ssm ──▶ Conv1d (因果) ──▶ SiLU ──▶ Selective SSM ──┐
│ │
├──▶ res ─────────────────────────────────────────────▶ SiLU │
│ │
└──▶ 逐元素相乘 ◀───────────────────────────────────────────┘
│
▼
Linear ──▶ 输出
关键设计选择:
1. 因果卷积:引入局部依赖,弥补SSM缺乏局部性的不足
2. SiLU激活:平滑非线性,与门控机制配合
3. 选择性参数:B, C, Δ都依赖于输入x
4. 并行扫描:O(n)复杂度的高效实现
═══════════════════════════════════════════════════════════════════
2.5 Mamba 变体¶
Mamba 家族
═══════════════════════════════════════════════════════════════════
Mamba (2023)
├── 纯SSM架构,无注意力
├── 线性复杂度 O(n)
├── 语言建模性能匹敌Transformer
└── 长序列优势显著
Jamba (AI21, 2024)
├── Mamba + Transformer 混合
├── 注意力层与SSM层交替
├── 结合两者优势
└── 256K上下文支持
Falcon Mamba (TII, 2024)
├── 7B参数纯Mamba模型
├── 开源可商用
└── 在长文本任务上表现突出
Vision Mamba (2024)
├── 将Mamba应用于视觉
├── 双向SSM处理图像
└── 高效视觉backbone
Mamba-2 (2024)
├── 理论统一:SSM与结构化矩阵
├── 更高效的实现
└── 8倍训练速度提升
═══════════════════════════════════════════════════════════════════
线性注意力与 RNN 化 Transformer¶
3.1 线性注意力机制¶
标准注意力的复杂度为 O(n²),线性注意力通过核技巧将其降至 O(n)。
线性注意力原理
═══════════════════════════════════════════════════════════════════
标准Softmax Attention:
softmax(QK^T)V
Attn = ───────────────
Z
其中 Z = row_sum(softmax(QK^T))
复杂度:O(n²·d) 因为需要计算n×n的注意力矩阵
线性注意力(核技巧):
使用特征映射 φ,使得:
softmax(QK^T/√d) ≈ φ(Q)φ(K)^T
则:
φ(Q)φ(K)^T V φ(Q)(φ(K)^T V)
Attn = ───────────── = ─────────────────
Z Z
关键观察:φ(K)^T V 可以先计算(与序列长度无关)
复杂度:O(n·d²) 线性于序列长度!
═══════════════════════════════════════════════════════════════════
class LinearAttention(nn.Module):
"""
线性注意力实现(基于Katharopoulos et al.)
"""
def __init__(self, dim, num_heads, feature_dim=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.feature_dim = feature_dim or self.head_dim
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def elu_feature_map(self, x):
"""特征映射:ELU + 1"""
return F.elu(x) + 1
def forward(self, x):
B, N, C = x.shape
# 生成QKV
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
# 特征映射
q = self.elu_feature_map(q)
k = self.elu_feature_map(k)
# 线性注意力计算
# KV = Σ_t k_t^T v_t
KV = torch.einsum('bhsk,bhsv->bhkv', k, v) # [B, heads, head_dim, head_dim]
# Z = Σ_t k_t
Z = k.sum(dim=2) # [B, heads, head_dim]
# 输出 = Q @ KV / (Q @ Z)
numerator = torch.einsum('bhnk,bhkv->bhnv', q, KV) # [B, heads, N, head_dim]
denominator = torch.einsum('bhnk,bhk->bhn', q, Z).unsqueeze(-1) + 1e-6
out = numerator / denominator
out = out.transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
return out
3.2 RNN 化 Transformer¶
将 Transformer 解码为 RNN 形式,实现 O(1)推理内存。
class RNNTransformer(nn.Module):
"""
Transformer的RNN形式(基于RWKV和RetNet思想)
"""
def __init__(self, dim, num_heads):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.decay = nn.Parameter(torch.randn(num_heads)) # 每头一个标量衰减率
def forward_training(self, x):
"""训练时使用并行计算"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 因果掩码 + 衰减
positions = torch.arange(N, device=x.device)
rel_pos = positions.unsqueeze(0) - positions.unsqueeze(1) # [N, N]
# decay_matrix[h, n, m] = exp((n-m) * decay_h),对下三角有效
decay_matrix = torch.exp(
rel_pos.unsqueeze(0) * self.decay.view(-1, 1, 1) # [H, N, N]
)
decay_matrix = torch.tril(decay_matrix) # 因果掩码
# 注意力:Retention(X) = (Q K^T ⊙ D) V
attn = torch.einsum('bnhd,bmhd->bhnm', q, k) # [B, H, N, M]
attn = attn * decay_matrix.unsqueeze(0) # [B, H, N, M]
attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-6)
out = torch.einsum('bhnm,bmhd->bnhd', attn, v) # [B, N, H, D]
return out.reshape(B, N, C)
def forward_inference(self, x, state):
"""
推理时使用RNN形式,O(1)内存
state: (KV_cache, K_sum_cache)
"""
B, C = x.shape
qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
KV_cache, K_sum_cache = state
# 更新状态(RNN形式)
decay = torch.exp(self.decay) # [num_heads]
# KV_cache = decay * KV_cache + k^T v
new_KV = KV_cache * decay.view(1, -1, 1, 1) # [B, H, D, D]
new_KV = new_KV + torch.einsum('bhd,bhe->bhde', k, v)
# K_sum_cache = decay * K_sum_cache + k
new_K_sum = K_sum_cache * decay.view(1, -1, 1) # [B, H, D]
new_K_sum = new_K_sum + k
# 计算输出
numerator = torch.einsum('bhd,bhde->bhe', q, new_KV)
denominator = torch.einsum('bhd,bhd->bh', q, new_K_sum).unsqueeze(-1) + 1e-6
out = numerator / denominator
return out.reshape(B, C), (new_KV, new_K_sum)
3.3 代表性模型¶
线性/RNN化Transformer模型
═══════════════════════════════════════════════════════════════════
RWKV (2023)
├── 架构:结合RNN和Transformer优点
├── 训练:并行化(类似Transformer)
├── 推理:O(1)内存(类似RNN)
├── 特点:时间衰减机制
└── 开源:完全开源可商用
RetNet (Microsoft, 2023)
├── 架构:保留Transformer结构
├── 核心:Retention机制(对偶形式)
├── 训练:并行
├── 推理:循环形式
└── 声称:性能匹敌Transformer
TransNormerLLM (2023)
├── 分块注意力 + 线性注意力
├── 局部精确 + 全局线性
└── 长序列高效
Linear Transformer (2020)
├── 最早的线性注意力工作
├── 使用核技巧近似softmax
└── 理论基础
═══════════════════════════════════════════════════════════════════
其他架构创新¶
4.1 混合专家系统(非 MoE 形式)¶
class MultiScaleTransformer(nn.Module):
"""
多尺度Transformer:不同层使用不同注意力模式
"""
def __init__(self, dim, num_layers):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
if i % 3 == 0:
# 局部窗口注意力(捕捉局部模式)
self.layers.append(LocalAttention(dim, window_size=512))
elif i % 3 == 1:
# 稀疏全局注意力(长程依赖)
self.layers.append(SparseGlobalAttention(dim, stride=16))
else:
# 标准注意力(精确建模)
self.layers.append(StandardAttention(dim))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
4.2 分层注意力¶
class HierarchicalAttention(nn.Module):
"""
分层注意力:先粗粒度后细粒度
"""
def __init__(self, dim, num_tokens_per_group=64):
super().__init__()
self.num_tokens_per_group = num_tokens_per_group
# 粗粒度(组级别)
self.coarse_attention = nn.MultiheadAttention(dim, num_heads=8)
# 细粒度(组内)
self.fine_attention = nn.MultiheadAttention(dim, num_heads=8)
def forward(self, x):
B, N, C = x.shape
G = N // self.num_tokens_per_group # 组数
# 分组
x_grouped = x.reshape(B, G, self.num_tokens_per_group, C)
# 组内池化得到组表示
group_repr = x_grouped.mean(dim=2) # [B, G, C]
# 粗粒度:组间注意力
coarse_out, _ = self.coarse_attention(group_repr, group_repr, group_repr)
# 将组表示广播回各token
coarse_expanded = coarse_out.unsqueeze(2).expand(-1, -1, self.num_tokens_per_group, -1)
# 细粒度:组内注意力
x_combined = x_grouped + coarse_expanded # 融合
x_fine = x_combined.reshape(B * G, self.num_tokens_per_group, C)
fine_out, _ = self.fine_attention(x_fine, x_fine, x_fine)
return fine_out.reshape(B, N, C)
4.3 内存高效架构¶
class MemoryEfficientAttention(nn.Module):
"""
FlashAttention风格的内存高效实现
通过分块计算避免存储完整的注意力矩阵
"""
def __init__(self, dim, num_heads, block_size=256):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.block_size = block_size
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
分块计算注意力,O(n)内存(无需存储完整N×N注意力矩阵)
"""
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# 分块
num_blocks = (N + self.block_size - 1) // self.block_size
outputs = []
for i in range(num_blocks):
start_i = i * self.block_size
end_i = min(start_i + self.block_size, N)
block_len = end_i - start_i
q_block = q[:, start_i:end_i] # [B, block_len, H, D]
# 在线计算softmax
max_score = torch.full((B, H, block_len, 1), float('-inf'), device=x.device)
numerator = torch.zeros(B, H, block_len, D, device=x.device)
denominator = torch.zeros(B, H, block_len, 1, device=x.device)
for j in range(num_blocks):
start_j = j * self.block_size
end_j = min(start_j + self.block_size, N)
k_block = k[:, start_j:end_j] # [B, block_j, H, D]
v_block = v[:, start_j:end_j] # [B, block_j, H, D]
# 计算块间注意力,保留多头维度
scores = torch.einsum('bnhd,bmhd->bhnm', q_block, k_block) # [B, H, block_i, block_j]
scores = scores / math.sqrt(D)
# 在线softmax更新
max_score_new = torch.maximum(max_score, scores.max(dim=-1, keepdim=True)[0])
exp_scores = torch.exp(scores - max_score_new)
numerator = numerator * torch.exp(max_score - max_score_new) + \
torch.einsum('bhnm,bmhd->bhnd', exp_scores, v_block)
denominator = denominator * torch.exp(max_score - max_score_new) + \
exp_scores.sum(dim=-1, keepdim=True)
max_score = max_score_new
out_block = numerator / denominator # [B, H, block_len, D]
out_block = out_block.permute(0, 2, 1, 3) # [B, block_len, H, D]
outputs.append(out_block)
out = torch.cat(outputs, dim=1)
out = out.reshape(B, N, C)
return self.proj(out)
架构选择指南¶
5.1 不同场景推荐¶
架构选择决策树
═══════════════════════════════════════════════════════════════════
你的主要约束是什么?
│
├── 长序列(>32K tokens)
│ │
│ ├── 需要强长程依赖捕捉
│ │ └──▶ Mamba / Jamba / Linear Attention
│ │
│ └── 主要关注局部模式
│ └──▶ 滑动窗口注意力 + 全局token
│
├── 大规模部署(成本敏感)
│ │
│ ├── 高吞吐推理
│ │ └──▶ MoE(激活参数少)/ RWKV(O(1)内存)
│ │
│ └── 训练成本控制
│ └──▶ 线性注意力 / 分层注意力
│
├── 通用性能优先
│ │
│ ├── 有足够计算资源
│ │ └──▶ 标准Transformer + FlashAttention
│ │
│ └── 资源受限
│ └──▶ 混合架构(MoE + 线性注意力)
│
└── 特定领域
│
├── 代码生成
│ └──▶ 标准Transformer(结构化模式重要)
│
├── 长文档理解
│ └──▶ Mamba / Longformer
│
└── 多模态
└──▶ 标准Transformer(跨模态对齐成熟)
═══════════════════════════════════════════════════════════════════
5.2 架构对比表¶
| 架构 | 训练复杂度 | 推理复杂度 | 长序列支持 | 实现难度 | 生态成熟度 |
|---|---|---|---|---|---|
| 标准 Transformer | O(n²) | O(n²) | 差 | 低 | 极高 |
| FlashAttention | O(n²) | O(n²) | 中 | 中 | 高 |
| MoE | O(n²)⁺ | O(n²)⁺ | 好 | 高 | 中 |
| Mamba | O(n) | O(n) | 极好 | 中 | 增长中 |
| 线性注意力 | O(n) | O(n) | 好 | 低 | 中 |
| RWKV | O(n) | O(1)⁺⁺ | 好 | 中 | 低 |
⁺ MoE 注意力复杂度仍为 O(n²),但每个 token 的 FFN 计算量显著降低(仅激活少数专家) ⁺⁺ RWKV 推理复杂度 O(1)指内存不随序列增长
5.3 实践建议¶
- 从标准 Transformer 开始
- 除非有明确的长序列需求,否则优先使用成熟的 Transformer
-
FlashAttention 已解决大部分内存问题
-
MoE 的适用场景
- 模型参数>10B 时考虑
- 推理吞吐量优先于延迟的场景
-
有分布式训练基础设施
-
Mamba/SSM 的适用场景
- 序列长度>32K
- 需要流式处理(如语音、视频)
-
内存严格受限的边端设备
-
混合策略
- 底层: Mamba/线性注意力处理长上下文
- 上层:标准 Transformer 进行精确建模
- 参考: Jamba 架构
面试常见问题¶
Q1: MoE 架构为什么能实现"参数量大但计算量小"?负载均衡损失的作用是什么?¶
答:MoE(Mixture of Experts)的核心思想是将 FFN 层替换为多个并行的专家网络,每个 token 只激活其中少数(通常 Top-2)专家。例如 DeepSeek-V3 有 256 个专家但每个 token 只用 8 个,因此总参数量虽大(671B),但每个 token 的激活参数仅 37B。
负载均衡损失(Load Balancing Loss) 解决的是"路由崩塌"问题:如果不加约束,路由器倾向于将大部分 token 分配给少数"热门"专家,导致: - 热门专家过载,成为训练瓶颈 - 冷门专家得不到训练,参数浪费 - GPU 间负载不均衡(跨设备 MoE 时尤其严重)
DeepSeek-V3 使用辅助无损负载均衡策略(Auxiliary-Loss-Free Load Balancing),通过偏置项动态调整专家选择,在不牺牲性能的前提下实现均衡。
Q2: Mamba/SSM 与 Transformer 的本质区别是什么?SSM 能完全替代 Transformer 吗?¶
答:本质区别在于状态表示: - Transformer:维护完整的 KV Cache(显式记忆所有历史),复杂度 O(n²) - SSM(Mamba):将历史信息压缩到固定大小的隐状态中(类似 RNN),复杂度 O(n)
SSM 的优势是长序列处理效率高、推理内存恒定;劣势是在需要精确回溯远距离 token 的任务上(如复制、精确检索)表现不如 Transformer。
不能完全替代。当前趋势是混合架构(如 Jamba = Mamba + Attention + MoE): - 底层用 Mamba 处理长上下文(效率优先) - 关键层保留 Attention(精度优先) - MoE 控制参数-计算比
Q3: FlashAttention 为什么能加速?它改变了数学计算吗?¶
答:FlashAttention 没有改变数学计算,输出与标准注意力完全一致。它加速的原理是减少 GPU HBM(高带宽内存)访问次数:
- 标准注意力:S = QK^T 写入 HBM → 读取 HBM 加掩码 → Softmax 写入 HBM → 乘 V 写入 HBM。中间矩阵 S(n×n) 占用大量内存带宽。
- FlashAttention:利用 GPU SRAM(片上内存,快但小),将 Q、K、V 分块(tiling),在 SRAM 中完成 QK^T → Softmax → ×V 的完整计算,只将最终结果写回 HBM。
核心技巧是在线 Softmax(Online Softmax):分块计算时动态维护全局最大值和累加和,避免二次读取。
FlashAttention-2 进一步优化了并行策略(沿序列长度维度并行),FlashAttention-3 利用 H100 的异步性和 TF32。
Q4: 线性注意力(如 RWKV)为什么没有成为主流?它的根本局限是什么?¶
答:线性注意力通过将 Softmax 分解为核函数近似(φ(Q)·φ(K)^T · V),将复杂度从 O(n²) 降为 O(n)。但根本局限在于:
- 信息压缩损失:将 n×d 的注意力矩阵压缩为 d×d 的累积和矩阵,等价于用一个固定大小的"摘要"代替完整的注意力分布,无法精确回溯特定位置
- 表达能力下降:核函数近似是 Softmax 的低秩近似,理论上无法完全恢复 Softmax 的建模能力
- 实践效果:在短-中序列任务上与 Transformer 差距不大,但在需要精确位置关联的任务(如代码生成、数学推理)上明显落后
RWKV 的创新在于结合了 RNN 的推理效率(O(1) 内存)和 Transformer 的训练并行性,适合资源受限的部署场景。
下一步¶
完成本章节学习后,建议继续阅读: - 05-大模型安全与对齐 - 理解 RLHF 、安全训练 - 06-大模型应用与产品化 - 工程实践与部署
实践项目: - 尝试用 Mamba 替换 Transformer 中的部分层,对比长序列性能 - 实现简化的 MoE 层,观察负载均衡损失的影响
📝 本章练习¶
🤔 思考题¶
- MoE 负载均衡:MoE 模型中为什么会出现"路由坍塌"(所有 Token 都被分配到少数 Expert)?负载均衡损失的原理是什么?
- Mamba vs Transformer:Mamba 的选择性状态空间机制(Selective SSM)相比传统 SSM 做了什么改进?为什么能在保持线性复杂度的同时获得更好的性能?
- FlashAttention:FlashAttention 是如何减少 GPU HBM 访问次数的?为什么减少内存访问比减少计算量更重要?
- 架构趋势:你认为未来 LLM 架构会朝什么方向发展?纯 Transformer、MoE+Transformer、还是 SSM+Transformer 混合?
💻 代码实践¶
- 入门:实现一个简化的 Top-K 路由 MoE 层,观察不同 K 值对负载分布的影响
- 进阶:对比 Mamba 和 Transformer 在长序列(8K+ tokens)上的推理速度和内存占用
- 高级:实现 Grouped Query Attention (GQA),对比与 MHA、MQA 的推理效率差异
💡 参考答案
#### 思考题参考答案 **1. MoE 负载均衡** 路由坍塌原因: - 初始化时某些 Expert 恰好对训练早期样本表现好 - 梯度更新强化了这种偏好,形成正反馈循环 - 最终所有 Token 都路由到少数 Expert,其余 Expert 得不到训练 负载均衡损失:`L_aux = α · N · Σ(f_i · P_i)`,其中 f_i 是 Expert i 被选中的频率,P_i 是平均路由概率。最小化这个损失鼓励均匀分布。 **2. Mamba 选择性 SSM** 传统 SSM(如 S4)的参数是固定的,不依赖输入。Selective SSM 让参数(B、C、Δ)成为输入的函数: - 输入相关的步长 Δ:对重要 Token 分配更多计算 - 输入相关的 B、C:选择性传播或遗忘信息 这解决了传统 SSM 无法"选择性记忆"的问题,在需要精确位置关联的任务上性能显著提升。 **3. FlashAttention** 核心优化:利用 GPU SRAM(片上内存,~20MB)的高带宽,减少对 HBM(高带宽内存,~80GB)的访问。 具体策略: - **分块计算**:将 Q、K、V 分成小块,在 SRAM 中完成注意力计算 - **在线 Softmax**:逐块更新 Softmax 结果,避免先完整计算再写回 HBM - **重计算**:反向传播时重新计算注意力矩阵而非存储,用计算换内存 为什么内存访问更重要:GPU 计算能力远超内存带宽(算力/带宽比持续增长),内存访问成为瓶颈。 **4. 架构趋势** 短期(2025-2026):MoE + Transformer 为主流(DeepSeek-V3、Mixtral),GQA 成为标配。 中期:SSM + Attention 混合架构(如 Jamba)可能在中长序列场景替代纯 Transformer。 长期:可能涌现全新的架构范式,但 Transformer 的生态优势会持续很长时间。 #### 代码实践参考答案 **实践 1:简化 MoE 层**import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleMoELayer(nn.Module):
def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.experts = nn.ModuleList([
nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
for _ in range(num_experts)
])
self.gate = nn.Linear(d_model, num_experts)
def forward(self, x):
# x: [batch, seq_len, d_model]
logits = self.gate(x) # [batch, seq, num_experts]
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
top_k_weights = F.softmax(top_k_logits, dim=-1)
# 负载统计
expert_counts = torch.zeros(self.num_experts)
for idx in top_k_indices.flatten().tolist():
expert_counts[idx] += 1
load_balance = (expert_counts / expert_counts.sum()).std().item()
output = torch.zeros_like(x)
for i in range(self.top_k):
for e in range(self.num_experts):
mask = (top_k_indices[:, :, i] == e)
if mask.any():
expert_input = x[mask]
expert_output = self.experts[e](expert_input)
output[mask] += top_k_weights[:, :, i][mask].unsqueeze(-1) * expert_output
return output, load_balance
# 测试
moe = SimpleMoELayer(d_model=256, d_ff=1024, num_experts=8, top_k=2)
x = torch.randn(4, 32, 256)
out, balance = moe(x)
print(f"负载均衡标准差: {balance:.4f} (越小越均衡)")
最后更新日期:2026-04-21 适用版本:LLM 学习教程 v2026.04
审查记录: - 2026-04-20: 修复标题编号为 "# 04.",新增面试常见问题章节(Q1-Q4: MoE 负载均衡、Mamba vs Transformer、FlashAttention 原理、线性注意力局限),更新日期 - 2026-03-26: 初始版本