跳转至

04-Mamba与状态空间模型

学习时间: 约10-12小时 难度级别: ⭐⭐⭐⭐⭐ 高级 前置知识: 线性代数、微分方程基础、Transformer架构、PyTorch 学习目标: 深入理解状态空间模型的数学原理,掌握Mamba架构的核心创新,理解线性注意力的理论与实现

📌 定位说明:本章介绍新兴的SSM架构,是Transformer的重要替代方案。结合 01-注意力机制详解02-Transformer架构 学习,理解O(n²)到O(n)的复杂度突破。


目录


1. 状态空间模型基础

1.1 什么是状态空间模型

状态空间模型(State Space Model, SSM) 源自控制理论,通过一个隐状态 \(h(t)\) 来描述系统的动态演化。在深度学习语境下,SSM提供了一种新的序列建模范式。

连续时间SSM

经典的连续时间状态空间模型定义为:

\[\frac{dh(t)}{dt} = \mathbf{A}h(t) + \mathbf{B}x(t) \quad \text{(状态方程)}\]
\[y(t) = \mathbf{C}h(t) + \mathbf{D}x(t) \quad \text{(输出方程)}\]

其中: - \(x(t) \in \mathbb{R}^1\):输入信号(标量) - \(h(t) \in \mathbb{R}^N\):隐状态向量(N维) - \(y(t) \in \mathbb{R}^1\):输出信号(标量) - \(\mathbf{A} \in \mathbb{R}^{N \times N}\):状态转移矩阵 - \(\mathbf{B} \in \mathbb{R}^{N \times 1}\):输入投影矩阵 - \(\mathbf{C} \in \mathbb{R}^{1 \times N}\):输出投影矩阵 - \(\mathbf{D} \in \mathbb{R}^{1 \times 1}\):直通项(通常省略)

1.2 离散化:从连续到离散

深度学习处理的是离散序列,需要将连续SSM离散化。

零阶保持(Zero-Order Hold, ZOH)离散化

假设输入在采样间隔 \(\Delta\) 内保持不变:

\[h_k = \bar{\mathbf{A}}h_{k-1} + \bar{\mathbf{B}}x_k\]
\[y_k = \mathbf{C}h_k\]

离散化参数为:

\[\bar{\mathbf{A}} = \exp(\Delta \mathbf{A})\]
\[\bar{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}\]

简化离散化(实际常用)

为了计算效率,常用一阶近似:

\[\bar{\mathbf{A}} \approx \mathbf{I} + \Delta \mathbf{A}\]
\[\bar{\mathbf{B}} \approx \Delta \mathbf{B}\]
Python
import torch
import torch.nn as nn
import math

def discretize_zoh(A, B, delta):
    """
    零阶保持离散化
    A: (N, N) 状态转移矩阵
    B: (N, 1) 输入矩阵
    delta: (batch, seq_len, 1) 或标量,时间步长
    """
    # 简化版本:一阶近似
    A_bar = torch.eye(A.shape[0], device=A.device) + delta * A
    B_bar = delta * B
    return A_bar, B_bar

1.3 SSM的递归形式与卷积形式

SSM的一个关键优势是可以用两种等价形式计算:

递归形式(RNN式)

\[h_t = \bar{\mathbf{A}}h_{t-1} + \bar{\mathbf{B}}x_t\]
\[y_t = \mathbf{C}h_t\]
  • 优点:推理时O(1)复杂度(每个时间步)
  • 缺点:无法并行训练

卷积形式(CNN式)

展开递归,可以得到:

\[y_t = \mathbf{C}\bar{\mathbf{A}}^t\bar{\mathbf{B}}x_0 + \mathbf{C}\bar{\mathbf{A}}^{t-1}\bar{\mathbf{B}}x_1 + \cdots + \mathbf{C}\bar{\mathbf{B}}x_t\]

定义SSM卷积核:

\[\mathbf{K} = (\mathbf{C}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}^2\bar{\mathbf{B}}, \ldots, \mathbf{C}\bar{\mathbf{A}}^{L-1}\bar{\mathbf{B}})\]

则输出可以写成卷积:

\[\mathbf{y} = \mathbf{K} * \mathbf{x}\]
  • 优点:训练时可并行(FFT加速)
  • 缺点:推理时需要重新计算整个序列
