跳转至

02-自监督学习

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐⭐ 高级 前置知识: CNN/Transformer基础、对比学习概念、PyTorch 学习目标: 理解自监督学习的核心范式(对比学习、掩码建模、自蒸馏),掌握SimCLR、BYOL、MAE等方法


目录


1. 自监督学习概述

1.1 学习范式对比

范式 标签需求 代表方法
监督学习 需要大量人工标注 ResNet, BERT(微调)
无监督学习 不需要标签 K-Means, AE
自监督学习 从数据自身构造监督信号 SimCLR, MAE, BERT(预训练)

1.2 自监督学习的核心思想

从未标注数据中自动构造监督信号("代理任务"),学习通用的特征表示,然后迁移到下游任务。

1.3 主要范式

自监督学习范式

Text Only
自监督学习
├── 对比学习 (Contrastive Learning)
│   ├── SimCLR
│   ├── MoCo
│   └── CLIP
├── 自蒸馏 (Self-Distillation)
│   ├── BYOL
│   └── DINO
└── 掩码建模 (Masked Modeling)
    ├── BERT (NLP)
    ├── MAE (Vision)
    └── BEiT

2. 前置任务(Pretext Tasks)

2.1 经典前置任务

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

# 1. 拼图 (Jigsaw Puzzle)
class JigsawTransform:
    """将图像切成 3x3 的块并打乱"""
    def __init__(self, num_permutations=100):  # __init__构造方法,创建对象时自动调用
        self.grid_size = 3

    def __call__(self, img):  # __call__使实例可像函数一样被调用
        # __call__使该类实例可当函数调用(如transform(img)),这是torchvision数据增强的标准接口模式
        # img: PIL Image 或 Tensor
        pieces = []
        w, h = img.size
        pw, ph = w // 3, h // 3
        for i in range(3):
            for j in range(3):
                piece = img.crop((j*pw, i*ph, (j+1)*pw, (i+1)*ph))
                pieces.append(piece)

        perm = torch.randperm(9)
        shuffled = [pieces[p] for p in perm]  # 列表推导式,简洁创建列表
        return shuffled, perm

# 2. 旋转预测
class RotationTransform:
    """随机旋转图像 0°/90°/180°/270°"""
    def __call__(self, img):
        angle_idx = torch.randint(0, 4, (1,)).item()  # .item()将单元素张量转为Python数值
        angle = angle_idx * 90
        rotated = transforms.functional.rotate(img, angle)
        return rotated, angle_idx

# 3. 颜色化 (Colorization) — 从灰度图预测颜色

3. 对比学习

SimCLR框架

3.1 核心思想

拉近同一样本不同增强视图的特征(正对),推远不同样本的特征(负对)。

3.2 InfoNCE 损失

\[\mathcal{L} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)}\]

其中 \(\text{sim}(u, v) = \frac{u^T v}{\|u\| \|v\|}\) 是余弦相似度,\(\tau\) 是温度参数。

3.3 SimCLR 完整实现

Python
class SimCLRAugmentation:
    """SimCLR 数据增强"""
    def __init__(self, size=32):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3),
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

class ProjectionHead(nn.Module):  # 继承nn.Module定义神经网络层
    """SimCLR 投影头"""
    def __init__(self, input_dim, hidden_dim=256, output_dim=128):
        super().__init__()  # super()调用父类方法
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

class SimCLR(nn.Module):
    """SimCLR: A Simple Framework for Contrastive Learning"""
    def __init__(self, backbone, feature_dim, proj_dim=128, temperature=0.5):
        super().__init__()
        self.backbone = backbone
        self.projector = ProjectionHead(feature_dim, 256, proj_dim)
        self.temperature = temperature

    def forward(self, x1, x2):
        """x1, x2: 同一批样本的两个增强视图"""
        # 提取特征
        h1 = self.backbone(x1)  # (batch, feature_dim)
        h2 = self.backbone(x2)

        # 投影
        z1 = self.projector(h1)  # (batch, proj_dim)
        z2 = self.projector(h2)

        return h1, h2, z1, z2

    def contrastive_loss(self, z1, z2):
        """NT-Xent (Normalized Temperature-scaled Cross Entropy) Loss"""
        batch_size = z1.size(0)

        # L2 归一化
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        # 合并
        z = torch.cat([z1, z2], dim=0)  # (2N, dim)

        # 相似度矩阵
        sim_matrix = torch.mm(z, z.t()) / self.temperature  # (2N, 2N)

        # 掩码:排除自身
        mask = ~torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)  # 链式调用,连续执行多个方法  # view重塑张量形状(要求内存连续)

        # 正对标签
        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(0, batch_size)
        ]).to(z.device)

        # 调整标签(因为去掉了对角线)
        labels = labels - (labels > torch.arange(2*batch_size, device=z.device)).long()

        loss = F.cross_entropy(sim_matrix, labels)
        return loss

