跳转至

04. 模型架构创新

目录

  1. Mixture of Experts (MoE)
  2. 状态空间模型 (Mamba)
  3. 线性注意力与 RNN 化 Transformer
  4. 其他架构创新
  5. 架构选择指南

Mixture of Experts (MoE)

1.1 MoE 核心思想

混合专家模型( Mixture of Experts, MoE ) 通过条件计算( Conditional Computation )实现模型容量与计算成本的解耦。核心思想是:对于每个输入,只激活部分参数(专家),而非整个网络。

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    MoE 架构示意图                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  Input                                                           │
│    │                                                             │
│    ▼                                                             │
│  ┌─────────────────┐                                             │
│  │   Router /      │                                             │
│  │   Gate Network  │──▶ 生成路由权重 [0.1, 0.7, 0.05, 0.15]      │
│  │   (可学习)       │                                             │
│  └─────────────────┘                                             │
│           │                                                      │
│           ▼                                                      │
│    ┌──────┴──────┬────────┬────────┐                            │
│    ▼             ▼        ▼        ▼                            │
│  ┌─────┐    ┌─────┐   ┌─────┐  ┌─────┐                         │
│  │ E0  │    │ E1  │   │ E2  │  │ E3  │    Expert Networks       │
│  │0.1  │    │0.7  │   │0.05 │  │0.15 │    (FFN层作为专家)        │
│  └──┬──┘    └──┬──┘   └──┬──┘  └──┬──┘                         │
│     │          │         │        │                             │
│     └──────────┴────┬────┴────────┘                             │
│                     ▼                                            │
│              ┌──────────┐                                        │
│              │ 加权求和  │──▶ 0.1×E0 + 0.7×E1 + 0.05×E2 + 0.15×E3 │
│              └──────────┘                                        │
│                     │                                            │
│                     ▼                                            │
│                   Output                                         │
│                                                                  │
│  关键:只计算Top-K专家(如Top-2),其余专家输出设为0              │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

1.2 MoE 数学原理

Text Only
MoE层的前向传播
═══════════════════════════════════════════════════════════════════

给定输入 x,MoE层的输出为:

                N
MoE(x) = Σ  G(x)_i · E_i(x)
               i=1

其中:
- N = 专家数量(通常 8, 16, 64, 128)
- E_i = 第i个专家网络(通常是FFN)
- G(x)_i = 门控网络对第i个专家的权重

门控函数(Softmax Gate):

              exp(W_g · x)_i
G(x)_i = ─────────────────────
          Σ_j exp(W_g · x)_j

Top-K门控(实际使用):

G(x)_i = { Softmax(TopK(W_g · x, k))_i  if i ∈ TopK
         { 0                           otherwise

负载均衡损失(Auxiliary Loss):

L_aux = α · N · Σ_i f_i · P_i

其中:
- f_i = 批次中分配给专家i的token比例
- P_i = 路由器对专家i的平均预测概率
- α = 超参数(通常 0.01)

═══════════════════════════════════════════════════════════════════

1.3 MoE 实现代码

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class MoELayer(nn.Module):
    """
    Mixture of Experts 层实现
    """
    def __init__(
        self,
        d_model: int,
        num_experts: int = 8,
        top_k: int = 2,
        expert_hidden_dim: int = None,
        aux_loss_coef: float = 0.01
    ):
        super().__init__()  # super()调用父类方法
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        self.aux_loss_coef = aux_loss_coef

        if expert_hidden_dim is None:
            expert_hidden_dim = 4 * d_model

        # 门控网络(路由器)
        self.gate = nn.Linear(d_model, num_experts, bias=False)

        # 专家网络(每个专家是一个FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, expert_hidden_dim),
                nn.GELU(),
                nn.Linear(expert_hidden_dim, d_model)
            )
            for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            output: [batch_size, seq_len, d_model]
            aux_loss: 负载均衡辅助损失
        """
        batch_size, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)  # [batch_size * seq_len, d_model]  # view重塑张量形状

        # 1. 计算门控分数
        gate_logits = self.gate(x_flat)  # [num_tokens, num_experts]
        gate_probs = F.softmax(gate_logits, dim=-1)

        # 2. Top-K路由
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)  # 重新归一化

        # 3. 计算专家输出并加权
        output = torch.zeros_like(x_flat)

        for i in range(self.top_k):
            expert_indices = top_k_indices[:, i]
            expert_probs = top_k_probs[:, i:i+1]

            # 对每个专家处理分配给它的token
            for expert_id in range(self.num_experts):
                mask = expert_indices == expert_id
                if mask.any():  # any()任一为True则返回True
                    expert_input = x_flat[mask]
                    expert_output = self.experts[expert_id](expert_input)
                    output[mask] += expert_probs[mask] * expert_output

        # 4. 计算负载均衡辅助损失
        aux_loss = self._compute_aux_loss(gate_probs, top_k_indices)

        output = output.view(batch_size, seq_len, d_model)
        return output, aux_loss

    def _compute_aux_loss(
        self,
        gate_probs: torch.Tensor,
        top_k_indices: torch.Tensor
    ) -> torch.Tensor:
        """
        计算负载均衡辅助损失(来自Switch Transformer论文)
        """
        num_tokens = gate_probs.size(0)

        # f_i: 每个专家被选择的token比例
        expert_mask = F.one_hot(top_k_indices, self.num_experts).sum(dim=1)
        f = expert_mask.float().sum(dim=0) / num_tokens

        # P_i: 路由器分配给每个专家的平均概率
        P = gate_probs.mean(dim=0)

        # 辅助损失:鼓励均匀分布
        aux_loss = self.num_experts * (f * P).sum()

        return self.aux_loss_coef * aux_loss

class MoETransformerBlock(nn.Module):
    """
    使用MoE的Transformer块(替代标准FFN)
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_experts: int = 8,
        top_k: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.moe = MoELayer(d_model, num_experts, top_k)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # Self-Attention
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))

        # MoE FFN
        moe_out, aux_loss = self.moe(x)
        x = self.norm2(x + self.dropout(moe_out))

        return x, aux_loss

