01 - 图像生成实战¶
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 将理论转化为实际应用的关键
🎯 项目目标¶
完成本项目后,你将能够: - 从零开始训练一个DDPM模型 - 在CIFAR-10数据集上实现图像生成 - 掌握完整的训练和评估流程 - 优化模型性能 - 部署模型进行推理
1. 项目概述¶
1.1 项目简介¶
本项目将训练一个扩散模型,在CIFAR-10数据集上生成高质量的图像。
技术栈: - PyTorch: 深度学习框架 - CIFAR-10: 训练数据集 - UNet: 模型架构 - DDPM: 扩散模型算法
1.2 项目结构¶
Text Only
image_generation_project/
├── data/ # 数据目录
├── checkpoints/ # 模型检查点
├── logs/ # 训练日志
├── samples/ # 生成的样本
├── config.py # 配置文件
├── model.py # 模型定义
├── train.py # 训练脚本
├── sample.py # 采样脚本
├── evaluate.py # 评估脚本
└── README.md # 项目说明
2. 环境准备¶
2.1 安装依赖¶
Bash
# 创建虚拟环境
conda create -n diffusion python=3.9
conda activate diffusion
# 安装PyTorch
pip install torch torchvision torchaudio
# 安装其他依赖
pip install numpy matplotlib tqdm tensorboard
2.2 下载CIFAR-10数据集¶
Python
# download_cifar10.py
import torchvision
import torchvision.transforms as transforms
def download_cifar10(data_dir='./data'):
"""下载CIFAR-10数据集"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = torchvision.datasets.CIFAR10(
root=data_dir,
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.CIFAR10(
root=data_dir,
train=False,
download=True,
transform=transform
)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
if __name__ == '__main__':
download_cifar10()
运行脚本:
3. 配置文件¶
Python
# config.py
import torch
class Config:
"""项目配置"""
# 数据配置
image_size = 32
in_channels = 3
out_channels = 3
num_classes = 10
# 模型配置
model_dim = 128
num_heads = 4
num_layers = 4
# 扩散配置
T = 1000
beta_start = 0.0001
beta_end = 0.02
# 训练配置
batch_size = 128
num_epochs = 100
learning_rate = 1e-4
weight_decay = 0.01
grad_clip = 1.0
# 优化器配置
optimizer = 'adamw'
scheduler = 'cosine'
# 数据增强
augment = True
# 保存和日志
save_dir = './checkpoints'
log_dir = './logs'
sample_dir = './samples'
save_interval = 10
sample_interval = 5
# 设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 采样配置
num_samples = 16
sampling_steps = 1000
sampling_method = 'ddpm' # 'ddpm' or 'ddim'
# 创建全局配置实例
config = Config()
4. 模型定义¶
Python
# model.py
import torch
import torch.nn as nn
import math
class SinusoidalPositionEmbedding(nn.Module): # 继承nn.Module定义网络层
"""正弦位置编码"""
def __init__(self, dim):
super().__init__() # super()调用父类方法
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # torch.cat沿已有维度拼接张量
return emb
class ResidualBlock(nn.Module):
"""残差块"""
def __init__(self, in_channels, out_channels, time_emb_dim):
super().__init__()
self.norm1 = nn.GroupNorm(8, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = nn.GroupNorm(8, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# 时间步嵌入投影
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels)
)
# 残差连接
if in_channels != out_channels:
self.skip_conv = nn.Conv2d(in_channels, out_channels, 1)
else:
self.skip_conv = nn.Identity()
def forward(self, x, time_emb):
h = self.norm1(x)
h = nn.SiLU()(h)
h = self.conv1(h)
# 添加时间步嵌入
time_emb = self.time_mlp(time_emb)
h = h + time_emb[:, :, None, None]
h = self.norm2(h)
h = nn.SiLU()(h)
h = self.conv2(h)
return h + self.skip_conv(x)
class UNet(nn.Module):
"""UNet模型"""
def __init__(self, in_channels=3, out_channels=3, model_dim=128, time_emb_dim=128):
super().__init__()
# 时间步嵌入
self.time_embedding = SinusoidalPositionEmbedding(time_emb_dim)
self.time_mlp = nn.Sequential(
nn.Linear(time_emb_dim, model_dim * 4),
nn.SiLU(),
nn.Linear(model_dim * 4, model_dim)
)
# 初始卷积
self.conv_in = nn.Conv2d(in_channels, model_dim, 3, padding=1)
# 下采样
self.down_blocks = nn.ModuleList([
ResidualBlock(model_dim, model_dim, model_dim),
ResidualBlock(model_dim, model_dim * 2, model_dim),
ResidualBlock(model_dim * 2, model_dim * 4, model_dim),
])
self.down_samples = nn.ModuleList([
nn.Conv2d(model_dim, model_dim, 3, stride=2, padding=1),
nn.Conv2d(model_dim * 2, model_dim * 2, 3, stride=2, padding=1),
nn.Conv2d(model_dim * 4, model_dim * 4, 3, stride=2, padding=1),
])
# 中间层
self.mid_block1 = ResidualBlock(model_dim * 4, model_dim * 4, model_dim)
self.mid_block2 = ResidualBlock(model_dim * 4, model_dim * 4, model_dim)
# 上采样
self.up_blocks = nn.ModuleList([
ResidualBlock(model_dim * 4, model_dim * 4, model_dim),
ResidualBlock(model_dim * 4, model_dim * 2, model_dim),
ResidualBlock(model_dim * 2, model_dim, model_dim),
])
self.up_samples = nn.ModuleList([
nn.ConvTranspose2d(model_dim * 4, model_dim * 4, 4, stride=2, padding=1),
nn.ConvTranspose2d(model_dim * 4, model_dim * 2, 4, stride=2, padding=1),
nn.ConvTranspose2d(model_dim * 2, model_dim, 4, stride=2, padding=1),
])
# 输出卷积
self.conv_out = nn.Conv2d(model_dim, out_channels, 3, padding=1)
def forward(self, x, t):
# 时间步嵌入
time_emb = self.time_embedding(t)
time_emb = self.time_mlp(time_emb)
# 初始卷积
h = self.conv_in(x)
# 下采样
skips = []
for i, (down_block, down_sample) in enumerate(zip(self.down_blocks, self.down_samples)): # enumerate同时获取索引和元素 # zip按位置配对
h = down_block(h, time_emb)
skips.append(h)
h = down_sample(h)
# 中间层
h = self.mid_block1(h, time_emb)
h = self.mid_block2(h, time_emb)
# 上采样
for i, (up_block, up_sample) in enumerate(zip(self.up_blocks, self.up_samples)):
h = up_sample(h)
h = h + skips[-(i+1)]
h = up_block(h, time_emb)
# 输出
h = self.conv_out(h)
return h
5. 训练脚本¶
Python
# train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
import time
from config import config
from model import UNet
def get_schedule(T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def get_dataloaders():
"""获取数据加载器"""
# 训练数据变换
if config.augment:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(config.image_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
else:
train_transform = transforms.Compose([
transforms.Resize(config.image_size),
transforms.CenterCrop(config.image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 测试数据变换
test_transform = transforms.Compose([
transforms.Resize(config.image_size),
transforms.CenterCrop(config.image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 加载数据集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=test_transform
)
# 创建数据加载器
train_loader = DataLoader( # DataLoader批量加载数据
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
return train_loader, test_loader
def train_step(model, x_0, alphas_cumprod, device):
"""单个训练步骤"""
model.train() # train()训练模式
# 随机采样时间步
batch_size = x_0.shape[0]
t = torch.randint(0, config.T, (batch_size,), device=device)
# 生成噪声
noise = torch.randn_like(x_0)
# 计算加噪后的图像
sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1) # 重塑张量形状
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
# 模型预测噪声
predicted_noise = model(x_t, t)
# 计算损失
loss = nn.functional.mse_loss(predicted_noise, noise)
return loss
def validate(model, val_loader, alphas_cumprod, device):
"""验证模型"""
model.eval()
total_loss = 0
with torch.no_grad(): # 禁用梯度计算,节省内存
for x_0, _ in val_loader:
x_0 = x_0.to(device) # 移至GPU/CPU
# 随机采样时间步
batch_size = x_0.shape[0]
t = torch.randint(0, config.T, (batch_size,), device=device)
# 生成噪声
noise = torch.randn_like(x_0)
# 计算加噪后的图像
sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
# 模型预测噪声
predicted_noise = model(x_t, t)
# 计算损失
loss = nn.functional.mse_loss(predicted_noise, noise)
total_loss += loss.item() # 将单元素张量转为Python数值
return total_loss / len(val_loader)
def train():
"""训练函数"""
# 创建保存目录
os.makedirs(config.save_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)
os.makedirs(config.sample_dir, exist_ok=True)
# 设置设备
device = torch.device(config.device)
print(f"使用设备: {device}")
# 加载数据
print("加载数据...")
train_loader, val_loader = get_dataloaders()
print(f"训练集大小: {len(train_loader.dataset)}")
print(f"验证集大小: {len(val_loader.dataset)}")
# 创建模型
print("创建模型...")
model = UNet(
in_channels=config.in_channels,
out_channels=config.out_channels,
model_dim=config.model_dim,
time_emb_dim=config.model_dim
).to(device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建噪声调度
alphas, betas, alphas_cumprod = get_schedule(config.T, config.beta_start, config.beta_end)
alphas_cumprod = alphas_cumprod.to(device)
# 创建优化器
optimizer = optim.AdamW(
model.parameters(),
lr=config.learning_rate,
betas=(0.9, 0.999),
weight_decay=config.weight_decay
)
# 创建学习率调度器
total_steps = config.num_epochs * len(train_loader)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=total_steps,
eta_min=1e-6
)
# 创建TensorBoard写入器
writer = SummaryWriter(config.log_dir)
# 训练循环
best_loss = float('inf')
global_step = 0
print(f"\n开始训练,共 {config.num_epochs} 轮")
start_time = time.time()
for epoch in range(config.num_epochs):
epoch_loss = 0
epoch_start_time = time.time()
# 训练阶段
model.train()
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.num_epochs}')
for x_0, _ in pbar:
x_0 = x_0.to(device)
# 训练步骤
loss = train_step(model, x_0, alphas_cumprod, device)
# 反向传播
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
# 梯度裁剪
if config.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
# 更新参数
optimizer.step() # 更新参数
scheduler.step()
epoch_loss += loss.item()
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
# 记录到TensorBoard
writer.add_scalar('Loss/train', loss.item(), global_step)
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
global_step += 1
epoch_loss /= len(train_loader)
epoch_time = time.time() - epoch_start_time
# 验证阶段
val_loss = validate(model, val_loader, alphas_cumprod, device)
# 记录到TensorBoard
writer.add_scalar('Loss/val', val_loss, epoch)
writer.add_scalar('Epoch_Time', epoch_time, epoch)
# 打印进度
print(f"\nEpoch {epoch+1}/{config.num_epochs}")
print(f" 训练损失: {epoch_loss:.4f}")
print(f" 验证损失: {val_loss:.4f}")
print(f" 学习率: {optimizer.param_groups[0]['lr']:.6f}")
print(f" 时间: {epoch_time:.2f}s")
# 保存最佳模型
if val_loss < best_loss:
best_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
}, os.path.join(config.save_dir, 'best_model.pth'))
print(f" ✓ 保存最佳模型 (损失: {val_loss:.4f})")
# 定期保存检查点
if (epoch + 1) % config.save_interval == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': val_loss,
}, os.path.join(config.save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
total_time = time.time() - start_time
print(f"\n训练完成!总时间: {total_time/3600:.2f}小时")
print(f"最佳验证损失: {best_loss:.4f}")
writer.close()
if __name__ == '__main__':
train()
6. 采样脚本¶
Python
# sample.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os
from config import config
from model import UNet
def get_schedule(T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def ddpm_sample(model, x_T, T, alphas, betas, alphas_cumprod, device):
"""DDPM采样"""
model.eval()
x_t = x_T.to(device)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
with torch.no_grad():
for t in reversed(range(T)):
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
alpha_t_bar_prev = alphas_cumprod_prev[t]
# 预测噪声
t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
predicted_noise = model(x_t, t_tensor)
# 计算均值
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
# 添加噪声
if t > 0:
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
return x_t
def sample():
"""采样函数"""
# 设置设备
device = torch.device(config.device)
# 创建模型
model = UNet(
in_channels=config.in_channels,
out_channels=config.out_channels,
model_dim=config.model_dim,
time_emb_dim=config.model_dim
).to(device)
# 加载检查点
checkpoint_path = os.path.join(config.save_dir, 'best_model.pth')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载检查点: {checkpoint_path}")
print(f" Epoch: {checkpoint['epoch']}")
print(f" Loss: {checkpoint['loss']:.4f}")
else:
print(f"未找到检查点: {checkpoint_path}")
return
# 创建噪声调度
alphas, betas, alphas_cumprod = get_schedule(config.T, config.beta_start, config.beta_end)
alphas_cumprod = alphas_cumprod.to(device)
# 生成样本
print(f"\n生成 {config.num_samples} 个样本...")
x_T = torch.randn(config.num_samples, config.in_channels, config.image_size, config.image_size)
samples = ddpm_sample(model, x_T, config.T, alphas, betas, alphas_cumprod, device)
# 反归一化
samples = (samples + 1) / 2
samples = samples.clamp(0, 1)
# 可视化
grid = make_grid(samples, nrow=4, padding=2, normalize=False)
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title('Generated Samples')
plt.axis('off')
save_path = os.path.join(config.sample_dir, 'generated_samples.png')
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"样本已保存到: {save_path}")
plt.show()
if __name__ == '__main__':
sample()
7. 运行项目¶
7.1 训练模型¶
7.2 生成样本¶
8. 项目总结¶
8.1 完成的工作¶
- ✅ 搭建了完整的训练环境
- ✅ 实现了UNet模型架构
- ✅ 实现了DDPM训练流程
- ✅ 实现了DDPM采样流程
- ✅ 添加了TensorBoard监控
8.2 关键技术点¶
| 技术点 | 说明 |
|---|---|
| UNet架构 | 用于去噪的核心网络 |
| 噪声调度 | 控制每步添加的噪声量 |
| 训练目标 | 预测添加的噪声 |
| 采样过程 | 从噪声逐步恢复图像 |
| TensorBoard | 监控训练过程 |
8.3 改进方向¶
- 使用DDIM加速:减少采样步数
- 添加数据增强:提高模型泛化能力
- 使用EMA:改善生成质量
- 尝试更大的模型:提高生成质量
- 使用更大数据集:如ImageNet
9. 常见问题¶
Q1: 训练时间太长怎么办?¶
A: 可以尝试: - 减少batch size - 使用混合精度训练 - 使用更小的模型 - 减少训练轮数
Q2: 生成的图像质量不好?¶
A: 可以尝试: - 增加训练轮数 - 使用更大的模型 - 使用更好的数据增强 - 调整噪声调度
Q3: 显存不足?¶
A: 可以尝试: - 减少batch size - 使用梯度累积 - 使用更小的模型 - 使用混合精度训练
10. 下一步¶
完成本项目后,你可以:
- 尝试其他数据集:如ImageNet、LSUN
- 实现条件生成:添加类别条件
- 实现文本到图像:使用CLIP编码文本
- 实现图像编辑:如图像修复、风格迁移
- 优化模型:使用DDIM、LDM等技术
下一章: 02-文本到图像生成 - 实现基于文本的图像生成