03 - 图像修复与编辑¶
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 实用性极强的应用场景
🎯 项目目标¶
完成本项目后,你将能够: - 理解图像修复的原理 - 实现图像修复功能 - 实现图像编辑功能 - 掌握掩码的使用方法 - 应用到实际场景
1. 项目概述¶
1.1 项目简介¶
本项目将实现扩散模型的图像修复和编辑功能,包括: - 图像修复(Inpainting):修复图像中的缺失或损坏区域 - 图像编辑(Editing):根据文本描述编辑图像
1.2 应用场景¶
| 应用 | 说明 |
|---|---|
| 照片修复 | 修复老照片的损坏部分 |
| 去水印 | 去除图像中的水印 |
| 物体移除 | 移除图像中的不需要的物体 |
| 风格迁移 | 改变图像的风格 |
| 内容生成 | 在图像中添加新内容 |
2. 图像修复原理¶
2.1 基本原理¶
图像修复的核心思想是:在需要修复的区域进行扩散采样,在其他区域保持原始图像。
流程:
2.2 掩码创建¶
Python
# mask_utils.py
import torch
import numpy as np
from PIL import Image, ImageDraw
def create_center_mask(image_size=32, mask_size=10):
"""
创建中心掩码
参数:
image_size: 图像大小
mask_size: 掩码大小
返回:
mask: [1, 1, H, W], 1表示需要修复的区域
"""
mask = torch.zeros(1, 1, image_size, image_size)
center = image_size // 2
half_size = mask_size // 2
mask[:, :, center-half_size:center+half_size,
center-half_size:center+half_size] = 1
return mask
def create_random_mask(image_size=32, num_holes=5, max_hole_size=8):
"""
创建随机掩码
参数:
image_size: 图像大小
num_holes: 孔洞数量
max_hole_size: 最大孔洞大小
返回:
mask: [1, 1, H, W]
"""
mask = torch.zeros(1, 1, image_size, image_size)
for _ in range(num_holes):
# 随机位置
x = np.random.randint(0, image_size - max_hole_size)
y = np.random.randint(0, image_size - max_hole_size)
# 随机大小
w = np.random.randint(2, max_hole_size)
h = np.random.randint(2, max_hole_size)
mask[:, :, x:x+w, y:y+h] = 1
return mask
def create_mask_from_points(image_size=32, points=None):
"""
从点创建掩码
参数:
image_size: 图像大小
points: 点列表 [(x1, y1), (x2, y2), ...]
返回:
mask: [1, 1, H, W]
"""
if points is None:
# 默认中心区域
center = image_size // 2
points = [
(center-5, center-5),
(center+5, center-5),
(center+5, center+5),
(center-5, center+5)
]
# 创建PIL图像
img = Image.new('L', (image_size, image_size), 0)
draw = ImageDraw.Draw(img)
# 绘制多边形
draw.polygon(points, fill=255)
# 转换为张量
mask = torch.tensor(np.array(img), dtype=torch.float32) / 255.0 # np.array创建NumPy数组
mask = mask.unsqueeze(0).unsqueeze(0) # unsqueeze增加一个维度
return mask
def visualize_mask(image, mask, save_path='mask_visualization.png'):
"""
可视化掩码
参数:
image: 原始图像 [C, H, W]
mask: 掩码 [1, H, W]
save_path: 保存路径
"""
import matplotlib.pyplot as plt
# 反归一化图像
image = (image + 1) / 2
image = image.clamp(0, 1)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 原始图像
axes[0].imshow(image.permute(1, 2, 0))
axes[0].set_title('原始图像')
axes[0].axis('off')
# 掩码
axes[1].imshow(mask.squeeze(), cmap='gray') # squeeze压缩维度
axes[1].set_title('掩码')
axes[1].axis('off')
# 应用掩码的图像
masked_image = image * (1 - mask)
axes[2].imshow(masked_image.permute(1, 2, 0))
axes[2].set_title('应用掩码的图像')
axes[2].axis('off')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
# 测试
if __name__ == '__main__':
# 创建不同类型的掩码
mask1 = create_center_mask(image_size=32, mask_size=10)
mask2 = create_random_mask(image_size=32, num_holes=5, max_hole_size=8)
mask3 = create_mask_from_points(image_size=32)
print(f"中心掩码形状: {mask1.shape}")
print(f"随机掩码形状: {mask2.shape}")
print(f"点掩码形状: {mask3.shape}")
3. 图像修复实现¶
Python
# inpainting.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os
from mask_utils import create_center_mask, create_random_mask
def get_schedule(T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def inpaint(model, original_image, mask, T, alphas, betas, alphas_cumprod,
num_steps=1000, device='cuda'):
"""
图像修复
参数:
model: 扩散模型
original_image: 原始图像 [1, C, H, W]
mask: 掩码 [1, 1, H, W], 1表示需要修复的区域
T: 总步数
alphas, betas, alphas_cumprod: 调度表
num_steps: 采样步数
device: 设备
返回:
修复后的图像
"""
model.eval() # eval()评估模式
original_image = original_image.to(device) # 移至GPU/CPU
mask = mask.to(device)
# 初始化:在掩码区域添加噪声
x_T = torch.randn_like(original_image)
x_T = x_T * mask + original_image * (1 - mask)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]]) # torch.cat沿已有维度拼接张量
with torch.no_grad(): # 禁用梯度计算,节省内存
for t in reversed(range(num_steps)):
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
t_tensor = torch.full((x_T.shape[0],), t, device=device, dtype=torch.long)
# 预测噪声
predicted_noise = model(x_T, t_tensor)
# 计算均值
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_T - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
# 添加噪声
if t > 0:
alpha_t_bar_prev = alphas_cumprod_prev[t]
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_T)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
# 在非掩码区域保持原始图像
x_t = x_t * mask + original_image * (1 - mask)
return x_t
def test_inpainting(model, test_images, device='cuda'):
"""
测试图像修复
参数:
model: 扩散模型
test_images: 测试图像 [N, C, H, W]
device: 设备
"""
# 创建噪声调度
T = 1000
alphas, betas, alphas_cumprod = get_schedule(T, 0.0001, 0.02)
alphas_cumprod = alphas_cumprod.to(device)
# 测试不同的掩码
test_cases = []
for i in range(min(4, len(test_images))):
image = test_images[i:i+1]
# 中心掩码
mask1 = create_center_mask(image_size=32, mask_size=10)
result1 = inpaint(model, image, mask1, T, alphas, betas, alphas_cumprod,
num_steps=1000, device=device)
test_cases.append((image, mask1, result1, "中心掩码"))
# 随机掩码
mask2 = create_random_mask(image_size=32, num_holes=5, max_hole_size=8)
result2 = inpaint(model, image, mask2, T, alphas, betas, alphas_cumprod,
num_steps=1000, device=device)
test_cases.append((image, mask2, result2, "随机掩码"))
# 可视化结果
fig, axes = plt.subplots(len(test_cases), 4, figsize=(16, 4*len(test_cases)))
for i, (original, mask, result, title) in enumerate(test_cases): # enumerate同时获取索引和元素
# 原始图像
orig_img = (original + 1) / 2
orig_img = orig_img.clamp(0, 1)
axes[i, 0].imshow(orig_img.squeeze().permute(1, 2, 0).cpu())
axes[i, 0].set_title('原始图像')
axes[i, 0].axis('off')
# 掩码
axes[i, 1].imshow(mask.squeeze().cpu(), cmap='gray')
axes[i, 1].set_title('掩码')
axes[i, 1].axis('off')
# 应用掩码的图像
masked_img = orig_img * (1 - mask)
axes[i, 2].imshow(masked_img.squeeze().permute(1, 2, 0).cpu())
axes[i, 2].set_title('应用掩码')
axes[i, 2].axis('off')
# 修复后的图像
result_img = (result + 1) / 2
result_img = result_img.clamp(0, 1)
axes[i, 3].imshow(result_img.squeeze().permute(1, 2, 0).cpu())
axes[i, 3].set_title(f'修复结果 ({title})')
axes[i, 3].axis('off')
plt.tight_layout()
os.makedirs('./samples', exist_ok=True)
plt.savefig('./samples/inpainting_results.png', dpi=150, bbox_inches='tight')
plt.show()
if __name__ == '__main__':
# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 这里需要加载训练好的模型
# model = load_model(...)
# model.to(device)
# 加载测试图像
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
test_images = []
for i in range(4):
img, _ = test_dataset[i]
test_images.append(img)
test_images = torch.stack(test_images) # torch.stack沿新维度拼接张量
# 测试图像修复
test_inpainting(model, test_images, device)
4. 图像编辑实现¶
Python
# image_editing.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os
from text_encoder import TextEncoder
def get_schedule(T, beta_start, beta_end):
"""创建噪声调度表"""
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
return alphas, betas, alphas_cumprod
def edit_image(model, text_encoder, original_image, target_text,
T, alphas, betas, alphas_cumprod,
guidance_scale=7.5, num_steps=1000, device='cuda'):
"""
图像编辑(基于文本)
参数:
model: 文本条件扩散模型
text_encoder: 文本编码器
original_image: 原始图像 [1, C, H, W]
target_text: 目标文本描述
T: 总步数
alphas, betas, alphas_cumprod: 调度表
guidance_scale: 引导强度
num_steps: 采样步数
device: 设备
返回:
编辑后的图像
"""
model.eval()
original_image = original_image.to(device)
# 编码文本
text_embeddings = text_encoder([target_text]).to(device)
# 从原始图像开始(而不是纯噪声)
x_T = original_image
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
with torch.no_grad():
for t in reversed(range(num_steps)):
alpha_t = alphas[t]
beta_t = betas[t]
alpha_t_bar = alphas_cumprod[t]
t_tensor = torch.full((x_T.shape[0],), t, device=device, dtype=torch.long)
# 预测条件噪声
noise_cond = model(x_T, t_tensor, text_embeddings)
# 预测无条件噪声
noise_uncond = model(x_T, t_tensor, torch.zeros_like(text_embeddings))
# 组合预测
predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
# 更新
sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
mean = sqrt_recip_alpha_t * (
x_T - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
)
if t > 0:
alpha_t_bar_prev = alphas_cumprod_prev[t]
posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
noise = torch.randn_like(x_T)
x_t = mean + torch.sqrt(posterior_variance) * noise
else:
x_t = mean
return x_t
def test_image_editing(model, text_encoder, test_images, device='cuda'):
"""
测试图像编辑
参数:
model: 文本条件扩散模型
text_encoder: 文本编码器
test_images: 测试图像 [N, C, H, W]
device: 设备
"""
# 创建噪声调度
T = 1000
alphas, betas, alphas_cumprod = get_schedule(T, 0.0001, 0.02)
alphas_cumprod = alphas_cumprod.to(device)
# 编辑提示
edit_prompts = [
"变成黑白",
"增加亮度",
"改变颜色",
"添加艺术风格"
]
# 测试编辑
test_cases = []
for i in range(min(2, len(test_images))):
image = test_images[i:i+1]
for prompt in edit_prompts:
result = edit_image(model, text_encoder, image, prompt,
T, alphas, betas, alphas_cumprod,
guidance_scale=7.5, num_steps=1000, device=device)
test_cases.append((image, result, prompt))
# 可视化结果
fig, axes = plt.subplots(len(test_cases), 2, figsize=(10, 5*len(test_cases)))
for i, (original, result, prompt) in enumerate(test_cases):
# 原始图像
orig_img = (original + 1) / 2
orig_img = orig_img.clamp(0, 1)
axes[i, 0].imshow(orig_img.squeeze().permute(1, 2, 0).cpu())
axes[i, 0].set_title('原始图像')
axes[i, 0].axis('off')
# 编辑后的图像
result_img = (result + 1) / 2
result_img = result_img.clamp(0, 1)
axes[i, 1].imshow(result_img.squeeze().permute(1, 2, 0).cpu())
axes[i, 1].set_title(f'编辑结果: {prompt}')
axes[i, 1].axis('off')
plt.tight_layout()
os.makedirs('./samples', exist_ok=True)
plt.savefig('./samples/image_editing_results.png', dpi=150, bbox_inches='tight')
plt.show()
if __name__ == '__main__':
# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 这里需要加载训练好的文本条件模型
# model = load_text_conditioned_model(...)
# model.to(device)
# 加载文本编码器
# text_encoder = TextEncoder()
# text_encoder.to(device)
# 加载测试图像
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
test_images = []
for i in range(2):
img, _ = test_dataset[i]
test_images.append(img)
test_images = torch.stack(test_images)
# 测试图像编辑
test_image_editing(model, text_encoder, test_images, device)
5. 完整示例¶
Python
# complete_example.py
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
from mask_utils import create_center_mask, create_random_mask, visualize_mask
from inpainting import inpaint, test_inpainting
from image_editing import edit_image, test_image_editing
def main():
"""主函数"""
print("图像修复与编辑示例")
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 加载测试图像
print("\n加载测试图像...")
transform = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
test_images = []
for i in range(4):
img, _ = test_dataset[i]
test_images.append(img)
test_images = torch.stack(test_images)
print(f"加载了 {len(test_images)} 张测试图像")
# 可视化掩码
print("\n可视化掩码...")
mask1 = create_center_mask(image_size=32, mask_size=10)
mask2 = create_random_mask(image_size=32, num_holes=5, max_hole_size=8)
visualize_mask(test_images[0], mask1, 'samples/mask1_visualization.png')
visualize_mask(test_images[0], mask2, 'samples/mask2_visualization.png')
# 测试图像修复
print("\n测试图像修复...")
# 这里需要加载训练好的模型
# model = load_model(...)
# model.to(device)
# test_inpainting(model, test_images, device)
print(" ✓ 图像修复测试完成")
# 测试图像编辑
print("\n测试图像编辑...")
# 这里需要加载训练好的文本条件模型
# text_model = load_text_conditioned_model(...)
# text_encoder = TextEncoder()
# text_model.to(device)
# text_encoder.to(device)
# test_image_editing(text_model, text_encoder, test_images, device)
print(" ✓ 图像编辑测试完成")
print("\n所有测试完成!")
if __name__ == '__main__':
main()
6. 运行示例¶
6.1 创建掩码¶
6.2 测试图像修复¶
6.3 测试图像编辑¶
6.4 运行完整示例¶
7. 项目总结¶
7.1 完成的工作¶
- ✅ 实现了多种掩码创建方法
- ✅ 实现了图像修复功能
- ✅ 实现了图像编辑功能
- ✅ 提供了完整的测试示例
- ✅ 可视化了修复和编辑结果
7.2 关键技术点¶
| 技术点 | 说明 |
|---|---|
| 掩码创建 | 定义需要修复或编辑的区域 |
| 图像修复 | 在掩码区域进行扩散采样 |
| 图像编辑 | 基于文本编辑图像 |
| 无分类器引导 | 提高编辑质量 |
| 保持原始内容 | 在非编辑区域保持原始图像 |
7.3 改进方向¶
- 使用更大的模型:提高修复和编辑质量
- 使用更好的文本编码器:如T5
- 实现更复杂的编辑:如风格迁移
- 优化采样速度:使用DDIM等技术
- 支持更高分辨率:使用LDM等技术
8. 常见问题¶
Q1: 修复效果不好?¶
A: 可以尝试: - 增加采样步数 - 使用更大的模型 - 调整掩码大小 - 使用更好的训练数据
Q2: 编辑效果不符合预期?¶
A: 可以尝试: - 改进文本描述 - 调整引导强度 - 增加采样步数 - 使用更大的模型
Q3: 如何创建自定义掩码?¶
A: 可以使用: - 图像编辑软件(如Photoshop) - 在线工具 - 编程方式(如mask_utils.py中的方法)
9. 下一步¶
完成本项目后,你可以:
- 尝试其他应用:如去水印、物体移除
- 实现更复杂的编辑:如风格迁移
- 优化性能:使用DDIM、LDM等技术
- 应用到实际场景:如照片修复、图像增强
- 集成到应用中:如Web应用、移动应用
10. 总结¶
本项目展示了扩散模型在图像修复和编辑方面的强大能力:
图像修复: - 可以修复图像中的缺失或损坏区域 - 支持多种掩码类型 - 保持非修复区域的原始内容
图像编辑: - 可以根据文本描述编辑图像 - 使用无分类器引导提高质量 - 支持多种编辑操作
这些技术可以应用于许多实际场景,如照片修复、去水印、物体移除等。
项目完成!恭喜你掌握了扩散模型的图像修复和编辑技术!