跳转至

03 - 变分推断基础

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 理解扩散模型训练目标的核心数学工具


🎯 学习目标

完成本章后,你将能够: - 理解变分推断的核心思想和数学原理 - 掌握证据下界(ELBO)的完整推导 - 理解均值场近似和重参数化技巧 - 将变分推断应用于扩散模型


1. 变分推断概述

1.1 问题背景

在概率模型中,我们经常遇到后验分布计算困难的问题:

\[p(z | x) = \frac{p(x | z) p(z)}{p(x)}\]

其中 \(p(x) = \int p(x | z) p(z) dz\)边缘似然(evidence),通常难以计算。

变分推断的核心思想:用一个简单的分布 \(q(z)\) 来近似复杂的后验 \(p(z | x)\)

1.2 为什么需要变分推断

场景1:贝叶斯神经网络 - 需要计算权重的后验分布 - 积分空间维度极高(数百万参数)

场景2:生成模型 - VAE、扩散模型都需要近似后验 - 直接计算不可行

场景3:概率图模型 - 复杂图结构导致推断困难 - 需要近似算法

1.3 变分推断 vs MCMC

方法 优点 缺点 适用场景
变分推断 快、可扩展、确定性 有偏、近似误差 大规模数据、实时应用
MCMC 渐近精确 慢、难以收敛 小规模数据、精确推断

扩散模型选择变分推断,因为: 1. 需要高效训练 2. 可以接受近似 3. 与深度学习框架兼容


2. KL散度与变分目标

2.1 KL散度的定义

KL散度衡量两个分布之间的差异:

\[D_{KL}(q(z) \| p(z | x)) = \int q(z) \log \frac{q(z)}{p(z | x)} dz\]

性质: - \(D_{KL} \geq 0\),当且仅当 \(q = p\) 时等于0 - 不对称\(D_{KL}(q \Vert p) \neq D_{KL}(p \| q)\) - 变分推断通常最小化 \(D_{KL}(q \Vert p)\)(前向KL)

2.2 推导证据下界(ELBO)

我们的目标是最小化 \(D_{KL}(q(z) \| p(z | x))\),但 \(p(z | x)\) 未知。展开KL散度:

\[ \begin{aligned} D_{KL}(q(z) \| p(z | x)) &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log p(z | x) \right] \\ &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log \frac{p(x, z)}{p(x)} \right] \\ &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log p(x, z) + \log p(x) \right] \\ &= \mathbb{E}_{q(z)} \left[ \log q(z) - \log p(x, z) \right] + \log p(x) \end{aligned} \]

重新整理:

\[\log p(x) = D_{KL}(q(z) \| p(z | x)) + \mathbb{E}_{q(z)} \left[ \log p(x, z) - \log q(z) \right]\]

由于 \(D_{KL} \geq 0\),我们得到:

\[\log p(x) \geq \mathbb{E}_{q(z)} \left[ \log p(x, z) - \log q(z) \right] =: \mathcal{L}(q)\]

这就是证据下界(Evidence Lower BOund, ELBO)

2.3 ELBO的等价形式

ELBO有多种等价表达形式:

形式1: $\(\mathcal{L}(q) = \mathbb{E}_{q(z)} [\log p(x | z)] - D_{KL}(q(z) \| p(z))\)$

  • 第一项:重构项,衡量解码质量
  • 第二项:正则项,使 \(q(z)\) 接近先验 \(p(z)\)

形式2: $\(\mathcal{L}(q) = \log p(x) - D_{KL}(q(z) \| p(z | x))\)$

这表明: - 最大化ELBO等价于最小化 \(D_{KL}(q(z) \| p(z | x))\) - ELBO越紧(接近 \(\log p(x)\)),近似越好

2.4 代码验证

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def kl_divergence_gaussian(mu1, logvar1, mu2, logvar2):
    """
    计算两个高斯分布之间的KL散度

    参数:
        mu1, logvar1: 分布q的均值和对数方差
        mu2, logvar2: 分布p的均值和对数方差

    返回:
        KL(q||p)
    """
    var1 = torch.exp(logvar1)
    var2 = torch.exp(logvar2)

    kl = 0.5 * (logvar2 - logvar1 + (var1 + (mu1 - mu2)**2) / var2 - 1)
    return kl.sum(dim=-1)

