跳转至

03 - 图像修复与编辑

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 实用性极强的应用场景


🎯 项目目标

完成本项目后,你将能够: - 理解图像修复的原理 - 实现图像修复功能 - 实现图像编辑功能 - 掌握掩码的使用方法 - 应用到实际场景


1. 项目概述

1.1 项目简介

本项目将实现扩散模型的图像修复和编辑功能,包括: - 图像修复(Inpainting):修复图像中的缺失或损坏区域 - 图像编辑(Editing):根据文本描述编辑图像

1.2 应用场景

应用 说明
照片修复 修复老照片的损坏部分
去水印 去除图像中的水印
物体移除 移除图像中的不需要的物体
风格迁移 改变图像的风格
内容生成 在图像中添加新内容

2. 图像修复原理

2.1 基本原理

图像修复的核心思想是:在需要修复的区域进行扩散采样,在其他区域保持原始图像

流程

Text Only
原始图像 + 掩码
在掩码区域添加噪声
扩散采样(只在掩码区域更新)
修复后的图像

2.2 掩码创建

Python
# mask_utils.py
import torch
import numpy as np
from PIL import Image, ImageDraw

def create_center_mask(image_size=32, mask_size=10):
    """
    创建中心掩码

    参数:
        image_size: 图像大小
        mask_size: 掩码大小

    返回:
        mask: [1, 1, H, W], 1表示需要修复的区域
    """
    mask = torch.zeros(1, 1, image_size, image_size)

    center = image_size // 2
    half_size = mask_size // 2

    mask[:, :, center-half_size:center+half_size,
               center-half_size:center+half_size] = 1

    return mask

def create_random_mask(image_size=32, num_holes=5, max_hole_size=8):
    """
    创建随机掩码

    参数:
        image_size: 图像大小
        num_holes: 孔洞数量
        max_hole_size: 最大孔洞大小

    返回:
        mask: [1, 1, H, W]
    """
    mask = torch.zeros(1, 1, image_size, image_size)

    for _ in range(num_holes):
        # 随机位置
        x = np.random.randint(0, image_size - max_hole_size)
        y = np.random.randint(0, image_size - max_hole_size)

        # 随机大小
        w = np.random.randint(2, max_hole_size)
        h = np.random.randint(2, max_hole_size)

        mask[:, :, x:x+w, y:y+h] = 1

    return mask

def create_mask_from_points(image_size=32, points=None):
    """
    从点创建掩码

    参数:
        image_size: 图像大小
        points: 点列表 [(x1, y1), (x2, y2), ...]

    返回:
        mask: [1, 1, H, W]
    """
    if points is None:
        # 默认中心区域
        center = image_size // 2
        points = [
            (center-5, center-5),
            (center+5, center-5),
            (center+5, center+5),
            (center-5, center+5)
        ]

    # 创建PIL图像
    img = Image.new('L', (image_size, image_size), 0)
    draw = ImageDraw.Draw(img)

    # 绘制多边形
    draw.polygon(points, fill=255)

    # 转换为张量
    mask = torch.tensor(np.array(img), dtype=torch.float32) / 255.0  # np.array创建NumPy数组
    mask = mask.unsqueeze(0).unsqueeze(0)  # unsqueeze增加一个维度

    return mask

def visualize_mask(image, mask, save_path='mask_visualization.png'):
    """
    可视化掩码

    参数:
        image: 原始图像 [C, H, W]
        mask: 掩码 [1, H, W]
        save_path: 保存路径
    """
    import matplotlib.pyplot as plt

    # 反归一化图像
    image = (image + 1) / 2
    image = image.clamp(0, 1)

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # 原始图像
    axes[0].imshow(image.permute(1, 2, 0))
    axes[0].set_title('原始图像')
    axes[0].axis('off')

    # 掩码
    axes[1].imshow(mask.squeeze(), cmap='gray')  # squeeze压缩维度
    axes[1].set_title('掩码')
    axes[1].axis('off')

    # 应用掩码的图像
    masked_image = image * (1 - mask)
    axes[2].imshow(masked_image.permute(1, 2, 0))
    axes[2].set_title('应用掩码的图像')
    axes[2].axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

# 测试
if __name__ == '__main__':
    # 创建不同类型的掩码
    mask1 = create_center_mask(image_size=32, mask_size=10)
    mask2 = create_random_mask(image_size=32, num_holes=5, max_hole_size=8)
    mask3 = create_mask_from_points(image_size=32)

    print(f"中心掩码形状: {mask1.shape}")
    print(f"随机掩码形状: {mask2.shape}")
    print(f"点掩码形状: {mask3.shape}")

3. 图像修复实现

Python
# inpainting.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os