Python
def ssm_conv_kernel(A_bar, B_bar, C, L):
    """
    计算SSM的卷积核
    A_bar: (N, N) 离散化状态矩阵
    B_bar: (N, 1) 离散化输入矩阵  
    C: (1, N) 输出矩阵
    L: 序列长度
    返回: (L,) 卷积核
    """
    N = A_bar.shape[0]
    K = torch.zeros(L)

    A_power = torch.eye(N)  # A^0
    for i in range(L):
        # K[i] = C @ A^i @ B
        K[i] = (C @ A_power @ B_bar).squeeze()
        A_power = A_power @ A_bar  # A^(i+1)

    return K

1.4 与RNN、Transformer的对比

特性 RNN/LSTM Transformer SSM (Mamba)
训练复杂度 O(L) 串行 O(L²) 并行 O(L) 并行
推理复杂度 O(1) 每步 O(L) 每步 O(1) 每步
长程依赖 梯度消失 优秀 优秀(HiPPO)
显存占用 O(1) O(L²) O(L)
状态大小 固定 KV Cache增长 固定

SSM的"两全其美": - 训练时:像Transformer一样并行(卷积形式) - 推理时:像RNN一样高效(递归形式)


2. HiPPO矩阵初始化理论

2.1 长程记忆的挑战

普通SSM的矩阵 \(\mathbf{A}\) 如果随机初始化,会导致: - 梯度消失或爆炸 - 无法有效捕捉长程依赖

HiPPO(High-order Polynomial Projection Operators) 理论提供了系统性的初始化方法。

2.2 HiPPO的核心思想

将输入信号 \(x(t)\) 投影到多项式基函数上,用隐状态 \(h(t)\) 存储多项式系数。

关键洞察:为了保持对历史信息的"记忆",需要设计特殊的 \(\mathbf{A}\) 矩阵,使得隐状态能够以特定方式压缩历史信息。

2.3 HiPPO矩阵

对于滑动窗口的记忆模式,HiPPO矩阵为:

\[\mathbf{A}_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}\]
Python
def hippo_matrix(N):
    """
    构造HiPPO矩阵(LegS变体)
    N: 隐状态维度
    """
    A = torch.zeros(N, N)
    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = -((2*n + 1) ** 0.5) * ((2*k + 1) ** 0.5)
            elif n == k:
                A[n, k] = -(n + 1)
            else:
                A[n, k] = 0
    return A

def hippo_init(N):
    """
    HiPPO初始化的完整版本(S4/Mamba使用)
    """
    # HiPPO-LegS矩阵
    A = torch.zeros(N, N)
    for i in range(N):
        for j in range(N):
            if i > j:
                A[i, j] = (2*i + 1) ** 0.5 * (2*j + 1) ** 0.5
            elif i == j:
                A[i, j] = i + 1

    A = -A  # 加负号

    # B矩阵
    B = torch.zeros(N, 1)
    for i in range(N):
        B[i, 0] = (2*i + 1) ** 0.5

    return A, B

2.4 为什么HiPPO有效

HiPPO矩阵的特殊结构保证了:

  1. 稳定的特征值分布:特征值分布在负实轴上,避免数值不稳定
  2. 指数衰减的记忆:能够以指数衰减的方式"记住"历史
  3. 正交多项式基:隐状态对应Legendre多项式系数,具有最优逼近性质

3. Mamba架构详解

3.1 从S4到Mamba的演进

Text Only
LSSL (2021) → S4 (2022) → S4D (2022) → H3 (2022) → Mamba (2023) → Mamba-2 (2024)

S4(Structured State Spaces):首次让SSM在长序列任务上超越Transformer Mamba:引入选择性机制,让SSM具有内容感知能力

3.2 选择性状态空间(Selective SSM)

核心问题

传统SSM的参数 \(\mathbf{A}, \mathbf{B}, \mathbf{C}\)时不变的(Time-Invariant),即对所有时间步使用相同的参数。这限制了模型的表达能力。

Mamba的创新:时变参数

让参数依赖于输入:

\[\mathbf{B}_t = \text{Linear}_B(x_t)\]
\[\mathbf{C}_t = \text{Linear}_C(x_t)\]
\[\Delta_t = \text{Broadcast}(\text{Softplus}(\text{Linear}_{\Delta}(x_t)))\]

