跳转至

项目1: 图像分类实战

难度: ⭐⭐⭐ 中等 时间: 10-15小时 涉及知识: CNN、PyTorch、数据增强、迁移学习


📖 项目概述

图像分类流程

项目背景

图像分类是计算机视觉的核心任务之一,广泛应用于医疗诊断、自动驾驶、安防监控等领域。本项目将带你从零开始构建一个完整的图像分类系统,涵盖数据准备、模型设计、训练优化到部署的全流程。

项目目标

构建一个完整的图像分类系统,能够: - 处理大规模图像数据集 - 实现多种CNN架构 - 应用数据增强技术 - 使用迁移学习提升性能 - 进行模型评估和可视化 - 部署模型为Web服务

技术栈

  • 深度学习框架: PyTorch
  • 数据处理: torchvision, PIL
  • 可视化: matplotlib, tensorboard
  • Web框架: FastAPI
  • 前端: Streamlit

🏗️ 项目结构

Text Only
image-classification/
├── data/                      # 数据目录
│   ├── raw/                  # 原始数据
│   ├── processed/            # 处理后数据
│   └── augment/              # 增强数据
├── models/                   # 模型目录
│   ├── __init__.py
│   ├── cnn.py               # CNN模型
│   ├── resnet.py            # ResNet模型
│   └── vgg.py               # VGG模型
├── utils/                    # 工具函数
│   ├── __init__.py
│   ├── data_loader.py       # 数据加载
│   ├── augmentation.py      # 数据增强
│   ├── metrics.py           # 评估指标
│   └── visualization.py     # 可视化
├── train.py                  # 训练脚本
├── evaluate.py               # 评估脚本
├── inference.py              # 推理脚本
├── app.py                    # Web应用
├── config.py                 # 配置文件
├── requirements.txt          # 依赖文件
└── README.md                # 项目说明

🎯 核心功能

1. 数据处理

数据增强对比

  • 数据加载: 高效加载图像数据
  • 数据预处理: 归一化、调整大小
  • 数据增强: 旋转、翻转、裁剪等
  • 数据划分: 训练集、验证集、测试集

2. 模型设计

CNN架构

  • 基础CNN: 自定义CNN架构
  • ResNet: 残差网络
  • VGG: 经典VGG架构
  • 迁移学习: 使用预训练模型

迁移学习流程

3. 训练优化

  • 损失函数: 交叉熵损失
  • 优化器: SGD, Adam
  • 学习率调度: StepLR, CosineAnnealingLR
  • 早停机制: 防止过拟合

4. 评估分析

混淆矩阵

  • 准确率: Top-1和Top-5准确率
  • 混淆矩阵: 可视化分类结果
  • ROC曲线: 评估分类性能
  • 错误分析: 分析错误样本

ROC曲线

5. 模型部署

  • 模型保存: 保存最佳模型
  • 模型加载: 加载训练好的模型
  • 批量推理: 批量处理图像
  • Web服务: 提供REST API

💻 代码实现

1. 配置文件 (config.py)

Python
"""
图像分类配置文件
"""
import torch
from dataclasses import dataclass

@dataclass  # @dataclass自动生成__init__等方法
class Config:
    """配置类"""

    # 数据配置
    data_dir: str = "./data"
    batch_size: int = 32
    num_workers: int = 4
    image_size: tuple[int, int] = (224, 224)

    # 模型配置
    num_classes: int = 10
    model_type: str = "resnet18"  # cnn, resnet18, resnet50, vgg16
    pretrained: bool = True

    # 训练配置
    num_epochs: int = 50
    learning_rate: float = 0.001
    weight_decay: float = 1e-4
    momentum: float = 0.9

    # 学习率调度
    lr_scheduler: str = "step"  # step, cosine
    step_size: int = 10
    gamma: float = 0.1

    # 早停配置
    early_stopping: bool = True
    patience: int = 10

    # 设备配置
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # 保存配置
    checkpoint_dir: str = "./checkpoints"
    log_dir: str = "./logs"

config = Config()

2. 数据加载器 (utils/data_loader.py)

Python
"""
数据加载器
"""
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image
import os

