跳转至

20 - 自监督学习 (Self-Supervised Learning)

自监督学习图

🎯 什么是自监督学习?

学习范式对比

Text Only
监督学习:
数据:(图像, 标签) → 学习映射:图像 → 标签
缺点:需要大量标注数据

无监督学习:
数据:(图像) → 发现结构:聚类、降维
缺点:没有明确任务目标

自监督学习:
数据:(图像) → 构造伪标签 → 学习表示
优点:利用无标签数据学习通用表示

核心思想

Text Only
1. 从数据本身构造监督信号(伪标签)
2. 通过预训练任务学习特征表示
3. 将学到的表示迁移到下游任务

关键优势:
- 无需人工标注
- 可扩展性强
- 学到的表示通用性好

🖼️ 计算机视觉中的自监督学习

1. 基于对比学习的方法

SimCLR

Text Only
核心思想:同一张图像的不同增强视图应该相似,不同图像应该不同

步骤:
1. 对每张图像做两种随机增强 → 得到两个视图
2. 编码器提取特征
3. 对比损失:拉近正样本,推远负样本

损失函数(NT-Xent):
L = -log(exp(sim(z_i, z_j)/τ) / Σ exp(sim(z_i, z_k)/τ))
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class SimCLR(nn.Module):  # 继承nn.Module定义神经网络层
    def __init__(self, base_encoder, projection_dim=128):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = F.normalize(self.projector(h), dim=1)
        return h, z

# NT-Xent损失
def nt_xent_loss(z_i, z_j, temperature=0.5):
    """
    z_i, z_j: (batch_size, projection_dim)
    """
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)  # (2*batch, dim)

    # 计算相似度矩阵
    sim_matrix = torch.mm(z, z.t()) / temperature

    # 正样本对
    positives = torch.cat([
        torch.diag(sim_matrix, batch_size),
        torch.diag(sim_matrix, -batch_size)
    ]).view(2*batch_size, 1)  # view重塑张量形状(要求内存连续)

    # 负样本
    mask = torch.eye(2*batch_size, device=z.device).bool()
    negatives = sim_matrix[~mask].view(2*batch_size, -1)

    # 计算损失
    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(2*batch_size, device=z.device, dtype=torch.long)

    return F.cross_entropy(logits, labels)

MoCo (Momentum Contrast)

Text Only
核心创新:
1. 动态字典:维护一个大的负样本队列
2. 动量编码器:缓慢更新的编码器提供一致的键

优势:
- 可以构建更大的字典
- 负样本更丰富
- 训练更稳定
Python
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999):
        super().__init__()
        self.K = K
        self.m = m

        # 查询编码器
        self.encoder_q = base_encoder
        self.projector_q = nn.Linear(512, dim)

        # 键编码器(动量更新)— 必须深拷贝,不能共享引用
        self.encoder_k = copy.deepcopy(base_encoder)
        self.projector_k = copy.deepcopy(self.projector_q)

        # 冻结键编码器参数(仅通过动量更新)
        for param_k in self.encoder_k.parameters():
            param_k.requires_grad = False
        for param_k in self.projector_k.parameters():
            param_k.requires_grad = False

        # 队列
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()  # 禁用梯度计算,节省内存
    def _momentum_update_key_encoder(self):
        """动量更新键编码器(包括projector)"""
        for param_q, param_k in zip(self.encoder_q.parameters(),
                                     self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
        # 同步更新projector_k
        for param_q, param_k in zip(self.projector_q.parameters(),
                                     self.projector_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """更新队列"""
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)

        # 替换队列中的样本
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

2. 基于掩码的方法

MAE (Masked Autoencoder)

Text Only
核心思想:
1. 随机掩码图像的大部分区域(如75%)
2. 编码器只处理可见区域
3. 解码器重建完整图像

优势:
- 非对称编码器-解码器设计
- 计算效率高
- 学到的表示质量好
Python
class MAE(nn.Module):
    def __init__(self, encoder, decoder, mask_ratio=0.75):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.mask_ratio = mask_ratio

    def random_masking(self, x, mask_ratio):
        """随机掩码"""
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # 保留的token
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))  # unsqueeze增加一个维度

        # 生成掩码
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward(self, x):
        # 分块
        x = self.patch_embed(x)

        # 添加位置编码
        x = x + self.pos_embed

        # 随机掩码
        x, mask, ids_restore = self.random_masking(x, self.mask_ratio)

        # 编码
        x = self.encoder(x)

        # 解码
        x = self.decoder(x, ids_restore)

        return x, mask

