跳转至

07 - 连续时间扩散模型

学习时间: 5小时 重要性: ⭐⭐⭐⭐⭐ 从离散到连续,理解扩散模型的统一框架


🎯 学习目标

完成本章后,你将能够: - 理解从离散时间到连续时间的过渡 - 掌握VP-SDE、VE-SDE和子VP-SDE的区别 - 理解得分随机微分方程(Score SDE) - 掌握概率流ODE和似然计算 - 实现连续时间扩散模型


1. 从离散到连续

1.1 离散时间扩散的局限

回顾DDPM的离散时间前向过程: $\(x_{t} = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon_{t-1}\)$

问题: - 时间步数 \(T\) 是离散的(通常1000步) - 步长 \(\beta_t\) 需要精心设计 - 难以分析连续时间极限

1.2 连续时间极限

令时间步长 \(\Delta t \to 0\),离散过程收敛到随机微分方程(SDE)

\[dx = f(x, t) dt + g(t) dw\]

其中: - \(f(x, t)\):漂移系数 - \(g(t)\):扩散系数 - \(w\):维纳过程(标准布朗运动)

1.3 连续时间的优势

  1. 数学优雅:可以使用随机分析工具
  2. 统一框架:涵盖多种扩散模型
  3. 灵活采样:任意时间步长
  4. 精确似然:通过概率流ODE计算

2. 三种经典SDE

2.1 VP-SDE(Variance Preserving SDE)

动机:保持方差有界,避免爆炸或消失

SDE形式: $\(dx = -\frac{1}{2}\beta(t) x dt + \sqrt{\beta(t)} dw\)$

扰动核: $\(p(x(t) | x(0)) = \mathcal{N}(x(t); e^{-\frac{1}{2}\int_0^t \beta(s)ds} x(0), (1 - e^{-\int_0^t \beta(s)ds})I)\)$

特性: - 当 \(t \to \infty\)\(x(t) \sim \mathcal{N}(0, I)\) - 方差保持在 \([0, 1]\) 范围内 - 对应于DDPM的连续版本

2.2 VE-SDE(Variance Exploding SDE)

动机:允许方差随时间增长

SDE形式: $\(dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw\)$

扰动核: $\(p(x(t) | x(0)) = \mathcal{N}(x(t); x(0), (\sigma^2(t) - \sigma^2(0))I)\)$

特性: - 方差随时间增长(exploding) - 需要选择适当的 \(\sigma(t)\) 调度 - 对应于SMLD(Score Matching with Langevin Dynamics)

2.3 子VP-SDE(Sub-Variance Preserving SDE)

动机:VP-SDE的改进版本,更好的数值稳定性

SDE形式: $\(dx = -\frac{1}{2}\beta(t) x dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)ds})} dw\)$

特性: - 介于VP-SDE和VE-SDE之间 - 在Score SDE论文中表现最好 - 更好的数值稳定性

2.4 三种SDE对比

特性 VP-SDE VE-SDE 子VP-SDE
漂移
方差 保持 爆炸 次保持
稳定性 需调参 最好
对应离散模型 DDPM SMLD 改进版

3. 得分随机微分方程(Score SDE)

3.1 统一框架

一般形式的前向SDE: $\(dx = f(x, t) dt + g(t) dw\)$

关键洞察:通过Anderson定理,可以推导出对应的逆向SDE

\[dx = [f(x, t) - g^2(t) \nabla_x \log p_t(x)] dt + g(t) d\bar{w}\]

其中: - \(\nabla_x \log p_t(x)\)得分函数(score function) - \(d\bar{w}\) 是反向时间的维纳过程

3.2 得分函数估计

在实际中,我们用神经网络 \(s_\theta(x, t)\) 来近似得分函数:

\[s_\theta(x, t) \approx \nabla_x \log p_t(x)\]

训练目标(去噪得分匹配): $\(\mathcal{L}(\theta) = \mathbb{E}_{t, x(0), x(t)} \left[ \| s_\theta(x(t), t) - \nabla_{x(t)} \log p(x(t) | x(0)) \|^2 \right]\)$

为什么去噪得分匹配等价于显式得分匹配?

显式得分匹配(ESM)的目标是最小化 \(\mathbb{E}_{p_t(x)}[\|s_\theta(x,t) - \nabla_x \log p_t(x)\|^2]\),但 \(\nabla_x \log p_t(x)\) 未知。关键等式为:

\[\mathbb{E}_{p_t(x)}\left[\|s_\theta - \nabla_x \log p_t\|^2\right] = \mathbb{E}_{p_0(x_0)p_{t|0}(x_t|x_0)}\left[\|s_\theta(x_t,t) - \nabla_{x_t} \log p_{t|0}(x_t|x_0)\|^2\right] + C\]

