跳转至

03 - 反向去噪过程

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 扩散模型的核心学习部分


🎯 学习目标

完成本章后,你将能够: - 理解反向过程的数学定义 - 掌握为什么反向过程也是高斯分布 - 理解神经网络在反向过程中的作用 - 了解反向过程的均值和方差参数化 - 实现反向过程的采样算法


1. 反向过程概述

1.1 什么是反向过程

定义:反向过程是一个学习的马尔可夫链,逐步从噪声中恢复数据。

直观理解

Text Only
x_T (纯噪声)
    ↓ - 预测噪声
x_{T-1} (稍微清晰)
    ↓ - 预测噪声
x_{T-2} (更清晰)
    ...
    ↓ - 预测噪声
x_0 (原始图像)

关键问题: - 如果前向过程是固定的,反向过程可以推导出来吗? - 为什么需要用神经网络学习?

1.2 反向过程的数学定义

目标:学习 \(p_\theta(x_{t-1} | x_t)\),即从 \(x_t\) 恢复 \(x_{t-1}\) 的条件分布。

数学形式: $\(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)$

其中: - \(\mu_\theta(x_t, t)\):学习的均值函数 - \(\Sigma_\theta(x_t, t)\):学习的方差(或固定) - \(\theta\):神经网络的参数

1.3 为什么反向过程也是高斯分布?

关键定理:当 \(\beta_t\) 足够小时,反向过程 \(q(x_{t-1} | x_t)\) 也是高斯分布。

直观解释: - 前向过程每一步只添加少量噪声 - 因此反向过程每一步只需要去除少量噪声 - 小步长的反向过程可以用高斯分布很好地近似

数学推导(简要): 根据贝叶斯定理: $\(q(x_{t-1} | x_t) = \frac{q(x_t | x_{t-1}) q(x_{t-1})}{q(x_t)}\)$

\(\beta_t \to 0\) 时,这个分布趋近于高斯分布。


2. 后验分布 \(q(x_{t-1} | x_t, x_0)\)

2.1 为什么需要 \(x_0\)

直接计算 \(q(x_{t-1} | x_t)\) 是困难的,因为需要积分掉 \(x_0\)

但如果已知 \(x_0\),我们可以精确计算: $\(q(x_{t-1} | x_t, x_0)\)$

2.2 后验分布的闭合形式

定理: $\(q(x_{t-1} | x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t \mathbf{I})\)$

其中:

\[\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_t\]
\[\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t\]

2.3 推导思路

利用贝叶斯定理和三个高斯分布: 1. \(q(x_t | x_{t-1}) = \mathcal{N}(\sqrt{\alpha_t} x_{t-1}, \beta_t \mathbf{I})\) 2. \(q(x_{t-1} | x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}} x_0, (1-\bar{\alpha}_{t-1})\mathbf{I})\) 3. \(q(x_t | x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)\mathbf{I})\)

通过高斯分布的乘积性质,可以得到后验分布。

2.4 代码实现

Python
import numpy as np
import torch