3. 基于预测的方法

旋转预测

Text Only
任务:预测图像旋转角度(0°, 90°, 180°, 270°)
Python
class RotationPrediction(nn.Module):
    def __init__(self, base_encoder):
        super().__init__()
        self.encoder = base_encoder
        self.classifier = nn.Linear(512, 4)  # 4个旋转角度

    def forward(self, x):
        features = self.encoder(x)
        logits = self.classifier(features)
        return logits

# 数据增强
def rotate_batch(images):
    """随机旋转图像并返回标签"""
    batch_size = images.size(0)
    labels = torch.randint(0, 4, (batch_size,))

    rotated = []
    for i, label in enumerate(labels):  # enumerate同时获取索引和元素
        angle = label * 90
        rotated.append(transforms.functional.rotate(images[i], angle.item()))  # .item()将单元素张量转为Python数值

    return torch.stack(rotated), labels

📝 自然语言处理中的自监督学习

掩码语言模型 (MLM)

已在 18-NLP与Transformer详解.md 中详细介绍。

Text Only
BERT的MLM:
- 随机掩码15%的token
- 80%用[MASK]替换
- 10%用随机token替换
- 10%保持不变

自回归语言模型

Text Only
GPT系列:
- 从左到右预测下一个token
- 标准语言建模目标

优点:
- 天然适合生成任务
- 可以处理任意长度序列

对比学习 (SimCSE)

Python
class SimCSE(nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.encoder = pretrained_model

    def forward(self, input_ids, attention_mask):
        # 同一输入过两次(dropout不同)
        output1 = self.encoder(input_ids, attention_mask)
        output2 = self.encoder(input_ids, attention_mask)

        # [CLS]向量作为句子表示
        z1 = output1.last_hidden_state[:, 0]
        z2 = output2.last_hidden_state[:, 0]

        return z1, z2

🎯 下游任务应用

线性探测 (Linear Probing)

Python
# 冻结预训练编码器,只训练分类头
class LinearProbe(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.encoder.eval()  # eval()开启评估模式(关闭Dropout等)

        # 冻结编码器参数
        for param in self.encoder.parameters():
            param.requires_grad = False

        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return self.classifier(features)

微调 (Fine-tuning)

Python
# 端到端微调
class FineTunedModel(nn.Module):
    def __init__(self, pretrained_model, num_classes):
        super().__init__()
        self.model = pretrained_model
        self.model.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        return self.model(x)

# 使用较小学习率微调
optimizer = torch.optim.SGD([
    {'params': model.model.fc.parameters(), 'lr': 0.01},
    {'params': model.model.layer4.parameters(), 'lr': 0.001},
    {'params': model.model.layer1.parameters(), 'lr': 0.0001}
])

📊 方法对比

方法 预训练任务 主要优势 适用场景
SimCLR 对比学习 简单有效 通用表示学习
MoCo 对比学习 大字典 大规模数据
MAE 掩码重建 计算高效 视觉表示
BEiT 掩码预测 离散表示 视觉-语言
BERT 掩码语言模型 双向上下文 NLP
GPT 自回归 生成能力 文本生成

💡 总结

Text Only
自监督学习的核心:
从无标签数据中构造监督信号

主要方法:
1. 对比学习:学习判别性表示
2. 掩码方法:学习重建能力
3. 预测方法:学习预测能力

未来趋势:
- 多模态自监督学习
- 更大规模的预训练
- 与生成模型结合

实践建议:
1. 从SimCLR或MAE开始
2. 使用预训练模型加速下游任务
3. 根据任务选择合适的预训练方法

下一步:学习 21-元学习.md,掌握快速适应新任务的能力!