跳转至

05 - 模型训练与评估

学习时间: 3.5小时 重要性: ⭐⭐⭐⭐⭐ 将理论转化为实际可用的模型


🎯 学习目标

完成本章后,你将能够: - 实现完整的DDPM训练流程 - 掌握模型评估的方法 - 学习监控训练进度的技巧 - 实现检查点保存和恢复 - 评估生成图像的质量


1. 完整训练流程

1.1 训练类设计

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
import time
from tqdm import tqdm

class DDPMTrainer:
    """DDPM训练器"""

    def __init__(self, model, train_loader, val_loader, config, device='cuda'):
        """
        初始化训练器

        参数:
            model: DDPM模型
            train_loader: 训练数据加载器
            val_loader: 验证数据加载器
            config: 训练配置
            device: 设备
        """
        self.model = model.to(device)  # 移至GPU/CPU
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device

        # 创建噪声调度
        self.T = config['T']
        self.alphas, self.betas, self.alphas_cumprod = self._get_schedule()
        self.alphas_cumprod = self.alphas_cumprod.to(device)

        # 创建优化器
        self.optimizer = self._get_optimizer()

        # 创建学习率调度器
        self.scheduler = self._get_scheduler()

        # 创建TensorBoard写入器
        self.writer = SummaryWriter(config['log_dir'])

        # 训练状态
        self.current_epoch = 0
        self.global_step = 0
        self.best_loss = float('inf')

        # 创建保存目录
        os.makedirs(config['save_dir'], exist_ok=True)

    def _get_schedule(self):
        """创建噪声调度表"""
        # 内联实现线性调度(ddpm_utils 为示例模块,请根据项目实际结构调整)
        beta_start = self.config['beta_start']
        beta_end = self.config['beta_end']
        betas = torch.linspace(beta_start, beta_end, self.T)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        return alphas, betas, alphas_cumprod

    def _get_optimizer(self):
        """创建优化器"""
        return optim.AdamW(
            self.model.parameters(),
            lr=self.config['learning_rate'],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=self.config['weight_decay']
        )

    def _get_scheduler(self):
        """创建学习率调度器"""
        total_steps = self.config['num_epochs'] * len(self.train_loader)
        return optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=total_steps,
            eta_min=self.config['min_lr']
        )

    def train_step(self, x_0):
        """
        单个训练步骤

        参数:
            x_0: 原始图像 [batch_size, channels, height, width]

        返回:
            损失值
        """
        self.model.train()  # train()训练模式

        # 随机采样时间步
        batch_size = x_0.shape[0]
        t = torch.randint(0, self.T, (batch_size,), device=self.device)

        # 生成噪声
        noise = torch.randn_like(x_0)

        # 计算加噪后的图像
        sqrt_alpha_t_bar = torch.sqrt(self.alphas_cumprod[t]).view(-1, 1, 1, 1)  # 重塑张量形状
        sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - self.alphas_cumprod[t]).view(-1, 1, 1, 1)
        x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise

        # 模型预测噪声
        predicted_noise = self.model(x_t, t)

        # 计算损失
        loss = nn.functional.mse_loss(predicted_noise, noise)

        # 反向传播
        self.optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度

        # 梯度裁剪
        if self.config['grad_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config['grad_clip']
            )

        # 更新参数
        self.optimizer.step()  # 更新参数
        self.scheduler.step()

        return loss.item()  # 将单元素张量转为Python数值

    def validate(self):
        """验证模型"""
        self.model.eval()
        total_loss = 0

        with torch.no_grad():  # 禁用梯度计算,节省内存
            for x_0, _ in tqdm(self.val_loader, desc='验证'):
                x_0 = x_0.to(self.device)

                # 随机采样时间步
                batch_size = x_0.shape[0]
                t = torch.randint(0, self.T, (batch_size,), device=self.device)

                # 生成噪声
                noise = torch.randn_like(x_0)

                # 计算加噪后的图像
                sqrt_alpha_t_bar = torch.sqrt(self.alphas_cumprod[t]).view(-1, 1, 1, 1)
                sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - self.alphas_cumprod[t]).view(-1, 1, 1, 1)
                x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise

                # 模型预测噪声
                predicted_noise = self.model(x_t, t)

                # 计算损失
                loss = nn.functional.mse_loss(predicted_noise, noise)
                total_loss += loss.item()

        return total_loss / len(self.val_loader)

    def train(self, num_epochs=None):
        """
        训练模型

        参数:
            num_epochs: 训练轮数(None表示使用配置中的值)
        """
        if num_epochs is None:
            num_epochs = self.config['num_epochs']

        print(f"开始训练,共 {num_epochs} 轮")
        print(f"设备: {self.device}")
        print(f"模型参数量: {sum(p.numel() for p in self.model.parameters()):,}")

        start_time = time.time()

        for epoch in range(self.current_epoch, self.current_epoch + num_epochs):
            self.current_epoch = epoch
            epoch_loss = 0
            epoch_start_time = time.time()

            # 训练阶段
            self.model.train()
            pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

            for x_0, _ in pbar:
                x_0 = x_0.to(self.device)

                # 训练步骤
                loss = self.train_step(x_0)
                epoch_loss += loss

                # 更新进度条
                pbar.set_postfix({'loss': f'{loss:.4f}'})

                # 记录到TensorBoard
                self.writer.add_scalar('Loss/train', loss, self.global_step)
                self.writer.add_scalar('Learning_Rate',
                                     self.optimizer.param_groups[0]['lr'],
                                     self.global_step)
                self.global_step += 1

            epoch_loss /= len(self.train_loader)
            epoch_time = time.time() - epoch_start_time

            # 验证阶段
            val_loss = self.validate()

            # 记录到TensorBoard
            self.writer.add_scalar('Loss/val', val_loss, epoch)
            self.writer.add_scalar('Epoch_Time', epoch_time, epoch)

            # 打印进度
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print(f"  训练损失: {epoch_loss:.4f}")
            print(f"  验证损失: {val_loss:.4f}")
            print(f"  学习率: {self.optimizer.param_groups[0]['lr']:.6f}")
            print(f"  时间: {epoch_time:.2f}s")

            # 保存最佳模型
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                self.save_checkpoint('best_model.pth', is_best=True)
                print(f"  ✓ 保存最佳模型 (损失: {val_loss:.4f})")

            # 定期保存检查点
            if (epoch + 1) % self.config['save_interval'] == 0:
                self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pth')

            # 生成样本
            if (epoch + 1) % self.config['sample_interval'] == 0:
                self.generate_samples(epoch + 1)

        total_time = time.time() - start_time
        print(f"\n训练完成!总时间: {total_time/3600:.2f}小时")
        print(f"最佳验证损失: {self.best_loss:.4f}")

        self.writer.close()

