02-VAE基础¶
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 概率论(贝叶斯推断)、信息论基础、PyTorch 学习目标: 理解VAE的数学原理(ELBO推导),掌握VAE及其变体的实现
目录¶
- 1. 生成模型与隐变量
- 2. VAE的数学推导
- 3. 重参数化技巧
- 4. VAE的完整实现
- 5. 条件VAE (CVAE)
- 6. β-VAE与解耦表示
- 7. VQ-VAE
- 8. VAE vs GAN
- 9. 练习与自我检查
1. 生成模型与隐变量¶
1.1 隐变量模型¶
假设可观察数据 \(x\) 由一些不可观察的隐变量 \(z\) 生成:
\[p_\theta(x) = \int p_\theta(x|z) p(z) dz\]
问题: - 积分不可解(intractable)— \(z\) 的维度很高 - 后验 \(p_\theta(z|x)\) 也不可解
1.2 变分推断的思路¶
用一个可训练的分布 \(q_\phi(z|x)\)(编码器)来近似真实后验 \(p_\theta(z|x)\)。
2. VAE的数学推导¶
2.1 ELBO推导¶
从对数边际似然出发:
\[\log p_\theta(x) = \log \int p_\theta(x|z)p(z)dz\]
引入变分分布 \(q_\phi(z|x)\):
\[\log p_\theta(x) = \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x, z)}{q_\phi(z|x)}\right] + KL(q_\phi(z|x) \| p_\theta(z|x))\]
由于 \(KL \geq 0\):
\[\log p_\theta(x) \geq \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x, z)}{q_\phi(z|x)}\right] = \text{ELBO}\]
2.2 ELBO的分解¶
\[\text{ELBO} = \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{重构项}} - \underbrace{KL(q_\phi(z|x) \| p(z))}_{\text{正则项}}\]
- 重构项:解码器重建输入的能力
- 正则项:让编码器分布接近先验 \(p(z) = \mathcal{N}(0, I)\)
2.3 高斯假设下的KL散度解析解¶
当 \(q_\phi(z|x) = \mathcal{N}(\mu, \text{diag}(\sigma^2))\),\(p(z) = \mathcal{N}(0, I)\) 时:
\[KL(q_\phi \| p) = -\frac{1}{2}\sum_{j=1}^{d}\left(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right)\]
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
def kl_divergence(mu, log_var):
"""
高斯 KL 散度的解析解
mu: (batch, latent_dim)
log_var: (batch, latent_dim) — 对数方差
"""
return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
3. 重参数化技巧¶
3.1 问题¶
训练时需要对 ELBO 求梯度,但采样操作 \(z \sim q_\phi(z|x)\) 不可导。
3.2 解决方案¶
将随机性从参数中分离:
\[z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]
这样 \(z\) 关于 \(\mu\) 和 \(\sigma\) 是确定性函数,梯度可以反传。
Python
def reparameterize(mu, log_var):
"""重参数化技巧"""
std = torch.exp(0.5 * log_var) # σ = exp(log_var / 2)
eps = torch.randn_like(std) # ε ~ N(0, I)
z = mu + std * eps # z = μ + σ * ε
return z
4. VAE的完整实现¶
Python
class Encoder(nn.Module): # 继承nn.Module定义神经网络层
"""VAE 编码器 — 输出 μ 和 log σ²"""
def __init__(self, input_dim, hidden_dim, latent_dim): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.fc = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = self.fc(x)
mu = self.fc_mu(h)
log_var = self.fc_log_var(h)
return mu, log_var
class Decoder(nn.Module):
"""VAE 解码器"""
def __init__(self, latent_dim, hidden_dim, output_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.Sigmoid() # 输出在 [0, 1]
)
def forward(self, z):
return self.fc(z)
class VAE(nn.Module):
"""变分自编码器"""
def __init__(self, input_dim=784, hidden_dim=512, latent_dim=20):
super().__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
self.latent_dim = latent_dim
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z)
return x_recon, mu, log_var
def loss_function(self, x, x_recon, mu, log_var):
# 重构损失(BCE)
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# KL 散度
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + kl_loss, recon_loss, kl_loss
def sample(self, num_samples, device):
"""从先验分布采样并解码"""
z = torch.randn(num_samples, self.latent_dim, device=device)
samples = self.decoder(z)
return samples
def interpolate(self, x1, x2, steps=10):
"""在两个样本之间进行隐空间插值"""
mu1, _ = self.encoder(x1)
mu2, _ = self.encoder(x2)
alphas = torch.linspace(0, 1, steps, device=x1.device)
interpolations = []
for alpha in alphas:
z = (1 - alpha) * mu1 + alpha * mu2
recon = self.decoder(z)
interpolations.append(recon)
return torch.stack(interpolations)
4.1 训练循环¶
Python
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def train_vae():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据
transform = transforms.Compose([transforms.ToTensor()])
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True) # DataLoader批量加载数据,支持shuffle和多进程
# 模型
model = VAE(input_dim=784, hidden_dim=512, latent_dim=20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
model.train() # train()开启训练模式
total_loss, total_recon, total_kl = 0, 0, 0
for x, _ in dataloader:
x = x.view(-1, 784).to(device) # 链式调用,连续执行多个方法 # view重塑张量形状(要求内存连续)
x_recon, mu, log_var = model(x)
loss, recon, kl = model.loss_function(x, x_recon, mu, log_var)
optimizer.zero_grad() # 清零梯度,防止梯度累积
loss.backward() # 反向传播计算梯度
optimizer.step() # 根据梯度更新模型参数
total_loss += loss.item() # .item()将单元素张量转为Python数值
total_recon += recon.item()
total_kl += kl.item()
n = len(dataset)
print(f"Epoch {epoch+1}: Loss={total_loss/n:.2f}, "
f"Recon={total_recon/n:.2f}, KL={total_kl/n:.2f}")
return model
# model = train_vae()
4.2 卷积VAE¶
Python
class ConvEncoder(nn.Module):
"""卷积编码器"""
def __init__(self, in_channels, latent_dim):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, 2, 1), # (32, 14, 14)
nn.ReLU(),
nn.Conv2d(32, 64, 3, 2, 1), # (64, 7, 7)
nn.ReLU(),
nn.Conv2d(64, 128, 3, 2, 1), # (128, 4, 4)
nn.ReLU(),
)
self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
self.fc_log_var = nn.Linear(128 * 4 * 4, latent_dim)
def forward(self, x):
h = self.conv(x).view(x.size(0), -1)
return self.fc_mu(h), self.fc_log_var(h)
class ConvDecoder(nn.Module):
"""卷积解码器"""
def __init__(self, latent_dim, out_channels):
super().__init__()
self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, 2, 1), # (64, 7, 7)
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), # (32, 14, 14)
nn.ReLU(),
nn.ConvTranspose2d(32, out_channels, 3, 2, 1, 1), # (C, 28, 28)
nn.Sigmoid(),
)
def forward(self, z):
h = self.fc(z).view(-1, 128, 4, 4)
return self.deconv(h)
class ConvVAE(nn.Module):
def __init__(self, in_channels=1, latent_dim=32):
super().__init__()
self.encoder = ConvEncoder(in_channels, latent_dim)
self.decoder = ConvDecoder(latent_dim, in_channels)
self.latent_dim = latent_dim
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z)
return x_recon, mu, log_var
5. 条件VAE (CVAE)¶
Python
class ConditionalVAE(nn.Module):
"""条件 VAE — 可以按条件(如类别标签)生成"""
def __init__(self, input_dim=784, hidden_dim=512, latent_dim=20, num_classes=10):
super().__init__()
self.latent_dim = latent_dim
# 编码器:输入 x 和 y(one-hot)
self.enc_fc = nn.Sequential(
nn.Linear(input_dim + num_classes, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
# 解码器:输入 z 和 y
self.dec_fc = nn.Sequential(
nn.Linear(latent_dim + num_classes, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x, y):
h = self.enc_fc(torch.cat([x, y], dim=1))
return self.fc_mu(h), self.fc_log_var(h)
def decode(self, z, y):
return self.dec_fc(torch.cat([z, y], dim=1))
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
return mu + std * torch.randn_like(std)
def forward(self, x, y):
mu, log_var = self.encode(x, y)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z, y)
return x_recon, mu, log_var
def generate(self, y, device):
"""按条件生成"""
z = torch.randn(y.size(0), self.latent_dim, device=device)
return self.decode(z, y)
6. β-VAE与解耦表示¶
6.1 β-VAE¶
Higgins et al.(2017)在 ELBO 中加入超参数 \(\beta\):
\[\mathcal{L} = \mathbb{E}[\log p_\theta(x|z)] - \beta \cdot KL(q_\phi(z|x) \| p(z))\]
- \(\beta > 1\):更强的正则化,促进隐变量解耦
- \(\beta < 1\):更注重重构质量
Python
class BetaVAE(VAE):
"""β-VAE"""
def __init__(self, input_dim=784, hidden_dim=512, latent_dim=10, beta=4.0):
super().__init__(input_dim, hidden_dim, latent_dim)
self.beta = beta
def loss_function(self, x, x_recon, mu, log_var):
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + self.beta * kl_loss, recon_loss, kl_loss
7. VQ-VAE¶
7.1 思想¶
Van den Oord et al.(2017)的 VQ-VAE 使用离散隐变量,通过向量量化将编码器输出映射到最近的码本向量。
\[z_q = \text{argmin}_{e_j \in \mathcal{E}} \|z_e - e_j\|\]
Python
class VectorQuantizer(nn.Module):
"""向量量化层"""
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
nn.init.uniform_(self.embedding.weight, -1/num_embeddings, 1/num_embeddings)
self.commitment_cost = commitment_cost
def forward(self, z_e):
"""
z_e: (batch, embedding_dim, H, W) — 编码器输出
"""
# 调整维度
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, D)
z_e_flat = z_e.reshape(-1, z_e.size(-1)) # (B*H*W, D) # reshape重塑张量形状
# 计算距离
distances = (z_e_flat.pow(2).sum(1, keepdim=True)
+ self.embedding.weight.pow(2).sum(1)
- 2 * z_e_flat @ self.embedding.weight.t())
# 找最近的码本向量
encoding_indices = distances.argmin(1)
z_q = self.embedding(encoding_indices).view(z_e.shape)
# 损失
codebook_loss = F.mse_loss(z_q, z_e.detach()) # 更新码本
commitment_loss = F.mse_loss(z_e, z_q.detach()) # 约束编码器
vq_loss = codebook_loss + self.commitment_cost * commitment_loss
# 直通估计器(Straight-Through Estimator)
z_q = z_e + (z_q - z_e).detach()
z_q = z_q.permute(0, 3, 1, 2) # (B, D, H, W)
return z_q, vq_loss, encoding_indices
class VQVAE(nn.Module):
"""VQ-VAE"""
def __init__(self, in_channels=1, hidden_dim=128, num_embeddings=512, embedding_dim=64):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim // 2, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim // 2, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim, embedding_dim, 3, 1, 1),
)
# 向量量化
self.vq = VectorQuantizer(num_embeddings, embedding_dim)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embedding_dim, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim // 2, in_channels, 3, 1, 1),
nn.Sigmoid(),
)
def forward(self, x):
z_e = self.encoder(x)
z_q, vq_loss, indices = self.vq(z_e)
x_recon = self.decoder(z_q)
recon_loss = F.mse_loss(x_recon, x)
total_loss = recon_loss + vq_loss
return x_recon, total_loss, recon_loss, vq_loss
# 测试
vqvae = VQVAE(in_channels=1, hidden_dim=128, num_embeddings=512, embedding_dim=64)
x = torch.randn(4, 1, 28, 28)
x_recon, total_loss, recon_loss, vq_loss = vqvae(x)
print(f"VQ-VAE: recon={x_recon.shape}, loss={total_loss.item():.4f}")
8. VAE vs GAN¶
| 特性 | VAE | GAN |
|---|---|---|
| 目标 | 最大化 ELBO(近似似然) | 对抗博弈 |
| 训练稳定性 | ✅ 稳定 | ❌ 可能不稳定 |
| 生成质量 | 略模糊 | 更锐利 |
| 隐空间结构 | ✅ 连续、可插值 | ❌ 结构较弱 |
| 密度估计 | ✅ 可近似 | ❌ 无法估计 |
| 多样性 | ✅ 高 | ❌ 可能模式崩塌 |
| 理论基础 | 变分推断 | 博弈论 |
9. 练习与自我检查¶
练习题¶
- ELBO推导:从头推导 ELBO,证明它是对数似然的下界。
- 基本VAE:在 MNIST 上训练 VAE,可视化重构图像和从先验采样的图像。
- 隐空间可视化:使用 t-SNE 可视化 VAE 的隐空间,观察不同数字是否聚类。
- 插值实验:在两个样本的隐编码之间线性插值,观察生成图像的过渡。
- β-VAE:实验不同 \(\beta\) 值,观察对重构质量和解耦程度的影响。
- VQ-VAE:实现 VQ-VAE,对比与普通 VAE 的重构质量。
自我检查清单¶
- 能推导 ELBO 并解释每一项的含义
- 理解重参数化技巧解决了什么问题
- 能从零实现 VAE 并训练
- 理解 KL 散度的解析解推导
- 能区分 VAE 和 GAN 的优缺点
- 了解 β-VAE 和 VQ-VAE 的核心创新
下一篇: 06-高级主题/01-模型压缩与加速