选择性机制的意义: - \(\mathbf{B}_t\) 控制输入多少信息进入隐状态 - \(\mathbf{C}_t\) 控制输出多少信息从隐状态提取 - \(\Delta_t\) 控制时间分辨率(关注细粒度还是粗粒度)

Python
class SelectiveSSM(nn.Module):
    """
    选择性状态空间模块(Mamba核心)
    """
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state  # N, 隐状态维度
        self.d_conv = d_conv    # 局部卷积核大小
        self.d_inner = d_model * expand  # 扩展维度

        # 输入投影
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # 局部卷积
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner, 
            kernel_size=d_conv, 
            padding=d_conv - 1,
            groups=self.d_inner
        )

        # SSM参数投影
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)

        # 可学习的A参数(对角线)
        self.A_log = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.D = nn.Parameter(torch.ones(self.d_inner))  # Skip connection

        # 输出投影
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

    def forward(self, x):
        """
        x: (batch, seq_len, d_model)
        """
        batch, seq_len, _ = x.shape

        # 1. 输入投影并分割
        xz = self.in_proj(x)  # (batch, seq_len, d_inner * 2)
        x, z = xz.chunk(2, dim=-1)  # 各 (batch, seq_len, d_inner)

        # 2. 局部卷积 + 激活
        x = x.transpose(1, 2)  # (batch, d_inner, seq_len)
        x = self.conv1d(x)[:, :, :seq_len]
        x = x.transpose(1, 2)  # (batch, seq_len, d_inner)
        x = F.silu(x)

        # 3. 计算时变参数
        x_dbl = self.x_proj(x)  # (batch, seq_len, d_state * 2 + 1)
        delta, B, C = x_dbl.split([1, self.d_state, self.d_state], dim=-1)
        delta = F.softplus(delta)  # 确保正值

        # 4. SSM计算
        A = -torch.exp(self.A_log)  # (d_inner, d_state)
        y = self.selective_scan(x, delta, A, B, C, self.D)

        # 5. 门控 + 输出投影
        y = y * F.silu(z)
        return self.out_proj(y)

    def selective_scan(self, u, delta, A, B, C, D):
        """
        选择性扫描算法(核心计算)
        u: (batch, seq_len, d_inner) 输入
        delta: (batch, seq_len, 1) 时间步长
        A: (d_inner, d_state) 状态矩阵(对角)
        B: (batch, seq_len, d_state) 输入矩阵
        C: (batch, seq_len, d_state) 输出矩阵
        D: (d_inner,) skip connection
        """
        batch, seq_len, d_inner = u.shape
        d_state = A.shape[1]

        # 离散化 A_bar = exp(delta * A), B_bar = delta * B
        deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (batch, seq_len, d_inner, d_state)
        deltaB_u = delta * B * u.unsqueeze(-1)  # (batch, seq_len, d_inner, d_state)

        # 递归计算
        h = torch.zeros(batch, d_inner, d_state, device=u.device)
        ys = []

        for i in range(seq_len):
            h = deltaA[:, i] * h + deltaB_u[:, i]  # 状态更新
            y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1)  # 输出
            ys.append(y)

        y = torch.stack(ys, dim=1)  # (batch, seq_len, d_inner)
        y = y + u * D  # skip connection

        return y

3.3 硬件感知算法

问题:选择性SSM无法使用卷积形式

由于参数 \(\mathbf{B}_t, \mathbf{C}_t, \Delta_t\) 依赖于输入,无法预先计算卷积核,只能用递归形式。

解决方案:并行扫描算法

扫描算法(Scan/Prefix Sum) 可以并行化递归计算:

Text Only
传统递归:h[0] → h[1] → h[2] → h[3] → ... (串行)
并行扫描:使用二叉树结构,O(log n)步完成
Python
def parallel_scan(A, B, h0):
    """
    并行扫描算法示意(简化版)
    实际实现需要CUDA kernel
    h[t] = A[t] * h[t-1] + B[t]
    """
    L = A.shape[0]

    # 向上扫描(reduce)
    # 实际实现使用平衡二叉树

    # 向下扫描(downsweep)
    # 实际实现使用平衡二叉树

    pass  # 实际实现需要自定义CUDA kernel

