跳转至

模型架构创新

目录

  1. Mixture of Experts (MoE)
  2. 状态空间模型 (Mamba)
  3. 线性注意力与RNN化Transformer
  4. 其他架构创新
  5. 架构选择指南

Mixture of Experts (MoE)

1.1 MoE核心思想

混合专家模型(Mixture of Experts, MoE) 通过条件计算(Conditional Computation)实现模型容量与计算成本的解耦。核心思想是:对于每个输入,只激活部分参数(专家),而非整个网络。

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                    MoE 架构示意图                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  Input                                                           │
│    │                                                             │
│    ▼                                                             │
│  ┌─────────────────┐                                             │
│  │   Router /      │                                             │
│  │   Gate Network  │──▶ 生成路由权重 [0.1, 0.7, 0.05, 0.15]      │
│  │   (可学习)       │                                             │
│  └─────────────────┘                                             │
│           │                                                      │
│           ▼                                                      │
│    ┌──────┴──────┬────────┬────────┐                            │
│    ▼             ▼        ▼        ▼                            │
│  ┌─────┐    ┌─────┐   ┌─────┐  ┌─────┐                         │
│  │ E0  │    │ E1  │   │ E2  │  │ E3  │    Expert Networks       │
│  │0.1  │    │0.7  │   │0.05 │  │0.15 │    (FFN层作为专家)        │
│  └──┬──┘    └──┬──┘   └──┬──┘  └──┬──┘                         │
│     │          │         │        │                             │
│     └──────────┴────┬────┴────────┘                             │
│                     ▼                                            │
│              ┌──────────┐                                        │
│              │ 加权求和  │──▶ 0.1×E0 + 0.7×E1 + 0.05×E2 + 0.15×E3 │
│              └──────────┘                                        │
│                     │                                            │
│                     ▼                                            │
│                   Output                                         │
│                                                                  │
│  关键:只计算Top-K专家(如Top-2),其余专家输出设为0              │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

1.2 MoE数学原理

Text Only
MoE层的前向传播
═══════════════════════════════════════════════════════════════════

给定输入 x,MoE层的输出为:

                N
MoE(x) = Σ  G(x)_i · E_i(x)
               i=1

其中:
- N = 专家数量(通常 8, 16, 64, 128)
- E_i = 第i个专家网络(通常是FFN)
- G(x)_i = 门控网络对第i个专家的权重

门控函数(Softmax Gate):

              exp(W_g · x)_i
G(x)_i = ─────────────────────
          Σ_j exp(W_g · x)_j

Top-K门控(实际使用):

