04 - 训练目标推导¶
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 理解扩散模型如何训练的核心
🎯 学习目标¶
完成本章后,你将能够: - 理解变分下界(ELBO)的推导过程 - 掌握简化训练目标的由来 - 理解为什么预测噪声比预测图像更好 - 实现完整的训练损失函数 - 了解不同训练目标的优缺点
1. 问题设定¶
1.1 生成模型的目标¶
生成模型的目标是学习数据分布 \(p(x)\),使得我们可以从该分布中采样生成新数据。
扩散模型的方法: - 定义一个前向过程 \(q(x_{0:T})\),将数据逐步转化为噪声 - 学习一个反向过程 \(p_\theta(x_{0:T})\),从噪声恢复数据 - 通过最大化似然(或变分下界)来训练模型
1.2 扩散模型的联合分布¶
前向过程(固定): $\(q(x_{0:T}) = q(x_0) \prod_{t=1}^T q(x_t | x_{t-1})\)$
反向过程(学习): $\(p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} | x_t)\)$
其中: - \(p(x_T) = \mathcal{N}(0, \mathbf{I})\)(先验分布) - \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)
2. 变分下界(ELBO)推导¶
2.1 最大似然估计¶
我们希望最大化数据的对数似然: $\(\log p_\theta(x_0)\)$
但直接计算需要积分: $\(p_\theta(x_0) = \int p_\theta(x_{0:T}) dx_{1:T}\)$
这个积分是高维的,难以直接计算。
2.2 引入变分分布¶
技巧:引入前向过程作为变分分布 \(q(x_{1:T} | x_0)\),使用Jensen不等式。
由Jensen不等式(对数函数是凹函数): $\(\geq \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} \right]\)$
这就是变分下界(Evidence Lower Bound, ELBO)。
2.3 ELBO的展开¶
可以重写为: $\(= \mathbb{E}_q \left[ \log p(x_T) - \sum_{t=1}^T \log q(x_t | x_{t-1}) + \sum_{t=1}^T \log p_\theta(x_{t-1} | x_t) \right]\)$
2.4 进一步分解¶
利用贝叶斯定理: $\(q(x_t | x_{t-1}) = \frac{q(x_{t-1} | x_t, x_0) q(x_t | x_0)}{q(x_{t-1} | x_0)}\)$
经过推导(详见附录),ELBO可以分解为:
这三项分别表示: - \(L_0\):重构项(最后一跳) - \(L_T\):先验匹配项(第一项) - \(L_{t-1}\):一致性项(中间步骤)
3. 各项的详细分析¶
3.1 重构项 \(L_0\)¶
- 表示从 \(x_1\) 重构 \(x_0\) 的质量
- 类似于VAE中的重构损失
- 通常用MSE或BCE近似
3.2 先验匹配项 \(L_T\)¶
- \(q(x_T | x_0) \approx \mathcal{N}(0, \mathbf{I})\)(当T足够大)
- \(p(x_T) = \mathcal{N}(0, \mathbf{I})\)
- 因此 \(L_T \approx 0\),可以忽略
3.3 一致性项 \(L_{t-1}\)(核心)¶
这是两个高斯分布之间的KL散度!
回忆: - \(q(x_{t-1} | x_t, x_0) = \mathcal{N}(\tilde{\mu}_t, \tilde{\beta}_t \mathbf{I})\)(后验,已知) - \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta, \Sigma_\theta)\)(模型,学习)
3.4 高斯分布的KL散度¶
对于两个高斯分布 \(\mathcal{N}(\mu_1, \sigma_1^2)\) 和 \(\mathcal{N}(\mu_2, \sigma_2^2)\):
在我们的情况下: - 如果固定 \(\Sigma_\theta = \tilde{\beta}_t \mathbf{I}\)(与后验相同) - 则KL散度简化为:
这就是均方误差损失!
4. 训练目标的简化¶
4.1 从预测 \(x_0\) 到预测噪声¶
原始后验均值: $\(\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_t\)$
模型预测的均值: $\(\mu_\theta = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)\)$
4.2 简化推导¶
将 \(\tilde{\mu}_t\) 也用噪声 \(\epsilon\) 表示:
从 \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\),得到: $\(x_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon}{\sqrt{\bar{\alpha}_t}}\)$
代入 \(\tilde{\mu}_t\): $\(\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon\right)\)$
4.3 最终的训练目标¶
因此,损失函数变为: $\(L_{t-1} = \mathbb{E}_{x_0, \epsilon, t} \left[ \frac{\beta_t^2}{2\tilde{\beta}_t \alpha_t (1-\bar{\alpha}_t)} \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]\)$
(注:根据DDPM论文公式(16),权重系数为 \(\frac{\beta_t^2}{2\tilde{\beta}_t \alpha_t (1-\bar{\alpha}_t)}\),其中 \(\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t\),分母中的 \(\alpha_t\) 来自均值差异 \(\tilde{\mu}_t - \mu_\theta\) 的系数 \(\frac{\beta_t}{\sqrt{\alpha_t}\sqrt{1-\bar{\alpha}_t}}\) 的平方)
DDPM的简化版本(去掉权重): $\(\mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, \epsilon, t} \left[ \| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, t) \|^2 \right]\)$
这就是简单的均方误差!
4.4 为什么简化有效?¶
- 权重项的影响不大:实验发现,使用统一权重效果也很好
- 数值稳定性:简化版本更稳定
- 训练效率:不需要计算复杂的权重
5. 完整的训练算法¶
5.1 训练流程¶
算法: DDPM训练
─────────────────────────────────
重复直到收敛:
1. 从数据分布采样: x_0 ~ q(x_0)
2. 随机选择时间步: t ~ Uniform({1, 2, ..., T})
3. 采样噪声: ε ~ N(0, I)
4. 计算加噪图像:
x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
5. 预测噪声: ε_θ(x_t, t)
6. 计算损失: L = ||ε - ε_θ||²
7. 梯度下降更新参数 θ
5.2 PyTorch实现¶
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
class DDPMTrainer:
"""
DDPM训练器
"""
def __init__(self, model, timesteps=1000, beta_schedule='linear', lr=1e-4):
"""
参数:
model: 噪声预测网络
timesteps: 扩散步数
beta_schedule: 噪声调度策略
lr: 学习率
"""
self.model = model
self.timesteps = timesteps
# 设置beta调度
self.betas = self._get_beta_schedule(beta_schedule, timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# 优化器
self.optimizer = optim.Adam(model.parameters(), lr=lr)
# 记录损失
self.losses = []
def _get_beta_schedule(self, schedule, timesteps):
"""获取beta调度"""
if schedule == 'linear':
return torch.linspace(0.0001, 0.02, timesteps)
elif schedule == 'cosine':
# 余弦调度实现
s = 0.008
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi / 2) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
else:
raise ValueError(f"Unknown schedule: {schedule}")
def q_sample(self, x_0, t, noise=None):
"""
前向扩散过程: 从 q(x_t | x_0) 采样
参数:
x_0: 原始图像 [B, C, H, W]
t: 时间步 [B]
noise: 可选的噪声
返回:
x_t: 加噪后的图像
noise: 使用的噪声
"""
if noise is None:
noise = torch.randn_like(x_0)
# 获取对应时间步的值
sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) # 重塑张量形状
sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
# 重参数化
x_t = sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise
return x_t, noise
def training_step(self, x_0):
"""
单步训练
参数:
x_0: 原始图像批次 [B, C, H, W]
返回:
loss: 损失值
"""
self.model.train() # train()训练模式
batch_size = x_0.shape[0]
device = x_0.device
# 将调度参数移到设备
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device) # 移至GPU/CPU
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
# 1. 随机选择时间步
t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()
# 2. 采样噪声
noise = torch.randn_like(x_0)
# 3. 计算加噪图像
x_t, _ = self.q_sample(x_0, t, noise)
# 4. 预测噪声
noise_pred = self.model(x_t, t)
# 5. 计算损失 (MSE)
loss = nn.functional.mse_loss(noise_pred, noise)
# 6. 反向传播
self.optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
# 梯度裁剪(防止爆炸)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step() # 更新参数
return loss.item() # 将单元素张量转为Python数值
def train_epoch(self, dataloader):
"""
训练一个epoch
参数:
dataloader: 数据加载器
返回:
avg_loss: 平均损失
"""
epoch_losses = []
for batch_idx, (x_0, _) in enumerate(dataloader): # enumerate同时获取索引和元素
# 移动数据到设备
x_0 = x_0.to(next(self.model.parameters()).device)
# 训练步骤
loss = self.training_step(x_0)
epoch_losses.append(loss)
# 打印进度
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss:.4f}")
avg_loss = sum(epoch_losses) / len(epoch_losses)
self.losses.extend(epoch_losses)
return avg_loss
# 使用示例
def train_ddpm():
"""训练DDPM的完整示例"""
# 假设有数据加载器
# dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
# 创建模型(这里使用简化版UNet)
# model = UNet(in_channels=3, out_channels=3)
# 创建训练器
# trainer = DDPMTrainer(model, timesteps=1000)
# 训练循环
# for epoch in range(num_epochs):
# avg_loss = trainer.train_epoch(dataloader)
# print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f}")
pass
6. 训练技巧和注意事项¶
6.1 学习率调度¶
# 使用余弦退火
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
# 或使用预热
from torch.optim.lr_scheduler import LambdaLR
def warmup_scheduler(epoch, warmup_epochs=10):
if epoch < warmup_epochs:
return epoch / warmup_epochs
return 1.0
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_scheduler(epoch)) # lambda匿名函数
6.2 指数移动平均(EMA)¶
class EMA:
"""指数移动平均"""
def __init__(self, model, decay=0.9999):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
# 初始化shadow参数
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
"""更新shadow参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
def apply_shadow(self):
"""使用shadow参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
"""恢复原始参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
param.data = self.backup[name]
self.backup = {}
# 使用
ema = EMA(model)
# 训练循环中
for epoch in range(num_epochs):
for x_0 in dataloader:
loss = trainer.training_step(x_0)
ema.update()
# 采样时使用EMA参数
ema.apply_shadow()
# sample(...)
ema.restore()
6.3 混合精度训练¶
from torch.amp import autocast, GradScaler
scaler = GradScaler()
def training_step_mixed_precision(self, x_0):
"""混合精度训练步骤"""
self.model.train()
batch_size = x_0.shape[0]
device = x_0.device
t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()
noise = torch.randn_like(x_0)
x_t, _ = self.q_sample(x_0, t, noise)
# 使用自动混合精度
with autocast('cuda'):
noise_pred = self.model(x_t, t)
loss = nn.functional.mse_loss(noise_pred, noise)
# 缩放梯度
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
scaler.step(self.optimizer)
scaler.update()
return loss.item()
7. 本章总结¶
核心概念¶
- 变分下界(ELBO)
- 通过Jensen不等式得到似然的下界
-
分解为重构项、先验匹配项和一致性项
-
训练目标
- 核心是KL散度:\(D_{KL}(q(x_{t-1}|x_t,x_0) \| p_\theta(x_{t-1}|x_t))\)
-
简化为预测噪声的MSE损失
-
简化目标
- \(\mathcal{L}_{\text{simple}} = \mathbb{E}[\| \epsilon - \epsilon_\theta \|^2]\)
- 简单但有效
关键公式¶
| 概念 | 公式 |
|---|---|
| ELBO | \(\mathbb{E}_q[\log p_\theta(x_0) - D_{KL}(q\|p)]\) |
| 训练损失 | \(\mathcal{L} = \mathbb{E}_{t,x_0,\epsilon}[\|\epsilon - \epsilon_\theta(x_t, t)\|^2]\) |
| 加噪过程 | \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\) |
训练流程¶
📝 自测问题¶
基础问题¶
- ELBO推导
- 为什么要引入变分下界?
- Jensen不等式在推导中起什么作用?
-
ELBO的三项分别代表什么含义?
-
训练目标
- 为什么最终的训练目标是预测噪声?
- 简化目标去掉了哪些项?为什么有效?
-
如果直接预测 \(x_0\) 会怎样?
-
实现细节
- 为什么要随机选择时间步?
- 梯度裁剪的作用是什么?
- EMA为什么能提升效果?
编程练习¶
- 实现完整的DDPM训练循环
- 添加学习率调度和EMA
- 实现混合精度训练
- 可视化训练过程中的损失变化
思考题¶
- 训练目标中的权重项有什么作用?去掉会怎样?
- 为什么扩散模型训练稳定,不容易出现模式崩溃?
- 如何设计更好的训练目标?
🔗 下一步¶
理解了训练目标后,我们将学习采样算法详解,了解如何从训练好的模型中生成图像。
→ 下一步:05-采样算法详解.md