跳转至

04 - 完整DDPM实现

学习时间: 5小时 重要性: ⭐⭐⭐⭐⭐ 整合所有组件的完整实现


🎯 学习目标

完成本章后,你将能够: - 整合所有组件实现完整DDPM - 理解各模块之间的交互 - 实现完整的训练和采样流程 - 调试和优化DDPM模型


1. 项目结构

1.1 文件组织

Text Only
ddpm_implementation/
├── config.py           # 配置文件
├── model.py            # UNet模型
├── diffusion.py        # 扩散过程
├── trainer.py          # 训练器
├── sampler.py          # 采样器
├── dataset.py          # 数据加载
├── utils.py            # 工具函数
├── train.py            # 训练脚本
└── sample.py           # 采样脚本

1.2 配置文件

Python
# config.py
from dataclasses import dataclass
from typing import Tuple

@dataclass  # @dataclass自动生成__init__等方法
class DDPMConfig:
    """DDPM配置"""
    # 数据配置
    dataset: str = "cifar10"
    image_size: int = 32
    channels: int = 3
    data_path: str = "./data"

    # 模型配置
    base_channels: int = 128
    channel_mults: Tuple[int, ...] = (1, 2, 2, 2)
    num_res_blocks: int = 2
    attention_resolutions: Tuple[int, ...] = (16, 8)
    dropout: float = 0.1

    # 扩散配置
    timesteps: int = 1000
    beta_schedule: str = "linear"
    beta_start: float = 0.0001
    beta_end: float = 0.02

    # 训练配置
    batch_size: int = 128
    num_epochs: int = 100
    learning_rate: float = 2e-4
    weight_decay: float = 0.0
    grad_clip: float = 1.0

    # EMA配置
    use_ema: bool = True
    ema_decay: float = 0.9999

    # 日志配置
    log_interval: int = 100
    save_interval: int = 5000
    sample_interval: int = 5000

    # 设备配置
    device: str = "cuda"
    num_workers: int = 4

    # 输出配置
    output_dir: str = "./outputs"
    checkpoint_dir: str = "./checkpoints"

2. 完整DDPM类

Python
# diffusion.py
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple

