跳转至

01 - 环境搭建与数据准备

学习时间: 2小时 重要性: ⭐⭐⭐⭐⭐ 动手实践的第一步


🎯 学习目标

完成本章后,你将能够: - 搭建完整的扩散模型开发环境 - 准备和处理训练数据 - 理解数据预处理的重要性 - 实现高效的数据加载器


1. 环境配置

1.1 硬件要求

最低配置: - GPU: NVIDIA GTX 1060 6GB 或更高 - 内存: 16GB RAM - 存储: 50GB 可用空间

推荐配置: - GPU: NVIDIA RTX 3060 12GB 或更高 - 内存: 32GB RAM - 存储: 100GB SSD

1.2 软件环境

Bash
# 创建conda环境
conda create -n diffusion python=3.9
conda activate diffusion

# 安装PyTorch(根据你的CUDA版本选择)
# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# 或 CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# 安装其他依赖
pip install numpy matplotlib pillow tqdm tensorboard einops

# 可选:安装wandb用于实验跟踪
pip install wandb

1.3 验证安装

Python
import torch
import torchvision

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA版本: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

2. 数据集准备

2.1 数据集选择

初学者推荐: - MNIST: 手写数字,28x28,10类,适合快速验证 - CIFAR-10: 彩色小图像,32x32,10类,标准基准 - CelebA: 人脸图像,64x64或更高,无条件生成

进阶数据集: - ImageNet: 大规模图像分类数据集 - LSUN: 场景理解数据集 - 自定义数据集: 你自己的图像数据

2.2 下载MNIST数据集

Python
import torchvision
import torchvision.transforms as transforms

# 定义数据变换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量 [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

# 下载训练集
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# 下载测试集
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"图像形状: {train_dataset[0][0].shape}")
print(f"标签: {train_dataset[0][1]}")

2.3 下载CIFAR-10数据集

Python
# CIFAR-10数据变换
transform_cifar = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 数据增强
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # RGB三通道
])

# 下载CIFAR-10
train_dataset_cifar = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_cifar
)

test_dataset_cifar = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_cifar
)

print(f"训练集大小: {len(train_dataset_cifar)}")
print(f"类别: {train_dataset_cifar.classes}")

2.4 自定义数据集