Mamba的优化: 1. Kernel Fusion:将离散化、扫描、输出合并到一个CUDA kernel 2. 重计算策略:前向时不存储中间状态,反向时重计算(节省显存) 3. IO感知:减少HBM(高带宽内存)访问次数

3.4 Mamba块的完整结构

Text Only
输入 x
    ├──────────────────────────┐
    │                          │
    ▼                          ▼
┌─────────┐              ┌─────────┐
│ Linear  │              │ Linear  │
│ (d_inner*2)│           │ (d_inner*2)│
└────┬────┘              └────┬────┘
     │                        │
     ▼                        │
┌─────────┐                   │
│ Conv1D  │                   │
│ (local) │                   │
└────┬────┘                   │
     │                        │
     ▼                        │
┌─────────┐                   │
│  SiLU   │                   │
└────┬────┘                   │
     │                        │
     ▼                        │
┌─────────────────┐           │
│ Selective SSM   │           │
│ (B,C,Δ from x) │           │
└────────┬────────┘           │
         │                    │
         ▼                    ▼
       (×)─────────────────(SiLU)
    ┌─────────┐
    │ Linear  │
    │ (d_model)│
    └─────────┘
       输出

4. 线性注意力机制

4.1 标准注意力的复杂度瓶颈

标准缩放点积注意力:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

复杂度分析: - \(QK^T\): \(O(L^2 \cdot d)\) - Softmax: \(O(L^2)\) - 与V相乘: \(O(L^2 \cdot d)\) - 总计: \(O(L^2 \cdot d)\) 时间和空间

当序列长度 \(L\) 很大时(如100K tokens),\(L^2\) 成为瓶颈。

4.2 线性注意力的数学推导

核心思想:利用结合律

标准注意力可以写成:

\[\text{Attention}(Q, K, V) = \sum_{j=1}^{L} \frac{\exp(q_i \cdot k_j)}{\sum_{j'}\exp(q_i \cdot k_{j'})} v_j\]

关键洞察:如果能找到一个核函数 \(\phi\),使得:

\[\exp(q_i \cdot k_j) \approx \phi(q_i) \cdot \phi(k_j)\]

那么可以改变计算顺序:

\[\text{LinearAttention}(Q, K, V) = \phi(Q) \cdot (\phi(K)^T V)\]
  • 先计算 \(\phi(K)^T V\): \(O(L \cdot d^2)\)
  • 再与 \(\phi(Q)\) 相乘: \(O(L \cdot d^2)\)
  • 总计: \(O(L \cdot d^2)\),当 \(d < L\) 时是线性的!

4.3 Kernel化注意力

常用的核函数

  1. ReLU核\(\phi(x) = \text{ReLU}(x)\)

  2. ELU+1核\(\phi(x) = \text{ELU}(x) + 1\)

  3. 随机特征核(Performer): $\(\phi(x) = \frac{1}{\sqrt{m}}[\sin(W_1 x), \cos(W_1 x), \ldots, \sin(W_m x), \cos(W_m x)]\)$ 其中 \(W_i \sim \mathcal{N}(0, I)\)

Python
class LinearAttention(nn.Module):
    """
    线性注意力实现
    """
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        self.heads = heads
        self.dim_head = dim_head
        inner_dim = heads * dim_head

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        """
        x: (batch, seq_len, dim)
        """
        batch, seq_len, _ = x.shape

        # QKV投影
        qkv = self.to_qkv(x).reshape(batch, seq_len, 3, self.heads, self.dim_head)
        q, k, v = qkv.unbind(2)  # 各 (batch, seq_len, heads, dim_head)

        # 应用核函数(这里用ELU+1)
        q = F.elu(q) + 1
        k = F.elu(k) + 1

        # 重排维度
        q = q.transpose(1, 2)  # (batch, heads, seq_len, dim_head)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # 线性注意力: Q(K^T V) 而不是 (Q K^T)V
        # K^T V: (batch, heads, dim_head, dim_head)
        kv = torch.einsum('bhnd,bhne->bhde', k, v)

        # Q (K^T V): (batch, heads, seq_len, dim_head)
        qkv = torch.einsum('bhnd,bhde->bhne', q, kv)

        # 归一化
        k_sum = k.sum(dim=2, keepdim=True)  # (batch, heads, 1, dim_head)
        normalizer = torch.einsum('bhnd,bhmd->bhnm', q, k_sum).squeeze(-1)  # (batch, heads, seq_len)
        normalizer = normalizer.unsqueeze(-1) + 1e-6

        out = qkv / normalizer

        # 输出
        out = out.transpose(1, 2).reshape(batch, seq_len, -1)
        return self.to_out(out)

