03 - 数据加载与预处理¶
学习时间: 2.5小时 重要性: ⭐⭐⭐⭐⭐ 高质量数据是训练成功的基础
🎯 学习目标¶
完成本章后,你将能够: - 理解扩散模型对数据的要求 - 掌握常用的数据集加载方法 - 实现高效的数据预处理流程 - 创建自定义数据集类 - 优化数据加载性能
1. 数据要求概述¶
1.1 扩散模型对数据的要求¶
| 要求 | 说明 | 原因 |
|---|---|---|
| 图像尺寸 | 通常32×32到512×512 | 影响模型大小和训练速度 |
| 像素值范围 | [0, 1] 或 [-1, 1] | 便于模型学习 |
| 数据质量 | 高质量、多样化 | 影响生成质量 |
| 数据量 | 越多越好 | 提高泛化能力 |
| 标注 | 可选(用于条件生成) | 实现可控生成 |
1.2 常用数据集¶
| 数据集 | 图像数量 | 图像尺寸 | 特点 |
|---|---|---|---|
| CIFAR-10 | 60,000 | 32×32 | 适合快速实验 |
| ImageNet | 1.28M | 224×224 | 大规模、高质量 |
| LSUN | 数百万 | 256×256 | 场景图像 |
| FFHQ | 70,000 | 1024×1024 | 高质量人脸 |
| LAION-5B | 50亿 | 多种尺寸 | 超大规模 |
2. 基础数据加载¶
2.1 CIFAR-10数据集¶
CIFAR-10是学习和实验扩散模型的理想数据集:
Python
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
def load_cifar10(batch_size=128, image_size=32, augment=True):
"""
加载CIFAR-10数据集
参数:
batch_size: 批量大小
image_size: 目标图像大小
augment: 是否使用数据增强
返回:
train_loader, test_loader
"""
# 训练数据变换
if augment:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
else:
train_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 测试数据变换
test_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 加载数据集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=test_transform
)
# 创建数据加载器
train_loader = DataLoader( # DataLoader批量加载数据
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
return train_loader, test_loader
# 使用示例
train_loader, test_loader = load_cifar10(batch_size=128, image_size=32)
print(f"训练集大小: {len(train_loader.dataset)}")
print(f"测试集大小: {len(test_loader.dataset)}")
print(f"批量大小: {train_loader.batch_size}")
print(f"训练批次数: {len(train_loader)}")
2.2 可视化数据集¶
Python
def visualize_cifar10(dataloader, num_samples=16, save_path='cifar10_samples.png'):
"""
可视化CIFAR-10样本
参数:
dataloader: 数据加载器
num_samples: 显示的样本数
save_path: 保存路径
"""
# 获取一个batch
images, labels = next(iter(dataloader))
# 反归一化
images = images * 0.5 + 0.5 # 从[-1, 1]到[0, 1]
images = images.clamp(0, 1)
# 类别名称
classes = ['飞机', '汽车', '鸟', '猫', '鹿',
'狗', '青蛙', '马', '船', '卡车']
# 创建网格
from torchvision.utils import make_grid
grid = make_grid(images[:num_samples], nrow=4, padding=2, normalize=False)
# 显示
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0))
plt.title('CIFAR-10 样本', fontsize=16)
plt.axis('off')
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
# 打印标签
print("样本标签:")
for i in range(min(num_samples, len(labels))):
print(f" {i+1}. {classes[labels[i]]}")
# 可视化
visualize_cifar10(train_loader, num_samples=16)
3. ImageNet数据集¶
3.1 加载ImageNet¶
Python
import os
from torch.utils.data import Dataset
from PIL import Image
class ImageNetDataset(Dataset):
"""自定义ImageNet数据集"""
def __init__(self, root_dir, split='train', transform=None):
"""
参数:
root_dir: 数据集根目录
split: 'train' 或 'val'
transform: 数据变换
"""
self.root_dir = root_dir
self.split = split
self.transform = transform
# 获取图像路径
self.image_paths = []
self.labels = []
if split == 'train':
data_dir = os.path.join(root_dir, 'train')
else:
data_dir = os.path.join(root_dir, 'val')
# 遍历类别文件夹
for class_idx, class_name in enumerate(sorted(os.listdir(data_dir))): # enumerate同时获取索引和元素
class_dir = os.path.join(data_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(class_dir, img_name)
self.image_paths.append(img_path)
self.labels.append(class_idx)
def __len__(self): # __len__定义len()行为
return len(self.image_paths)
def __getitem__(self, idx): # __getitem__定义索引访问行为
# 加载图像
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
def load_imagenet(root_dir, batch_size=128, image_size=256, split='train'):
"""
加载ImageNet数据集
参数:
root_dir: 数据集根目录
batch_size: 批量大小
image_size: 目标图像大小
split: 'train' 或 'val'
返回:
dataloader
"""
# 数据变换
if split == 'train':
transform = transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
else:
transform = transforms.Compose([
transforms.Resize(int(image_size * 1.14)),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 创建数据集
dataset = ImageNetDataset(root_dir, split=split, transform=transform)
# 创建数据加载器
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(split == 'train'),
num_workers=8,
pin_memory=True,
drop_last=(split == 'train')
)
return dataloader
# 使用示例
# train_loader = load_imagenet('/path/to/imagenet', batch_size=128, split='train')
# val_loader = load_imagenet('/path/to/imagenet', batch_size=128, split='val')
4. 自定义数据集¶
4.1 从文件夹加载图像¶
Python
class ImageFolderDataset(Dataset):
"""从文件夹加载图像的自定义数据集"""
def __init__(self, root_dir, transform=None, extensions=('.png', '.jpg', '.jpeg', '.bmp')):
"""
参数:
root_dir: 图像文件夹根目录
transform: 数据变换
extensions: 支持的图像扩展名
"""
self.root_dir = root_dir
self.transform = transform
self.extensions = extensions
# 收集所有图像路径
self.image_paths = []
for root, _, files in os.walk(root_dir):
for file in files:
if file.lower().endswith(extensions):
self.image_paths.append(os.path.join(root, file))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
# 应用变换
if self.transform:
image = self.transform(image)
return image, 0 # 返回图像和虚拟标签
# 使用示例
custom_dataset = ImageFolderDataset(
root_dir='./my_images',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
)
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)
4.2 带标签的自定义数据集¶
Python
class LabeledImageDataset(Dataset):
"""带标签的自定义图像数据集"""
def __init__(self, data_list, transform=None):
"""
参数:
data_list: 数据列表,每个元素为 (image_path, label)
transform: 数据变换
"""
self.data_list = data_list
self.transform = transform
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
img_path, label = self.data_list[idx]
# 加载图像
image = Image.open(img_path).convert('RGB')
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
# 创建数据列表
data_list = [
('./images/cat1.jpg', 0),
('./images/cat2.jpg', 0),
('./images/dog1.jpg', 1),
('./images/dog2.jpg', 1),
]
# 创建数据集
labeled_dataset = LabeledImageDataset(
data_list,
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
)
labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)
5. 数据预处理技巧¶
5.1 图像归一化¶
Python
def compute_dataset_statistics(dataloader):
"""
计算数据集的均值和标准差
参数:
dataloader: 数据加载器
返回:
mean, std
"""
mean = 0.0
std = 0.0
total_images = 0
for images, _ in dataloader:
batch_samples = images.size(0)
images = images.view(batch_samples, images.size(1), -1) # 重塑张量形状
mean += images.mean(2).sum(0)
std += images.std(2).sum(0)
total_images += batch_samples
mean /= total_images
std /= total_images
return mean, std
# 计算CIFAR-10的统计信息
# 注意:这需要先加载未归一化的数据
no_norm_transform = transforms.Compose([
transforms.ToTensor()
])
temp_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=no_norm_transform
)
temp_loader = DataLoader(temp_dataset, batch_size=128, shuffle=False)
mean, std = compute_dataset_statistics(temp_loader)
print(f"数据集均值: {mean}")
print(f"数据集标准差: {std}")
5.2 数据增强策略¶
Python
def get_advanced_augmentation(image_size=256):
"""
获取高级数据增强
参数:
image_size: 目标图像大小
返回:
数据增强变换
"""
return transforms.Compose([
# 随机裁剪和缩放
transforms.RandomResizedCrop(
image_size,
scale=(0.8, 1.0),
ratio=(0.9, 1.1)
),
# 随机水平翻转
transforms.RandomHorizontalFlip(p=0.5),
# 随机旋转
transforms.RandomRotation(degrees=15),
# 颜色抖动
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
# 随机仿射变换
transforms.RandomAffine(
degrees=0,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=5
),
# 转换为张量
transforms.ToTensor(),
# 归一化
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
# 使用示例
augmented_transform = get_advanced_augmentation(image_size=256)
5.3 自定义增强¶
Python
import random
import numpy as np
class RandomGaussianBlur:
"""随机高斯模糊"""
def __init__(self, p=0.5, kernel_size=5):
self.p = p
self.kernel_size = kernel_size
def __call__(self, img): # __call__使实例可像函数一样调用
if random.random() < self.p:
from PIL import ImageFilter
img = img.filter(ImageFilter.GaussianBlur(self.kernel_size))
return img
class RandomCutout:
"""随机遮挡"""
def __init__(self, p=0.5, scale=(0.02, 0.2)):
self.p = p
self.scale = scale
def __call__(self, img):
if random.random() < self.p:
w, h = img.size
cutout_size = random.uniform(*self.scale)
cutout_w = int(w * cutout_size)
cutout_h = int(h * cutout_size)
x = random.randint(0, w - cutout_w)
y = random.randint(0, h - cutout_h)
img_array = np.array(img) # np.array创建NumPy数组
img_array[y:y+cutout_h, x:x+cutout_w] = 0
img = Image.fromarray(img_array)
return img
# 使用自定义增强
custom_augmentation = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
RandomGaussianBlur(p=0.3, kernel_size=3),
RandomCutout(p=0.3, scale=(0.05, 0.15)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
6. 性能优化¶
6.1 多进程数据加载¶
Python
def optimized_dataloader(dataset, batch_size=128, num_workers=4, pin_memory=True):
"""
优化的数据加载器
参数:
dataset: 数据集
batch_size: 批量大小
num_workers: 工作进程数
pin_memory: 是否使用固定内存
返回:
优化的数据加载器
"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=True, # 保持工作进程
prefetch_factor=2 # 预取因子
)
# 使用示例
train_loader = optimized_dataloader(
train_dataset,
batch_size=128,
num_workers=4,
pin_memory=True
)
6.2 缓存数据¶
Python
class CachedDataset(Dataset):
"""缓存数据集到内存"""
def __init__(self, dataset, transform=None, cache_size=10000):
"""
参数:
dataset: 原始数据集
transform: 数据变换
cache_size: 缓存大小
"""
self.dataset = dataset
self.transform = transform
self.cache_size = cache_size
self.cache = {}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# 检查缓存
if idx in self.cache:
return self.cache[idx]
# 从原始数据集获取
data = self.dataset[idx]
# 应用变换
if self.transform:
data = self.transform(data)
# 缓存
if len(self.cache) < self.cache_size:
self.cache[idx] = data
return data
# 使用示例
cached_dataset = CachedDataset(
train_dataset,
transform=train_transform,
cache_size=10000
)
7. 数据验证¶
7.1 检查数据完整性¶
Python
def validate_dataset(dataset, num_samples=100):
"""
验证数据集完整性
参数:
dataset: 数据集
num_samples: 检查的样本数
"""
print("开始验证数据集...")
errors = []
for i in range(min(num_samples, len(dataset))):
try: # try/except捕获异常
image, label = dataset[i]
# 检查图像
if isinstance(image, torch.Tensor): # isinstance检查类型
if torch.isnan(image).any(): # any()任一为True则返回True
errors.append(f"样本 {i}: 图像包含NaN")
if torch.isinf(image).any():
errors.append(f"样本 {i}: 图像包含Inf")
else:
if image is None:
errors.append(f"样本 {i}: 图像为None")
except Exception as e:
errors.append(f"样本 {i}: {str(e)}")
if errors:
print(f"发现 {len(errors)} 个错误:")
for error in errors[:10]: # 只显示前10个
print(f" - {error}")
else:
print("✓ 数据集验证通过!")
return len(errors) == 0
# 验证数据集
is_valid = validate_dataset(train_dataset, num_samples=100)
7.2 数据统计信息¶
Python
def analyze_dataset(dataloader):
"""
分析数据集统计信息
参数:
dataloader: 数据加载器
"""
print("分析数据集...")
all_images = []
all_labels = []
for images, labels in dataloader:
all_images.append(images)
all_labels.append(labels)
if len(all_images) * images.size(0) >= 1000: # 分析1000个样本
break
all_images = torch.cat(all_images, dim=0) # torch.cat沿已有维度拼接张量
all_labels = torch.cat(all_labels, dim=0)
# 图像统计
print(f"\n图像统计:")
print(f" 形状: {all_images.shape}")
print(f" 最小值: {all_images.min().item():.4f}") # 将单元素张量转为Python数值
print(f" 最大值: {all_images.max().item():.4f}")
print(f" 均值: {all_images.mean().item():.4f}")
print(f" 标准差: {all_images.std().item():.4f}")
# 标签统计
if all_labels.numel() > 0:
print(f"\n标签统计:")
unique_labels, counts = torch.unique(all_labels, return_counts=True)
for label, count in zip(unique_labels, counts): # zip按位置配对
print(f" 类别 {label.item()}: {count.item()} 个样本")
# 分析数据集
analyze_dataset(train_loader)
8. 完整的数据加载流程¶
Python
def prepare_data(dataset_name='cifar10', batch_size=128, image_size=32,
num_workers=4, augment=True):
"""
完整的数据准备流程
参数:
dataset_name: 数据集名称 ('cifar10', 'imagenet', 'custom')
batch_size: 批量大小
image_size: 目标图像大小
num_workers: 工作进程数
augment: 是否使用数据增强
返回:
train_loader, test_loader
"""
print(f"准备 {dataset_name} 数据集...")
if dataset_name == 'cifar10':
train_loader, test_loader = load_cifar10(
batch_size=batch_size,
image_size=image_size,
augment=augment
)
elif dataset_name == 'imagenet':
train_loader = load_imagenet(
root_dir='/path/to/imagenet',
batch_size=batch_size,
image_size=image_size,
split='train'
)
test_loader = load_imagenet(
root_dir='/path/to/imagenet',
batch_size=batch_size,
image_size=image_size,
split='val'
)
elif dataset_name == 'custom':
# 自定义数据集
train_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.RandomHorizontalFlip(p=0.5) if augment else transforms.Lambda(lambda x: x), # lambda匿名函数
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = ImageFolderDataset(
root_dir='./train_images',
transform=train_transform
)
test_dataset = ImageFolderDataset(
root_dir='./test_images',
transform=train_transform
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
else:
raise ValueError(f"未知的数据集: {dataset_name}")
# 打印信息
print(f"训练集大小: {len(train_loader.dataset)}")
print(f"测试集大小: {len(test_loader.dataset)}")
print(f"批量大小: {batch_size}")
print(f"训练批次数: {len(train_loader)}")
print(f"测试批次数: {len(test_loader)}")
return train_loader, test_loader
# 使用示例
train_loader, test_loader = prepare_data(
dataset_name='cifar10',
batch_size=128,
image_size=32,
num_workers=4,
augment=True
)
9. 总结¶
9.1 核心概念回顾¶
| 概念 | 说明 |
|---|---|
| 数据加载器 | 批量、打乱、多进程加载 |
| 数据增强 | 提高模型泛化能力 |
| 数据归一化 | 标准化输入范围 |
| 自定义数据集 | 处理特定格式的数据 |
| 性能优化 | 多进程、缓存、预取 |
9.2 最佳实践¶
- 选择合适的数据集:根据任务选择合适规模的数据集
- 使用数据增强:提高模型泛化能力
- 优化加载性能:使用多进程和缓存
- 验证数据质量:检查数据完整性和统计信息
- 保持一致性:训练和测试使用相同的归一化
9.3 学习建议¶
- 从简单开始:先用CIFAR-10等小数据集练习
- 逐步复杂:再尝试ImageNet等大数据集
- 自定义数据:学习处理自己的数据
- 性能调优:优化数据加载速度
10. 推荐资源¶
文档¶
- PyTorch DataLoader文档
- torchvision.datasets文档
- PIL/Pillow文档
工具¶
- albumentations: 高级图像增强库
- imgaug: 图像增强库
- fastai: 简化的数据加载API
11. 自测问题¶
- 数据增强有哪些常用方法?
- 如何优化数据加载的性能?
- 为什么需要归一化图像数据?
- 如何创建自定义数据集?
- 如何验证数据集的完整性?
下一章: 04-完整DDPM实现 - 将所有组件整合在一起