def train_simclr(model, train_loader, epochs=200, device='cuda'):
    """训练 SimCLR"""
    model = model.to(device)  # .to(device)将数据移至GPU/CPU
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        model.train()  # train()开启训练模式
        total_loss = 0

        for (x1, x2), _ in train_loader:
            x1, x2 = x1.to(device), x2.to(device)

            _, _, z1, z2 = model(x1, x2)
            loss = model.contrastive_loss(z1, z2)

            optimizer.zero_grad()  # 清零梯度,防止梯度累积
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 根据梯度更新模型参数

            total_loss += loss.item()

        scheduler.step()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")

MoCo动量对比

3.4 MoCo(动量对比)

Python
import copy

class MoCo(nn.Module):
    """MoCo v2: 动量对比学习"""
    def __init__(self, backbone, feature_dim, queue_size=65536,
                 momentum=0.999, temperature=0.07):
        super().__init__()
        self.K = queue_size
        self.m = momentum
        self.T = temperature

        # 查询编码器
        self.encoder_q = backbone
        self.proj_q = ProjectionHead(feature_dim, 256, 128)

        # 键编码器(动量更新,必须用deepcopy避免共享引用)
        self.encoder_k = copy.deepcopy(backbone)
        self.proj_k = ProjectionHead(feature_dim, 256, 128)

        # 初始化键编码器 = 查询编码器
        for param_q, param_k in zip(
            list(self.encoder_q.parameters()) + list(self.proj_q.parameters()),
            list(self.encoder_k.parameters()) + list(self.proj_k.parameters())
        ):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        # 负样本队列
        self.register_buffer("queue", F.normalize(torch.randn(128, queue_size), dim=0))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()  # 禁用梯度计算,节省内存
    def _momentum_update(self):
        """动量更新键编码器"""
        for param_q, param_k in zip(
            list(self.encoder_q.parameters()) + list(self.proj_q.parameters()),
            list(self.encoder_k.parameters()) + list(self.proj_k.parameters())
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)

        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, x_q, x_k):
        # 查询编码
        q = F.normalize(self.proj_q(self.encoder_q(x_q)), dim=1)

        # 键编码(不计算梯度)
        with torch.no_grad():
            self._momentum_update()
            k = F.normalize(self.proj_k(self.encoder_k(x_k)), dim=1)

        # 正对相似度
        # einsum('nc,nc->n')对Q和K做逐样本内积(N个标量),unsqueeze(-1)扩展为(N,1)以便与负对(N,K)拼接成(N,1+K)的logits
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)  # (N, 1)

        # 负对相似度(从队列中)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])  # (N, K)

        logits = torch.cat([l_pos, l_neg], dim=1) / self.T  # (N, 1+K)
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)

        loss = F.cross_entropy(logits, labels)

        self._dequeue_and_enqueue(k)

        return loss

4. BYOL与自蒸馏

4.1 BYOL — 不需要负样本

Grill et al.(2020)的 BYOL 证明对比学习不一定需要负样本!

Python
class BYOL(nn.Module):
    """Bootstrap Your Own Latent"""
    def __init__(self, backbone, feature_dim, hidden_dim=256, proj_dim=128, momentum=0.996):
        super().__init__()
        self.momentum = momentum

        # 在线网络 (online)
        self.online_encoder = backbone
        self.online_projector = ProjectionHead(feature_dim, hidden_dim, proj_dim)
        self.predictor = nn.Sequential(
            nn.Linear(proj_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, proj_dim),
        )

        # 目标网络 (target) — EMA 更新
        self.target_encoder = type(backbone)()
        self.target_projector = ProjectionHead(feature_dim, hidden_dim, proj_dim)

        # 初始化
        self._init_target()

    def _init_target(self):
        for online_params, target_params in zip(
            list(self.online_encoder.parameters()) + list(self.online_projector.parameters()),
            list(self.target_encoder.parameters()) + list(self.target_projector.parameters())
        ):
            target_params.data.copy_(online_params.data)
            target_params.requires_grad = False

    @torch.no_grad()
    def update_target(self):
        for online_params, target_params in zip(
            list(self.online_encoder.parameters()) + list(self.online_projector.parameters()),
            list(self.target_encoder.parameters()) + list(self.target_projector.parameters())
        ):
            target_params.data = self.momentum * target_params.data + (1 - self.momentum) * online_params.data

    def forward(self, x1, x2):
        # 在线网络
        online_proj1 = self.online_projector(self.online_encoder(x1))
        online_proj2 = self.online_projector(self.online_encoder(x2))
        online_pred1 = self.predictor(online_proj1)
        online_pred2 = self.predictor(online_proj2)

        # 目标网络(不计算梯度)
        with torch.no_grad():
            target_proj1 = self.target_projector(self.target_encoder(x1))
            target_proj2 = self.target_projector(self.target_encoder(x2))

        # BYOL 损失(对称)
        loss = (self._regression_loss(online_pred1, target_proj2) +
                self._regression_loss(online_pred2, target_proj1))
        return loss / 2

    @staticmethod  # @staticmethod静态方法,无需实例即可调用
    def _regression_loss(x, y):
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)
        return 2 - 2 * (x * y).sum(dim=-1).mean()

5. 掩码建模

MAE掩码图像建模

5.1 视觉掩码建模思路

随机遮挡图像的一部分,训练模型预测被遮挡的内容。

