跳转至

02 - 条件生成与引导

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 控制生成内容的核心技术


🎯 学习目标

完成本章后,你将能够: - 理解条件扩散模型的原理 - 掌握Classifier Guidance和Classifier-Free Guidance - 实现文本条件生成 - 理解CFG Scale的作用和调节方法


1. 条件生成概述

1.1 什么是条件生成

无条件生成:从随机噪声生成样本

Text Only
noise → model → image

条件生成:根据给定条件生成样本

Text Only
noise + condition → model → image

条件类型: - 类别标签(class label) - 文本描述(text prompt) - 图像(image-to-image) - 语义分割图(semantic map) - 关键点(keypoints)

1.2 条件扩散模型的数学形式

条件前向过程(与无条件相同): $\(q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I)\)$

条件反向过程: $\(p_\theta(x_{t-1} | x_t, c) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t, c), \Sigma_\theta(x_t, t, c))\)$

其中 \(c\) 是条件信息。

1.3 条件注入方式

方式1:条件编码器

Python
# 将条件编码为向量
condition_embedding = condition_encoder(c)  # [B, cond_dim]

# 注入到UNet
output = unet(x_t, t, condition_embedding)

方式2:交叉注意力(Cross-Attention)

Python
# 在UNet的每一层使用交叉注意力
# Query: 图像特征
# Key/Value: 条件特征
attn_output = cross_attention(image_features, condition_features)

方式3:自适应归一化(AdaGN)

Python
# 使用条件调节GroupNorm的参数
scale, shift = condition_mlp(condition_embedding)
x = group_norm(x) * (1 + scale) + shift


2. Classifier Guidance

2.1 核心思想

问题:如何引导生成过程朝向特定条件?

解决方案:使用分类器的梯度来引导

数学推导

根据贝叶斯定理: $\(p(x_t | c) = \frac{p(c | x_t) p(x_t)}{p(c)}\)$

取对数: $\(\log p(x_t | c) = \log p(c | x_t) + \log p(x_t) - \log p(c)\)$

\(x_t\) 求梯度: $\(\nabla_{x_t} \log p(x_t | c) = \nabla_{x_t} \log p(c | x_t) + \nabla_{x_t} \log p(x_t)\)$

Classifier Guidance: $\(\hat{\epsilon}(x_t, c) = \epsilon(x_t) - w \cdot \sigma_t \cdot \nabla_{x_t} \log p_\phi(c | x_t)\)$

其中: - \(p_\phi(c | x_t)\):分类器 - \(w\):引导强度(guidance scale) - \(\sigma_t\):噪声水平