4.4 复杂度对比总结

方法 时间复杂度 空间复杂度 特点
标准Attention \(O(L^2 d)\) \(O(L^2)\) 精确,但长序列慢
线性Attention \(O(L d^2)\) \(O(L d)\) 近似,长序列快
Mamba/SSM \(O(L d N)\) \(O(L d + dN)\) 线性,状态压缩

其中 \(N\) 是SSM的状态维度(通常 \(N \ll L\))。


5. 代码实现

5.1 完整的Mamba块实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange

class MambaBlock(nn.Module):
    """
    Mamba块的完整实现

    参考: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
    作者: Gu & Dao (2023)
    """

    def __init__(
        self,
        d_model: int,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
        dt_rank: str | int = "auto",
        dt_min: float = 0.001,
        dt_max: float = 0.1,
        dt_init: str = "random",
        dt_scale: float = 1.0,
        dt_init_floor: float = 1e-4,
        conv_bias: bool = True,
        bias: bool = False,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

        # 输入投影: d_model -> d_inner * 2
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)

        # 局部卷积
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
        )

        # x投影: d_inner -> dt_rank + d_state * 2
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)

        # dt投影: dt_rank -> d_inner
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        # 初始化dt_proj特殊bias
        dt_init_std = self.dt_rank ** -0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # 初始化dt bias
        dt = torch.exp(
            torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)

        # A参数(对数空间)
        A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True

        # D参数(skip connection)
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.D._no_weight_decay = True

        # 输出投影
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)

    def forward(self, x, cache=None):
        """
        x: (batch, seq_len, d_model)
        返回: (batch, seq_len, d_model)
        """
        batch, seq_len, d_model = x.shape

        # 1. 输入投影
        xz = self.in_proj(x)  # (batch, seq_len, d_inner * 2)
        x, z = xz.chunk(2, dim=-1)  # 各 (batch, seq_len, d_inner)

        # 2. 局部卷积
        x = rearrange(x, 'b l d -> b d l')
        x = self.conv1d(x)[:, :, :seq_len]
        x = rearrange(x, 'b d l -> b l d')
        x = F.silu(x)

        # 3. 计算SSM参数
        x_dbl = self.x_proj(x)  # (batch, seq_len, dt_rank + d_state * 2)
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = self.dt_proj(dt)  # (batch, seq_len, d_inner)

        # 4. 选择性扫描
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        y = self.selective_scan(x, dt, A, B, C)

        # 5. 门控和输出
        y = y * F.silu(z)
        output = self.out_proj(y)

        return output

    def selective_scan(self, u, delta, A, B, C):
        """
        选择性扫描(简化版,实际应用需要CUDA优化)
        u: (batch, seq_len, d_inner) 输入
        delta: (batch, seq_len, d_inner) 时间步长
        A: (d_inner, d_state) 状态矩阵
        B: (batch, seq_len, d_state) 输入矩阵
        C: (batch, seq_len, d_state) 输出矩阵
        """
        batch, seq_len, d_inner = u.shape
        d_state = A.shape[1]

        # 确保delta为正
        delta = F.softplus(delta)

        # 离散化
        deltaA = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
        deltaB_u = torch.einsum('bld,bln,bld->bldn', delta, B, u)

        # 递归计算(这里用Python循环示意,实际需要并行扫描)
        h = torch.zeros(batch, d_inner, d_state, device=u.device, dtype=u.dtype)
        ys = []

        for i in range(seq_len):
            h = deltaA[:, i] * h + deltaB_u[:, i]
            y = torch.einsum('bdn,bn->bd', h, C[:, i])
            ys.append(y)

        y = torch.stack(ys, dim=1)  # (batch, seq_len, d_inner)
        y = y + u * self.D  # skip connection

        return y


class MambaLayer(nn.Module):
    """
    Mamba层 = Mamba块 + 归一化
    """
    def __init__(self, d_model, **kwargs):
        super().__init__()
        self.mamba = MambaBlock(d_model, **kwargs)
        self.norm = nn.RMSNorm(d_model)

    def forward(self, x):
        return self.norm(x + self.mamba(x))