Python
class SimpleMIM(nn.Module):
    """简化版掩码图像建模"""
    def __init__(self, encoder, decoder_dim=256, patch_size=16,
                 img_size=224, mask_ratio=0.75):
        super().__init__()
        self.encoder = encoder
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # 解码器(简单的线性层)
        self.decoder = nn.Sequential(
            nn.Linear(decoder_dim, 256),
            nn.GELU(),
            nn.Linear(256, patch_size * patch_size * 3)
        )

    def patchify(self, imgs):
        """将图像转换为 patch 序列"""
        p = self.patch_size
        B, C, H, W = imgs.shape
        h, w = H // p, W // p
        # 6D张量(B,C,h,p,w,p)经permute重排为(B,h,w,p,p,C),再reshape将每个patch展平为p*p*C维向量,得(B,patch数,patch维度)
        patches = imgs.reshape(B, C, h, p, w, p)
        patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(B, h*w, p*p*C)
        return patches

    def forward(self, imgs):
        patches = self.patchify(imgs)
        B, N, _ = patches.shape

        # 随机生成掩码
        num_masked = int(N * self.mask_ratio)
        noise = torch.rand(B, N, device=imgs.device)
        ids_shuffle = noise.argsort(dim=1)
        mask = torch.zeros(B, N, device=imgs.device)
        mask.scatter_(1, ids_shuffle[:, :num_masked], 1)  # 1 = 被掩码

        # 编码 + 解码
        features = self.encoder(imgs)  # 根据具体编码器
        pred = self.decoder(features)

        # 只在被掩码的位置计算损失
        loss = F.mse_loss(pred * mask.unsqueeze(-1), patches * mask.unsqueeze(-1))
        return loss

6. 自监督学习在NLP中的应用

6.1 BERT风格:掩码语言建模

Python
class MLMHead(nn.Module):
    """掩码语言建模头"""
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        self.dense = nn.Linear(hidden_dim, hidden_dim)
        self.activation = nn.GELU()
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, hidden_states):
        x = self.layer_norm(self.activation(self.dense(hidden_states)))
        return self.decoder(x)

def create_mlm_batch(input_ids, vocab_size, mask_token_id, mask_prob=0.15):
    """创建 MLM 训练批次"""
    labels = input_ids.clone()

    # 随机选择 15% 的 token
    probability_matrix = torch.full(input_ids.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # 只对被掩码的位置计算损失

    # 80% 替换为 [MASK]
    replace_mask = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
    input_ids[replace_mask] = mask_token_id

    # 10% 替换为随机 token
    random_mask = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~replace_mask
    random_tokens = torch.randint(vocab_size, input_ids.shape)
    input_ids[random_mask] = random_tokens[random_mask]

    # 10% 保持不变

    return input_ids, labels

6.2 GPT风格:因果语言建模

Python
class CLMHead(nn.Module):
    """因果语言建模(自回归)"""
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, hidden_states, input_ids):
        logits = self.lm_head(hidden_states)
        # 预测下一个 token
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., 1:].contiguous()
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )
        return loss

7. 下游任务迁移

7.1 线性探测(Linear Probing)

Python
class LinearProbe(nn.Module):
    """线性探测 — 冻结预训练编码器,只训练线性分类头"""
    def __init__(self, encoder, feature_dim, num_classes):
        super().__init__()
        self.encoder = encoder
        for param in self.encoder.parameters():
            param.requires_grad = False

        self.classifier = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return self.classifier(features)

def evaluate_representations(encoder, feature_dim, train_loader, test_loader,
                             num_classes=10, epochs=100, device='cuda'):
    """评估自监督表示质量"""
    probe = LinearProbe(encoder, feature_dim, num_classes).to(device)
    optimizer = torch.optim.Adam(probe.classifier.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        probe.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            logits = probe(images)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # 测试
    probe.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            logits = probe(images)
            correct += (logits.argmax(1) == labels).sum().item()
            total += labels.size(0)

    accuracy = 100. * correct / total
    print(f"Linear Probe 准确率: {accuracy:.2f}%")
    return accuracy

8. 练习与自我检查

练习题

  1. SimCLR:在 CIFAR-10 上训练 SimCLR,然后用线性探测评估表示质量。
  2. 数据增强消融:去掉 SimCLR 的不同增强操作,分析哪些增强最重要。
  3. 对比损失:实现 InfoNCE 损失,实验不同温度参数 \(\tau\) 的影响。
  4. BYOL vs SimCLR:在同一设置下对比两者的性能。
  5. 预训练效果:对比随机初始化、ImageNet 监督预训练、自监督预训练在小数据集上的迁移效果。

自我检查清单

  • 理解自监督学习的核心动机
  • 能区分对比学习、自蒸馏、掩码建模三种范式
  • 理解 InfoNCE 损失的推导
  • 知道 SimCLR 的关键设计(增强、投影头、大batch)
  • 理解 BYOL 为什么不需要负样本
  • 了解 MLM 和 CLM 的区别
  • 能用线性探测评估表示质量

自监督学习应用

下一篇: 实战项目/03-文本生成实战