跳转至

19 - 生成模型深度解析

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

生成模型深度解析图

🎨 生成模型概述

生成vs判别模型

Text Only
判别模型:P(Y|X) - 学习决策边界
- 逻辑回归、SVM、神经网络
- 回答:"这是什么?"

生成模型:P(X) 或 P(X|Y) - 学习数据分布
- GAN、VAE、扩散模型
- 回答:"这可能是什么?"

生成模型应用

Text Only
图像生成:艺术创作、风格迁移、超分辨率
文本生成:对话系统、代码生成、创意写作
音频生成:音乐创作、语音合成
视频生成:深度伪造、动画制作
药物发现:分子生成

🎭 GAN (生成对抗网络)

核心思想

Text Only
两个网络对抗:
生成器 G:从噪声生成假样本
判别器 D:区分真假样本

min_G max_D V(D,G) = E[log D(x)] + E[log(1-D(G(z)))]

训练过程

Text Only
1. 训练判别器:
   - 真样本标签1,假样本标签0
   - 最大化区分能力

2. 训练生成器:
   - 假样本标签1(欺骗判别器)
   - 最小化被识破概率

3. 交替训练直到纳什均衡

基础GAN实现

Python
import torch
import torch.nn as nn

# 生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, img_dim):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# 判别器
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

# 训练
latent_dim = 100
img_dim = 784  # 28x28

generator = Generator(latent_dim, img_dim)
discriminator = Discriminator(img_dim)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
criterion = nn.BCELoss()

for epoch in range(epochs):
    for batch_idx, (real_imgs, _) in enumerate(dataloader):  # enumerate同时获取索引和元素
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.view(batch_size, -1)

        # 标签
        real_label = torch.ones(batch_size, 1)
        fake_label = torch.zeros(batch_size, 1)

        # ========== 训练判别器 ==========
        optimizer_D.zero_grad()

        # 真样本损失
        real_output = discriminator(real_imgs)
        d_real_loss = criterion(real_output, real_label)

        # 假样本损失
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = generator(z)
        fake_output = discriminator(fake_imgs.detach())  # detach()从计算图分离,不参与梯度计算
        d_fake_loss = criterion(fake_output, fake_label)

        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()  # 反向传播计算梯度
        optimizer_D.step()

        # ========== 训练生成器 ==========
        optimizer_G.zero_grad()

        fake_output = discriminator(fake_imgs)
        g_loss = criterion(fake_output, real_label)  # 希望被判别为真

        g_loss.backward()
        optimizer_G.step()

DCGAN (深度卷积GAN)

Python
class DCGAN_Generator(nn.Module):
    def __init__(self, latent_dim=100, channels=3):
        super().__init__()
        self.model = nn.Sequential(
            # 输入: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 128 x 16 x 16
            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # channels x 32 x 32
        )

    def forward(self, z):
        return self.model(z.view(-1, z.size(1), 1, 1))  # view重塑张量形状(要求内存连续)

条件GAN (cGAN)

Python
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_dim):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )

    def forward(self, z, labels):
        label_embedding = self.label_emb(labels)
        x = torch.cat([z, label_embedding], dim=1)
        return self.model(x)

🔄 VAE (变分自编码器)

核心思想

Text Only
传统自编码器:
输入 → 编码器 → 隐向量 → 解码器 → 输出

VAE改进:
输入 → 编码器 → (μ, σ) → 采样 → 解码器 → 输出

关键:隐空间学习概率分布

数学原理

Text Only
目标:最大化证据下界 (ELBO)
ELBO = E[log P(X|Z)] - KL(Q(Z|X) || P(Z))

重建损失:输入与输出的差异
KL散度:隐分布与先验分布的差异

重参数化技巧

Text Only
问题:采样操作不可导
解决:z = μ + σ × ε,其中 ε ~ N(0,1)

这样梯度可以反向传播

代码实现

Python
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()

        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

# 损失函数
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

# 训练
model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    for batch_idx, (data, _) in enumerate(dataloader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()  # 根据梯度更新模型参数

# 生成新样本
with torch.no_grad():
    z = torch.randn(64, 20)  # 从标准正态分布采样
    sample = model.decode(z)

🌊 扩散模型 (Diffusion Models)

核心思想

Text Only
前向过程(扩散):
逐步添加高斯噪声,直到纯噪声

反向过程(去噪):
学习逐步去噪,从噪声恢复数据

训练:预测噪声
生成:迭代去噪

DDPM (去噪扩散概率模型)

Python
import torch
import torch.nn as nn
import math

# 简化的U-Net用于噪声预测
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=256):
        super().__init__()
        # 时间嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        # 编码器
        self.enc1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1)

        # 解码器
        self.dec2 = nn.Conv2d(128, 64, 3, padding=1)
        self.dec1 = nn.Conv2d(64, out_channels, 3, padding=1)

    def forward(self, x, t):
        # 时间嵌入
        t_emb = self.time_embed(t.float().view(-1, 1) / 1000)

        # 编码
        h1 = torch.relu(self.enc1(x))
        h2 = torch.relu(self.enc2(h1))

        # 解码
        h = torch.relu(self.dec2(h2))
        out = self.dec1(h)

        return out