其中 \(C\) 不依赖 \(\theta\)。证明思路:将 \(\nabla_x \log p_t(x)\) 展开为 \(\nabla_x \log \int p_{t|0}(x|x_0)p_0(x_0)dx_0\),利用 \(\nabla_x \log p_t(x) = \mathbb{E}_{p_0(x_0|x)}[\nabla_x \log p_{t|0}(x|x_0)]\)(后验加权平均),然后展开平方范数并利用该恒等式消去交叉项中的未知量。最终两个目标关于 \(\theta\) 的梯度相同(Vincent, 2011)。

因此,我们可以用已知的条件得分 \(\nabla_{x_t} \log p(x_t|x_0)\) 替代未知的边缘得分进行训练。对于高斯扰动核 \(p(x_t|x_0) = \mathcal{N}(\alpha_t x_0, \sigma_t^2 I)\),条件得分为 \(-\frac{x_t - \alpha_t x_0}{\sigma_t^2}\)

3.3 与DDPM的联系

对于VP-SDE,得分函数与DDPM的噪声预测的关系:

\[s_\theta(x, t) = -\frac{\epsilon_\theta(x, t)}{\sqrt{1 - \alpha(t)}}\]

其中 \(\alpha(t) = e^{-\int_0^t \beta(s)ds}\)


4. 概率流ODE

4.1 确定性生成

定理:存在一个常微分方程(ODE),其边缘分布与SDE相同:

\[\frac{dx}{dt} = f(x, t) - \frac{1}{2}g^2(t) s_\theta(x, t)\]

这就是概率流ODE(Probability Flow ODE)

4.2 ODE vs SDE

特性 SDE ODE
随机性
可逆性 概率可逆 确定性可逆
似然计算 困难 容易(用流模型技术)
采样速度

4.3 精确似然计算

通过概率流ODE,可以使用瞬时变化率公式计算精确似然:

\[\log p(x(0)) = \log p(x(T)) + \int_0^T \nabla \cdot \tilde{f}(x(t), t) dt\]

其中 \(\tilde{f}(x, t) = f(x, t) - \frac{1}{2}g^2(t) s_\theta(x, t)\)

散度计算: - 精确方法:自动微分(计算量大) - 近似方法:Skilling-Hutchinson迹估计


5. 数值方法

5.1 SDE求解器

Euler-Maruyama方法: $\(x_{t+\Delta t} = x_t + f(x_t, t)\Delta t + g(t) \sqrt{\Delta t} z_t, \quad z_t \sim \mathcal{N}(0, I)\)$

Milstein方法(更高精度): 增加二阶修正项

预测-校正方法: 1. 预测步:用Euler-Maruyama得到 \(\tilde{x}_{t+\Delta t}\) 2. 校正步:用得分函数修正

5.2 ODE求解器

Runge-Kutta方法: - RK45:自适应步长 - RK4:固定步长

Adams-Bashforth方法: 多步法,利用历史信息

5.3 步长选择

自适应步长: 根据局部误差估计调整步长

固定步长: 均匀离散化,简单易实现


6. 实现

Python
import torch
import torch.nn as nn
import numpy as np
from scipy import integrate

