跳转至

01-GAN基础

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: CNN基础、概率论基础、PyTorch 学习目标: 理解GAN的核心原理,掌握从GAN到DCGAN、条件GAN等主要变体的实现


目录


1. GAN的核心思想

1.1 对抗博弈

Goodfellow et al.(2014)提出 GAN(Generative Adversarial Network),核心是两个网络的博弈:

  • 生成器 (Generator, G): 从噪声 \(z \sim p_z(z)\) 生成"假"数据,试图骗过判别器
  • 判别器 (Discriminator, D): 区分"真"数据和"假"数据

类比:G 是造假币者,D 是验钞机。两者互相博弈,最终 G 生成的假币足以以假乱真。

GAN架构图

1.2 直觉理解

Text Only
噪声 z ~ N(0,1) → [生成器 G] → 生成样本 G(z)
真实数据 x ~ p_data  →  [判别器 D]  → D(x): 真还是假?
                        生成样本 G(z) →

GAN训练过程示意图


2. 数学推导

2.1 目标函数

\[\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\]

判别器目标:最大化 \(V\) - 对真实样本 \(x\)\(D(x) \to 1\)\(\log D(x) \to 0\) - 对生成样本 \(G(z)\)\(D(G(z)) \to 0\)\(\log(1 - D(G(z))) \to 0\)

生成器目标:最小化 \(V\) - 让 \(D(G(z)) \to 1\)\(\log(1 - D(G(z))) \to -\infty\)

2.2 最优判别器

固定 G,对 D 求最优解:

\[D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\]

2.3 全局最优

\(p_g = p_{data}\) 时,\(D^*(x) = \frac{1}{2}\),目标函数达到全局最小值 \(-\log 4\)

此时目标函数可以写成 JS 散度的形式:

\[C(G) = -\log 4 + 2 \cdot JSD(p_{data} \| p_g)\]

2.4 非饱和 GAN 损失

实际训练中,\(\log(1 - D(G(z)))\) 在训练初期梯度很小(D 很容易分辨真假时,\(D(G(z)) \approx 0\))。因此实践中 G 的损失改为:

\[\mathcal{L}_G = -\mathbb{E}_{z \sim p_z}[\log D(G(z))]\]
Python
import torch
import torch.nn as nn

# 原始 GAN 损失
criterion = nn.BCEWithLogitsLoss()

def discriminator_loss(D, real_data, fake_data):
    real_pred = D(real_data)
    fake_pred = D(fake_data.detach())  # detach()从计算图分离,不参与梯度计算

    real_loss = criterion(real_pred, torch.ones_like(real_pred))
    fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))

    return (real_loss + fake_loss) / 2

def generator_loss(D, fake_data):
    # 非饱和损失
    fake_pred = D(fake_data)
    return criterion(fake_pred, torch.ones_like(fake_pred))

3. 原始GAN实现

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class Generator(nn.Module):  # 继承nn.Module定义神经网络层
    """全连接生成器"""
    def __init__(self, latent_dim=100, img_dim=784):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    """全连接判别器"""
    def __init__(self, img_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.net(x)

def train_vanilla_gan():
    # 超参数
    latent_dim = 100
    img_dim = 28 * 28
    batch_size = 128
    epochs = 200
    lr = 2e-4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 数据
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)  # DataLoader批量加载数据,支持shuffle和多进程

    # 模型
    G = Generator(latent_dim, img_dim).to(device)
    D = Discriminator(img_dim).to(device)

    opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(64, latent_dim, device=device)

    for epoch in range(epochs):
        for real_imgs, _ in dataloader:
            real_imgs = real_imgs.view(-1, img_dim).to(device)  # 链式调用,连续执行多个方法  # view重塑张量形状(要求内存连续)
            bs = real_imgs.size(0)

            # === 训练判别器 ===
            z = torch.randn(bs, latent_dim, device=device)
            fake_imgs = G(z)

            d_real = D(real_imgs)
            d_fake = D(fake_imgs.detach())

            loss_D = (criterion(d_real, torch.ones_like(d_real)) +
                      criterion(d_fake, torch.zeros_like(d_fake))) / 2

            opt_D.zero_grad()  # 清零梯度,防止梯度累积
            loss_D.backward()  # 反向传播计算梯度
            opt_D.step()

            # === 训练生成器 ===
            z = torch.randn(bs, latent_dim, device=device)
            fake_imgs = G(z)
            d_fake = D(fake_imgs)

            loss_G = criterion(d_fake, torch.ones_like(d_fake))

            opt_G.zero_grad()
            loss_G.backward()
            opt_G.step()

        if (epoch + 1) % 20 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] D_loss: {loss_D:.4f}, G_loss: {loss_G:.4f}")
            with torch.no_grad():  # 禁用梯度计算,节省内存
                samples = G(fixed_noise).view(-1, 1, 28, 28)
                grid = torchvision.utils.make_grid(samples, nrow=8, normalize=True)