class DiffusionModel(nn.Module):
    def __init__(self, timesteps=1000):
        super().__init__()
        self.timesteps = timesteps

        # 定义beta调度
        self.betas = torch.linspace(0.0001, 0.02, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # U-Net噪声预测网络
        self.unet = UNet()

    def forward_diffusion(self, x_0, t, noise=None):
        """前向加噪"""
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alpha_cumprod = torch.sqrt(self.alphas_cumprod[t])[:, None, None, None]
        sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - self.alphas_cumprod[t])[:, None, None, None]

        x_t = sqrt_alpha_cumprod * x_0 + sqrt_one_minus_alpha_cumprod * noise
        return x_t, noise

    def reverse_diffusion(self, x_t, t):
        """反向去噪(预测噪声)"""
        predicted_noise = self.unet(x_t, t)
        return predicted_noise

# 训练
def train_step(model, x_0, optimizer):
    optimizer.zero_grad()

    batch_size = x_0.size(0)
    t = torch.randint(0, model.timesteps, (batch_size,))
    noise = torch.randn_like(x_0)

    x_t, _ = model.forward_diffusion(x_0, t, noise)
    predicted_noise = model.reverse_diffusion(x_t, t)

    loss = nn.functional.mse_loss(predicted_noise, noise)
    loss.backward()
    optimizer.step()

    return loss.item()  # .item()将单元素张量转为Python数值

# 采样生成
@torch.no_grad()  # 装饰器:禁用梯度计算,减少内存占用并加速推理,采样时不需要反向传播
def sample(model, shape):
    device = next(model.parameters()).device
    x = torch.randn(shape).to(device)  # .to(device)将数据移至GPU/CPU

    for t in reversed(range(model.timesteps)):
        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
        predicted_noise = model.reverse_diffusion(x, t_batch)

        alpha = model.alphas[t]
        alpha_cumprod = model.alphas_cumprod[t]
        beta = model.betas[t]

        # 计算x_{t-1}
        x = (x - beta / torch.sqrt(1 - alpha_cumprod) * predicted_noise) / torch.sqrt(alpha)

        if t > 0:
            noise = torch.randn_like(x)
            x = x + torch.sqrt(beta) * noise

    return x

Stable Diffusion

Text Only
核心创新:在隐空间进行扩散
1. 使用VAE编码到隐空间
2. 在隐空间进行扩散
3. 条件控制(文本、类别等)
4. 使用VAE解码回图像空间

优势:计算效率大幅提升

🎯 生成模型对比

模型 训练稳定性 生成质量 多样性 条件控制 训练速度
GAN
VAE
扩散模型 极高 极好

💡 总结

Text Only
生成模型演进:
VAE (2014) → GAN (2014) → Flow (2014) →
PixelCNN (2016) → VQ-VAE (2017) →
DDPM (2020) → Stable Diffusion (2022)

当前主流:
- 图像:Stable Diffusion
- 文本:GPT系列
- 多模态:DALL-E, Midjourney

实践建议:
1. 从VAE入门理解生成模型原理
2. 掌握GAN训练技巧
3. 学习扩散模型是当前重点
4. 使用预训练模型快速应用

🌊 流模型 (Flow-based Models)

核心思想

可逆变换:通过一系列可逆的神经网络变换,将简单分布(如高斯)映射到复杂数据分布。

\[x = f(z), \quad z = f^{-1}(x)\]

关键优势: - 精确的似然计算 - 可逆的编码/解码 - 高效的采样

变量变换公式

\[p_X(x) = p_Z(f^{-1}(x)) \left| \det \frac{\partial f^{-1}(x)}{\partial x} \right|\]

NICE (Non-linear Independent Components Estimation)

Python
import math