1.2 采样方法

Python
class DDPMTrainer:
    def sample(self, num_samples=16, num_steps=None, method='ddpm'):
        """
        生成样本

        参数:
            num_samples: 生成样本数量
            num_steps: 采样步数(None表示使用T)
            method: 采样方法 ('ddpm' 或 'ddim')

        返回:
            生成的图像 [num_samples, channels, height, width]
        """
        self.model.eval()

        if num_steps is None:
            num_steps = self.T

        # 创建初始噪声
        x_T = torch.randn(num_samples, 3, self.config['image_size'],
                         self.config['image_size']).to(self.device)

        if method == 'ddpm':
            return self._ddpm_sample(x_T, num_steps)
        elif method == 'ddim':
            return self._ddim_sample(x_T, num_steps)
        else:
            raise ValueError(f"未知的采样方法: {method}")

    def _ddpm_sample(self, x_T, num_steps):
        """DDPM采样"""
        x_t = x_T

        # 计算子采样间隔
        if num_steps < self.T:
            step_indices = torch.linspace(0, self.T-1, num_steps, dtype=torch.long)
        else:
            step_indices = torch.arange(self.T)

        alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])  # torch.cat沿已有维度拼接张量

        with torch.no_grad():
            for i, t in enumerate(reversed(step_indices)):  # enumerate同时获取索引和元素
                t = t.item()
                alpha_t = self.alphas[t]
                beta_t = self.betas[t]
                alpha_t_bar = self.alphas_cumprod[t]
                alpha_t_bar_prev = alphas_cumprod_prev[t]

                # 预测噪声
                t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
                predicted_noise = self.model(x_t, t_tensor)

                # 计算均值
                sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
                sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)

                mean = sqrt_recip_alpha_t * (
                    x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
                )

                # 添加噪声
                if i < len(step_indices) - 1:
                    posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                    noise = torch.randn_like(x_t)
                    x_t = mean + torch.sqrt(posterior_variance) * noise
                else:
                    x_t = mean

        return x_t

    def _ddim_sample(self, x_T, num_steps):
        """DDIM采样"""
        x_t = x_T

        # 选择采样时间步
        step_indices = torch.linspace(0, self.T-1, num_steps, dtype=torch.long)

        with torch.no_grad():
            for i, t in enumerate(reversed(step_indices)):
                t = t.item()
                alpha_t = self.alphas[t]
                alpha_t_bar = self.alphas_cumprod[t]

                # 预测噪声
                t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
                predicted_noise = self.model(x_t, t_tensor)

                # 计算x_0的预测
                sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
                sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
                x_0_pred = (x_t - sqrt_one_minus_alpha_t_bar * predicted_noise) / sqrt_alpha_t_bar

                # 计算x_{t-1}
                if i < len(step_indices) - 1:
                    t_prev = step_indices[len(step_indices) - 2 - i].item()
                    alpha_t_prev_bar = self.alphas_cumprod[t_prev]

                    sqrt_alpha_t_prev_bar = torch.sqrt(alpha_t_prev_bar)
                    sqrt_one_minus_alpha_t_prev_bar = torch.sqrt(1 - alpha_t_prev_bar)

                    x_t = sqrt_alpha_t_prev_bar * x_0_pred + sqrt_one_minus_alpha_t_prev_bar * predicted_noise
                else:
                    x_t = x_0_pred

        return x_t

