06 - 训练技巧与优化¶
学习时间: 3小时 重要性: ⭐⭐⭐⭐⭐ 高效训练扩散模型的关键
🎯 学习目标¶
完成本章后,你将能够: - 理解扩散模型训练的核心技巧 - 掌握数据增强和预处理方法 - 学习优化器和学习率调度策略 - 理解训练稳定性的关键因素 - 实现高效的训练流程
1. 训练基础回顾¶
1.1 DDPM训练目标¶
扩散模型的训练目标是最小化以下损失:
\[L = \mathbb{E}_{t, x_0, \epsilon} \left[ \|\epsilon - \epsilon_\theta(x_t, t)\|^2 \right]\]
其中: - \(x_0\):原始图像 - \(\epsilon\):添加的高斯噪声 - \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon\):加噪后的图像 - \(\epsilon_\theta(x_t, t)\):模型预测的噪声
1.2 训练流程¶
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
def ddpm_loss(model, x_0, t, alphas_cumprod, device='cpu'):
"""
计算DDPM损失
参数:
model: 去噪模型
x_0: 原始图像 [batch_size, channels, height, width]
t: 时间步 [batch_size]
alphas_cumprod: 累积alpha值
device: 设备
返回:
损失值
"""
# 生成噪声
noise = torch.randn_like(x_0)
# 计算加噪后的图像
sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1) # 重塑张量形状
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
# 模型预测噪声
predicted_noise = model(x_t, t)
# 计算MSE损失
loss = nn.functional.mse_loss(predicted_noise, noise)
return loss
2. 数据增强技巧¶
2.1 基础数据增强¶
Python
from torchvision import transforms
def get_train_transforms(image_size=32, augment=True):
"""
获取训练数据增强
参数:
image_size: 目标图像大小
augment: 是否使用数据增强
返回:
数据增强变换
"""
if augment:
return transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
else:
return transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def get_val_transforms(image_size=32):
"""获取验证数据变换"""
return transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
2.2 高级数据增强¶
Python
class AdvancedAugmentation:
"""高级数据增强"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img): # __call__使实例可像函数一样调用
"""应用随机增强"""
import random
if random.random() < self.p:
img = transforms.RandomRotation(15)(img)
if random.random() < self.p:
img = transforms.RandomVerticalFlip(p=0.5)(img)
if random.random() < self.p:
img = transforms.RandomAffine(
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
)(img)
return img
# 使用示例
advanced_transform = transforms.Compose([
AdvancedAugmentation(p=0.3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
3. 优化器选择与配置¶
3.1 AdamW优化器¶
AdamW是训练扩散模型最常用的优化器:
Python
def get_optimizer(model, learning_rate=1e-4, weight_decay=0.01):
"""
获取AdamW优化器
参数:
model: 模型
learning_rate: 学习率
weight_decay: 权重衰减
返回:
优化器
"""
optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=weight_decay
)
return optimizer
3.2 学习率调度¶
3.2.1 余弦退火调度¶
Python
def get_cosine_scheduler(optimizer, T_max, eta_min=1e-6):
"""
余弦退火学习率调度器
参数:
optimizer: 优化器
T_max: 最大迭代次数
eta_min: 最小学习率
返回:
调度器
"""
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=T_max,
eta_min=eta_min
)
return scheduler
# 使用示例
optimizer = get_optimizer(model, learning_rate=1e-4)
scheduler = get_cosine_scheduler(optimizer, T_max=100000)
3.2.2 指数衰减调度¶
Python
def get_exponential_scheduler(optimizer, gamma=0.999):
"""
指数衰减学习率调度器
参数:
optimizer: 优化器
gamma: 衰减因子
返回:
调度器
"""
scheduler = optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=gamma
)
return scheduler
3.2.3 预热调度¶
Python
def get_warmup_scheduler(optimizer, warmup_steps, target_lr):
"""
预热学习率调度器
参数:
optimizer: 优化器
warmup_steps: 预热步数
target_lr: 目标学习率
返回:
调度器
"""
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
return 1.0
scheduler = LambdaLR(optimizer, lr_lambda)
return scheduler
4. 训练稳定性技巧¶
4.1 梯度裁剪¶
防止梯度爆炸:
Python
def train_step_with_grad_clip(model, optimizer, x_0, t, alphas_cumprod,
max_grad_norm=1.0, device='cpu'):
"""
带梯度裁剪的训练步骤
参数:
model: 模型
optimizer: 优化器
x_0: 原始图像
t: 时间步
alphas_cumprod: 累积alpha值
max_grad_norm: 最大梯度范数
device: 设备
返回:
损失值
"""
model.train() # train()训练模式
optimizer.zero_grad() # 清零梯度
# 前向传播
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
# 反向传播
loss.backward() # 反向传播计算梯度
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# 更新参数
optimizer.step() # 更新参数
return loss.item() # 将单元素张量转为Python数值
4.2 EMA(指数移动平均)¶
使用EMA可以改善生成质量:
Python
class EMA:
"""指数移动平均"""
def __init__(self, model, decay=0.9999):
"""
初始化EMA
参数:
model: 模型
decay: 衰减系数
"""
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
# 初始化shadow参数
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
"""更新EMA参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
"""应用shadow参数到模型"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
"""恢复原始参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
param.data = self.backup[name]
self.backup = {}
# 使用示例
ema = EMA(model, decay=0.9999)
# 训练循环中
for epoch in range(num_epochs):
for batch in dataloader:
# 正常训练
loss = train_step(model, optimizer, ...)
# 更新EMA
ema.update()
# 评估时使用EMA参数
ema.apply_shadow()
evaluate(model)
ema.restore()
4.3 混合精度训练¶
使用混合精度可以加速训练并减少显存占用:
Python
def train_mixed_precision(model, optimizer, dataloader, alphas_cumprod,
num_epochs, device='cuda'):
"""
混合精度训练
参数:
model: 模型
optimizer: 优化器
dataloader: 数据加载器
alphas_cumprod: 累积alpha值
num_epochs: 训练轮数
device: 设备
"""
from torch.amp import autocast, GradScaler
model.to(device) # 移至GPU/CPU
scaler = GradScaler()
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
for x_0, _ in dataloader:
x_0 = x_0.to(device)
# 随机采样时间步
batch_size = x_0.shape[0]
t = torch.randint(0, len(alphas_cumprod), (batch_size,), device=device)
optimizer.zero_grad()
# 混合精度前向传播
with autocast('cuda'):
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
# 混合精度反向传播
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
5. 批量训练技巧¶
5.1 梯度累积¶
当显存不足时,可以使用梯度累积:
Python
def train_with_gradient_accumulation(model, optimizer, dataloader, alphas_cumprod,
accumulation_steps=4, num_epochs=10, device='cuda'):
"""
带梯度累积的训练
参数:
model: 模型
optimizer: 优化器
dataloader: 数据加载器
alphas_cumprod: 累积alpha值
accumulation_steps: 累积步数
num_epochs: 训练轮数
device: 设备
"""
model.to(device)
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
epoch_loss = 0
for i, (x_0, _) in enumerate(dataloader): # enumerate同时获取索引和元素
x_0 = x_0.to(device)
# 随机采样时间步
batch_size = x_0.shape[0]
t = torch.randint(0, len(alphas_cumprod), (batch_size,), device=device)
# 计算损失
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
loss = loss / accumulation_steps # 归一化损失
loss.backward()
epoch_loss += loss.item() * accumulation_steps
# 累积足够的梯度后更新
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
5.2 动态批量大小¶
Python
def get_dynamic_batch_size(model, image_size, max_batch_size=32, device='cuda'):
"""
根据显存动态确定批量大小
参数:
model: 模型
image_size: 图像大小
max_batch_size: 最大批量大小
device: 设备
返回:
合适的批量大小
"""
if device == 'cpu':
return min(max_batch_size, 8)
# 测试不同批量大小
for batch_size in range(max_batch_size, 0, -4):
try: # try/except捕获异常
x = torch.randn(batch_size, 3, image_size, image_size).to(device)
t = torch.randint(0, 1000, (batch_size,)).to(device)
with torch.no_grad(): # 禁用梯度计算,节省内存
_ = model(x, t)
return batch_size
except RuntimeError as e:
if 'out of memory' in str(e):
torch.cuda.empty_cache()
continue
else:
raise e
return 1 # 最小批量大小
6. 损失函数改进¶
6.1 时间步加权损失¶
给不同时间步赋予不同权重:
Python
def weighted_ddpm_loss(model, x_0, t, alphas_cumprod, weights=None, device='cpu'):
"""
加权DDPM损失
参数:
model: 模型
x_0: 原始图像
t: 时间步
alphas_cumprod: 累积alpha值
weights: 时间步权重 [T]
device: 设备
返回:
加权损失
"""
# 生成噪声
noise = torch.randn_like(x_0)
# 计算加噪后的图像
sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
# 模型预测噪声
predicted_noise = model(x_t, t)
# 计算MSE损失
loss = nn.functional.mse_loss(predicted_noise, noise, reduction='none')
loss = loss.mean(dim=[1, 2, 3]) # [batch_size]
# 应用权重
if weights is not None:
t_weights = weights[t].to(device)
loss = loss * t_weights
return loss.mean()
# 创建时间步权重(SNR加权)
def create_snr_weights(alphas_cumprod):
"""
基于SNR创建权重
参数:
alphas_cumprod: 累积alpha值
返回:
权重
"""
snr = alphas_cumprod / (1 - alphas_cumprod)
weights = snr / snr.max()
return weights
6.2 Huber损失¶
对异常值更鲁棒:
Python
def huber_loss(model, x_0, t, alphas_cumprod, delta=1.0, device='cpu'):
"""
Huber损失
参数:
model: 模型
x_0: 原始图像
t: 时间步
alphas_cumprod: 累积alpha值
delta: Huber损失的阈值
device: 设备
返回:
Huber损失
"""
# 生成噪声
noise = torch.randn_like(x_0)
# 计算加噪后的图像
sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
# 模型预测噪声
predicted_noise = model(x_t, t)
# 计算Huber损失
residual = torch.abs(predicted_noise - noise)
quadratic = torch.min(residual, torch.tensor(delta))
linear = (residual - quadratic)
loss = 0.5 * quadratic ** 2 + delta * linear
return loss.mean()
7. 完整训练流程¶
Python
def train_diffusion_model(model, train_loader, val_loader, num_epochs,
learning_rate=1e-4, device='cuda',
use_ema=True, use_mixed_precision=True):
"""
完整的扩散模型训练流程
参数:
model: 模型
train_loader: 训练数据加载器
val_loader: 验证数据加载器
num_epochs: 训练轮数
learning_rate: 学习率
device: 设备
use_ema: 是否使用EMA
use_mixed_precision: 是否使用混合精度
"""
# 创建调度表
T = 1000
alphas, betas, alphas_cumprod = get_schedule(T)
alphas_cumprod = alphas_cumprod.to(device)
# 创建优化器和调度器
optimizer = get_optimizer(model, learning_rate=learning_rate)
scheduler = get_cosine_scheduler(optimizer, T_max=num_epochs * len(train_loader))
# 创建EMA
ema = EMA(model, decay=0.9999) if use_ema else None
# 混合精度训练
scaler = GradScaler() if use_mixed_precision else None
# 训练循环
model.to(device)
best_loss = float('inf')
for epoch in range(num_epochs):
# 训练阶段
model.train()
train_loss = 0
for x_0, _ in train_loader:
x_0 = x_0.to(device)
batch_size = x_0.shape[0]
t = torch.randint(0, T, (batch_size,), device=device)
optimizer.zero_grad()
if use_mixed_precision:
with autocast('cuda'):
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
loss.backward()
optimizer.step()
# 更新EMA
if ema is not None:
ema.update()
train_loss += loss.item()
scheduler.step()
train_loss /= len(train_loader)
# 验证阶段
model.eval()
val_loss = 0
with torch.no_grad():
for x_0, _ in val_loader:
x_0 = x_0.to(device)
batch_size = x_0.shape[0]
t = torch.randint(0, T, (batch_size,), device=device)
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
val_loss += loss.item()
val_loss /= len(val_loader)
# 打印进度
print(f"Epoch {epoch+1}/{num_epochs}")
print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}")
print(f" LR: {optimizer.param_groups[0]['lr']:.6f}")
# 保存最佳模型
if val_loss < best_loss:
best_loss = val_loss
if ema is not None:
ema.apply_shadow()
torch.save(model.state_dict(), 'best_model.pth')
if ema is not None:
ema.restore()
print(f" ✓ Saved best model (loss: {best_loss:.4f})")
# 定期保存检查点
if (epoch + 1) % 10 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': val_loss,
}
if ema is not None:
checkpoint['ema_shadow'] = ema.shadow
torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
print("\n训练完成!")
print(f"最佳验证损失: {best_loss:.4f}")
8. 训练监控与可视化¶
8.1 TensorBoard集成¶
Python
from torch.utils.tensorboard import SummaryWriter
def train_with_tensorboard(model, train_loader, val_loader, num_epochs,
log_dir='runs/diffusion', device='cuda'):
"""
使用TensorBoard监控训练
参数:
model: 模型
train_loader: 训练数据加载器
val_loader: 验证数据加载器
num_epochs: 训练轮数
log_dir: TensorBoard日志目录
device: 设备
"""
writer = SummaryWriter(log_dir)
# 创建调度表
T = 1000
alphas, betas, alphas_cumprod = get_schedule(T)
alphas_cumprod = alphas_cumprod.to(device)
# 创建优化器
optimizer = get_optimizer(model)
scheduler = get_cosine_scheduler(optimizer, T_max=num_epochs * len(train_loader))
model.to(device)
global_step = 0
for epoch in range(num_epochs):
model.train()
train_loss = 0
for x_0, _ in train_loader:
x_0 = x_0.to(device)
batch_size = x_0.shape[0]
t = torch.randint(0, T, (batch_size,), device=device)
optimizer.zero_grad()
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
loss.backward()
optimizer.step()
scheduler.step()
train_loss += loss.item()
# 记录训练损失
writer.add_scalar('Loss/train', loss.item(), global_step)
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
global_step += 1
train_loss /= len(train_loader)
# 验证
model.eval()
val_loss = 0
with torch.no_grad():
for x_0, _ in val_loader:
x_0 = x_0.to(device)
batch_size = x_0.shape[0]
t = torch.randint(0, T, (batch_size,), device=device)
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
val_loss += loss.item()
val_loss /= len(val_loader)
# 记录验证损失
writer.add_scalar('Loss/val', val_loss, epoch)
# 生成样本并记录
if (epoch + 1) % 5 == 0:
model.eval()
with torch.no_grad():
x_T = torch.randn(4, 3, 32, 32).to(device)
samples = ddpm_sample(model, x_T, T, alphas, betas, alphas_cumprod, device)
samples = (samples + 1) / 2 # 归一化到[0, 1]
writer.add_images('Generated_Samples', samples, epoch)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
writer.close()
8.2 损失曲线可视化¶
Python
import matplotlib.pyplot as plt
def plot_loss_curves(train_losses, val_losses, save_path='loss_curves.png'):
"""
绘制损失曲线
参数:
train_losses: 训练损失列表
val_losses: 验证损失列表
save_path: 保存路径
"""
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss', alpha=0.8)
plt.plot(val_losses, label='Val Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
9. 常见问题与解决方案¶
9.1 训练不稳定¶
症状:损失震荡或发散
解决方案: 1. 降低学习率 2. 使用梯度裁剪 3. 检查数据预处理 4. 确保batch normalization正常工作
Python
# 降低学习率
optimizer = get_optimizer(model, learning_rate=1e-5) # 从1e-4降到1e-5
# 添加梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm=1.0)
9.2 生成质量差¶
症状:生成的图像模糊或有噪声
解决方案: 1. 增加训练轮数 2. 使用EMA 3. 调整噪声调度 4. 增加模型容量
Python
# 使用EMA
ema = EMA(model, decay=0.9999)
# 调整噪声调度(使用余弦调度)
alphas, betas, alphas_cumprod = get_cosine_schedule(T)
9.3 训练速度慢¶
症状:训练时间过长
解决方案: 1. 使用混合精度训练 2. 增加批量大小 3. 使用更高效的模型架构 4. 使用分布式训练
Python
# 混合精度训练
scaler = GradScaler()
with autocast('cuda'):
loss = ddpm_loss(model, x_0, t, alphas_cumprod, device)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
10. 总结¶
10.1 核心技巧回顾¶
| 技巧 | 作用 | 实现难度 |
|---|---|---|
| 数据增强 | 提高泛化能力 | 简单 |
| AdamW优化器 | 稳定训练 | 简单 |
| 学习率调度 | 改善收敛 | 简单 |
| 梯度裁剪 | 防止梯度爆炸 | 简单 |
| EMA | 改善生成质量 | 中等 |
| 混合精度 | 加速训练 | 中等 |
| 梯度累积 | 处理显存不足 | 中等 |
| 损失加权 | 改善训练稳定性 | 较难 |
10.2 最佳实践¶
- 从简单开始:先使用基础配置训练,确保流程正确
- 逐步优化:在基础配置上逐步添加优化技巧
- 监控训练:使用TensorBoard等工具监控训练过程
- 定期保存:定期保存检查点,防止训练中断
- 验证生成:定期生成样本,评估生成质量
10.3 学习建议¶
- 理解原理:先理解每个技巧的原理
- 动手实践:在实际项目中应用这些技巧
- 对比实验:对比不同技巧的效果
- 记录经验:记录有效的配置和技巧
11. 推荐资源¶
论文¶
- DDPM: "Denoising Diffusion Probabilistic Models"
- Improved DDPM: "Improved Denoising Diffusion Probabilistic Models"
- ADM: "Diffusion Models Beat GANs on Image Synthesis"
代码库¶
- OpenAI's guided-diffusion
- NVIDIA's Latent-Diffusion
- Hugging Face Diffusers
工具¶
- TensorBoard
- Weights & Biases
- MLflow
12. 自测问题¶
- 为什么EMA能改善生成质量?
- 混合精度训练有什么优点和缺点?
- 如何选择合适的学习率调度策略?
- 梯度裁剪的作用是什么?
- 如何诊断训练不稳定的问题?
下一章: 03-DDPM从零实现 - 动手实现完整的DDPM模型