def q_posterior_mean_variance(x_0, x_t, t, alphas_cumprod, betas):
    """
    计算后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差

    参数:
        x_0: 原始数据
        x_t: 当前时刻的数据
        t: 当前时间步
        alphas_cumprod: 累积alpha值
        betas: beta调度

    返回:
        posterior_mean: 后验均值
        posterior_variance: 后验方差
        posterior_log_variance: 后验对数方差(数值稳定性)
    """
    # 获取alpha_bar_{t-1}, alpha_bar_t
    alpha_bar_t = alphas_cumprod[t]
    alpha_bar_t_prev = alphas_cumprod[t-1] if t > 0 else torch.ones_like(alpha_bar_t)

    # 计算beta_t
    beta_t = betas[t]

    # 后验均值: μ̃_t(x_t, x_0)
    # 根据DDPM论文,后验均值系数为:
    # coef_x0 = sqrt(alpha_bar_{t-1}) * beta_t / (1 - alpha_bar_t)
    # coef_xt = sqrt(alpha_t) * (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t)
    coef_x0 = torch.sqrt(alpha_bar_t_prev) * beta_t / (1 - alpha_bar_t)
    coef_xt = torch.sqrt(1 - betas[t]) * (1 - alpha_bar_t_prev) / (1 - alpha_bar_t)

    posterior_mean = coef_x0 * x_0 + coef_xt * x_t

    # 后验方差: β̃_t
    posterior_variance = beta_t * (1 - alpha_bar_t_prev) / (1 - alpha_bar_t)

    # 边界处理:当 t=0 时,posterior_variance 为 0,torch.log(0) = -inf
    # 解决方案:使用 clamp 确保方差有下界,或对 t=0 特殊处理
    # 方法1:clamp 方式(推荐)
    posterior_variance = torch.clamp(posterior_variance, min=1e-20)
    posterior_log_variance = torch.log(posterior_variance)

    # 方法2:对 t=0 特殊处理(DDPM原始实现)
    # if t == 0:
    #     # t=0 时直接返回 x_0,不需要方差
    #     posterior_log_variance = torch.zeros_like(posterior_variance)
    # else:
    #     posterior_log_variance = torch.log(torch.clamp(posterior_variance, min=1e-20))

    return posterior_mean, posterior_variance, posterior_log_variance

# 简化版本(DDPM中使用)
def predict_x0_from_eps(x_t, eps, t, alphas_cumprod):
    """
    从预测的噪声反推 x_0
    根据: x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
    解得: x_0 = (x_t - √(1-ᾱ_t) * ε) / √ᾱ_t
    """
    alpha_bar_t = alphas_cumprod[t]
    sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
    sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)

    x_0_pred = (x_t - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t
    return x_0_pred

3. 神经网络的参数化

3.1 预测目标的选择

神经网络 \(\epsilon_\theta(x_t, t)\) 可以预测不同的目标:

预测目标 公式 优点 缺点
直接预测 \(x_0\) \(x_0 = f_\theta(x_t, t)\) 直观 数值不稳定
预测噪声 \(\epsilon\) \(\epsilon = \epsilon_\theta(x_t, t)\) 稳定,效果好 需要转换得到 \(x_0\)
预测得分函数 \(\nabla \log p(x_t)\) 理论优雅 实现复杂

DDPM选择预测噪声 \(\epsilon\),因为: 1. 数值稳定性好 2. 与 \(x_t\) 同维度,容易学习 3. 可以直接用于重参数化

3.2 从噪声预测推导均值

已知: $\(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\)$

如果网络预测了噪声 \(\epsilon_\theta\),可以重构 \(x_0\): $\(x_0 \approx \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta}{\sqrt{\bar{\alpha}_t}}\)$

代入后验均值公式: $\(\mu_\theta(x_t, t) = \tilde{\mu}_t\left(x_t, \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta}{\sqrt{\bar{\alpha}_t}}\right)\)$

简化后: $\(\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)\)$

3.3 方差的选择

方案1:固定方差 $\(\Sigma_\theta(x_t, t) = \beta_t \mathbf{I}\)$ 或 $\(\Sigma_\theta(x_t, t) = \tilde{\beta}_t \mathbf{I}\)$

方案2:学习方差(更复杂,收益有限)

DDPM采用固定方差,发现效果已经很好。


4. 反向采样算法

4.1 算法流程

Text Only
算法: 反向采样 (Reverse Sampling)
─────────────────────────────────
输入: 训练好的噪声预测网络 ε_θ
输出: 生成的样本 x_0

1. 从标准高斯分布采样: x_T ~ N(0, I)

2. 对于 t = T, T-1, ..., 1:

   a. 预测噪声: ε_θ(x_t, t)

   b. 计算均值: μ_θ = (x_t - β_t/√(1-ᾱ_t) * ε_θ) / √α_t

   c. 如果 t > 1:
         z ~ N(0, I)
      否则:
         z = 0

   d. 采样: x_{t-1} = μ_θ + √β_t * z

3. 返回 x_0

