01 - 环境搭建与数据准备¶
学习时间: 2小时 重要性: ⭐⭐⭐⭐⭐ 动手实践的第一步
🎯 学习目标¶
完成本章后,你将能够: - 搭建完整的扩散模型开发环境 - 准备和处理训练数据 - 理解数据预处理的重要性 - 实现高效的数据加载器
1. 环境配置¶
1.1 硬件要求¶
最低配置: - GPU: NVIDIA GTX 1060 6GB 或更高 - 内存: 16GB RAM - 存储: 50GB 可用空间
推荐配置: - GPU: NVIDIA RTX 3060 12GB 或更高 - 内存: 32GB RAM - 存储: 100GB SSD
1.2 软件环境¶
Bash
# 创建conda环境
conda create -n diffusion python=3.9
conda activate diffusion
# 安装PyTorch(根据你的CUDA版本选择)
# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 或 CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# 安装其他依赖
pip install numpy matplotlib pillow tqdm tensorboard einops
# 可选:安装wandb用于实验跟踪
pip install wandb
1.3 验证安装¶
Python
import torch
import torchvision
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA版本: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
2. 数据集准备¶
2.1 数据集选择¶
初学者推荐: - MNIST: 手写数字,28x28,10类,适合快速验证 - CIFAR-10: 彩色小图像,32x32,10类,标准基准 - CelebA: 人脸图像,64x64或更高,无条件生成
进阶数据集: - ImageNet: 大规模图像分类数据集 - LSUN: 场景理解数据集 - 自定义数据集: 你自己的图像数据
2.2 下载MNIST数据集¶
Python
import torchvision
import torchvision.transforms as transforms
# 定义数据变换
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量 [0, 1]
transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1]
])
# 下载训练集
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
# 下载测试集
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transform
)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"图像形状: {train_dataset[0][0].shape}")
print(f"标签: {train_dataset[0][1]}")
2.3 下载CIFAR-10数据集¶
Python
# CIFAR-10数据变换
transform_cifar = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 数据增强
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # RGB三通道
])
# 下载CIFAR-10
train_dataset_cifar = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform_cifar
)
test_dataset_cifar = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform_cifar
)
print(f"训练集大小: {len(train_dataset_cifar)}")
print(f"类别: {train_dataset_cifar.classes}")
2.4 自定义数据集¶
Python
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomImageDataset(Dataset):
"""
自定义图像数据集
"""
def __init__(self, root_dir, transform=None):
"""
参数:
root_dir: 图像文件夹路径
transform: 数据变换
"""
self.root_dir = root_dir
self.transform = transform
self.image_files = [
f for f in os.listdir(root_dir)
if f.endswith(('.png', '.jpg', '.jpeg'))
]
def __len__(self): # __len__定义len()行为
return len(self.image_files)
def __getitem__(self, idx): # __getitem__定义索引访问行为
img_path = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, 0 # 返回图像和虚拟标签
# 使用示例
# custom_dataset = CustomImageDataset(
# root_dir='path/to/your/images',
# transform=transform_cifar
# )
3. 数据预处理¶
3.1 为什么需要预处理¶
- 归一化: 将像素值从 [0, 255] 转换到 [-1, 1] 或 [0, 1]
- 统一尺寸: 确保所有图像大小一致
- 数据增强: 增加数据多样性,防止过拟合
- 去噪: 清理低质量图像
3.2 数据变换流程¶
Python
def get_transforms(image_size=32, train=True):
"""
获取数据变换
参数:
image_size: 输出图像尺寸
train: 是否为训练模式
返回:
transform: 组合的数据变换
"""
transforms_list = []
# 调整图像大小
transforms_list.append(transforms.Resize(image_size))
# 中心裁剪或随机裁剪
if train:
transforms_list.append(transforms.RandomCrop(image_size, padding=4))
else:
transforms_list.append(transforms.CenterCrop(image_size))
# 随机水平翻转(仅训练)
if train:
transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
# 转换为张量
transforms_list.append(transforms.ToTensor())
# 归一化到 [-1, 1]
transforms_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
return transforms.Compose(transforms_list)
# 测试变换
test_transform = get_transforms(image_size=32, train=True)
print(test_transform)
3.3 数据可视化¶
Python
import matplotlib.pyplot as plt
import numpy as np
def denormalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
"""
反归一化,将张量转换回可显示的图像
"""
mean = torch.tensor(mean).view(3, 1, 1) # 重塑张量形状
std = torch.tensor(std).view(3, 1, 1)
tensor = tensor * std + mean
return torch.clamp(tensor, 0, 1)
def visualize_dataset(dataset, num_samples=16, figsize=(12, 12)):
"""
可视化数据集样本
"""
fig, axes = plt.subplots(4, 4, figsize=figsize)
axes = axes.flatten()
for i in range(num_samples):
idx = np.random.randint(len(dataset))
image, label = dataset[idx]
# 反归一化
image = denormalize(image)
# 转换为numpy数组
image_np = image.permute(1, 2, 0).numpy()
axes[i].imshow(image_np)
if hasattr(dataset, 'classes'): # hasattr检查对象是否有某属性
axes[i].set_title(f'Class: {dataset.classes[label]}')
else:
axes[i].set_title(f'Label: {label}')
axes[i].axis('off')
plt.tight_layout()
plt.savefig('dataset_visualization.png', dpi=150)
plt.show()
# 可视化CIFAR-10
# visualize_dataset(train_dataset_cifar)
4. 数据加载器¶
4.1 创建DataLoader¶
Python
from torch.utils.data import DataLoader
def create_dataloaders(
dataset_name='cifar10',
batch_size=128,
num_workers=4,
image_size=32
):
"""
创建数据加载器
参数:
dataset_name: 数据集名称 ('mnist', 'cifar10', 'custom')
batch_size: 批次大小
num_workers: 数据加载线程数
image_size: 图像尺寸
返回:
train_loader, test_loader
"""
# 获取数据变换
train_transform = get_transforms(image_size, train=True)
test_transform = get_transforms(image_size, train=False)
# 加载数据集
if dataset_name == 'mnist':
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=test_transform
)
elif dataset_name == 'cifar10':
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=test_transform
)
else:
raise ValueError(f"Unknown dataset: {dataset_name}")
# 创建数据加载器
train_loader = DataLoader( # DataLoader批量加载数据
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader
# 创建数据加载器
train_loader, test_loader = create_dataloaders(
dataset_name='cifar10',
batch_size=128,
num_workers=4,
image_size=32
)
print(f"训练批次数量: {len(train_loader)}")
print(f"测试批次数量: {len(test_loader)}")
# 测试数据加载
for batch_idx, (images, labels) in enumerate(train_loader): # enumerate同时获取索引和元素
print(f"批次 {batch_idx}:")
print(f" 图像形状: {images.shape}")
print(f" 标签形状: {labels.shape}")
print(f" 图像范围: [{images.min():.2f}, {images.max():.2f}]")
break
4.2 优化数据加载¶
Python
class InfiniteDataLoader:
"""
无限数据加载器,用于持续训练
"""
def __init__(self, dataloader):
self.dataloader = dataloader
self.data_iter = iter(dataloader)
def __next__(self):
try: # try/except捕获异常
batch = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.dataloader)
batch = next(self.data_iter)
return batch
def __iter__(self):
return self
# 使用示例
# infinite_loader = InfiniteDataLoader(train_loader)
# for i in range(1000000):
# images, labels = next(infinite_loader)
# # 训练代码...
5. 数据管道完整代码¶
Python
# data_utils.py
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
class DiffusionDataModule:
"""
扩散模型数据模块
"""
def __init__(
self,
dataset_name='cifar10',
data_dir='./data',
batch_size=128,
num_workers=4,
image_size=32,
train_split=0.9
):
self.dataset_name = dataset_name
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.image_size = image_size
self.train_split = train_split
# 数据变换
self.train_transform = self._get_transforms(train=True)
self.test_transform = self._get_transforms(train=False)
def _get_transforms(self, train=True):
"""获取数据变换"""
transforms_list = [
transforms.Resize(self.image_size),
]
if train:
transforms_list.extend([
transforms.RandomCrop(self.image_size, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
])
else:
transforms_list.append(transforms.CenterCrop(self.image_size))
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transforms.Compose(transforms_list)
def setup(self):
"""设置数据集"""
if self.dataset_name == 'mnist':
self.train_dataset = torchvision.datasets.MNIST(
self.data_dir, train=True, download=True, transform=self.train_transform
)
self.test_dataset = torchvision.datasets.MNIST(
self.data_dir, train=False, download=True, transform=self.test_transform
)
self.num_classes = 10
self.channels = 1
elif self.dataset_name == 'cifar10':
self.train_dataset = torchvision.datasets.CIFAR10(
self.data_dir, train=True, download=True, transform=self.train_transform
)
self.test_dataset = torchvision.datasets.CIFAR10(
self.data_dir, train=False, download=True, transform=self.test_transform
)
self.num_classes = 10
self.channels = 3
elif self.dataset_name == 'celeba':
# CelebA需要单独下载
self.train_dataset = torchvision.datasets.CelebA(
self.data_dir, split='train', download=True, transform=self.train_transform
)
self.test_dataset = torchvision.datasets.CelebA(
self.data_dir, split='test', download=True, transform=self.test_transform
)
self.num_classes = None
self.channels = 3
else:
raise ValueError(f"Unknown dataset: {self.dataset_name}")
def train_dataloader(self):
"""训练数据加载器"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
def test_dataloader(self):
"""测试数据加载器"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
def get_batch(self, split='train'):
"""获取一个批次"""
loader = self.train_dataloader() if split == 'train' else self.test_dataloader()
return next(iter(loader))
# 测试
if __name__ == "__main__":
# 创建数据模块
data_module = DiffusionDataModule(
dataset_name='cifar10',
batch_size=16,
image_size=32
)
# 设置数据集
data_module.setup()
# 获取数据加载器
train_loader = data_module.train_dataloader()
# 测试
images, labels = next(iter(train_loader))
print(f"图像批次形状: {images.shape}")
print(f"标签批次形状: {labels.shape}")
print(f"图像范围: [{images.min():.3f}, {images.max():.3f}]")
print(f"数据集类别数: {data_module.num_classes}")
print(f"图像通道数: {data_module.channels}")
6. 本章总结¶
核心要点¶
- 环境配置
- PyTorch + CUDA
- 必要的Python包
-
验证GPU可用性
-
数据准备
- 选择合适的数据集
- 下载和预处理
-
自定义数据集支持
-
数据变换
- 归一化到 [-1, 1]
- 数据增强
-
统一尺寸
-
数据加载
- 高效的DataLoader
- 多线程加载
- 无限数据流
关键代码¶
Python
# 数据变换
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 数据加载器
DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
📝 自测问题¶
基础问题¶
- 数据预处理
- 为什么需要将图像归一化到 [-1, 1]?
- 数据增强在扩散模型中起什么作用?
-
如何选择合适的图像尺寸?
-
数据加载
- num_workers参数的作用是什么?
- 为什么要使用pin_memory=True?
- drop_last=True有什么影响?
编程练习¶
- 实现一个自定义数据集类
- 可视化不同数据增强的效果
- 实现数据加载的性能测试
思考题¶
- 如果图像不归一化,会对训练产生什么影响?
- 如何设计适合高分辨率图像的数据管道?
- 数据质量对扩散模型有多重要?
🔗 下一步¶
环境搭建完成后,我们将学习UNet架构详解,这是扩散模型的核心网络结构。
→ 下一步:02-UNet架构详解.md