03 - 潜空间扩散模型LDM¶
学习时间: 5小时 重要性: ⭐⭐⭐⭐⭐ Stable Diffusion的核心技术
🎯 学习目标¶
完成本章后,你将能够: - 理解潜空间扩散的原理和优势 - 掌握VAE编码器/解码器的设计 - 实现LDM的完整流程 - 理解感知损失和 patch-based 判别器
1. 为什么需要潜空间扩散¶
1.1 像素空间扩散的问题¶
计算成本高: - 高分辨率图像(如 1024×1024)有 3M+ 像素 - 每个像素都需要进行扩散过程 - 训练和推理都极其昂贵
内存占用大: - UNet需要在高分辨率上操作 - 注意力机制的复杂度是 O(N²) - 限制了 batch size 和模型大小
效率低: - 大部分计算花在了不重要的细节上 - 图像的语义信息被淹没在像素级噪声中
1.2 潜空间的优势¶
降维: - 将图像压缩到低维潜空间(如 64×64×4) - 计算量减少 96% 以上 - 可以在更高分辨率上训练
解耦语义和细节: - 潜空间捕获语义信息 - VAE解码器负责细节重建 - 扩散模型专注于语义生成
更好的表示: - VAE学习有意义的表示 - 潜空间具有更好的线性可分性 - 便于条件控制和编辑
1.3 潜空间扩散vs像素空间扩散¶
| 特性 | 像素空间DDPM | 潜空间LDM |
|---|---|---|
| 空间维度 | 高(256×256×3) | 低(32×32×4) |
| 计算成本 | 极高 | 低 |
| 训练速度 | 慢 | 快(4-10倍) |
| 内存占用 | 大 | 小 |
| 生成质量 | 高 | 高(相当) |
| 可控性 | 一般 | 更好 |
2. 变分自编码器(VAE)¶
2.1 VAE架构¶
编码器(Encoder):
解码器(Decoder):
2.2 VAE训练目标¶
重构损失: $\(\mathcal{L}_{\text{rec}} = \|x - \hat{x}\|^2\)$
KL散度: $\(\mathcal{L}_{\text{KL}} = D_{KL}(q(z|x) \Vert p(z))\)$
总损失: $\(\mathcal{L} = \mathcal{L}_{\text{rec}} + \beta \cdot \mathcal{L}_{\text{KL}}\)$
其中 \(\beta\) 控制重构质量和潜空间正则化的平衡。
2.3 感知损失(Perceptual Loss)¶
问题:像素级MSE损失会导致模糊
解决方案:使用预训练VGG网络的特征差异
class PerceptualLoss(nn.Module): # 继承nn.Module定义网络层
def __init__(self):
super().__init__() # super()调用父类方法
vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
self.feature_extractor = nn.Sequential(*list(vgg.features)[:16]) # 切片操作,取前n个元素
self.feature_extractor.eval()
for param in self.feature_extractor.parameters():
param.requires_grad = False
def forward(self, x, x_hat):
# 提取特征
features_x = self.feature_extractor(x)
features_x_hat = self.feature_extractor(x_hat)
# 计算L1距离
loss = F.l1_loss(features_x, features_x_hat) # F.xxx PyTorch函数式API
return loss
2.4 Patch-based 判别器¶
目的:提高重建图像的局部真实感
架构:
class PatchDiscriminator(nn.Module):
def __init__(self, input_channels=3):
super().__init__()
self.model = nn.Sequential(
# 输入: 3×256×256
nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# 64×128×128
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2),
# 128×64×64
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2),
# 256×32×32
nn.Conv2d(256, 512, 4, stride=1, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2),
# 512×31×31
nn.Conv2d(512, 1, 4, stride=1, padding=1),
# 1×30×30 (patch输出)
)
def forward(self, x):
return self.model(x)
优势: - 输出是特征图而非单一值 - 每个patch都有独立的真假判断 - 更好的局部细节监督
3. 完整VAE实现¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class VAEEncoder(nn.Module):
"""
VAE编码器
输入: 3×256×256
输出: 4×32×32 (均值和对数方差)
"""
def __init__(self, in_channels=3, latent_dim=4):
super().__init__()
# 下采样路径
self.encoder = nn.Sequential(
# 3×256×256 → 128×128×128
nn.Conv2d(in_channels, 128, 3, stride=2, padding=1),
nn.GroupNorm(32, 128),
nn.SiLU(),
# 128×128×128 → 128×64×64
nn.Conv2d(128, 128, 3, stride=2, padding=1),
nn.GroupNorm(32, 128),
nn.SiLU(),
# 128×64×64 → 256×32×32
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.GroupNorm(32, 256),
nn.SiLU(),
# 256×32×32 → 256×32×32
nn.Conv2d(256, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.SiLU(),
)
# 输出均值和对数方差
self.to_mu = nn.Conv2d(256, latent_dim, 3, padding=1)
self.to_logvar = nn.Conv2d(256, latent_dim, 3, padding=1)
def forward(self, x):
h = self.encoder(x)
mu = self.to_mu(h)
logvar = self.to_logvar(h)
return mu, logvar
class VAEDecoder(nn.Module):
"""
VAE解码器
输入: 4×32×32
输出: 3×256×256
"""
def __init__(self, latent_dim=4, out_channels=3):
super().__init__()
# 初始卷积
self.init_conv = nn.Sequential(
nn.Conv2d(latent_dim, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.SiLU(),
)
# 上采样路径
self.decoder = nn.Sequential(
# 256×32×32 → 256×32×32
nn.Conv2d(256, 256, 3, padding=1),
nn.GroupNorm(32, 256),
nn.SiLU(),
# 256×32×32 → 128×64×64
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.GroupNorm(32, 128),
nn.SiLU(),
# 128×64×64 → 128×128×128
nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
nn.GroupNorm(32, 128),
nn.SiLU(),
# 128×128×128 → 3×256×256
nn.ConvTranspose2d(128, out_channels, 4, stride=2, padding=1),
)
def forward(self, z):
h = self.init_conv(z)
x_hat = self.decoder(h)
return x_hat
class VAE(nn.Module):
"""
完整VAE
"""
def __init__(self, in_channels=3, latent_dim=4):
super().__init__()
self.encoder = VAEEncoder(in_channels, latent_dim)
self.decoder = VAEDecoder(latent_dim, in_channels)
self.latent_dim = latent_dim
def encode(self, x):
"""编码为潜变量"""
mu, logvar = self.encoder(x)
return mu, logvar
def reparameterize(self, mu, logvar):
"""重参数化采样"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def decode(self, z):
"""从潜变量解码"""
return self.decoder(z)
def forward(self, x):
"""前向传播"""
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_hat = self.decode(z)
return x_hat, mu, logvar
def get_latent(self, x):
"""获取潜变量(用于LDM)"""
mu, _ = self.encode(x)
return mu # 使用均值作为确定性编码
# VAE训练
class VAETrainer:
def __init__(self, vae, perceptual_loss=True, use_discriminator=True):
self.vae = vae
self.perceptual_loss = perceptual_loss
self.use_discriminator = use_discriminator
if perceptual_loss:
self.perceptual = PerceptualLoss()
if use_discriminator:
self.discriminator = PatchDiscriminator()
self.adversarial_loss = nn.MSELoss()
def compute_loss(self, x, x_hat, mu, logvar):
"""
计算VAE总损失
"""
# 重构损失
rec_loss = F.l1_loss(x_hat, x)
# 感知损失
if self.perceptual_loss:
p_loss = self.perceptual(x, x_hat)
rec_loss = rec_loss + 0.1 * p_loss
# KL散度
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
kl_loss = kl_loss / x.size(0) # 平均
# 总损失
total_loss = rec_loss + 0.000001 * kl_loss
return total_loss, rec_loss, kl_loss
4. 潜空间扩散模型¶
4.1 LDM架构¶
4.2 LDM实现¶
class LatentDiffusion(nn.Module):
"""
潜空间扩散模型
"""
def __init__(self, vae, unet, diffusion):
super().__init__()
self.vae = vae
self.unet = unet
self.diffusion = diffusion
# 冻结VAE
for param in self.vae.parameters():
param.requires_grad = False
self.vae.eval()
def encode_images(self, x):
"""
将图像编码到潜空间
参数:
x: 图像 [B, 3, H, W]
返回:
z: 潜变量 [B, 4, H/8, W/8]
"""
with torch.no_grad(): # 禁用梯度计算,节省内存
z = self.vae.get_latent(x)
return z
def decode_latents(self, z):
"""
将潜变量解码为图像
参数:
z: 潜变量 [B, 4, H/8, W/8]
返回:
x: 图像 [B, 3, H, W]
"""
with torch.no_grad():
x = self.vae.decode(z)
return x
def forward(self, x, c=None):
"""
前向传播(训练)
参数:
x: 图像
c: 条件
返回:
loss: 扩散损失
"""
# 编码到潜空间
z_0 = self.encode_images(x)
# 在潜空间进行扩散训练
loss = self.diffusion.training_losses(self.unet, z_0, c)
return loss
@torch.no_grad()
def sample(self, shape, c=None, num_steps=50):
"""
采样生成
参数:
shape: 图像形状 (B, 3, H, W)
c: 条件
num_steps: 采样步数
返回:
x: 生成图像
"""
# 计算潜空间形状
b, _, h, w = shape
latent_shape = (b, self.vae.latent_dim, h // 8, w // 8)
# 在潜空间采样
z = self.diffusion.sample(self.unet, latent_shape, c, num_steps)
# 解码为图像
x = self.decode_latents(z)
return x
# 使用示例
if __name__ == "__main__":
# 创建VAE
vae = VAE(in_channels=3, latent_dim=4)
# 创建UNet(在潜空间操作)
unet = UNet(
in_channels=4, # 潜空间通道
out_channels=4,
base_channels=64,
channel_mults=(1, 2, 4),
)
# 创建扩散过程
diffusion = GaussianDiffusion(timesteps=1000)
# 创建LDM
ldm = LatentDiffusion(vae, unet, diffusion)
# 测试
x = torch.randn(2, 3, 256, 256)
# 编码
z = ldm.encode_images(x)
print(f"图像形状: {x.shape}")
print(f"潜变量形状: {z.shape}")
print(f"压缩率: {x.numel() / z.numel():.1f}x")
# 解码
x_hat = ldm.decode_latents(z)
print(f"重建形状: {x_hat.shape}")
4.3 LDM的优势总结¶
计算效率: - 256×256图像 → 32×32潜空间 - 计算量减少 64倍 - 可以使用更大的batch size
内存效率: - UNet在更低分辨率上运行 - 可以训练更大的模型 - 支持更高分辨率生成
质量保持: - VAE保留了足够的语义信息 - 扩散模型专注于语义生成 - 解码器负责细节重建
5. 本章总结¶
核心概念¶
- 潜空间扩散
- 在压缩的潜空间进行扩散
- 大幅降低计算成本
-
保持生成质量
-
VAE
- 编码器:图像→潜变量
- 解码器:潜变量→图像
-
感知损失提高质量
-
LDM架构
- VAE编码 → 潜空间扩散 → VAE解码
- 训练和推理都更高效
- 支持高分辨率生成
关键公式¶
| 概念 | 公式 |
|---|---|
| VAE损失 | \(\mathcal{L} = \|x - \hat{x}\|^2 + \beta \cdot D_{KL}(q(z\|x) \Vert p(z))\) |
| 压缩率 | \((H \times W \times 3) / (h \times w \times c)\) |
实现要点¶
# LDM核心流程
z = vae.encode(x) # 编码到潜空间
z_noisy = diffusion.add_noise(z, t) # 加噪
z_denoised = unet(z_noisy, t) # 去噪
x_hat = vae.decode(z_denoised) # 解码为图像
📝 自测问题¶
基础问题¶
- 潜空间扩散
- 为什么需要潜空间扩散?
- 潜空间相比像素空间的优势?
-
压缩率如何计算?
-
VAE
- VAE的两个主要损失?
- 感知损失的作用?
-
Patch判别器的优势?
-
LDM
- LDM的完整流程?
- 为什么VAE要冻结?
- LDM支持多高分辨率?
编程练习¶
- 实现完整的VAE
- 训练VAE并评估重建质量
- 实现LDM并在潜空间训练扩散模型
- 比较像素空间和潜空间的训练速度
思考题¶
- VAE的质量如何影响LDM的生成效果?
- 潜空间的维度如何选择?
- LDM的局限性是什么?
🔗 下一步¶
理解了潜空间扩散后,我们将学习稳定扩散Stable Diffusion,这是目前最流行的开源文本到图像生成模型。
→ 下一步:04-扩散模型加速技术.md