def compute_elbo(recon_x, x, mu, logvar):
    """
    计算VAE的ELBO

    参数:
        recon_x: 重构的图像
        x: 原始图像
        mu, logvar: 编码器输出的均值和对数方差

    返回:
        ELBO损失
    """
    # 重构损失(二元交叉熵或MSE)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.shape[0]  # F.xxx PyTorch函数式API

    # KL散度
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.shape[0]

    # ELBO = - (重构损失 + KL散度)
    elbo = -(recon_loss + kl_loss)

    return elbo, recon_loss, kl_loss

# 测试
print("=" * 60)
print("ELBO计算测试")
print("=" * 60)

batch_size = 4
latent_dim = 10
image_size = 32

# 模拟数据
x = torch.randn(batch_size, 3, image_size, image_size)
recon_x = torch.randn(batch_size, 3, image_size, image_size)
mu = torch.randn(batch_size, latent_dim)
logvar = torch.randn(batch_size, latent_dim)

elbo, recon_loss, kl_loss = compute_elbo(recon_x, x, mu, logvar)

print(f"重构损失: {recon_loss.item():.4f}")  # 将单元素张量转为Python数值
print(f"KL散度: {kl_loss.item():.4f}")
print(f"ELBO: {elbo.item():.4f}")

3. 均值场近似

3.1 什么是均值场近似

均值场近似(Mean Field Approximation)假设变分分布可以分解为独立因子的乘积:

\[q(z) = \prod_{i=1}^M q_i(z_i)\]

这意味着我们假设隐变量的不同维度之间相互独立。

3.2 坐标上升变分推断(CAVI)

目标:找到最优的 \(q_i(z_i)\) 来最大化ELBO。

推导

对于每个因子 \(q_j(z_j)\),固定其他因子,最优解为:

\[\log q_j^*(z_j) = \mathbb{E}_{-j} [\log p(x, z)] + \text{const}\]

其中 \(\mathbb{E}_{-j}\) 表示对除 \(z_j\) 外的所有变量求期望。

算法流程

Text Only
算法: CAVI
─────────────────────────────────
初始化所有 q_i(z_i)

重复直到收敛:
  对于每个 j = 1, ..., M:
    计算: log q_j(z_j) = E_{-j}[log p(x, z)] + const
    归一化得到 q_j(z_j)

  计算 ELBO

3.3 高斯均值场

假设每个 \(q_i(z_i)\) 是高斯分布:

\[q_i(z_i) = \mathcal{N}(z_i; \mu_i, \sigma_i^2)\]

优化参数:均值 \(\mu_i\) 和方差 \(\sigma_i^2\)(或对数方差)。

梯度

\[\frac{\partial \mathcal{L}}{\partial \mu_i} = \mathbb{E}_{q} \left[ \frac{\partial \log p(x, z)}{\partial z_i} \right]\]
\[\frac{\partial \mathcal{L}}{\partial \sigma_i} = \mathbb{E}_{q} \left[ \frac{\partial \log p(x, z)}{\partial z_i} \cdot \frac{z_i - \mu_i}{\sigma_i} \right] + \frac{1}{\sigma_i}\]

4. 重参数化技巧(Reparameterization Trick)

4.1 问题:随机节点的梯度

在变分推断中,我们需要从 \(q_\phi(z)\) 采样,然后计算梯度:

\[\nabla_\phi \mathbb{E}_{q_\phi(z)} [f(z)]\]

问题:采样操作是不可导的!

4.2 解决方案

对于高斯分布 \(q_\phi(z) = \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x))\),我们可以:

\[z = \mu_\phi(x) + \sigma_\phi(x) \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

这样:

\[\mathbb{E}_{q_\phi(z)} [f(z)] = \mathbb{E}_{\mathcal{N}(\epsilon; 0, I)} [f(\mu_\phi(x) + \sigma_\phi(x) \cdot \epsilon)]\]

梯度可以通过 \(\mu_\phi\)\(\sigma_\phi\) 传播!

4.3 一般形式

对于任意分布 \(q_\phi(z)\),如果存在变换:

\[z = g_\phi(\epsilon, x), \quad \epsilon \sim p(\epsilon)\]

则:

\[\mathbb{E}_{q_\phi(z)} [f(z)] = \mathbb{E}_{p(\epsilon)} [f(g_\phi(\epsilon, x))]\]

常见分布的重参数化

分布 重参数化
高斯 \(z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0, I)\)
指数 \(z = -\log(1 - \epsilon) / \lambda, \epsilon \sim \text{Uniform}(0, 1)\)
Gumbel \(z = \mu - \beta \log(-\log \epsilon), \epsilon \sim \text{Uniform}(0, 1)\)

