02-自监督学习¶
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐⭐ 高级 前置知识: CNN/Transformer基础、对比学习概念、PyTorch 学习目标: 理解自监督学习的核心范式(对比学习、掩码建模、自蒸馏),掌握SimCLR、BYOL、MAE等方法
目录¶
1. 自监督学习概述¶
1.1 学习范式对比¶
| 范式 | 标签需求 | 代表方法 |
|---|---|---|
| 监督学习 | 需要大量人工标注 | ResNet, BERT(微调) |
| 无监督学习 | 不需要标签 | K-Means, AE |
| 自监督学习 | 从数据自身构造监督信号 | SimCLR, MAE, BERT(预训练) |
1.2 自监督学习的核心思想¶
从未标注数据中自动构造监督信号("代理任务"),学习通用的特征表示,然后迁移到下游任务。
1.3 主要范式¶
Text Only
自监督学习
├── 对比学习 (Contrastive Learning)
│ ├── SimCLR
│ ├── MoCo
│ └── CLIP
├── 自蒸馏 (Self-Distillation)
│ ├── BYOL
│ └── DINO
└── 掩码建模 (Masked Modeling)
├── BERT (NLP)
├── MAE (Vision)
└── BEiT
2. 前置任务(Pretext Tasks)¶
2.1 经典前置任务¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
# 1. 拼图 (Jigsaw Puzzle)
class JigsawTransform:
"""将图像切成 3x3 的块并打乱"""
def __init__(self, num_permutations=100): # __init__构造方法,创建对象时自动调用
self.grid_size = 3
def __call__(self, img): # __call__使实例可像函数一样被调用
# __call__使该类实例可当函数调用(如transform(img)),这是torchvision数据增强的标准接口模式
# img: PIL Image 或 Tensor
pieces = []
w, h = img.size
pw, ph = w // 3, h // 3
for i in range(3):
for j in range(3):
piece = img.crop((j*pw, i*ph, (j+1)*pw, (i+1)*ph))
pieces.append(piece)
perm = torch.randperm(9)
shuffled = [pieces[p] for p in perm] # 列表推导式,简洁创建列表
return shuffled, perm
# 2. 旋转预测
class RotationTransform:
"""随机旋转图像 0°/90°/180°/270°"""
def __call__(self, img):
angle_idx = torch.randint(0, 4, (1,)).item() # .item()将单元素张量转为Python数值
angle = angle_idx * 90
rotated = transforms.functional.rotate(img, angle)
return rotated, angle_idx
# 3. 颜色化 (Colorization) — 从灰度图预测颜色
3. 对比学习¶
3.1 核心思想¶
拉近同一样本不同增强视图的特征(正对),推远不同样本的特征(负对)。
3.2 InfoNCE 损失¶
\[\mathcal{L} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)}\]
其中 \(\text{sim}(u, v) = \frac{u^T v}{\|u\| \|v\|}\) 是余弦相似度,\(\tau\) 是温度参数。
3.3 SimCLR 完整实现¶
Python
class SimCLRAugmentation:
"""SimCLR 数据增强"""
def __init__(self, size=32):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
])
def __call__(self, x):
return self.transform(x), self.transform(x)
class ProjectionHead(nn.Module): # 继承nn.Module定义神经网络层
"""SimCLR 投影头"""
def __init__(self, input_dim, hidden_dim=256, output_dim=128):
super().__init__() # super()调用父类方法
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.net(x)
class SimCLR(nn.Module):
"""SimCLR: A Simple Framework for Contrastive Learning"""
def __init__(self, backbone, feature_dim, proj_dim=128, temperature=0.5):
super().__init__()
self.backbone = backbone
self.projector = ProjectionHead(feature_dim, 256, proj_dim)
self.temperature = temperature
def forward(self, x1, x2):
"""x1, x2: 同一批样本的两个增强视图"""
# 提取特征
h1 = self.backbone(x1) # (batch, feature_dim)
h2 = self.backbone(x2)
# 投影
z1 = self.projector(h1) # (batch, proj_dim)
z2 = self.projector(h2)
return h1, h2, z1, z2
def contrastive_loss(self, z1, z2):
"""NT-Xent (Normalized Temperature-scaled Cross Entropy) Loss"""
batch_size = z1.size(0)
# L2 归一化
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 合并
z = torch.cat([z1, z2], dim=0) # (2N, dim)
# 相似度矩阵
sim_matrix = torch.mm(z, z.t()) / self.temperature # (2N, 2N)
# 掩码:排除自身
mask = ~torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1) # 链式调用,连续执行多个方法 # view重塑张量形状(要求内存连续)
# 正对标签
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(0, batch_size)
]).to(z.device)
# 调整标签(因为去掉了对角线)
labels = labels - (labels > torch.arange(2*batch_size, device=z.device)).long()
loss = F.cross_entropy(sim_matrix, labels)
return loss
def train_simclr(model, train_loader, epochs=200, device='cuda'):
"""训练 SimCLR"""
model = model.to(device) # .to(device)将数据移至GPU/CPU
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
model.train() # train()开启训练模式
total_loss = 0
for (x1, x2), _ in train_loader:
x1, x2 = x1.to(device), x2.to(device)
_, _, z1, z2 = model(x1, x2)
loss = model.contrastive_loss(z1, z2)
optimizer.zero_grad() # 清零梯度,防止梯度累积
loss.backward() # 反向传播计算梯度
optimizer.step() # 根据梯度更新模型参数
total_loss += loss.item()
scheduler.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")
3.4 MoCo(动量对比)¶
Python
import copy
class MoCo(nn.Module):
"""MoCo v2: 动量对比学习"""
def __init__(self, backbone, feature_dim, queue_size=65536,
momentum=0.999, temperature=0.07):
super().__init__()
self.K = queue_size
self.m = momentum
self.T = temperature
# 查询编码器
self.encoder_q = backbone
self.proj_q = ProjectionHead(feature_dim, 256, 128)
# 键编码器(动量更新,必须用deepcopy避免共享引用)
self.encoder_k = copy.deepcopy(backbone)
self.proj_k = ProjectionHead(feature_dim, 256, 128)
# 初始化键编码器 = 查询编码器
for param_q, param_k in zip(
list(self.encoder_q.parameters()) + list(self.proj_q.parameters()),
list(self.encoder_k.parameters()) + list(self.proj_k.parameters())
):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# 负样本队列
self.register_buffer("queue", F.normalize(torch.randn(128, queue_size), dim=0))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad() # 禁用梯度计算,节省内存
def _momentum_update(self):
"""动量更新键编码器"""
for param_q, param_k in zip(
list(self.encoder_q.parameters()) + list(self.proj_q.parameters()),
list(self.encoder_k.parameters()) + list(self.proj_k.parameters())
):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, x_q, x_k):
# 查询编码
q = F.normalize(self.proj_q(self.encoder_q(x_q)), dim=1)
# 键编码(不计算梯度)
with torch.no_grad():
self._momentum_update()
k = F.normalize(self.proj_k(self.encoder_k(x_k)), dim=1)
# 正对相似度
# einsum('nc,nc->n')对Q和K做逐样本内积(N个标量),unsqueeze(-1)扩展为(N,1)以便与负对(N,K)拼接成(N,1+K)的logits
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # (N, 1)
# 负对相似度(从队列中)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # (N, K)
logits = torch.cat([l_pos, l_neg], dim=1) / self.T # (N, 1+K)
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)
self._dequeue_and_enqueue(k)
return loss
4. BYOL与自蒸馏¶
4.1 BYOL — 不需要负样本¶
Grill et al.(2020)的 BYOL 证明对比学习不一定需要负样本!
Python
class BYOL(nn.Module):
"""Bootstrap Your Own Latent"""
def __init__(self, backbone, feature_dim, hidden_dim=256, proj_dim=128, momentum=0.996):
super().__init__()
self.momentum = momentum
# 在线网络 (online)
self.online_encoder = backbone
self.online_projector = ProjectionHead(feature_dim, hidden_dim, proj_dim)
self.predictor = nn.Sequential(
nn.Linear(proj_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, proj_dim),
)
# 目标网络 (target) — EMA 更新
self.target_encoder = type(backbone)()
self.target_projector = ProjectionHead(feature_dim, hidden_dim, proj_dim)
# 初始化
self._init_target()
def _init_target(self):
for online_params, target_params in zip(
list(self.online_encoder.parameters()) + list(self.online_projector.parameters()),
list(self.target_encoder.parameters()) + list(self.target_projector.parameters())
):
target_params.data.copy_(online_params.data)
target_params.requires_grad = False
@torch.no_grad()
def update_target(self):
for online_params, target_params in zip(
list(self.online_encoder.parameters()) + list(self.online_projector.parameters()),
list(self.target_encoder.parameters()) + list(self.target_projector.parameters())
):
target_params.data = self.momentum * target_params.data + (1 - self.momentum) * online_params.data
def forward(self, x1, x2):
# 在线网络
online_proj1 = self.online_projector(self.online_encoder(x1))
online_proj2 = self.online_projector(self.online_encoder(x2))
online_pred1 = self.predictor(online_proj1)
online_pred2 = self.predictor(online_proj2)
# 目标网络(不计算梯度)
with torch.no_grad():
target_proj1 = self.target_projector(self.target_encoder(x1))
target_proj2 = self.target_projector(self.target_encoder(x2))
# BYOL 损失(对称)
loss = (self._regression_loss(online_pred1, target_proj2) +
self._regression_loss(online_pred2, target_proj1))
return loss / 2
@staticmethod # @staticmethod静态方法,无需实例即可调用
def _regression_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return 2 - 2 * (x * y).sum(dim=-1).mean()
5. 掩码建模¶
5.1 视觉掩码建模思路¶
随机遮挡图像的一部分,训练模型预测被遮挡的内容。
Python
class SimpleMIM(nn.Module):
"""简化版掩码图像建模"""
def __init__(self, encoder, decoder_dim=256, patch_size=16,
img_size=224, mask_ratio=0.75):
super().__init__()
self.encoder = encoder
self.mask_ratio = mask_ratio
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 解码器(简单的线性层)
self.decoder = nn.Sequential(
nn.Linear(decoder_dim, 256),
nn.GELU(),
nn.Linear(256, patch_size * patch_size * 3)
)
def patchify(self, imgs):
"""将图像转换为 patch 序列"""
p = self.patch_size
B, C, H, W = imgs.shape
h, w = H // p, W // p
# 6D张量(B,C,h,p,w,p)经permute重排为(B,h,w,p,p,C),再reshape将每个patch展平为p*p*C维向量,得(B,patch数,patch维度)
patches = imgs.reshape(B, C, h, p, w, p)
patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(B, h*w, p*p*C)
return patches
def forward(self, imgs):
patches = self.patchify(imgs)
B, N, _ = patches.shape
# 随机生成掩码
num_masked = int(N * self.mask_ratio)
noise = torch.rand(B, N, device=imgs.device)
ids_shuffle = noise.argsort(dim=1)
mask = torch.zeros(B, N, device=imgs.device)
mask.scatter_(1, ids_shuffle[:, :num_masked], 1) # 1 = 被掩码
# 编码 + 解码
features = self.encoder(imgs) # 根据具体编码器
pred = self.decoder(features)
# 只在被掩码的位置计算损失
loss = F.mse_loss(pred * mask.unsqueeze(-1), patches * mask.unsqueeze(-1))
return loss
6. 自监督学习在NLP中的应用¶
6.1 BERT风格:掩码语言建模¶
Python
class MLMHead(nn.Module):
"""掩码语言建模头"""
def __init__(self, hidden_dim, vocab_size):
super().__init__()
self.dense = nn.Linear(hidden_dim, hidden_dim)
self.activation = nn.GELU()
self.layer_norm = nn.LayerNorm(hidden_dim)
self.decoder = nn.Linear(hidden_dim, vocab_size)
def forward(self, hidden_states):
x = self.layer_norm(self.activation(self.dense(hidden_states)))
return self.decoder(x)
def create_mlm_batch(input_ids, vocab_size, mask_token_id, mask_prob=0.15):
"""创建 MLM 训练批次"""
labels = input_ids.clone()
# 随机选择 15% 的 token
probability_matrix = torch.full(input_ids.shape, mask_prob)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # 只对被掩码的位置计算损失
# 80% 替换为 [MASK]
replace_mask = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[replace_mask] = mask_token_id
# 10% 替换为随机 token
random_mask = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~replace_mask
random_tokens = torch.randint(vocab_size, input_ids.shape)
input_ids[random_mask] = random_tokens[random_mask]
# 10% 保持不变
return input_ids, labels
6.2 GPT风格:因果语言建模¶
Python
class CLMHead(nn.Module):
"""因果语言建模(自回归)"""
def __init__(self, hidden_dim, vocab_size):
super().__init__()
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(self, hidden_states, input_ids):
logits = self.lm_head(hidden_states)
# 预测下一个 token
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return loss
7. 下游任务迁移¶
7.1 线性探测(Linear Probing)¶
Python
class LinearProbe(nn.Module):
"""线性探测 — 冻结预训练编码器,只训练线性分类头"""
def __init__(self, encoder, feature_dim, num_classes):
super().__init__()
self.encoder = encoder
for param in self.encoder.parameters():
param.requires_grad = False
self.classifier = nn.Linear(feature_dim, num_classes)
def forward(self, x):
with torch.no_grad():
features = self.encoder(x)
return self.classifier(features)
def evaluate_representations(encoder, feature_dim, train_loader, test_loader,
num_classes=10, epochs=100, device='cuda'):
"""评估自监督表示质量"""
probe = LinearProbe(encoder, feature_dim, num_classes).to(device)
optimizer = torch.optim.Adam(probe.classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
probe.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = probe(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试
probe.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
logits = probe(images)
correct += (logits.argmax(1) == labels).sum().item()
total += labels.size(0)
accuracy = 100. * correct / total
print(f"Linear Probe 准确率: {accuracy:.2f}%")
return accuracy
8. 练习与自我检查¶
练习题¶
- SimCLR:在 CIFAR-10 上训练 SimCLR,然后用线性探测评估表示质量。
- 数据增强消融:去掉 SimCLR 的不同增强操作,分析哪些增强最重要。
- 对比损失:实现 InfoNCE 损失,实验不同温度参数 \(\tau\) 的影响。
- BYOL vs SimCLR:在同一设置下对比两者的性能。
- 预训练效果:对比随机初始化、ImageNet 监督预训练、自监督预训练在小数据集上的迁移效果。
自我检查清单¶
- 理解自监督学习的核心动机
- 能区分对比学习、自蒸馏、掩码建模三种范式
- 理解 InfoNCE 损失的推导
- 知道 SimCLR 的关键设计(增强、投影头、大batch)
- 理解 BYOL 为什么不需要负样本
- 了解 MLM 和 CLM 的区别
- 能用线性探测评估表示质量
下一篇: 实战项目/03-文本生成实战




