跳转至

03 - 数据加载与预处理

学习时间: 2.5小时 重要性: ⭐⭐⭐⭐⭐ 高质量数据是训练成功的基础


🎯 学习目标

完成本章后,你将能够: - 理解扩散模型对数据的要求 - 掌握常用的数据集加载方法 - 实现高效的数据预处理流程 - 创建自定义数据集类 - 优化数据加载性能


1. 数据要求概述

1.1 扩散模型对数据的要求

要求 说明 原因
图像尺寸 通常32×32到512×512 影响模型大小和训练速度
像素值范围 [0, 1] 或 [-1, 1] 便于模型学习
数据质量 高质量、多样化 影响生成质量
数据量 越多越好 提高泛化能力
标注 可选(用于条件生成) 实现可控生成

1.2 常用数据集

数据集 图像数量 图像尺寸 特点
CIFAR-10 60,000 32×32 适合快速实验
ImageNet 1.28M 224×224 大规模、高质量
LSUN 数百万 256×256 场景图像
FFHQ 70,000 1024×1024 高质量人脸
LAION-5B 50亿 多种尺寸 超大规模

2. 基础数据加载

2.1 CIFAR-10数据集

CIFAR-10是学习和实验扩散模型的理想数据集:

Python
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

def load_cifar10(batch_size=128, image_size=32, augment=True):
    """
    加载CIFAR-10数据集

    参数:
        batch_size: 批量大小
        image_size: 目标图像大小
        augment: 是否使用数据增强

    返回:
        train_loader, test_loader
    """
    # 训练数据变换
    if augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    # 测试数据变换
    test_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # 加载数据集
    train_dataset = datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=train_transform
    )

    test_dataset = datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=test_transform
    )

    # 创建数据加载器
    train_loader = DataLoader(  # DataLoader批量加载数据
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    return train_loader, test_loader

# 使用示例
train_loader, test_loader = load_cifar10(batch_size=128, image_size=32)

print(f"训练集大小: {len(train_loader.dataset)}")
print(f"测试集大小: {len(test_loader.dataset)}")
print(f"批量大小: {train_loader.batch_size}")
print(f"训练批次数: {len(train_loader)}")

2.2 可视化数据集

Python
def visualize_cifar10(dataloader, num_samples=16, save_path='cifar10_samples.png'):
    """
    可视化CIFAR-10样本

    参数:
        dataloader: 数据加载器
        num_samples: 显示的样本数
        save_path: 保存路径
    """
    # 获取一个batch
    images, labels = next(iter(dataloader))

    # 反归一化
    images = images * 0.5 + 0.5  # 从[-1, 1]到[0, 1]
    images = images.clamp(0, 1)

    # 类别名称
    classes = ['飞机', '汽车', '鸟', '猫', '鹿',
               '狗', '青蛙', '马', '船', '卡车']

    # 创建网格
    from torchvision.utils import make_grid
    grid = make_grid(images[:num_samples], nrow=4, padding=2, normalize=False)

    # 显示
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0))
    plt.title('CIFAR-10 样本', fontsize=16)
    plt.axis('off')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

    # 打印标签
    print("样本标签:")
    for i in range(min(num_samples, len(labels))):
        print(f"  {i+1}. {classes[labels[i]]}")

# 可视化
visualize_cifar10(train_loader, num_samples=16)

3. ImageNet数据集

3.1 加载ImageNet

Python
import os
from torch.utils.data import Dataset
from PIL import Image