4.4 代码实现

Python
class ReparameterizedGaussian:
    """
    可重参数化的高斯分布
    """
    def __init__(self, mu, logvar):
        self.mu = mu
        self.logvar = logvar
        self.std = torch.exp(0.5 * logvar)

    def sample(self, num_samples=None):
        """
        重参数化采样
        """
        if num_samples is None:
            eps = torch.randn_like(self.std)
        else:
            eps = torch.randn(num_samples, *self.std.shape)
        return self.mu + self.std * eps

    def rsample(self, num_samples=None):
        """
        可微采样(PyTorch风格)
        """
        return self.sample(num_samples)

    def log_prob(self, z):
        """
        计算对数概率
        """
        var = torch.exp(self.logvar)
        log_prob = -0.5 * (
            torch.log(2 * torch.pi * var) +
            (z - self.mu).pow(2) / var
        )
        return log_prob.sum(dim=-1)

    def kl_divergence(self, other_mu=0, other_logvar=0):
        """
        计算与标准高斯的KL散度
        """
        return kl_divergence_gaussian(
            self.mu, self.logvar,
            torch.zeros_like(self.mu),
            torch.zeros_like(self.logvar)
        )

# 测试重参数化
print("\n" + "=" * 60)
print("重参数化技巧测试")
print("=" * 60)

mu = torch.tensor([0.0, 1.0, -1.0])
logvar = torch.tensor([0.0, 0.5, -0.5])

dist = ReparameterizedGaussian(mu, logvar)

# 多次采样
samples = []
for _ in range(1000):
    z = dist.sample()
    samples.append(z)

samples = torch.stack(samples)  # torch.stack沿新维度拼接张量

print(f"理论均值: {mu}")
print(f"采样均值: {samples.mean(dim=0)}")
print(f"理论标准差: {torch.exp(0.5 * logvar)}")
print(f"采样标准差: {samples.std(dim=0)}")

# 验证梯度
mu_param = torch.tensor([0.0, 1.0], requires_grad=True)
logvar_param = torch.tensor([0.0, 0.0], requires_grad=True)

dist_grad = ReparameterizedGaussian(mu_param, logvar_param)
z = dist_grad.sample()
loss = z.pow(2).sum()
loss.backward()  # 反向传播计算梯度

print(f"\n梯度验证:")
print(f"mu梯度: {mu_param.grad}")
print(f"logvar梯度: {logvar_param.grad}")

5. 变分自编码器(VAE)

5.1 VAE作为变分推断

VAE是变分推断在深度学习中的典型应用:

生成模型: $\(p_\theta(x, z) = p_\theta(x | z) p(z)\)$

推断模型(编码器): $\(q_\phi(z | x) = \mathcal{N}(z; \mu_\phi(x), \sigma_\phi^2(x))\)$

ELBO: $\(\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z | x)} [\log p_\theta(x | z)] - D_{KL}(q_\phi(z | x) \| p(z))\)$

5.2 与扩散模型的联系

特性 VAE 扩散模型
隐变量 低维 \(z\) 高维 \(x_{1:T}\)
编码器 学习 \(q_\phi(z \mid x)\) 固定前向过程
解码器 学习 \(p_\theta(x \mid z)\) 学习反向过程
推断 单步 多步马尔可夫链

扩散模型可以看作: - 编码器是固定的(前向扩散) - 解码器是多步的(反向去噪) - 隐变量与数据同维度


6. 变分推断在扩散模型中的应用

6.1 扩散模型的变分目标

回顾扩散模型的联合分布:

前向(固定): $\(q(x_{0:T}) = q(x_0) \prod_{t=1}^T q(x_t | x_{t-1})\)$

反向(学习): $\(p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} | x_t)\)$

变分下界: $\(\log p_\theta(x_0) \geq \mathbb{E}_{q(x_{1:T} | x_0)} \left[ \log \frac{p_\theta(x_{0:T})}{q(x_{1:T} | x_0)} \right]\)$