class ContinuousDiffusion:
    """
    连续时间扩散模型
    """
    def __init__(self, sde_type='vp', beta_min=0.1, beta_max=20.0):
        """
        参数:
            sde_type: 'vp', 've', 或 'subvp'
            beta_min, beta_max: beta调度参数
        """
        self.sde_type = sde_type
        self.beta_min = beta_min
        self.beta_max = beta_max

    def beta(self, t):
        """beta(t)调度"""
        return self.beta_min + t * (self.beta_max - self.beta_min)

    def drift(self, x, t):
        """漂移系数 f(x, t)"""
        if self.sde_type == 'vp':
            return -0.5 * self.beta(t) * x
        elif self.sde_type == 've':
            return torch.zeros_like(x)
        elif self.sde_type == 'subvp':
            return -0.5 * self.beta(t) * x
        else:
            raise ValueError(f"Unknown SDE type: {self.sde_type}")

    def diffusion(self, t):
        """扩散系数 g(t)"""
        if self.sde_type == 'vp':
            return np.sqrt(self.beta(t))
        elif self.sde_type == 've':
            sigma = self.sigma(t)
            # g(t) = σ(t)√(2 ln(σ_max/σ_min)),由 d[σ²(t)]/dt 推导
            return sigma * np.sqrt(2 * np.log(self.beta_max / self.beta_min))
        elif self.sde_type == 'subvp':
            alpha = np.exp(-0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * t**2)
            # g(t) = √(β(t)(1 - e^{-2∫β})),其中 e^{-2∫β} = α⁴
            return np.sqrt(self.beta(t) * (1 - alpha**4))
        else:
            raise ValueError(f"Unknown SDE type: {self.sde_type}")

    def sigma(self, t):
        """VE-SDE的sigma(t)"""
        return self.beta_min * (self.beta_max / self.beta_min) ** t

    def marginal_prob(self, x0, t):
        """
        计算边缘分布 p(x(t) | x(0))

        返回:
            mean: 均值
            std: 标准差
        """
        if self.sde_type == 'vp':
            alpha = np.exp(-0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * t**2)
            mean = alpha * x0
            std = np.sqrt(1 - alpha**2)
        elif self.sde_type == 've':
            mean = x0
            # 方差为 σ²(t) - σ²(0),标准差取平方根
            std = np.sqrt(self.sigma(t)**2 - self.sigma(0)**2)
        elif self.sde_type == 'subvp':
            alpha = np.exp(-0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * t**2)
            mean = alpha * x0
            # sub-VP-SDE: 方差 = (1-α²)²,标准差 = |1-α²| = 1-α² (因为α∈[0,1])
            # 注意:这里std是标准差,直接等于(1-α²)
            std = 1 - alpha**2
        else:
            raise ValueError(f"Unknown SDE type: {self.sde_type}")

        return mean, std

    def prior_sampling(self, shape):
        """从先验分布采样"""
        if self.sde_type == 'vp' or self.sde_type == 'subvp':
            return torch.randn(*shape)
        elif self.sde_type == 've':
            return torch.randn(*shape) * self.sigma(1.0)

    def forward_sde(self, x0, t, noise=None):
        """
        前向SDE: x(t) = mean + std * noise
        """
        mean, std = self.marginal_prob(x0, t)
        if noise is None:
            noise = torch.randn_like(x0)
        xt = mean + std * noise
        return xt, noise

    def reverse_sde(self, score_model, xt, t, dt, noise=None):
        """
        逆向SDE采样一步

        参数:
            score_model: 得分模型 s_θ(x, t)
            xt: 当前状态
            t: 当前时间
            dt: 时间步长(负值表示反向)
            noise: 可选的噪声
        """
        if noise is None:
            noise = torch.randn_like(xt)

        # 计算得分
        score = score_model(xt, t)

        # 漂移和扩散
        f = self.drift(xt, t)
        g = self.diffusion(t)

        # 逆向SDE: dx = (f - g² * score)dt + g * dw
        drift = f - g**2 * score
        diffusion = g

        # Euler-Maruyama
        x_prev = xt + drift * dt + diffusion * np.sqrt(-dt) * noise

        return x_prev

    def probability_flow_ode(self, score_model, xt, t, dt):
        """
        概率流ODE采样一步

        参数:
            score_model: 得分模型
            xt: 当前状态
            t: 当前时间
            dt: 时间步长(负值表示反向)
        """
        # 计算得分
        score = score_model(xt, t)

        # 漂移和扩散
        f = self.drift(xt, t)
        g = self.diffusion(t)

        # 概率流ODE: dx = (f - 0.5 * g² * score)dt
        dx = (f - 0.5 * g**2 * score) * dt

        x_prev = xt + dx

        return x_prev

    def sample_sde(self, score_model, shape, device='cuda', num_steps=1000, eps=1e-3):
        """
        使用SDE采样
        """
        score_model.eval()  # eval()评估模式

        # 从先验采样
        x = self.prior_sampling(shape).to(device)  # 移至GPU/CPU

        # 时间网格
        timesteps = torch.linspace(1.0, eps, num_steps)
        dt = -(1.0 - eps) / num_steps

        # 逆向采样
        with torch.no_grad():  # 禁用梯度计算,节省内存
            for i in range(num_steps):
                t = timesteps[i]
                t_batch = torch.ones(shape[0], device=device) * t
                x = self.reverse_sde(score_model, x, t_batch, dt)

        return x

    def sample_ode(self, score_model, shape, device='cuda', rtol=1e-5, atol=1e-5, method='RK45'):
        """
        使用ODE采样(自适应步长)
        """
        score_model.eval()

        # 初始条件
        x = self.prior_sampling(shape).to(device)

        # 定义ODE
        def ode_func(t, x):
            x = torch.tensor(x, device=device, dtype=torch.float32).reshape(shape)  # 重塑张量形状
            t = torch.ones(shape[0], device=device) * t

            with torch.no_grad():
                score = score_model(x, t)
                f = self.drift(x, t.item())  # 将单元素张量转为Python数值
                g = self.diffusion(t.item())
                dx = (f - 0.5 * g**2 * score).cpu().numpy().flatten()

            return dx

        # 求解ODE
        solution = integrate.solve_ivp(
            ode_func,
            [1.0, 1e-3],
            x.cpu().numpy().flatten(),
            rtol=rtol,
            atol=atol,
            method=method,
        )

        x_final = torch.tensor(solution.y[:, -1], device=device, dtype=torch.float32).reshape(shape)

        return x_final