# train_vanilla_gan()

4. DCGAN

4.1 架构设计原则

Radford et al.(2015)提出的 DCGAN 用卷积替代全连接层:

  • 步长卷积替代池化层
  • 使用 BatchNorm(G 和 D 都用,D 的第一层除外)
  • G 使用 ReLU(最后一层 Tanh),D 使用 LeakyReLU
  • 去掉全连接层
Python
class DCGenerator(nn.Module):
    """DCGAN 生成器"""
    def __init__(self, latent_dim=100, channels=1, feature_maps=64):
        super().__init__()
        self.net = nn.Sequential(
            # 输入: (latent_dim, 1, 1)
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            # (feature_maps*8, 4, 4)

            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            # (feature_maps*4, 8, 8)

            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            # (feature_maps*2, 16, 16)

            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            # (feature_maps, 32, 32)

            nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # (channels, 64, 64)
        )
        self.apply(self._init_weights)

    @staticmethod  # @staticmethod静态方法,无需实例即可调用
    def _init_weights(m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):  # isinstance检查对象类型
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.zeros_(m.bias)

    def forward(self, z):
        return self.net(z.view(z.size(0), -1, 1, 1))

class DCDiscriminator(nn.Module):
    """DCGAN 判别器"""
    def __init__(self, channels=1, feature_maps=64):
        super().__init__()
        self.net = nn.Sequential(
            # 输入: (channels, 64, 64)
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (feature_maps, 32, 32)

            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # (feature_maps*2, 16, 16)

            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (feature_maps*4, 8, 8)

            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # (feature_maps*8, 4, 4)

            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False)
            # (1, 1, 1)
        )
        self.apply(DCGenerator._init_weights)

    def forward(self, x):
        return self.net(x).view(-1, 1)

DCGAN架构图


5. 条件GAN (cGAN)

5.1 核心思想

条件 GAN 在 G 和 D 中都加入条件信息 \(y\)(如类别标签):

\[\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x|y)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z|y)|y))]\]
Python
class ConditionalGenerator(nn.Module):
    """条件生成器"""
    def __init__(self, latent_dim=100, num_classes=10, img_dim=784):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.net = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, z, labels):
        label_embed = self.label_emb(labels)
        x = torch.cat([z, label_embed], dim=1)
        return self.net(x)

