20 - 自监督学习 (Self-Supervised Learning)¶
🎯 什么是自监督学习?¶
学习范式对比¶
Text Only
监督学习:
数据:(图像, 标签) → 学习映射:图像 → 标签
缺点:需要大量标注数据
无监督学习:
数据:(图像) → 发现结构:聚类、降维
缺点:没有明确任务目标
自监督学习:
数据:(图像) → 构造伪标签 → 学习表示
优点:利用无标签数据学习通用表示
核心思想¶
🖼️ 计算机视觉中的自监督学习¶
1. 基于对比学习的方法¶
SimCLR¶
Text Only
核心思想:同一张图像的不同增强视图应该相似,不同图像应该不同
步骤:
1. 对每张图像做两种随机增强 → 得到两个视图
2. 编码器提取特征
3. 对比损失:拉近正样本,推远负样本
损失函数(NT-Xent):
L = -log(exp(sim(z_i, z_j)/τ) / Σ exp(sim(z_i, z_k)/τ))
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
class SimCLR(nn.Module): # 继承nn.Module定义神经网络层
def __init__(self, base_encoder, projection_dim=128): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.encoder = base_encoder
self.projector = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, projection_dim)
)
def forward(self, x):
h = self.encoder(x)
z = F.normalize(self.projector(h), dim=1)
return h, z
# NT-Xent损失
def nt_xent_loss(z_i, z_j, temperature=0.5):
"""
z_i, z_j: (batch_size, projection_dim)
"""
batch_size = z_i.size(0)
z = torch.cat([z_i, z_j], dim=0) # (2*batch, dim)
# 计算相似度矩阵
sim_matrix = torch.mm(z, z.t()) / temperature
# 正样本对
positives = torch.cat([
torch.diag(sim_matrix, batch_size),
torch.diag(sim_matrix, -batch_size)
]).view(2*batch_size, 1) # view重塑张量形状(要求内存连续)
# 负样本
mask = torch.eye(2*batch_size, device=z.device).bool()
negatives = sim_matrix[~mask].view(2*batch_size, -1)
# 计算损失
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(2*batch_size, device=z.device, dtype=torch.long)
return F.cross_entropy(logits, labels)
MoCo (Momentum Contrast)¶
Python
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999):
super().__init__()
self.K = K
self.m = m
# 查询编码器
self.encoder_q = base_encoder
self.projector_q = nn.Linear(512, dim)
# 键编码器(动量更新)— 必须深拷贝,不能共享引用
self.encoder_k = copy.deepcopy(base_encoder)
self.projector_k = copy.deepcopy(self.projector_q)
# 冻结键编码器参数(仅通过动量更新)
for param_k in self.encoder_k.parameters():
param_k.requires_grad = False
for param_k in self.projector_k.parameters():
param_k.requires_grad = False
# 队列
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad() # 禁用梯度计算,节省内存
def _momentum_update_key_encoder(self):
"""动量更新键编码器(包括projector)"""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
# 同步更新projector_k
for param_q, param_k in zip(self.projector_q.parameters(),
self.projector_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""更新队列"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# 替换队列中的样本
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
2. 基于掩码的方法¶
MAE (Masked Autoencoder)¶
Text Only
核心思想:
1. 随机掩码图像的大部分区域(如75%)
2. 编码器只处理可见区域
3. 解码器重建完整图像
优势:
- 非对称编码器-解码器设计
- 计算效率高
- 学到的表示质量好
Python
class MAE(nn.Module):
def __init__(self, encoder, decoder, mask_ratio=0.75):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.mask_ratio = mask_ratio
def random_masking(self, x, mask_ratio):
"""随机掩码"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# 保留的token
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # unsqueeze增加一个维度
# 生成掩码
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(self, x):
# 分块
x = self.patch_embed(x)
# 添加位置编码
x = x + self.pos_embed
# 随机掩码
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# 编码
x = self.encoder(x)
# 解码
x = self.decoder(x, ids_restore)
return x, mask
3. 基于预测的方法¶
旋转预测¶
Python
class RotationPrediction(nn.Module):
def __init__(self, base_encoder):
super().__init__()
self.encoder = base_encoder
self.classifier = nn.Linear(512, 4) # 4个旋转角度
def forward(self, x):
features = self.encoder(x)
logits = self.classifier(features)
return logits
# 数据增强
def rotate_batch(images):
"""随机旋转图像并返回标签"""
batch_size = images.size(0)
labels = torch.randint(0, 4, (batch_size,))
rotated = []
for i, label in enumerate(labels): # enumerate同时获取索引和元素
angle = label * 90
rotated.append(transforms.functional.rotate(images[i], angle.item())) # .item()将单元素张量转为Python数值
return torch.stack(rotated), labels
📝 自然语言处理中的自监督学习¶
掩码语言模型 (MLM)¶
已在 18-NLP与Transformer详解.md 中详细介绍。
自回归语言模型¶
对比学习 (SimCSE)¶
Python
class SimCSE(nn.Module):
def __init__(self, pretrained_model):
super().__init__()
self.encoder = pretrained_model
def forward(self, input_ids, attention_mask):
# 同一输入过两次(dropout不同)
output1 = self.encoder(input_ids, attention_mask)
output2 = self.encoder(input_ids, attention_mask)
# [CLS]向量作为句子表示
z1 = output1.last_hidden_state[:, 0]
z2 = output2.last_hidden_state[:, 0]
return z1, z2
🎯 下游任务应用¶
线性探测 (Linear Probing)¶
Python
# 冻结预训练编码器,只训练分类头
class LinearProbe(nn.Module):
def __init__(self, encoder, num_classes):
super().__init__()
self.encoder = encoder
self.encoder.eval() # eval()开启评估模式(关闭Dropout等)
# 冻结编码器参数
for param in self.encoder.parameters():
param.requires_grad = False
self.classifier = nn.Linear(512, num_classes)
def forward(self, x):
with torch.no_grad():
features = self.encoder(x)
return self.classifier(features)
微调 (Fine-tuning)¶
Python
# 端到端微调
class FineTunedModel(nn.Module):
def __init__(self, pretrained_model, num_classes):
super().__init__()
self.model = pretrained_model
self.model.fc = nn.Linear(512, num_classes)
def forward(self, x):
return self.model(x)
# 使用较小学习率微调
optimizer = torch.optim.SGD([
{'params': model.model.fc.parameters(), 'lr': 0.01},
{'params': model.model.layer4.parameters(), 'lr': 0.001},
{'params': model.model.layer1.parameters(), 'lr': 0.0001}
])
📊 方法对比¶
| 方法 | 预训练任务 | 主要优势 | 适用场景 |
|---|---|---|---|
| SimCLR | 对比学习 | 简单有效 | 通用表示学习 |
| MoCo | 对比学习 | 大字典 | 大规模数据 |
| MAE | 掩码重建 | 计算高效 | 视觉表示 |
| BEiT | 掩码预测 | 离散表示 | 视觉-语言 |
| BERT | 掩码语言模型 | 双向上下文 | NLP |
| GPT | 自回归 | 生成能力 | 文本生成 |
💡 总结¶
Text Only
自监督学习的核心:
从无标签数据中构造监督信号
主要方法:
1. 对比学习:学习判别性表示
2. 掩码方法:学习重建能力
3. 预测方法:学习预测能力
未来趋势:
- 多模态自监督学习
- 更大规模的预训练
- 与生成模型结合
实践建议:
1. 从SimCLR或MAE开始
2. 使用预训练模型加速下游任务
3. 根据任务选择合适的预训练方法
下一步:学习 21-元学习.md,掌握快速适应新任务的能力!