class GaussianDiffusion:
    """
    高斯扩散过程的完整实现
    """
    def __init__(
        self,
        timesteps: int = 1000,
        beta_schedule: str = "linear",
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
    ):
        self.timesteps = timesteps

        # 生成beta调度
        if beta_schedule == "linear":
            betas = torch.linspace(beta_start, beta_end, timesteps)
        elif beta_schedule == "cosine":
            betas = self._cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f"Unknown beta schedule: {beta_schedule}")

        self.betas = betas
        self.alphas = 1.0 - betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([  # torch.cat沿已有维度拼接张量
            torch.tensor([1.0]),
            self.alphas_cumprod[:-1]
        ])

        # 预计算各种值用于加速
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)

        # 后验方差
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])
        )

        # 后验均值系数
        self.posterior_mean_coef1 = (
            betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
        )

    def _cosine_beta_schedule(self, timesteps: int, s: float = 0.008) -> torch.Tensor:
        """余弦beta调度"""
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi / 2) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)

    def q_sample(
        self,
        x_0: torch.Tensor,
        t: torch.Tensor,
        noise: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向扩散过程: 从 q(x_t | x_0) 采样

        参数:
            x_0: 原始图像 [B, C, H, W]
            t: 时间步 [B]
            noise: 可选的噪声

        返回:
            x_t: 加噪图像
            noise: 使用的噪声
        """
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alpha_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_0.shape)
        sqrt_one_minus_alpha_cumprod_t = self._extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
        )

        x_t = sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise
        return x_t, noise

    def q_posterior_mean_variance(
        self,
        x_0: torch.Tensor,
        x_t: torch.Tensor,
        t: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        计算后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差
        """
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
            self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance = self._extract(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance

    def predict_x0_from_eps(
        self,
        x_t: torch.Tensor,
        eps: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        """
        从预测的噪声反推 x_0
        """
        sqrt_recip_alphas_cumprod_t = self._extract(
            self.sqrt_recip_alphas_cumprod, t, x_t.shape
        )
        sqrt_recipm1_alphas_cumprod_t = self._extract(
            self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
        )
        return sqrt_recip_alphas_cumprod_t * x_t - sqrt_recipm1_alphas_cumprod_t * eps

    def p_mean_variance(
        self,
        model: nn.Module,
        x_t: torch.Tensor,
        t: torch.Tensor,
        clip_denoised: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        计算反向过程 p(x_{t-1} | x_t) 的均值和方差
        """
        # 预测噪声
        eps_pred = model(x_t, t)

        # 预测x_0
        x_0_pred = self.predict_x0_from_eps(x_t, eps_pred, t)

        if clip_denoised:
            x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)

        # 计算后验均值和方差
        model_mean, posterior_variance, posterior_log_variance = \
            self.q_posterior_mean_variance(x_0_pred, x_t, t)

        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(
        self,
        model: nn.Module,
        x_t: torch.Tensor,
        t: torch.Tensor,
        noise: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        从 p(x_{t-1} | x_t) 采样
        """
        model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t)

        if noise is None:
            noise = torch.randn_like(x_t)

        # 只在 t > 0 时添加噪声
        nonzero_mask = (t != 0).float().view(-1, 1, 1, 1)  # 重塑张量形状
        x_prev = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise

        return x_prev

    def p_sample_loop(
        self,
        model: nn.Module,
        shape: Tuple[int, ...],
        device: torch.device,
        noise: Optional[torch.Tensor] = None,
        progress: bool = True,
    ) -> torch.Tensor:
        """
        完整的采样循环
        """
        if noise is not None:
            x_t = noise
        else:
            x_t = torch.randn(shape, device=device)

        indices = list(range(self.timesteps))[::-1]

        if progress:
            from tqdm import tqdm
            indices = tqdm(indices, desc="Sampling")

        for i in indices:
            t = torch.full((shape[0],), i, device=device, dtype=torch.long)
            x_t = self.p_sample(model, x_t, t)

        return x_t

    def training_losses(
        self,
        model: nn.Module,
        x_0: torch.Tensor,
        t: torch.Tensor,
        noise: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        计算训练损失
        """
        if noise is None:
            noise = torch.randn_like(x_0)

        x_t, _ = self.q_sample(x_0, t, noise)
        eps_pred = model(x_t, t)

        # MSE损失
        loss = torch.nn.functional.mse_loss(eps_pred, noise)
        return loss

    def _extract(
        self,
        a: torch.Tensor,
        t: torch.Tensor,
        x_shape: Tuple[int, ...],
    ) -> torch.Tensor:
        """从张量a中提取对应时间步t的值"""
        batch_size = t.shape[0]
        out = a.to(t.device).gather(0, t).float()
        return out.view(batch_size, *((1,) * (len(x_shape) - 1)))

3. 完整训练器

Python
# trainer.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any
import json

class DDPMTrainer:
    """
    DDPM训练器
    """
    def __init__(
        self,
        model: nn.Module,
        diffusion: 'GaussianDiffusion',
        config: 'DDPMConfig',
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None,
    ):
        self.model = model
        self.diffusion = diffusion
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader

        # 设备
        self.device = torch.device(config.device if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # 优化器
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )

        # 学习率调度器
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=config.num_epochs * len(train_loader),
        )

        # EMA
        if config.use_ema:
            self.ema_model = self._create_ema_model()
        else:
            self.ema_model = None

        # 日志
        self.writer = SummaryWriter(log_dir=os.path.join(config.output_dir, "logs"))

        # 检查点
        self.checkpoint_dir = Path(config.checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

        # 训练状态
        self.global_step = 0
        self.epoch = 0
        self.best_loss = float('inf')

    def _create_ema_model(self):
        """创建EMA模型"""
        ema_model = type(self.model)(**self.model.init_kwargs).to(self.device)
        ema_model.load_state_dict(self.model.state_dict())
        ema_model.requires_grad_(False)
        return ema_model

    def _update_ema(self):
        """更新EMA"""
        if self.ema_model is None:
            return

        with torch.no_grad():  # 禁用梯度计算,节省内存
            # ⚠️ 注意: 此处只更新 parameters,不包含 buffers(如 BatchNorm 的 running_mean/running_var)。
            # 如果模型使用了 BatchNorm,应额外同步 buffers:
            #   for ema_buf, buf in zip(self.ema_model.buffers(), self.model.buffers()):
            #       ema_buf.data.copy_(buf.data)
            for ema_param, param in zip(  # zip按位置配对
                self.ema_model.parameters(),
                self.model.parameters()
            ):
                ema_param.data.mul_(self.config.ema_decay).add_(
                    param.data, alpha=1 - self.config.ema_decay
                )

    def train_epoch(self) -> Dict[str, float]:
        """训练一个epoch"""
        self.model.train()  # train()训练模式
        total_loss = 0.0
        num_batches = 0

        for batch_idx, (images, _) in enumerate(self.train_loader):  # enumerate同时获取索引和元素
            images = images.to(self.device)
            batch_size = images.shape[0]

            # 随机选择时间步
            t = torch.randint(
                0,
                self.config.timesteps,
                (batch_size,),
                device=self.device,
            ).long()

            # 计算损失
            loss = self.diffusion.training_losses(self.model, images, t)

            # 反向传播
            self.optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度

            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.grad_clip,
            )

            self.optimizer.step()  # 更新参数
            self.scheduler.step()

            # 更新EMA
            self._update_ema()

            # 记录
            total_loss += loss.item()  # 将单元素张量转为Python数值
            num_batches += 1
            self.global_step += 1

            # 日志
            if self.global_step % self.config.log_interval == 0:
                lr = self.scheduler.get_last_lr()[0]
                self.writer.add_scalar("train/loss", loss.item(), self.global_step)
                self.writer.add_scalar("train/lr", lr, self.global_step)

                print(
                    f"Epoch {self.epoch}, Step {self.global_step}, "
                    f"Loss: {loss.item():.4f}, LR: {lr:.6f}"
                )

            # 保存检查点
            if self.global_step % self.config.save_interval == 0:
                self.save_checkpoint()

            # 采样
            if self.global_step % self.config.sample_interval == 0:
                self.sample_and_save()

        avg_loss = total_loss / num_batches
        return {"train_loss": avg_loss}

    @torch.no_grad()
    def validate(self) -> Dict[str, float]:
        """验证"""
        if self.val_loader is None:
            return {}

        # ⚠️ 注意: 验证损失使用的是训练模型 (self.model) 而非 EMA 模型。
        # 如果希望用 EMA 模型评估,应替换为:
        #   model = self.ema_model if self.ema_model is not None else self.model
        self.model.eval()
        total_loss = 0.0
        num_batches = 0

        for images, _ in self.val_loader:
            images = images.to(self.device)
            batch_size = images.shape[0]

            t = torch.randint(
                0,
                self.config.timesteps,
                (batch_size,),
                device=self.device,
            ).long()

            loss = self.diffusion.training_losses(self.model, images, t)
            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        self.writer.add_scalar("val/loss", avg_loss, self.global_step)

        return {"val_loss": avg_loss}

    @torch.no_grad()
    def sample_and_save(self, num_samples: int = 64):
        """采样并保存"""
        self.model.eval()

        # 使用EMA模型采样(如果可用)
        model = self.ema_model if self.ema_model is not None else self.model

        samples = self.diffusion.p_sample_loop(
            model,
            shape=(num_samples, self.config.channels, self.config.image_size, self.config.image_size),
            device=self.device,
        )

        # 保存图像
        self._save_images(samples, f"sample_step_{self.global_step}.png")

        self.model.train()

    def _save_images(self, images: torch.Tensor, filename: str):
        """保存图像"""
        import torchvision

        # 反归一化
        images = (images + 1) / 2
        images = torch.clamp(images, 0, 1)

        # 保存
        save_path = Path(self.config.output_dir) / "samples" / filename
        save_path.parent.mkdir(parents=True, exist_ok=True)
        torchvision.utils.save_image(images, save_path, nrow=8)

        print(f"Saved samples to {save_path}")

    def save_checkpoint(self, is_best: bool = False):
        """保存检查点"""
        checkpoint = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "global_step": self.global_step,
            "epoch": self.epoch,
            "config": self.config.__dict__,
        }

        if self.ema_model is not None:
            checkpoint["ema_model"] = self.ema_model.state_dict()

        # 保存最新检查点
        latest_path = self.checkpoint_dir / "latest.pt"
        torch.save(checkpoint, latest_path)

        # 保存最佳检查点
        if is_best:
            best_path = self.checkpoint_dir / "best.pt"
            torch.save(checkpoint, best_path)

        # 定期保存
        if self.global_step % (self.config.save_interval * 10) == 0:
            step_path = self.checkpoint_dir / f"checkpoint_{self.global_step}.pt"
            torch.save(checkpoint, step_path)

    def load_checkpoint(self, checkpoint_path: str):
        """加载检查点"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)

        self.model.load_state_dict(checkpoint["model"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scheduler.load_state_dict(checkpoint["scheduler"])
        self.global_step = checkpoint["global_step"]
        self.epoch = checkpoint["epoch"]

        if "ema_model" in checkpoint and self.ema_model is not None:
            self.ema_model.load_state_dict(checkpoint["ema_model"])

        print(f"Loaded checkpoint from {checkpoint_path}")

    def train(self):
        """完整训练流程"""
        print(f"Starting training for {self.config.num_epochs} epochs")
        print(f"Device: {self.device}")
        print(f"Total steps: {self.config.num_epochs * len(self.train_loader)}")

        for epoch in range(self.epoch, self.config.num_epochs):
            self.epoch = epoch
            print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

            # 训练
            train_metrics = self.train_epoch()
            print(f"Train Loss: {train_metrics['train_loss']:.4f}")

            # 验证
            val_metrics = self.validate()
            if val_metrics:
                print(f"Val Loss: {val_metrics['val_loss']:.4f}")

                # 保存最佳模型
                if val_metrics['val_loss'] < self.best_loss:
                    self.best_loss = val_metrics['val_loss']
                    self.save_checkpoint(is_best=True)

            # 每个epoch保存
            self.save_checkpoint()

        print("\nTraining completed!")
        self.writer.close()

4. 训练脚本

Python
# train.py
import argparse
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

from config import DDPMConfig
from model import UNet
from diffusion import GaussianDiffusion
from trainer import DDPMTrainer

def get_dataloader(config: DDPMConfig, train: bool = True):
    """获取数据加载器"""
    # 数据变换
    if train:
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    # 加载数据集
    if config.dataset == "cifar10":
        dataset = torchvision.datasets.CIFAR10(
            root=config.data_path,
            train=train,
            download=True,
            transform=transform,
        )
    elif config.dataset == "mnist":
        dataset = torchvision.datasets.MNIST(
            root=config.data_path,
            train=train,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(config.image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]),
        )
    else:
        raise ValueError(f"Unknown dataset: {config.dataset}")

    # 数据加载器
    dataloader = DataLoader(  # DataLoader批量加载数据
        dataset,
        batch_size=config.batch_size,
        shuffle=train,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=train,
    )

    return dataloader

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default=None)
    parser.add_argument("--resume", type=str, default=None)
    args = parser.parse_args()

    # 配置
    if args.config:
        # 从文件加载配置
        import json
        with open(args.config, 'r') as f:  # with自动管理文件关闭
            config_dict = json.load(f)
        config = DDPMConfig(**config_dict)
    else:
        config = DDPMConfig()

    print("Configuration:")
    for key, value in config.__dict__.items():
        print(f"  {key}: {value}")

    # 数据
    train_loader = get_dataloader(config, train=True)
    val_loader = get_dataloader(config, train=False)

    print(f"\nTrain batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")

    # 模型
    model = UNet(
        in_channels=config.channels,
        out_channels=config.channels,
        base_channels=config.base_channels,
        channel_mults=config.channel_mults,
        num_res_blocks=config.num_res_blocks,
        attention_resolutions=config.attention_resolutions,
        dropout=config.dropout,
    )

    # 保存模型初始化参数
    model.init_kwargs = {
        "in_channels": config.channels,
        "out_channels": config.channels,
        "base_channels": config.base_channels,
        "channel_mults": config.channel_mults,
        "num_res_blocks": config.num_res_blocks,
        "attention_resolutions": config.attention_resolutions,
        "dropout": config.dropout,
    }

    num_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel parameters: {num_params / 1e6:.2f}M")

    # 扩散过程
    diffusion = GaussianDiffusion(
        timesteps=config.timesteps,
        beta_schedule=config.beta_schedule,
        beta_start=config.beta_start,
        beta_end=config.beta_end,
    )

    # 训练器
    trainer = DDPMTrainer(
        model=model,
        diffusion=diffusion,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader,
    )

    # 恢复训练
    if args.resume:
        trainer.load_checkpoint(args.resume)

    # 训练
    trainer.train()

if __name__ == "__main__":
    main()

5. 本章总结

核心组件

  1. GaussianDiffusion: 完整的扩散过程实现
  2. DDPMTrainer: 完整的训练流程
  3. 配置系统: 灵活的配置管理
  4. 日志系统: TensorBoard集成
  5. 检查点: 自动保存和恢复

关键特性

  • EMA: 指数移动平均提升生成质量
  • 学习率调度: 余弦退火
  • 梯度裁剪: 防止梯度爆炸
  • 混合精度: 加速训练(可扩展)
  • 验证: 监控过拟合

使用方法

Bash
# 训练
python train.py

# 恢复训练
python train.py --resume checkpoints/latest.pt

# 使用自定义配置
python train.py --config config.json

🔗 下一步

实现了完整的DDPM后,我们将学习模型训练与评估,包括分布式训练、超参数调优等。

→ 下一步:05-模型训练与评估.md