跳转至

03 - 潜空间扩散模型LDM

学习时间: 5小时 重要性: ⭐⭐⭐⭐⭐ Stable Diffusion的核心技术


🎯 学习目标

完成本章后,你将能够: - 理解潜空间扩散的原理和优势 - 掌握VAE编码器/解码器的设计 - 实现LDM的完整流程 - 理解感知损失和 patch-based 判别器


1. 为什么需要潜空间扩散

1.1 像素空间扩散的问题

计算成本高: - 高分辨率图像(如 1024×1024)有 3M+ 像素 - 每个像素都需要进行扩散过程 - 训练和推理都极其昂贵

内存占用大: - UNet需要在高分辨率上操作 - 注意力机制的复杂度是 O(N²) - 限制了 batch size 和模型大小

效率低: - 大部分计算花在了不重要的细节上 - 图像的语义信息被淹没在像素级噪声中

1.2 潜空间的优势

降维: - 将图像压缩到低维潜空间(如 64×64×4) - 计算量减少 96% 以上 - 可以在更高分辨率上训练

解耦语义和细节: - 潜空间捕获语义信息 - VAE解码器负责细节重建 - 扩散模型专注于语义生成

更好的表示: - VAE学习有意义的表示 - 潜空间具有更好的线性可分性 - 便于条件控制和编辑

1.3 潜空间扩散vs像素空间扩散

特性 像素空间DDPM 潜空间LDM
空间维度 高(256×256×3) 低(32×32×4)
计算成本 极高
训练速度 快(4-10倍)
内存占用
生成质量 高(相当)
可控性 一般 更好

2. 变分自编码器(VAE)

2.1 VAE架构

编码器(Encoder)

Text Only
图像 x (H×W×3)
卷积下采样
潜变量 z (h×w×c)  其中 h=H/8, w=W/8

解码器(Decoder)

Text Only
潜变量 z (h×w×c)
卷积上采样
重建图像 x̂ (H×W×3)

2.2 VAE训练目标

重构损失: $\(\mathcal{L}_{\text{rec}} = \|x - \hat{x}\|^2\)$

KL散度: $\(\mathcal{L}_{\text{KL}} = D_{KL}(q(z|x) \Vert p(z))\)$

总损失: $\(\mathcal{L} = \mathcal{L}_{\text{rec}} + \beta \cdot \mathcal{L}_{\text{KL}}\)$

其中 \(\beta\) 控制重构质量和潜空间正则化的平衡。

2.3 感知损失(Perceptual Loss)

问题:像素级MSE损失会导致模糊

解决方案:使用预训练VGG网络的特征差异

Python
class PerceptualLoss(nn.Module):  # 继承nn.Module定义网络层
    def __init__(self):
        super().__init__()  # super()调用父类方法
        vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(vgg.features)[:16])  # 切片操作,取前n个元素
        self.feature_extractor.eval()

        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, x, x_hat):
        # 提取特征
        features_x = self.feature_extractor(x)
        features_x_hat = self.feature_extractor(x_hat)

        # 计算L1距离
        loss = F.l1_loss(features_x, features_x_hat)  # F.xxx PyTorch函数式API
        return loss

2.4 Patch-based 判别器

目的:提高重建图像的局部真实感

架构

Python
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()

        self.model = nn.Sequential(
            # 输入: 3×256×256
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 64×128×128

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            # 128×64×64

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            # 256×32×32

            nn.Conv2d(256, 512, 4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),
            # 512×31×31

            nn.Conv2d(512, 1, 4, stride=1, padding=1),
            # 1×30×30 (patch输出)
        )

    def forward(self, x):
        return self.model(x)

优势: - 输出是特征图而非单一值 - 每个patch都有独立的真假判断 - 更好的局部细节监督


3. 完整VAE实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class VAEEncoder(nn.Module):
    """
    VAE编码器
    输入: 3×256×256
    输出: 4×32×32 (均值和对数方差)
    """
    def __init__(self, in_channels=3, latent_dim=4):
        super().__init__()

        # 下采样路径
        self.encoder = nn.Sequential(
            # 3×256×256 → 128×128×128
            nn.Conv2d(in_channels, 128, 3, stride=2, padding=1),
            nn.GroupNorm(32, 128),
            nn.SiLU(),

            # 128×128×128 → 128×64×64
            nn.Conv2d(128, 128, 3, stride=2, padding=1),
            nn.GroupNorm(32, 128),
            nn.SiLU(),

            # 128×64×64 → 256×32×32
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.GroupNorm(32, 256),
            nn.SiLU(),

            # 256×32×32 → 256×32×32
            nn.Conv2d(256, 256, 3, padding=1),
            nn.GroupNorm(32, 256),
            nn.SiLU(),
        )

        # 输出均值和对数方差
        self.to_mu = nn.Conv2d(256, latent_dim, 3, padding=1)
        self.to_logvar = nn.Conv2d(256, latent_dim, 3, padding=1)

    def forward(self, x):
        h = self.encoder(x)
        mu = self.to_mu(h)
        logvar = self.to_logvar(h)
        return mu, logvar

