跳转至

04 - 扩散模型加速技术

学习时间: 3小时 重要性: ⭐⭐⭐⭐⭐ 实现快速、实用的扩散模型的关键


🎯 学习目标

完成本章后,你将能够: - 理解扩散模型加速的基本原理 - 掌握DDIM等确定性采样方法 - 学习渐进式扩散技术 - 了解潜空间扩散模型(LDM) - 实现高效的采样流程


1. 加速技术概述

1.1 为什么需要加速?

问题 影响 解决方案
采样慢 需要1000步迭代 减少采样步数
计算量大 每步需要前向传播 优化模型架构
显存占用高 大图像难以处理 潜空间建模
推理延迟高 实时应用困难 模型压缩和优化

1.2 加速技术分类

Text Only
加速技术
├── 采样加速
│   ├── DDIM(确定性采样)
│   ├── DDPM++(改进调度)
│   └── 一致性模型
├── 模型加速
│   ├── 知识蒸馏
│   ├── 模型剪枝
│   └── 量化
├── 架构优化
│   ├── 潜空间扩散(LDM)
│   ├── 级联扩散
│   └── 多尺度扩散
└── 训练加速
    ├── 混合精度训练
    ├── 分布式训练
    └── 梯度累积

2. DDIM加速采样

2.1 DDIM原理回顾

DDIM(Denoising Diffusion Implicit Models)的核心思想是:不需要遵循马尔可夫链,可以跳步

关键优势: - 确定性采样(相同输入相同输出) - 可以大幅减少采样步数 - 质量接近DDPM

2.2 DDIM数学原理

DDIM更新公式

\[x_{t-1} = \sqrt{\alpha_{t-1}} \left(\frac{x_t - \sqrt{1-\alpha_t} \epsilon_\theta(x_t, t)}{\sqrt{\alpha_t}}\right) + \sqrt{1-\alpha_{t-1}} \epsilon_\theta(x_t, t)\]

参数 \(\eta\) 控制: $\(x_{t-1} = \sqrt{\alpha_{t-1}} \tilde{x}_0 + \sqrt{1-\alpha_{t-1}} \tilde{\epsilon}_t\)$

其中: $\(\tilde{x}_0 = \frac{x_t - \sqrt{1-\alpha_t} \epsilon_\theta(x_t, t)}{\sqrt{\alpha_t}}\)$ $\(\tilde{\epsilon}_t = \sqrt{1-\alpha_{t-1} - \sigma_t^2} \epsilon_\theta(x_t, t) + \sigma_t \epsilon\)$

\[\sigma_t = \eta \sqrt{\frac{1-\alpha_{t-1}}{1-\alpha_t}} \sqrt{1-\frac{\alpha_t}{\alpha_{t-1}}}\]
  • \(\eta = 0\):确定性DDIM
  • \(\eta = 1\):DDPM(随机采样)

2.3 DDIM代码实现

Python
import torch
import torch.nn as nn
import numpy as np

