跳转至

第14章 自监督学习

自监督学习图

📚 章节概述

本章介绍自监督学习的核心技术,包括SimCLR、MoCo、MAE等。自监督学习能够利用大量无标注数据,是当前AI研究的热点。

学习时间:5-7天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第5-6章

🎯 学习目标

完成本章后,你将能够: - 理解自监督学习的基本原理 - 掌握对比学习方法 - 了解掩码自编码器 - 能够进行自监督预训练 - 完成自监督学习项目


14.1 对比学习

14.1.1 SimCLR

核心思想: - 数据增强 - 编码器提取特征 - 对比损失

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimCLR(nn.Module):  # 继承nn.Module定义网络层
    def __init__(self, encoder, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projection = nn.Sequential(
            nn.Linear(encoder.output_dim, encoder.output_dim),
            nn.ReLU(),
            nn.Linear(encoder.output_dim, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection(h)
        return z

def nt_xent_loss(z_i, z_j, temperature=0.5):
    """Normalized Temperature-scaled Cross Entropy Loss

    注意:必须屏蔽对角线(自相似度=1.0),否则模型会
    "作弊"匹配自身而非正样本对,导致训练无效。
    """
    batch_size = z_i.shape[0]
    N = 2 * batch_size

    # 拼接正负样本
    z = torch.cat([z_i, z_j], dim=0)  # torch.cat沿已有维度拼接张量

    # 计算相似度矩阵 (2N x 2N)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)  # unsqueeze增加一个维度  # F.xxx PyTorch函数式API

    # ⚠️ 屏蔽对角线(自相似度),设为 -inf 使其在 softmax 中贡献为 0
    mask = torch.eye(N, dtype=torch.bool, device=z_i.device)
    sim = sim.masked_fill(mask, float('-inf'))

    # 创建标签:z_i[k] 的正样本是 z_j[k](索引 batch_size+k)
    # 屏蔽对角线后标签索引需要调整(每行少了一个自身元素)
    # 但 cross_entropy 会忽略 -inf 位置,标签仍指向原始列索引
    labels = torch.arange(batch_size, device=z_i.device)
    labels = torch.cat([labels + batch_size, labels], dim=0)

    # 温度缩放
    sim = sim / temperature

    # 损失
    loss = F.cross_entropy(sim, labels)  # F.cross_entropy PyTorch函数式交叉熵损失

    return loss

14.1.2 MoCo (Momentum Contrast)

核心创新: - 动量编码器 - 队列机制

Python
import copy

class MoCo(nn.Module):
    def __init__(self, encoder, K=65536, m=0.999, temperature=0.07):
        super(MoCo, self).__init__()
        self.K = K
        self.m = m
        self.T = temperature

        # Query编码器
        self.encoder_q = encoder

        # Key编码器(动量更新)— 必须深拷贝,不能共享引用
        self.encoder_k = copy.deepcopy(encoder)
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):  # zip按位置配对
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # 队列
        self.register_buffer("queue", torch.randn(encoder.output_dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()  # 禁用梯度计算,节省内存
    def _momentum_update_key_encoder(self):
        """动量更新key编码器"""
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """更新队列"""
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)

        # 更新队列
        self.queue[:, ptr:ptr + batch_size] = keys.T

        # 更新指针
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        # Query
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        # Key(不计算梯度)
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # 计算相似度
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])  # 分离计算图,不参与梯度计算

        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T

        # 标签(正样本在第一个位置)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(q.device)

        loss = F.cross_entropy(logits, labels)

        # 更新队列
        self._dequeue_and_enqueue(k)

        return loss

14.2 掩码自编码器

14.2.1 MAE (Masked Autoencoder)

Python
class MAE(nn.Module):
    def __init__(self, encoder, decoder, mask_ratio=0.75):
        super(MAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.mask_ratio = mask_ratio

    def random_masking(self, x, mask_ratio):
        """随机掩码"""
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))

        # 随机噪声
        noise = torch.rand(N, L, device=x.device)

        # 选择保留的patch
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # 生成mask
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward(self, x):
        # 随机掩码
        x_masked, mask, ids_restore = self.random_masking(x, self.mask_ratio)

        # 编码
        latent = self.encoder(x_masked)

        # 解码
        pred = self.decoder(latent, ids_restore)

        # 损失(只计算被掩码的部分)
        loss = (pred - x) ** 2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum() if mask.sum() > 0 else loss.sum()

        return loss, pred, mask

14.3 实战案例:自监督预训练

Python
import torch.optim as optim
from torchvision import datasets, transforms

# 数据增强
class SimCLRTransform:
    def __init__(self, size=224):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=int(0.1 * size), sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __call__(self, x):  # __call__使实例可像函数一样调用
        return self.transform(x), self.transform(x)

# 数据集
transform = SimCLRTransform()
dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)  # DataLoader批量加载数据

# 模型
encoder = ResNet18()
model = SimCLR(encoder).cuda()

# 优化器
optimizer = optim.Adam(model.parameters(), lr=3e-4)

# 训练
def train_simclr(model, dataloader, epochs=100):
    model.train()  # train()训练模式
    for epoch in range(epochs):
        total_loss = 0.0
        for (x_i, x_j), _ in dataloader:
            x_i, x_j = x_i.cuda(), x_j.cuda()

            # 前向传播
            z_i = model(x_i)
            z_j = model(x_j)

            # 损失
            loss = nt_xent_loss(z_i, z_j)

            # 反向传播
            optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新参数

            total_loss += loss.item()  # 将单元素张量转为Python数值

        print(f'Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}')

# 训练
train_simclr(model, dataloader, epochs=100)

14.4 练习题

基础题

  1. 简答题
  2. 什么是自监督学习?

    自监督学习是一种无需人工标注的表征学习方法,通过从数据本身构造监督信号(如预测遮挡区域、对比不同增强视图等)来学习通用特征。与监督学习相比,它可利用海量无标注数据,学到的表征泛化能力强,在下游任务微调后往往能媲美甚至超越监督预训练。主要范式包括对比学习(SimCLR、MoCo)和掩码预测(MAE、BEiT)。

  3. SimCLR和MoCo有什么区别?

    SimCLR使用端到端训练,正负样本全部来自当前mini-batch,依赖超大batch size(如4096)提供足够负样本,因此对GPU内存要求高。MoCo引入动量编码器和队列(负样本字典),用动量更新的编码器生成一致的负样本表征,将负样本数量与batch size解耦,可在小batch(如256)下获得大量高质量负样本。MoCo更节省显存,SimCLR架构更简洁。

进阶题

  1. 编程题
  2. 实现一个简单的对比学习。
  3. 使用MAE进行预训练。

14.5 面试准备

大厂面试题

Q1: 自监督学习和监督学习有什么区别?

参考答案: - 监督学习:需要标注数据 - 自监督学习:从数据本身生成标签 - 优势: - 无需标注 - 数据量大 - 泛化能力强

Q2: 对比学习的核心思想是什么?

参考答案: - 拉近正样本对 - 推远负样本对 - 对比损失 - 数据增强关键


14.6 本章小结

核心知识点

  1. 对比学习:SimCLR、MoCo
  2. MAE:掩码自编码器
  3. 自监督预训练:无标注数据

下一步

下一章15-模型部署与优化.md - 学习模型部署


恭喜完成第14章! 🎉