跳转至

05 - 采样算法详解

学习时间: 3小时 重要性: ⭐⭐⭐⭐⭐ 从噪声生成高质量图像的关键


🎯 学习目标

完成本章后,你将能够: - 理解扩散模型采样的基本原理 - 掌握DDPM采样算法的实现 - 了解DDIM等加速采样算法 - 理解采样步数与生成质量的关系 - 实现并优化采样过程


1. 采样概述

1.1 什么是采样?

采样是从训练好的扩散模型中生成新图像的过程。

类比理解: - 训练:学习如何从噪声中恢复图像(学习"去噪") - 采样:应用学到的去噪能力,从纯噪声生成图像

1.2 采样的基本流程

Text Only
纯噪声 x_T
    ↓ 应用模型 ε_θ(x_T, T)
预测噪声
    ↓ 去噪
x_{T-1}
    ↓ 应用模型 ε_θ(x_{T-1}, T-1)
预测噪声
    ↓ 去噪
...
x_0 (生成的图像)

2. DDPM采样算法

2.1 数学原理

DDPM(Denoising Diffusion Probabilistic Models)使用马尔可夫链进行采样。

核心公式

给定 \(x_t\),如何得到 \(x_{t-1}\)

\[q(x_{t-1} | x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t \mathbf{I})\]

其中: $\(\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t\)$

\[\tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_t\]

2.2 实际采样公式

由于我们不知道 \(x_0\),需要用模型预测的噪声来估计:

\[x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right) + \sqrt{\beta_t} \mathbf{z}\]

其中: - \(\epsilon_\theta(x_t, t)\):模型预测的噪声 - \(\mathbf{z} \sim \mathcal{N}(0, \mathbf{I})\):随机噪声(仅在 \(t > 1\) 时添加) - \(\beta_t = 1 - \alpha_t\):方差调度参数 - \(\sqrt{\beta_t}\):噪声标准差

2.3 代码实现

Python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

def get_schedule(T, beta_start=0.0001, beta_end=0.02):
    """
    创建噪声调度表

    参数:
        T: 总步数
        beta_start: 初始beta值
        beta_end: 最终beta值

    返回:
        alphas, betas, alphas_cumprod
    """
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    return alphas, betas, alphas_cumprod

def ddpm_sample(model, x_T, T, alphas, betas, alphas_cumprod, device='cpu'):
    """
    DDPM采样算法

    参数:
        model: 训练好的去噪模型
        x_T: 初始噪声 [batch_size, channels, height, width]
        T: 总步数
        alphas, betas, alphas_cumprod: 调度表
        device: 设备

    返回:
        生成的图像 x_0
    """
    model.eval()  # eval()评估模式
    x_t = x_T.to(device)  # 移至GPU/CPU

    # 预计算一些值
    alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])  # torch.cat沿已有维度拼接张量

    with torch.no_grad():  # 禁用梯度计算,节省内存
        for t in reversed(range(T)):
            # 当前时间步的参数
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_t_bar = alphas_cumprod[t]
            alpha_t_bar_prev = alphas_cumprod_prev[t]

            # 预测噪声
            t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
            predicted_noise = model(x_t, t_tensor)

            # 计算均值
            sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
            sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)

            mean = sqrt_recip_alpha_t * (
                x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
            )

            # 添加随机噪声(最后一步不添加)
            if t > 0:
                posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                noise = torch.randn_like(x_t)
                x_t = mean + torch.sqrt(posterior_variance) * noise
            else:
                x_t = mean

    return x_t

# 示例:简单的UNet模型
class SimpleUNet(nn.Module):  # 继承nn.Module定义网络层
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()  # super()调用父类方法
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, out_channels, 3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # 简化版,实际UNet更复杂
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.conv4(x)
        return x

# 测试采样
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")

# 创建模型和调度表
T = 1000
model = SimpleUNet().to(device)
alphas, betas, alphas_cumprod = get_schedule(T)

# 生成初始噪声
batch_size = 4
x_T = torch.randn(batch_size, 3, 32, 32)

# 采样
print(f"开始采样,总步数: {T}")
x_0 = ddpm_sample(model, x_T, T, alphas, betas, alphas_cumprod, device)
print(f"采样完成,输出形状: {x_0.shape}")

3. 采样步数与质量的关系

3.1 步数的影响

步数 生成质量 采样速度 适用场景
少步数 (10-50) 较低 很快 快速预览、交互式应用
中等步数 (100-500) 较好 中等 平衡质量与速度
多步数 (1000+) 最好 较慢 高质量生成、艺术创作