class DDIMSampler:
    """DDIM采样器"""

    def __init__(self, model, alphas, alphas_cumprod, device='cuda'):
        """
        初始化DDIM采样器

        参数:
            model: 训练好的扩散模型
            alphas: alpha值
            alphas_cumprod: 累积alpha值
            device: 设备
        """
        self.model = model
        self.alphas = alphas.to(device)  # 移至GPU/CPU
        self.alphas_cumprod = alphas_cumprod.to(device)
        self.device = device

        # 预计算
        self.sqrt_alphas = torch.sqrt(alphas)
        self.sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

    def sample(self, x_T, num_steps=50, eta=0.0):
        """
        DDIM采样

        参数:
            x_T: 初始噪声 [batch_size, channels, height, width]
            num_steps: 采样步数
            eta: 随机性参数 (0=确定性, 1=DDPM)

        返回:
            生成的图像 x_0
        """
        self.model.eval()  # eval()评估模式
        x_t = x_T.to(self.device)

        T = len(self.alphas)
        # 选择采样时间步
        # 注意:linspace + dtype=torch.long 会截断小数,可能导致步长不完全均匀。
        # 更精确的做法是先用 float linspace 再 round().long(),或手动构造等间隔索引。
        step_indices = torch.linspace(0, T-1, num_steps, dtype=torch.long)

        with torch.no_grad():  # 禁用梯度计算,节省内存
            for i, t in enumerate(reversed(step_indices)):  # enumerate同时获取索引和元素
                t = t.item()  # 将单元素张量转为Python数值
                alpha_t = self.alphas[t]
                alpha_t_bar = self.alphas_cumprod[t]

                # 预测噪声
                t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
                predicted_noise = self.model(x_t, t_tensor)

                # 计算x_0的预测
                sqrt_alpha_t_bar = self.sqrt_alphas_cumprod[t]
                sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alphas_cumprod[t]
                x_0_pred = (x_t - sqrt_one_minus_alpha_t_bar * predicted_noise) / sqrt_alpha_t_bar

                # 计算x_{t-1}
                if i < len(step_indices) - 1:
                    t_prev = step_indices[len(step_indices) - 2 - i].item()
                    alpha_t_prev = self.alphas[t_prev]
                    alpha_t_prev_bar = self.alphas_cumprod[t_prev]

                    # 计算sigma_t(使用累积alpha而非单步alpha)
                    sigma_t = eta * torch.sqrt(
                        (1 - alpha_t_prev_bar) / (1 - alpha_t_bar)
                        * (1 - alpha_t_bar / alpha_t_prev_bar)
                    )

                    # 计算方向(噪声系数)
                    sqrt_alpha_t_prev_bar = torch.sqrt(alpha_t_prev_bar)

                    direction = torch.sqrt(1 - alpha_t_prev_bar - sigma_t ** 2)
                    noise = torch.randn_like(x_t) if eta > 0 else 0

                    x_t = sqrt_alpha_t_prev_bar * x_0_pred + direction * predicted_noise + sigma_t * noise
                else:
                    x_t = x_0_pred

        return x_t

# 使用示例
def get_ddim_schedule(T=1000, beta_start=0.0001, beta_end=0.02):
    """获取DDIM调度表"""
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, alphas_cumprod

# 创建采样器
alphas, alphas_cumprod = get_ddim_schedule(T=1000)
sampler = DDIMSampler(model, alphas, alphas_cumprod, device='cuda')

# 生成样本
x_T = torch.randn(4, 3, 32, 32).to('cuda')
x_0 = sampler.sample(x_T, num_steps=50, eta=0.0)
print(f"生成完成,形状: {x_0.shape}")

2.4 DDIM vs DDPM对比

Python
def compare_sampling_methods(model, x_T, alphas, betas, alphas_cumprod, device='cuda'):
    """
    对比DDIM和DDPM采样

    参数:
        model: 模型
        x_T: 初始噪声
        alphas, betas, alphas_cumprod: 调度表
        device: 设备
    """
    import time
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid

    # DDPM采样(1000步)
    start_time = time.time()
    # 参见 02-扩散模型核心原理/05-采样算法详解.md 中的 ddpm_sample 实现
    x_0_ddpm = ddpm_sample(model, x_T, 1000, alphas, betas, alphas_cumprod, device)
    ddpm_time = time.time() - start_time

    # DDIM采样(50步)
    start_time = time.time()
    ddim_sampler = DDIMSampler(model, alphas, alphas_cumprod, device)
    x_0_ddim = ddim_sampler.sample(x_T, num_steps=50, eta=0.0)
    ddim_time = time.time() - start_time

    # 可视化对比
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # DDPM结果
    ddpm_grid = make_grid((x_0_ddpm + 1) / 2, nrow=2, padding=2, normalize=False)
    axes[0].imshow(ddpm_grid.permute(1, 2, 0).cpu())
    axes[0].set_title(f'DDPM (1000步, {ddpm_time:.2f}s)')
    axes[0].axis('off')

    # DDIM结果
    ddim_grid = make_grid((x_0_ddim + 1) / 2, nrow=2, padding=2, normalize=False)
    axes[1].imshow(ddim_grid.permute(1, 2, 0).cpu())
    axes[1].set_title(f'DDIM (50步, {ddim_time:.2f}s)')
    axes[1].axis('off')

    plt.tight_layout()
    plt.savefig('ddim_vs_ddpm.png', dpi=150, bbox_inches='tight')
    plt.show()

    print(f"DDPM时间: {ddpm_time:.2f}s")
    print(f"DDIM时间: {ddim_time:.2f}s")
    print(f"加速比: {ddpm_time/ddim_time:.2f}x")

