跳转至

04 - 训练目标推导

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 理解扩散模型如何训练的核心


🎯 学习目标

完成本章后,你将能够: - 理解变分下界(ELBO)的推导过程 - 掌握简化训练目标的由来 - 理解为什么预测噪声比预测图像更好 - 实现完整的训练损失函数 - 了解不同训练目标的优缺点


1. 问题设定

1.1 生成模型的目标

生成模型的目标是学习数据分布 \(p(x)\),使得我们可以从该分布中采样生成新数据。

扩散模型的方法: - 定义一个前向过程 \(q(x_{0:T})\),将数据逐步转化为噪声 - 学习一个反向过程 \(p_\theta(x_{0:T})\),从噪声恢复数据 - 通过最大化似然(或变分下界)来训练模型

1.2 扩散模型的联合分布

前向过程(固定): $\(q(x_{0:T}) = q(x_0) \prod_{t=1}^T q(x_t | x_{t-1})\)$

反向过程(学习): $\(p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} | x_t)\)$

其中: - \(p(x_T) = \mathcal{N}(0, \mathbf{I})\)(先验分布) - \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\)


2. 变分下界(ELBO)推导

2.1 最大似然估计

我们希望最大化数据的对数似然: $\(\log p_\theta(x_0)\)$

但直接计算需要积分: $\(p_\theta(x_0) = \int p_\theta(x_{0:T}) dx_{1:T}\)$

这个积分是高维的,难以直接计算。

2.2 引入变分分布

技巧:引入前向过程作为变分分布 \(q(x_{1:T} | x_0)\),使用Jensen不等式

\[\log p_\theta(x_0) = \log \int p_\theta(x_{0:T}) dx_{1:T}\]
\[= \log \int q(x_{1:T} | x_0) \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} dx_{1:T}\]
\[= \log \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} \right]\]

由Jensen不等式(对数函数是凹函数): $\(\geq \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} \right]\)$

这就是变分下界(Evidence Lower Bound, ELBO)

2.3 ELBO的展开

\[\text{ELBO} = \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \log p(x_T) + \sum_{t=1}^T \log \frac{p_\theta(x_{t-1} | x_t)}{q(x_t | x_{t-1})} \right]\]

可以重写为: $\(= \mathbb{E}_q \left[ \log p(x_T) - \sum_{t=1}^T \log q(x_t | x_{t-1}) + \sum_{t=1}^T \log p_\theta(x_{t-1} | x_t) \right]\)$

2.4 进一步分解

利用贝叶斯定理: $\(q(x_t | x_{t-1}) = \frac{q(x_{t-1} | x_t, x_0) q(x_t | x_0)}{q(x_{t-1} | x_0)}\)$

经过推导(详见附录),ELBO可以分解为:

\[\text{ELBO} = \mathbb{E}_q \left[ \underbrace{\log p_\theta(x_0 | x_1)}_{L_0} - \underbrace{D_{KL}(q(x_T | x_0) \| p(x_T))}_{L_T} - \sum_{t=2}^T \underbrace{D_{KL}(q(x_{t-1} | x_t, x_0) \| p_\theta(x_{t-1} | x_t))}_{L_{t-1}} \right]\]

这三项分别表示: - \(L_0\):重构项(最后一跳) - \(L_T\):先验匹配项(第一项) - \(L_{t-1}\):一致性项(中间步骤)


3. 各项的详细分析

3.1 重构项 \(L_0\)

\[L_0 = \mathbb{E}_q [\log p_\theta(x_0 | x_1)]\]
  • 表示从 \(x_1\) 重构 \(x_0\) 的质量
  • 类似于VAE中的重构损失
  • 通常用MSE或BCE近似

3.2 先验匹配项 \(L_T\)

\[L_T = D_{KL}(q(x_T | x_0) \| p(x_T))\]
  • \(q(x_T | x_0) \approx \mathcal{N}(0, \mathbf{I})\)(当T足够大)
  • \(p(x_T) = \mathcal{N}(0, \mathbf{I})\)
  • 因此 \(L_T \approx 0\),可以忽略

3.3 一致性项 \(L_{t-1}\)(核心)

\[L_{t-1} = D_{KL}(q(x_{t-1} | x_t, x_0) \| p_\theta(x_{t-1} | x_t))\]

这是两个高斯分布之间的KL散度

回忆: - \(q(x_{t-1} | x_t, x_0) = \mathcal{N}(\tilde{\mu}_t, \tilde{\beta}_t \mathbf{I})\)(后验,已知) - \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta, \Sigma_\theta)\)(模型,学习)

3.4 高斯分布的KL散度

对于两个高斯分布 \(\mathcal{N}(\mu_1, \sigma_1^2)\)\(\mathcal{N}(\mu_2, \sigma_2^2)\)

\[D_{KL}(\mathcal{N}_1 \| \mathcal{N}_2) = \log\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}\]

在我们的情况下: - 如果固定 \(\Sigma_\theta = \tilde{\beta}_t \mathbf{I}\)(与后验相同) - 则KL散度简化为:

