03 - 反向去噪过程¶
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 扩散模型的核心学习部分
🎯 学习目标¶
完成本章后,你将能够: - 理解反向过程的数学定义 - 掌握为什么反向过程也是高斯分布 - 理解神经网络在反向过程中的作用 - 了解反向过程的均值和方差参数化 - 实现反向过程的采样算法
1. 反向过程概述¶
1.1 什么是反向过程¶
定义:反向过程是一个学习的马尔可夫链,逐步从噪声中恢复数据。
直观理解:
关键问题: - 如果前向过程是固定的,反向过程可以推导出来吗? - 为什么需要用神经网络学习?
1.2 反向过程的数学定义¶
目标:学习 \(p_\theta(x_{t-1} | x_t)\),即从 \(x_t\) 恢复 \(x_{t-1}\) 的条件分布。
数学形式: $\(p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)$
其中: - \(\mu_\theta(x_t, t)\):学习的均值函数 - \(\Sigma_\theta(x_t, t)\):学习的方差(或固定) - \(\theta\):神经网络的参数
1.3 为什么反向过程也是高斯分布?¶
关键定理:当 \(\beta_t\) 足够小时,反向过程 \(q(x_{t-1} | x_t)\) 也是高斯分布。
直观解释: - 前向过程每一步只添加少量噪声 - 因此反向过程每一步只需要去除少量噪声 - 小步长的反向过程可以用高斯分布很好地近似
数学推导(简要): 根据贝叶斯定理: $\(q(x_{t-1} | x_t) = \frac{q(x_t | x_{t-1}) q(x_{t-1})}{q(x_t)}\)$
当 \(\beta_t \to 0\) 时,这个分布趋近于高斯分布。
2. 后验分布 \(q(x_{t-1} | x_t, x_0)\)¶
2.1 为什么需要 \(x_0\)¶
直接计算 \(q(x_{t-1} | x_t)\) 是困难的,因为需要积分掉 \(x_0\)。
但如果已知 \(x_0\),我们可以精确计算: $\(q(x_{t-1} | x_t, x_0)\)$
2.2 后验分布的闭合形式¶
定理: $\(q(x_{t-1} | x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t \mathbf{I})\)$
其中:
2.3 推导思路¶
利用贝叶斯定理和三个高斯分布: 1. \(q(x_t | x_{t-1}) = \mathcal{N}(\sqrt{\alpha_t} x_{t-1}, \beta_t \mathbf{I})\) 2. \(q(x_{t-1} | x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}} x_0, (1-\bar{\alpha}_{t-1})\mathbf{I})\) 3. \(q(x_t | x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)\mathbf{I})\)
通过高斯分布的乘积性质,可以得到后验分布。
2.4 代码实现¶
import numpy as np
import torch
def q_posterior_mean_variance(x_0, x_t, t, alphas_cumprod, betas):
"""
计算后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差
参数:
x_0: 原始数据
x_t: 当前时刻的数据
t: 当前时间步
alphas_cumprod: 累积alpha值
betas: beta调度
返回:
posterior_mean: 后验均值
posterior_variance: 后验方差
posterior_log_variance: 后验对数方差(数值稳定性)
"""
# 获取alpha_bar_{t-1}, alpha_bar_t
alpha_bar_t = alphas_cumprod[t]
alpha_bar_t_prev = alphas_cumprod[t-1] if t > 0 else torch.ones_like(alpha_bar_t)
# 计算beta_t
beta_t = betas[t]
# 后验均值: μ̃_t(x_t, x_0)
# 根据DDPM论文,后验均值系数为:
# coef_x0 = sqrt(alpha_bar_{t-1}) * beta_t / (1 - alpha_bar_t)
# coef_xt = sqrt(alpha_t) * (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t)
coef_x0 = torch.sqrt(alpha_bar_t_prev) * beta_t / (1 - alpha_bar_t)
coef_xt = torch.sqrt(1 - betas[t]) * (1 - alpha_bar_t_prev) / (1 - alpha_bar_t)
posterior_mean = coef_x0 * x_0 + coef_xt * x_t
# 后验方差: β̃_t
posterior_variance = beta_t * (1 - alpha_bar_t_prev) / (1 - alpha_bar_t)
# 边界处理:当 t=0 时,posterior_variance 为 0,torch.log(0) = -inf
# 解决方案:使用 clamp 确保方差有下界,或对 t=0 特殊处理
# 方法1:clamp 方式(推荐)
posterior_variance = torch.clamp(posterior_variance, min=1e-20)
posterior_log_variance = torch.log(posterior_variance)
# 方法2:对 t=0 特殊处理(DDPM原始实现)
# if t == 0:
# # t=0 时直接返回 x_0,不需要方差
# posterior_log_variance = torch.zeros_like(posterior_variance)
# else:
# posterior_log_variance = torch.log(torch.clamp(posterior_variance, min=1e-20))
return posterior_mean, posterior_variance, posterior_log_variance
# 简化版本(DDPM中使用)
def predict_x0_from_eps(x_t, eps, t, alphas_cumprod):
"""
从预测的噪声反推 x_0
根据: x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
解得: x_0 = (x_t - √(1-ᾱ_t) * ε) / √ᾱ_t
"""
alpha_bar_t = alphas_cumprod[t]
sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)
x_0_pred = (x_t - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t
return x_0_pred
3. 神经网络的参数化¶
3.1 预测目标的选择¶
神经网络 \(\epsilon_\theta(x_t, t)\) 可以预测不同的目标:
| 预测目标 | 公式 | 优点 | 缺点 |
|---|---|---|---|
| 直接预测 \(x_0\) | \(x_0 = f_\theta(x_t, t)\) | 直观 | 数值不稳定 |
| 预测噪声 \(\epsilon\) | \(\epsilon = \epsilon_\theta(x_t, t)\) | 稳定,效果好 | 需要转换得到 \(x_0\) |
| 预测得分函数 | \(\nabla \log p(x_t)\) | 理论优雅 | 实现复杂 |
DDPM选择预测噪声 \(\epsilon\),因为: 1. 数值稳定性好 2. 与 \(x_t\) 同维度,容易学习 3. 可以直接用于重参数化
3.2 从噪声预测推导均值¶
已知: $\(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\)$
如果网络预测了噪声 \(\epsilon_\theta\),可以重构 \(x_0\): $\(x_0 \approx \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta}{\sqrt{\bar{\alpha}_t}}\)$
代入后验均值公式: $\(\mu_\theta(x_t, t) = \tilde{\mu}_t\left(x_t, \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta}{\sqrt{\bar{\alpha}_t}}\right)\)$
简化后: $\(\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)\)$
3.3 方差的选择¶
方案1:固定方差 $\(\Sigma_\theta(x_t, t) = \beta_t \mathbf{I}\)$ 或 $\(\Sigma_\theta(x_t, t) = \tilde{\beta}_t \mathbf{I}\)$
方案2:学习方差(更复杂,收益有限)
DDPM采用固定方差,发现效果已经很好。
4. 反向采样算法¶
4.1 算法流程¶
算法: 反向采样 (Reverse Sampling)
─────────────────────────────────
输入: 训练好的噪声预测网络 ε_θ
输出: 生成的样本 x_0
1. 从标准高斯分布采样: x_T ~ N(0, I)
2. 对于 t = T, T-1, ..., 1:
a. 预测噪声: ε_θ(x_t, t)
b. 计算均值: μ_θ = (x_t - β_t/√(1-ᾱ_t) * ε_θ) / √α_t
c. 如果 t > 1:
z ~ N(0, I)
否则:
z = 0
d. 采样: x_{t-1} = μ_θ + √β_t * z
3. 返回 x_0
4.2 代码实现¶
import torch
import torch.nn as nn
import numpy as np
class ReverseDiffusion:
"""
反向扩散过程的实现
"""
def __init__(self, model, timesteps=1000, beta_schedule='linear'):
"""
参数:
model: 噪声预测网络 ε_θ
timesteps: 扩散步数
beta_schedule: 噪声调度策略
"""
self.model = model
self.timesteps = timesteps
# 设置beta调度
if beta_schedule == 'linear':
self.betas = self._linear_beta_schedule(timesteps)
else:
raise NotImplementedError()
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
# 转换为torch张量
self.betas = torch.from_numpy(self.betas).float()
self.alphas = torch.from_numpy(self.alphas).float()
self.alphas_cumprod = torch.from_numpy(self.alphas_cumprod).float()
self.alphas_cumprod_prev = torch.from_numpy(self.alphas_cumprod_prev).float()
# 预计算一些值
self.sqrt_alphas = torch.sqrt(self.alphas)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# 后验方差
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
def _linear_beta_schedule(self, timesteps, beta_start=0.0001, beta_end=0.02):
return np.linspace(beta_start, beta_end, timesteps)
def p_mean_variance(self, x_t, t):
"""
计算反向过程 p(x_{t-1} | x_t) 的均值和方差
参数:
x_t: 当前时刻的数据 [B, C, H, W]
t: 时间步 [B]
返回:
mean: 均值
variance: 方差
log_variance: 对数方差
"""
# 预测噪声
eps_pred = self.model(x_t, t)
# 计算均值: μ_θ = (x_t - β_t/√(1-ᾱ_t) * ε_θ) / √α_t
beta_t = self._extract(self.betas, t, x_t.shape)
sqrt_alpha_t = self._extract(self.sqrt_alphas, t, x_t.shape)
sqrt_one_minus_alpha_cumprod_t = self._extract(
self.sqrt_one_minus_alphas_cumprod, t, x_t.shape
)
mean = (x_t - beta_t / sqrt_one_minus_alpha_cumprod_t * eps_pred) / sqrt_alpha_t
# 方差
variance = self._extract(self.posterior_variance, t, x_t.shape)
log_variance = torch.log(variance)
return mean, variance, log_variance
def p_sample(self, x_t, t):
"""
从 p(x_{t-1} | x_t) 采样
"""
mean, variance, log_variance = self.p_mean_variance(x_t, t)
# 只在 t > 0 时添加噪声
noise = torch.randn_like(x_t)
nonzero_mask = (t != 0).float().view(-1, 1, 1, 1) # 重塑张量形状
x_prev = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
return x_prev
def sample(self, shape, device='cpu'):
"""
生成样本
参数:
shape: 输出形状 (B, C, H, W)
device: 计算设备
返回:
x_0: 生成的样本
"""
self.model.eval() # eval()评估模式
# 从纯噪声开始
x_t = torch.randn(shape, device=device)
# 反向迭代
for t in reversed(range(self.timesteps)):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
x_t = self.p_sample(x_t, t_batch)
return x_t
def _extract(self, a, t, x_shape):
"""提取对应时间步的值"""
batch_size = t.shape[0]
out = a.to(t.device).gather(0, t).float()
return out.view(batch_size, *((1,) * (len(x_shape) - 1)))
# 简单的噪声预测网络示例
class SimpleUNet(nn.Module): # 继承nn.Module定义网络层
"""简化的UNet用于噪声预测"""
def __init__(self, in_channels=3, model_channels=64, time_emb_dim=128):
super().__init__() # super()调用父类方法
# 时间嵌入
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
)
# 简化的编码器-解码器结构
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, model_channels, 3, padding=1),
nn.ReLU(),
nn.Conv2d(model_channels, model_channels, 3, padding=1),
nn.ReLU(),
)
self.middle = nn.Sequential(
nn.Conv2d(model_channels, model_channels, 3, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Conv2d(model_channels, model_channels, 3, padding=1),
nn.ReLU(),
nn.Conv2d(model_channels, in_channels, 3, padding=1),
)
def forward(self, x, t):
# 简化的前向传播
h = self.encoder(x)
h = self.middle(h)
out = self.decoder(h)
return out
5. 可视化反向过程¶
import matplotlib.pyplot as plt
def visualize_reverse_process(model, diffusion, shape, save_path='reverse_process.png'):
"""
可视化反向扩散过程
"""
model.eval()
# 从纯噪声开始
x = torch.randn(shape)
# 存储中间结果
intermediates = [x.clone()]
# 反向迭代,每隔一定步数保存
save_steps = [999, 900, 800, 600, 400, 200, 100, 50, 0]
with torch.no_grad(): # 禁用梯度计算,节省内存
for t in reversed(range(diffusion.timesteps)):
t_batch = torch.full((shape[0],), t, dtype=torch.long)
x = diffusion.p_sample(x, t_batch)
if t in save_steps:
intermediates.append(x.clone())
# 可视化
n_show = len(intermediates)
fig, axes = plt.subplots(1, n_show, figsize=(3*n_show, 3))
for i, (ax, img) in enumerate(zip(axes, intermediates)): # enumerate同时获取索引和元素 # zip按位置配对
# 转换为可显示的格式
img_np = img[0].permute(1, 2, 0).cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
ax.imshow(img_np)
step = save_steps[i] if i < len(save_steps) else 0
ax.set_title(f't={step}')
ax.axis('off')
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.show()
# 使用示例(需要训练好的模型)
# model = SimpleUNet()
# diffusion = ReverseDiffusion(model)
# visualize_reverse_process(model, diffusion, shape=(1, 3, 32, 32))
6. 本章总结¶
核心概念¶
- 反向过程
- 学习的马尔可夫链
- 从噪声逐步恢复数据
-
也是高斯分布(当步长足够小)
-
后验分布
- \(q(x_{t-1} | x_t, x_0)\) 有闭合形式
- 均值依赖于 \(x_0\) 和 \(x_t\)
-
方差由前向过程决定
-
神经网络参数化
- 预测噪声 \(\epsilon\) 效果最好
- 从噪声预测推导均值
- 方差通常固定
关键公式¶
| 概念 | 公式 |
|---|---|
| 反向过程 | \(p_\theta(x_{t-1} \mid x_t) = \mathcal{N}(\mu_\theta, \Sigma_\theta)\) |
| 后验均值 | \(\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_t\) |
| 预测均值 | \(\mu_\theta = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta)\) |
| 采样 | \(x_{t-1} = \mu_\theta + \sqrt{\beta_t} \cdot z\) |
算法流程¶
def sample(model, timesteps):
x = randn(shape) # 从噪声开始
for t in reversed(range(timesteps)):
eps = model(x, t) # 预测噪声
x = denoise(x, eps, t) # 去噪
return x # 返回生成结果
📝 自测问题¶
基础问题¶
- 反向过程的特性
- 反向过程是固定的还是学习的?为什么?
- 为什么反向过程也是高斯分布?
-
神经网络在反向过程中起什么作用?
-
后验分布
- 解释 \(q(x_{t-1} | x_t, x_0)\) 的含义
- 为什么需要已知 \(x_0\) 才能计算后验?
-
后验均值公式中各项的含义是什么?
-
参数化选择
- 为什么DDPM选择预测噪声而不是直接预测 \(x_0\)?
- 如何从噪声预测推导均值?
- 方差为什么要固定?
编程练习¶
- 实现完整的反向采样算法
- 可视化不同时间步的去噪效果
- 比较不同方差选择的效果
思考题¶
- 如果前向过程的步长很大,反向过程还能用高斯近似吗?
- 为什么采样时只在 \(t > 0\) 时添加噪声?
- 如何改进反向过程以提高采样速度?
🔗 下一步¶
理解了反向过程后,我们将学习训练目标的推导,了解如何训练噪声预测网络。
→ 下一步:04-训练目标推导.md