3. 渐进式扩散

3.1 渐进式扩散原理

渐进式扩散(Progressive Diffusion)的核心思想是:逐步提高图像分辨率

优势: - 减少计算量 - 提高生成质量 - 更好地控制生成过程

3.2 渐进式扩散架构

Python
class ProgressiveDiffusion(nn.Module):  # 继承nn.Module定义网络层
    """渐进式扩散模型"""

    def __init__(self, stages=[32, 64, 128, 256]):
        """
        参数:
            stages: 各阶段的分辨率
        """
        super().__init__()  # super()调用父类方法
        self.stages = stages

        # 为每个分辨率创建模型
        self.models = nn.ModuleList()
        for i, size in enumerate(stages):
            if i == 0:
                in_channels = 3
            else:
                in_channels = 3 * 2  # 原始图像 + 上采样图像

            model = UNet(
                in_channels=in_channels,
                out_channels=3,
                model_dim=64 * (2 ** i)
            )
            self.models.append(model)

    def forward(self, x, t, stage_idx):
        """
        前向传播

        参数:
            x: 输入图像
            t: 时间步
            stage_idx: 阶段索引
        """
        return self.models[stage_idx](x, t)

    def progressive_sample(self, x_T, num_steps=50, device='cuda'):
        """
        渐进式采样

        参数:
            x_T: 初始噪声(最低分辨率)
            num_steps: 每阶段的采样步数
            device: 设备

        返回:
            生成的图像(最高分辨率)
        """
        current_x = x_T.to(device)

        for stage_idx, target_size in enumerate(self.stages):
            print(f"阶段 {stage_idx + 1}/{len(self.stages)}: 生成 {target_size}x{target_size} 图像")

            # 如果不是第一阶段,上采样前一阶段的输出
            if stage_idx > 0:
                current_x = torch.nn.functional.interpolate(
                    current_x,
                    size=(target_size, target_size),
                    mode='bilinear',
                    align_corners=False
                )

            # 在当前分辨率采样
            sampler = DDIMSampler(self.models[stage_idx], *get_ddim_schedule(), device)
            current_x = sampler.sample(current_x, num_steps=num_steps)

        return current_x

# 使用示例
progressive_model = ProgressiveDiffusion(stages=[32, 64, 128, 256])
x_T = torch.randn(1, 3, 32, 32).to('cuda')
x_0 = progressive_model.progressive_sample(x_T, num_steps=50)
print(f"生成完成,形状: {x_0.shape}")

4. 潜空间扩散模型(LDM)

4.1 LDM原理

潜空间扩散模型(Latent Diffusion Models, LDM)的核心思想是:在压缩的潜空间中进行扩散

优势: - 大幅减少计算量 - 支持更高分辨率 - 更好的生成质量

架构

Text Only
输入图像
VAE编码器
潜空间 z (压缩表示)
扩散模型(在潜空间)
潜空间 z'
VAE解码器
输出图像

4.2 VAE编码器/解码器

Python
class VAEEncoder(nn.Module):
    """VAE编码器"""

    def __init__(self, in_channels=3, latent_dim=4, base_dim=64):
        super().__init__()

        # 下采样
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, base_dim, 3, padding=1),
            nn.GroupNorm(8, base_dim),
            nn.SiLU(),

            nn.Conv2d(base_dim, base_dim, 3, stride=2, padding=1),  # /2
            nn.GroupNorm(8, base_dim),
            nn.SiLU(),

            nn.Conv2d(base_dim, base_dim * 2, 3, stride=2, padding=1),  # /4
            nn.GroupNorm(8, base_dim * 2),
            nn.SiLU(),

            nn.Conv2d(base_dim * 2, base_dim * 4, 3, stride=2, padding=1),  # /8
            nn.GroupNorm(8, base_dim * 4),
            nn.SiLU(),

            nn.Conv2d(base_dim * 4, base_dim * 4, 3, padding=1),
            nn.GroupNorm(8, base_dim * 4),
            nn.SiLU(),
        )

        # 输出均值和方差
        self.mean_conv = nn.Conv2d(base_dim * 4, latent_dim, 1)
        self.logvar_conv = nn.Conv2d(base_dim * 4, latent_dim, 1)

    def forward(self, x):
        """
        前向传播

        参数:
            x: 输入图像 [B, C, H, W]

        返回:
            mean, logvar: 潜空间的均值和对数方差
        """
        x = self.encoder(x)
        mean = self.mean_conv(x)
        logvar = self.logvar_conv(x)
        return mean, logvar