G(x)_i = { Softmax(TopK(W_g · x, k))_i  if i ∈ TopK
         { 0                           otherwise

负载均衡损失(Auxiliary Loss):

L_aux = α · N · Σ_i f_i · P_i

其中:
- f_i = 批次中分配给专家i的token比例
- P_i = 路由器对专家i的平均预测概率
- α = 超参数(通常 0.01)

═══════════════════════════════════════════════════════════════════

1.3 MoE实现代码

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class MoELayer(nn.Module):
    """
    Mixture of Experts 层实现
    """
    def __init__(
        self,
        d_model: int,
        num_experts: int = 8,
        top_k: int = 2,
        expert_hidden_dim: int = None,
        aux_loss_coef: float = 0.01
    ):
        super().__init__()  # super()调用父类方法
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        self.aux_loss_coef = aux_loss_coef

        if expert_hidden_dim is None:
            expert_hidden_dim = 4 * d_model

        # 门控网络(路由器)
        self.gate = nn.Linear(d_model, num_experts, bias=False)

        # 专家网络(每个专家是一个FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, expert_hidden_dim),
                nn.GELU(),
                nn.Linear(expert_hidden_dim, d_model)
            )
            for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            output: [batch_size, seq_len, d_model]
            aux_loss: 负载均衡辅助损失
        """
        batch_size, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)  # [batch_size * seq_len, d_model]  # view重塑张量形状

        # 1. 计算门控分数
        gate_logits = self.gate(x_flat)  # [num_tokens, num_experts]
        gate_probs = F.softmax(gate_logits, dim=-1)

        # 2. Top-K路由
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)  # 重新归一化

        # 3. 计算专家输出并加权
        output = torch.zeros_like(x_flat)

        for i in range(self.top_k):
            expert_indices = top_k_indices[:, i]
            expert_probs = top_k_probs[:, i:i+1]

            # 对每个专家处理分配给它的token
            for expert_id in range(self.num_experts):
                mask = expert_indices == expert_id
                if mask.any():  # any()任一为True则返回True
                    expert_input = x_flat[mask]
                    expert_output = self.experts[expert_id](expert_input)
                    output[mask] += expert_probs[mask] * expert_output

        # 4. 计算负载均衡辅助损失
        aux_loss = self._compute_aux_loss(gate_probs, top_k_indices)

        output = output.view(batch_size, seq_len, d_model)
        return output, aux_loss

    def _compute_aux_loss(
        self,
        gate_probs: torch.Tensor,
        top_k_indices: torch.Tensor
    ) -> torch.Tensor:
        """
        计算负载均衡辅助损失(来自Switch Transformer论文)
        """
        num_tokens = gate_probs.size(0)

        # f_i: 每个专家被选择的token比例
        expert_mask = F.one_hot(top_k_indices, self.num_experts).sum(dim=1)
        f = expert_mask.float().sum(dim=0) / num_tokens

        # P_i: 路由器分配给每个专家的平均概率
        P = gate_probs.mean(dim=0)

        # 辅助损失:鼓励均匀分布
        aux_loss = self.num_experts * (f * P).sum()

        return self.aux_loss_coef * aux_loss

class MoETransformerBlock(nn.Module):
    """
    使用MoE的Transformer块(替代标准FFN)
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_experts: int = 8,
        top_k: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.moe = MoELayer(d_model, num_experts, top_k)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # Self-Attention
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))

        # MoE FFN
        moe_out, aux_loss = self.moe(x)
        x = self.norm2(x + self.dropout(moe_out))

        return x, aux_loss

1.4 MoE变体与优化

Text Only
MoE架构演进
═══════════════════════════════════════════════════════════════════

> ⚠️ **时效性说明(2026-02 复核)**:下述“性能对比”结论受评测集、训练数据与评测口径影响,请结合原论文与最新公开评测解读。

Switch Transformer (Google, 2021)
├── Top-1 路由(每个token只选一个专家)
├── 简化设计,降低通信开销
└── 1.6T参数,与T5-XXL相当性能但快4倍

GLaM (Google, 2021)
├── 64专家 per MoE层
├── 1.2T总参数,96B激活参数
└── 在部分公开任务上优于同期GPT-3基线

ST-MoE (Google, 2022)
├── 引入Expert Choice Routing
├── 从token选专家 → 专家选token
└── 更好的负载均衡

Mixtral 8x7B (Mistral, 2023)
├── 8专家,Top-2路由
├── 46.7B总参数,12.9B激活参数
├── 开源,在多项公开评测中表现接近或优于Llama 2 70B
└── 稀疏MoE首次在开源社区成功

Qwen1.5-MoE (阿里, 2024)
├── 共享专家 + 路由专家分离
├── 更好的专家特化
└── 中文MoE代表模型之一

═══════════════════════════════════════════════════════════════════

专家选择路由(Expert Choice Routing)