class NICE(nn.Module):
    """
    NICE: 非线性独立成分估计
    使用加性耦合层实现可逆变换
    """

    def __init__(self, input_dim, hidden_dim=1000, num_coupling_layers=4):
        super().__init__()
        self.input_dim = input_dim
        self.num_coupling_layers = num_coupling_layers

        # 分割维度
        self.dim1 = input_dim // 2
        self.dim2 = input_dim - self.dim1

        # 耦合层网络
        self.coupling_layers = nn.ModuleList([
            self._make_coupling_net(hidden_dim)
            for _ in range(num_coupling_layers)
        ])

        # 缩放参数(对角缩放)
        self.scaling_diag = nn.Parameter(torch.zeros(input_dim))

    def _make_coupling_net(self, hidden_dim):
        """创建耦合网络 m(x) """
        return nn.Sequential(
            nn.Linear(self.dim1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.dim2)
        )

    def forward(self, x):
        """
        前向:x → z (编码)
        返回: z, log_det_jacobian
        """
        log_det_jacobian = 0

        # 耦合层变换
        for i, coupling_net in enumerate(self.coupling_layers):
            x1, x2 = x[:, :self.dim1], x[:, self.dim1:]

            if i % 2 == 0:
                # x2 = x2 + m(x1)
                x2 = x2 + coupling_net(x1)
            else:
                # x1 = x1 + m(x2)
                x1 = x1 + coupling_net(x2)

            x = torch.cat([x1, x2], dim=1)

        # 对角缩放
        z = x * torch.exp(self.scaling_diag)
        log_det_jacobian += torch.sum(self.scaling_diag)

        return z, log_det_jacobian

    def inverse(self, z):
        """
        反向:z → x (解码/生成)
        """
        # 逆缩放
        x = z * torch.exp(-self.scaling_diag)

        # 逆耦合层(反向顺序)
        for i in range(self.num_coupling_layers - 1, -1, -1):
            x1, x2 = x[:, :self.dim1], x[:, self.dim1:]

            if i % 2 == 0:
                x2 = x2 - self.coupling_layers[i](x1)
            else:
                x1 = x1 - self.coupling_layers[i](x2)

            x = torch.cat([x1, x2], dim=1)

        return x

    def log_prob(self, x):
        """计算对数似然"""
        z, log_det_jacobian = self.forward(x)

        # 先验分布(标准高斯)的对数概率
        log_prior = -0.5 * torch.sum(z ** 2, dim=1) - 0.5 * self.input_dim * math.log(2 * math.pi)

        # 变量变换公式
        log_prob = log_prior + log_det_jacobian

        return log_prob

    def sample(self, num_samples, device='cpu'):
        """从模型采样"""
        # 从先验采样
        z = torch.randn(num_samples, self.input_dim, device=device)
        # 逆变换
        x = self.inverse(z)
        return x

RealNVP (Real-valued Non-Volume Preserving)

Python
class RealNVP(nn.Module):
    """
    RealNVP: 使用仿射耦合层的流模型
    比NICE更灵活,允许缩放
    """

    def __init__(self, input_dim, hidden_dim=256, num_coupling_layers=6):
        super().__init__()
        self.input_dim = input_dim
        self.num_coupling_layers = num_coupling_layers

        self.dim1 = input_dim // 2
        self.dim2 = input_dim - self.dim1

        # 每个耦合层有两个网络:s(x) 和 t(x)
        self.s_nets = nn.ModuleList()
        self.t_nets = nn.ModuleList()

        for _ in range(num_coupling_layers):
            self.s_nets.append(self._make_net(hidden_dim, self.dim1, self.dim2))
            self.t_nets.append(self._make_net(hidden_dim, self.dim1, self.dim2))

    def _make_net(self, hidden_dim, in_dim, out_dim):
        """创建耦合网络"""
        return nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        """x → z"""
        log_det_jacobian = 0

        for i in range(self.num_coupling_layers):
            x1, x2 = x[:, :self.dim1], x[:, self.dim1:]

            if i % 2 == 0:
                s = self.s_nets[i](x1)
                t = self.t_nets[i](x1)
                # 仿射变换: x2 = exp(s) * x2 + t
                x2 = torch.exp(s) * x2 + t
                log_det_jacobian += torch.sum(s, dim=1)
            else:
                s = self.s_nets[i](x2)
                t = self.t_nets[i](x2)
                x1 = torch.exp(s) * x1 + t
                log_det_jacobian += torch.sum(s, dim=1)

            x = torch.cat([x1, x2], dim=1)

        return x, log_det_jacobian

    def inverse(self, z):
        """z → x"""
        x = z

        for i in range(self.num_coupling_layers - 1, -1, -1):
            x1, x2 = x[:, :self.dim1], x[:, self.dim1:]

            if i % 2 == 0:
                s = self.s_nets[i](x1)
                t = self.t_nets[i](x1)
                x2 = (x2 - t) * torch.exp(-s)
            else:
                s = self.s_nets[i](x2)
                t = self.t_nets[i](x2)
                x1 = (x1 - t) * torch.exp(-s)

            x = torch.cat([x1, x2], dim=1)

        return x

Glow

Python
import numpy as np
import torch.nn.functional as F