class VAEDecoder(nn.Module):
    """
    VAE解码器
    输入: 4×32×32
    输出: 3×256×256
    """
    def __init__(self, latent_dim=4, out_channels=3):
        super().__init__()

        # 初始卷积
        self.init_conv = nn.Sequential(
            nn.Conv2d(latent_dim, 256, 3, padding=1),
            nn.GroupNorm(32, 256),
            nn.SiLU(),
        )

        # 上采样路径
        self.decoder = nn.Sequential(
            # 256×32×32 → 256×32×32
            nn.Conv2d(256, 256, 3, padding=1),
            nn.GroupNorm(32, 256),
            nn.SiLU(),

            # 256×32×32 → 128×64×64
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.GroupNorm(32, 128),
            nn.SiLU(),

            # 128×64×64 → 128×128×128
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
            nn.GroupNorm(32, 128),
            nn.SiLU(),

            # 128×128×128 → 3×256×256
            nn.ConvTranspose2d(128, out_channels, 4, stride=2, padding=1),
        )

    def forward(self, z):
        h = self.init_conv(z)
        x_hat = self.decoder(h)
        return x_hat

class VAE(nn.Module):
    """
    完整VAE
    """
    def __init__(self, in_channels=3, latent_dim=4):
        super().__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim)
        self.decoder = VAEDecoder(latent_dim, in_channels)
        self.latent_dim = latent_dim

    def encode(self, x):
        """编码为潜变量"""
        mu, logvar = self.encoder(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """重参数化采样"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        """从潜变量解码"""
        return self.decoder(z)

    def forward(self, x):
        """前向传播"""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

    def get_latent(self, x):
        """获取潜变量(用于LDM)"""
        mu, _ = self.encode(x)
        return mu  # 使用均值作为确定性编码

# VAE训练
class VAETrainer:
    def __init__(self, vae, perceptual_loss=True, use_discriminator=True):
        self.vae = vae
        self.perceptual_loss = perceptual_loss
        self.use_discriminator = use_discriminator

        if perceptual_loss:
            self.perceptual = PerceptualLoss()

        if use_discriminator:
            self.discriminator = PatchDiscriminator()
            self.adversarial_loss = nn.MSELoss()

    def compute_loss(self, x, x_hat, mu, logvar):
        """
        计算VAE总损失
        """
        # 重构损失
        rec_loss = F.l1_loss(x_hat, x)

        # 感知损失
        if self.perceptual_loss:
            p_loss = self.perceptual(x, x_hat)
            rec_loss = rec_loss + 0.1 * p_loss

        # KL散度
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_loss = kl_loss / x.size(0)  # 平均

        # 总损失
        total_loss = rec_loss + 0.000001 * kl_loss

        return total_loss, rec_loss, kl_loss

4. 潜空间扩散模型

4.1 LDM架构

Text Only
图像 x
VAE编码器 → 潜变量 z
扩散模型(在潜空间)
去噪潜变量 z_0
VAE解码器 → 重建图像 x̂

4.2 LDM实现

Python
class LatentDiffusion(nn.Module):
    """
    潜空间扩散模型
    """
    def __init__(self, vae, unet, diffusion):
        super().__init__()

        self.vae = vae
        self.unet = unet
        self.diffusion = diffusion

        # 冻结VAE
        for param in self.vae.parameters():
            param.requires_grad = False
        self.vae.eval()

    def encode_images(self, x):
        """
        将图像编码到潜空间

        参数:
            x: 图像 [B, 3, H, W]

        返回:
            z: 潜变量 [B, 4, H/8, W/8]
        """
        with torch.no_grad():  # 禁用梯度计算,节省内存
            z = self.vae.get_latent(x)
        return z

    def decode_latents(self, z):
        """
        将潜变量解码为图像

        参数:
            z: 潜变量 [B, 4, H/8, W/8]

        返回:
            x: 图像 [B, 3, H, W]
        """
        with torch.no_grad():
            x = self.vae.decode(z)
        return x

    def forward(self, x, c=None):
        """
        前向传播(训练)

        参数:
            x: 图像
            c: 条件

        返回:
            loss: 扩散损失
        """
        # 编码到潜空间
        z_0 = self.encode_images(x)

        # 在潜空间进行扩散训练
        loss = self.diffusion.training_losses(self.unet, z_0, c)

        return loss

    @torch.no_grad()
    def sample(self, shape, c=None, num_steps=50):
        """
        采样生成

        参数:
            shape: 图像形状 (B, 3, H, W)
            c: 条件
            num_steps: 采样步数

        返回:
            x: 生成图像
        """
        # 计算潜空间形状
        b, _, h, w = shape
        latent_shape = (b, self.vae.latent_dim, h // 8, w // 8)

        # 在潜空间采样
        z = self.diffusion.sample(self.unet, latent_shape, c, num_steps)

        # 解码为图像
        x = self.decode_latents(z)

        return x

# 使用示例
if __name__ == "__main__":
    # 创建VAE
    vae = VAE(in_channels=3, latent_dim=4)

    # 创建UNet(在潜空间操作)
    unet = UNet(
        in_channels=4,      # 潜空间通道
        out_channels=4,
        base_channels=64,
        channel_mults=(1, 2, 4),
    )

    # 创建扩散过程
    diffusion = GaussianDiffusion(timesteps=1000)

    # 创建LDM
    ldm = LatentDiffusion(vae, unet, diffusion)

    # 测试
    x = torch.randn(2, 3, 256, 256)

    # 编码
    z = ldm.encode_images(x)
    print(f"图像形状: {x.shape}")
    print(f"潜变量形状: {z.shape}")
    print(f"压缩率: {x.numel() / z.numel():.1f}x")

    # 解码
    x_hat = ldm.decode_latents(z)
    print(f"重建形状: {x_hat.shape}")

4.3 LDM的优势总结

计算效率: - 256×256图像 → 32×32潜空间 - 计算量减少 64倍 - 可以使用更大的batch size

内存效率: - UNet在更低分辨率上运行 - 可以训练更大的模型 - 支持更高分辨率生成

质量保持: - VAE保留了足够的语义信息 - 扩散模型专注于语义生成 - 解码器负责细节重建


5. 本章总结

核心概念

  1. 潜空间扩散
  2. 在压缩的潜空间进行扩散
  3. 大幅降低计算成本
  4. 保持生成质量

  5. VAE

  6. 编码器:图像→潜变量
  7. 解码器:潜变量→图像
  8. 感知损失提高质量

  9. LDM架构

  10. VAE编码 → 潜空间扩散 → VAE解码
  11. 训练和推理都更高效
  12. 支持高分辨率生成

关键公式

概念 公式
VAE损失 \(\mathcal{L} = \|x - \hat{x}\|^2 + \beta \cdot D_{KL}(q(z\|x) \Vert p(z))\)
压缩率 \((H \times W \times 3) / (h \times w \times c)\)

实现要点

Python
# LDM核心流程
z = vae.encode(x)           # 编码到潜空间
z_noisy = diffusion.add_noise(z, t)  # 加噪
z_denoised = unet(z_noisy, t)        # 去噪
x_hat = vae.decode(z_denoised)       # 解码为图像

📝 自测问题

基础问题

  1. 潜空间扩散
  2. 为什么需要潜空间扩散?
  3. 潜空间相比像素空间的优势?
  4. 压缩率如何计算?

  5. VAE

  6. VAE的两个主要损失?
  7. 感知损失的作用?
  8. Patch判别器的优势?

  9. LDM

  10. LDM的完整流程?
  11. 为什么VAE要冻结?
  12. LDM支持多高分辨率?

编程练习

  1. 实现完整的VAE
  2. 训练VAE并评估重建质量
  3. 实现LDM并在潜空间训练扩散模型
  4. 比较像素空间和潜空间的训练速度

思考题

  1. VAE的质量如何影响LDM的生成效果?
  2. 潜空间的维度如何选择?
  3. LDM的局限性是什么?

🔗 下一步

理解了潜空间扩散后,我们将学习稳定扩散Stable Diffusion,这是目前最流行的开源文本到图像生成模型。

→ 下一步:04-扩散模型加速技术.md