class ImageNetDataset(Dataset):
    """自定义ImageNet数据集"""

    def __init__(self, root_dir, split='train', transform=None):
        """
        参数:
            root_dir: 数据集根目录
            split: 'train' 或 'val'
            transform: 数据变换
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform

        # 获取图像路径
        self.image_paths = []
        self.labels = []

        if split == 'train':
            data_dir = os.path.join(root_dir, 'train')
        else:
            data_dir = os.path.join(root_dir, 'val')

        # 遍历类别文件夹
        for class_idx, class_name in enumerate(sorted(os.listdir(data_dir))):  # enumerate同时获取索引和元素
            class_dir = os.path.join(data_dir, class_name)

            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(class_dir, img_name)
                        self.image_paths.append(img_path)
                        self.labels.append(class_idx)

    def __len__(self):  # __len__定义len()行为
        return len(self.image_paths)

    def __getitem__(self, idx):  # __getitem__定义索引访问行为
        # 加载图像
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        # 应用变换
        if self.transform:
            image = self.transform(image)

        return image, label

def load_imagenet(root_dir, batch_size=128, image_size=256, split='train'):
    """
    加载ImageNet数据集

    参数:
        root_dir: 数据集根目录
        batch_size: 批量大小
        image_size: 目标图像大小
        split: 'train' 或 'val'

    返回:
        dataloader
    """
    # 数据变换
    if split == 'train':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(int(image_size * 1.14)),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    # 创建数据集
    dataset = ImageNetDataset(root_dir, split=split, transform=transform)

    # 创建数据加载器
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(split == 'train'),
        num_workers=8,
        pin_memory=True,
        drop_last=(split == 'train')
    )

    return dataloader

# 使用示例
# train_loader = load_imagenet('/path/to/imagenet', batch_size=128, split='train')
# val_loader = load_imagenet('/path/to/imagenet', batch_size=128, split='val')

4. 自定义数据集

4.1 从文件夹加载图像

Python
class ImageFolderDataset(Dataset):
    """从文件夹加载图像的自定义数据集"""

    def __init__(self, root_dir, transform=None, extensions=('.png', '.jpg', '.jpeg', '.bmp')):
        """
        参数:
            root_dir: 图像文件夹根目录
            transform: 数据变换
            extensions: 支持的图像扩展名
        """
        self.root_dir = root_dir
        self.transform = transform
        self.extensions = extensions

        # 收集所有图像路径
        self.image_paths = []

        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(extensions):
                    self.image_paths.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 加载图像
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        # 应用变换
        if self.transform:
            image = self.transform(image)

        return image, 0  # 返回图像和虚拟标签

# 使用示例
custom_dataset = ImageFolderDataset(
    root_dir='./my_images',
    transform=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)

custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)

4.2 带标签的自定义数据集

Python
class LabeledImageDataset(Dataset):
    """带标签的自定义图像数据集"""

    def __init__(self, data_list, transform=None):
        """
        参数:
            data_list: 数据列表,每个元素为 (image_path, label)
            transform: 数据变换
        """
        self.data_list = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        img_path, label = self.data_list[idx]

        # 加载图像
        image = Image.open(img_path).convert('RGB')

        # 应用变换
        if self.transform:
            image = self.transform(image)

        return image, label

# 创建数据列表
data_list = [
    ('./images/cat1.jpg', 0),
    ('./images/cat2.jpg', 0),
    ('./images/dog1.jpg', 1),
    ('./images/dog2.jpg', 1),
]

# 创建数据集
labeled_dataset = LabeledImageDataset(
    data_list,
    transform=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)

labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)

5. 数据预处理技巧

5.1 图像归一化

Python
def compute_dataset_statistics(dataloader):
    """
    计算数据集的均值和标准差

    参数:
        dataloader: 数据加载器

    返回:
        mean, std
    """
    mean = 0.0
    std = 0.0
    total_images = 0

    for images, _ in dataloader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)  # 重塑张量形状
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images += batch_samples

    mean /= total_images
    std /= total_images

    return mean, std

# 计算CIFAR-10的统计信息
# 注意:这需要先加载未归一化的数据
no_norm_transform = transforms.Compose([
    transforms.ToTensor()
])

temp_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=no_norm_transform
)

temp_loader = DataLoader(temp_dataset, batch_size=128, shuffle=False)
mean, std = compute_dataset_statistics(temp_loader)

print(f"数据集均值: {mean}")
print(f"数据集标准差: {std}")

5.2 数据增强策略

Python
def get_advanced_augmentation(image_size=256):
    """
    获取高级数据增强

    参数:
        image_size: 目标图像大小

    返回:
        数据增强变换
    """
    return transforms.Compose([
        # 随机裁剪和缩放
        transforms.RandomResizedCrop(
            image_size,
            scale=(0.8, 1.0),
            ratio=(0.9, 1.1)
        ),

        # 随机水平翻转
        transforms.RandomHorizontalFlip(p=0.5),

        # 随机旋转
        transforms.RandomRotation(degrees=15),

        # 颜色抖动
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1
        ),

        # 随机仿射变换
        transforms.RandomAffine(
            degrees=0,
            translate=(0.1, 0.1),
            scale=(0.9, 1.1),
            shear=5
        ),

        # 转换为张量
        transforms.ToTensor(),

        # 归一化
        transforms.Normalize(
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5]
        )
    ])

# 使用示例
augmented_transform = get_advanced_augmentation(image_size=256)

5.3 自定义增强

Python
import random
import numpy as np

class RandomGaussianBlur:
    """随机高斯模糊"""

    def __init__(self, p=0.5, kernel_size=5):
        self.p = p
        self.kernel_size = kernel_size

    def __call__(self, img):  # __call__使实例可像函数一样调用
        if random.random() < self.p:
            from PIL import ImageFilter
            img = img.filter(ImageFilter.GaussianBlur(self.kernel_size))
        return img

class RandomCutout:
    """随机遮挡"""

    def __init__(self, p=0.5, scale=(0.02, 0.2)):
        self.p = p
        self.scale = scale

    def __call__(self, img):
        if random.random() < self.p:
            w, h = img.size
            cutout_size = random.uniform(*self.scale)
            cutout_w = int(w * cutout_size)
            cutout_h = int(h * cutout_size)

            x = random.randint(0, w - cutout_w)
            y = random.randint(0, h - cutout_h)

            img_array = np.array(img)  # np.array创建NumPy数组
            img_array[y:y+cutout_h, x:x+cutout_w] = 0
            img = Image.fromarray(img_array)

        return img

# 使用自定义增强
custom_augmentation = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    RandomGaussianBlur(p=0.3, kernel_size=3),
    RandomCutout(p=0.3, scale=(0.05, 0.15)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

6. 性能优化

6.1 多进程数据加载

Python
def optimized_dataloader(dataset, batch_size=128, num_workers=4, pin_memory=True):
    """
    优化的数据加载器

    参数:
        dataset: 数据集
        batch_size: 批量大小
        num_workers: 工作进程数
        pin_memory: 是否使用固定内存

    返回:
        优化的数据加载器
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True,  # 保持工作进程
        prefetch_factor=2  # 预取因子
    )

