05 - 采样算法详解¶
学习时间: 3小时 重要性: ⭐⭐⭐⭐⭐ 从噪声生成高质量图像的关键
🎯 学习目标¶
完成本章后,你将能够: - 理解扩散模型采样的基本原理 - 掌握DDPM采样算法的实现 - 了解DDIM等加速采样算法 - 理解采样步数与生成质量的关系 - 实现并优化采样过程
1. 采样概述¶
1.1 什么是采样?¶
采样是从训练好的扩散模型中生成新图像的过程。
类比理解: - 训练:学习如何从噪声中恢复图像(学习"去噪") - 采样:应用学到的去噪能力,从纯噪声生成图像
1.2 采样的基本流程¶
纯噪声 x_T
↓ 应用模型 ε_θ(x_T, T)
预测噪声
↓ 去噪
x_{T-1}
↓ 应用模型 ε_θ(x_{T-1}, T-1)
预测噪声
↓ 去噪
...
↓
x_0 (生成的图像)
2. DDPM采样算法¶
2.1 数学原理¶
DDPM(Denoising Diffusion Probabilistic Models)使用马尔可夫链进行采样。
核心公式:
给定 \(x_t\),如何得到 \(x_{t-1}\)?
其中: $\(\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t\)$
2.2 实际采样公式¶
由于我们不知道 \(x_0\),需要用模型预测的噪声来估计:
其中: - \(\epsilon_\theta(x_t, t)\):模型预测的噪声 - \(\mathbf{z} \sim \mathcal{N}(0, \mathbf{I})\):随机噪声(仅在 \(t > 1\) 时添加) - \(\beta_t = 1 - \alpha_t\):方差调度参数 - \(\sqrt{\beta_t}\):噪声标准差
2.3 代码实现¶
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
def get_schedule(T, beta_start=0.0001, beta_end=0.02):
"""
创建噪声调度表
参数:
T: 总步数
beta_start: 初始beta值
beta_end: 最终beta值
返回:
alphas, betas, alphas_cumprod
"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def ddpm_sample(model, x_T, T, alphas, betas, alphas_cumprod, device='cpu'):
"""
DDPM采样算法
参数:
model: 训练好的去噪模型
x_T: 初始噪声 [batch_size, channels, height, width]
T: 总步数
alphas, betas, alphas_cumprod: 调度表
device: 设备
返回:
生成的图像 x_0
"""
model.eval() # eval()评估模式
x_t = x_T.to(device) # 移至GPU/CPU
# 预计算一些值
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]]) # torch.cat沿已有维度拼接张量
with torch.no_grad(): # 禁用梯度计算,节省内存
for t in reversed(range(T)):
# 当前时间步的参数
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
alpha_t_bar_prev = alphas_cumprod_prev[t]
# 预测噪声
t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
predicted_noise = model(x_t, t_tensor)
# 计算均值
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
# 添加随机噪声(最后一步不添加)
if t > 0:
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
return x_t
# 示例:简单的UNet模型
class SimpleUNet(nn.Module): # 继承nn.Module定义网络层
def __init__(self, in_channels=3, out_channels=3):
super().__init__() # super()调用父类方法
self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 64, 3, padding=1)
self.conv4 = nn.Conv2d(64, out_channels, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, x, t):
# 简化版,实际UNet更复杂
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.conv4(x)
return x
# 测试采样
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 创建模型和调度表
T = 1000
model = SimpleUNet().to(device)
alphas, betas, alphas_cumprod = get_schedule(T)
# 生成初始噪声
batch_size = 4
x_T = torch.randn(batch_size, 3, 32, 32)
# 采样
print(f"开始采样,总步数: {T}")
x_0 = ddpm_sample(model, x_T, T, alphas, betas, alphas_cumprod, device)
print(f"采样完成,输出形状: {x_0.shape}")
3. 采样步数与质量的关系¶
3.1 步数的影响¶
| 步数 | 生成质量 | 采样速度 | 适用场景 |
|---|---|---|---|
| 少步数 (10-50) | 较低 | 很快 | 快速预览、交互式应用 |
| 中等步数 (100-500) | 较好 | 中等 | 平衡质量与速度 |
| 多步数 (1000+) | 最好 | 较慢 | 高质量生成、艺术创作 |
3.2 可视化不同步数的效果¶
def sample_with_steps(model, x_T, steps, alphas, betas, alphas_cumprod, device='cpu'):
"""
使用指定步数进行采样
参数:
steps: 采样步数列表,例如 [10, 50, 100, 500, 1000]
返回:
不同步数下的生成结果
"""
results = {}
T = len(alphas)
for num_steps in steps:
# 计算子采样间隔
step_indices = torch.linspace(0, T-1, num_steps, dtype=torch.long)
x_t = x_T.clone()
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
with torch.no_grad():
for i, t in enumerate(reversed(step_indices)): # enumerate同时获取索引和元素
t = t.item() # 将单元素张量转为Python数值
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
alpha_t_bar_prev = alphas_cumprod_prev[t]
# 预测噪声
t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
predicted_noise = model(x_t, t_tensor)
# 计算均值
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
# 添加噪声
if i < len(step_indices) - 1: # 不是最后一步
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
results[num_steps] = x_t.detach().cpu() # 分离计算图,不参与梯度计算
return results
# 测试不同步数
steps_list = [10, 50, 100, 500, 1000]
results = sample_with_steps(model, x_T, steps_list, alphas, betas, alphas_cumprod, device)
# 可视化
fig, axes = plt.subplots(1, len(steps_list), figsize=(15, 3))
for idx, steps in enumerate(steps_list):
img = results[steps][0] # 取第一张图
img = img.permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min()) # 归一化
axes[idx].imshow(img)
axes[idx].set_title(f'{steps} 步')
axes[idx].axis('off')
plt.tight_layout()
plt.show()
4. DDIM采样算法¶
4.1 DDIM vs DDPM¶
DDIM(Denoising Diffusion Implicit Models) 是DDPM的加速版本。
| 特性 | DDPM | DDIM |
|---|---|---|
| 采样类型 | 随机采样 | 确定性采样 |
| 采样步数 | 需要很多步 | 可以少很多步 |
| 采样速度 | 慢 | 快 |
| 生成质量 | 稍好 | 接近 |
| 确定性 | 否(每次不同) | 是(相同输入相同输出) |
4.2 DDIM数学原理¶
DDIM的核心思想是:不需要遵循马尔可夫链,可以跳步。
DDIM更新公式:
4.3 DDIM代码实现¶
def ddim_sample(model, x_T, T, alphas, alphas_cumprod, num_steps=50, device='cpu'):
"""
DDIM采样算法
参数:
model: 训练好的去噪模型
x_T: 初始噪声
T: 原始总步数
alphas, alphas_cumprod: 调度表
num_steps: 实际采样步数
device: 设备
返回:
生成的图像 x_0
"""
model.eval()
x_t = x_T.to(device)
# 选择采样时间步
step_indices = torch.linspace(0, T-1, num_steps, dtype=torch.long)
with torch.no_grad():
for i, t in enumerate(reversed(step_indices)):
t = t.item()
alpha_t = alphas[t]
alpha_t_bar = alphas_cumprod[t]
# 预测噪声
t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
predicted_noise = model(x_t, t_tensor)
# 计算x_0的预测
sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
x_0_pred = (x_t - sqrt_one_minus_alpha_t_bar * predicted_noise) / sqrt_alpha_t_bar
# 计算x_{t-1}
if i < len(step_indices) - 1: # 不是最后一步
t_prev = step_indices[len(step_indices) - 2 - i].item()
alpha_t_prev_bar = alphas_cumprod[t_prev]
sqrt_alpha_t_prev_bar = torch.sqrt(alpha_t_prev_bar)
sqrt_one_minus_alpha_t_prev_bar = torch.sqrt(1 - alpha_t_prev_bar)
# DDIM更新
x_t = sqrt_alpha_t_prev_bar * x_0_pred + sqrt_one_minus_alpha_t_prev_bar * predicted_noise
else:
x_t = x_0_pred
return x_t
# 测试DDIM
print("测试DDIM采样...")
for num_steps in [10, 20, 50, 100]:
x_0_ddim = ddim_sample(model, x_T, T, alphas, alphas_cumprod, num_steps, device)
print(f"DDIM {num_steps} 步采样完成,输出形状: {x_0_ddim.shape}")
5. 采样优化技巧¶
5.1 噪声调度优化¶
不同的噪声调度会影响生成质量:
def get_cosine_schedule(T, s=0.008):
"""
余弦噪声调度(Improved DDPM)
参数:
T: 总步数
s: 偏移参数
返回:
alphas, betas, alphas_cumprod
"""
steps = T + 1
x = torch.linspace(0, T, steps)
alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = 1 - betas
# 裁剪beta值
betas = torch.clip(betas, 0.0001, 0.9999)
return alphas, betas, alphas_cumprod[:-1]
# 对比不同调度
T = 1000
# 线性调度
alphas_linear, betas_linear, alphas_cumprod_linear = get_schedule(T)
# 余弦调度
alphas_cosine, betas_cosine, alphas_cumprod_cosine = get_cosine_schedule(T)
# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(betas_linear.numpy(), label='线性')
plt.plot(betas_cosine.numpy(), label='余弦')
plt.xlabel('时间步 t')
plt.ylabel('β_t')
plt.title('Beta 调度')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 3, 2)
plt.plot(alphas_cumprod_linear.numpy(), label='线性')
plt.plot(alphas_cumprod_cosine.numpy(), label='余弦')
plt.xlabel('时间步 t')
plt.ylabel('ᾱ_t')
plt.title('Alpha Bar 调度')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 3, 3)
plt.plot(np.sqrt(1 - alphas_cumprod_linear.numpy()), label='线性')
plt.plot(np.sqrt(1 - alphas_cumprod_cosine.numpy()), label='余弦')
plt.xlabel('时间步 t')
plt.ylabel('√(1-ᾱ_t)')
plt.title('噪声水平')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
5.2 分类器引导(Classifier Guidance)¶
使用分类器引导可以改善生成质量:
def classifier_guided_sample(model, classifier, x_T, T, alphas, betas, alphas_cumprod,
guidance_scale=1.0, class_label=None, device='cpu'):
"""
分类器引导采样
参数:
model: 去噪模型
classifier: 分类器
x_T: 初始噪声
T: 总步数
alphas, betas, alphas_cumprod: 调度表
guidance_scale: 引导强度
class_label: 目标类别
device: 设备
返回:
生成的图像
"""
model.eval()
classifier.eval()
x_t = x_T.to(device)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
with torch.no_grad():
for t in reversed(range(T)):
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
# 预测无条件噪声
predicted_noise = model(x_t, t_tensor)
# 如果有分类器,计算梯度引导
if classifier is not None and class_label is not None:
# 计算分类器梯度
x_t.requires_grad = True
logits = classifier(x_t, t_tensor)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
target_log_prob = log_probs[:, class_label].sum()
target_log_prob.backward() # 反向传播计算梯度
grad = x_t.grad
x_t.requires_grad = False
# 添加引导
predicted_noise = predicted_noise - guidance_scale * grad * torch.sqrt(1 - alpha_t_bar)
# 更新
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
if t > 0:
alpha_t_bar_prev = alphas_cumprod_prev[t]
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
return x_t
5.3 无分类器引导(Classifier-Free Guidance)¶
更常用的引导方式,不需要额外的分类器:
def classifier_free_guidance(model, x_T, T, alphas, betas, alphas_cumprod,
guidance_scale=7.5, condition=None, device='cpu'):
"""
无分类器引导采样
参数:
model: 支持条件生成的模型
x_T: 初始噪声
T: 总步数
alphas, betas, alphas_cumprod: 调度表
guidance_scale: 引导强度(通常5-15)
condition: 条件(如文本提示)
device: 设备
返回:
生成的图像
"""
model.eval()
x_t = x_T.to(device)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
with torch.no_grad():
for t in reversed(range(T)):
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
t_tensor = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
# 预测条件噪声和无条件噪声
noise_cond = model(x_t, t_tensor, condition)
noise_uncond = model(x_t, t_tensor, None)
# 组合预测
predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
# 更新
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
if t > 0:
alpha_t_bar_prev = alphas_cumprod_prev[t]
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
return x_t
6. 采样质量评估¶
6.1 视觉检查¶
def visualize_samples(samples, nrow=4, title="生成样本"):
"""
可视化生成的样本
参数:
samples: [batch_size, channels, height, width]
nrow: 每行显示的图片数
title: 标题
"""
from torchvision.utils import make_grid
# 归一化到[0, 1]
samples = (samples - samples.min()) / (samples.max() - samples.min())
# 创建网格
grid = make_grid(samples, nrow=nrow, padding=2, normalize=False)
# 显示
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title(title)
plt.axis('off')
plt.show()
# 生成并可视化
samples = torch.randn(16, 3, 32, 32) # 模拟生成结果
visualize_samples(samples, nrow=4, title="DDPM生成样本")
6.2 FID分数(Fréchet Inception Distance)¶
FID是评估生成质量的常用指标:
def calculate_fid(real_images, generated_images, batch_size=32):
"""
计算FID分数
参数:
real_images: 真实图像 [N, C, H, W]
generated_images: 生成图像 [N, C, H, W]
batch_size: 批大小
返回:
FID分数(越低越好)
"""
from torchvision.models import inception_v3, Inception_V3_Weights
from scipy import linalg
# 加载预训练的Inception模型,移除最后分类层以获取 2048 维 pool 特征
inception = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
inception.fc = torch.nn.Identity() # 替换分类层,输出 2048 维特征
inception = inception.eval()
def get_features(images):
"""提取Inception特征"""
features = []
with torch.no_grad():
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
# 调整大小到299x299(Inception的输入大小)
batch_resized = torch.nn.functional.interpolate(
batch, size=(299, 299), mode='bilinear', align_corners=False
)
feat = inception(batch_resized)
features.append(feat)
return torch.cat(features, dim=0).cpu().numpy()
# 提取特征
real_features = get_features(real_images)
gen_features = get_features(generated_images)
# 计算均值和协方差
mu_real = np.mean(real_features, axis=0)
mu_gen = np.mean(gen_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
sigma_gen = np.cov(gen_features, rowvar=False)
# 计算FID
diff = mu_real - mu_gen
covmean, _ = linalg.sqrtm(sigma_real @ sigma_gen, disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_gen - 2 * covmean)
return fid
7. 实践项目:实现完整的采样流程¶
class DiffusionSampler:
"""扩散模型采样器"""
def __init__(self, model, T=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
"""
初始化采样器
参数:
model: 训练好的模型
T: 总步数
beta_start, beta_end: beta范围
device: 设备
"""
self.model = model
self.T = T
self.device = device
# 创建调度表
self.alphas, self.betas, self.alphas_cumprod = self._get_schedule(T, beta_start, beta_end)
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), self.alphas_cumprod[:-1]])
def _get_schedule(self, T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def sample(self, batch_size=1, num_steps=None, method='ddpm', **kwargs): # *args接收任意位置参数,**kwargs接收任意关键字参数
"""
采样
参数:
batch_size: 批大小
num_steps: 采样步数(None表示使用T)
method: 采样方法 ('ddpm' 或 'ddim')
**kwargs: 其他参数
返回:
生成的图像 [batch_size, channels, height, width]
"""
# 创建初始噪声
x_T = torch.randn(batch_size, 3, 32, 32) # 假设输入大小为32x32
if method == 'ddpm':
return self._ddpm_sample(x_T, num_steps or self.T)
elif method == 'ddim':
return self._ddim_sample(x_T, num_steps or self.T)
else:
raise ValueError(f"未知的采样方法: {method}")
def _ddpm_sample(self, x_T, num_steps):
"""DDPM采样"""
self.model.eval()
x_t = x_T.to(self.device)
# 计算子采样间隔
if num_steps < self.T:
step_indices = torch.linspace(0, self.T-1, num_steps, dtype=torch.long)
else:
step_indices = torch.arange(self.T)
with torch.no_grad():
for i, t in enumerate(reversed(step_indices)):
t = t.item()
alpha_t = self.alphas[t]
beta_t = self.betas[t]
alpha_t_bar = self.alphas_cumprod[t]
alpha_t_bar_prev = self.alphas_cumprod_prev[t]
# 预测噪声
t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
predicted_noise = self.model(x_t, t_tensor)
# 计算均值
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_t - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
# 添加噪声
if i < len(step_indices) - 1:
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_t)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
return x_t
def _ddim_sample(self, x_T, num_steps):
"""DDIM采样"""
self.model.eval()
x_t = x_T.to(self.device)
# 选择采样时间步
step_indices = torch.linspace(0, self.T-1, num_steps, dtype=torch.long)
with torch.no_grad():
for i, t in enumerate(reversed(step_indices)):
t = t.item()
alpha_t = self.alphas[t]
alpha_t_bar = self.alphas_cumprod[t]
# 预测噪声
t_tensor = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long)
predicted_noise = self.model(x_t, t_tensor)
# 计算x_0的预测
sqrt_alpha_t_bar = torch.sqrt(alpha_t_bar)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
x_0_pred = (x_t - sqrt_one_minus_alpha_t_bar * predicted_noise) / sqrt_alpha_t_bar
# 计算x_{t-1}
if i < len(step_indices) - 1:
t_prev = step_indices[len(step_indices) - 2 - i].item()
alpha_t_prev_bar = self.alphas_cumprod[t_prev]
sqrt_alpha_t_prev_bar = torch.sqrt(alpha_t_prev_bar)
sqrt_one_minus_alpha_t_prev_bar = torch.sqrt(1 - alpha_t_prev_bar)
x_t = sqrt_alpha_t_prev_bar * x_0_pred + sqrt_one_minus_alpha_t_prev_bar * predicted_noise
else:
x_t = x_0_pred
return x_t
# 使用示例
sampler = DiffusionSampler(model, T=1000, device=device)
# DDPM采样
samples_ddpm = sampler.sample(batch_size=4, num_steps=1000, method='ddpm')
print(f"DDPM采样完成,形状: {samples_ddpm.shape}")
# DDIM采样(更快)
samples_ddim = sampler.sample(batch_size=4, num_steps=50, method='ddim')
print(f"DDIM采样完成,形状: {samples_ddim.shape}")
8. 总结¶
8.1 核心概念回顾¶
| 概念 | 说明 |
|---|---|
| DDPM采样 | 标准的马尔可夫链采样,质量好但速度慢 |
| DDIM采样 | 确定性采样,可以大幅减少步数 |
| 噪声调度 | 控制每步添加的噪声量,影响生成质量 |
| 引导采样 | 使用分类器或无分类器引导改善生成 |
| 采样步数 | 平衡质量与速度的关键参数 |
8.2 实践建议¶
- 从DDPM开始:先理解标准采样算法
- 尝试DDIM:加速采样,适合快速迭代
- 优化调度:尝试余弦调度等改进
- 使用引导:无分类器引导是现代实践
- 评估质量:结合视觉检查和FID等指标
8.3 常见问题¶
Q: 采样步数越多越好吗? A: 不一定。步数过多会浪费时间,步数过少会影响质量。需要根据具体任务调整。
Q: DDIM和DDPM哪个更好? A: 取决于应用场景。DDIM更快,适合交互式应用;DDPM质量稍好,适合高质量生成。
Q: 如何选择引导强度? A: 通常在5-15之间。值越大生成越符合条件,但可能降低多样性。
9. 推荐资源¶
论文¶
- DDPM: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- DDIM: "Denoising Diffusion Implicit Models" (Song et al., 2020)
- Improved DDPM: "Improved Denoising Diffusion Probabilistic Models" (Nichol & Dhariwal, 2021)
代码库¶
- Hugging Face Diffusers
- OpenAI's guided-diffusion
- CompVis/stable-diffusion
10. 自测问题¶
- DDPM和DDIM采样的主要区别是什么?
- 采样步数如何影响生成质量和速度?
- 什么是无分类器引导?它如何工作?
- 余弦噪声调度相比线性调度有什么优势?
- 如何评估生成图像的质量?
下一章: 06-训练技巧与优化 - 学习如何高效训练扩散模型