3.2 可视化不同步数的效果

Python
def sample_with_steps(model, x_T, steps, alphas, betas, alphas_cumprod, device='cpu'):
    """
    使用指定步数进行采样

    参数:
        steps: 采样步数列表,例如 [10, 50, 100, 500, 1000]

    返回:
        不同步数下的生成结果
    """
    results = {}
    T = len(alphas)

    for num_steps in steps:
        # 计算子采样间隔
        step_indices = torch.linspace(0, T-1, num_steps, dtype=torch.long)

        x_t = x_T.clone()
        alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])

        with torch.no_grad():
            for i, t in enumerate(reversed(step_indices)):  # enumerate同时获取索引和元素
                t = t.item()  # 将单元素张量转为Python数值
                alpha_t = alphas[t]
                beta_t = betas[t]
                alpha_t_bar = alphas_cumprod[t]
                alpha_t_bar_prev = alphas_cumprod_prev[t]

                # 预测噪声
                t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
                predicted_noise = model(x_t, t_tensor)

                # 计算均值
                sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
                sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)

                mean = sqrt_recip_alpha_t * (
                    x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
                )

                # 添加噪声
                if i < len(step_indices) - 1:  # 不是最后一步
                    posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                    noise = torch.randn_like(x_t)
                    x_t = mean + torch.sqrt(posterior_variance) * noise
                else:
                    x_t = mean

        results[num_steps] = x_t.detach().cpu()  # 分离计算图,不参与梯度计算

    return results

# 测试不同步数
steps_list = [10, 50, 100, 500, 1000]
results = sample_with_steps(model, x_T, steps_list, alphas, betas, alphas_cumprod, device)

# 可视化
fig, axes = plt.subplots(1, len(steps_list), figsize=(15, 3))

for idx, steps in enumerate(steps_list):
    img = results[steps][0]  # 取第一张图
    img = img.permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # 归一化

    axes[idx].imshow(img)
    axes[idx].set_title(f'{steps} 步')
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

4. DDIM采样算法

4.1 DDIM vs DDPM

DDIM(Denoising Diffusion Implicit Models) 是DDPM的加速版本。

特性 DDPM DDIM
采样类型 随机采样 确定性采样
采样步数 需要很多步 可以少很多步
采样速度
生成质量 稍好 接近
确定性 否(每次不同) 是(相同输入相同输出)

4.2 DDIM数学原理

DDIM的核心思想是:不需要遵循马尔可夫链,可以跳步

DDIM更新公式

\[x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \left(\frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}\right) + \sqrt{1-\bar{\alpha}_{t-1}} \epsilon_\theta(x_t, t)\]

4.3 DDIM代码实现

Python
def ddim_sample(model, x_T, T, alphas, alphas_cumprod, num_steps=50, device='cpu'):
    """
    DDIM采样算法

    参数:
        model: 训练好的去噪模型
        x_T: 初始噪声
        T: 原始总步数
        alphas, alphas_cumprod: 调度表
        num_steps: 实际采样步数
        device: 设备

    返回:
        生成的图像 x_0
    """
    model.eval()
    x_t = x_T.to(device)

    # 选择采样时间步
    step_indices = torch.linspace(0, T-1, num_steps, dtype=torch.long)

    with torch.no_grad():
        for i, t in enumerate(reversed(step_indices)):
            t = t.item()
            alpha_t = alphas[t]
            alpha_t_bar = alphas_cumprod[t]

            # 预测噪声
            t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
            predicted_noise = model(x_t, t_tensor)

            # 计算x_0的预测
            sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
            x_0_pred = (x_t - sqrt_one_minus_alpha_t_bar * predicted_noise) / sqrt_alpha_t_bar

            # 计算x_{t-1}
            if i < len(step_indices) - 1:  # 不是最后一步
                t_prev = step_indices[len(step_indices) - 2 - i].item()
                alpha_t_prev_bar = alphas_cumprod[t_prev]

                sqrt_alpha_t_prev_bar = torch.sqrt(alpha_t_prev_bar)
                sqrt_one_minus_alpha_t_prev_bar = torch.sqrt(1 - alpha_t_prev_bar)

                # DDIM更新
                x_t = sqrt_alpha_t_prev_bar * x_0_pred + sqrt_one_minus_alpha_t_prev_bar * predicted_noise
            else:
                x_t = x_0_pred

    return x_t

# 测试DDIM
print("测试DDIM采样...")
for num_steps in [10, 20, 50, 100]:
    x_0_ddim = ddim_sample(model, x_T, T, alphas, alphas_cumprod, num_steps, device)
    print(f"DDIM {num_steps} 步采样完成,输出形状: {x_0_ddim.shape}")