1.3 检查点管理

Python
class DDPMTrainer:
    def save_checkpoint(self, filename, is_best=False):
        """
        保存检查点

        参数:
            filename: 文件名
            is_best: 是否是最佳模型
        """
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_loss': self.best_loss,
            'config': self.config,
            'global_step': self.global_step
        }

        filepath = os.path.join(self.config['save_dir'], filename)
        torch.save(checkpoint, filepath)

        if is_best:
            best_filepath = os.path.join(self.config['save_dir'], 'best_model.pth')
            torch.save(checkpoint, best_filepath)

    def load_checkpoint(self, filename):
        """
        加载检查点

        参数:
            filename: 文件名
        """
        filepath = os.path.join(self.config['save_dir'], filename)

        if not os.path.exists(filepath):
            raise FileNotFoundError(f"检查点文件不存在: {filepath}")

        checkpoint = torch.load(filepath, map_location=self.device, weights_only=True)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.current_epoch = checkpoint['epoch'] + 1
        self.best_loss = checkpoint['best_loss']
        self.global_step = checkpoint['global_step']

        print(f"从检查点恢复训练: {filename}")
        print(f"  当前轮数: {self.current_epoch}")
        print(f"  最佳损失: {self.best_loss:.4f}")

1.4 生成和可视化样本