Python
class ExpertChoiceMoE(nn.Module):
    """
    Expert Choice Routing: 专家选择token而非token选择专家
    优势:完美负载均衡,每个专家处理固定数量token
    """
    def __init__(self, d_model: int, num_experts: int, expert_capacity: int):
        super().__init__()
        self.num_experts = num_experts
        self.expert_capacity = expert_capacity  # 每个专家处理的token数

        self.gate = nn.Linear(d_model, num_experts)
        self.experts = nn.ModuleList([
            self._create_expert(d_model) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len
        x_flat = x.view(num_tokens, d_model)

        # 计算每个token对每个专家的偏好分数
        gate_scores = self.gate(x_flat)  # [num_tokens, num_experts]

        # 每个专家选择Top-C个token
        selected_tokens = []
        for expert_id in range(self.num_experts):
            expert_scores = gate_scores[:, expert_id]
            top_k_values, top_k_indices = torch.topk(
                expert_scores,
                min(self.expert_capacity, num_tokens)
            )
            selected_tokens.append((expert_id, top_k_indices, top_k_values))

        # 聚合输出(一个token可能被多个专家处理)
        output = torch.zeros_like(x_flat)
        for expert_id, indices, scores in selected_tokens:
            expert_input = x_flat[indices]
            expert_output = self.experts[expert_id](expert_input)
            output[indices] += scores.unsqueeze(1) * expert_output  # unsqueeze增加一个维度

        return output.view(batch_size, seq_len, d_model)

1.5 MoE训练挑战与解决方案

挑战 原因 解决方案
负载不均衡 某些专家被过度使用 辅助损失、专家选择路由、容量因子
通信开销 专家分布在不同设备 专家并行、All-to-All通信优化
训练不稳定 路由决策离散 负载均衡损失、专家dropout、温度退火
显存占用 大量专家参数 专家并行、CPU offloading、激活检查点

状态空间模型 (Mamba)

2.1 从RNN到SSM

状态空间模型(State Space Model, SSM) 是一类将连续时间系统离散化的序列建模方法。Mamba通过选择性机制(Selective Mechanism)解决了传统SSM的内容依赖问题。

Text Only
序列模型演进
═══════════════════════════════════════════════════════════════════

RNN
├── 特点:隐状态压缩历史信息
├── 问题:长程依赖困难、训练并行度低
└── 代表:LSTM, GRU

Transformer
├── 特点:全局注意力、训练并行
├── 问题:O(n²)复杂度、长序列内存爆炸
└── 代表:GPT, BERT, LLaMA

SSM (S4, 2021)
├── 特点:线性复杂度、连续时间建模
├── 问题:内容无关(参数固定)
└── 代表:S4, DSS, GSS

Mamba (2023)
├── 特点:选择性状态空间、输入依赖
├── 优势:线性复杂度 + 内容感知
└── 代表:Mamba, Jamba, Falcon Mamba

═══════════════════════════════════════════════════════════════════

2.2 状态空间模型基础

Text Only
连续时间状态空间方程
═══════════════════════════════════════════════════════════════════

连续形式:

h'(t) = A · h(t) + B · x(t)   (状态演化)
y(t)  = C · h(t)             (输出)

其中:
- h(t) ∈ R^N: 隐状态(N维)
- x(t) ∈ R: 输入(标量)
- y(t) ∈ R: 输出(标量)
- A ∈ R^(N×N): 状态转移矩阵
- B ∈ R^N: 输入影响矩阵
- C ∈ R^N: 输出投影矩阵

离散化(零阶保持,ZOH):

h_k = Ā · h_{k-1} + B̄ · x_k
y_k = C · h_k

其中离散化参数:
Ā = exp(Δ · A)          (矩阵指数)
B̄ = (Ā - I) · A^(-1) · B

═══════════════════════════════════════════════════════════════════

2.3 Mamba的选择性机制

Mamba的核心创新是使B、C、Δ成为输入的函数,实现选择性记忆/遗忘。

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class SelectiveSSM(nn.Module):
    """
    选择性状态空间模型(Mamba核心)
    """
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = int(expand * d_model)

        # 输入投影(x → x和残差)
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # 因果卷积(局部建模)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner,  # 深度可分离
            bias=True
        )

        # 选择性参数生成(关键创新)
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
        # 输出: [B, Δ, C] 其中Δ是时间步长参数

        # 状态转移矩阵A(可学习或固定)
        # 通常使用HiPPO初始化或学习
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1)).repeat(self.d_inner, 1))
        self.D = nn.Parameter(torch.ones(self.d_inner))  # 跳跃连接

        # 输出投影
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            output: [batch, seq_len, d_model]
        """
        batch, seq_len, dim = x.shape

        # 输入投影
        x_and_res = self.in_proj(x)  # [batch, seq_len, 2*d_inner]
        x_ssm, res = x_and_res.split([self.d_inner, self.d_inner], dim=-1)

        # 因果卷积
        x_conv = rearrange(x_ssm, 'b l d -> b d l')
        x_conv = self.conv1d(x_conv)[..., :seq_len]  # 因果截断
        x_conv = rearrange(x_conv, 'b d l -> b l d')
        x_conv = F.silu(x_conv)

        # 选择性SSM
        y = self.selective_scan(x_conv)

        # 门控融合
        y = y * F.silu(res)

        # 输出投影
        output = self.out_proj(y)
        return output

    def selective_scan(self, x):
        """
        选择性扫描:核心计算
        使用并行关联扫描(Parallel Associative Scan)实现高效计算
        """
        batch, seq_len, d_in = x.shape

        # 生成选择性参数(输入依赖)
        params = self.x_proj(x)  # [batch, seq_len, d_state*2 + 1]
        B, C, delta = params.split([self.d_state, self.d_state, 1], dim=-1)

        # softplus确保Δ为正
        delta = F.softplus(delta.squeeze(-1))  # [batch, seq_len, d_in]

        # 离散化
        A = -torch.exp(self.A_log.float())  # [d_in, d_state]

        # 使用并行扫描高效计算
        # 这里使用简化的顺序实现,实际使用CUDA优化的并行版本
        y = self.parallel_scan(x, delta, A, B, C)

        return y

    def parallel_scan(self, x, delta, A, B, C):
        """
        并行关联扫描实现
        将递归计算转化为可并行的关联操作
        """
        batch, seq_len, d_in = x.shape

        # 离散化参数
        delta_A = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
        delta_B = torch.einsum('bld,bln->bldn', delta, B)

        # 输入加权
        BX = delta_B * x.unsqueeze(-1)  # [batch, seq_len, d_in, d_state]

        # 关联扫描: h_k = delta_A_k * h_{k-1} + BX_k
        # 使用Blelloch扫描算法并行化
        h = self.associative_scan(delta_A, BX)

        # 输出投影
        y = torch.einsum('bldn,bln->bld', h, C)

        # 跳跃连接
        y = y + self.D.unsqueeze(0).unsqueeze(0) * x

        return y

    def associative_scan(self, A, B):
        """
        关联扫描:将线性递归转化为并行计算
        核心思想:h_k = A_k * h_{k-1} + B_k 是关联操作
        """
        # 实际实现使用CUDA内核
        # 这里展示概念性实现
        batch, seq_len, d_in, d_state = B.shape

        h = torch.zeros(batch, d_in, d_state, device=B.device, dtype=B.dtype)
        hs = []

        for k in range(seq_len):
            h = A[:, k] * h + B[:, k]
            hs.append(h)

        return torch.stack(hs, dim=1)

2.4 Mamba架构细节

Text Only
Mamba Block 详细结构
═══════════════════════════════════════════════════════════════════

输入 x ∈ R^(L×D)
    ├──▶ Linear ──▶ [x_ssm, res]  分割为SSM输入和残差
    ├──▶ x_ssm ──▶ Conv1d (因果) ──▶ SiLU ──▶ Selective SSM ──┐
    │                                                          │
    ├──▶ res ─────────────────────────────────────────────▶ SiLU │
    │                                                          │
    └──▶ 逐元素相乘 ◀───────────────────────────────────────────┘
         Linear ──▶ 输出

关键设计选择:
1. 因果卷积:引入局部依赖,弥补SSM缺乏局部性的不足
2. SiLU激活:平滑非线性,与门控机制配合
3. 选择性参数:B, C, Δ都依赖于输入x
4. 并行扫描:O(n)复杂度的高效实现

═══════════════════════════════════════════════════════════════════

2.5 Mamba变体

Text Only
Mamba 家族
═══════════════════════════════════════════════════════════════════

Mamba (2023)
├── 纯SSM架构,无注意力
├── 线性复杂度 O(n)
├── 语言建模性能匹敌Transformer
└── 长序列优势显著

Jamba (AI21, 2024)
├── Mamba + Transformer 混合
├── 注意力层与SSM层交替
├── 结合两者优势
└── 256K上下文支持

Falcon Mamba (TII, 2024)
├── 7B参数纯Mamba模型
├── 开源可商用
└── 在长文本任务上表现突出

Vision Mamba (2024)
├── 将Mamba应用于视觉
├── 双向SSM处理图像
└── 高效视觉backbone

Mamba-2 (2024)
├── 理论统一:SSM与结构化矩阵
├── 更高效的实现
└── 8倍训练速度提升

═══════════════════════════════════════════════════════════════════

线性注意力与RNN化Transformer

3.1 线性注意力机制

标准注意力的复杂度为O(n²),线性注意力通过核技巧将其降至O(n)。

Text Only
线性注意力原理
═══════════════════════════════════════════════════════════════════

标准Softmax Attention:

       softmax(QK^T)V
Attn = ───────────────
          Z

其中 Z = row_sum(softmax(QK^T))

复杂度:O(n²·d) 因为需要计算n×n的注意力矩阵

线性注意力(核技巧):

使用特征映射 φ,使得:
softmax(QK^T/√d) ≈ φ(Q)φ(K)^T

则:
       φ(Q)φ(K)^T V     φ(Q)(φ(K)^T V)
Attn = ───────────── = ─────────────────
            Z                Z

关键观察:φ(K)^T V 可以先计算(与序列长度无关)

复杂度:O(n·d²) 线性于序列长度!

═══════════════════════════════════════════════════════════════════
Python
class LinearAttention(nn.Module):
    """
    线性注意力实现(基于Katharopoulos et al.)
    """
    def __init__(self, dim, num_heads, feature_dim=None):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.feature_dim = feature_dim or self.head_dim

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def elu_feature_map(self, x):
        """特征映射:ELU + 1"""
        return F.elu(x) + 1

    def forward(self, x):
        B, N, C = x.shape

        # 生成QKV
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 特征映射
        q = self.elu_feature_map(q)
        k = self.elu_feature_map(k)

        # 线性注意力计算
        # KV = Σ_t k_t^T v_t
        KV = torch.einsum('bhsk,bhsv->bhkv', k, v)  # [B, heads, head_dim, head_dim]

        # Z = Σ_t k_t
        Z = k.sum(dim=2)  # [B, heads, head_dim]

        # 输出 = Q @ KV / (Q @ Z)
        numerator = torch.einsum('bhnk,bhkv->bhnv', q, KV)  # [B, heads, N, head_dim]
        denominator = torch.einsum('bhnk,bhk->bhn', q, Z).unsqueeze(-1) + 1e-6

        out = numerator / denominator
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)

        return out

3.2 RNN化Transformer

将Transformer解码为RNN形式,实现O(1)推理内存。

Python
class RNNTransformer(nn.Module):
    """
    Transformer的RNN形式(基于RWKV和RetNet思想)
    """
    def __init__(self, dim, num_heads):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.decay = nn.Parameter(torch.randn(num_heads))  # 每头一个标量衰减率

    def forward_training(self, x):
        """训练时使用并行计算"""
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 因果掩码 + 衰减
        positions = torch.arange(N, device=x.device)
        rel_pos = positions.unsqueeze(0) - positions.unsqueeze(1)  # [N, N]

        # decay_matrix[h, n, m] = exp((n-m) * decay_h),对下三角有效
        decay_matrix = torch.exp(
            rel_pos.unsqueeze(0) * self.decay.view(-1, 1, 1)  # [H, N, N]
        )
        decay_matrix = torch.tril(decay_matrix)  # 因果掩码

        # 注意力:Retention(X) = (Q K^T ⊙ D) V
        attn = torch.einsum('bnhd,bmhd->bhnm', q, k)  # [B, H, N, M]
        attn = attn * decay_matrix.unsqueeze(0)  # [B, H, N, M]
        attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-6)

        out = torch.einsum('bhnm,bmhd->bnhd', attn, v)  # [B, N, H, D]
        return out.reshape(B, N, C)

    def forward_inference(self, x, state):
        """
        推理时使用RNN形式,O(1)内存
        state: (KV_cache, K_sum_cache)
        """
        B, C = x.shape

        qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        KV_cache, K_sum_cache = state

        # 更新状态(RNN形式)
        decay = torch.exp(self.decay)  # [num_heads]

        # KV_cache = decay * KV_cache + k^T v
        new_KV = KV_cache * decay.view(1, -1, 1, 1)  # [B, H, D, D]
        new_KV = new_KV + torch.einsum('bhd,bhe->bhde', k, v)

        # K_sum_cache = decay * K_sum_cache + k
        new_K_sum = K_sum_cache * decay.view(1, -1, 1)  # [B, H, D]
        new_K_sum = new_K_sum + k

        # 计算输出
        numerator = torch.einsum('bhd,bhde->bhe', q, new_KV)
        denominator = torch.einsum('bhd,bhd->bh', q, new_K_sum).unsqueeze(-1) + 1e-6

        out = numerator / denominator
        return out.reshape(B, C), (new_KV, new_K_sum)

3.3 代表性模型

Text Only
线性/RNN化Transformer模型
═══════════════════════════════════════════════════════════════════

RWKV (2023)
├── 架构:结合RNN和Transformer优点
├── 训练:并行化(类似Transformer)
├── 推理:O(1)内存(类似RNN)
├── 特点:时间衰减机制
└── 开源:完全开源可商用

RetNet (Microsoft, 2023)
├── 架构:保留Transformer结构
├── 核心:Retention机制(对偶形式)
├── 训练:并行
├── 推理:循环形式
└── 声称:性能匹敌Transformer

TransNormerLLM (2023)
├── 分块注意力 + 线性注意力
├── 局部精确 + 全局线性
└── 长序列高效

Linear Transformer (2020)
├── 最早的线性注意力工作
├── 使用核技巧近似softmax
└── 理论基础

═══════════════════════════════════════════════════════════════════

其他架构创新

4.1 混合专家系统(非MoE形式)

Python
class MultiScaleTransformer(nn.Module):
    """
    多尺度Transformer:不同层使用不同注意力模式
    """
    def __init__(self, dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()

        for i in range(num_layers):
            if i % 3 == 0:
                # 局部窗口注意力(捕捉局部模式)
                self.layers.append(LocalAttention(dim, window_size=512))
            elif i % 3 == 1:
                # 稀疏全局注意力(长程依赖)
                self.layers.append(SparseGlobalAttention(dim, stride=16))
            else:
                # 标准注意力(精确建模)
                self.layers.append(StandardAttention(dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

4.2 分层注意力

Python
class HierarchicalAttention(nn.Module):
    """
    分层注意力:先粗粒度后细粒度
    """
    def __init__(self, dim, num_tokens_per_group=64):
        super().__init__()
        self.num_tokens_per_group = num_tokens_per_group

        # 粗粒度(组级别)
        self.coarse_attention = nn.MultiheadAttention(dim, num_heads=8)

        # 细粒度(组内)
        self.fine_attention = nn.MultiheadAttention(dim, num_heads=8)

    def forward(self, x):
        B, N, C = x.shape
        G = N // self.num_tokens_per_group  # 组数

        # 分组
        x_grouped = x.reshape(B, G, self.num_tokens_per_group, C)

        # 组内池化得到组表示
        group_repr = x_grouped.mean(dim=2)  # [B, G, C]

        # 粗粒度:组间注意力
        coarse_out, _ = self.coarse_attention(group_repr, group_repr, group_repr)

        # 将组表示广播回各token
        coarse_expanded = coarse_out.unsqueeze(2).expand(-1, -1, self.num_tokens_per_group, -1)

        # 细粒度:组内注意力
        x_combined = x_grouped + coarse_expanded  # 融合
        x_fine = x_combined.reshape(B * G, self.num_tokens_per_group, C)
        fine_out, _ = self.fine_attention(x_fine, x_fine, x_fine)

        return fine_out.reshape(B, N, C)

4.3 内存高效架构

Python
class MemoryEfficientAttention(nn.Module):
    """
    FlashAttention风格的内存高效实现
    通过分块计算避免存储完整的注意力矩阵
    """
    def __init__(self, dim, num_heads, block_size=256):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.block_size = block_size

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        """
        分块计算注意力,O(n)内存(无需存储完整N×N注意力矩阵)
        """
        B, N, C = x.shape
        H = self.num_heads
        D = self.head_dim
        qkv = self.qkv(x).reshape(B, N, 3, H, D)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # 分块
        num_blocks = (N + self.block_size - 1) // self.block_size

        outputs = []
        for i in range(num_blocks):
            start_i = i * self.block_size
            end_i = min(start_i + self.block_size, N)
            block_len = end_i - start_i

            q_block = q[:, start_i:end_i]  # [B, block_len, H, D]

            # 在线计算softmax
            max_score = torch.full((B, H, block_len, 1), float('-inf'), device=x.device)
            numerator = torch.zeros(B, H, block_len, D, device=x.device)
            denominator = torch.zeros(B, H, block_len, 1, device=x.device)

            for j in range(num_blocks):
                start_j = j * self.block_size
                end_j = min(start_j + self.block_size, N)

                k_block = k[:, start_j:end_j]  # [B, block_j, H, D]
                v_block = v[:, start_j:end_j]  # [B, block_j, H, D]

                # 计算块间注意力,保留多头维度
                scores = torch.einsum('bnhd,bmhd->bhnm', q_block, k_block)  # [B, H, block_i, block_j]
                scores = scores / math.sqrt(D)

                # 在线softmax更新
                max_score_new = torch.maximum(max_score, scores.max(dim=-1, keepdim=True)[0])
                exp_scores = torch.exp(scores - max_score_new)

                numerator = numerator * torch.exp(max_score - max_score_new) + \
                           torch.einsum('bhnm,bmhd->bhnd', exp_scores, v_block)
                denominator = denominator * torch.exp(max_score - max_score_new) + \
                             exp_scores.sum(dim=-1, keepdim=True)

                max_score = max_score_new

            out_block = numerator / denominator  # [B, H, block_len, D]
            out_block = out_block.permute(0, 2, 1, 3)  # [B, block_len, H, D]
            outputs.append(out_block)

        out = torch.cat(outputs, dim=1)
        out = out.reshape(B, N, C)
        return self.proj(out)

架构选择指南

5.1 不同场景推荐

Text Only
架构选择决策树
═══════════════════════════════════════════════════════════════════

你的主要约束是什么?
    ├── 长序列(>32K tokens)
    │       │
    │       ├── 需要强长程依赖捕捉
    │       │       └──▶ Mamba / Jamba / Linear Attention
    │       │
    │       └── 主要关注局部模式
    │               └──▶ 滑动窗口注意力 + 全局token
    ├── 大规模部署(成本敏感)
    │       │
    │       ├── 高吞吐推理
    │       │       └──▶ MoE(激活参数少)/ RWKV(O(1)内存)
    │       │
    │       └── 训练成本控制
    │               └──▶ 线性注意力 / 分层注意力
    ├── 通用性能优先
    │       │
    │       ├── 有足够计算资源
    │       │       └──▶ 标准Transformer + FlashAttention
    │       │
    │       └── 资源受限
    │               └──▶ 混合架构(MoE + 线性注意力)
    └── 特定领域
            ├── 代码生成
            │       └──▶ 标准Transformer(结构化模式重要)
            ├── 长文档理解
            │       └──▶ Mamba / Longformer
            └── 多模态
                    └──▶ 标准Transformer(跨模态对齐成熟)

═══════════════════════════════════════════════════════════════════

5.2 架构对比表

架构 训练复杂度 推理复杂度 长序列支持 实现难度 生态成熟度
标准Transformer O(n²) O(n²) 极高
FlashAttention O(n²) O(n²)
MoE O(n²)⁺ O(n²)⁺
Mamba O(n) O(n) 极好 增长中
线性注意力 O(n) O(n)
RWKV O(n) O(1)⁺⁺

⁺ MoE注意力复杂度仍为O(n²),但每个token的FFN计算量显著降低(仅激活少数专家) ⁺⁺ RWKV推理复杂度O(1)指内存不随序列增长

5.3 实践建议

  1. 从标准Transformer开始
  2. 除非有明确的长序列需求,否则优先使用成熟的Transformer
  3. FlashAttention已解决大部分内存问题

  4. MoE的适用场景

  5. 模型参数>10B时考虑
  6. 推理吞吐量优先于延迟的场景
  7. 有分布式训练基础设施

  8. Mamba/SSM的适用场景

  9. 序列长度>32K
  10. 需要流式处理(如语音、视频)
  11. 内存严格受限的边端设备

  12. 混合策略

  13. 底层:Mamba/线性注意力处理长上下文
  14. 上层:标准Transformer进行精确建模
  15. 参考:Jamba架构

下一步

完成本章节学习后,建议继续阅读: - 05-大模型安全与对齐 - 理解RLHF、安全训练 - 06-大模型应用与产品化 - 工程实践与部署

实践项目: - 尝试用Mamba替换Transformer中的部分层,对比长序列性能 - 实现简化的MoE层,观察负载均衡损失的影响


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026