# 使用示例
train_loader = optimized_dataloader(
    train_dataset,
    batch_size=128,
    num_workers=4,
    pin_memory=True
)

6.2 缓存数据

Python
class CachedDataset(Dataset):
    """缓存数据集到内存"""

    def __init__(self, dataset, transform=None, cache_size=10000):
        """
        参数:
            dataset: 原始数据集
            transform: 数据变换
            cache_size: 缓存大小
        """
        self.dataset = dataset
        self.transform = transform
        self.cache_size = cache_size
        self.cache = {}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # 检查缓存
        if idx in self.cache:
            return self.cache[idx]

        # 从原始数据集获取
        data = self.dataset[idx]

        # 应用变换
        if self.transform:
            data = self.transform(data)

        # 缓存
        if len(self.cache) < self.cache_size:
            self.cache[idx] = data

        return data

# 使用示例
cached_dataset = CachedDataset(
    train_dataset,
    transform=train_transform,
    cache_size=10000
)

7. 数据验证

7.1 检查数据完整性

Python
def validate_dataset(dataset, num_samples=100):
    """
    验证数据集完整性

    参数:
        dataset: 数据集
        num_samples: 检查的样本数
    """
    print("开始验证数据集...")

    errors = []
    for i in range(min(num_samples, len(dataset))):
        try:  # try/except捕获异常
            image, label = dataset[i]

            # 检查图像
            if isinstance(image, torch.Tensor):  # isinstance检查类型
                if torch.isnan(image).any():  # any()任一为True则返回True
                    errors.append(f"样本 {i}: 图像包含NaN")
                if torch.isinf(image).any():
                    errors.append(f"样本 {i}: 图像包含Inf")
            else:
                if image is None:
                    errors.append(f"样本 {i}: 图像为None")

        except Exception as e:
            errors.append(f"样本 {i}: {str(e)}")

    if errors:
        print(f"发现 {len(errors)} 个错误:")
        for error in errors[:10]:  # 只显示前10个
            print(f"  - {error}")
    else:
        print("✓ 数据集验证通过!")

    return len(errors) == 0

# 验证数据集
is_valid = validate_dataset(train_dataset, num_samples=100)

7.2 数据统计信息

