跳转至

02-VAE基础

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 概率论(贝叶斯推断)、信息论基础、PyTorch 学习目标: 理解VAE的数学原理(ELBO推导),掌握VAE及其变体的实现


目录


1. 生成模型与隐变量

1.1 隐变量模型

假设可观察数据 \(x\) 由一些不可观察的隐变量 \(z\) 生成:

\[p_\theta(x) = \int p_\theta(x|z) p(z) dz\]

问题: - 积分不可解(intractable)— \(z\) 的维度很高 - 后验 \(p_\theta(z|x)\) 也不可解

1.2 变分推断的思路

用一个可训练的分布 \(q_\phi(z|x)\)(编码器)来近似真实后验 \(p_\theta(z|x)\)

VAE架构图


2. VAE的数学推导

2.1 ELBO推导

从对数边际似然出发:

\[\log p_\theta(x) = \log \int p_\theta(x|z)p(z)dz\]

引入变分分布 \(q_\phi(z|x)\)

\[\log p_\theta(x) = \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x, z)}{q_\phi(z|x)}\right] + KL(q_\phi(z|x) \| p_\theta(z|x))\]

由于 \(KL \geq 0\)

\[\log p_\theta(x) \geq \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x, z)}{q_\phi(z|x)}\right] = \text{ELBO}\]

2.2 ELBO的分解

\[\text{ELBO} = \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{重构项}} - \underbrace{KL(q_\phi(z|x) \| p(z))}_{\text{正则项}}\]
  • 重构项:解码器重建输入的能力
  • 正则项:让编码器分布接近先验 \(p(z) = \mathcal{N}(0, I)\)

2.3 高斯假设下的KL散度解析解

\(q_\phi(z|x) = \mathcal{N}(\mu, \text{diag}(\sigma^2))\)\(p(z) = \mathcal{N}(0, I)\) 时:

\[KL(q_\phi \| p) = -\frac{1}{2}\sum_{j=1}^{d}\left(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right)\]
Python
import torch
import torch.nn as nn
import torch.nn.functional as F

def kl_divergence(mu, log_var):
    """
    高斯 KL 散度的解析解
    mu: (batch, latent_dim)
    log_var: (batch, latent_dim)  — 对数方差
    """
    return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)

3. 重参数化技巧

3.1 问题

训练时需要对 ELBO 求梯度,但采样操作 \(z \sim q_\phi(z|x)\) 不可导。

3.2 解决方案

将随机性从参数中分离:

\[z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

这样 \(z\) 关于 \(\mu\)\(\sigma\) 是确定性函数,梯度可以反传。

Python
def reparameterize(mu, log_var):
    """重参数化技巧"""
    std = torch.exp(0.5 * log_var)  # σ = exp(log_var / 2)
    eps = torch.randn_like(std)     # ε ~ N(0, I)
    z = mu + std * eps              # z = μ + σ * ε
    return z

重参数化技巧


4. VAE的完整实现

Python
class Encoder(nn.Module):  # 继承nn.Module定义神经网络层
    """VAE 编码器 — 输出 μ 和 log σ²"""
    def __init__(self, input_dim, hidden_dim, latent_dim):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = self.fc(x)
        mu = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        return mu, log_var

class Decoder(nn.Module):
    """VAE 解码器"""
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()  # 输出在 [0, 1]
        )

    def forward(self, z):
        return self.fc(z)

class VAE(nn.Module):
    """变分自编码器"""
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=20):
        super().__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

    def loss_function(self, x, x_recon, mu, log_var):
        # 重构损失(BCE)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        # KL 散度
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + kl_loss, recon_loss, kl_loss

    def sample(self, num_samples, device):
        """从先验分布采样并解码"""
        z = torch.randn(num_samples, self.latent_dim, device=device)
        samples = self.decoder(z)
        return samples

    def interpolate(self, x1, x2, steps=10):
        """在两个样本之间进行隐空间插值"""
        mu1, _ = self.encoder(x1)
        mu2, _ = self.encoder(x2)

        alphas = torch.linspace(0, 1, steps, device=x1.device)
        interpolations = []
        for alpha in alphas:
            z = (1 - alpha) * mu1 + alpha * mu2
            recon = self.decoder(z)
            interpolations.append(recon)

        return torch.stack(interpolations)

