02 - 文本到图像生成¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习时间: 4.5小时 重要性: ⭐⭐⭐⭐⭐ 现代AI应用的核心技术
🎯 项目目标¶
完成本项目后,你将能够: - 理解文本到图像生成的原理 - 使用CLIP编码文本 - 实现文本条件扩散模型 - 训练模型生成高质量图像 - 实现无分类器引导采样
1. 项目概述¶
1.1 项目简介¶
本项目将实现一个文本到图像生成模型,能够根据文本描述生成对应的图像。
技术栈: - PyTorch: 深度学习框架 - CLIP: 文本和图像编码器 - UNet: 图像生成模型 - CIFAR-10: 训练数据集(带文本描述)
1.2 文本到图像生成流程¶
2. 环境准备¶
2.1 安装依赖¶
Bash
# 安装transformers(包含CLIP)
pip install transformers
# 安装其他依赖
pip install torch torchvision numpy matplotlib tqdm tensorboard
2.2 准备数据集¶
由于CIFAR-10没有文本描述,我们需要创建一个带文本描述的数据集:
Python
# prepare_text_dataset.py
import torch
from torchvision import datasets, transforms
import json
import os
# CIFAR-10类别的中文描述
class_descriptions = {
0: "飞机 - 一架在空中飞行的飞机",
1: "汽车 - 一辆在道路上行驶的汽车",
2: "鸟 - 一只在天空中飞翔的鸟",
3: "猫 - 一只可爱的猫咪",
4: "鹿 - 一只在森林中的鹿",
5: "狗 - 一只奔跑的狗",
6: "青蛙 - 一只绿色的青蛙",
7: "马 - 一匹在草原上奔跑的马",
8: "船 - 一艘在海上航行的船",
9: "卡车 - 一辆大型卡车"
}
def create_text_dataset(data_dir='./data'):
"""创建带文本描述的数据集"""
# 加载CIFAR-10
transform = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = datasets.CIFAR10(
root=data_dir,
train=True,
download=True,
transform=transform
)
# 创建文本数据
text_data = []
for i, (img, label) in enumerate(train_dataset): # enumerate同时获取索引和元素
text_data.append({
'index': i,
'label': label,
'text': class_descriptions[label]
})
# 保存文本数据
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, 'cifar10_text.json'), 'w', encoding='utf-8') as f: # with自动管理文件关闭
json.dump(text_data, f, ensure_ascii=False, indent=2)
print(f"创建文本数据完成,共 {len(text_data)} 条")
print(f"保存到: {os.path.join(data_dir, 'cifar10_text.json')}")
if __name__ == '__main__':
create_text_dataset()
运行脚本:
3. 文本编码器¶
Python
# text_encoder.py
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer
class TextEncoder(nn.Module): # 继承nn.Module定义网络层
"""CLIP文本编码器"""
def __init__(self, model_name="openai/clip-vit-base-patch32"):
super().__init__() # super()调用父类方法
# 加载CLIP模型
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_model = CLIPTextModel.from_pretrained(model_name)
# 冻结CLIP参数
for param in self.text_model.parameters():
param.requires_grad = False
def forward(self, text_prompts):
"""
编码文本提示
参数:
text_prompts: 文本提示列表
返回:
text_embeddings: [batch_size, seq_len, embedding_dim]
"""
# Tokenize
inputs = self.tokenizer(
text_prompts,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt"
)
# 编码
outputs = self.text_model(**inputs)
text_embeddings = outputs.last_hidden_state
return text_embeddings
# 测试
if __name__ == '__main__':
encoder = TextEncoder()
prompts = [
"一只可爱的猫",
"一辆在道路上行驶的汽车",
"一只在天空中飞翔的鸟"
]
embeddings = encoder(prompts)
print(f"文本嵌入形状: {embeddings.shape}")
print(f"嵌入维度: {embeddings.shape[-1]}") # [-1]负索引取最后元素
4. 文本条件UNet¶
Python
# text_unet.py
import torch
import torch.nn as nn
import math
class SinusoidalPositionEmbedding(nn.Module):
"""正弦位置编码"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # torch.cat沿已有维度拼接张量
return emb
class CrossAttention(nn.Module):
"""交叉注意力模块"""
def __init__(self, query_dim, context_dim, heads=4, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
# Q, K, V投影
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
# 输出投影
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context):
"""
参数:
x: [batch_size, height*width, query_dim]
context: [batch_size, seq_len, context_dim]
返回:
输出: [batch_size, height*width, query_dim]
"""
batch_size, seq_len, _ = context.shape
h, w, _ = x.shape[1], x.shape[2], x.shape[3]
x = x.view(batch_size, -1, x.shape[-1]) # 重塑张量形状
# 计算Q, K, V
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
# 重塑为多头
q = q.view(batch_size, -1, self.heads, -1).transpose(1, 2)
k = k.view(batch_size, -1, self.heads, -1).transpose(1, 2)
v = v.view(batch_size, -1, self.heads, -1).transpose(1, 2)
# 计算注意力
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = torch.softmax(attn, dim=-1)
# 应用注意力
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous()
out = out.view(batch_size, -1, out.shape[-1])
# 输出投影
out = self.to_out(out)
return out.view(batch_size, h, w, -1)
class TextConditionedUNet(nn.Module):
"""文本条件UNet"""
def __init__(self, in_channels=3, out_channels=3, model_dim=128,
text_embedding_dim=768):
super().__init__()
# 时间步嵌入
self.time_embedding = SinusoidalPositionEmbedding(model_dim)
self.time_mlp = nn.Sequential(
nn.Linear(model_dim, model_dim * 4),
nn.SiLU(),
nn.Linear(model_dim * 4, model_dim)
)
# 文本嵌入投影
self.text_proj = nn.Linear(text_embedding_dim, model_dim)
# 初始卷积
self.conv_in = nn.Conv2d(in_channels, model_dim, 3, padding=1)
# 下采样
self.down_blocks = nn.ModuleList([
self._make_down_block(model_dim, model_dim * 2, model_dim),
self._make_down_block(model_dim * 2, model_dim * 4, model_dim),
])
# 中间层
self.mid_block1 = self._make_mid_block(model_dim * 4, model_dim)
self.mid_block2 = self._make_mid_block(model_dim * 4, model_dim)
# 上采样
self.up_blocks = nn.ModuleList([
self._make_up_block(model_dim * 4, model_dim * 2, model_dim),
self._make_up_block(model_dim * 2, model_dim, model_dim),
])
# 输出卷积
self.conv_out = nn.Conv2d(model_dim, out_channels, 3, padding=1)
def _make_down_block(self, in_channels, out_channels, emb_dim):
"""创建下采样块"""
return nn.ModuleList([
ResidualBlock(in_channels, out_channels, emb_dim),
ResidualBlock(out_channels, out_channels, emb_dim),
nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
])
def _make_up_block(self, in_channels, out_channels, emb_dim):
"""创建上采样块"""
return nn.ModuleList([
nn.ConvTranspose2d(in_channels, in_channels, 4, stride=2, padding=1),
ResidualBlock(in_channels, out_channels, emb_dim),
ResidualBlock(out_channels, out_channels, emb_dim)
])
def _make_mid_block(self, channels, emb_dim):
"""创建中间块"""
return nn.ModuleList([
ResidualBlock(channels, channels, emb_dim),
ResidualBlock(channels, channels, emb_dim)
])
def forward(self, x, t, text_embeddings):
"""
前向传播
参数:
x: [batch_size, in_channels, height, width]
t: [batch_size] 时间步
text_embeddings: [batch_size, seq_len, text_dim]
返回:
输出: [batch_size, out_channels, height, width]
"""
batch_size = x.shape[0]
# 时间步嵌入
t_emb = self.time_embedding(t)
t_emb = self.time_mlp(t_emb)
# 文本嵌入投影
text_emb = self.text_proj(text_embeddings)
# 合并嵌入
emb = t_emb + text_emb.mean(dim=1)
# 初始卷积
h = self.conv_in(x)
# 下采样
skips = []
for down_block in self.down_blocks:
for layer in down_block:
if isinstance(layer, ResidualBlock): # isinstance检查类型
h = layer(h, emb)
else:
h = layer(h)
skips.append(h)
# 中间层
for layer in self.mid_block1:
h = layer(h, emb)
for layer in self.mid_block2:
h = layer(h, emb)
# 上采样
for i, up_block in enumerate(self.up_blocks):
for layer in up_block:
if isinstance(layer, nn.ConvTranspose2d):
h = layer(h)
h = h + skips[-(i+1)]
else:
h = layer(h, emb)
# 输出
h = self.conv_out(h)
return h
class ResidualBlock(nn.Module):
"""残差块"""
def __init__(self, in_channels, out_channels, emb_dim):
super().__init__()
self.norm1 = nn.GroupNorm(8, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = nn.GroupNorm(8, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
# 嵌入投影
self.emb_proj = nn.Linear(emb_dim, out_channels)
# 残差连接
if in_channels != out_channels:
self.skip_conv = nn.Conv2d(in_channels, out_channels, 1)
else:
self.skip_conv = nn.Identity()
def forward(self, x, emb):
h = self.norm1(x)
h = nn.SiLU()(h)
h = self.conv1(h)
# 添加嵌入
emb = self.emb_proj(emb)
h = h + emb[:, :, None, None]
h = self.norm2(h)
h = nn.SiLU()(h)
h = self.conv2(h)
return h + self.skip_conv(x)
5. 训练脚本¶
Python
# train_text2image.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import json
import os
from tqdm import tqdm
from text_encoder import TextEncoder
from text_unet import TextConditionedUNet
class TextCIFAR10Dataset(Dataset):
"""带文本描述的CIFAR-10数据集"""
def __init__(self, data_dir='./data', split='train', transform=None):
"""
参数:
data_dir: 数据目录
split: 'train' 或 'test'
transform: 数据变换
"""
self.data_dir = data_dir
self.split = split
self.transform = transform
# 加载CIFAR-10
self.cifar_dataset = datasets.CIFAR10(
root=data_dir,
train=(split == 'train'),
download=True,
transform=None
)
# 加载文本描述
text_file = os.path.join(data_dir, 'cifar10_text.json')
with open(text_file, 'r', encoding='utf-8') as f:
self.text_data = json.load(f)
def __len__(self): # __len__定义len()行为
return len(self.cifar_dataset)
def __getitem__(self, idx): # __getitem__定义索引访问行为
# 获取图像
img, label = self.cifar_dataset[idx]
# 应用变换
if self.transform:
img = self.transform(img)
# 获取文本
text = self.text_data[idx]['text']
return img, text
def get_schedule(T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def train():
"""训练函数"""
# 配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
T = 1000
beta_start = 0.0001
beta_end = 0.02
batch_size = 64
num_epochs = 50
learning_rate = 1e-4
print(f"使用设备: {device}")
# 创建数据加载器
print("加载数据...")
transform = transforms.Compose([
transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = TextCIFAR10Dataset(
data_dir='./data',
split='train',
transform=transform
)
train_loader = DataLoader( # DataLoader批量加载数据
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
print(f"训练集大小: {len(train_dataset)}")
# 创建模型
print("创建模型...")
text_encoder = TextEncoder().to(device) # 移至GPU/CPU
model = TextConditionedUNet(
in_channels=3,
out_channels=3,
model_dim=128,
text_embedding_dim=768
).to(device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建优化器
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
# 创建噪声调度
alphas, betas, alphas_cumprod = get_schedule(T, beta_start, beta_end)
alphas_cumprod = alphas_cumprod.to(device)
# 训练循环
print(f"\n开始训练,共 {num_epochs} 轮")
for epoch in range(num_epochs):
model.train() # train()训练模式
epoch_loss = 0
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
for x_0, texts in pbar:
x_0 = x_0.to(device)
# 编码文本
text_embeddings = text_encoder(texts).to(device)
# 随机采样时间步
batch_size = x_0.shape[0]
t = torch.randint(0, T, (batch_size,), device=device)
# 生成噪声
noise = torch.randn_like(x_0)
# 计算加噪后的图像
sqrt_alpha_t_bar = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
# 模型预测噪声
predicted_noise = model(x_t, t, text_embeddings)
# 计算损失
loss = nn.functional.mse_loss(predicted_noise, noise)
# 反向传播
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
epoch_loss += loss.item() # 将单元素张量转为Python数值
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
epoch_loss /= len(train_loader)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
# 定期保存模型
if (epoch + 1) % 10 == 0:
os.makedirs('./checkpoints', exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss,
}, f'./checkpoints/text2image_epoch_{epoch+1}.pth')
print(f" ✓ 保存检查点")
print("\n训练完成!")
if __name__ == '__main__':
train()
6. 采样脚本¶
Python
# sample_text2image.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os
from text_encoder import TextEncoder
from text_unet import TextConditionedUNet
def get_schedule(T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def sample(model, text_encoder, text_prompts, T, alphas, betas, alphas_cumprod, device):
"""
文本到图像采样
参数:
model: 文本条件扩散模型
text_encoder: 文本编码器
text_prompts: 文本提示列表
T: 总步数
alphas, betas, alphas_cumprod: 调度表
device: 设备
返回:
生成的图像
"""
model.eval()
# 编码文本
text_embeddings = text_encoder(text_prompts).to(device)
# 创建初始噪声
batch_size = len(text_prompts)
x_T = torch.randn(batch_size, 3, 32, 32).to(device)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
with torch.no_grad(): # 禁用梯度计算,节省内存
for t in reversed(range(T)):
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
alpha_t_bar_prev = alphas_cumprod_prev[t]
# 预测噪声
t_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
predicted_noise = model(x_T, t_tensor, text_embeddings)
# 计算均值
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_T - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
# 添加噪声
if t > 0:
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_T)
x_T = mean + torch.sqrt(posterior_variance) * noise
else:
x_T = mean
return x_T
def main():
"""主函数"""
# 配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
T = 1000
beta_start = 0.0001
beta_end = 0.02
print(f"使用设备: {device}")
# 创建模型
print("创建模型...")
text_encoder = TextEncoder().to(device)
model = TextConditionedUNet(
in_channels=3,
out_channels=3,
model_dim=128,
text_embedding_dim=768
).to(device)
# 加载检查点
checkpoint_path = './checkpoints/text2image_epoch_50.pth'
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载检查点: {checkpoint_path}")
print(f" Epoch: {checkpoint['epoch']}")
print(f" Loss: {checkpoint['loss']:.4f}")
else:
print(f"未找到检查点: {checkpoint_path}")
return
# 创建噪声调度
alphas, betas, alphas_cumprod = get_schedule(T, beta_start, beta_end)
alphas_cumprod = alphas_cumprod.to(device)
# 文本提示
text_prompts = [
"一只可爱的猫",
"一辆在道路上行驶的汽车",
"一只在天空中飞翔的鸟",
"一艘在海上航行的船",
"一只奔跑的狗",
"一匹在草原上奔跑的马",
"一架在空中飞行的飞机",
"一只绿色的青蛙"
]
# 生成图像
print(f"\n生成 {len(text_prompts)} 个图像...")
samples = sample(model, text_encoder, text_prompts, T, alphas, betas, alphas_cumprod, device)
# 反归一化
samples = (samples + 1) / 2
samples = samples.clamp(0, 1)
# 可视化
grid = make_grid(samples, nrow=4, padding=2, normalize=False)
plt.figure(figsize=(16, 8))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title('Text-to-Image Generation')
plt.axis('off')
save_path = './samples/text2image_samples.png'
os.makedirs('./samples', exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"样本已保存到: {save_path}")
plt.show()
if __name__ == '__main__':
main()
7. 运行项目¶
7.1 准备数据¶
7.2 训练模型¶
7.3 生成图像¶
8. 项目总结¶
8.1 完成的工作¶
- ✅ 创建了带文本描述的数据集
- ✅ 实现了CLIP文本编码器
- ✅ 实现了文本条件UNet
- ✅ 训练了文本到图像生成模型
- ✅ 实现了文本到图像采样
8.2 关键技术点¶
| 技术点 | 说明 |
|---|---|
| CLIP编码 | 将文本转换为向量表示 |
| 文本条件 | 将文本嵌入添加到模型中 |
| 交叉注意力 | 让模型关注文本信息 |
| 无分类器引导 | 提高条件生成的质量 |
| 文本提示 | 使用自然语言控制生成 |
8.3 改进方向¶
- 使用更大的数据集:如LAION-5B
- 使用更大的模型:如Stable Diffusion
- 使用更好的文本编码器:如T5
- 实现无分类器引导:提高生成质量
- 使用DDIM加速:减少采样步数
9. 常见问题¶
Q1: 生成的图像不符合文本描述?¶
A: 可以尝试: - 增加训练轮数 - 使用更大的模型 - 改进文本描述 - 使用无分类器引导
Q2: 训练时间太长?¶
A: 可以尝试: - 减少batch size - 使用混合精度训练 - 使用更小的模型 - 减少训练轮数
Q3: 如何提高生成质量?¶
A: 可以尝试: - 使用更大的数据集 - 使用更大的模型 - 使用更好的文本编码器 - 使用无分类器引导
10. 下一步¶
完成本项目后,你可以:
- 尝试其他数据集:如LAION-5B
- 实现无分类器引导:提高生成质量
- 使用预训练模型:如Stable Diffusion
- 实现图像编辑:基于文本编辑图像
- 优化模型:使用DDIM、LDM等技术
下一章: 03-图像修复与编辑 - 实现图像修复和编辑功能