# 得分模型示例
class ScoreNet(nn.Module):  # 继承nn.Module定义网络层
    """简化的得分网络"""
    def __init__(self, channels=3, time_emb_dim=256):
        super().__init__()  # super()调用父类方法
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # 简化的UNet结构
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.SiLU(),
        )

        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.SiLU(),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, channels, 3, padding=1),
        )

    def forward(self, x, t):
        # 时间嵌入
        if t.dim() == 0:
            t = t.unsqueeze(0).expand(x.shape[0])  # unsqueeze增加一个维度
        t_emb = self.time_embed(t.view(-1, 1).float())

        # 前向传播
        h = self.encoder(x)
        h = self.middle(h)
        h = self.decoder(h)

        return h

# 使用示例
if __name__ == "__main__":
    # 创建扩散模型
    diffusion = ContinuousDiffusion(sde_type='vp', beta_min=0.1, beta_max=20.0)

    # 创建得分模型
    score_model = ScoreNet(channels=3)

    # 测试前向过程
    x0 = torch.randn(4, 3, 32, 32)
    t = torch.tensor([0.5, 0.5, 0.5, 0.5])
    xt, noise = diffusion.forward_sde(x0, t)

    print(f"x0 shape: {x0.shape}")
    print(f"xt shape: {xt.shape}")
    print(f"noise shape: {noise.shape}")

    # 测试采样(需要训练好的模型)
    # samples = diffusion.sample_sde(score_model, shape=(4, 3, 32, 32))
    # print(f"Samples shape: {samples.shape}")

7. 本章总结

核心概念

  1. 连续时间扩散
  2. 从离散SDE到连续SDE
  3. 数学更优雅,分析更方便
  4. 统一框架

  5. 三种SDE

  6. VP-SDE:方差保持(DDPM)
  7. VE-SDE:方差爆炸(SMLD)
  8. 子VP-SDE:改进版本

  9. Score SDE

  10. 用神经网络估计得分函数
  11. 逆向SDE进行采样
  12. 统一训练目标

  13. 概率流ODE

  14. 确定性采样
  15. 精确似然计算
  16. 更快的采样

关键公式

概念 公式
VP-SDE \(dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)} dw\)
VE-SDE \(dx = \sqrt{d[\sigma^2(t)]/dt} dw\)
逆向SDE \(dx = (f - g^2 \nabla \log p)dt + g d\bar{w}\)
概率流ODE \(\frac{dx}{dt} = f - \frac{1}{2}g^2 \nabla \log p\)

实现要点

Python
# 核心采样循环
def sample_sde(score_model, shape):
    x = prior_sampling(shape)
    for t in reversed(timesteps):
        score = score_model(x, t)
        x = reverse_sde_step(x, t, score)
    return x

📝 自测问题

基础问题

  1. 连续时间优势
  2. 为什么需要连续时间扩散模型?
  3. 离散和连续的主要区别?
  4. 连续时间的数学工具?

  5. 三种SDE

  6. VP-SDE、VE-SDE、子VP-SDE的区别?
  7. 各自的优势和适用场景?
  8. 如何选择合适的SDE?

  9. Score SDE

  10. 什么是得分函数?
  11. 如何训练得分模型?
  12. 与DDPM的关系?

  13. 概率流ODE

  14. ODE与SDE的区别?
  15. 如何计算精确似然?
  16. 数值求解方法?

编程练习

  1. 实现三种SDE的前向过程
  2. 实现逆向SDE采样
  3. 实现概率流ODE采样
  4. 比较SDE和ODE的采样质量

思考题

  1. 连续时间模型的计算挑战?
  2. 如何设计更好的SDE?
  3. 概率流ODE的局限性?

🔗 下一步

理解了连续时间扩散模型后,我们将进入扩散模型变体与进阶,学习DDIM加速采样等高级技术。

→ 下一步:04-扩散模型变体与进阶/01-DDIM加速采样.md