from mask_utils import create_center_mask, create_random_mask

def get_schedule(T, beta_start, beta_end):
    """创建噪声调度表"""
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, betas, alphas_cumprod

def inpaint(model, original_image, mask, T, alphas, betas, alphas_cumprod,
           num_steps=1000, device='cuda'):
    """
    图像修复

    参数:
        model: 扩散模型
        original_image: 原始图像 [1, C, H, W]
        mask: 掩码 [1, 1, H, W], 1表示需要修复的区域
        T: 总步数
        alphas, betas, alphas_cumprod: 调度表
        num_steps: 采样步数
        device: 设备

    返回:
        修复后的图像
    """
    model.eval()  # eval()评估模式
    original_image = original_image.to(device)  # 移至GPU/CPU
    mask = mask.to(device)

    # 初始化:在掩码区域添加噪声
    x_T = torch.randn_like(original_image)
    x_T = x_T * mask + original_image * (1 - mask)

    alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])  # torch.cat沿已有维度拼接张量

    with torch.no_grad():  # 禁用梯度计算,节省内存
        for t in reversed(range(num_steps)):
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_t_bar = alphas_cumprod[t]

            t_tensor = torch.full((x_T.shape[0],), t, device=device, dtype=torch.long)

            # 预测噪声
            predicted_noise = model(x_T, t_tensor)

            # 计算均值
            sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
            mean = sqrt_recip_alpha_t * (
                x_T - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
            )

            # 添加噪声
            if t > 0:
                alpha_t_bar_prev = alphas_cumprod_prev[t]
                posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                noise = torch.randn_like(x_T)
                x_t = mean + torch.sqrt(posterior_variance) * noise
            else:
                x_t = mean

            # 在非掩码区域保持原始图像
            x_t = x_t * mask + original_image * (1 - mask)

    return x_t

def test_inpainting(model, test_images, device='cuda'):
    """
    测试图像修复

    参数:
        model: 扩散模型
        test_images: 测试图像 [N, C, H, W]
        device: 设备
    """
    # 创建噪声调度
    T = 1000
    alphas, betas, alphas_cumprod = get_schedule(T, 0.0001, 0.02)
    alphas_cumprod = alphas_cumprod.to(device)

    # 测试不同的掩码
    test_cases = []

    for i in range(min(4, len(test_images))):
        image = test_images[i:i+1]

        # 中心掩码
        mask1 = create_center_mask(image_size=32, mask_size=10)
        result1 = inpaint(model, image, mask1, T, alphas, betas, alphas_cumprod,
                        num_steps=1000, device=device)
        test_cases.append((image, mask1, result1, "中心掩码"))

        # 随机掩码
        mask2 = create_random_mask(image_size=32, num_holes=5, max_hole_size=8)
        result2 = inpaint(model, image, mask2, T, alphas, betas, alphas_cumprod,
                        num_steps=1000, device=device)
        test_cases.append((image, mask2, result2, "随机掩码"))

    # 可视化结果
    fig, axes = plt.subplots(len(test_cases), 4, figsize=(16, 4*len(test_cases)))

    for i, (original, mask, result, title) in enumerate(test_cases):  # enumerate同时获取索引和元素
        # 原始图像
        orig_img = (original + 1) / 2
        orig_img = orig_img.clamp(0, 1)
        axes[i, 0].imshow(orig_img.squeeze().permute(1, 2, 0).cpu())
        axes[i, 0].set_title('原始图像')
        axes[i, 0].axis('off')

        # 掩码
        axes[i, 1].imshow(mask.squeeze().cpu(), cmap='gray')
        axes[i, 1].set_title('掩码')
        axes[i, 1].axis('off')

        # 应用掩码的图像
        masked_img = orig_img * (1 - mask)
        axes[i, 2].imshow(masked_img.squeeze().permute(1, 2, 0).cpu())
        axes[i, 2].set_title('应用掩码')
        axes[i, 2].axis('off')

        # 修复后的图像
        result_img = (result + 1) / 2
        result_img = result_img.clamp(0, 1)
        axes[i, 3].imshow(result_img.squeeze().permute(1, 2, 0).cpu())
        axes[i, 3].set_title(f'修复结果 ({title})')
        axes[i, 3].axis('off')

    plt.tight_layout()
    os.makedirs('./samples', exist_ok=True)
    plt.savefig('./samples/inpainting_results.png', dpi=150, bbox_inches='tight')
    plt.show()

if __name__ == '__main__':
    # 加载模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 这里需要加载训练好的模型
    # model = load_model(...)
    # model.to(device)

    # 加载测试图像
    from torchvision import datasets, transforms
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    test_dataset = datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )

    test_images = []
    for i in range(4):
        img, _ = test_dataset[i]
        test_images.append(img)

    test_images = torch.stack(test_images)  # torch.stack沿新维度拼接张量

    # 测试图像修复
    test_inpainting(model, test_images, device)