class ActNorm(nn.Module):
    """
    激活归一化层
    初始化时使每个通道的均值为0,方差为1
    """

    def __init__(self, num_features):
        super().__init__()
        self.num_features = num_features
        self.log_scale = nn.Parameter(torch.zeros(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.initialized = False

    def forward(self, x):
        if not self.initialized:
            # 数据依赖初始化
            with torch.no_grad():
                mean = torch.mean(x, dim=[0, 2, 3], keepdim=True)
                std = torch.std(x, dim=[0, 2, 3], keepdim=True)
                self.bias.data = -mean.squeeze()  # squeeze去除大小为1的维度
                self.log_scale.data = -torch.log(std.squeeze())
                self.initialized = True

        x = x + self.bias.view(1, -1, 1, 1)
        x = x * torch.exp(self.log_scale).view(1, -1, 1, 1)
        log_det = torch.sum(self.log_scale) * x.size(2) * x.size(3)
        return x, log_det

    def inverse(self, x):
        x = x * torch.exp(-self.log_scale).view(1, -1, 1, 1)
        x = x - self.bias.view(1, -1, 1, 1)
        return x

class Invertible1x1Conv(nn.Module):
    """
    可逆1x1卷积
    使用LU分解实现高效求逆
    """

    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels

        # 初始化随机正交矩阵
        w_init = np.linalg.qr(np.random.randn(num_channels, num_channels))[0]
        w_init = torch.from_numpy(w_init).float()

        # LU分解参数化: W = P @ L @ (U_strict + diag(s))
        # 使用scipy进行LU分解以正确初始化
        from scipy.linalg import lu as scipy_lu
        P_np, L_np, U_np = scipy_lu(w_init.numpy())
        s_np = np.diag(U_np)

        self.register_buffer('P', torch.from_numpy(P_np).float())
        self.L = nn.Parameter(torch.from_numpy(L_np).float())
        self.U = nn.Parameter(torch.from_numpy(np.triu(U_np, 1)).float())
        self.s = nn.Parameter(torch.from_numpy(s_np).float())

    def get_w(self):
        """获取权重矩阵"""
        L = self.L.tril(-1) + torch.eye(self.num_channels, device=self.L.device)
        U = self.U.triu(1) + torch.diag(self.s)
        W = self.P @ L @ U
        # unsqueeze(-1)两次将2D权重矩阵(C,C)扩展为(C,C,1,1)的4D张量,匹配F.conv2d要求的(out,in,kH,kW)卷积核格式
        return W.unsqueeze(-1).unsqueeze(-1)

    def forward(self, x):
        W = self.get_w()
        x = F.conv2d(x, W)
        log_det = torch.sum(torch.log(torch.abs(self.s))) * x.size(2) * x.size(3)
        return x, log_det

    def inverse(self, x):
        W = self.get_w()
        W_inv = torch.inverse(W.squeeze())
        x = F.conv2d(x, W_inv.unsqueeze(-1).unsqueeze(-1))
        return x

class GlowStep(nn.Module):
    """Glow的一个步骤"""

    def __init__(self, num_channels):
        super().__init__()
        self.actnorm = ActNorm(num_channels)
        self.conv = Invertible1x1Conv(num_channels)
        # 耦合层(使用上面RealNVP中定义的仿射耦合层)
        # AffineCoupling可基于RealNVP的仿射变换实现
        self.coupling = AffineCoupling(num_channels)  # 需根据RealNVP风格自行实现

    def forward(self, x):
        log_det = 0

        x, ld = self.actnorm(x)
        log_det += ld

        x, ld = self.conv(x)
        log_det += ld

        x, ld = self.coupling(x)
        log_det += ld

        return x, log_det

    def inverse(self, x):
        x = self.coupling.inverse(x)
        x = self.conv.inverse(x)
        x = self.actnorm.inverse(x)
        return x

⚡ 一致性模型 (Consistency Models)

核心思想

一致性模型是扩散模型的加速版本,通过直接学习从噪声到数据的映射,实现单步生成。

核心概念: - 一致性函数 \(f(x_t, t)\):将任意时间步 \(t\) 的噪声样本映射到时间步0的数据 - 一致性性质:\(f(x_t, t) = f(x_{t'}, t')\) 对于所有 \(t, t'\) 应该映射到同一个数据点

与扩散模型的对比

特性 扩散模型 一致性模型
采样步数 50-1000步 1-4步
生成速度 极快
训练方式 噪声预测 一致性蒸馏
图像质量 极高

数学原理

一致性映射

\[f(x_t, t) = x_0\]

训练目标

对于相邻时间步 \((t, t')\),要求一致性:

\[\mathcal{L} = \mathbb{E}_{x_0, t, t'} \left[ \lambda(t) \cdot d(f(x_t, t), f(x_{t'}, t')) \right]\]

其中 \(d(\cdot, \cdot)\) 是距离度量(如L2距离)。

代码实现

Python
class ConsistencyModel(nn.Module):
    """
    一致性模型
    学习从任意噪声水平直接映射到干净数据
    """

    def __init__(self, input_dim, num_timesteps=1000, sigma_min=0.002, sigma_max=80.0):
        super().__init__()
        self.input_dim = input_dim
        self.num_timesteps = num_timesteps
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

        # U-Net风格的网络
        self.network = UNetConsistency(input_dim)

        # 时间步嵌入
        self.time_embed_dim = 256
        self.time_mlp = nn.Sequential(
            nn.Linear(1, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )

    def get_sigma_schedule(self, t):
        """
        噪声调度(EDM风格)
        """
        # 对数线性调度
        sigma = self.sigma_max ** (t / self.num_timesteps) * self.sigma_min ** (1 - t / self.num_timesteps)
        return sigma

    def forward(self, x_t, t):
        """
        一致性函数: (x_t, t) → x_0

        Args:
            x_t: 噪声样本
            t: 时间步 (可以是连续的)
        """
        # 时间嵌入
        t_normalized = t.float() / self.num_timesteps
        t_emb = self.time_mlp(t_normalized.view(-1, 1))

        # 预测干净数据
        x_0_pred = self.network(x_t, t_emb)

        return x_0_pred

    def consistency_loss(self, x_0):
        """
        一致性训练损失

        要求 f(x_t, t) ≈ f(x_{t-1}, t-1)
        """
        batch_size = x_0.size(0)

        # 随机采样时间步
        t = torch.randint(1, self.num_timesteps, (batch_size,), device=x_0.device)
        t_prev = t - 1

        # 获取噪声水平
        sigma_t = self.get_sigma_schedule(t).view(-1, 1, 1, 1)
        sigma_t_prev = self.get_sigma_schedule(t_prev).view(-1, 1, 1, 1)

        # 添加噪声
        noise = torch.randn_like(x_0)
        x_t = x_0 + sigma_t * noise
        x_t_prev = x_0 + sigma_t_prev * noise

        # 一致性预测
        f_t = self.forward(x_t, t.float())
        f_t_prev = self.forward(x_t_prev, t_prev.float())

        # 一致性损失
        loss = F.mse_loss(f_t, f_t_prev)

        return loss

    @torch.no_grad()
    def sample(self, batch_size, device='cpu', num_steps=1):
        """
        采样生成

        Args:
            num_steps: 采样步数(1为单步生成)
        """
        # 从先验采样(最大噪声)
        x = torch.randn(batch_size, *self.input_dim, device=device) * self.sigma_max

        if num_steps == 1:
            # 单步生成
            x_0 = self.forward(x, torch.tensor([self.num_timesteps] * batch_size, device=device).float())
            return x_0
        else:
            # 多步生成(提高质量)
            timesteps = torch.linspace(self.num_timesteps, 0, num_steps + 1, device=device).long()

            for i in range(num_steps):
                t = torch.full((batch_size,), timesteps[i], device=device).float()
                x_0_pred = self.forward(x, t)

                if i < num_steps - 1:
                    # 添加少量噪声进行下一步
                    sigma_next = self.get_sigma_schedule(timesteps[i + 1])
                    noise = torch.randn_like(x)
                    x = x_0_pred + sigma_next * noise
                else:
                    x = x_0_pred

            return x

class UNetConsistency(nn.Module):
    """用于一致性模型的简化U-Net"""

    def __init__(self, input_dim, base_channels=128):
        super().__init__()
        in_channels, height, width = input_dim

        # 编码器
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.SiLU()
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.GroupNorm(8, base_channels * 2),
            nn.SiLU()
        )

        # 中间层
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1),
            nn.GroupNorm(8, base_channels * 2),
            nn.SiLU()
        )

        # 解码器
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.SiLU()
        )
        self.dec1 = nn.Conv2d(base_channels, in_channels, 3, padding=1)

    def forward(self, x, t_emb):
        # 编码
        h1 = self.enc1(x)
        h2 = self.enc2(h1)

        # 中间(可以加入时间嵌入)
        h = self.middle(h2)

        # 解码
        h = self.dec2(h)
        out = self.dec1(h)

        return out

def train_consistency_model():
    """训练一致性模型"""
    print("\n" + "=" * 60)
    print("一致性模型训练")
    print("=" * 60)

    # 配置
    input_dim = (3, 32, 32)  # CIFAR-10
    model = ConsistencyModel(input_dim, num_timesteps=1000)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # 模拟训练
    print("\n训练配置:")
    print(f"  输入维度: {input_dim}")
    print(f"  时间步数: {model.num_timesteps}")
    print(f"  噪声范围: [{model.sigma_min}, {model.sigma_max}]")

    # 模拟一个训练步骤
    x_0 = torch.randn(4, *input_dim)
    loss = model.consistency_loss(x_0)
    print(f"\n示例损失: {loss.item():.4f}")

    # 采样
    with torch.no_grad():
        samples_1step = model.sample(4, num_steps=1)
        samples_4step = model.sample(4, num_steps=4)

    print(f"\n单步生成样本形状: {samples_1step.shape}")
    print(f"四步生成样本形状: {samples_4step.shape}")

# 运行示例
if __name__ == "__main__":
    train_consistency_model()

一致性蒸馏 (Consistency Distillation)

Python
class ConsistencyDistillation:
    """
    一致性蒸馏:从预训练扩散模型蒸馏一致性模型
    """

    def __init__(self, teacher_diffusion_model, student_consistency_model):
        self.teacher = teacher_diffusion_model
        self.student = student_consistency_model

        # 冻结教师模型
        for param in self.teacher.parameters():
            param.requires_grad = False

    def distillation_loss(self, x_0):
        """
        蒸馏损失

        学生模型预测应该与教师模型的PF-ODE解一致
        """
        batch_size = x_0.size(0)
        device = x_0.device

        # 随机采样时间步
        t = torch.randint(1, self.student.num_timesteps, (batch_size,), device=device)

        # 获取噪声水平
        sigma_t = self.student.get_sigma_schedule(t).view(-1, 1, 1, 1)

        # 添加噪声
        noise = torch.randn_like(x_0)
        x_t = x_0 + sigma_t * noise

        # 学生预测: f(x_t, t)
        student_pred = self.student(x_t, t.float())

        # 教师预测: 使用教师模型去噪一步
        with torch.no_grad():
            # 教师模型预测噪声
            predicted_noise = self.teacher(x_t, t)
            # 去噪一步
            alpha_t = self.teacher.alphas[t].view(-1, 1, 1, 1)
            alpha_cumprod = self.teacher.alphas_cumprod[t].view(-1, 1, 1, 1)
            beta_t = self.teacher.betas[t].view(-1, 1, 1, 1)

            teacher_pred = (x_t - beta_t / torch.sqrt(1 - alpha_cumprod) * predicted_noise) / torch.sqrt(alpha_t)

        # 蒸馏损失
        loss = F.mse_loss(student_pred, teacher_pred)

        return loss

🎬 视频生成模型

视频扩散模型

Python
class VideoDiffusionModel(nn.Module):
    """
    视频扩散模型
    扩展图像扩散模型到时空维度
    """

    def __init__(self, num_frames=16, height=64, width=64, channels=3):
        super().__init__()
        self.num_frames = num_frames
        self.height = height
        self.width = width
        self.channels = channels

        # 3D U-Net (时空卷积)
        self.unet = SpatioTemporalUNet(
            in_channels=channels,
            num_frames=num_frames
        )

    def forward(self, video, t):
        """
        预测视频噪声

        Args:
            video: (B, C, T, H, W)
            t: 时间步
        """
        return self.unet(video, t)

class SpatioTemporalUNet(nn.Module):
    """时空U-Net"""

    def __init__(self, in_channels, num_frames, base_channels=64):
        super().__init__()

        # 空间卷积
        self.spatial_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # 时间注意力
        self.temporal_attn = nn.MultiheadAttention(base_channels, num_heads=8)

        # 时空卷积块
        self.st_blocks = nn.ModuleList([
            SpatioTemporalBlock(base_channels, base_channels * 2),
            SpatioTemporalBlock(base_channels * 2, base_channels * 4),
        ])

    def forward(self, x, t):
        # x: (B, C, T, H, W)
        B, C, T, H, W = x.shape

        # 处理每一帧
        x = x.permute(0, 2, 1, 3, 4)  # (B, T, C, H, W)
        x = x.reshape(B * T, C, H, W)  # reshape重塑张量形状

        # 空间特征
        x = self.spatial_conv(x)  # (B*T, base_channels, H, W)

        # 重塑为时间序列
        x = x.view(B, T, -1, H * W).permute(0, 3, 1, 2)  # (B, H*W, T, C)
        x = x.reshape(B * H * W, T, -1)

        # 时间注意力
        x, _ = self.temporal_attn(x, x, x)

        # 恢复形状
        x = x.view(B, H, W, T, -1).permute(0, 4, 3, 1, 2)

        return x

class SpatioTemporalBlock(nn.Module):
    """时空卷积块"""

    def __init__(self, in_channels, out_channels):
        super().__init__()

        # 空间卷积
        self.spatial_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.SiLU()
        )

        # 时间卷积 (1D卷积沿时间轴)
        self.temporal_conv = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.SiLU()
        )

    def forward(self, x):
        # x: (B, C, T, H, W)
        B, C, T, H, W = x.shape

        # 空间卷积
        x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        x = self.spatial_conv(x)
        C_new = x.size(1)
        x = x.view(B, T, C_new, H, W).permute(0, 2, 1, 3, 4)

        # 时间卷积
        x = x.permute(0, 3, 4, 1, 2).reshape(B * H * W, C_new, T)
        x = self.temporal_conv(x)
        x = x.view(B, H, W, C_new, T).permute(0, 3, 4, 1, 2)

        return x

🧬 分子生成

Python
class MoleculeGenerator(nn.Module):
    """
    分子生成模型
    用于药物发现
    """

    def __init__(self, atom_types=10, max_atoms=50):
        super().__init__()
        self.atom_types = atom_types
        self.max_atoms = max_atoms

        # 图神经网络
        self.gnn = GraphNeuralNetwork(atom_types, hidden_dim=256)

        # 生成器
        self.atom_generator = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, atom_types)
        )

        self.bond_generator = nn.Sequential(
            nn.Linear(256 * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 4)  # 无键、单键、双键、三键
        )

    def generate_molecule(self, num_atoms=None):
        """
        生成新分子

        Returns:
            atoms: 原子类型列表
            bonds: 邻接矩阵
        """
        if num_atoms is None:
            num_atoms = torch.randint(10, self.max_atoms, (1,)).item()

        # 初始化随机原子
        atoms = torch.randint(0, self.atom_types, (num_atoms,))
        positions = torch.randn(num_atoms, 3)  # 3D坐标

        # 使用GNN优化结构
        for _ in range(10):
            atom_features = self.gnn(atoms, positions)

            # 更新原子类型
            atom_logits = self.atom_generator(atom_features)
            atoms = torch.argmax(atom_logits, dim=-1)

        # 生成分子图
        bonds = self._generate_bonds(atom_features)

        return atoms, bonds

    def _generate_bonds(self, atom_features):
        """生成化学键"""
        num_atoms = atom_features.size(0)
        bonds = torch.zeros(num_atoms, num_atoms, dtype=torch.long)

        for i in range(num_atoms):
            for j in range(i + 1, num_atoms):
                # 预测键类型
                pair_features = torch.cat([atom_features[i], atom_features[j]])
                bond_logits = self.bond_generator(pair_features)
                bond_type = torch.argmax(bond_logits).item()
                bonds[i, j] = bond_type
                bonds[j, i] = bond_type

        return bonds

