21 - 元学习 (Meta-Learning)¶
🎯 什么是元学习?¶
核心思想¶
Text Only
传统机器学习:
在大量数据上训练 → 在同类数据上测试
元学习(学习如何学习):
在多个任务上学习 → 快速适应新任务
类比:
传统ML:学习识别猫
元学习:学习如何快速学会识别新动物
为什么需要元学习?¶
Text Only
小样本学习 (Few-Shot Learning):
- 每个类别只有少量样本(1-5个)
- 传统方法会严重过拟合
快速适应:
- 新任务只需几步梯度更新
- 适合个性化、持续学习场景
📚 元学习分类¶
Text Only
元学习方法
├── 基于优化的方法
│ ├── MAML (模型无关元学习)
│ └── Meta-SGD
├── 基于度量的方法
│ ├── Siamese Network
│ ├── Matching Network
│ └── Prototypical Network
└── 基于记忆的方法
├── Memory-Augmented Neural Networks
└── Neural Turing Machine
⚡ MAML (Model-Agnostic Meta-Learning)¶
核心思想¶
Text Only
目标:学习一个好的初始化参数
使得在新任务上只需少量梯度步就能达到好的性能
关键:
- 在支持集(support set)上计算梯度
- 在查询集(query set)上评估性能
- 优化初始化参数使得这种适应能力最强
算法流程¶
Text Only
外层循环(元训练):
对于每个batch的任务:
对于每个任务:
1. 在支持集上计算损失
2. 一步梯度下降:θ' = θ - α∇L
3. 在查询集上评估θ'的性能
4. 汇总所有任务的查询集损失
5. 更新初始化参数:θ = θ - β∇ΣL'
代码实现¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MAML(nn.Module): # 继承nn.Module定义神经网络层
def __init__(self, model, inner_lr=0.01, num_inner_steps=1): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.model = model
self.inner_lr = inner_lr
self.num_inner_steps = num_inner_steps
def inner_loop(self, support_x, support_y):
"""内循环:在支持集上适应"""
# 创建任务特定的参数副本
fast_weights = [p.clone() for p in self.model.parameters()]
for _ in range(self.num_inner_steps):
# 前向传播
logits = self.forward_with_params(support_x, fast_weights)
loss = F.cross_entropy(logits, support_y)
# 计算梯度
grads = torch.autograd.grad(loss, fast_weights, create_graph=True)
# 更新快速权重
fast_weights = [w - self.inner_lr * g
for w, g in zip(fast_weights, grads)] # zip按位置配对多个可迭代对象
return fast_weights
def forward_with_params(self, x, params):
"""使用指定参数前向传播(简化为2层MLP示例)
注意:这里使用简单的MLP结构来演示MAML核心原理。
若使用SimpleCNN等复杂模型,需通过torch.nn.utils中的
功能性API遍历每一层进行参数替换。
"""
x = x.view(x.size(0), -1) # 展平输入
x = F.linear(x, params[0], params[1])
x = F.relu(x)
x = F.linear(x, params[2], params[3])
return x
def forward(self, batch_tasks, meta_optimizer):
"""元训练"""
meta_loss = 0
for task in batch_tasks:
support_x, support_y = task['support']
query_x, query_y = task['query']
# 内循环适应
fast_weights = self.inner_loop(support_x, support_y)
# 在查询集上评估
query_logits = self.forward_with_params(query_x, fast_weights)
task_loss = F.cross_entropy(query_logits, query_y)
meta_loss += task_loss
# 元更新
meta_optimizer.zero_grad()
meta_loss.backward() # 反向传播计算梯度
meta_optimizer.step() # 根据梯度更新模型参数
return meta_loss.item() # .item()将单元素张量转为Python数值
# 简单的4层CNN
class SimpleCNN(nn.Module):
def __init__(self, in_channels=3, num_classes=5):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
MAML的优缺点¶
🎯 基于度量的方法¶
Prototypical Networks (原型网络)¶
核心思想¶
算法流程¶
Text Only
1. 编码器将支持集和查询集映射到嵌入空间
2. 计算每个类别的原型:c_k = (1/|S_k|) Σ f(x_i)
3. 查询样本分类:基于到各原型的距离
p(y=k|x) = softmax(-d(f(x), c_k))
代码实现¶
Python
class PrototypicalNetwork(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def euclidean_distance(self, x, y):
"""计算欧氏距离"""
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d) # 链式调用,连续执行多个方法 # unsqueeze增加一个维度
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
def forward(self, support_x, support_y, query_x, n_way, k_shot):
"""
support_x: (n_way * k_shot, ...)
support_y: (n_way * k_shot,)
query_x: (n_query, ...)
"""
# 编码
support_embed = self.encoder(support_x) # (n_way*k_shot, embed_dim)
query_embed = self.encoder(query_x) # (n_query, embed_dim)
# 计算原型
prototypes = []
for k in range(n_way):
class_mask = (support_y == k)
class_embed = support_embed[class_mask]
prototype = class_embed.mean(dim=0)
prototypes.append(prototype)
prototypes = torch.stack(prototypes) # (n_way, embed_dim)
# 计算查询样本到原型的距离
distances = self.euclidean_distance(query_embed, prototypes)
# 转换为概率(距离越小,概率越大)
logits = -distances
return logits
# 训练
proto_net = PrototypicalNetwork(SimpleCNN())
optimizer = torch.optim.Adam(proto_net.parameters(), lr=1e-3)
for episode in range(num_episodes):
# 采样一个episode
support_x, support_y, query_x, query_y = sample_episode(dataset, n_way, k_shot, n_query)
# 前向传播
logits = proto_net(support_x, support_y, query_x, n_way, k_shot)
# 计算损失
loss = F.cross_entropy(logits, query_y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
Matching Networks¶
Text Only
核心思想:
查询样本的标签 = 支持集标签的加权平均
权重 = 查询样本与支持样本的注意力
注意力计算:
a(x̂, x_i) = softmax(cosine(f(x̂), g(x_i)))
其中f和g可以是同一个编码器
📊 任务构建¶
N-Way K-Shot 任务¶
Text Only
定义:
- N-way:N个类别
- K-shot:每个类别K个支持样本
- Q-query:每个类别Q个查询样本
示例(5-way 1-shot):
支持集:5个类别,每类1个样本 → 共5个样本
查询集:5个类别,每类15个样本 → 共75个样本
Episode采样¶
Python
class EpisodeSampler:
def __init__(self, dataset, n_way, k_shot, n_query):
self.dataset = dataset
self.n_way = n_way
self.k_shot = k_shot
self.n_query = n_query
# 按类别组织数据
self.class_to_indices = {}
for idx, label in enumerate(dataset.labels): # enumerate同时获取索引和元素
if label not in self.class_to_indices:
self.class_to_indices[label] = []
self.class_to_indices[label].append(idx)
self.classes = list(self.class_to_indices.keys())
def sample_episode(self):
"""采样一个episode"""
# 随机选择n_way个类别
episode_classes = np.random.choice(self.classes, self.n_way, replace=False)
support_x, support_y = [], []
query_x, query_y = [], []
for new_label, old_label in enumerate(episode_classes):
# 获取该类所有样本索引
indices = self.class_to_indices[old_label]
# 随机选择k_shot + n_query个样本
selected = np.random.choice(indices, self.k_shot + self.n_query, replace=False)
# 分为支持集和查询集
support_indices = selected[:self.k_shot]
query_indices = selected[self.k_shot:]
# 收集样本
for idx in support_indices:
support_x.append(self.dataset[idx][0])
support_y.append(new_label)
for idx in query_indices:
query_x.append(self.dataset[idx][0])
query_y.append(new_label)
return (torch.stack(support_x), torch.tensor(support_y),
torch.stack(query_x), torch.tensor(query_y))
🎯 常用数据集¶
Omniglot¶
miniImageNet¶
tieredImageNet¶
💡 总结¶
Text Only
元学习核心:
学习如何快速学习新任务
主要方法:
1. 基于优化:MAML,学习好的初始化
2. 基于度量:Prototypical Networks,学习距离度量
3. 基于记忆:存储和检索经验
应用场景:
- 小样本图像分类
- 个性化推荐
- 快速适应的机器人
- 持续学习
实践建议:
1. 从Prototypical Networks开始(简单有效)
2. 理解MAML的二阶优化
3. 注意任务分布的一致性
4. 选择合适的N-way K-shot设置
🎓 完整学习路径回顾¶
恭喜!你已经完成了整个机器学习教程的学习:
Text Only
基础阶段:
00-数学基础 → 01-基础概念 → 02-监督学习 → 03-无监督学习
进阶阶段:
04-时序模型 → 05-深度学习 → 06-实践指南 → 07-模型评估与调优
专项阶段:
08-特征工程 → 09-深度学习进阶 → 10-强化学习基础 → 11-MLOps与部署
高阶阶段:
12-集成学习进阶 → 13-降维与流形学习 → 14-贝叶斯方法 → 15-其他监督学习算法
16-聚类算法进阶 → 17-图神经网络 → 18-NLP与Transformer详解
19-生成模型深度解析 → 20-自监督学习 → 21-元学习
下一步:
- 选择感兴趣的方向深入研究
- 参与实际项目积累经验
- 关注领域最新进展
- 持续学习,不断进步!
祝你在机器学习的道路上越走越远!🚀