Python
class DDPMTrainer:
    def generate_samples(self, epoch, num_samples=16):
        """
        生成样本并可视化

        参数:
            epoch: 当前轮数
            num_samples: 样本数量
        """
        print(f"\n生成样本 (Epoch {epoch})...")

        # 生成样本
        samples = self.sample(num_samples=num_samples, num_steps=1000, method='ddpm')

        # 反归一化
        samples = (samples + 1) / 2
        samples = samples.clamp(0, 1)

        # 保存到TensorBoard
        self.writer.add_images(f'Generated_Samples/Epoch_{epoch}', samples, epoch)

        # 保存到文件
        from torchvision.utils import make_grid
        grid = make_grid(samples, nrow=4, padding=2, normalize=False)

        import matplotlib.pyplot as plt
        plt.figure(figsize=(12, 12))
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.title(f'Generated Samples - Epoch {epoch}')
        plt.axis('off')

        save_path = os.path.join(self.config['save_dir'], f'samples_epoch_{epoch}.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"  样本已保存到: {save_path}")

2. 配置管理

2.1 训练配置

Python
def get_default_config():
    """
    获取默认训练配置

    返回:
        配置字典
    """
    config = {
        # 模型配置
        'image_size': 32,
        'in_channels': 3,
        'out_channels': 3,
        'model_dim': 128,
        'num_heads': 4,

        # 训练配置
        'T': 1000,
        'beta_start': 0.0001,
        'beta_end': 0.02,
        'num_epochs': 100,
        'batch_size': 128,
        'learning_rate': 1e-4,
        'min_lr': 1e-6,
        'weight_decay': 0.01,
        'grad_clip': 1.0,

        # 保存和日志
        'save_dir': './checkpoints',
        'log_dir': './runs/ddpm',
        'save_interval': 10,
        'sample_interval': 5,

        # 其他
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    }

    return config

2.2 命令行参数

Python
import argparse

def parse_args():
    """解析命令行参数"""
    parser = argparse.ArgumentParser(description='DDPM训练')

    # 数据参数
    parser.add_argument('--dataset', type=str, default='cifar10',
                       choices=['cifar10', 'imagenet'],
                       help='数据集名称')
    parser.add_argument('--data_dir', type=str, default='./data',
                       help='数据集目录')

    # 模型参数
    parser.add_argument('--image_size', type=int, default=32,
                       help='图像大小')
    parser.add_argument('--model_dim', type=int, default=128,
                       help='模型维度')

    # 训练参数
    parser.add_argument('--batch_size', type=int, default=128,
                       help='批量大小')
    parser.add_argument('--num_epochs', type=int, default=100,
                       help='训练轮数')
    parser.add_argument('--learning_rate', type=float, default=1e-4,
                       help='学习率')
    parser.add_argument('--num_workers', type=int, default=4,
                       help='数据加载工作进程数')

    # 其他参数
    parser.add_argument('--resume', type=str, default=None,
                       help='恢复训练的检查点文件')
    parser.add_argument('--device', type=str, default=None,
                       help='设备 (cuda/cpu)')

    args = parser.parse_args()
    return args

3. 完整训练脚本

Python
def main():
    """主函数"""
    # 解析参数
    args = parse_args()

    # 获取配置
    config = get_default_config()

    # 更新配置
    config.update({
        'image_size': args.image_size,
        'model_dim': args.model_dim,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.learning_rate,
        'device': args.device if args.device else config['device'],
    })

    # 设置设备
    device = torch.device(config['device'])
    print(f"使用设备: {device}")

    # 加载数据
    print(f"\n加载 {args.dataset} 数据集...")
    from ddpm_data import prepare_data  # 示例: 请替换为项目中的实际导入
    train_loader, val_loader = prepare_data(
        dataset_name=args.dataset,
        batch_size=config['batch_size'],
        image_size=config['image_size'],
        num_workers=args.num_workers,
        augment=True
    )

    # 创建模型
    print("\n创建模型...")
    from unet_model import UNet  # 示例: 请替换为项目中的实际导入
    model = UNet(
        in_channels=config['in_channels'],
        out_channels=config['out_channels'],
        model_dim=config['model_dim']
    )

    # 创建训练器
    trainer = DDPMTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        device=device
    )

    # 恢复训练(如果指定)
    if args.resume:
        trainer.load_checkpoint(args.resume)

    # 开始训练
    print("\n开始训练...")
    trainer.train(num_epochs=config['num_epochs'])

    print("\n训练完成!")

if __name__ == '__main__':
    main()

4. 模型评估

4.1 生成质量评估

Python
def evaluate_generation_quality(model, dataloader, num_samples=100, device='cuda'):
    """
    评估生成质量

    参数:
        model: 训练好的模型
        dataloader: 数据加载器
        num_samples: 生成样本数量
        device: 设备

    返回:
        评估指标
    """
    model.eval()

    # 获取真实图像
    real_images = []
    for images, _ in dataloader:
        real_images.append(images)
        if len(real_images) * images.size(0) >= num_samples:
            break
    real_images = torch.cat(real_images, dim=0)[:num_samples]

    # 生成图像(ddpm_utils 为示例模块,请根据项目实际结构调整,或使用 DDPMTrainer.sample() 方法)
    # from ddpm_utils import ddpm_sample
    alphas, betas, alphas_cumprod = get_schedule(1000)
    x_T = torch.randn(num_samples, 3, 32, 32).to(device)
    generated_images = ddpm_sample(model, x_T, 1000, alphas, betas, alphas_cumprod, device)

    # 计算FID
    fid_score = calculate_fid(real_images, generated_images)

    # 计算IS (Inception Score)
    is_score = calculate_inception_score(generated_images)

    return {
        'FID': fid_score,
        'IS': is_score
    }