2.2 实现代码

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class ClassifierGuidedDiffusion:
    """
    Classifier Guidance实现
    """
    def __init__(self, diffusion_model, classifier, guidance_scale=1.0):
        """
        参数:
            diffusion_model: 无条件扩散模型
            classifier: 噪声图像分类器
            guidance_scale: 引导强度w
        """
        self.model = diffusion_model
        self.classifier = classifier
        self.guidance_scale = guidance_scale

    def guided_score(self, x_t, t, c):
        """
        计算引导后的得分

        参数:
            x_t: 噪声图像
            t: 时间步
            c: 目标类别

        返回:
            guided_eps: 引导后的噪声预测
        """
        # 需要梯度
        x_t.requires_grad_(True)

        # 无条件预测
        with torch.no_grad():  # 禁用梯度计算,节省内存
            eps_uncond = self.model(x_t, t)

        # 分类器预测
        logits = self.classifier(x_t, t)
        log_prob = F.log_softmax(logits, dim=-1)  # F.xxx PyTorch函数式API
        selected_log_prob = log_prob[range(len(c)), c]

        # 计算梯度
        grad = torch.autograd.grad(
            selected_log_prob.sum(),
            x_t,
            create_graph=False
        )[0]

        # 计算噪声水平
        sigma_t = torch.sqrt(1 - self.model.alphas_cumprod[t]).view(-1, 1, 1, 1)  # 重塑张量形状

        # Classifier Guidance
        guided_eps = eps_uncond - self.guidance_scale * sigma_t * grad

        return guided_eps.detach()  # 分离计算图,不参与梯度计算

    def sample(self, shape, c, device='cuda'):
        """
        使用Classifier Guidance采样

        参数:
            shape: 输出形状
            c: 目标类别 [B]
            device: 计算设备
        """
        x = torch.randn(shape, device=device)

        for t in reversed(range(self.model.timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

            # 获取引导后的噪声预测
            eps = self.guided_score(x, t_batch, c)

            # 去噪步骤(使用DDPM采样)
            alpha_t = self.model.alphas[t]
            alpha_cumprod_t = self.model.alphas_cumprod[t]
            beta_t = self.model.betas[t]

            # 预测x_0
            x_0_pred = (x - torch.sqrt(1 - alpha_cumprod_t) * eps) / torch.sqrt(alpha_cumprod_t)
            x_0_pred = torch.clamp(x_0_pred, -1, 1)

            # 计算x_{t-1}(使用DDPM后验分布)
            if t > 0:
                alpha_cumprod_t_prev = self.model.alphas_cumprod[t - 1]

                # 后验均值: μ_t = √(ᾱ_{t-1}) * β_t / (1-ᾱ_t) * x_0 + √(α_t) * (1-ᾱ_{t-1}) / (1-ᾱ_t) * x_t
                posterior_mean = (
                    torch.sqrt(alpha_cumprod_t_prev) * beta_t / (1 - alpha_cumprod_t) * x_0_pred +
                    torch.sqrt(alpha_t) * (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * x
                )

                # 后验方差: β̃_t = (1-ᾱ_{t-1}) / (1-ᾱ_t) * β_t
                posterior_variance = (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * beta_t

                noise = torch.randn_like(x)
                x = posterior_mean + torch.sqrt(posterior_variance) * noise
            else:
                x = x_0_pred

        return x

class NoiseClassifier(nn.Module):  # 继承nn.Module定义网络层
    """
    噪声图像分类器(用于Classifier Guidance)
    """
    def __init__(self, num_classes=10, time_emb_dim=256):
        super().__init__()  # super()调用父类方法

        # 时间嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # 图像编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )

        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(256 + time_emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes),
        )

    def forward(self, x, t):
        """
        参数:
            x: 噪声图像 [B, C, H, W]
            t: 时间步 [B]

        返回:
            logits: 分类logits [B, num_classes]
        """
        # 时间嵌入
        t_emb = self.time_embed(t.float().view(-1, 1) / 1000)

        # 图像特征
        h = self.encoder(x)
        h = h.view(h.size(0), -1)

        # 拼接并分类
        h = torch.cat([h, t_emb], dim=1)  # torch.cat沿已有维度拼接张量
        logits = self.classifier(h)

        return logits

2.3 Classifier Guidance的优缺点

优点: - 可以使用预训练的分类器 - 引导强度可灵活调节 - 适用于各种条件类型

缺点: - 需要额外训练分类器 - 分类器需要在噪声图像上训练 - 梯度计算增加计算成本 - 可能导致对抗样本问题


3. Classifier-Free Guidance (CFG)

3.1 核心思想

问题:能否不使用分类器实现引导?

解决方案:训练时随机丢弃条件,推理时结合条件和无条件预测

数学公式: $\(\hat{\epsilon}(x_t, c) = \epsilon(x_t, \emptyset) + w \cdot (\epsilon(x_t, c) - \epsilon(x_t, \emptyset))\)$

其中: - \(\epsilon(x_t, c)\):条件预测 - \(\epsilon(x_t, \emptyset)\):无条件预测 - \(w\):引导强度(通常1.5-10)

数学解释(为什么CFG有效)

回忆噪声预测与得分函数的关系:\(\epsilon_\theta(x_t, c) \propto -\nabla_{x_t} \log p(x_t|c)\)

由贝叶斯定理:\(\nabla_{x_t} \log p(x_t|c) = \nabla_{x_t} \log p(x_t) + \nabla_{x_t} \log p(c|x_t)\)

因此隐式分类器梯度为: $\(\nabla_{x_t} \log p(c|x_t) = \nabla_{x_t} \log p(x_t|c) - \nabla_{x_t} \log p(x_t) \propto \epsilon(x_t, \emptyset) - \epsilon(x_t, c)\)$

将CFG公式改写为得分函数形式: $\(\hat{s}(x_t, c) = \nabla_{x_t} \log p(x_t) + w \cdot \nabla_{x_t} \log p(c|x_t)\)$

这正是Classifier Guidance的形式,但分类器梯度 \(\nabla \log p(c|x_t)\) 是由模型自身(条件与无条件预测的差值)隐式估计的,而非来自外部分类器。当 \(w > 1\) 时,相当于放大了分类器梯度,使生成结果更强烈地符合条件 \(c\),但过大的 \(w\) 会降低多样性甚至导致伪影。

直观理解: - 无条件预测提供基础方向 - 条件预测提供目标方向 - 引导强度控制"朝向条件的偏移量"

3.2 训练策略

随机丢弃条件

Python
# 训练时以概率p_uncond丢弃条件
if random.random() < p_uncond:
    c = None  # 无条件

# 模型需要处理None条件
if c is None:
    c = null_token  # 学习一个空条件嵌入

代码实现

Python
class CFGDiffusionModel(nn.Module):
    """
    支持Classifier-Free Guidance的扩散模型
    """
    def __init__(self, unet, num_classes=10, p_uncond=0.1):
        super().__init__()
        self.unet = unet
        self.num_classes = num_classes
        self.p_uncond = p_uncond

        # 类别嵌入
        self.class_embed = nn.Embedding(num_classes + 1, 256)  # +1 for null token
        self.null_token_id = num_classes  # 空条件的ID

    def forward(self, x_t, t, c=None):
        """
        参数:
            x_t: 噪声图像
            t: 时间步
            c: 类别标签(None表示无条件)

        返回:
            eps_pred: 预测的噪声
        """
        # 处理条件
        if c is None:
            c = torch.full((x_t.size(0),), self.null_token_id,
                          device=x_t.device, dtype=torch.long)

        # 获取条件嵌入
        c_emb = self.class_embed(c)

        # UNet预测
        eps_pred = self.unet(x_t, t, c_emb)

        return eps_pred

class CFGTrainer:
    """
    Classifier-Free Guidance训练器
    """
    def __init__(self, model, p_uncond=0.1):
        self.model = model
        self.p_uncond = p_uncond

    def training_step(self, x_0, c):
        """
        训练步骤

        参数:
            x_0: 原始图像
            c: 类别标签
        """
        batch_size = x_0.size(0)

        # 随机选择时间步
        t = torch.randint(0, 1000, (batch_size,), device=x_0.device)

        # 加噪
        noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t, noise)

        # 随机丢弃条件
        mask = torch.rand(batch_size) < self.p_uncond
        c_train = c.clone()
        c_train[mask] = self.model.null_token_id

        # 预测
        eps_pred = self.model(x_t, t, c_train)

        # 损失
        loss = F.mse_loss(eps_pred, noise)

        return loss

3.3 CFG采样

Python
def cfg_sample(model, shape, c, guidance_scale=7.5, device='cuda'):
    """
    使用Classifier-Free Guidance采样

    参数:
        model: CFG扩散模型
        shape: 输出形状
        c: 条件(类别标签或文本嵌入)
        guidance_scale: 引导强度w
        device: 计算设备
    """
    model.eval()  # eval()评估模式

    # 初始化
    x = torch.randn(shape, device=device)

    # 创建无条件标签
    null_c = torch.full((shape[0],), model.null_token_id,
                        device=device, dtype=torch.long)

    for t in reversed(range(model.timesteps)):
        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

        with torch.no_grad():
            # 条件预测
            eps_cond = model(x, t_batch, c)

            # 无条件预测
            eps_uncond = model(x, t_batch, null_c)

            # CFG
            eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

            # 去噪(DDPM步骤)
            alpha_t = model.alphas[t]
            alpha_cumprod_t = model.alphas_cumprod[t]

            x_0_pred = (x - torch.sqrt(1 - alpha_cumprod_t) * eps) / torch.sqrt(alpha_cumprod_t)
            x_0_pred = torch.clamp(x_0_pred, -1, 1)

            if t > 0:
                alpha_cumprod_t_prev = model.alphas_cumprod[t - 1]
                beta_t = 1 - alpha_t

                # 后验均值
                posterior_mean = (
                    torch.sqrt(alpha_cumprod_t_prev) * beta_t / (1 - alpha_cumprod_t) * x_0_pred +
                    torch.sqrt(alpha_t) * (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * x
                )

                # 后验方差
                posterior_variance = (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * beta_t

                noise = torch.randn_like(x)
                x = posterior_mean + torch.sqrt(posterior_variance) * noise
            else:
                x = x_0_pred

    return x

3.4 Guidance Scale的选择

Guidance Scale 效果 适用场景
1.0 无引导,多样性高 探索性生成
3-5 轻度引导,平衡 一般应用
7-10 强引导,条件忠实 标准设置
15-20 超强引导,可能过饱和 需要精确控制

调节建议: - 从7.5开始尝试 - 增加guidance_scale提高条件忠实度 - 降低guidance_scale增加多样性


4. 文本条件生成

4.1 CLIP文本编码

Python
import clip

class TextConditionedDiffusion:
    """
    文本条件扩散模型
    """
    def __init__(self, unet, clip_model_name='ViT-B/32'):
        # 加载CLIP
        self.clip_model, self.preprocess = clip.load(clip_model_name)
        self.clip_model.eval()

        self.unet = unet

        # 文本投影
        self.text_proj = nn.Linear(512, 256)  # CLIP维度到UNet条件维度

    def encode_text(self, text):
        """
        编码文本

        参数:
            text: 文本列表

        返回:
            text_emb: 文本嵌入 [B, 256]
        """
        with torch.no_grad():
            text_tokens = clip.tokenize(text).to(self.clip_model.device)
            text_features = self.clip_model.encode_text(text_tokens)
            text_features = text_features.float()

        # 投影
        text_emb = self.text_proj(text_features)
        return text_emb

    def forward(self, x_t, t, text):
        """
        前向传播

        参数:
            x_t: 噪声图像
            t: 时间步
            text: 文本描述
        """
        # 编码文本
        text_emb = self.encode_text(text)

        # UNet预测
        eps_pred = self.unet(x_t, t, text_emb)

        return eps_pred

4.2 文本CFG采样

Python
def text_cfg_sample(model, shape, prompt, negative_prompt="",
                    guidance_scale=7.5, num_inference_steps=50):
    """
    文本条件CFG采样(类似Stable Diffusion)

    参数:
        model: 文本条件模型
        shape: 输出形状
        prompt: 正向提示词
        negative_prompt: 负向提示词
        guidance_scale: 引导强度
        num_inference_steps: 采样步数
    """
    model.eval()

    # 编码文本
    text_emb = model.encode_text([prompt])
    uncond_emb = model.encode_text([negative_prompt])

    # 初始化
    x = torch.randn(shape, device=text_emb.device)

    # 时间步
    timesteps = torch.linspace(999, 0, num_inference_steps, dtype=torch.long)

    for t in timesteps:
        t_batch = t.unsqueeze(0).expand(shape[0])  # unsqueeze增加一个维度

        with torch.no_grad():
            # 条件预测
            eps_cond = model.unet(x, t_batch, text_emb)

            # 无条件预测
            eps_uncond = model.unet(x, t_batch, uncond_emb)

            # CFG
            eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

            # DDIM步骤(更快)
            alpha_t = model.alphas_cumprod[t]
            x_0_pred = (x - torch.sqrt(1 - alpha_t) * eps) / torch.sqrt(alpha_t)

            if t > 0:
                alpha_t_prev = model.alphas_cumprod[t - 1]
                x = (
                    torch.sqrt(alpha_t_prev) * x_0_pred +
                    torch.sqrt(1 - alpha_t_prev) * eps
                )
            else:
                x = x_0_pred

    return x

5. 高级引导技术

5.1 负向提示(Negative Prompting)

Python
def sample_with_negative_prompt(model, shape, prompt, negative_prompt,
                                guidance_scale=7.5):
    """
    使用负向提示引导

    原理:远离负向提示描述的内容
    """
    # 编码正负提示
    pos_emb = model.encode_text([prompt])
    neg_emb = model.encode_text([negative_prompt])

    # CFG公式扩展
    # eps = eps_uncond + w_pos * (eps_pos - eps_uncond) - w_neg * (eps_neg - eps_uncond)

    # 采样过程...

5.2 多条件组合

Python
def multi_condition_guidance(model, x_t, t, conditions, weights):
    """
    多条件加权引导

    参数:
        conditions: 条件列表 [c1, c2, c3, ...]
        weights: 权重列表 [w1, w2, w3, ...]
    """
    eps_uncond = model(x_t, t, None)

    eps_guided = eps_uncond.clone()
    for c, w in zip(conditions, weights):  # zip按位置配对
        eps_c = model(x_t, t, c)
        eps_guided += w * (eps_c - eps_uncond)

    return eps_guided

6. 本章总结

核心概念

  1. 条件扩散模型
  2. 条件注入方式:编码器、交叉注意力、AdaGN
  3. 条件类型:类别、文本、图像等

  4. Classifier Guidance

  5. 使用分类器梯度引导
  6. 需要预训练分类器
  7. 计算成本高

  8. Classifier-Free Guidance (CFG)

  9. 训练时随机丢弃条件
  10. 推理时结合条件和无条件预测
  11. 效果更好,无需额外分类器

  12. Guidance Scale

  13. 控制条件忠实度vs多样性
  14. 常用范围:7-10
  15. 可调节生成效果

关键公式

技术 公式
Classifier Guidance \(\hat{\epsilon} = \epsilon - w \cdot \sigma_t \cdot \nabla \log p(c \| x_t)\)
CFG \(\hat{\epsilon} = \epsilon_{\text{uncond}} + w \cdot (\epsilon_{\text{cond}} - \epsilon_{\text{uncond}})\)

实现要点

Python
# CFG核心代码
eps_cond = model(x_t, t, c)
eps_uncond = model(x_t, t, null_c)
eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

📝 自测问题

基础问题

  1. 条件生成
  2. 条件扩散模型与无条件的区别?
  3. 有哪些条件注入方式?
  4. 文本条件如何实现?

  5. Classifier Guidance

  6. 为什么需要分类器梯度?
  7. 引导强度如何影响生成?
  8. 有哪些缺点?

  9. CFG

  10. CFG为什么不需要分类器?
  11. 训练时为什么要随机丢弃条件?
  12. Guidance Scale的作用?

编程练习

  1. 实现Classifier Guidance采样
  2. 实现CFG训练和采样
  3. 实现文本条件生成
  4. 实验不同guidance_scale的效果

思考题

  1. Classifier Guidance vs CFG,哪个更好?为什么?
  2. 如何选择合适的guidance_scale?
  3. 负向提示的原理是什么?

🔗 下一步

理解了条件生成后,我们将学习潜空间扩散模型LDM,了解Stable Diffusion的核心技术。

→ 下一步:03-潜空间扩散模型LDM.md