跳转至

02 - 文本到图像生成

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习时间: 4.5小时 重要性: ⭐⭐⭐⭐⭐ 现代AI应用的核心技术


🎯 项目目标

完成本项目后,你将能够: - 理解文本到图像生成的原理 - 使用CLIP编码文本 - 实现文本条件扩散模型 - 训练模型生成高质量图像 - 实现无分类器引导采样


1. 项目概述

1.1 项目简介

本项目将实现一个文本到图像生成模型,能够根据文本描述生成对应的图像。

技术栈: - PyTorch: 深度学习框架 - CLIP: 文本和图像编码器 - UNet: 图像生成模型 - CIFAR-10: 训练数据集(带文本描述)

1.2 文本到图像生成流程

Text Only
文本描述
CLIP文本编码器
文本嵌入
文本条件UNet
扩散采样
生成图像

2. 环境准备

2.1 安装依赖

Bash
# 安装transformers(包含CLIP)
pip install transformers

# 安装其他依赖
pip install torch torchvision numpy matplotlib tqdm tensorboard

2.2 准备数据集

由于CIFAR-10没有文本描述,我们需要创建一个带文本描述的数据集:

Python
# prepare_text_dataset.py
import torch
from torchvision import datasets, transforms
import json
import os

# CIFAR-10类别的中文描述
class_descriptions = {
    0: "飞机 - 一架在空中飞行的飞机",
    1: "汽车 - 一辆在道路上行驶的汽车",
    2: "鸟 - 一只在天空中飞翔的鸟",
    3: "猫 - 一只可爱的猫咪",
    4: "鹿 - 一只在森林中的鹿",
    5: "狗 - 一只奔跑的狗",
    6: "青蛙 - 一只绿色的青蛙",
    7: "马 - 一匹在草原上奔跑的马",
    8: "船 - 一艘在海上航行的船",
    9: "卡车 - 一辆大型卡车"
}

def create_text_dataset(data_dir='./data'):
    """创建带文本描述的数据集"""

    # 加载CIFAR-10
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    train_dataset = datasets.CIFAR10(
        root=data_dir,
        train=True,
        download=True,
        transform=transform
    )

    # 创建文本数据
    text_data = []
    for i, (img, label) in enumerate(train_dataset):  # enumerate同时获取索引和元素
        text_data.append({
            'index': i,
            'label': label,
            'text': class_descriptions[label]
        })

    # 保存文本数据
    os.makedirs(data_dir, exist_ok=True)
    with open(os.path.join(data_dir, 'cifar10_text.json'), 'w', encoding='utf-8') as f:  # with自动管理文件关闭
        json.dump(text_data, f, ensure_ascii=False, indent=2)

    print(f"创建文本数据完成,共 {len(text_data)} 条")
    print(f"保存到: {os.path.join(data_dir, 'cifar10_text.json')}")

if __name__ == '__main__':
    create_text_dataset()

运行脚本:

Bash
python prepare_text_dataset.py


3. 文本编码器

Python
# text_encoder.py
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer

class TextEncoder(nn.Module):  # 继承nn.Module定义网络层
    """CLIP文本编码器"""

    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        super().__init__()  # super()调用父类方法

        # 加载CLIP模型
        self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
        self.text_model = CLIPTextModel.from_pretrained(model_name)

        # 冻结CLIP参数
        for param in self.text_model.parameters():
            param.requires_grad = False

    def forward(self, text_prompts):
        """
        编码文本提示

        参数:
            text_prompts: 文本提示列表

        返回:
            text_embeddings: [batch_size, seq_len, embedding_dim]
        """
        # Tokenize
        inputs = self.tokenizer(
            text_prompts,
            padding=True,
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )

        # 编码
        outputs = self.text_model(**inputs)
        text_embeddings = outputs.last_hidden_state

        return text_embeddings

# 测试
if __name__ == '__main__':
    encoder = TextEncoder()

    prompts = [
        "一只可爱的猫",
        "一辆在道路上行驶的汽车",
        "一只在天空中飞翔的鸟"
    ]

    embeddings = encoder(prompts)
    print(f"文本嵌入形状: {embeddings.shape}")
    print(f"嵌入维度: {embeddings.shape[-1]}")  # [-1]负索引取最后元素

4. 文本条件UNet

Python
# text_unet.py
import torch
import torch.nn as nn
import math

class SinusoidalPositionEmbedding(nn.Module):
    """正弦位置编码"""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # torch.cat沿已有维度拼接张量
        return emb