6.2 简化为去噪目标

经过推导(详见第4章),ELBO可以简化为:

\[\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]\]

这就是去噪目标

直观理解: - 变分推断告诉我们应该优化什么(ELBO) - 数学推导简化为具体的目标函数 - 最终形式简单但有效

6.3 为什么扩散模型有效

  1. 灵活的近似族:高斯分布可以近似复杂的后验
  2. 多步细化:每一步只需要做小的修正
  3. 重参数化:可以使用梯度下降高效训练
  4. 变分下界:提供了理论保证

7. 高级主题

7.1 重要性加权自编码器(IWAE)

使用多个样本获得更紧的下界:

\[\mathcal{L}_k = \mathbb{E}_{z_1, ..., z_k \sim q(z | x)} \left[ \log \frac{1}{k} \sum_{i=1}^k \frac{p(x, z_i)}{q(z_i | x)} \right]\]

\(k \to \infty\) 时,\(\mathcal{L}_k \to \log p(x)\)

7.2 归一化流(Normalizing Flows)

使用可逆变换构建复杂的变分分布:

\[z = f_\phi(\epsilon), \quad \epsilon \sim \mathcal{N}(0, I)\]
\[q_\phi(z) = p(\epsilon) \left| \det \frac{\partial f_\phi^{-1}}{\partial z} \right|\]

7.3 变分推断的局限性

  1. 近似误差\(q(z)\) 可能无法很好地近似 \(p(z | x)\)
  2. 局部最优:优化可能陷入局部最优
  3. 模型选择:选择合适的 \(q(z)\) 族很重要
  4. 计算成本:对于复杂模型,计算成本仍然很高

8. 本章总结

核心概念

  1. 变分推断
  2. 用简单分布近似复杂后验
  3. 最小化KL散度
  4. 最大化证据下界(ELBO)

  5. ELBO

  6. \(\mathcal{L} = \mathbb{E}_q[\log p(x, z) - \log q(z)]\)
  7. 重构项 + 正则项
  8. 提供了似然的下界

  9. 重参数化技巧

  10. 使随机采样可微
  11. 高斯:\(z = \mu + \sigma \cdot \epsilon\)
  12. 支持端到端训练

  13. 与扩散模型的联系

  14. 扩散模型是多步VAE
  15. 前向过程是固定编码器
  16. 反向过程是学习解码器

关键公式

概念 公式
KL散度 \(D_{KL}(q \Vert p) = \mathbb{E}_q[\log q - \log p]\)
ELBO \(\mathcal{L} = \mathbb{E}_q[\log p(x, z) - \log q(z)]\)
高斯ELBO \(\mathcal{L} = \mathbb{E}[\log p(x \| z)] - D_{KL}(q(z) \| p(z))\)
重参数化 \(z = \mu + \sigma \cdot \epsilon\)

算法流程

Text Only
变分推断算法:
1. 选择变分分布族 q_φ(z)
2. 初始化参数 φ
3. 重复:
   a. 从 q_φ(z) 采样 z
   b. 计算 ELBO
   c. 计算梯度 ∇_φ ELBO
   d. 更新参数 φ
4. 直到收敛

📝 自测问题

基础问题

  1. 变分推断基础
  2. 为什么需要变分推断?
  3. KL散度为什么不对称?
  4. 前向KL和后向KL有什么区别?

  5. ELBO推导

  6. 从KL散度推导出ELBO
  7. ELBO的两项分别代表什么?
  8. 为什么ELBO是似然的下界?

  9. 重参数化技巧

  10. 为什么需要重参数化?
  11. 高斯分布如何重参数化?
  12. 重参数化对梯度计算有什么帮助?

编程练习

  1. 实现完整的VAE模型
  2. 比较不同样本数下的ELBO
  3. 实现重要性加权自编码器(IWAE)
  4. 可视化变分分布的优化过程

思考题

  1. 变分推断的近似误差来自哪里?
  2. 如何设计更好的变分分布?
  3. 扩散模型相比VAE的优势在哪里?
  4. 变分推断和MCMC如何选择?

🔗 下一步

理解了变分推断后,我们将学习线性代数基础,这是理解扩散模型的数学工具。

→ 下一步:04-线性代数基础.md