1.4 MoE 变体与优化

Text Only
MoE架构演进
═══════════════════════════════════════════════════════════════════

> ⚠️ **时效性说明(2026-02 复核)**:下述“性能对比”结论受评测集、训练数据与评测口径影响,请结合原论文与最新公开评测解读。

Switch Transformer (Google, 2021)
├── Top-1 路由(每个token只选一个专家)
├── 简化设计,降低通信开销
└── 1.6T参数,与T5-XXL相当性能但快4倍

GLaM (Google, 2021)
├── 64专家 per MoE层
├── 1.2T总参数,96B激活参数
└── 在部分公开任务上优于同期GPT-3基线

ST-MoE (Google, 2022)
├── 引入Expert Choice Routing
├── 从token选专家 → 专家选token
└── 更好的负载均衡

Mixtral 8x7B (Mistral, 2023)
├── 8专家,Top-2路由
├── 46.7B总参数,12.9B激活参数
├── 开源,在多项公开评测中表现接近或优于Llama 2 70B
└── 稀疏MoE首次在开源社区成功

Qwen1.5-MoE (阿里, 2024)
├── 共享专家 + 路由专家分离
├── 更好的专家特化
└── 中文MoE代表模型之一

═══════════════════════════════════════════════════════════════════

专家选择路由( Expert Choice Routing )