Python
def analyze_dataset(dataloader):
    """
    分析数据集统计信息

    参数:
        dataloader: 数据加载器
    """
    print("分析数据集...")

    all_images = []
    all_labels = []

    for images, labels in dataloader:
        all_images.append(images)
        all_labels.append(labels)

        if len(all_images) * images.size(0) >= 1000:  # 分析1000个样本
            break

    all_images = torch.cat(all_images, dim=0)  # torch.cat沿已有维度拼接张量
    all_labels = torch.cat(all_labels, dim=0)

    # 图像统计
    print(f"\n图像统计:")
    print(f"  形状: {all_images.shape}")
    print(f"  最小值: {all_images.min().item():.4f}")  # 将单元素张量转为Python数值
    print(f"  最大值: {all_images.max().item():.4f}")
    print(f"  均值: {all_images.mean().item():.4f}")
    print(f"  标准差: {all_images.std().item():.4f}")

    # 标签统计
    if all_labels.numel() > 0:
        print(f"\n标签统计:")
        unique_labels, counts = torch.unique(all_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):  # zip按位置配对
            print(f"  类别 {label.item()}: {count.item()} 个样本")

# 分析数据集
analyze_dataset(train_loader)

8. 完整的数据加载流程

Python
def prepare_data(dataset_name='cifar10', batch_size=128, image_size=32,
                 num_workers=4, augment=True):
    """
    完整的数据准备流程

    参数:
        dataset_name: 数据集名称 ('cifar10', 'imagenet', 'custom')
        batch_size: 批量大小
        image_size: 目标图像大小
        num_workers: 工作进程数
        augment: 是否使用数据增强

    返回:
        train_loader, test_loader
    """
    print(f"准备 {dataset_name} 数据集...")

    if dataset_name == 'cifar10':
        train_loader, test_loader = load_cifar10(
            batch_size=batch_size,
            image_size=image_size,
            augment=augment
        )

    elif dataset_name == 'imagenet':
        train_loader = load_imagenet(
            root_dir='/path/to/imagenet',
            batch_size=batch_size,
            image_size=image_size,
            split='train'
        )
        test_loader = load_imagenet(
            root_dir='/path/to/imagenet',
            batch_size=batch_size,
            image_size=image_size,
            split='val'
        )

    elif dataset_name == 'custom':
        # 自定义数据集
        train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(p=0.5) if augment else transforms.Lambda(lambda x: x),  # lambda匿名函数
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        train_dataset = ImageFolderDataset(
            root_dir='./train_images',
            transform=train_transform
        )

        test_dataset = ImageFolderDataset(
            root_dir='./test_images',
            transform=train_transform
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )

    else:
        raise ValueError(f"未知的数据集: {dataset_name}")

    # 打印信息
    print(f"训练集大小: {len(train_loader.dataset)}")
    print(f"测试集大小: {len(test_loader.dataset)}")
    print(f"批量大小: {batch_size}")
    print(f"训练批次数: {len(train_loader)}")
    print(f"测试批次数: {len(test_loader)}")

    return train_loader, test_loader

# 使用示例
train_loader, test_loader = prepare_data(
    dataset_name='cifar10',
    batch_size=128,
    image_size=32,
    num_workers=4,
    augment=True
)

9. 总结

9.1 核心概念回顾

概念 说明
数据加载器 批量、打乱、多进程加载
数据增强 提高模型泛化能力
数据归一化 标准化输入范围
自定义数据集 处理特定格式的数据
性能优化 多进程、缓存、预取

9.2 最佳实践

  1. 选择合适的数据集:根据任务选择合适规模的数据集
  2. 使用数据增强:提高模型泛化能力
  3. 优化加载性能:使用多进程和缓存
  4. 验证数据质量:检查数据完整性和统计信息
  5. 保持一致性:训练和测试使用相同的归一化

9.3 学习建议

  1. 从简单开始:先用CIFAR-10等小数据集练习
  2. 逐步复杂:再尝试ImageNet等大数据集
  3. 自定义数据:学习处理自己的数据
  4. 性能调优:优化数据加载速度

10. 推荐资源

文档

  • PyTorch DataLoader文档
  • torchvision.datasets文档
  • PIL/Pillow文档

工具

  • albumentations: 高级图像增强库
  • imgaug: 图像增强库
  • fastai: 简化的数据加载API

11. 自测问题

  1. 数据增强有哪些常用方法?
  2. 如何优化数据加载的性能?
  3. 为什么需要归一化图像数据?
  4. 如何创建自定义数据集?
  5. 如何验证数据集的完整性?

下一章: 04-完整DDPM实现 - 将所有组件整合在一起