class VAEDecoder(nn.Module):
    """VAE解码器"""

    def __init__(self, latent_dim=4, out_channels=3, base_dim=64):
        super().__init__()

        # 上采样
        self.decoder = nn.Sequential(
            nn.Conv2d(latent_dim, base_dim * 4, 3, padding=1),
            nn.GroupNorm(8, base_dim * 4),
            nn.SiLU(),

            nn.Conv2d(base_dim * 4, base_dim * 4, 3, padding=1),
            nn.GroupNorm(8, base_dim * 4),
            nn.SiLU(),

            nn.ConvTranspose2d(base_dim * 4, base_dim * 2, 4, stride=2, padding=1),  # *2
            nn.GroupNorm(8, base_dim * 2),
            nn.SiLU(),

            nn.ConvTranspose2d(base_dim * 2, base_dim, 4, stride=2, padding=1),  # *2
            nn.GroupNorm(8, base_dim),
            nn.SiLU(),

            nn.ConvTranspose2d(base_dim, base_dim, 4, stride=2, padding=1),  # *2
            nn.GroupNorm(8, base_dim),
            nn.SiLU(),

            nn.Conv2d(base_dim, out_channels, 3, padding=1),
        )

    def forward(self, z):
        """
        前向传播

        参数:
            z: 潜空间表示 [B, latent_dim, H/8, W/8]

        返回:
            x: 重建图像 [B, out_channels, H, W]
        """
        x = self.decoder(z)
        return x