class GraphNeuralNetwork(nn.Module):
    """用于分子的图神经网络"""

    def __init__(self, atom_types, hidden_dim):
        super().__init__()
        self.atom_embedding = nn.Embedding(atom_types, hidden_dim)

        # 消息传递层
        self.message_layers = nn.ModuleList([
            MessagePassingLayer(hidden_dim) for _ in range(3)
        ])

    def forward(self, atoms, positions):
        # 原子嵌入
        x = self.atom_embedding(atoms)

        # 基于距离构建图
        distances = torch.cdist(positions, positions)
        adjacency = (distances < 2.0).float()  # 2埃阈值

        # 消息传递
        for layer in self.message_layers:
            x = layer(x, adjacency)

        return x

class MessagePassingLayer(nn.Module):
    """消息传递层"""

    def __init__(self, hidden_dim):
        super().__init__()
        self.message_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.update_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x, adjacency):
        # 聚合邻居消息
        num_nodes = x.size(0)
        messages = []

        for i in range(num_nodes):
            neighbors = adjacency[i].nonzero(as_tuple=True)[0]
            if len(neighbors) > 0:
                neighbor_features = x[neighbors]
                message = self.message_net(
                    torch.cat([x[i].unsqueeze(0).expand(len(neighbors), -1), neighbor_features], dim=-1)
                ).mean(dim=0)
            else:
                message = torch.zeros_like(x[i])
            messages.append(message)

        messages = torch.stack(messages)

        # 更新节点特征
        x_new = self.update_net(torch.cat([x, messages], dim=-1))
        return x_new