def calculate_fid(real_images, generated_images):
    """
    计算FID分数

    参数:
        real_images: 真实图像
        generated_images: 生成图像

    返回:
        FID分数
    """
    from torchvision.models import inception_v3, Inception_V3_Weights
    from scipy import linalg
    import numpy as np

    # 加载 Inception 模型,移除最后分类层以获取 2048 维 pool 特征
    inception = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
    inception.fc = torch.nn.Identity()  # 替换分类层,输出 2048 维特征
    inception = inception.eval()

    def get_features(images):
        """提取Inception特征"""
        features = []
        with torch.no_grad():
            for i in range(0, len(images), 32):
                batch = images[i:i+32]
                batch_resized = torch.nn.functional.interpolate(
                    batch, size=(299, 299), mode='bilinear', align_corners=False
                )
                feat = inception(batch_resized)
                features.append(feat)
        return torch.cat(features, dim=0).cpu().numpy()

    # 提取特征
    real_features = get_features(real_images)
    gen_features = get_features(generated_images)

    # 计算均值和协方差
    mu_real = np.mean(real_features, axis=0)
    mu_gen = np.mean(gen_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    sigma_gen = np.cov(gen_features, rowvar=False)

    # 计算FID
    diff = mu_real - mu_gen
    covmean, _ = linalg.sqrtm(sigma_real @ sigma_gen, disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_real + sigma_gen - 2 * covmean)

    return fid

def calculate_inception_score(images, splits=10):
    """
    计算Inception Score

    参数:
        images: 图像
        splits: 分割数

    返回:
        IS分数
    """
    from torchvision.models import inception_v3, Inception_V3_Weights
    import numpy as np

    # 加载Inception模型(IS需要分类概率,保留完整模型)
    inception = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False).eval()

    # 获取预测
    preds = []
    with torch.no_grad():
        for i in range(0, len(images), 32):
            batch = images[i:i+32]
            batch_resized = torch.nn.functional.interpolate(
                batch, size=(299, 299), mode='bilinear', align_corners=False
            )
            pred = torch.nn.functional.softmax(inception(batch_resized), dim=1)
            preds.append(pred)

    preds = torch.cat(preds, dim=0).cpu().numpy()

    # 计算IS
    split_scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits): (i + 1) * (len(preds) // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        split_scores.append(np.exp(kl))

    return np.mean(split_scores), np.std(split_scores)

4.2 可视化评估

Python
def visualize_comparison(model, dataloader, num_samples=8, device='cuda'):
    """
    可视化真实图像和生成图像的对比

    参数:
        model: 模型
        dataloader: 数据加载器
        num_samples: 样本数量
        device: 设备
    """
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid

    # 获取真实图像
    real_images, _ = next(iter(dataloader))
    real_images = real_images[:num_samples]

    # 生成图像(ddpm_utils 为示例模块,请根据项目实际结构调整,或使用 DDPMTrainer.sample() 方法)
    # from ddpm_utils import ddpm_sample
    alphas, betas, alphas_cumprod = get_schedule(1000)
    x_T = torch.randn(num_samples, 3, 32, 32).to(device)
    generated_images = ddpm_sample(model, x_T, 1000, alphas, betas, alphas_cumprod, device)

    # 反归一化
    real_images = (real_images + 1) / 2
    real_images = real_images.clamp(0, 1)
    generated_images = (generated_images + 1) / 2
    generated_images = generated_images.clamp(0, 1)

    # 创建对比图
    fig, axes = plt.subplots(2, num_samples, figsize=(2*num_samples, 4))

    for i in range(num_samples):
        # 真实图像
        axes[0, i].imshow(real_images[i].permute(1, 2, 0))
        axes[0, i].set_title('真实')
        axes[0, i].axis('off')

        # 生成图像
        axes[1, i].imshow(generated_images[i].cpu().permute(1, 2, 0))
        axes[1, i].set_title('生成')
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.savefig('comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

5. 训练监控

5.1 实时监控

Python
def monitor_training(trainer):
    """
    实时监控训练

    参数:
        trainer: 训练器
    """
    import matplotlib.pyplot as plt

    # 读取TensorBoard日志
    from tensorboard.backend.event_processing import event_accumulator
    ea = event_accumulator.EventAccumulator(trainer.config['log_dir'])
    ea.Reload()

    # 获取训练损失
    train_loss = [s.value for s in ea.Scalars('Loss/train')]
    val_loss = [s.value for s in ea.Scalars('Loss/val')]

    # 绘制损失曲线
    plt.figure(figsize=(10, 5))
    plt.plot(train_loss, label='训练损失', alpha=0.8)
    plt.plot(val_loss, label='验证损失', alpha=0.8)
    plt.xlabel('步数')
    plt.ylabel('损失')
    plt.title('训练和验证损失')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('loss_curve.png', dpi=150, bbox_inches='tight')
    plt.show()

5.2 异常检测

Python
def detect_anomalies(trainer, threshold=3.0):
    """
    检测训练异常

    参数:
        trainer: 训练器
        threshold: 异常阈值(标准差倍数)
    """
    from tensorboard.backend.event_processing import event_accumulator
    import numpy as np

    # 读取TensorBoard日志
    ea = event_accumulator.EventAccumulator(trainer.config['log_dir'])
    ea.Reload()

    # 获取训练损失
    train_loss = np.array([s.value for s in ea.Scalars('Loss/train')])  # np.array创建NumPy数组

    # 计算移动平均和标准差
    window = 100
    moving_avg = np.convolve(train_loss, np.ones(window)/window, mode='valid')
    # 计算移动标准差(滑动窗口内的标准差)
    moving_std = np.array([
        np.std(train_loss[i:i+window]) for i in range(len(train_loss) - window + 1)
    ])

    # 检测异常
    anomalies = []
    for i in range(len(moving_avg)):
        if abs(train_loss[i + window - 1] - moving_avg[i]) > threshold * moving_std[i]:
            anomalies.append(i + window - 1)

    if anomalies:
        print(f"检测到 {len(anomalies)} 个异常点:")
        for idx in anomalies[:10]:  # 只显示前10个
            print(f"  步数 {idx}: 损失 = {train_loss[idx]:.4f}")
    else:
        print("未检测到异常")

    return anomalies

6. 总结

6.1 核心概念回顾

概念 说明
训练流程 数据加载→前向传播→损失计算→反向传播→参数更新
检查点 定期保存模型状态,便于恢复训练
评估指标 FID、IS等量化生成质量
监控 TensorBoard可视化训练过程
异常检测 识别训练中的问题

6.2 最佳实践

  1. 定期保存:定期保存检查点,防止训练中断
  2. 监控训练:使用TensorBoard监控训练过程
  3. 评估质量:使用FID、IS等指标评估生成质量
  4. 可视化结果:定期生成样本并可视化
  5. 调优超参数:根据训练结果调整超参数

6.3 学习建议

  1. 从小规模开始:先用小数据集和小模型验证流程
  2. 逐步扩展:成功后再扩大规模
  3. 记录实验:记录每次实验的配置和结果
  4. 对比分析:对比不同配置的效果

7. 推荐资源

工具

  • TensorBoard: 训练监控
  • Weights & Biases: 实验跟踪
  • MLflow: 模型管理

论文

  • DDPM: "Denoising Diffusion Probabilistic Models"
  • ADM: "Diffusion Models Beat GANs on Image Synthesis"

8. 自测问题

  1. 如何设计一个完整的训练流程?
  2. 检查点保存和恢复有什么作用?
  3. FID和IS分别衡量什么?
  4. 如何监控训练过程?
  5. 如何检测训练异常?

下一章: 04-扩散模型变体与进阶 - 探索扩散模型的各种改进和变体