01-GAN基础¶
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: CNN基础、概率论基础、PyTorch 学习目标: 理解GAN的核心原理,掌握从GAN到DCGAN、条件GAN等主要变体的实现
目录¶
- 1. GAN的核心思想
- 2. 数学推导
- 3. 原始GAN实现
- 4. DCGAN
- 5. 条件GAN (cGAN)
- 6. GAN的训练技巧
- 7. GAN的评估指标
- 8. GAN变体概览
- 9. 练习与自我检查
1. GAN的核心思想¶
1.1 对抗博弈¶
Goodfellow et al.(2014)提出 GAN(Generative Adversarial Network),核心是两个网络的博弈:
- 生成器 (Generator, G): 从噪声 \(z \sim p_z(z)\) 生成"假"数据,试图骗过判别器
- 判别器 (Discriminator, D): 区分"真"数据和"假"数据
类比:G 是造假币者,D 是验钞机。两者互相博弈,最终 G 生成的假币足以以假乱真。
1.2 直觉理解¶
噪声 z ~ N(0,1) → [生成器 G] → 生成样本 G(z)
↓
真实数据 x ~ p_data → [判别器 D] → D(x): 真还是假?
↑
生成样本 G(z) →
2. 数学推导¶
2.1 目标函数¶
判别器目标:最大化 \(V\) - 对真实样本 \(x\),\(D(x) \to 1\),\(\log D(x) \to 0\) - 对生成样本 \(G(z)\),\(D(G(z)) \to 0\),\(\log(1 - D(G(z))) \to 0\)
生成器目标:最小化 \(V\) - 让 \(D(G(z)) \to 1\),\(\log(1 - D(G(z))) \to -\infty\)
2.2 最优判别器¶
固定 G,对 D 求最优解:
2.3 全局最优¶
当 \(p_g = p_{data}\) 时,\(D^*(x) = \frac{1}{2}\),目标函数达到全局最小值 \(-\log 4\)。
此时目标函数可以写成 JS 散度的形式:
2.4 非饱和 GAN 损失¶
实际训练中,\(\log(1 - D(G(z)))\) 在训练初期梯度很小(D 很容易分辨真假时,\(D(G(z)) \approx 0\))。因此实践中 G 的损失改为:
import torch
import torch.nn as nn
# 原始 GAN 损失
criterion = nn.BCEWithLogitsLoss()
def discriminator_loss(D, real_data, fake_data):
real_pred = D(real_data)
fake_pred = D(fake_data.detach()) # detach()从计算图分离,不参与梯度计算
real_loss = criterion(real_pred, torch.ones_like(real_pred))
fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))
return (real_loss + fake_loss) / 2
def generator_loss(D, fake_data):
# 非饱和损失
fake_pred = D(fake_data)
return criterion(fake_pred, torch.ones_like(fake_pred))
3. 原始GAN实现¶
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class Generator(nn.Module): # 继承nn.Module定义神经网络层
"""全连接生成器"""
def __init__(self, latent_dim=100, img_dim=784): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.net = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
"""全连接判别器"""
def __init__(self, img_dim=784):
super().__init__()
self.net = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1)
)
def forward(self, x):
return self.net(x)
def train_vanilla_gan():
# 超参数
latent_dim = 100
img_dim = 28 * 28
batch_size = 128
epochs = 200
lr = 2e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # DataLoader批量加载数据,支持shuffle和多进程
# 模型
G = Generator(latent_dim, img_dim).to(device)
D = Discriminator(img_dim).to(device)
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()
fixed_noise = torch.randn(64, latent_dim, device=device)
for epoch in range(epochs):
for real_imgs, _ in dataloader:
real_imgs = real_imgs.view(-1, img_dim).to(device) # 链式调用,连续执行多个方法 # view重塑张量形状(要求内存连续)
bs = real_imgs.size(0)
# === 训练判别器 ===
z = torch.randn(bs, latent_dim, device=device)
fake_imgs = G(z)
d_real = D(real_imgs)
d_fake = D(fake_imgs.detach())
loss_D = (criterion(d_real, torch.ones_like(d_real)) +
criterion(d_fake, torch.zeros_like(d_fake))) / 2
opt_D.zero_grad() # 清零梯度,防止梯度累积
loss_D.backward() # 反向传播计算梯度
opt_D.step()
# === 训练生成器 ===
z = torch.randn(bs, latent_dim, device=device)
fake_imgs = G(z)
d_fake = D(fake_imgs)
loss_G = criterion(d_fake, torch.ones_like(d_fake))
opt_G.zero_grad()
loss_G.backward()
opt_G.step()
if (epoch + 1) % 20 == 0:
print(f"Epoch [{epoch+1}/{epochs}] D_loss: {loss_D:.4f}, G_loss: {loss_G:.4f}")
with torch.no_grad(): # 禁用梯度计算,节省内存
samples = G(fixed_noise).view(-1, 1, 28, 28)
grid = torchvision.utils.make_grid(samples, nrow=8, normalize=True)
# train_vanilla_gan()
4. DCGAN¶
4.1 架构设计原则¶
Radford et al.(2015)提出的 DCGAN 用卷积替代全连接层:
- 用步长卷积替代池化层
- 使用 BatchNorm(G 和 D 都用,D 的第一层除外)
- G 使用 ReLU(最后一层 Tanh),D 使用 LeakyReLU
- 去掉全连接层
class DCGenerator(nn.Module):
"""DCGAN 生成器"""
def __init__(self, latent_dim=100, channels=1, feature_maps=64):
super().__init__()
self.net = nn.Sequential(
# 输入: (latent_dim, 1, 1)
nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(True),
# (feature_maps*8, 4, 4)
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(True),
# (feature_maps*4, 8, 8)
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(True),
# (feature_maps*2, 16, 16)
nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps),
nn.ReLU(True),
# (feature_maps, 32, 32)
nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
nn.Tanh()
# (channels, 64, 64)
)
self.apply(self._init_weights)
@staticmethod # @staticmethod静态方法,无需实例即可调用
def _init_weights(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): # isinstance检查对象类型
nn.init.normal_(m.weight, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def forward(self, z):
return self.net(z.view(z.size(0), -1, 1, 1))
class DCDiscriminator(nn.Module):
"""DCGAN 判别器"""
def __init__(self, channels=1, feature_maps=64):
super().__init__()
self.net = nn.Sequential(
# 输入: (channels, 64, 64)
nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# (feature_maps, 32, 32)
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
# (feature_maps*2, 16, 16)
nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.LeakyReLU(0.2, inplace=True),
# (feature_maps*4, 8, 8)
nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.LeakyReLU(0.2, inplace=True),
# (feature_maps*8, 4, 4)
nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False)
# (1, 1, 1)
)
self.apply(DCGenerator._init_weights)
def forward(self, x):
return self.net(x).view(-1, 1)
5. 条件GAN (cGAN)¶
5.1 核心思想¶
条件 GAN 在 G 和 D 中都加入条件信息 \(y\)(如类别标签):
class ConditionalGenerator(nn.Module):
"""条件生成器"""
def __init__(self, latent_dim=100, num_classes=10, img_dim=784):
super().__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.net = nn.Sequential(
nn.Linear(latent_dim + num_classes, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, z, labels):
label_embed = self.label_emb(labels)
x = torch.cat([z, label_embed], dim=1)
return self.net(x)
class ConditionalDiscriminator(nn.Module):
"""条件判别器"""
def __init__(self, num_classes=10, img_dim=784):
super().__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.net = nn.Sequential(
nn.Linear(img_dim + num_classes, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1)
)
def forward(self, x, labels):
label_embed = self.label_emb(labels)
x = torch.cat([x, label_embed], dim=1)
return self.net(x)
6. GAN的训练技巧¶
6.1 常见问题与解决方案¶
| 问题 | 症状 | 解决方案 |
|---|---|---|
| 模式崩塌 | G 只生成少数样本 | Mini-batch discrimination, Unrolled GAN |
| 训练不稳定 | 损失震荡,不收敛 | Spectral Normalization, 梯度惩罚 |
| 梯度消失 | D 太强,G 学不到东西 | 非饱和损失, WGAN |
6.2 谱归一化 (Spectral Normalization)¶
# PyTorch 内置支持
from torch.nn.utils import spectral_norm
class SNDiscriminator(nn.Module):
"""带谱归一化的判别器"""
def __init__(self, channels=1):
super().__init__()
self.net = nn.Sequential(
spectral_norm(nn.Conv2d(channels, 64, 4, 2, 1)),
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(256, 1, 4, 1, 0)),
)
def forward(self, x):
return self.net(x).view(-1, 1)
6.3 WGAN-GP (梯度惩罚)¶
def gradient_penalty(D, real_data, fake_data, device):
"""WGAN-GP 梯度惩罚"""
alpha = torch.rand(real_data.size(0), 1, 1, 1, device=device)
interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
d_interpolated = D(interpolated)
gradients = torch.autograd.grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interpolated),
create_graph=True,
retain_graph=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gp
# WGAN-GP 训练步骤
def wgan_gp_step(D, G, opt_D, real_data, latent_dim, device, lambda_gp=10):
z = torch.randn(real_data.size(0), latent_dim, device=device)
fake_data = G(z)
d_real = D(real_data).mean()
d_fake = D(fake_data.detach()).mean()
gp = gradient_penalty(D, real_data, fake_data.detach(), device)
loss_D = d_fake - d_real + lambda_gp * gp
opt_D.zero_grad()
loss_D.backward()
opt_D.step()
return loss_D.item() # .item()将单元素张量转为Python数值
7. GAN的评估指标¶
7.1 Inception Score (IS)¶
高 IS 意味着: - 每张图片的分类输出很确定(低熵) - 所有图片的类别分布很均匀(高熵)
7.2 Fréchet Inception Distance (FID)¶
FID 越低越好。它直接衡量生成分布和真实分布在 Inception 特征空间的距离。
# 使用 torchmetrics 计算 FID
# pip install torchmetrics[image]
from torchmetrics.image.fid import FrechetInceptionDistance
def compute_fid(real_images, fake_images):
"""计算 FID 分数"""
fid = FrechetInceptionDistance(feature=2048, normalize=True)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
score = fid.compute()
return score.item()
8. GAN变体概览¶
| 变体 | 年份 | 核心贡献 |
|---|---|---|
| GAN | 2014 | 对抗训练框架 |
| DCGAN | 2015 | 卷积架构设计原则 |
| WGAN | 2017 | Wasserstein距离,稳定训练 |
| WGAN-GP | 2017 | 梯度惩罚替代权重裁剪 |
| Progressive GAN | 2018 | 渐进式增长,高分辨率生成 |
| StyleGAN | 2019 | 样式控制,映射网络 |
| StyleGAN2 | 2020 | 改进架构,更高质量 |
| StyleGAN3 | 2021 | 消除纹理粘滞 |
9. 练习与自我检查¶
练习题¶
- 原始 GAN:在 MNIST 上训练原始 GAN,观察训练过程和生成质量的变化。
- DCGAN:实现 DCGAN 生成 64×64 的人脸图像(使用 CelebA 数据集)。
- 条件 GAN:实现条件 GAN,按指定数字生成 MNIST 图像。
- WGAN-GP:对比 WGAN-GP 和原始 GAN 的训练稳定性。
- FID 评估:计算不同训练阶段的 FID 分数,观察生成质量的提升。
自我检查清单¶
- 理解 GAN 的对抗训练思想
- 能推导 GAN 的最优判别器
- 理解模式崩塌和训练不稳定的原因
- 能实现基本 GAN 和 DCGAN
- 了解条件 GAN 的实现方式
- 知道谱归一化和梯度惩罚的作用
- 能使用 IS 和 FID 评估生成质量
下一篇: 02-VAE基础




