03 - 变分推断基础¶
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 理解扩散模型训练目标的核心数学工具
🎯 学习目标¶
完成本章后,你将能够: - 理解变分推断的核心思想和数学原理 - 掌握证据下界(ELBO)的完整推导 - 理解均值场近似和重参数化技巧 - 将变分推断应用于扩散模型
1. 变分推断概述¶
1.1 问题背景¶
在概率模型中,我们经常遇到后验分布计算困难的问题:
其中 \(p(x) = \int p(x | z) p(z) dz\) 是边缘似然(evidence),通常难以计算。
变分推断的核心思想:用一个简单的分布 \(q(z)\) 来近似复杂的后验 \(p(z | x)\)。
1.2 为什么需要变分推断¶
场景1:贝叶斯神经网络 - 需要计算权重的后验分布 - 积分空间维度极高(数百万参数)
场景2:生成模型 - VAE、扩散模型都需要近似后验 - 直接计算不可行
场景3:概率图模型 - 复杂图结构导致推断困难 - 需要近似算法
1.3 变分推断 vs MCMC¶
| 方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| 变分推断 | 快、可扩展、确定性 | 有偏、近似误差 | 大规模数据、实时应用 |
| MCMC | 渐近精确 | 慢、难以收敛 | 小规模数据、精确推断 |
扩散模型选择变分推断,因为: 1. 需要高效训练 2. 可以接受近似 3. 与深度学习框架兼容
2. KL散度与变分目标¶
2.1 KL散度的定义¶
KL散度衡量两个分布之间的差异:
性质: - \(D_{KL} \geq 0\),当且仅当 \(q = p\) 时等于0 - 不对称:\(D_{KL}(q \Vert p) \neq D_{KL}(p \| q)\) - 变分推断通常最小化 \(D_{KL}(q \Vert p)\)(前向KL)
2.2 推导证据下界(ELBO)¶
我们的目标是最小化 \(D_{KL}(q(z) \| p(z | x))\),但 \(p(z | x)\) 未知。展开KL散度:
重新整理:
由于 \(D_{KL} \geq 0\),我们得到:
这就是证据下界(Evidence Lower BOund, ELBO)。
2.3 ELBO的等价形式¶
ELBO有多种等价表达形式:
形式1: $\(\mathcal{L}(q) = \mathbb{E}_{q(z)} [\log p(x | z)] - D_{KL}(q(z) \| p(z))\)$
- 第一项:重构项,衡量解码质量
- 第二项:正则项,使 \(q(z)\) 接近先验 \(p(z)\)
形式2: $\(\mathcal{L}(q) = \log p(x) - D_{KL}(q(z) \| p(z | x))\)$
这表明: - 最大化ELBO等价于最小化 \(D_{KL}(q(z) \| p(z | x))\) - ELBO越紧(接近 \(\log p(x)\)),近似越好
2.4 代码验证¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
def kl_divergence_gaussian(mu1, logvar1, mu2, logvar2):
"""
计算两个高斯分布之间的KL散度
参数:
mu1, logvar1: 分布q的均值和对数方差
mu2, logvar2: 分布p的均值和对数方差
返回:
KL(q||p)
"""
var1 = torch.exp(logvar1)
var2 = torch.exp(logvar2)
kl = 0.5 * (logvar2 - logvar1 + (var1 + (mu1 - mu2)**2) / var2 - 1)
return kl.sum(dim=-1)
def compute_elbo(recon_x, x, mu, logvar):
"""
计算VAE的ELBO
参数:
recon_x: 重构的图像
x: 原始图像
mu, logvar: 编码器输出的均值和对数方差
返回:
ELBO损失
"""
# 重构损失(二元交叉熵或MSE)
recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.shape[0] # F.xxx PyTorch函数式API
# KL散度
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.shape[0]
# ELBO = - (重构损失 + KL散度)
elbo = -(recon_loss + kl_loss)
return elbo, recon_loss, kl_loss
# 测试
print("=" * 60)
print("ELBO计算测试")
print("=" * 60)
batch_size = 4
latent_dim = 10
image_size = 32
# 模拟数据
x = torch.randn(batch_size, 3, image_size, image_size)
recon_x = torch.randn(batch_size, 3, image_size, image_size)
mu = torch.randn(batch_size, latent_dim)
logvar = torch.randn(batch_size, latent_dim)
elbo, recon_loss, kl_loss = compute_elbo(recon_x, x, mu, logvar)
print(f"重构损失: {recon_loss.item():.4f}") # 将单元素张量转为Python数值
print(f"KL散度: {kl_loss.item():.4f}")
print(f"ELBO: {elbo.item():.4f}")
3. 均值场近似¶
3.1 什么是均值场近似¶
均值场近似(Mean Field Approximation)假设变分分布可以分解为独立因子的乘积:
这意味着我们假设隐变量的不同维度之间相互独立。
3.2 坐标上升变分推断(CAVI)¶
目标:找到最优的 \(q_i(z_i)\) 来最大化ELBO。
推导:
对于每个因子 \(q_j(z_j)\),固定其他因子,最优解为:
其中 \(\mathbb{E}_{-j}\) 表示对除 \(z_j\) 外的所有变量求期望。
算法流程:
算法: CAVI
─────────────────────────────────
初始化所有 q_i(z_i)
重复直到收敛:
对于每个 j = 1, ..., M:
计算: log q_j(z_j) = E_{-j}[log p(x, z)] + const
归一化得到 q_j(z_j)
计算 ELBO
3.3 高斯均值场¶
假设每个 \(q_i(z_i)\) 是高斯分布:
优化参数:均值 \(\mu_i\) 和方差 \(\sigma_i^2\)(或对数方差)。
梯度:
4. 重参数化技巧(Reparameterization Trick)¶
4.1 问题:随机节点的梯度¶
在变分推断中,我们需要从 \(q_\phi(z)\) 采样,然后计算梯度:
问题:采样操作是不可导的!
4.2 解决方案¶
对于高斯分布 \(q_\phi(z) = \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x))\),我们可以:
这样:
梯度可以通过 \(\mu_\phi\) 和 \(\sigma_\phi\) 传播!
4.3 一般形式¶
对于任意分布 \(q_\phi(z)\),如果存在变换:
则:
常见分布的重参数化:
| 分布 | 重参数化 |
|---|---|
| 高斯 | \(z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I)\) |
| 指数 | \(z = -\log(1 - \epsilon) / \lambda, \epsilon \sim \text{Uniform}(0, 1)\) |
| Gumbel | \(z = \mu - \beta \log(-\log \epsilon), \epsilon \sim \text{Uniform}(0, 1)\) |
4.4 代码实现¶
class ReparameterizedGaussian:
"""
可重参数化的高斯分布
"""
def __init__(self, mu, logvar):
self.mu = mu
self.logvar = logvar
self.std = torch.exp(0.5 * logvar)
def sample(self, num_samples=None):
"""
重参数化采样
"""
if num_samples is None:
eps = torch.randn_like(self.std)
else:
eps = torch.randn(num_samples, *self.std.shape)
return self.mu + self.std * eps
def rsample(self, num_samples=None):
"""
可微采样(PyTorch风格)
"""
return self.sample(num_samples)
def log_prob(self, z):
"""
计算对数概率
"""
var = torch.exp(self.logvar)
log_prob = -0.5 * (
torch.log(2 * torch.pi * var) +
(z - self.mu).pow(2) / var
)
return log_prob.sum(dim=-1)
def kl_divergence(self, other_mu=0, other_logvar=0):
"""
计算与标准高斯的KL散度
"""
return kl_divergence_gaussian(
self.mu, self.logvar,
torch.zeros_like(self.mu),
torch.zeros_like(self.logvar)
)
# 测试重参数化
print("\n" + "=" * 60)
print("重参数化技巧测试")
print("=" * 60)
mu = torch.tensor([0.0, 1.0, -1.0])
logvar = torch.tensor([0.0, 0.5, -0.5])
dist = ReparameterizedGaussian(mu, logvar)
# 多次采样
samples = []
for _ in range(1000):
z = dist.sample()
samples.append(z)
samples = torch.stack(samples) # torch.stack沿新维度拼接张量
print(f"理论均值: {mu}")
print(f"采样均值: {samples.mean(dim=0)}")
print(f"理论标准差: {torch.exp(0.5 * logvar)}")
print(f"采样标准差: {samples.std(dim=0)}")
# 验证梯度
mu_param = torch.tensor([0.0, 1.0], requires_grad=True)
logvar_param = torch.tensor([0.0, 0.0], requires_grad=True)
dist_grad = ReparameterizedGaussian(mu_param, logvar_param)
z = dist_grad.sample()
loss = z.pow(2).sum()
loss.backward() # 反向传播计算梯度
print(f"\n梯度验证:")
print(f"mu梯度: {mu_param.grad}")
print(f"logvar梯度: {logvar_param.grad}")
5. 变分自编码器(VAE)¶
5.1 VAE作为变分推断¶
VAE是变分推断在深度学习中的典型应用:
生成模型: $\(p_\theta(x, z) = p_\theta(x | z) p(z)\)$
推断模型(编码器): $\(q_\phi(z | x) = \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x))\)$
ELBO: $\(\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z | x)} [\log p_\theta(x | z)] - D_{KL}(q_\phi(z | x) \| p(z))\)$
5.2 与扩散模型的联系¶
| 特性 | VAE | 扩散模型 |
|---|---|---|
| 隐变量 | 低维 \(z\) | 高维 \(x_{1:T}\) |
| 编码器 | 学习 \(q_\phi(z \mid x)\) | 固定前向过程 |
| 解码器 | 学习 \(p_\theta(x \mid z)\) | 学习反向过程 |
| 推断 | 单步 | 多步马尔可夫链 |
扩散模型可以看作: - 编码器是固定的(前向扩散) - 解码器是多步的(反向去噪) - 隐变量与数据同维度
6. 变分推断在扩散模型中的应用¶
6.1 扩散模型的变分目标¶
回顾扩散模型的联合分布:
前向(固定): $\(q(x_{0:T}) = q(x_0) \prod_{t=1}^T q(x_t | x_{t-1})\)$
反向(学习): $\(p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} | x_t)\)$
变分下界: $\(\log p_\theta(x_0) \geq \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} \right]\)$
6.2 简化为去噪目标¶
经过推导(详见第4章),ELBO可以简化为:
这就是去噪目标!
直观理解: - 变分推断告诉我们应该优化什么(ELBO) - 数学推导简化为具体的目标函数 - 最终形式简单但有效
6.3 为什么扩散模型有效¶
- 灵活的近似族:高斯分布可以近似复杂的后验
- 多步细化:每一步只需要做小的修正
- 重参数化:可以使用梯度下降高效训练
- 变分下界:提供了理论保证
7. 高级主题¶
7.1 重要性加权自编码器(IWAE)¶
使用多个样本获得更紧的下界:
当 \(k \to \infty\) 时,\(\mathcal{L}_k \to \log p(x)\)。
7.2 归一化流(Normalizing Flows)¶
使用可逆变换构建复杂的变分分布:
7.3 变分推断的局限性¶
- 近似误差:\(q(z)\) 可能无法很好地近似 \(p(z | x)\)
- 局部最优:优化可能陷入局部最优
- 模型选择:选择合适的 \(q(z)\) 族很重要
- 计算成本:对于复杂模型,计算成本仍然很高
8. 本章总结¶
核心概念¶
- 变分推断
- 用简单分布近似复杂后验
- 最小化KL散度
-
最大化证据下界(ELBO)
-
ELBO
- \(\mathcal{L} = \mathbb{E}_q[\log p(x, z) - \log q(z)]\)
- 重构项 + 正则项
-
提供了似然的下界
-
重参数化技巧
- 使随机采样可微
- 高斯:\(z = \mu + \sigma \cdot \epsilon\)
-
支持端到端训练
-
与扩散模型的联系
- 扩散模型是多步VAE
- 前向过程是固定编码器
- 反向过程是学习解码器
关键公式¶
| 概念 | 公式 |
|---|---|
| KL散度 | \(D_{KL}(q \Vert p) = \mathbb{E}_q[\log q - \log p]\) |
| ELBO | \(\mathcal{L} = \mathbb{E}_q[\log p(x, z) - \log q(z)]\) |
| 高斯ELBO | \(\mathcal{L} = \mathbb{E}[\log p(x \| z)] - D_{KL}(q(z) \| p(z))\) |
| 重参数化 | \(z = \mu + \sigma \cdot \epsilon\) |
算法流程¶
变分推断算法:
1. 选择变分分布族 q_φ(z)
2. 初始化参数 φ
3. 重复:
a. 从 q_φ(z) 采样 z
b. 计算 ELBO
c. 计算梯度 ∇_φ ELBO
d. 更新参数 φ
4. 直到收敛
📝 自测问题¶
基础问题¶
- 变分推断基础
- 为什么需要变分推断?
- KL散度为什么不对称?
-
前向KL和后向KL有什么区别?
-
ELBO推导
- 从KL散度推导出ELBO
- ELBO的两项分别代表什么?
-
为什么ELBO是似然的下界?
-
重参数化技巧
- 为什么需要重参数化?
- 高斯分布如何重参数化?
- 重参数化对梯度计算有什么帮助?
编程练习¶
- 实现完整的VAE模型
- 比较不同样本数下的ELBO
- 实现重要性加权自编码器(IWAE)
- 可视化变分分布的优化过程
思考题¶
- 变分推断的近似误差来自哪里?
- 如何设计更好的变分分布?
- 扩散模型相比VAE的优势在哪里?
- 变分推断和MCMC如何选择?
🔗 下一步¶
理解了变分推断后,我们将学习线性代数基础,这是理解扩散模型的数学工具。
→ 下一步:04-线性代数基础.md