📊 生成模型评估指标

Python
class GenerationMetrics:
    """
    生成模型评估指标
    """

    @staticmethod  # @staticmethod静态方法,无需实例即可调用
    def inception_score(images, inception_model, splits=10):
        """
        Inception Score (IS)

        评估生成图像质量和多样性
        """
        from torchvision.models import inception_v3

        preds = []
        for img in images:
            pred = inception_model(img.unsqueeze(0))
            preds.append(F.softmax(pred, dim=1).cpu().numpy())

        preds = np.concatenate(preds, axis=0)

        # 计算KL散度
        scores = []
        for i in range(splits):
            part = preds[i * len(preds) // splits: (i + 1) * len(preds) // splits]
            kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
            kl = np.mean(np.sum(kl, axis=1))
            scores.append(np.exp(kl))

        return np.mean(scores), np.std(scores)

    @staticmethod
    def frechet_inception_distance(real_images, fake_images, inception_model):
        """
        Fréchet Inception Distance (FID)

        衡量生成图像与真实图像分布的差异
        越低越好
        """
        from scipy import linalg

        def get_activations(images):
            activations = []
            for img in images:
                with torch.no_grad():
                    act = inception_model(img.unsqueeze(0))
                activations.append(act.cpu().numpy())
            return np.concatenate(activations, axis=0)

        act_real = get_activations(real_images)
        act_fake = get_activations(fake_images)

        # 计算均值和协方差
        mu_real, sigma_real = np.mean(act_real, axis=0), np.cov(act_real, rowvar=False)
        mu_fake, sigma_fake = np.mean(act_fake, axis=0), np.cov(act_fake, rowvar=False)

        # 计算FID
        diff = mu_real - mu_fake
        covmean, _ = linalg.sqrtm(sigma_real @ sigma_fake, disp=False)

        if np.iscomplexobj(covmean):
            covmean = covmean.real

        fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)

        return fid

    @staticmethod
    def precision_recall(real_features, fake_features, k=3):
        """
        Precision and Recall

        Precision: 生成样本中有多少是真实的
        Recall: 真实样本中有多少被生成
        """
        from sklearn.neighbors import NearestNeighbors

        # 计算Precision
        nn_real = NearestNeighbors(n_neighbors=k).fit(real_features)
        distances_fake_to_real, _ = nn_real.kneighbors(fake_features)
        precision = np.mean(distances_fake_to_real < np.median(distances_fake_to_real))

        # 计算Recall
        nn_fake = NearestNeighbors(n_neighbors=k).fit(fake_features)
        distances_real_to_fake, _ = nn_fake.kneighbors(real_features)
        recall = np.mean(distances_real_to_fake < np.median(distances_real_to_fake))

        return precision, recall