class ConditionalDiscriminator(nn.Module):
    """条件判别器"""
    def __init__(self, num_classes=10, img_dim=784):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.net = nn.Sequential(
            nn.Linear(img_dim + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, x, labels):
        label_embed = self.label_emb(labels)
        x = torch.cat([x, label_embed], dim=1)
        return self.net(x)

6. GAN的训练技巧

6.1 常见问题与解决方案

问题 症状 解决方案
模式崩塌 G 只生成少数样本 Mini-batch discrimination, Unrolled GAN
训练不稳定 损失震荡,不收敛 Spectral Normalization, 梯度惩罚
梯度消失 D 太强,G 学不到东西 非饱和损失, WGAN

模式崩溃示意图

6.2 谱归一化 (Spectral Normalization)

Python
# PyTorch 内置支持
from torch.nn.utils import spectral_norm

class SNDiscriminator(nn.Module):
    """带谱归一化的判别器"""
    def __init__(self, channels=1):
        super().__init__()
        self.net = nn.Sequential(
            spectral_norm(nn.Conv2d(channels, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(256, 1, 4, 1, 0)),
        )

    def forward(self, x):
        return self.net(x).view(-1, 1)

6.3 WGAN-GP (梯度惩罚)

Python
def gradient_penalty(D, real_data, fake_data, device):
    """WGAN-GP 梯度惩罚"""
    alpha = torch.rand(real_data.size(0), 1, 1, 1, device=device)
    interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)

    d_interpolated = D(interpolated)

    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gp

# WGAN-GP 训练步骤
def wgan_gp_step(D, G, opt_D, real_data, latent_dim, device, lambda_gp=10):
    z = torch.randn(real_data.size(0), latent_dim, device=device)
    fake_data = G(z)

    d_real = D(real_data).mean()
    d_fake = D(fake_data.detach()).mean()
    gp = gradient_penalty(D, real_data, fake_data.detach(), device)

    loss_D = d_fake - d_real + lambda_gp * gp

    opt_D.zero_grad()
    loss_D.backward()
    opt_D.step()

    return loss_D.item()  # .item()将单元素张量转为Python数值

7. GAN的评估指标

7.1 Inception Score (IS)

\[IS = \exp\left(\mathbb{E}_x [KL(p(y|x) \| p(y))]\right)\]

高 IS 意味着: - 每张图片的分类输出很确定(低熵) - 所有图片的类别分布很均匀(高熵)

7.2 Fréchet Inception Distance (FID)

\[FID = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2})\]

FID 越低越好。它直接衡量生成分布和真实分布在 Inception 特征空间的距离。

Python
# 使用 torchmetrics 计算 FID
# pip install torchmetrics[image]
from torchmetrics.image.fid import FrechetInceptionDistance

def compute_fid(real_images, fake_images):
    """计算 FID 分数"""
    fid = FrechetInceptionDistance(feature=2048, normalize=True)

    fid.update(real_images, real=True)
    fid.update(fake_images, real=False)

    score = fid.compute()
    return score.item()

8. GAN变体概览

变体 年份 核心贡献
GAN 2014 对抗训练框架
DCGAN 2015 卷积架构设计原则
WGAN 2017 Wasserstein距离,稳定训练
WGAN-GP 2017 梯度惩罚替代权重裁剪
Progressive GAN 2018 渐进式增长,高分辨率生成
StyleGAN 2019 样式控制,映射网络
StyleGAN2 2020 改进架构,更高质量
StyleGAN3 2021 消除纹理粘滞

潜在空间插值可视化


9. 练习与自我检查

练习题

  1. 原始 GAN:在 MNIST 上训练原始 GAN,观察训练过程和生成质量的变化。
  2. DCGAN:实现 DCGAN 生成 64×64 的人脸图像(使用 CelebA 数据集)。
  3. 条件 GAN:实现条件 GAN,按指定数字生成 MNIST 图像。
  4. WGAN-GP:对比 WGAN-GP 和原始 GAN 的训练稳定性。
  5. FID 评估:计算不同训练阶段的 FID 分数,观察生成质量的提升。

自我检查清单

  • 理解 GAN 的对抗训练思想
  • 能推导 GAN 的最优判别器
  • 理解模式崩塌和训练不稳定的原因
  • 能实现基本 GAN 和 DCGAN
  • 了解条件 GAN 的实现方式
  • 知道谱归一化和梯度惩罚的作用
  • 能使用 IS 和 FID 评估生成质量

下一篇: 02-VAE基础