class MambaModel(nn.Module):
    """
    完整的Mamba语言模型
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 768,
        n_layer: int = 24,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MambaLayer(d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layer)
        ])
        self.norm_f = nn.RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids):
        """
        input_ids: (batch, seq_len)
        """
        x = self.embedding(input_ids)

        for layer in self.layers:
            x = layer(x)

        x = self.norm_f(x)
        logits = self.lm_head(x)

        return logits

5.2 SSM前向传播(详细注释版)

Python
class SSMLayer(nn.Module):
    """
    基础SSM层(非选择性,用于理解原理)
    """
    def __init__(self, d_model, d_state=64, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # HiPPO初始化
        A, B = hippo_init(d_state)
        self.A = nn.Parameter(A)  # (d_state, d_state)
        self.B = nn.Parameter(B)  # (d_state, 1)
        self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)  # (1, d_state)
        self.D = nn.Parameter(torch.zeros(1))  # skip connection

        # 时间步长(可学习)
        self.delta = nn.Parameter(torch.ones(1) * 0.01)

        # 投影层
        self.in_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: (batch, seq_len, d_model)
        """
        residual = x
        x = self.norm(x)
        x = self.in_proj(x)

        # 离散化
        delta = F.softplus(self.delta)
        A_bar = torch.matrix_exp(delta * self.A)
        B_bar = torch.linalg.solve(self.A, (A_bar - torch.eye(self.d_state, device=x.device))) @ self.B
        B_bar = delta * self.B  # 简化近似

        # SSM递归
        batch, seq_len, _ = x.shape
        h = torch.zeros(batch, self.d_state, device=x.device)
        outputs = []

        for t in range(seq_len):
            # h_t = A_bar @ h_{t-1} + B_bar @ x_t
            h = A_bar @ h + B_bar.squeeze(-1) * x[:, t]
            # y_t = C @ h_t + D @ x_t
            y = self.C @ h + self.D * x[:, t]
            outputs.append(y)

        y = torch.stack(outputs, dim=1)  # (batch, seq_len, d_model)
        y = self.out_proj(y)
        y = self.dropout(y)

        return residual + y

5.3 与Transformer块的对比