4.1 训练循环

Python
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

def train_vae():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 数据
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)  # DataLoader批量加载数据,支持shuffle和多进程

    # 模型
    model = VAE(input_dim=784, hidden_dim=512, latent_dim=20).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(50):
        model.train()  # train()开启训练模式
        total_loss, total_recon, total_kl = 0, 0, 0

        for x, _ in dataloader:
            x = x.view(-1, 784).to(device)  # 链式调用,连续执行多个方法  # view重塑张量形状(要求内存连续)

            x_recon, mu, log_var = model(x)
            loss, recon, kl = model.loss_function(x, x_recon, mu, log_var)

            optimizer.zero_grad()  # 清零梯度,防止梯度累积
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 根据梯度更新模型参数

            total_loss += loss.item()  # .item()将单元素张量转为Python数值
            total_recon += recon.item()
            total_kl += kl.item()

        n = len(dataset)
        print(f"Epoch {epoch+1}: Loss={total_loss/n:.2f}, "
              f"Recon={total_recon/n:.2f}, KL={total_kl/n:.2f}")

    return model

# model = train_vae()

4.2 卷积VAE

Python
class ConvEncoder(nn.Module):
    """卷积编码器"""
    def __init__(self, in_channels, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 2, 1),  # (32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),            # (64, 7, 7)
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),           # (128, 4, 4)
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_log_var = nn.Linear(128 * 4 * 4, latent_dim)

    def forward(self, x):
        h = self.conv(x).view(x.size(0), -1)
        return self.fc_mu(h), self.fc_log_var(h)

class ConvDecoder(nn.Module):
    """卷积解码器"""
    def __init__(self, latent_dim, out_channels):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, 2, 1),      # (64, 7, 7)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),    # (32, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(32, out_channels, 3, 2, 1, 1),  # (C, 28, 28)
            nn.Sigmoid(),
        )

    def forward(self, z):
        h = self.fc(z).view(-1, 128, 4, 4)
        return self.deconv(h)

class ConvVAE(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32):
        super().__init__()
        self.encoder = ConvEncoder(in_channels, latent_dim)
        self.decoder = ConvDecoder(latent_dim, in_channels)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

潜在空间插值可视化


5. 条件VAE (CVAE)