5. 采样优化技巧

5.1 噪声调度优化

不同的噪声调度会影响生成质量:

Python
def get_cosine_schedule(T, s=0.008):
    """
    余弦噪声调度(Improved DDPM)

    参数:
        T: 总步数
        s: 偏移参数

    返回:
        alphas, betas, alphas_cumprod
    """
    steps = T + 1
    x = torch.linspace(0, T, steps)
    alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]

    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    alphas = 1 - betas

    # 裁剪beta值
    betas = torch.clip(betas, 0.0001, 0.9999)

    return alphas, betas, alphas_cumprod[:-1]

# 对比不同调度
T = 1000

# 线性调度
alphas_linear, betas_linear, alphas_cumprod_linear = get_schedule(T)

# 余弦调度
alphas_cosine, betas_cosine, alphas_cumprod_cosine = get_cosine_schedule(T)

# 可视化
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(betas_linear.numpy(), label='线性')
plt.plot(betas_cosine.numpy(), label='余弦')
plt.xlabel('时间步 t')
plt.ylabel('β_t')
plt.title('Beta 调度')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(alphas_cumprod_linear.numpy(), label='线性')
plt.plot(alphas_cumprod_cosine.numpy(), label='余弦')
plt.xlabel('时间步 t')
plt.ylabel('ᾱ_t')
plt.title('Alpha Bar 调度')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(np.sqrt(1 - alphas_cumprod_linear.numpy()), label='线性')
plt.plot(np.sqrt(1 - alphas_cumprod_cosine.numpy()), label='余弦')
plt.xlabel('时间步 t')
plt.ylabel('√(1-ᾱ_t)')
plt.title('噪声水平')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

5.2 分类器引导(Classifier Guidance)

使用分类器引导可以改善生成质量:

Python
def classifier_guided_sample(model, classifier, x_T, T, alphas, betas, alphas_cumprod,
                             guidance_scale=1.0, class_label=None, device='cpu'):
    """
    分类器引导采样

    参数:
        model: 去噪模型
        classifier: 分类器
        x_T: 初始噪声
        T: 总步数
        alphas, betas, alphas_cumprod: 调度表
        guidance_scale: 引导强度
        class_label: 目标类别
        device: 设备

    返回:
        生成的图像
    """
    model.eval()
    classifier.eval()
    x_t = x_T.to(device)

    alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])

    with torch.no_grad():
        for t in reversed(range(T)):
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_t_bar = alphas_cumprod[t]

            t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)

            # 预测无条件噪声
            predicted_noise = model(x_t, t_tensor)

            # 如果有分类器,计算梯度引导
            if classifier is not None and class_label is not None:
                # 计算分类器梯度
                x_t.requires_grad = True
                logits = classifier(x_t, t_tensor)
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                target_log_prob = log_probs[:, class_label].sum()
                target_log_prob.backward()  # 反向传播计算梯度

                grad = x_t.grad
                x_t.requires_grad = False

                # 添加引导
                predicted_noise = predicted_noise - guidance_scale * grad * torch.sqrt(1 - alpha_t_bar)

            # 更新
            sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)

            mean = sqrt_recip_alpha_t * (
                x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
            )

            if t > 0:
                alpha_t_bar_prev = alphas_cumprod_prev[t]
                posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                noise = torch.randn_like(x_t)
                x_t = mean + torch.sqrt(posterior_variance) * noise
            else:
                x_t = mean

    return x_t

5.3 无分类器引导(Classifier-Free Guidance)

更常用的引导方式,不需要额外的分类器:

Python
def classifier_free_guidance(model, x_T, T, alphas, betas, alphas_cumprod,
                            guidance_scale=7.5, condition=None, device='cpu'):
    """
    无分类器引导采样

    参数:
        model: 支持条件生成的模型
        x_T: 初始噪声
        T: 总步数
        alphas, betas, alphas_cumprod: 调度表
        guidance_scale: 引导强度(通常5-15)
        condition: 条件(如文本提示)
        device: 设备

    返回:
        生成的图像
    """
    model.eval()
    x_t = x_T.to(device)

    alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])

    with torch.no_grad():
        for t in reversed(range(T)):
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_t_bar = alphas_cumprod[t]

            t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)

            # 预测条件噪声和无条件噪声
            noise_cond = model(x_t, t_tensor, condition)
            noise_uncond = model(x_t, t_tensor, None)

            # 组合预测
            predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)

            # 更新
            sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)

            mean = sqrt_recip_alpha_t * (
                x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
            )

            if t > 0:
                alpha_t_bar_prev = alphas_cumprod_prev[t]
                posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                noise = torch.randn_like(x_t)
                x_t = mean + torch.sqrt(posterior_variance) * noise
            else:
                x_t = mean

    return x_t