Python
class TransformerBlock(nn.Module):
    """
    标准Transformer块(用于对比)
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # 自注意力
        residual = x
        x = self.norm1(x)
        x, _ = self.attention(x, x, x)
        x = residual + x

        # 前馈网络
        residual = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = residual + x

        return x


def compare_complexity():
    """
    复杂度对比实验
    """
    import time

    d_model = 768
    batch_size = 1
    seq_lengths = [128, 512, 1024, 2048, 4096]

    transformer = TransformerBlock(d_model, n_heads=12, d_ff=3072)
    mamba = MambaBlock(d_model)

    print(f"{'Seq Len':<10} {'Transformer (ms)':<20} {'Mamba (ms)':<15} {'Speedup':<10}")
    print("-" * 60)

    for seq_len in seq_lengths:
        x = torch.randn(batch_size, seq_len, d_model)

        # Transformer
        start = time.time()
        with torch.no_grad():
            _ = transformer(x)
        trans_time = (time.time() - start) * 1000

        # Mamba
        start = time.time()
        with torch.no_grad():
            _ = mamba(x)
        mamba_time = (time.time() - start) * 1000

        speedup = trans_time / mamba_time
        print(f"{seq_len:<10} {trans_time:<20.2f} {mamba_time:<15.2f} {speedup:<10.2f}x")


if __name__ == "__main__":
    # 测试Mamba块
    batch, seq_len, d_model = 2, 128, 256
    x = torch.randn(batch, seq_len, d_model)

    mamba = MambaBlock(d_model)
    out = mamba(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {out.shape}")

    # 复杂度对比
    compare_complexity()

6. 应用场景与前沿

6.1 长序列建模优势

Mamba/SSM在以下场景具有显著优势:

1. 长文本生成

  • 问题:Transformer的KV Cache随长度线性增长
  • Mamba优势:固定大小的隐状态,不受序列长度限制
Python
# 内存对比
seq_len = 100000  # 100K tokens
d_model = 4096
n_layers = 32

# Transformer KV Cache (假设8头)
kv_cache_size = 2 * n_layers * seq_len * d_model * 4  # bytes
print(f"Transformer KV Cache: {kv_cache_size / 1e9:.2f} GB")

# Mamba 状态
d_state = 16
mamba_state_size = n_layers * d_model * d_state * 4
print(f"Mamba State: {mamba_state_size / 1e6:.2f} MB")

2. 基因组学

  • DNA序列可达百万级长度
  • SSM可以高效处理超长序列

3. 音频/语音处理

  • 高采样率音频序列极长
  • 实时推理需要低延迟

4. 时间序列预测

  • 长历史依赖
  • 需要高效推理

6.2 与Transformer混合架构

Jamba架构

Jamba (AI21 Labs, 2024) 将Transformer和Mamba层混合:

Text Only
[Transformer Layer] → [Mamba Layer] → [Mamba Layer] → [Transformer Layer] → ...

优势: - Mamba层:高效处理长序列 - Transformer层:保持强泛化能力

Python
class JambaBlock(nn.Module):
    """
    Jamba混合块:Transformer + Mamba
    """
    def __init__(self, d_model, n_heads, d_ff, d_state=16, layer_type="mamba"):
        super().__init__()
        self.layer_type = layer_type

        if layer_type == "transformer":
            self.layer = TransformerBlock(d_model, n_heads, d_ff)
        else:
            self.layer = MambaLayer(d_model, d_state=d_state)

    def forward(self, x):
        return self.layer(x)


class JambaModel(nn.Module):
    """
    Jamba风格混合模型
    """
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)

        # 混合层配置:每4层中1个Transformer,3个Mamba
        layers = []
        for i in range(n_layers):
            if i % 4 == 0:
                layers.append(JambaBlock(d_model, n_heads, d_ff, layer_type="transformer"))
            else:
                layers.append(JambaBlock(d_model, n_heads, d_ff, layer_type="mamba"))

        self.layers = nn.ModuleList(layers)
        self.norm = nn.RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.lm_head(x)

6.3 Mamba-2与最新进展

Mamba-2 (2024)

核心改进

  1. 状态空间对偶性(State Space Duality)
  2. 揭示了SSM与半可分离矩阵的联系
  3. 允许使用矩阵乘法优化

  4. 并行算法改进

  5. 使用块分解矩阵乘法
  6. 更好的GPU利用率

  7. 与Attention的统一视角

  8. SSM可以看作"线性Attention"的特例
  9. 便于混合设计
Python
class Mamba2Block(nn.Module):
    """
    Mamba-2 块(简化示意)
    主要改进:使用块并行算法
    """
    def __init__(self, d_model, d_state=128, block_size=64):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.block_size = block_size

        # 参数
        self.A = nn.Parameter(torch.randn(d_model, d_state))
        self.B = nn.Parameter(torch.randn(d_model, d_state))
        self.C = nn.Parameter(torch.randn(d_model, d_state))

        self.in_proj = nn.Linear(d_model, d_model * 2)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        """
        使用块并行算法
        """
        batch, seq_len, _ = x.shape

        # 投影
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # 块并行SSM(示意)
        n_blocks = (seq_len + self.block_size - 1) // self.block_size
        outputs = []

        for block_idx in range(n_blocks):
            start = block_idx * self.block_size
            end = min(start + self.block_size, seq_len)
            x_block = x[:, start:end, :]

            # 块内并行计算
            # 实际实现使用优化的CUDA kernel
            y_block = self._block_ssm(x_block)
            outputs.append(y_block)

        y = torch.cat(outputs, dim=1)
        y = y * F.silu(z)

        return self.out_proj(y)

    def _block_ssm(self, x_block):
        """块内SSM计算"""
        # 简化实现
        return x_block

6.4 其他SSM变体

模型 年份 主要创新
S4 2022 结构化SSM,HiPPO初始化
S4D 2022 对角SSM,简化实现
H3 2022 混合SSM-Attention
Mamba 2023 选择性SSM
Mamba-2 2024 块并行,与Attention统一
Jamba 2024 Transformer-Mamba混合
MambaMixer 2024 用于视觉任务
VMamba 2024 视觉Mamba

6.5 实际应用建议

何时选择Mamba/SSM

推荐使用: - 序列长度 > 8K tokens - 推理延迟敏感场景 - 内存受限环境 - 流式处理需求

谨慎使用: - 短序列任务(< 1K) - 需要双向注意力(如BERT类任务) - 生态兼容性要求高

训练技巧

Python
# 1. 学习率设置
# Mamba的A_log参数需要较小学习率
optimizer = torch.optim.AdamW([
    {'params': model.embedding.parameters(), 'lr': 3e-4},
    {'params': model.layers.parameters(), 'lr': 3e-4},
    {'params': [p for n, p in model.named_parameters() if 'A_log' in n], 'lr': 1e-4}
])

# 2. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 3. 混合精度训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

with autocast():
    output = model(input_ids)
    loss = criterion(output, targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

7. 练习与自我检查

7.1 理论问题

  1. SSM离散化:推导零阶保持(ZOH)离散化公式,解释为什么 \(\bar{\mathbf{A}} = \exp(\Delta \mathbf{A})\)

  2. 复杂度分析

  3. 计算标准Attention在序列长度L=100K,d=4096时的内存需求
  4. 对比Mamba在相同设置下的内存需求

  5. HiPPO矩阵:解释HiPPO矩阵为什么能够保持长程记忆,其特征值分布有什么特点?

  6. 选择性机制:Mamba中的 \(\Delta_t\) 参数如何影响模型对序列的处理?

7.2 编程练习

Python
# 练习1:实现简化版SSM
class SimpleSSM(nn.Module):
    """
    TODO: 实现一个简化版SSM
    要求:
    1. 使用HiPPO初始化
    2. 实现递归形式的前向传播
    3. 实现卷积形式的前向传播
    4. 验证两种形式输出一致
    """
    pass

# 练习2:实现线性Attention
class LinearAttention(nn.Module):
    """
    TODO: 实现线性Attention
    要求:
    1. 使用ELU+1核函数
    2. 实现O(L*d²)复杂度的前向传播
    3. 与标准Attention对比输出差异
    """
    pass

# 练习3:复杂度实验
def complexity_experiment():
    """
    TODO: 设计实验对比以下模型在不同序列长度下的性能
    1. 标准Transformer
    2. 线性Attention
    3. Mamba

    测量指标:
    - 前向时间
    - 反向时间
    - 显存占用
    """
    pass

7.3 思考题

  1. SSM vs RNN:SSM如何解决RNN的梯度消失问题?两者在数学形式上的关键区别是什么?

  2. SSM vs Transformer

  3. 为什么Transformer在短序列上仍然占优?
  4. Mamba是否可以完全取代Transformer?为什么?

  5. 选择性机制

  6. 如果移除选择性机制(让B、C、Δ固定),Mamba的性能会如何变化?
  7. 设计实验验证你的假设

  8. 未来方向

  9. 如何将SSM扩展到多模态场景?
  10. SSM在强化学习中有哪些潜在应用?

7.4 进阶阅读

必读论文: 1. Mamba: Linear-Time Sequence Modeling with Selective State Spaces - Gu & Dao, 2023 2. Efficiently Modeling Long Sequences with Structured State Spaces - Gu et al., 2021 (S4) 3. Transformers are SSMs: Generalized Models and Efficient Algorithms - Dao & Gu, 2024 (Mamba-2)

推荐资源: - The Annotated S4 - S4的详细代码注释 - Mamba官方代码库 - State Spaces系列论文列表


总结

本章介绍了Mamba和状态空间模型这一新兴架构:

  1. SSM基础:从控制论的状态空间模型出发,理解连续到离散的转换,以及递归/卷积两种等价形式

  2. HiPPO理论:通过特殊矩阵初始化实现长程记忆,是SSM成功的关键

  3. Mamba创新:选择性机制让SSM具有内容感知能力,硬件感知算法实现高效训练

  4. 线性注意力:通过核函数近似,将O(L²)复杂度降为O(L),与SSM有深刻联系

  5. 实践应用:长序列建模、混合架构、最新进展(Mamba-2、Jamba)

核心要点: - Mamba实现了"两全其美":训练时并行,推理时O(1) - 选择性机制是Mamba超越传统SSM的关键 - SSM与线性Attention有统一的数学视角 - 长序列场景是SSM的主要应用领域


下一章预告03-视觉Transformer 将介绍Transformer在计算机视觉领域的应用,包括ViT、Swin Transformer等架构。