\[L_{t-1} = \frac{1}{2\tilde{\beta}_t} \| \tilde{\mu}_t - \mu_\theta \|^2\]

这就是均方误差损失!


4. 训练目标的简化

4.1 从预测 \(x_0\) 到预测噪声

原始后验均值: $\(\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_t\)$

模型预测的均值: $\(\mu_\theta = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)\)$

4.2 简化推导

\(\tilde{\mu}_t\) 也用噪声 \(\epsilon\) 表示:

\(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\),得到: $\(x_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon}{\sqrt{\bar{\alpha}_t}}\)$

代入 \(\tilde{\mu}_t\): $\(\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon\right)\)$

4.3 最终的训练目标

因此,损失函数变为: $\(L_{t-1} = \mathbb{E}_{x_0, \epsilon, t} \left[ \frac{\beta_t^2}{2\tilde{\beta}_t \alpha_t (1-\bar{\alpha}_t)} \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]\)$

(注:根据DDPM论文公式(16),权重系数为 \(\frac{\beta_t^2}{2\tilde{\beta}_t \alpha_t (1-\bar{\alpha}_t)}\),其中 \(\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t\),分母中的 \(\alpha_t\) 来自均值差异 \(\tilde{\mu}_t - \mu_\theta\) 的系数 \(\frac{\beta_t}{\sqrt{\alpha_t}\sqrt{1-\bar{\alpha}_t}}\) 的平方)

DDPM的简化版本(去掉权重): $\(\mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, \epsilon, t} \left[ \| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, t) \|^2 \right]\)$

这就是简单的均方误差

4.4 为什么简化有效?

  1. 权重项的影响不大:实验发现,使用统一权重效果也很好
  2. 数值稳定性:简化版本更稳定
  3. 训练效率:不需要计算复杂的权重

5. 完整的训练算法

5.1 训练流程

Text Only
算法: DDPM训练
─────────────────────────────────
重复直到收敛:

  1. 从数据分布采样: x_0 ~ q(x_0)

  2. 随机选择时间步: t ~ Uniform({1, 2, ..., T})

  3. 采样噪声: ε ~ N(0, I)

  4. 计算加噪图像:
     x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε

  5. 预测噪声: ε_θ(x_t, t)

  6. 计算损失: L = ||ε - ε_θ||²

  7. 梯度下降更新参数 θ

5.2 PyTorch实现

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class DDPMTrainer:
    """
    DDPM训练器
    """
    def __init__(self, model, timesteps=1000, beta_schedule='linear', lr=1e-4):
        """
        参数:
            model: 噪声预测网络
            timesteps: 扩散步数
            beta_schedule: 噪声调度策略
            lr: 学习率
        """
        self.model = model
        self.timesteps = timesteps

        # 设置beta调度
        self.betas = self._get_beta_schedule(beta_schedule, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

        # 优化器
        self.optimizer = optim.Adam(model.parameters(), lr=lr)

        # 记录损失
        self.losses = []

    def _get_beta_schedule(self, schedule, timesteps):
        """获取beta调度"""
        if schedule == 'linear':
            return torch.linspace(0.0001, 0.02, timesteps)
        elif schedule == 'cosine':
            # 余弦调度实现
            s = 0.008
            steps = timesteps + 1
            x = torch.linspace(0, timesteps, steps)
            alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi / 2) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            return torch.clip(betas, 0.0001, 0.9999)
        else:
            raise ValueError(f"Unknown schedule: {schedule}")

    def q_sample(self, x_0, t, noise=None):
        """
        前向扩散过程: 从 q(x_t | x_0) 采样

        参数:
            x_0: 原始图像 [B, C, H, W]
            t: 时间步 [B]
            noise: 可选的噪声

        返回:
            x_t: 加噪后的图像
            noise: 使用的噪声
        """
        if noise is None:
            noise = torch.randn_like(x_0)

        # 获取对应时间步的值
        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)  # 重塑张量形状
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        # 重参数化
        x_t = sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise

        return x_t, noise

    def training_step(self, x_0):
        """
        单步训练

        参数:
            x_0: 原始图像批次 [B, C, H, W]

        返回:
            loss: 损失值
        """
        self.model.train()  # train()训练模式
        batch_size = x_0.shape[0]
        device = x_0.device

        # 将调度参数移到设备
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)  # 移至GPU/CPU
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)

        # 1. 随机选择时间步
        t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()

        # 2. 采样噪声
        noise = torch.randn_like(x_0)

        # 3. 计算加噪图像
        x_t, _ = self.q_sample(x_0, t, noise)

        # 4. 预测噪声
        noise_pred = self.model(x_t, t)

        # 5. 计算损失 (MSE)
        loss = nn.functional.mse_loss(noise_pred, noise)

        # 6. 反向传播
        self.optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度

        # 梯度裁剪(防止爆炸)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

        self.optimizer.step()  # 更新参数

        return loss.item()  # 将单元素张量转为Python数值

    def train_epoch(self, dataloader):
        """
        训练一个epoch

        参数:
            dataloader: 数据加载器

        返回:
            avg_loss: 平均损失
        """
        epoch_losses = []

        for batch_idx, (x_0, _) in enumerate(dataloader):  # enumerate同时获取索引和元素
            # 移动数据到设备
            x_0 = x_0.to(next(self.model.parameters()).device)

            # 训练步骤
            loss = self.training_step(x_0)
            epoch_losses.append(loss)

            # 打印进度
            if batch_idx % 100 == 0:
                print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss:.4f}")

        avg_loss = sum(epoch_losses) / len(epoch_losses)
        self.losses.extend(epoch_losses)

        return avg_loss