def evaluate_generation_example():
    """生成模型评估示例"""
    print("\n" + "=" * 60)
    print("生成模型评估指标")
    print("=" * 60)

    metrics = GenerationMetrics()

    # 模拟数据
    print("\n1. Inception Score (IS)")
    print("-" * 40)
    print("  IS衡量: 生成图像的质量和多样性")
    print("  高分表示: 高质量且多样化")
    print("  典型值: 真实图像IS≈233 (ImageNet)")

    print("\n2. Fréchet Inception Distance (FID)")
    print("-" * 40)
    print("  FID衡量: 生成分布与真实分布的距离")
    print("  越低越好")
    print("  优秀: FID < 10")
    print("  良好: FID < 50")

    print("\n3. Precision and Recall")
    print("-" * 40)
    print("  Precision: 生成样本的真实性")
    print("  Recall: 覆盖真实分布的程度")
    print("  理想: Precision ≈ 1, Recall ≈ 1")

# 运行示例
if __name__ == "__main__":
    evaluate_generation_example()

💡 总结

Text Only
生成模型演进:
VAE (2014) → GAN (2014) → Flow (2014) →
PixelCNN (2016) → VQ-VAE (2017) →
DDPM (2020) → Stable Diffusion (2022) →
Consistency Models (2023) → Video Diffusion (2023-2024)