class CrossAttention(nn.Module):
    """交叉注意力模块"""

    def __init__(self, query_dim, context_dim, heads=4, dim_head=64):
        super().__init__()

        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        # Q, K, V投影
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        # 输出投影
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context):
        """
        参数:
            x: [batch_size, height*width, query_dim]
            context: [batch_size, seq_len, context_dim]

        返回:
            输出: [batch_size, height*width, query_dim]
        """
        batch_size, seq_len, _ = context.shape
        h, w, _ = x.shape[1], x.shape[2], x.shape[3]
        x = x.view(batch_size, -1, x.shape[-1])  # 重塑张量形状

        # 计算Q, K, V
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        # 重塑为多头
        q = q.view(batch_size, -1, self.heads, -1).transpose(1, 2)
        k = k.view(batch_size, -1, self.heads, -1).transpose(1, 2)
        v = v.view(batch_size, -1, self.heads, -1).transpose(1, 2)

        # 计算注意力
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = torch.softmax(attn, dim=-1)

        # 应用注意力
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, -1, out.shape[-1])

        # 输出投影
        out = self.to_out(out)

        return out.view(batch_size, h, w, -1)

class TextConditionedUNet(nn.Module):
    """文本条件UNet"""

    def __init__(self, in_channels=3, out_channels=3, model_dim=128,
                 text_embedding_dim=768):
        super().__init__()

        # 时间步嵌入
        self.time_embedding = SinusoidalPositionEmbedding(model_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(model_dim, model_dim * 4),
            nn.SiLU(),
            nn.Linear(model_dim * 4, model_dim)
        )

        # 文本嵌入投影
        self.text_proj = nn.Linear(text_embedding_dim, model_dim)

        # 初始卷积
        self.conv_in = nn.Conv2d(in_channels, model_dim, 3, padding=1)

        # 下采样
        self.down_blocks = nn.ModuleList([
            self._make_down_block(model_dim, model_dim * 2, model_dim),
            self._make_down_block(model_dim * 2, model_dim * 4, model_dim),
        ])

        # 中间层
        self.mid_block1 = self._make_mid_block(model_dim * 4, model_dim)
        self.mid_block2 = self._make_mid_block(model_dim * 4, model_dim)

        # 上采样
        self.up_blocks = nn.ModuleList([
            self._make_up_block(model_dim * 4, model_dim * 2, model_dim),
            self._make_up_block(model_dim * 2, model_dim, model_dim),
        ])

        # 输出卷积
        self.conv_out = nn.Conv2d(model_dim, out_channels, 3, padding=1)

    def _make_down_block(self, in_channels, out_channels, emb_dim):
        """创建下采样块"""
        return nn.ModuleList([
            ResidualBlock(in_channels, out_channels, emb_dim),
            ResidualBlock(out_channels, out_channels, emb_dim),
            nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
        ])

    def _make_up_block(self, in_channels, out_channels, emb_dim):
        """创建上采样块"""
        return nn.ModuleList([
            nn.ConvTranspose2d(in_channels, in_channels, 4, stride=2, padding=1),
            ResidualBlock(in_channels, out_channels, emb_dim),
            ResidualBlock(out_channels, out_channels, emb_dim)
        ])

    def _make_mid_block(self, channels, emb_dim):
        """创建中间块"""
        return nn.ModuleList([
            ResidualBlock(channels, channels, emb_dim),
            ResidualBlock(channels, channels, emb_dim)
        ])

    def forward(self, x, t, text_embeddings):
        """
        前向传播

        参数:
            x: [batch_size, in_channels, height, width]
            t: [batch_size] 时间步
            text_embeddings: [batch_size, seq_len, text_dim]

        返回:
            输出: [batch_size, out_channels, height, width]
        """
        batch_size = x.shape[0]

        # 时间步嵌入
        t_emb = self.time_embedding(t)
        t_emb = self.time_mlp(t_emb)

        # 文本嵌入投影
        text_emb = self.text_proj(text_embeddings)

        # 合并嵌入
        emb = t_emb + text_emb.mean(dim=1)

        # 初始卷积
        h = self.conv_in(x)

        # 下采样
        skips = []
        for down_block in self.down_blocks:
            for layer in down_block:
                if isinstance(layer, ResidualBlock):  # isinstance检查类型
                    h = layer(h, emb)
                else:
                    h = layer(h)
            skips.append(h)

        # 中间层
        for layer in self.mid_block1:
            h = layer(h, emb)
        for layer in self.mid_block2:
            h = layer(h, emb)

        # 上采样
        for i, up_block in enumerate(self.up_blocks):
            for layer in up_block:
                if isinstance(layer, nn.ConvTranspose2d):
                    h = layer(h)
                    h = h + skips[-(i+1)]
                else:
                    h = layer(h, emb)

        # 输出
        h = self.conv_out(h)
        return h