Python
class ExpertChoiceMoE(nn.Module):
    """
    Expert Choice Routing: 专家选择token而非token选择专家
    优势:完美负载均衡,每个专家处理固定数量token
    """
    def __init__(self, d_model: int, num_experts: int, expert_capacity: int):
        super().__init__()
        self.num_experts = num_experts
        self.expert_capacity = expert_capacity  # 每个专家处理的token数

        self.gate = nn.Linear(d_model, num_experts)
        self.experts = nn.ModuleList([
            self._create_expert(d_model) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len
        x_flat = x.view(num_tokens, d_model)

        # 计算每个token对每个专家的偏好分数
        gate_scores = self.gate(x_flat)  # [num_tokens, num_experts]

        # 每个专家选择Top-C个token
        selected_tokens = []
        for expert_id in range(self.num_experts):
            expert_scores = gate_scores[:, expert_id]
            top_k_values, top_k_indices = torch.topk(
                expert_scores,
                min(self.expert_capacity, num_tokens)
            )
            selected_tokens.append((expert_id, top_k_indices, top_k_values))

        # 聚合输出(一个token可能被多个专家处理)
        output = torch.zeros_like(x_flat)
        for expert_id, indices, scores in selected_tokens:
            expert_input = x_flat[indices]
            expert_output = self.experts[expert_id](expert_input)
            output[indices] += scores.unsqueeze(1) * expert_output  # unsqueeze增加一个维度

        return output.view(batch_size, seq_len, d_model)

1.5 MoE 训练挑战与解决方案

挑战 原因 解决方案
负载不均衡 某些专家被过度使用 辅助损失、专家选择路由、容量因子
通信开销 专家分布在不同设备 专家并行、 All-to-All 通信优化
训练不稳定 路由决策离散 负载均衡损失、专家 dropout 、温度退火
显存占用 大量专家参数 专家并行、 CPU offloading 、激活检查点

状态空间模型 (Mamba)

2.1 从 RNN 到 SSM

状态空间模型( State Space Model, SSM ) 是一类将连续时间系统离散化的序列建模方法。 Mamba 通过选择性机制( Selective Mechanism )解决了传统 SSM 的内容依赖问题。

Text Only
序列模型演进
═══════════════════════════════════════════════════════════════════

RNN
├── 特点:隐状态压缩历史信息
├── 问题:长程依赖困难、训练并行度低
└── 代表:LSTM, GRU

Transformer
├── 特点:全局注意力、训练并行
├── 问题:O(n²)复杂度、长序列内存爆炸
└── 代表:GPT, BERT, LLaMA

SSM (S4, 2021)
├── 特点:线性复杂度、连续时间建模
├── 问题:内容无关(参数固定)
└── 代表:S4, DSS, GSS

Mamba (2023)
├── 特点:选择性状态空间、输入依赖
├── 优势:线性复杂度 + 内容感知
└── 代表:Mamba, Jamba, Falcon Mamba

═══════════════════════════════════════════════════════════════════

2.2 状态空间模型基础

Text Only
连续时间状态空间方程
═══════════════════════════════════════════════════════════════════

连续形式:

h'(t) = A · h(t) + B · x(t)   (状态演化)
y(t)  = C · h(t)             (输出)

其中:
- h(t) ∈ R^N: 隐状态(N维)
- x(t) ∈ R: 输入(标量)
- y(t) ∈ R: 输出(标量)
- A ∈ R^(N×N): 状态转移矩阵
- B ∈ R^N: 输入影响矩阵
- C ∈ R^N: 输出投影矩阵

离散化(零阶保持,ZOH):

h_k = Ā · h_{k-1} + B̄ · x_k
y_k = C · h_k

其中离散化参数:
Ā = exp(Δ · A)          (矩阵指数)
B̄ = (Ā - I) · A^(-1) · B

═══════════════════════════════════════════════════════════════════

2.3 Mamba 的选择性机制

Mamba 的核心创新是使 B 、 C 、Δ成为输入的函数,实现选择性记忆/遗忘。

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class SelectiveSSM(nn.Module):
    """
    选择性状态空间模型(Mamba核心)
    """
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = int(expand * d_model)

        # 输入投影(x → x和残差)
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # 因果卷积(局部建模)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,  # 深度可分离
            bias=True
        )

        # 选择性参数生成(关键创新)
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
        # 输出: [B, Δ, C] 其中Δ是时间步长参数

        # 状态转移矩阵A(可学习或固定)
        # 通常使用HiPPO初始化或学习
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1)).repeat(self.d_inner, 1))
        self.D = nn.Parameter(torch.ones(self.d_inner))  # 跳跃连接

        # 输出投影
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            output: [batch, seq_len, d_model]
        """
        batch, seq_len, dim = x.shape

        # 输入投影
        x_and_res = self.in_proj(x)  # [batch, seq_len, 2*d_inner]
        x_ssm, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1)

        # 因果卷积
        x_conv = rearrange(x_ssm, 'b l d -> b d l')
        x_conv = self.conv1d(x_conv)[..., :seq_len]  # 因果截断
        x_conv = rearrange(x_conv, 'b d l -> b l d')
        x_conv = F.silu(x_conv)

        # 选择性SSM
        y = self.selective_scan(x_conv)

        # 门控融合
        y = y * F.silu(res)

        # 输出投影
        output = self.out_proj(y)
        return output

    def selective_scan(self, x):
        """
        选择性扫描:核心计算
        使用并行关联扫描(Parallel Associative Scan)实现高效计算
        """
        batch, seq_len, d_in = x.shape

        # 生成选择性参数(输入依赖)
        params = self.x_proj(x)  # [batch, seq_len, d_state*2 + 1]
        B, C, delta = params.split([self.d_state, self.d_state, 1], dim=-1)

        # softplus确保Δ为正
        delta = F.softplus(delta.squeeze(-1))  # [batch, seq_len, d_in]

        # 离散化
        A = -torch.exp(self.A_log.float())  # [d_in, d_state]

        # 使用并行扫描高效计算
        # 这里使用简化的顺序实现,实际使用CUDA优化的并行版本
        y = self.parallel_scan(x, delta, A, B, C)

        return y

    def parallel_scan(self, x, delta, A, B, C):
        """
        并行关联扫描实现
        将递归计算转化为可并行的关联操作
        """
        batch, seq_len, d_in = x.shape

        # 离散化参数
        delta_A = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
        delta_B = torch.einsum('bld,bln->bldn', delta, B)

        # 输入加权
        BX = delta_B * x.unsqueeze(-1)  # [batch, seq_len, d_in, d_state]

        # 关联扫描: h_k = delta_A_k * h_{k-1} + BX_k
        # 使用Blelloch扫描算法并行化
        h = self.associative_scan(delta_A, BX)

        # 输出投影
        y = torch.einsum('bldn,bln->bld', h, C)

        # 跳跃连接
        y = y + self.D.unsqueeze(0).unsqueeze(0) * x

        return y

    def associative_scan(self, A, B):
        """
        关联扫描:将线性递归转化为并行计算
        核心思想:h_k = A_k * h_{k-1} + B_k 是关联操作
        """
        # 实际实现使用CUDA内核
        # 这里展示概念性实现
        batch, seq_len, d_in, d_state = B.shape

        h = torch.zeros(batch, d_in, d_state, device=B.device, dtype=B.dtype)
        hs = []

        for k in range(seq_len):
            h = A[:, k] * h + B[:, k]
            hs.append(h)

        return torch.stack(hs, dim=1)

2.4 Mamba 架构细节

Text Only
Mamba Block 详细结构
═══════════════════════════════════════════════════════════════════

输入 x ∈ R^(L×D)
    ├──▶ Linear ──▶ [x_ssm, res]  分割为SSM输入和残差
    ├──▶ x_ssm ──▶ Conv1d (因果) ──▶ SiLU ──▶ Selective SSM ──┐
    │                                                          │
    ├──▶ res ─────────────────────────────────────────────▶ SiLU │
    │                                                          │
    └──▶ 逐元素相乘 ◀───────────────────────────────────────────┘
         Linear ──▶ 输出

关键设计选择:
1. 因果卷积:引入局部依赖,弥补SSM缺乏局部性的不足
2. SiLU激活:平滑非线性,与门控机制配合
3. 选择性参数:B, C, Δ都依赖于输入x
4. 并行扫描:O(n)复杂度的高效实现

═══════════════════════════════════════════════════════════════════

2.5 Mamba 变体

Text Only
Mamba 家族
═══════════════════════════════════════════════════════════════════

Mamba (2023)
├── 纯SSM架构,无注意力
├── 线性复杂度 O(n)
├── 语言建模性能匹敌Transformer
└── 长序列优势显著

Jamba (AI21, 2024)
├── Mamba + Transformer 混合
├── 注意力层与SSM层交替
├── 结合两者优势
└── 256K上下文支持

Falcon Mamba (TII, 2024)
├── 7B参数纯Mamba模型
├── 开源可商用
└── 在长文本任务上表现突出

Vision Mamba (2024)
├── 将Mamba应用于视觉
├── 双向SSM处理图像
└── 高效视觉backbone

Mamba-2 (2024)
├── 理论统一:SSM与结构化矩阵
├── 更高效的实现
└── 8倍训练速度提升

═══════════════════════════════════════════════════════════════════

线性注意力与 RNN 化 Transformer

3.1 线性注意力机制

标准注意力的复杂度为 O(n²),线性注意力通过核技巧将其降至 O(n)。

Text Only
线性注意力原理
═══════════════════════════════════════════════════════════════════

标准Softmax Attention:

       softmax(QK^T)V
Attn = ───────────────
          Z

其中 Z = row_sum(softmax(QK^T))

复杂度:O(n²·d) 因为需要计算n×n的注意力矩阵

线性注意力(核技巧):

使用特征映射 φ,使得:
softmax(QK^T/√d) ≈ φ(Q)φ(K)^T

则:
       φ(Q)φ(K)^T V     φ(Q)(φ(K)^T V)
Attn = ───────────── = ─────────────────
            Z                Z

关键观察:φ(K)^T V 可以先计算(与序列长度无关)

复杂度:O(n·d²) 线性于序列长度!

═══════════════════════════════════════════════════════════════════
Python
class LinearAttention(nn.Module):
    """
    线性注意力实现(基于Katharopoulos et al.)
    """
    def __init__(self, dim, num_heads, feature_dim=None):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.feature_dim = feature_dim or self.head_dim

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def elu_feature_map(self, x):
        """特征映射:ELU + 1"""
        return F.elu(x) + 1

    def forward(self, x):
        B, N, C = x.shape

        # 生成QKV
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 特征映射
        q = self.elu_feature_map(q)
        k = self.elu_feature_map(k)

        # 线性注意力计算
        # KV = Σ_t k_t^T v_t
        KV = torch.einsum('bhsk,bhsv->bhkv', k, v)  # [B, heads, head_dim, head_dim]

        # Z = Σ_t k_t
        Z = k.sum(dim=2)  # [B, heads, head_dim]

        # 输出 = Q @ KV / (Q @ Z)
        numerator = torch.einsum('bhnk,bhkv->bhnv', q, KV)  # [B, heads, N, head_dim]
        denominator = torch.einsum('bhnk,bhk->bhn', q, Z).unsqueeze(-1) + 1e-6

        out = numerator / denominator
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)

        return out

3.2 RNN 化 Transformer

将 Transformer 解码为 RNN 形式,实现 O(1)推理内存。

Python
class RNNTransformer(nn.Module):
    """
    Transformer的RNN形式(基于RWKV和RetNet思想)
    """
    def __init__(self, dim, num_heads):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.decay = nn.Parameter(torch.randn(num_heads))  # 每头一个标量衰减率

    def forward_training(self, x):
        """训练时使用并行计算"""
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 因果掩码 + 衰减
        positions = torch.arange(N, device=x.device)
        rel_pos = positions.unsqueeze(0) - positions.unsqueeze(1)  # [N, N]

        # decay_matrix[h, n, m] = exp((n-m) * decay_h),对下三角有效
        decay_matrix = torch.exp(
            rel_pos.unsqueeze(0) * self.decay.view(-1, 1, 1)  # [H, N, N]
        )
        decay_matrix = torch.tril(decay_matrix)  # 因果掩码

        # 注意力:Retention(X) = (Q K^T ⊙ D) V
        attn = torch.einsum('bnhd,bmhd->bhnm', q, k)  # [B, H, N, M]
        attn = attn * decay_matrix.unsqueeze(0)  # [B, H, N, M]
        attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-6)

        out = torch.einsum('bhnm,bmhd->bnhd', attn, v)  # [B, N, H, D]
        return out.reshape(B, N, C)

    def forward_inference(self, x, state):
        """
        推理时使用RNN形式,O(1)内存
        state: (KV_cache, K_sum_cache)
        """
        B, C = x.shape

        qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        KV_cache, K_sum_cache = state

        # 更新状态(RNN形式)
        decay = torch.exp(self.decay)  # [num_heads]

        # KV_cache = decay * KV_cache + k^T v
        new_KV = KV_cache * decay.view(1, -1, 1, 1)  # [B, H, D, D]
        new_KV = new_KV + torch.einsum('bhd,bhe->bhde', k, v)

        # K_sum_cache = decay * K_sum_cache + k
        new_K_sum = K_sum_cache * decay.view(1, -1, 1)  # [B, H, D]
        new_K_sum = new_K_sum + k

        # 计算输出
        numerator = torch.einsum('bhd,bhde->bhe', q, new_KV)
        denominator = torch.einsum('bhd,bhd->bh', q, new_K_sum).unsqueeze(-1) + 1e-6

        out = numerator / denominator
        return out.reshape(B, C), (new_KV, new_K_sum)

3.3 代表性模型

Text Only
线性/RNN化Transformer模型
═══════════════════════════════════════════════════════════════════

RWKV (2023)
├── 架构:结合RNN和Transformer优点
├── 训练:并行化(类似Transformer)
├── 推理:O(1)内存(类似RNN)
├── 特点:时间衰减机制
└── 开源:完全开源可商用

RetNet (Microsoft, 2023)
├── 架构:保留Transformer结构
├── 核心:Retention机制(对偶形式)
├── 训练:并行
├── 推理:循环形式
└── 声称:性能匹敌Transformer

TransNormerLLM (2023)
├── 分块注意力 + 线性注意力
├── 局部精确 + 全局线性
└── 长序列高效

Linear Transformer (2020)
├── 最早的线性注意力工作
├── 使用核技巧近似softmax
└── 理论基础

═══════════════════════════════════════════════════════════════════

其他架构创新

4.1 混合专家系统(非 MoE 形式)

Python
class MultiScaleTransformer(nn.Module):
    """
    多尺度Transformer:不同层使用不同注意力模式
    """
    def __init__(self, dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()

        for i in range(num_layers):
            if i % 3 == 0:
                # 局部窗口注意力(捕捉局部模式)
                self.layers.append(LocalAttention(dim, window_size=512))
            elif i % 3 == 1:
                # 稀疏全局注意力(长程依赖)
                self.layers.append(SparseGlobalAttention(dim, stride=16))
            else:
                # 标准注意力(精确建模)
                self.layers.append(StandardAttention(dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

4.2 分层注意力

Python
class HierarchicalAttention(nn.Module):
    """
    分层注意力:先粗粒度后细粒度
    """
    def __init__(self, dim, num_tokens_per_group=64):
        super().__init__()
        self.num_tokens_per_group = num_tokens_per_group

        # 粗粒度(组级别)
        self.coarse_attention = nn.MultiheadAttention(dim, num_heads=8)

        # 细粒度(组内)
        self.fine_attention = nn.MultiheadAttention(dim, num_heads=8)

    def forward(self, x):
        B, N, C = x.shape
        G = N // self.num_tokens_per_group  # 组数

        # 分组
        x_grouped = x.reshape(B, G, self.num_tokens_per_group, C)

        # 组内池化得到组表示
        group_repr = x_grouped.mean(dim=2)  # [B, G, C]

        # 粗粒度:组间注意力
        coarse_out, _ = self.coarse_attention(group_repr, group_repr, group_repr)

        # 将组表示广播回各token
        coarse_expanded = coarse_out.unsqueeze(2).expand(-1, -1, self.num_tokens_per_group, -1)

        # 细粒度:组内注意力
        x_combined = x_grouped + coarse_expanded  # 融合
        x_fine = x_combined.reshape(B * G, self.num_tokens_per_group, C)
        fine_out, _ = self.fine_attention(x_fine, x_fine, x_fine)

        return fine_out.reshape(B, N, C)

4.3 内存高效架构

Python
class MemoryEfficientAttention(nn.Module):
    """
    FlashAttention风格的内存高效实现
    通过分块计算避免存储完整的注意力矩阵
    """
    def __init__(self, dim, num_heads, block_size=256):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.block_size = block_size

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        """
        分块计算注意力,O(n)内存(无需存储完整N×N注意力矩阵)
        """
        B, N, C = x.shape
        H = self.num_heads
        D = self.head_dim
        qkv = self.qkv(x).reshape(B, N, 3, H, D)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 分块
        num_blocks = (N + self.block_size - 1) // self.block_size

        outputs = []
        for i in range(num_blocks):
            start_i = i * self.block_size
            end_i = min(start_i + self.block_size, N)
            block_len = end_i - start_i

            q_block = q[:, start_i:end_i]  # [B, block_len, H, D]

            # 在线计算softmax
            max_score = torch.full((B, H, block_len, 1), float('-inf'), device=x.device)
            numerator = torch.zeros(B, H, block_len, D, device=x.device)
            denominator = torch.zeros(B, H, block_len, 1, device=x.device)

            for j in range(num_blocks):
                start_j = j * self.block_size
                end_j = min(start_j + self.block_size, N)

                k_block = k[:, start_j:end_j]  # [B, block_j, H, D]
                v_block = v[:, start_j:end_j]  # [B, block_j, H, D]

                # 计算块间注意力,保留多头维度
                scores = torch.einsum('bnhd,bmhd->bhnm', q_block, k_block)  # [B, H, block_i, block_j]
                scores = scores / math.sqrt(D)

                # 在线softmax更新
                max_score_new = torch.maximum(max_score, scores.max(dim=-1, keepdim=True)[0])
                exp_scores = torch.exp(scores - max_score_new)

                numerator = numerator * torch.exp(max_score - max_score_new) + \
                           torch.einsum('bhnm,bmhd->bhnd', exp_scores, v_block)
                denominator = denominator * torch.exp(max_score - max_score_new) + \
                             exp_scores.sum(dim=-1, keepdim=True)

                max_score = max_score_new

            out_block = numerator / denominator  # [B, H, block_len, D]
            out_block = out_block.permute(0, 2, 1, 3)  # [B, block_len, H, D]
            outputs.append(out_block)

        out = torch.cat(outputs, dim=1)
        out = out.reshape(B, N, C)
        return self.proj(out)

架构选择指南

5.1 不同场景推荐

Text Only
架构选择决策树
═══════════════════════════════════════════════════════════════════

你的主要约束是什么?
    ├── 长序列(>32K tokens)
    │       │
    │       ├── 需要强长程依赖捕捉
    │       │       └──▶ Mamba / Jamba / Linear Attention
    │       │
    │       └── 主要关注局部模式
    │               └──▶ 滑动窗口注意力 + 全局token
    ├── 大规模部署(成本敏感)
    │       │
    │       ├── 高吞吐推理
    │       │       └──▶ MoE(激活参数少)/ RWKV(O(1)内存)
    │       │
    │       └── 训练成本控制
    │               └──▶ 线性注意力 / 分层注意力
    ├── 通用性能优先
    │       │
    │       ├── 有足够计算资源
    │       │       └──▶ 标准Transformer + FlashAttention
    │       │
    │       └── 资源受限
    │               └──▶ 混合架构(MoE + 线性注意力)
    └── 特定领域
            ├── 代码生成
            │       └──▶ 标准Transformer(结构化模式重要)
            ├── 长文档理解
            │       └──▶ Mamba / Longformer
            └── 多模态
                    └──▶ 标准Transformer(跨模态对齐成熟)

═══════════════════════════════════════════════════════════════════

5.2 架构对比表

架构 训练复杂度 推理复杂度 长序列支持 实现难度 生态成熟度
标准 Transformer O(n²) O(n²) 极高
FlashAttention O(n²) O(n²)
MoE O(n²)⁺ O(n²)⁺
Mamba O(n) O(n) 极好 增长中
线性注意力 O(n) O(n)
RWKV O(n) O(1)⁺⁺

⁺ MoE 注意力复杂度仍为 O(n²),但每个 token 的 FFN 计算量显著降低(仅激活少数专家) ⁺⁺ RWKV 推理复杂度 O(1)指内存不随序列增长

5.3 实践建议

  1. 从标准 Transformer 开始
  2. 除非有明确的长序列需求,否则优先使用成熟的 Transformer
  3. FlashAttention 已解决大部分内存问题

  4. MoE 的适用场景

  5. 模型参数>10B 时考虑
  6. 推理吞吐量优先于延迟的场景
  7. 有分布式训练基础设施

  8. Mamba/SSM 的适用场景

  9. 序列长度>32K
  10. 需要流式处理(如语音、视频)
  11. 内存严格受限的边端设备

  12. 混合策略

  13. 底层: Mamba/线性注意力处理长上下文
  14. 上层:标准 Transformer 进行精确建模
  15. 参考: Jamba 架构

面试常见问题

Q1: MoE 架构为什么能实现"参数量大但计算量小"?负载均衡损失的作用是什么?

:MoE(Mixture of Experts)的核心思想是将 FFN 层替换为多个并行的专家网络,每个 token 只激活其中少数(通常 Top-2)专家。例如 DeepSeek-V3 有 256 个专家但每个 token 只用 8 个,因此总参数量虽大(671B),但每个 token 的激活参数仅 37B。

负载均衡损失(Load Balancing Loss) 解决的是"路由崩塌"问题:如果不加约束,路由器倾向于将大部分 token 分配给少数"热门"专家,导致: - 热门专家过载,成为训练瓶颈 - 冷门专家得不到训练,参数浪费 - GPU 间负载不均衡(跨设备 MoE 时尤其严重)

DeepSeek-V3 使用辅助无损负载均衡策略(Auxiliary-Loss-Free Load Balancing),通过偏置项动态调整专家选择,在不牺牲性能的前提下实现均衡。

Q2: Mamba/SSM 与 Transformer 的本质区别是什么?SSM 能完全替代 Transformer 吗?

:本质区别在于状态表示: - Transformer:维护完整的 KV Cache(显式记忆所有历史),复杂度 O(n²) - SSM(Mamba):将历史信息压缩到固定大小的隐状态中(类似 RNN),复杂度 O(n)

SSM 的优势是长序列处理效率高、推理内存恒定;劣势是在需要精确回溯远距离 token 的任务上(如复制、精确检索)表现不如 Transformer。

不能完全替代。当前趋势是混合架构(如 Jamba = Mamba + Attention + MoE): - 底层用 Mamba 处理长上下文(效率优先) - 关键层保留 Attention(精度优先) - MoE 控制参数-计算比

Q3: FlashAttention 为什么能加速?它改变了数学计算吗?

:FlashAttention 没有改变数学计算,输出与标准注意力完全一致。它加速的原理是减少 GPU HBM(高带宽内存)访问次数

  1. 标准注意力:S = QK^T 写入 HBM → 读取 HBM 加掩码 → Softmax 写入 HBM → 乘 V 写入 HBM。中间矩阵 S(n×n) 占用大量内存带宽。
  2. FlashAttention:利用 GPU SRAM(片上内存,快但小),将 Q、K、V 分块(tiling),在 SRAM 中完成 QK^T → Softmax → ×V 的完整计算,只将最终结果写回 HBM。

核心技巧是在线 Softmax(Online Softmax):分块计算时动态维护全局最大值和累加和,避免二次读取。

FlashAttention-2 进一步优化了并行策略(沿序列长度维度并行),FlashAttention-3 利用 H100 的异步性和 TF32。

Q4: 线性注意力(如 RWKV)为什么没有成为主流?它的根本局限是什么?

:线性注意力通过将 Softmax 分解为核函数近似(φ(Q)·φ(K)^T · V),将复杂度从 O(n²) 降为 O(n)。但根本局限在于:

  1. 信息压缩损失:将 n×d 的注意力矩阵压缩为 d×d 的累积和矩阵,等价于用一个固定大小的"摘要"代替完整的注意力分布,无法精确回溯特定位置
  2. 表达能力下降:核函数近似是 Softmax 的低秩近似,理论上无法完全恢复 Softmax 的建模能力
  3. 实践效果:在短-中序列任务上与 Transformer 差距不大,但在需要精确位置关联的任务(如代码生成、数学推理)上明显落后

RWKV 的创新在于结合了 RNN 的推理效率(O(1) 内存)和 Transformer 的训练并行性,适合资源受限的部署场景。


下一步

完成本章节学习后,建议继续阅读: - 05-大模型安全与对齐 - 理解 RLHF 、安全训练 - 06-大模型应用与产品化 - 工程实践与部署

实践项目: - 尝试用 Mamba 替换 Transformer 中的部分层,对比长序列性能 - 实现简化的 MoE 层,观察负载均衡损失的影响


📝 本章练习

🤔 思考题

  1. MoE 负载均衡:MoE 模型中为什么会出现"路由坍塌"(所有 Token 都被分配到少数 Expert)?负载均衡损失的原理是什么?
  2. Mamba vs Transformer:Mamba 的选择性状态空间机制(Selective SSM)相比传统 SSM 做了什么改进?为什么能在保持线性复杂度的同时获得更好的性能?
  3. FlashAttention:FlashAttention 是如何减少 GPU HBM 访问次数的?为什么减少内存访问比减少计算量更重要?
  4. 架构趋势:你认为未来 LLM 架构会朝什么方向发展?纯 Transformer、MoE+Transformer、还是 SSM+Transformer 混合?

💻 代码实践

  1. 入门:实现一个简化的 Top-K 路由 MoE 层,观察不同 K 值对负载分布的影响
  2. 进阶:对比 Mamba 和 Transformer 在长序列(8K+ tokens)上的推理速度和内存占用
  3. 高级:实现 Grouped Query Attention (GQA),对比与 MHA、MQA 的推理效率差异
💡 参考答案 #### 思考题参考答案 **1. MoE 负载均衡** 路由坍塌原因: - 初始化时某些 Expert 恰好对训练早期样本表现好 - 梯度更新强化了这种偏好,形成正反馈循环 - 最终所有 Token 都路由到少数 Expert,其余 Expert 得不到训练 负载均衡损失:`L_aux = α · N · Σ(f_i · P_i)`,其中 f_i 是 Expert i 被选中的频率,P_i 是平均路由概率。最小化这个损失鼓励均匀分布。 **2. Mamba 选择性 SSM** 传统 SSM(如 S4)的参数是固定的,不依赖输入。Selective SSM 让参数(B、C、Δ)成为输入的函数: - 输入相关的步长 Δ:对重要 Token 分配更多计算 - 输入相关的 B、C:选择性传播或遗忘信息 这解决了传统 SSM 无法"选择性记忆"的问题,在需要精确位置关联的任务上性能显著提升。 **3. FlashAttention** 核心优化:利用 GPU SRAM(片上内存,~20MB)的高带宽,减少对 HBM(高带宽内存,~80GB)的访问。 具体策略: - **分块计算**:将 Q、K、V 分成小块,在 SRAM 中完成注意力计算 - **在线 Softmax**:逐块更新 Softmax 结果,避免先完整计算再写回 HBM - **重计算**:反向传播时重新计算注意力矩阵而非存储,用计算换内存 为什么内存访问更重要:GPU 计算能力远超内存带宽(算力/带宽比持续增长),内存访问成为瓶颈。 **4. 架构趋势** 短期(2025-2026):MoE + Transformer 为主流(DeepSeek-V3、Mixtral),GQA 成为标配。 中期:SSM + Attention 混合架构(如 Jamba)可能在中长序列场景替代纯 Transformer。 长期:可能涌现全新的架构范式,但 Transformer 的生态优势会持续很长时间。 #### 代码实践参考答案 **实践 1:简化 MoE 层**
Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleMoELayer(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.experts = nn.ModuleList([
            nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
            for _ in range(num_experts)
        ])
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        logits = self.gate(x)  # [batch, seq, num_experts]
        top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
        top_k_weights = F.softmax(top_k_logits, dim=-1)

        # 负载统计
        expert_counts = torch.zeros(self.num_experts)
        for idx in top_k_indices.flatten().tolist():
            expert_counts[idx] += 1
        load_balance = (expert_counts / expert_counts.sum()).std().item()

        output = torch.zeros_like(x)
        for i in range(self.top_k):
            for e in range(self.num_experts):
                mask = (top_k_indices[:, :, i] == e)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[e](expert_input)
                    output[mask] += top_k_weights[:, :, i][mask].unsqueeze(-1) * expert_output

        return output, load_balance

# 测试
moe = SimpleMoELayer(d_model=256, d_ff=1024, num_experts=8, top_k=2)
x = torch.randn(4, 32, 256)
out, balance = moe(x)
print(f"负载均衡标准差: {balance:.4f} (越小越均衡)")

最后更新日期:2026-04-21 适用版本:LLM 学习教程 v2026.04

审查记录: - 2026-04-20: 修复标题编号为 "# 04.",新增面试常见问题章节(Q1-Q4: MoE 负载均衡、Mamba vs Transformer、FlashAttention 原理、线性注意力局限),更新日期 - 2026-03-26: 初始版本