6. 采样质量评估

6.1 视觉检查

Python
def visualize_samples(samples, nrow=4, title="生成样本"):
    """
    可视化生成的样本

    参数:
        samples: [batch_size, channels, height, width]
        nrow: 每行显示的图片数
        title: 标题
    """
    from torchvision.utils import make_grid

    # 归一化到[0, 1]
    samples = (samples - samples.min()) / (samples.max() - samples.min())

    # 创建网格
    grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)

    # 显示
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title(title)
    plt.axis('off')
    plt.show()

# 生成并可视化
samples = torch.randn(16, 3, 32, 32)  # 模拟生成结果
visualize_samples(samples, nrow=4, title="DDPM生成样本")

6.2 FID分数(Fréchet Inception Distance)

FID是评估生成质量的常用指标:

Python
def calculate_fid(real_images, generated_images, batch_size=32):
    """
    计算FID分数

    参数:
        real_images: 真实图像 [N, C, H, W]
        generated_images: 生成图像 [N, C, H, W]
        batch_size: 批大小

    返回:
        FID分数(越低越好)
    """
    from torchvision.models import inception_v3, Inception_V3_Weights
    from scipy import linalg

    # 加载预训练的Inception模型,移除最后分类层以获取 2048 维 pool 特征
    inception = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
    inception.fc = torch.nn.Identity()  # 替换分类层,输出 2048 维特征
    inception = inception.eval()

    def get_features(images):
        """提取Inception特征"""
        features = []
        with torch.no_grad():
            for i in range(0, len(images), batch_size):
                batch = images[i:i+batch_size]
                # 调整大小到299x299(Inception的输入大小)
                batch_resized = torch.nn.functional.interpolate(
                    batch, size=(299, 299), mode='bilinear', align_corners=False
                )
                feat = inception(batch_resized)
                features.append(feat)
        return torch.cat(features, dim=0).cpu().numpy()

    # 提取特征
    real_features = get_features(real_images)
    gen_features = get_features(generated_images)

    # 计算均值和协方差
    mu_real = np.mean(real_features, axis=0)
    mu_gen = np.mean(gen_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    sigma_gen = np.cov(gen_features, rowvar=False)

    # 计算FID
    diff = mu_real - mu_gen
    covmean, _ = linalg.sqrtm(sigma_real @ sigma_gen, disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_real + sigma_gen - 2 * covmean)

    return fid

7. 实践项目:实现完整的采样流程

Python
class DiffusionSampler:
    """扩散模型采样器"""

    def __init__(self, model, T=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
        """
        初始化采样器

        参数:
            model: 训练好的模型
            T: 总步数
            beta_start, beta_end: beta范围
            device: 设备
        """
        self.model = model
        self.T = T
        self.device = device

        # 创建调度表
        self.alphas, self.betas, self.alphas_cumprod = self._get_schedule(T, beta_start, beta_end)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])

    def _get_schedule(self, T, beta_start, beta_end):
        """创建噪声调度表"""
        betas = torch.linspace(beta_start, beta_end, T)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        return alphas, betas, alphas_cumprod

    def sample(self, batch_size=1, num_steps=None, method='ddpm', **kwargs):  # *args接收任意位置参数,**kwargs接收任意关键字参数
        """
        采样

        参数:
            batch_size: 批大小
            num_steps: 采样步数(None表示使用T)
            method: 采样方法 ('ddpm' 或 'ddim')
            **kwargs: 其他参数

        返回:
            生成的图像 [batch_size, channels, height, width]
        """
        # 创建初始噪声
        x_T = torch.randn(batch_size, 3, 32, 32)  # 假设输入大小为32x32

        if method == 'ddpm':
            return self._ddpm_sample(x_T, num_steps or self.T)
        elif method == 'ddim':
            return self._ddim_sample(x_T, num_steps or self.T)
        else:
            raise ValueError(f"未知的采样方法: {method}")

    def _ddpm_sample(self, x_T, num_steps):
        """DDPM采样"""
        self.model.eval()
        x_t = x_T.to(self.device)

        # 计算子采样间隔
        if num_steps < self.T:
            step_indices = torch.linspace(0, self.T-1, num_steps, dtype=torch.long)
        else:
            step_indices = torch.arange(self.T)

        with torch.no_grad():
            for i, t in enumerate(reversed(step_indices)):
                t = t.item()
                alpha_t = self.alphas[t]
                beta_t = self.betas[t]
                alpha_t_bar = self.alphas_cumprod[t]
                alpha_t_bar_prev = self.alphas_cumprod_prev[t]

                # 预测噪声
                t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
                predicted_noise = self.model(x_t, t_tensor)

                # 计算均值
                sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
                sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)

                mean = sqrt_recip_alpha_t * (
                    x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
                )

                # 添加噪声
                if i < len(step_indices) - 1:
                    posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                    noise = torch.randn_like(x_t)
                    x_t = mean + torch.sqrt(posterior_variance) * noise
                else:
                    x_t = mean

        return x_t

    def _ddim_sample(self, x_T, num_steps):
        """DDIM采样"""
        self.model.eval()
        x_t = x_T.to(self.device)

        # 选择采样时间步
        step_indices = torch.linspace(0, self.T-1, num_steps, dtype=torch.long)

        with torch.no_grad():
            for i, t in enumerate(reversed(step_indices)):
                t = t.item()
                alpha_t = self.alphas[t]
                alpha_t_bar = self.alphas_cumprod[t]

                # 预测噪声
                t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
                predicted_noise = self.model(x_t, t_tensor)

                # 计算x_0的预测
                sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
                sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
                x_0_pred = (x_t - sqrt_one_minus_alpha_t_bar * predicted_noise) / sqrt_alpha_t_bar

                # 计算x_{t-1}
                if i < len(step_indices) - 1:
                    t_prev = step_indices[len(step_indices) - 2 - i].item()
                    alpha_t_prev_bar = self.alphas_cumprod[t_prev]

                    sqrt_alpha_t_prev_bar = torch.sqrt(alpha_t_prev_bar)
                    sqrt_one_minus_alpha_t_prev_bar = torch.sqrt(1 - alpha_t_prev_bar)

                    x_t = sqrt_alpha_t_prev_bar * x_0_pred + sqrt_one_minus_alpha_t_prev_bar * predicted_noise
                else:
                    x_t = x_0_pred

        return x_t