class ResidualBlock(nn.Module):
    """残差块"""

    def __init__(self, in_channels, out_channels, emb_dim):
        super().__init__()

        self.norm1 = nn.GroupNorm(8, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        # 嵌入投影
        self.emb_proj = nn.Linear(emb_dim, out_channels)

        # 残差连接
        if in_channels != out_channels:
            self.skip_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.skip_conv = nn.Identity()

    def forward(self, x, emb):
        h = self.norm1(x)
        h = nn.SiLU()(h)
        h = self.conv1(h)

        # 添加嵌入
        emb = self.emb_proj(emb)
        h = h + emb[:, :, None, None]

        h = self.norm2(h)
        h = nn.SiLU()(h)
        h = self.conv2(h)

        return h + self.skip_conv(x)

5. 训练脚本

Python
# train_text2image.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import json
import os
from tqdm import tqdm

from text_encoder import TextEncoder
from text_unet import TextConditionedUNet

class TextCIFAR10Dataset(Dataset):
    """带文本描述的CIFAR-10数据集"""

    def __init__(self, data_dir='./data', split='train', transform=None):
        """
        参数:
            data_dir: 数据目录
            split: 'train' 或 'test'
            transform: 数据变换
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform

        # 加载CIFAR-10
        self.cifar_dataset = datasets.CIFAR10(
            root=data_dir,
            train=(split == 'train'),
            download=True,
            transform=None
        )

        # 加载文本描述
        text_file = os.path.join(data_dir, 'cifar10_text.json')
        with open(text_file, 'r', encoding='utf-8') as f:
            self.text_data = json.load(f)

    def __len__(self):  # __len__定义len()行为
        return len(self.cifar_dataset)

    def __getitem__(self, idx):  # __getitem__定义索引访问行为
        # 获取图像
        img, label = self.cifar_dataset[idx]

        # 应用变换
        if self.transform:
            img = self.transform(img)

        # 获取文本
        text = self.text_data[idx]['text']

        return img, text

def get_schedule(T, beta_start, beta_end):
    """创建噪声调度表"""
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, betas, alphas_cumprod

def train():
    """训练函数"""
    # 配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    T = 1000
    beta_start = 0.0001
    beta_end = 0.02
    batch_size = 64
    num_epochs = 50
    learning_rate = 1e-4

    print(f"使用设备: {device}")

    # 创建数据加载器
    print("加载数据...")
    transform = transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    train_dataset = TextCIFAR10Dataset(
        data_dir='./data',
        split='train',
        transform=transform
    )

    train_loader = DataLoader(  # DataLoader批量加载数据
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    print(f"训练集大小: {len(train_dataset)}")

    # 创建模型
    print("创建模型...")
    text_encoder = TextEncoder().to(device)  # 移至GPU/CPU
    model = TextConditionedUNet(
        in_channels=3,
        out_channels=3,
        model_dim=128,
        text_embedding_dim=768
    ).to(device)

    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

    # 创建优化器
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    # 创建噪声调度
    alphas, betas, alphas_cumprod = get_schedule(T, beta_start, beta_end)
    alphas_cumprod = alphas_cumprod.to(device)

    # 训练循环
    print(f"\n开始训练,共 {num_epochs} 轮")

    for epoch in range(num_epochs):
        model.train()  # train()训练模式
        epoch_loss = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

        for x_0, texts in pbar:
            x_0 = x_0.to(device)

            # 编码文本
            text_embeddings = text_encoder(texts).to(device)

            # 随机采样时间步
            batch_size = x_0.shape[0]
            t = torch.randint(0, T, (batch_size,), device=device)

            # 生成噪声
            noise = torch.randn_like(x_0)

            # 计算加噪后的图像
            sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
            x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise

            # 模型预测噪声
            predicted_noise = model(x_t, t, text_embeddings)

            # 计算损失
            loss = nn.functional.mse_loss(predicted_noise, noise)

            # 反向传播
            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新参数

            epoch_loss += loss.item()  # 将单元素张量转为Python数值
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

        # 定期保存模型
        if (epoch + 1) % 10 == 0:
            os.makedirs('./checkpoints', exist_ok=True)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
            }, f'./checkpoints/text2image_epoch_{epoch+1}.pth')
            print(f"  ✓ 保存检查点")

    print("\n训练完成!")

if __name__ == '__main__':
    train()

6. 采样脚本

Python
# sample_text2image.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os

from text_encoder import TextEncoder
from text_unet import TextConditionedUNet

def get_schedule(T, beta_start, beta_end):
    """创建噪声调度表"""
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, betas, alphas_cumprod

def sample(model, text_encoder, text_prompts, T, alphas, betas, alphas_cumprod, device):
    """
    文本到图像采样

    参数:
        model: 文本条件扩散模型
        text_encoder: 文本编码器
        text_prompts: 文本提示列表
        T: 总步数
        alphas, betas, alphas_cumprod: 调度表
        device: 设备

    返回:
        生成的图像
    """
    model.eval()

    # 编码文本
    text_embeddings = text_encoder(text_prompts).to(device)

    # 创建初始噪声
    batch_size = len(text_prompts)
    x_T = torch.randn(batch_size, 3, 32, 32).to(device)

    alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])

    with torch.no_grad():  # 禁用梯度计算,节省内存
        for t in reversed(range(T)):
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_t_bar = alphas_cumprod[t]
            alpha_t_bar_prev = alphas_cumprod_prev[t]

            # 预测噪声
            t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            predicted_noise = model(x_T, t_tensor, text_embeddings)

            # 计算均值
            sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
            mean = sqrt_recip_alpha_t * (
                x_T - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
            )

            # 添加噪声
            if t > 0:
                posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                noise = torch.randn_like(x_T)
                x_T = mean + torch.sqrt(posterior_variance) * noise
            else:
                x_T = mean

    return x_T

def main():
    """主函数"""
    # 配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    T = 1000
    beta_start = 0.0001
    beta_end = 0.02

    print(f"使用设备: {device}")

    # 创建模型
    print("创建模型...")
    text_encoder = TextEncoder().to(device)
    model = TextConditionedUNet(
        in_channels=3,
        out_channels=3,
        model_dim=128,
        text_embedding_dim=768
    ).to(device)

    # 加载检查点
    checkpoint_path = './checkpoints/text2image_epoch_50.pth'
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"加载检查点: {checkpoint_path}")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Loss: {checkpoint['loss']:.4f}")
    else:
        print(f"未找到检查点: {checkpoint_path}")
        return

    # 创建噪声调度
    alphas, betas, alphas_cumprod = get_schedule(T, beta_start, beta_end)
    alphas_cumprod = alphas_cumprod.to(device)

    # 文本提示
    text_prompts = [
        "一只可爱的猫",
        "一辆在道路上行驶的汽车",
        "一只在天空中飞翔的鸟",
        "一艘在海上航行的船",
        "一只奔跑的狗",
        "一匹在草原上奔跑的马",
        "一架在空中飞行的飞机",
        "一只绿色的青蛙"
    ]

    # 生成图像
    print(f"\n生成 {len(text_prompts)} 个图像...")
    samples = sample(model, text_encoder, text_prompts, T, alphas, betas, alphas_cumprod, device)

    # 反归一化
    samples = (samples + 1) / 2
    samples = samples.clamp(0, 1)

    # 可视化
    grid = make_grid(samples, nrow=4, padding=2, normalize=False)

    plt.figure(figsize=(16, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title('Text-to-Image Generation')
    plt.axis('off')

    save_path = './samples/text2image_samples.png'
    os.makedirs('./samples', exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"样本已保存到: {save_path}")
    plt.show()

if __name__ == '__main__':
    main()

7. 运行项目

7.1 准备数据

Bash
# 准备文本数据集
python prepare_text_dataset.py

7.2 训练模型

Bash
# 训练文本到图像模型
python train_text2image.py

7.3 生成图像

Bash
# 生成图像
python sample_text2image.py

8. 项目总结

8.1 完成的工作

  1. ✅ 创建了带文本描述的数据集
  2. ✅ 实现了CLIP文本编码器
  3. ✅ 实现了文本条件UNet
  4. ✅ 训练了文本到图像生成模型
  5. ✅ 实现了文本到图像采样

8.2 关键技术点

技术点 说明
CLIP编码 将文本转换为向量表示
文本条件 将文本嵌入添加到模型中
交叉注意力 让模型关注文本信息
无分类器引导 提高条件生成的质量
文本提示 使用自然语言控制生成

8.3 改进方向

  1. 使用更大的数据集:如LAION-5B
  2. 使用更大的模型:如Stable Diffusion
  3. 使用更好的文本编码器:如T5
  4. 实现无分类器引导:提高生成质量
  5. 使用DDIM加速:减少采样步数

9. 常见问题

Q1: 生成的图像不符合文本描述?

A: 可以尝试: - 增加训练轮数 - 使用更大的模型 - 改进文本描述 - 使用无分类器引导

Q2: 训练时间太长?

A: 可以尝试: - 减少batch size - 使用混合精度训练 - 使用更小的模型 - 减少训练轮数

Q3: 如何提高生成质量?

A: 可以尝试: - 使用更大的数据集 - 使用更大的模型 - 使用更好的文本编码器 - 使用无分类器引导


10. 下一步

完成本项目后,你可以:

  1. 尝试其他数据集:如LAION-5B
  2. 实现无分类器引导:提高生成质量
  3. 使用预训练模型:如Stable Diffusion
  4. 实现图像编辑:基于文本编辑图像
  5. 优化模型:使用DDIM、LDM等技术

下一章: 03-图像修复与编辑 - 实现图像修复和编辑功能