4.2 代码实现

Python
import torch
import torch.nn as nn
import numpy as np

class ReverseDiffusion:
    """
    反向扩散过程的实现
    """
    def __init__(self, model, timesteps=1000, beta_schedule='linear'):
        """
        参数:
            model: 噪声预测网络 ε_θ
            timesteps: 扩散步数
            beta_schedule: 噪声调度策略
        """
        self.model = model
        self.timesteps = timesteps

        # 设置beta调度
        if beta_schedule == 'linear':
            self.betas = self._linear_beta_schedule(timesteps)
        else:
            raise NotImplementedError()

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])

        # 转换为torch张量
        self.betas = torch.from_numpy(self.betas).float()
        self.alphas = torch.from_numpy(self.alphas).float()
        self.alphas_cumprod = torch.from_numpy(self.alphas_cumprod).float()
        self.alphas_cumprod_prev = torch.from_numpy(self.alphas_cumprod_prev).float()

        # 预计算一些值
        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

        # 后验方差
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )

    def _linear_beta_schedule(self, timesteps, beta_start=0.0001, beta_end=0.02):
        return np.linspace(beta_start, beta_end, timesteps)

    def p_mean_variance(self, x_t, t):
        """
        计算反向过程 p(x_{t-1} | x_t) 的均值和方差

        参数:
            x_t: 当前时刻的数据 [B, C, H, W]
            t: 时间步 [B]

        返回:
            mean: 均值
            variance: 方差
            log_variance: 对数方差
        """
        # 预测噪声
        eps_pred = self.model(x_t, t)

        # 计算均值: μ_θ = (x_t - β_t/√(1-ᾱ_t) * ε_θ) / √α_t
        beta_t = self._extract(self.betas, t, x_t.shape)
        sqrt_alpha_t = self._extract(self.sqrt_alphas, t, x_t.shape)
        sqrt_one_minus_alpha_cumprod_t = self._extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_t.shape
        )

        mean = (x_t - beta_t / sqrt_one_minus_alpha_cumprod_t * eps_pred) / sqrt_alpha_t

        # 方差
        variance = self._extract(self.posterior_variance, t, x_t.shape)
        log_variance = torch.log(variance)

        return mean, variance, log_variance

    def p_sample(self, x_t, t):
        """
        从 p(x_{t-1} | x_t) 采样
        """
        mean, variance, log_variance = self.p_mean_variance(x_t, t)

        # 只在 t > 0 时添加噪声
        noise = torch.randn_like(x_t)
        nonzero_mask = (t != 0).float().view(-1, 1, 1, 1)  # 重塑张量形状

        x_prev = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise

        return x_prev

    def sample(self, shape, device='cpu'):
        """
        生成样本

        参数:
            shape: 输出形状 (B, C, H, W)
            device: 计算设备

        返回:
            x_0: 生成的样本
        """
        self.model.eval()  # eval()评估模式

        # 从纯噪声开始
        x_t = torch.randn(shape, device=device)

        # 反向迭代
        for t in reversed(range(self.timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
            x_t = self.p_sample(x_t, t_batch)

        return x_t

    def _extract(self, a, t, x_shape):
        """提取对应时间步的值"""
        batch_size = t.shape[0]
        out = a.to(t.device).gather(0, t).float()
        return out.view(batch_size, *((1,) * (len(x_shape) - 1)))

# 简单的噪声预测网络示例
class SimpleUNet(nn.Module):  # 继承nn.Module定义网络层
    """简化的UNet用于噪声预测"""
    def __init__(self, in_channels=3, model_channels=64, time_emb_dim=128):
        super().__init__()  # super()调用父类方法
        # 时间嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # 简化的编码器-解码器结构
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, model_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(model_channels, model_channels, 3, padding=1),
            nn.ReLU(),
        )

        self.middle = nn.Sequential(
            nn.Conv2d(model_channels, model_channels, 3, padding=1),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(model_channels, model_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(model_channels, in_channels, 3, padding=1),
        )

    def forward(self, x, t):
        # 简化的前向传播
        h = self.encoder(x)
        h = self.middle(h)
        out = self.decoder(h)
        return out

5. 可视化反向过程

Python
import matplotlib.pyplot as plt

def visualize_reverse_process(model, diffusion, shape, save_path='reverse_process.png'):
    """
    可视化反向扩散过程
    """
    model.eval()

    # 从纯噪声开始
    x = torch.randn(shape)

    # 存储中间结果
    intermediates = [x.clone()]

    # 反向迭代,每隔一定步数保存
    save_steps = [999, 900, 800, 600, 400, 200, 100, 50, 0]

    with torch.no_grad():  # 禁用梯度计算,节省内存
        for t in reversed(range(diffusion.timesteps)):
            t_batch = torch.full((shape[0],), t, dtype=torch.long)
            x = diffusion.p_sample(x, t_batch)

            if t in save_steps:
                intermediates.append(x.clone())

    # 可视化
    n_show = len(intermediates)
    fig, axes = plt.subplots(1, n_show, figsize=(3*n_show, 3))

    for i, (ax, img) in enumerate(zip(axes, intermediates)):  # enumerate同时获取索引和元素  # zip按位置配对
        # 转换为可显示的格式
        img_np = img[0].permute(1, 2, 0).cpu().numpy()
        img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

        ax.imshow(img_np)
        step = save_steps[i] if i < len(save_steps) else 0
        ax.set_title(f't={step}')
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.show()

# 使用示例(需要训练好的模型)
# model = SimpleUNet()
# diffusion = ReverseDiffusion(model)
# visualize_reverse_process(model, diffusion, shape=(1, 3, 32, 32))

6. 本章总结

核心概念

  1. 反向过程
  2. 学习的马尔可夫链
  3. 从噪声逐步恢复数据
  4. 也是高斯分布(当步长足够小)

  5. 后验分布

  6. \(q(x_{t-1} | x_t, x_0)\) 有闭合形式
  7. 均值依赖于 \(x_0\)\(x_t\)
  8. 方差由前向过程决定

  9. 神经网络参数化

  10. 预测噪声 \(\epsilon\) 效果最好
  11. 从噪声预测推导均值
  12. 方差通常固定

关键公式

概念 公式
反向过程 \(p_\theta(x_{t-1} \mid x_t) = \mathcal{N}(\mu_\theta, \Sigma_\theta)\)
后验均值 \(\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_t\)
预测均值 \(\mu_\theta = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta)\)
采样 \(x_{t-1} = \mu_\theta + \sqrt{\beta_t} \cdot z\)

算法流程

Python
def sample(model, timesteps):
    x = randn(shape)  # 从噪声开始
    for t in reversed(range(timesteps)):
        eps = model(x, t)  # 预测噪声
        x = denoise(x, eps, t)  # 去噪
    return x  # 返回生成结果

📝 自测问题

基础问题

  1. 反向过程的特性
  2. 反向过程是固定的还是学习的?为什么?
  3. 为什么反向过程也是高斯分布?
  4. 神经网络在反向过程中起什么作用?

  5. 后验分布

  6. 解释 \(q(x_{t-1} | x_t, x_0)\) 的含义
  7. 为什么需要已知 \(x_0\) 才能计算后验?
  8. 后验均值公式中各项的含义是什么?

  9. 参数化选择

  10. 为什么DDPM选择预测噪声而不是直接预测 \(x_0\)
  11. 如何从噪声预测推导均值?
  12. 方差为什么要固定?

编程练习

  1. 实现完整的反向采样算法
  2. 可视化不同时间步的去噪效果
  3. 比较不同方差选择的效果

思考题

  1. 如果前向过程的步长很大,反向过程还能用高斯近似吗?
  2. 为什么采样时只在 \(t > 0\) 时添加噪声?
  3. 如何改进反向过程以提高采样速度?

🔗 下一步

理解了反向过程后,我们将学习训练目标的推导,了解如何训练噪声预测网络。

→ 下一步:04-训练目标推导.md