class CustomDataset(Dataset):
    """自定义数据集"""

    def __init__(self, root_dir, transform=None):  # __init__构造方法,创建对象时自动调用
        """
        初始化数据集

        Args:
            root_dir: 数据根目录
            transform: 数据变换
        """
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # 加载数据
        self._load_data()

    def _load_data(self):
        """加载数据"""
        classes = sorted(os.listdir(self.root_dir))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}  # enumerate同时获取索引和元素

        for class_name in classes:
            class_dir = os.path.join(self.root_dir, class_name)
            class_idx = self.class_to_idx[class_name]

            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.images.append(img_path)
                self.labels.append(class_idx)

    def __len__(self):  # __len__定义len()的行为
        return len(self.images)

    def __getitem__(self, idx):  # __getitem__定义索引访问行为
        """获取数据项"""
        img_path = self.images[idx]
        label = self.labels[idx]

        # 加载图像
        image = Image.open(img_path).convert('RGB')

        # 应用变换
        if self.transform:
            image = self.transform(image)

        return image, label

def get_data_loaders(config):
    """
    获取数据加载器

    Args:
        config: 配置对象

    Returns:
        训练、验证、测试数据加载器
    """
    # 训练数据变换(包含增强)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(config.image_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])

    # 验证和测试数据变换
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])

    # 创建数据集
    train_dataset = CustomDataset(
        os.path.join(config.data_dir, 'train'),
        transform=train_transform
    )

    val_dataset = CustomDataset(
        os.path.join(config.data_dir, 'val'),
        transform=test_transform
    )

    test_dataset = CustomDataset(
        os.path.join(config.data_dir, 'test'),
        transform=test_transform
    )

    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

3. CNN模型 (models/cnn.py)

Python
"""
CNN模型
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):  # 继承nn.Module定义神经网络层
    """简单的CNN模型"""

    def __init__(self, num_classes=10):
        """
        初始化模型

        Args:
            num_classes: 类别数
        """
        super(SimpleCNN, self).__init__()

        # 特征提取层
        self.features = nn.Sequential(
            # 卷积块1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # 卷积块2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # 卷积块3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # 分类层
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128 * 28 * 28, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        """前向传播"""
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class ResidualBlock(nn.Module):
    """残差块"""

    def __init__(self, in_channels, out_channels, stride=1):
        """
        初始化残差块

        Args:
            in_channels: 输入通道数
            out_channels: 输出通道数
            stride: 步长
        """
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                             stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                             stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        """前向传播"""
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    """ResNet模型"""

    def __init__(self, block, num_blocks, num_classes=10):
        """
        初始化ResNet

        Args:
            block: 残差块类型
            num_blocks: 每层的块数列表
            num_classes: 类别数
        """
        super(ResNet, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                             bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """创建层"""
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        """前向传播"""
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def resnet18(num_classes=10):
    """创建ResNet-18"""
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes)

4. 训练脚本 (train.py)

训练过程可视化

Python
"""
训练脚本
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter

from config import config
from utils.data_loader import get_data_loaders
from models.cnn import SimpleCNN, resnet18
from utils.metrics import AverageMeter, accuracy