# 使用示例
def train_ddpm():
    """训练DDPM的完整示例"""

    # 假设有数据加载器
    # dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

    # 创建模型(这里使用简化版UNet)
    # model = UNet(in_channels=3, out_channels=3)

    # 创建训练器
    # trainer = DDPMTrainer(model, timesteps=1000)

    # 训练循环
    # for epoch in range(num_epochs):
    #     avg_loss = trainer.train_epoch(dataloader)
    #     print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f}")

    pass

6. 训练技巧和注意事项

6.1 学习率调度

Python
# 使用余弦退火
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# 或使用预热
from torch.optim.lr_scheduler import LambdaLR

def warmup_scheduler(epoch, warmup_epochs=10):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_scheduler(epoch))  # lambda匿名函数

6.2 指数移动平均(EMA)

Python
class EMA:
    """指数移动平均"""
    def __init__(self, model, decay=0.9999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        # 初始化shadow参数
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        """更新shadow参数"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data

    def apply_shadow(self):
        """使用shadow参数"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        """恢复原始参数"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

# 使用
ema = EMA(model)

# 训练循环中
for epoch in range(num_epochs):
    for x_0 in dataloader:
        loss = trainer.training_step(x_0)
        ema.update()

    # 采样时使用EMA参数
    ema.apply_shadow()
    # sample(...)
    ema.restore()

6.3 混合精度训练

Python
from torch.amp import autocast, GradScaler

scaler = GradScaler()

def training_step_mixed_precision(self, x_0):
    """混合精度训练步骤"""
    self.model.train()
    batch_size = x_0.shape[0]
    device = x_0.device

    t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()
    noise = torch.randn_like(x_0)
    x_t, _ = self.q_sample(x_0, t, noise)

    # 使用自动混合精度
    with autocast('cuda'):
        noise_pred = self.model(x_t, t)
        loss = nn.functional.mse_loss(noise_pred, noise)

    # 缩放梯度
    self.optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.unscale_(self.optimizer)
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
    scaler.step(self.optimizer)
    scaler.update()

    return loss.item()

7. 本章总结

核心概念

  1. 变分下界(ELBO)
  2. 通过Jensen不等式得到似然的下界
  3. 分解为重构项、先验匹配项和一致性项

  4. 训练目标

  5. 核心是KL散度:\(D_{KL}(q(x_{t-1}|x_t,x_0) \| p_\theta(x_{t-1}|x_t))\)
  6. 简化为预测噪声的MSE损失

  7. 简化目标

  8. \(\mathcal{L}_{\text{simple}} = \mathbb{E}[\| \epsilon - \epsilon_\theta \|^2]\)
  9. 简单但有效

关键公式

概念 公式
ELBO \(\mathbb{E}_q[\log p_\theta(x_0) - D_{KL}(q\|p)]\)
训练损失 \(\mathcal{L} = \mathbb{E}_{t,x_0,\epsilon}[\|\epsilon - \epsilon_\theta(x_t, t)\|^2]\)
加噪过程 \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon\)

训练流程

Text Only
1. 采样数据 x_0
2. 随机选择时间步 t
3. 采样噪声 ε
4. 计算 x_t
5. 预测噪声 ε_θ
6. 计算 MSE 损失
7. 反向传播更新

📝 自测问题

基础问题

  1. ELBO推导
  2. 为什么要引入变分下界?
  3. Jensen不等式在推导中起什么作用?
  4. ELBO的三项分别代表什么含义?

  5. 训练目标

  6. 为什么最终的训练目标是预测噪声?
  7. 简化目标去掉了哪些项?为什么有效?
  8. 如果直接预测 \(x_0\) 会怎样?

  9. 实现细节

  10. 为什么要随机选择时间步?
  11. 梯度裁剪的作用是什么?
  12. EMA为什么能提升效果?

编程练习

  1. 实现完整的DDPM训练循环
  2. 添加学习率调度和EMA
  3. 实现混合精度训练
  4. 可视化训练过程中的损失变化

思考题

  1. 训练目标中的权重项有什么作用?去掉会怎样?
  2. 为什么扩散模型训练稳定,不容易出现模式崩溃?
  3. 如何设计更好的训练目标?

🔗 下一步

理解了训练目标后,我们将学习采样算法详解,了解如何从训练好的模型中生成图像。

→ 下一步:05-采样算法详解.md