# 使用示例
sampler = DiffusionSampler(model, T=1000, device=device)

# DDPM采样
samples_ddpm = sampler.sample(batch_size=4, num_steps=1000, method='ddpm')
print(f"DDPM采样完成,形状: {samples_ddpm.shape}")

# DDIM采样(更快)
samples_ddim = sampler.sample(batch_size=4, num_steps=50, method='ddim')
print(f"DDIM采样完成,形状: {samples_ddim.shape}")

8. 总结

8.1 核心概念回顾

概念 说明
DDPM采样 标准的马尔可夫链采样,质量好但速度慢
DDIM采样 确定性采样,可以大幅减少步数
噪声调度 控制每步添加的噪声量,影响生成质量
引导采样 使用分类器或无分类器引导改善生成
采样步数 平衡质量与速度的关键参数

8.2 实践建议

  1. 从DDPM开始:先理解标准采样算法
  2. 尝试DDIM:加速采样,适合快速迭代
  3. 优化调度:尝试余弦调度等改进
  4. 使用引导:无分类器引导是现代实践
  5. 评估质量:结合视觉检查和FID等指标

8.3 常见问题

Q: 采样步数越多越好吗? A: 不一定。步数过多会浪费时间,步数过少会影响质量。需要根据具体任务调整。

Q: DDIM和DDPM哪个更好? A: 取决于应用场景。DDIM更快,适合交互式应用;DDPM质量稍好,适合高质量生成。

Q: 如何选择引导强度? A: 通常在5-15之间。值越大生成越符合条件,但可能降低多样性。


9. 推荐资源

论文

  • DDPM: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
  • DDIM: "Denoising Diffusion Implicit Models" (Song et al., 2020)
  • Improved DDPM: "Improved Denoising Diffusion Probabilistic Models" (Nichol & Dhariwal, 2021)

代码库

  • Hugging Face Diffusers
  • OpenAI's guided-diffusion
  • CompVis/stable-diffusion

10. 自测问题

  1. DDPM和DDIM采样的主要区别是什么?
  2. 采样步数如何影响生成质量和速度?
  3. 什么是无分类器引导?它如何工作?
  4. 余弦噪声调度相比线性调度有什么优势?
  5. 如何评估生成图像的质量?

下一章: 06-训练技巧与优化 - 学习如何高效训练扩散模型