模型对比:
┌─────────────┬──────────┬──────────┬──────────┬──────────┐
│    模型     │ 训练稳定 │ 生成质量 │ 采样速度 │ 精确似然 │
├─────────────┼──────────┼──────────┼──────────┼──────────┤
│    VAE      │    ✓     │    △     │    ✓     │    ✓     │
│    GAN      │    ✗     │    ✓     │    ✓     │    ✗     │
│    Flow     │    ✓     │    △     │    ✓     │    ✓     │
│  扩散模型   │    ✓     │    ✓✓    │    ✗     │    ✗     │
│ 一致性模型  │    ✓     │    ✓     │    ✓✓    │    ✗     │
└─────────────┴──────────┴──────────┴──────────┴──────────┘

当前主流:
- 图像:Stable Diffusion, DALL-E 3, Midjourney
- 视频:Sora, Runway Gen-2, Pika
- 文本:GPT-4, Claude, LLaMA
- 分子:AlphaFold, DiffDock

实践建议:
1. 从VAE入门理解生成模型原理
2. 掌握GAN训练技巧(难但重要)
3. 学习扩散模型是当前重点
4. 关注一致性模型(未来方向)
5. 使用预训练模型快速应用
6. 了解评估指标选择合适模型

下一步:学习 20-自监督学习.md,掌握无标签数据学习方法!