04-GAN图像生成实战¶
学习时间: 约4-6小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: GAN基础、CNN基础、PyTorch 项目目标: 基于DCGAN实现人脸生成,掌握GAN的完整训练流程和调参技巧
目录¶
1. 项目概述¶
1.1 任务描述¶
使用 DCGAN 在人脸数据集(CelebA)上训练生成模型,能生成逼真的人脸图像。
1.2 技术路线¶
2. 数据准备¶
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid, save_image
# 超参数
IMG_SIZE = 64
BATCH_SIZE = 128
LATENT_DIM = 100
NUM_EPOCHS = 50
LR = 2e-4
BETA1 = 0.5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据预处理
transform = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # [-1, 1]
])
# CelebA 数据集(如果无法下载,可替换为其他人脸数据集)
# dataset = torchvision.datasets.CelebA(root='./data', split='train', transform=transform, download=True)
# 替代方案:使用 CIFAR-10 或之前下载的数据
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]), download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, # DataLoader批量加载数据,支持shuffle和多进程
num_workers=4, drop_last=True)
# 可视化真实数据
real_batch = next(iter(dataloader))[0][:64]
plt.figure(figsize=(10, 10))
grid = make_grid(real_batch, nrow=8, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
plt.title('Real Images')
plt.axis('off')
# plt.show()
3. DCGAN模型¶
Python
def weights_init(m):
"""DCGAN 权重初始化"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module): # 继承nn.Module定义神经网络层
"""DCGAN 生成器: z → 64×64×3 图像"""
def __init__(self, latent_dim=100, channels=3, feature_maps=64): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
ngf = feature_maps
self.main = nn.Sequential(
# Input: (latent_dim, 1, 1) → (ngf*8, 4, 4)
nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# (ngf*8, 4, 4) → (ngf*4, 8, 8)
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# (ngf*4, 8, 8) → (ngf*2, 16, 16)
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# (ngf*2, 16, 16) → (ngf, 32, 32)
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# (ngf, 32, 32) → (channels, 64, 64)
nn.ConvTranspose2d(ngf, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.main(z.view(-1, z.size(1), 1, 1)) # view重塑张量形状(要求内存连续)
class Discriminator(nn.Module):
"""DCGAN 判别器: 64×64×3 图像 → 真/假"""
def __init__(self, channels=3, feature_maps=64):
super().__init__()
ndf = feature_maps
self.main = nn.Sequential(
# (channels, 64, 64) → (ndf, 32, 32)
nn.Conv2d(channels, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# (ndf, 32, 32) → (ndf*2, 16, 16)
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# (ndf*2, 16, 16) → (ndf*4, 8, 8)
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# (ndf*4, 8, 8) → (ndf*8, 4, 4)
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# (ndf*8, 4, 4) → (1, 1, 1)
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
)
def forward(self, x):
return self.main(x).view(-1, 1) # 链式调用,连续执行多个方法
# 创建模型
netG = Generator(LATENT_DIM).to(DEVICE)
netD = Discriminator().to(DEVICE)
netG.apply(weights_init)
netD.apply(weights_init)
print(f"生成器参数量: {sum(p.numel() for p in netG.parameters()):,}")
print(f"判别器参数量: {sum(p.numel() for p in netD.parameters()):,}")
4. 训练流程¶
Python
def train_dcgan(netG, netD, dataloader, num_epochs=50, device=DEVICE):
"""训练 DCGAN"""
criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))
# 固定噪声用于可视化
fixed_noise = torch.randn(64, LATENT_DIM, device=device)
# 训练记录
G_losses, D_losses = [], []
D_real_scores, D_fake_scores = [], []
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader): # enumerate同时获取索引和元素
real_imgs = real_imgs.to(device) # .to(device)将数据移至GPU/CPU
batch_size = real_imgs.size(0)
# === 训练判别器 ===
netD.zero_grad()
# 真实图像
output_real = netD(real_imgs)
label_real = torch.ones_like(output_real)
loss_D_real = criterion(output_real, label_real)
# 生成图像
noise = torch.randn(batch_size, LATENT_DIM, device=device)
fake_imgs = netG(noise)
output_fake = netD(fake_imgs.detach()) # detach()从计算图分离,不参与梯度计算
label_fake = torch.zeros_like(output_fake)
loss_D_fake = criterion(output_fake, label_fake)
loss_D = (loss_D_real + loss_D_fake) / 2
loss_D.backward() # 反向传播计算梯度
optimizerD.step()
# === 训练生成器 ===
netG.zero_grad()
output_fake = netD(fake_imgs)
loss_G = criterion(output_fake, torch.ones_like(output_fake))
loss_G.backward()
optimizerG.step()
# 记录
D_real_scores.append(torch.sigmoid(output_real).mean().item())
D_fake_scores.append(torch.sigmoid(output_fake).mean().item())
G_losses.append(loss_G.item())
D_losses.append(loss_D.item())
# 每个 epoch 结束时打印和生成样本
print(f"Epoch [{epoch+1}/{num_epochs}] "
f"D_loss: {loss_D.item():.4f}, G_loss: {loss_G.item():.4f}, "
f"D(x): {D_real_scores[-1]:.3f}, D(G(z)): {D_fake_scores[-1]:.3f}") # [-1]负索引取最后一个元素
# 生成样本
with torch.no_grad():
fake_samples = netG(fixed_noise).cpu()
save_image(fake_samples, f'generated_epoch_{epoch+1}.png',
nrow=8, normalize=True)
return G_losses, D_losses, D_real_scores, D_fake_scores
5. 生成结果可视化¶
Python
def plot_training_progress(G_losses, D_losses, D_real_scores, D_fake_scores):
"""绘制训练过程"""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 损失曲线
axes[0].plot(G_losses, label='Generator', alpha=0.7)
axes[0].plot(D_losses, label='Discriminator', alpha=0.7)
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Losses')
axes[0].legend()
# 判别器分数
axes[1].plot(D_real_scores, label='D(real)', alpha=0.7)
axes[1].plot(D_fake_scores, label='D(fake)', alpha=0.7)
axes[1].axhline(y=0.5, color='r', linestyle='--', alpha=0.5)
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Score')
axes[1].set_title('Discriminator Scores')
axes[1].legend()
plt.tight_layout()
plt.savefig('training_progress.png', dpi=150)
plt.show()
def visualize_latent_interpolation(netG, device=DEVICE):
"""隐空间插值可视化"""
netG.eval() # eval()开启评估模式(关闭Dropout等)
z1 = torch.randn(1, LATENT_DIM, device=device)
z2 = torch.randn(1, LATENT_DIM, device=device)
steps = 10
interpolations = []
for alpha in np.linspace(0, 1, steps):
z = (1 - alpha) * z1 + alpha * z2
with torch.no_grad():
img = netG(z).cpu()
interpolations.append(img)
grid = make_grid(torch.cat(interpolations), nrow=steps, normalize=True)
plt.figure(figsize=(20, 3))
plt.imshow(grid.permute(1, 2, 0))
plt.title('Latent Space Interpolation')
plt.axis('off')
plt.savefig('interpolation.png', dpi=150)
plt.show()
def visualize_latent_arithmetic(netG, device=DEVICE):
"""隐空间算术: z_smile - z_neutral + z_glasses ≈ z_smile_with_glasses"""
netG.eval()
z_a = torch.randn(1, LATENT_DIM, device=device)
z_b = torch.randn(1, LATENT_DIM, device=device)
z_c = torch.randn(1, LATENT_DIM, device=device)
z_result = z_a - z_b + z_c
with torch.no_grad():
imgs = [netG(z).cpu() for z in [z_a, z_b, z_c, z_result]] # 列表推导式,简洁创建列表
grid = make_grid(torch.cat(imgs), nrow=4, normalize=True)
plt.figure(figsize=(12, 4))
plt.imshow(grid.permute(1, 2, 0))
plt.title('A - B + C = Result')
plt.axis('off')
plt.show()
6. 训练技巧与问题排查¶
6.1 常见问题诊断¶
| 症状 | 原因 | 解决方案 |
|---|---|---|
| D_loss → 0, G_loss → ∞ | D 太强 | 降低 D 学习率,减少 D 训练次数 |
| 生成图像全相同 | 模式崩塌 | 加入 mini-batch discrimination |
| 生成图像噪声 | 学习率过高 | 降低学习率 |
| 训练震荡 | 不稳定 | 使用谱归一化,梯度裁剪 |
6.2 改进方案¶
Python
# 1. 标签平滑
def smooth_labels(size, real=True, device='cuda'):
if real:
return torch.FloatTensor(size, 1).uniform_(0.7, 1.2).to(device)
else:
return torch.FloatTensor(size, 1).uniform_(0.0, 0.3).to(device)
# 2. 偶尔交换标签(防止D过强)
def noisy_labels(size, prob_flip=0.05, device='cuda'):
labels = torch.ones(size, 1, device=device)
flip_mask = torch.rand(size, 1, device=device) < prob_flip
labels[flip_mask] = 0
return labels
# 3. 渐进式训练调度
class ProgressiveTrainer:
def __init__(self, netG, netD):
self.netG = netG
self.netD = netD
self.d_steps = 1 # 初始 D 步数
def adjust_d_steps(self, d_loss, g_loss, threshold=0.1):
"""动态调整 D 的训练步数"""
if d_loss < threshold:
self.d_steps = max(1, self.d_steps - 1)
elif g_loss > 5.0:
self.d_steps = min(5, self.d_steps + 1)
7. 完整项目代码¶
Python
def main():
"""完整的 DCGAN 图像生成项目"""
print(f"使用设备: {DEVICE}")
# 1. 创建模型
netG = Generator(LATENT_DIM).to(DEVICE)
netD = Discriminator().to(DEVICE)
netG.apply(weights_init)
netD.apply(weights_init)
# 2. 训练
G_losses, D_losses, D_real, D_fake = train_dcgan(
netG, netD, dataloader, num_epochs=NUM_EPOCHS
)
# 3. 可视化训练过程
plot_training_progress(G_losses, D_losses, D_real, D_fake)
# 4. 生成最终样本
noise = torch.randn(64, LATENT_DIM, device=DEVICE)
with torch.no_grad():
final_samples = netG(noise).cpu()
save_image(final_samples, 'final_generated.png', nrow=8, normalize=True)
# 5. 隐空间探索
visualize_latent_interpolation(netG, DEVICE)
# 6. 保存模型
torch.save({
'generator': netG.state_dict(),
'discriminator': netD.state_dict(),
}, 'dcgan_model.pt')
print("项目完成!")
# main()
返回: 深度学习目录