class EarlyStopping:
    """早停机制"""

    def __init__(self, patience=10, min_delta=0):
        """
        初始化早停

        Args:
            patience: 耐心值
            min_delta: 最小改进
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss):  # __call__使实例可像函数一样被调用
        """检查是否早停"""
        if self.best_score is None:
            self.best_score = val_loss
        elif val_loss > self.best_score - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.counter = 0

def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """训练一个epoch"""
    model.train()  # train()开启训练模式

    losses = AverageMeter()
    accs = AverageMeter()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)  # .to(device)将数据移至GPU/CPU

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 根据梯度更新模型参数

        # 计算准确率
        acc = accuracy(outputs, labels, topk=(1,))[0]

        # 更新统计
        losses.update(loss.item(), images.size(0))
        accs.update(acc.item(), images.size(0))

        # 更新进度条
        pbar.set_postfix({
            'loss': f'{losses.avg:.4f}',
            'acc': f'{accs.avg:.2%}'
        })

    return losses.avg, accs.avg

def validate(model, val_loader, criterion, device):
    """验证模型"""
    model.eval()  # eval()开启评估模式(关闭Dropout等)

    losses = AverageMeter()
    accs = AverageMeter()

    with torch.no_grad():  # 禁用梯度计算,节省内存
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 计算准确率
            acc = accuracy(outputs, labels, topk=(1,))[0]

            # 更新统计
            losses.update(loss.item(), images.size(0))
            accs.update(acc.item(), images.size(0))

    return losses.avg, accs.avg

def train():
    """训练模型"""
    # 创建保存目录
    os.makedirs(config.checkpoint_dir, exist_ok=True)
    os.makedirs(config.log_dir, exist_ok=True)

    # 创建数据加载器
    train_loader, val_loader, _ = get_data_loaders(config)

    # 创建模型
    if config.model_type == 'cnn':
        model = SimpleCNN(config.num_classes)
    elif config.model_type == 'resnet18':
        model = resnet18(config.num_classes)
    else:
        raise ValueError(f"Unknown model type: {config.model_type}")

    model = model.to(config.device)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=config.learning_rate,
        momentum=config.momentum,
        weight_decay=config.weight_decay
    )

    # 学习率调度器
    if config.lr_scheduler == 'step':
        scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
    elif config.lr_scheduler == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=config.num_epochs)
    else:
        scheduler = None

    # 早停
    early_stopping = EarlyStopping(patience=config.patience) if config.early_stopping else None

    # TensorBoard
    writer = SummaryWriter(config.log_dir)

    # 训练循环
    best_val_acc = 0.0

    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch + 1}/{config.num_epochs}")

        # 训练
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, config.device, epoch
        )

        # 验证
        val_loss, val_acc = validate(model, val_loader, criterion, config.device)

        # 学习率调度
        if scheduler:
            scheduler.step()

        # 记录
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', train_acc, epoch)
        writer.add_scalar('Accuracy/val', val_acc, epoch)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2%}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2%}")

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, os.path.join(config.checkpoint_dir, 'best_model.pth'))
            print(f"✓ 保存最佳模型 (Val Acc: {val_acc:.2%})")

        # 早停检查
        if early_stopping:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("早停触发,停止训练")
                break

    writer.close()
    print(f"\n训练完成!最佳验证准确率: {best_val_acc:.2%}")

if __name__ == "__main__":
    train()

5. 评估指标 (utils/metrics.py)

Python
"""
评估指标
"""
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

class AverageMeter:
    """计算并存储平均值和当前值"""

    def __init__(self):
        self.reset()

    def reset(self):
        """重置"""
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """更新"""
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """
    计算Top-k准确率

    Args:
        output: 模型输出
        target: 目标标签
        topk: top-k列表

    Returns:
        top-k准确率列表
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))  # view重塑张量形状(要求内存连续)

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)  # reshape重塑张量形状
            res.append(correct_k.mul_(100.0 / batch_size).item())
        return res

def plot_confusion_matrix(y_true, y_pred, classes, save_path=None):
    """
    绘制混淆矩阵

    Args:
        y_true: 真实标签
        y_pred: 预测标签
        classes: 类别名称
        save_path: 保存路径
    """
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=classes, yticklabels=classes
    )
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def print_classification_report(y_true, y_pred, classes):
    """
    打印分类报告

    Args:
        y_true: 真实标签
        y_pred: 预测标签
        classes: 类别名称
    """
    report = classification_report(y_true, y_pred, target_names=classes)
    print(report)

6. Streamlit应用 (app.py)

Python
"""
图像分类Web应用
"""
import os
import streamlit as st
import torch
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F

from config import config
from models.cnn import SimpleCNN, resnet18

# 页面配置
st.set_page_config(
    page_title="图像分类系统",
    page_icon="🖼️",
    layout="wide"
)

# 标题
st.title("🖼️ 图像分类系统")
st.markdown("---")

# 侧边栏
st.sidebar.header("模型设置")

# 模型选择
model_type = st.sidebar.selectbox(
    "选择模型",
    ["resnet18", "cnn"],
)

# 设备选择
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.text(f"设备: {device}")

# 加载模型
@st.cache_resource
def load_model(model_type, num_classes):
    """加载模型"""
    if model_type == 'resnet18':
        model = resnet18(num_classes)
    else:
        model = SimpleCNN(num_classes)

    # 加载权重
    checkpoint_path = f"{config.checkpoint_dir}/best_model.pth"
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        st.sidebar.success("✓ 模型加载成功")
    else:
        st.sidebar.warning("⚠ 未找到模型权重")

    model = model.to(device)
    model.eval()
    return model

# 类别名称
class_names = [
    "飞机", "汽车", "鸟", "猫", "鹿",
    "狗", "青蛙", "马", "船", "卡车"
]

# 加载模型
model = load_model(model_type, len(class_names))