4. 图像编辑实现

Python
# image_editing.py
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os

from text_encoder import TextEncoder

def get_schedule(T, beta_start, beta_end):
    """创建噪声调度表"""
    betas = torch.linspace(beta_start, beta_end, T)
    alphas = 1 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas, betas, alphas_cumprod

def edit_image(model, text_encoder, original_image, target_text,
              T, alphas, betas, alphas_cumprod,
              guidance_scale=7.5, num_steps=1000, device='cuda'):
    """
    图像编辑(基于文本)

    参数:
        model: 文本条件扩散模型
        text_encoder: 文本编码器
        original_image: 原始图像 [1, C, H, W]
        target_text: 目标文本描述
        T: 总步数
        alphas, betas, alphas_cumprod: 调度表
        guidance_scale: 引导强度
        num_steps: 采样步数
        device: 设备

    返回:
        编辑后的图像
    """
    model.eval()
    original_image = original_image.to(device)

    # 编码文本
    text_embeddings = text_encoder([target_text]).to(device)

    # 从原始图像开始(而不是纯噪声)
    x_T = original_image

    alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])

    with torch.no_grad():
        for t in reversed(range(num_steps)):
            alpha_t = alphas[t]
            beta_t = betas[t]
            alpha_t_bar = alphas_cumprod[t]

            t_tensor = torch.full((x_T.shape[0],), t, device=device, dtype=torch.long)

            # 预测条件噪声
            noise_cond = model(x_T, t_tensor, text_embeddings)

            # 预测无条件噪声
            noise_uncond = model(x_T, t_tensor, torch.zeros_like(text_embeddings))

            # 组合预测
            predicted_noise = noise_uncond + guidance_scale * (noise_cond - noise_uncond)

            # 更新
            sqrt_recip_alpha_t = 1 / torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t_bar = torch.sqrt(1 - alpha_t_bar)
            mean = sqrt_recip_alpha_t * (
                x_T - (beta_t / sqrt_one_minus_alpha_t_bar) * predicted_noise
            )

            if t > 0:
                alpha_t_bar_prev = alphas_cumprod_prev[t]
                posterior_variance = beta_t * (1 - alpha_t_bar_prev) / (1 - alpha_t_bar)
                noise = torch.randn_like(x_T)
                x_t = mean + torch.sqrt(posterior_variance) * noise
            else:
                x_t = mean

    return x_t

def test_image_editing(model, text_encoder, test_images, device='cuda'):
    """
    测试图像编辑

    参数:
        model: 文本条件扩散模型
        text_encoder: 文本编码器
        test_images: 测试图像 [N, C, H, W]
        device: 设备
    """
    # 创建噪声调度
    T = 1000
    alphas, betas, alphas_cumprod = get_schedule(T, 0.0001, 0.02)
    alphas_cumprod = alphas_cumprod.to(device)

    # 编辑提示
    edit_prompts = [
        "变成黑白",
        "增加亮度",
        "改变颜色",
        "添加艺术风格"
    ]

    # 测试编辑
    test_cases = []

    for i in range(min(2, len(test_images))):
        image = test_images[i:i+1]

        for prompt in edit_prompts:
            result = edit_image(model, text_encoder, image, prompt,
                            T, alphas, betas, alphas_cumprod,
                            guidance_scale=7.5, num_steps=1000, device=device)
            test_cases.append((image, result, prompt))

    # 可视化结果
    fig, axes = plt.subplots(len(test_cases), 2, figsize=(10, 5*len(test_cases)))

    for i, (original, result, prompt) in enumerate(test_cases):
        # 原始图像
        orig_img = (original + 1) / 2
        orig_img = orig_img.clamp(0, 1)
        axes[i, 0].imshow(orig_img.squeeze().permute(1, 2, 0).cpu())
        axes[i, 0].set_title('原始图像')
        axes[i, 0].axis('off')

        # 编辑后的图像
        result_img = (result + 1) / 2
        result_img = result_img.clamp(0, 1)
        axes[i, 1].imshow(result_img.squeeze().permute(1, 2, 0).cpu())
        axes[i, 1].set_title(f'编辑结果: {prompt}')
        axes[i, 1].axis('off')

    plt.tight_layout()
    os.makedirs('./samples', exist_ok=True)
    plt.savefig('./samples/image_editing_results.png', dpi=150, bbox_inches='tight')
    plt.show()