Python
class ConditionalVAE(nn.Module):
    """条件 VAE — 可以按条件(如类别标签)生成"""
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=20, num_classes=10):
        super().__init__()
        self.latent_dim = latent_dim

        # 编码器:输入 x 和 y(one-hot)
        self.enc_fc = nn.Sequential(
            nn.Linear(input_dim + num_classes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)

        # 解码器:输入 z 和 y
        self.dec_fc = nn.Sequential(
            nn.Linear(latent_dim + num_classes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x, y):
        h = self.enc_fc(torch.cat([x, y], dim=1))
        return self.fc_mu(h), self.fc_log_var(h)

    def decode(self, z, y):
        return self.dec_fc(torch.cat([z, y], dim=1))

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        return mu + std * torch.randn_like(std)

    def forward(self, x, y):
        mu, log_var = self.encode(x, y)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z, y)
        return x_recon, mu, log_var

    def generate(self, y, device):
        """按条件生成"""
        z = torch.randn(y.size(0), self.latent_dim, device=device)
        return self.decode(z, y)

6. β-VAE与解耦表示

6.1 β-VAE

Higgins et al.(2017)在 ELBO 中加入超参数 \(\beta\)

\[\mathcal{L} = \mathbb{E}[\log p_\theta(x|z)] - \beta \cdot KL(q_\phi(z|x) \| p(z))\]
  • \(\beta > 1\):更强的正则化,促进隐变量解耦
  • \(\beta < 1\):更注重重构质量
Python
class BetaVAE(VAE):
    """β-VAE"""
    def __init__(self, input_dim=784, hidden_dim=512, latent_dim=10, beta=4.0):
        super().__init__(input_dim, hidden_dim, latent_dim)
        self.beta = beta

    def loss_function(self, x, x_recon, mu, log_var):
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + self.beta * kl_loss, recon_loss, kl_loss

7. VQ-VAE

7.1 思想

Van den Oord et al.(2017)的 VQ-VAE 使用离散隐变量,通过向量量化将编码器输出映射到最近的码本向量。

\[z_q = \text{argmin}_{e_j \in \mathcal{E}} \|z_e - e_j\|\]
Python
class VectorQuantizer(nn.Module):
    """向量量化层"""
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.uniform_(self.embedding.weight, -1/num_embeddings, 1/num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, z_e):
        """
        z_e: (batch, embedding_dim, H, W) — 编码器输出
        """
        # 调整维度
        z_e = z_e.permute(0, 2, 3, 1)  # (B, H, W, D)
        z_e_flat = z_e.reshape(-1, z_e.size(-1))  # (B*H*W, D)  # reshape重塑张量形状

        # 计算距离
        distances = (z_e_flat.pow(2).sum(1, keepdim=True)
                    + self.embedding.weight.pow(2).sum(1)
                    - 2 * z_e_flat @ self.embedding.weight.t())

        # 找最近的码本向量
        encoding_indices = distances.argmin(1)
        z_q = self.embedding(encoding_indices).view(z_e.shape)

        # 损失
        codebook_loss = F.mse_loss(z_q, z_e.detach())  # 更新码本
        commitment_loss = F.mse_loss(z_e, z_q.detach())  # 约束编码器
        vq_loss = codebook_loss + self.commitment_cost * commitment_loss

        # 直通估计器(Straight-Through Estimator)
        z_q = z_e + (z_q - z_e).detach()

        z_q = z_q.permute(0, 3, 1, 2)  # (B, D, H, W)
        return z_q, vq_loss, encoding_indices

class VQVAE(nn.Module):
    """VQ-VAE"""
    def __init__(self, in_channels=1, hidden_dim=128, num_embeddings=512, embedding_dim=64):
        super().__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim // 2, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, hidden_dim, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, embedding_dim, 3, 1, 1),
        )

        # 向量量化
        self.vq = VectorQuantizer(num_embeddings, embedding_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, hidden_dim, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, in_channels, 3, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss, indices = self.vq(z_e)
        x_recon = self.decoder(z_q)

        recon_loss = F.mse_loss(x_recon, x)
        total_loss = recon_loss + vq_loss

        return x_recon, total_loss, recon_loss, vq_loss

# 测试
vqvae = VQVAE(in_channels=1, hidden_dim=128, num_embeddings=512, embedding_dim=64)
x = torch.randn(4, 1, 28, 28)
x_recon, total_loss, recon_loss, vq_loss = vqvae(x)
print(f"VQ-VAE: recon={x_recon.shape}, loss={total_loss.item():.4f}")

VQ-VAE架构图


8. VAE vs GAN

特性 VAE GAN
目标 最大化 ELBO(近似似然) 对抗博弈
训练稳定性 ✅ 稳定 ❌ 可能不稳定
生成质量 略模糊 更锐利
隐空间结构 ✅ 连续、可插值 ❌ 结构较弱
密度估计 ✅ 可近似 ❌ 无法估计
多样性 ✅ 高 ❌ 可能模式崩塌
理论基础 变分推断 博弈论

VAE vs GAN对比


9. 练习与自我检查

练习题

  1. ELBO推导:从头推导 ELBO,证明它是对数似然的下界。
  2. 基本VAE:在 MNIST 上训练 VAE,可视化重构图像和从先验采样的图像。
  3. 隐空间可视化:使用 t-SNE 可视化 VAE 的隐空间,观察不同数字是否聚类。
  4. 插值实验:在两个样本的隐编码之间线性插值,观察生成图像的过渡。
  5. β-VAE:实验不同 \(\beta\) 值,观察对重构质量和解耦程度的影响。
  6. VQ-VAE:实现 VQ-VAE,对比与普通 VAE 的重构质量。

自我检查清单

  • 能推导 ELBO 并解释每一项的含义
  • 理解重参数化技巧解决了什么问题
  • 能从零实现 VAE 并训练
  • 理解 KL 散度的解析解推导
  • 能区分 VAE 和 GAN 的优缺点
  • 了解 β-VAE 和 VQ-VAE 的核心创新

下一篇: 06-高级主题/01-模型压缩与加速