# 图像预处理
def preprocess_image(image):
    """预处理图像"""
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)  # unsqueeze增加一个维度

# 主界面
col1, col2 = st.columns(2)

with col1:
    st.subheader("上传图像")

    # 上传图像
    uploaded_file = st.file_uploader(
        "选择图像",
        type=['jpg', 'jpeg', 'png'],
    )

    if uploaded_file is not None:
        # 显示图像
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption="上传的图像", use_container_width=True)

        # 预测
        with st.spinner("正在预测..."):
            input_tensor = preprocess_image(image).to(device)

            with torch.no_grad():
                outputs = model(input_tensor)
                probabilities = F.softmax(outputs, dim=1)[0]
                confidence, predicted = torch.max(probabilities, 0)

            # 显示结果
            with col2:
                st.subheader("预测结果")

                # 显示预测类别和置信度
                st.metric(
                    "预测类别",
                    class_names[predicted.item()],
                    f"置信度: {confidence.item():.2%}"
                )

                # 显示所有类别的概率
                st.write("所有类别概率:")
                probs_data = []
                for i, (class_name, prob) in enumerate(zip(class_names, probabilities)):  # zip按位置配对多个可迭代对象
                    probs_data.append({
                        "类别": class_name,
                        "概率": f"{prob.item():.2%}"
                    })

                st.dataframe(
                    probs_data,
                    use_container_width=True,
                    hide_index=True
                )

                # 可视化概率分布
                st.write("概率分布:")
                fig, ax = plt.subplots(figsize=(10, 6))
                ax.bar(class_names, probabilities.cpu().numpy())
                ax.set_xlabel('类别')
                ax.set_ylabel('概率')
                ax.set_title('各类别预测概率')
                plt.xticks(rotation=45)
                st.pyplot(fig)

🧪 测试方法

1. 单元测试

Python
"""
单元测试示例
"""
import pytest
import torch
from models.cnn import SimpleCNN, resnet18

def test_cnn_forward():
    """测试CNN前向传播"""
    model = SimpleCNN(num_classes=10)
    x = torch.randn(2, 3, 224, 224)
    output = model(x)

    assert output.shape == (2, 10)  # assert断言,条件为False时抛出异常
    print("✓ CNN前向传播测试通过")

def test_resnet_forward():
    """测试ResNet前向传播"""
    model = resnet18(num_classes=10)
    x = torch.randn(2, 3, 224, 224)
    output = model(x)

    assert output.shape == (2, 10)
    print("✓ ResNet前向传播测试通过")

2. 集成测试

Python
"""
集成测试示例
"""
def test_training_pipeline():
    """测试训练流程"""
    from config import config
    from utils.data_loader import get_data_loaders

    # 加载数据
    train_loader, val_loader, _ = get_data_loaders(config)

    # 创建模型
    model = SimpleCNN(config.num_classes)
    model = model.to(config.device)

    # 测试一个batch
    images, labels = next(iter(train_loader))
    images, labels = images.to(config.device), labels.to(config.device)

    # 前向传播
    outputs = model(images)

    assert outputs.shape[0] == config.batch_size
    assert outputs.shape[1] == config.num_classes

    print("✓ 训练流程测试通过")

📊 扩展建议

1. 功能扩展

  • 多标签分类: 支持一个图像多个标签
  • 细粒度分类: 区分相似类别
  • 零样本分类: 使用CLIP等模型
  • 可解释性: 使用Grad-CAM可视化

2. 性能优化

  • 模型压缩: 量化、剪枝
  • 知识蒸馏: 训练小模型
  • 混合精度训练: 使用FP16
  • 分布式训练: 多GPU训练

3. 部署优化

  • ONNX导出: 转换为ONNX格式
  • TensorRT加速: 使用TensorRT推理
  • 移动端部署: 部署到手机
  • 边缘设备: 部署到树莓派等

📚 学习收获

完成本项目后,你将掌握:

  • ✅ CNN架构设计和实现
  • ✅ 数据增强技术
  • ✅ 迁移学习方法
  • ✅ 模型训练和优化
  • ✅ 模型评估和分析
  • ✅ Streamlit Web应用开发
  • ✅ 完整的图像分类项目开发

🔗 参考资源


项目完成时间: 10-15小时 难度等级: ⭐⭐⭐ 中等 推荐指数: ⭐⭐⭐⭐⭐