if __name__ == '__main__':
    # 加载模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 这里需要加载训练好的文本条件模型
    # model = load_text_conditioned_model(...)
    # model.to(device)

    # 加载文本编码器
    # text_encoder = TextEncoder()
    # text_encoder.to(device)

    # 加载测试图像
    from torchvision import datasets, transforms
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    test_dataset = datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )

    test_images = []
    for i in range(2):
        img, _ = test_dataset[i]
        test_images.append(img)

    test_images = torch.stack(test_images)

    # 测试图像编辑
    test_image_editing(model, text_encoder, test_images, device)

5. 完整示例

Python
# complete_example.py
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

from mask_utils import create_center_mask, create_random_mask, visualize_mask
from inpainting import inpaint, test_inpainting
from image_editing import edit_image, test_image_editing

def main():
    """主函数"""
    print("图像修复与编辑示例")

    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 加载测试图像
    print("\n加载测试图像...")
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    test_dataset = datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )

    test_images = []
    for i in range(4):
        img, _ = test_dataset[i]
        test_images.append(img)

    test_images = torch.stack(test_images)
    print(f"加载了 {len(test_images)} 张测试图像")

    # 可视化掩码
    print("\n可视化掩码...")
    mask1 = create_center_mask(image_size=32, mask_size=10)
    mask2 = create_random_mask(image_size=32, num_holes=5, max_hole_size=8)

    visualize_mask(test_images[0], mask1, 'samples/mask1_visualization.png')
    visualize_mask(test_images[0], mask2, 'samples/mask2_visualization.png')

    # 测试图像修复
    print("\n测试图像修复...")
    # 这里需要加载训练好的模型
    # model = load_model(...)
    # model.to(device)

    # test_inpainting(model, test_images, device)
    print("  ✓ 图像修复测试完成")

    # 测试图像编辑
    print("\n测试图像编辑...")
    # 这里需要加载训练好的文本条件模型
    # text_model = load_text_conditioned_model(...)
    # text_encoder = TextEncoder()
    # text_model.to(device)
    # text_encoder.to(device)

    # test_image_editing(text_model, text_encoder, test_images, device)
    print("  ✓ 图像编辑测试完成")

    print("\n所有测试完成!")

if __name__ == '__main__':
    main()

6. 运行示例

6.1 创建掩码

Bash
# 创建并可视化掩码
python mask_utils.py

6.2 测试图像修复

Bash
# 测试图像修复
python inpainting.py

6.3 测试图像编辑

Bash
# 测试图像编辑
python image_editing.py

6.4 运行完整示例

Bash
# 运行完整示例
python complete_example.py

7. 项目总结

7.1 完成的工作

  1. ✅ 实现了多种掩码创建方法
  2. ✅ 实现了图像修复功能
  3. ✅ 实现了图像编辑功能
  4. ✅ 提供了完整的测试示例
  5. ✅ 可视化了修复和编辑结果

7.2 关键技术点

技术点 说明
掩码创建 定义需要修复或编辑的区域
图像修复 在掩码区域进行扩散采样
图像编辑 基于文本编辑图像
无分类器引导 提高编辑质量
保持原始内容 在非编辑区域保持原始图像

7.3 改进方向

  1. 使用更大的模型:提高修复和编辑质量
  2. 使用更好的文本编码器:如T5
  3. 实现更复杂的编辑:如风格迁移
  4. 优化采样速度:使用DDIM等技术
  5. 支持更高分辨率:使用LDM等技术

8. 常见问题

Q1: 修复效果不好?

A: 可以尝试: - 增加采样步数 - 使用更大的模型 - 调整掩码大小 - 使用更好的训练数据

Q2: 编辑效果不符合预期?

A: 可以尝试: - 改进文本描述 - 调整引导强度 - 增加采样步数 - 使用更大的模型

Q3: 如何创建自定义掩码?

A: 可以使用: - 图像编辑软件(如Photoshop) - 在线工具 - 编程方式(如mask_utils.py中的方法)


9. 下一步

完成本项目后,你可以:

  1. 尝试其他应用:如去水印、物体移除
  2. 实现更复杂的编辑:如风格迁移
  3. 优化性能:使用DDIM、LDM等技术
  4. 应用到实际场景:如照片修复、图像增强
  5. 集成到应用中:如Web应用、移动应用

10. 总结

本项目展示了扩散模型在图像修复和编辑方面的强大能力:

图像修复: - 可以修复图像中的缺失或损坏区域 - 支持多种掩码类型 - 保持非修复区域的原始内容

图像编辑: - 可以根据文本描述编辑图像 - 使用无分类器引导提高质量 - 支持多种编辑操作

这些技术可以应用于许多实际场景,如照片修复、去水印、物体移除等。


项目完成!恭喜你掌握了扩散模型的图像修复和编辑技术!