class VAE(nn.Module):
    """完整的VAE"""

    def __init__(self, in_channels=3, latent_dim=4, base_dim=64):
        super().__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim, base_dim)
        self.decoder = VAEDecoder(latent_dim, in_channels, base_dim)

    def encode(self, x):
        """编码"""
        mean, logvar = self.encoder(x)
        return mean, logvar

    def decode(self, z):
        """解码"""
        return self.decoder(z)

    def reparameterize(self, mean, logvar):
        """重参数化"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        """前向传播"""
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decode(z)
        return x_recon, mean, logvar

    def encode_to_latent(self, x):
        """编码到潜空间(确定性)"""
        mean, _ = self.encode(x)
        return mean

4.3 潜空间扩散模型

Python
class LatentDiffusion(nn.Module):
    """潜空间扩散模型"""

    def __init__(self, vae, diffusion_model, latent_dim=4, T=1000):
        super().__init__()
        self.vae = vae
        self.diffusion_model = diffusion_model
        self.latent_dim = latent_dim
        self.T = T

        # 创建噪声调度
        betas = torch.linspace(0.0001, 0.02, T)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)

    def encode(self, x):
        """编码到潜空间"""
        with torch.no_grad():
            z = self.vae.encode_to_latent(x)
        return z

    def decode(self, z):
        """从潜空间解码"""
        with torch.no_grad():
            x = self.vae.decode(z)
        return x

    def get_loss(self, x_0):
        """
        计算损失

        参数:
            x_0: 原始图像

        返回:
            损失值
        """
        # 编码到潜空间
        z_0 = self.encode(x_0)

        # 随机采样时间步
        batch_size = x_0.shape[0]
        t = torch.randint(0, self.T, (batch_size,), device=x_0.device)

        # 生成噪声
        noise = torch.randn_like(z_0)

        # 计算加噪后的潜空间
        sqrt_alpha_t_bar = torch.sqrt(self.alphas_cumprod[t]).view(-1, 1, 1, 1)  # 重塑张量形状
        sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - self.alphas_cumprod[t]).view(-1, 1, 1, 1)
        z_t = sqrt_alpha_t_bar * z_0 + sqrt_one_minus_alpha_t_bar * noise

        # 模型预测噪声
        predicted_noise = self.diffusion_model(z_t, t)

        # 计算损失
        loss = nn.functional.mse_loss(predicted_noise, noise)

        return loss

    def sample(self, num_samples=4, num_steps=50, method='ddim', image_size=256):
        """
        生成样本

        参数:
            num_samples: 样本数量
            num_steps: 采样步数
            method: 采样方法
            image_size: 输出图像大小

        返回:
            生成的图像
        """
        latent_size = image_size // 8

        # 在潜空间采样
        z_T = torch.randn(num_samples, self.latent_dim, latent_size, latent_size)

        if method == 'ddim':
            sampler = DDIMSampler(self.diffusion_model, self.alphas, self.alphas_cumprod)
            z_0 = sampler.sample(z_T, num_steps=num_steps)
        else:
            # 参见 02-扩散模型核心原理/05-采样算法详解.md 中的 ddpm_sample 实现
            z_0 = ddpm_sample(self.diffusion_model, z_T, self.T, self.alphas,
                            self.betas, self.alphas_cumprod)

        # 解码到图像空间
        x_0 = self.decode(z_0)

        return x_0

# 使用示例
# 创建VAE
vae = VAE(in_channels=3, latent_dim=4, base_dim=64)

# 创建扩散模型
diffusion_model = UNet(in_channels=4, out_channels=4, model_dim=128)

# 创建潜空间扩散模型
ldm = LatentDiffusion(vae, diffusion_model, latent_dim=4, T=1000)

# 生成样本
samples = ldm.sample(num_samples=4, num_steps=50, method='ddim', image_size=256)
print(f"生成完成,形状: {samples.shape}")

5. 模型压缩与优化

5.1 知识蒸馏

Python
def distill_model(teacher_model, student_model, dataloader, num_epochs=10, device='cuda'):
    """
    知识蒸馏

    参数:
        teacher_model: 教师模型
        student_model: 学生模型
        dataloader: 数据加载器
        num_epochs: 训练轮数
        device: 设备
    """
    teacher_model.eval()
    student_model.train()

    optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        total_loss = 0

        for x_0, _ in dataloader:
            x_0 = x_0.to(device)

            # 教师模型预测
            with torch.no_grad():
                teacher_output = teacher_model(x_0)

            # 学生模型预测
            student_output = student_model(x_0)

            # 蒸馏损失
            loss = nn.functional.mse_loss(student_output, teacher_output)

            # 反向传播
            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新参数

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    return student_model

5.2 模型量化

Python
def quantize_model(model):
    """
    量化模型

    参数:
        model: 模型

    返回:
        量化后的模型
    """
    # 动态量化
    quantized_model = torch.quantization.quantize_dynamic(
        model,
        {nn.Conv2d, nn.Linear},
        dtype=torch.qint8
    )

    return quantized_model

# 使用示例
quantized_model = quantize_model(model)
print(f"原始模型大小: {get_model_size(model):.2f} MB")
print(f"量化后模型大小: {get_model_size(quantized_model):.2f} MB")

6. 总结

6.1 核心技术回顾

技术 原理 加速比 质量影响
DDIM 确定性采样,跳步 10-20x 轻微下降
渐进式扩散 逐步提高分辨率 5-10x 轻微提升
LDM 潜空间扩散 20-50x 轻微提升
知识蒸馏 大模型教小模型 2-5x 中等下降
量化 降低精度 2-4x 轻微下降

6.2 最佳实践

  1. 从DDIM开始:最容易实现的加速方法
  2. 考虑LDM:适合高分辨率生成
  3. 组合使用:多种技术组合效果更好
  4. 权衡质量与速度:根据应用场景选择
  5. 评估影响:量化加速对质量的影响

6.3 学习建议

  1. 理解原理:先理解每种技术的原理
  2. 动手实现:亲自实现加速方法
  3. 对比实验:对比不同方法的效果
  4. 组合优化:尝试组合多种技术

7. 推荐资源

论文

  • DDIM: "Denoising Diffusion Implicit Models"
  • LDM: "High-Resolution Image Synthesis with Latent Diffusion Models"
  • Progressive Diffusion: "Progressive Distillation for Fast Sampling of Diffusion Models"

代码库

  • Hugging Face Diffusers
  • CompVis/stable-diffusion
  • NVIDIA/latent-diffusion

8. 自测问题

  1. DDIM相比DDPM有什么优势?
  2. 潜空间扩散模型如何加速采样?
  3. 渐进式扩散的原理是什么?
  4. 知识蒸馏如何用于扩散模型?
  5. 如何平衡加速和生成质量?

下一章: 05-条件生成与控制 - 学习如何控制扩散模型的生成过程