Python
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomImageDataset(Dataset):
    """
    自定义图像数据集
    """
    def __init__(self, root_dir, transform=None):
        """
        参数:
            root_dir: 图像文件夹路径
            transform: 数据变换
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [
            f for f in os.listdir(root_dir)
            if f.endswith(('.png', '.jpg', '.jpeg'))
        ]

    def __len__(self):  # __len__定义len()行为
        return len(self.image_files)

    def __getitem__(self, idx):  # __getitem__定义索引访问行为
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, 0  # 返回图像和虚拟标签

# 使用示例
# custom_dataset = CustomImageDataset(
#     root_dir='path/to/your/images',
#     transform=transform_cifar
# )

3. 数据预处理

3.1 为什么需要预处理

  1. 归一化: 将像素值从 [0, 255] 转换到 [-1, 1] 或 [0, 1]
  2. 统一尺寸: 确保所有图像大小一致
  3. 数据增强: 增加数据多样性,防止过拟合
  4. 去噪: 清理低质量图像

3.2 数据变换流程

Python
def get_transforms(image_size=32, train=True):
    """
    获取数据变换

    参数:
        image_size: 输出图像尺寸
        train: 是否为训练模式

    返回:
        transform: 组合的数据变换
    """
    transforms_list = []

    # 调整图像大小
    transforms_list.append(transforms.Resize(image_size))

    # 中心裁剪或随机裁剪
    if train:
        transforms_list.append(transforms.RandomCrop(image_size, padding=4))
    else:
        transforms_list.append(transforms.CenterCrop(image_size))

    # 随机水平翻转(仅训练)
    if train:
        transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))

    # 转换为张量
    transforms_list.append(transforms.ToTensor())

    # 归一化到 [-1, 1]
    transforms_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))

    return transforms.Compose(transforms_list)

# 测试变换
test_transform = get_transforms(image_size=32, train=True)
print(test_transform)

3.3 数据可视化

Python
import matplotlib.pyplot as plt
import numpy as np

def denormalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
    """
    反归一化,将张量转换回可显示的图像
    """
    mean = torch.tensor(mean).view(3, 1, 1)  # 重塑张量形状
    std = torch.tensor(std).view(3, 1, 1)
    tensor = tensor * std + mean
    return torch.clamp(tensor, 0, 1)

def visualize_dataset(dataset, num_samples=16, figsize=(12, 12)):
    """
    可视化数据集样本
    """
    fig, axes = plt.subplots(4, 4, figsize=figsize)
    axes = axes.flatten()

    for i in range(num_samples):
        idx = np.random.randint(len(dataset))
        image, label = dataset[idx]

        # 反归一化
        image = denormalize(image)

        # 转换为numpy数组
        image_np = image.permute(1, 2, 0).numpy()

        axes[i].imshow(image_np)
        if hasattr(dataset, 'classes'):  # hasattr检查对象是否有某属性
            axes[i].set_title(f'Class: {dataset.classes[label]}')
        else:
            axes[i].set_title(f'Label: {label}')
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('dataset_visualization.png', dpi=150)
    plt.show()

# 可视化CIFAR-10
# visualize_dataset(train_dataset_cifar)

4. 数据加载器

4.1 创建DataLoader

Python
from torch.utils.data import DataLoader

def create_dataloaders(
    dataset_name='cifar10',
    batch_size=128,
    num_workers=4,
    image_size=32
):
    """
    创建数据加载器

    参数:
        dataset_name: 数据集名称 ('mnist', 'cifar10', 'custom')
        batch_size: 批次大小
        num_workers: 数据加载线程数
        image_size: 图像尺寸

    返回:
        train_loader, test_loader
    """
    # 获取数据变换
    train_transform = get_transforms(image_size, train=True)
    test_transform = get_transforms(image_size, train=False)

    # 加载数据集
    if dataset_name == 'mnist':
        train_dataset = torchvision.datasets.MNIST(
            root='./data', train=True, download=True, transform=train_transform
        )
        test_dataset = torchvision.datasets.MNIST(
            root='./data', train=False, download=True, transform=test_transform
        )
    elif dataset_name == 'cifar10':
        train_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=train_transform
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True, transform=test_transform
        )
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # 创建数据加载器
    train_loader = DataLoader(  # DataLoader批量加载数据
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, test_loader

# 创建数据加载器
train_loader, test_loader = create_dataloaders(
    dataset_name='cifar10',
    batch_size=128,
    num_workers=4,
    image_size=32
)

print(f"训练批次数量: {len(train_loader)}")
print(f"测试批次数量: {len(test_loader)}")

# 测试数据加载
for batch_idx, (images, labels) in enumerate(train_loader):  # enumerate同时获取索引和元素
    print(f"批次 {batch_idx}:")
    print(f"  图像形状: {images.shape}")
    print(f"  标签形状: {labels.shape}")
    print(f"  图像范围: [{images.min():.2f}, {images.max():.2f}]")
    break

4.2 优化数据加载

Python
class InfiniteDataLoader:
    """
    无限数据加载器,用于持续训练
    """
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.data_iter = iter(dataloader)

    def __next__(self):
        try:  # try/except捕获异常
            batch = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.dataloader)
            batch = next(self.data_iter)
        return batch

    def __iter__(self):
        return self

# 使用示例
# infinite_loader = InfiniteDataLoader(train_loader)
# for i in range(1000000):
#     images, labels = next(infinite_loader)
#     # 训练代码...

5. 数据管道完整代码

Python
# data_utils.py
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np

class DiffusionDataModule:
    """
    扩散模型数据模块
    """
    def __init__(
        self,
        dataset_name='cifar10',
        data_dir='./data',
        batch_size=128,
        num_workers=4,
        image_size=32,
        train_split=0.9
    ):
        self.dataset_name = dataset_name
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = image_size
        self.train_split = train_split

        # 数据变换
        self.train_transform = self._get_transforms(train=True)
        self.test_transform = self._get_transforms(train=False)

    def _get_transforms(self, train=True):
        """获取数据变换"""
        transforms_list = [
            transforms.Resize(self.image_size),
        ]

        if train:
            transforms_list.extend([
                transforms.RandomCrop(self.image_size, padding=4),
                transforms.RandomHorizontalFlip(p=0.5),
            ])
        else:
            transforms_list.append(transforms.CenterCrop(self.image_size))

        transforms_list.extend([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        return transforms.Compose(transforms_list)

    def setup(self):
        """设置数据集"""
        if self.dataset_name == 'mnist':
            self.train_dataset = torchvision.datasets.MNIST(
                self.data_dir, train=True, download=True, transform=self.train_transform
            )
            self.test_dataset = torchvision.datasets.MNIST(
                self.data_dir, train=False, download=True, transform=self.test_transform
            )
            self.num_classes = 10
            self.channels = 1

        elif self.dataset_name == 'cifar10':
            self.train_dataset = torchvision.datasets.CIFAR10(
                self.data_dir, train=True, download=True, transform=self.train_transform
            )
            self.test_dataset = torchvision.datasets.CIFAR10(
                self.data_dir, train=False, download=True, transform=self.test_transform
            )
            self.num_classes = 10
            self.channels = 3

        elif self.dataset_name == 'celeba':
            # CelebA需要单独下载
            self.train_dataset = torchvision.datasets.CelebA(
                self.data_dir, split='train', download=True, transform=self.train_transform
            )
            self.test_dataset = torchvision.datasets.CelebA(
                self.data_dir, split='test', download=True, transform=self.test_transform
            )
            self.num_classes = None
            self.channels = 3
        else:
            raise ValueError(f"Unknown dataset: {self.dataset_name}")

    def train_dataloader(self):
        """训练数据加载器"""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True
        )

    def test_dataloader(self):
        """测试数据加载器"""
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def get_batch(self, split='train'):
        """获取一个批次"""
        loader = self.train_dataloader() if split == 'train' else self.test_dataloader()
        return next(iter(loader))

# 测试
if __name__ == "__main__":
    # 创建数据模块
    data_module = DiffusionDataModule(
        dataset_name='cifar10',
        batch_size=16,
        image_size=32
    )

    # 设置数据集
    data_module.setup()

    # 获取数据加载器
    train_loader = data_module.train_dataloader()

    # 测试
    images, labels = next(iter(train_loader))
    print(f"图像批次形状: {images.shape}")
    print(f"标签批次形状: {labels.shape}")
    print(f"图像范围: [{images.min():.3f}, {images.max():.3f}]")
    print(f"数据集类别数: {data_module.num_classes}")
    print(f"图像通道数: {data_module.channels}")

6. 本章总结

核心要点

  1. 环境配置
  2. PyTorch + CUDA
  3. 必要的Python包
  4. 验证GPU可用性

  5. 数据准备

  6. 选择合适的数据集
  7. 下载和预处理
  8. 自定义数据集支持

  9. 数据变换

  10. 归一化到 [-1, 1]
  11. 数据增强
  12. 统一尺寸

  13. 数据加载

  14. 高效的DataLoader
  15. 多线程加载
  16. 无限数据流

关键代码

Python
# 数据变换
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 数据加载器
DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

📝 自测问题

基础问题

  1. 数据预处理
  2. 为什么需要将图像归一化到 [-1, 1]?
  3. 数据增强在扩散模型中起什么作用?
  4. 如何选择合适的图像尺寸?

  5. 数据加载

  6. num_workers参数的作用是什么?
  7. 为什么要使用pin_memory=True?
  8. drop_last=True有什么影响?

编程练习

  1. 实现一个自定义数据集类
  2. 可视化不同数据增强的效果
  3. 实现数据加载的性能测试

思考题

  1. 如果图像不归一化,会对训练产生什么影响?
  2. 如何设计适合高分辨率图像的数据管道?
  3. 数据质量对扩散模型有多重要?

🔗 下一步

环境搭建完成后,我们将学习UNet架构详解,这是扩散模型的核心网络结构。

→ 下一步:02-UNet架构详解.md