跳转至

21 - 元学习 (Meta-Learning)

元学习图

🎯 什么是元学习?

核心思想

Text Only
传统机器学习:
在大量数据上训练 → 在同类数据上测试

元学习(学习如何学习):
在多个任务上学习 → 快速适应新任务

类比:
传统ML:学习识别猫
元学习:学习如何快速学会识别新动物

为什么需要元学习?

Text Only
小样本学习 (Few-Shot Learning):
- 每个类别只有少量样本(1-5个)
- 传统方法会严重过拟合

快速适应:
- 新任务只需几步梯度更新
- 适合个性化、持续学习场景

📚 元学习分类

Text Only
元学习方法
├── 基于优化的方法
│   ├── MAML (模型无关元学习)
│   └── Meta-SGD
├── 基于度量的方法
│   ├── Siamese Network
│   ├── Matching Network
│   └── Prototypical Network
└── 基于记忆的方法
    ├── Memory-Augmented Neural Networks
    └── Neural Turing Machine

⚡ MAML (Model-Agnostic Meta-Learning)

核心思想

Text Only
目标:学习一个好的初始化参数
使得在新任务上只需少量梯度步就能达到好的性能

关键:
- 在支持集(support set)上计算梯度
- 在查询集(query set)上评估性能
- 优化初始化参数使得这种适应能力最强

算法流程

Text Only
外层循环(元训练):
  对于每个batch的任务:
    对于每个任务:
      1. 在支持集上计算损失
      2. 一步梯度下降:θ' = θ - α∇L
      3. 在查询集上评估θ'的性能

    4. 汇总所有任务的查询集损失
    5. 更新初始化参数:θ = θ - β∇ΣL'

代码实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class MAML(nn.Module):  # 继承nn.Module定义神经网络层
    def __init__(self, model, inner_lr=0.01, num_inner_steps=1):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.model = model
        self.inner_lr = inner_lr
        self.num_inner_steps = num_inner_steps

    def inner_loop(self, support_x, support_y):
        """内循环:在支持集上适应"""
        # 创建任务特定的参数副本
        fast_weights = [p.clone() for p in self.model.parameters()]

        for _ in range(self.num_inner_steps):
            # 前向传播
            logits = self.forward_with_params(support_x, fast_weights)
            loss = F.cross_entropy(logits, support_y)

            # 计算梯度
            grads = torch.autograd.grad(loss, fast_weights, create_graph=True)

            # 更新快速权重
            fast_weights = [w - self.inner_lr * g
                          for w, g in zip(fast_weights, grads)]  # zip按位置配对多个可迭代对象

        return fast_weights

    def forward_with_params(self, x, params):
        """使用指定参数前向传播(简化为2层MLP示例)

        注意:这里使用简单的MLP结构来演示MAML核心原理。
        若使用SimpleCNN等复杂模型,需通过torch.nn.utils中的
        功能性API遍历每一层进行参数替换。
        """
        x = x.view(x.size(0), -1)  # 展平输入
        x = F.linear(x, params[0], params[1])
        x = F.relu(x)
        x = F.linear(x, params[2], params[3])
        return x

    def forward(self, batch_tasks, meta_optimizer):
        """元训练"""
        meta_loss = 0

        for task in batch_tasks:
            support_x, support_y = task['support']
            query_x, query_y = task['query']

            # 内循环适应
            fast_weights = self.inner_loop(support_x, support_y)

            # 在查询集上评估
            query_logits = self.forward_with_params(query_x, fast_weights)
            task_loss = F.cross_entropy(query_logits, query_y)
            meta_loss += task_loss

        # 元更新
        meta_optimizer.zero_grad()
        meta_loss.backward()  # 反向传播计算梯度
        meta_optimizer.step()  # 根据梯度更新模型参数

        return meta_loss.item()  # .item()将单元素张量转为Python数值

# 简单的4层CNN
class SimpleCNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

MAML的优缺点

Text Only
优点:
✅ 模型无关,可应用于任何基于梯度的模型
✅ 理论保证,收敛性好
✅ 适应速度快(1-5步梯度更新)

缺点:
❌ 需要二阶导数,计算开销大
❌ 内存消耗高(需要存储计算图)

🎯 基于度量的方法

Prototypical Networks (原型网络)

核心思想

Text Only
每个类别学习一个原型(prototype)
新样本分类:找到最近的原型

原型 = 该类所有支持样本嵌入的平均

算法流程

Text Only
1. 编码器将支持集和查询集映射到嵌入空间
2. 计算每个类别的原型:c_k = (1/|S_k|) Σ f(x_i)
3. 查询样本分类:基于到各原型的距离
   p(y=k|x) = softmax(-d(f(x), c_k))

代码实现

Python
class PrototypicalNetwork(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def euclidean_distance(self, x, y):
        """计算欧氏距离"""
        n = x.size(0)
        m = y.size(0)
        d = x.size(1)

        x = x.unsqueeze(1).expand(n, m, d)  # 链式调用,连续执行多个方法  # unsqueeze增加一个维度
        y = y.unsqueeze(0).expand(n, m, d)

        return torch.pow(x - y, 2).sum(2)

    def forward(self, support_x, support_y, query_x, n_way, k_shot):
        """
        support_x: (n_way * k_shot, ...)
        support_y: (n_way * k_shot,)
        query_x: (n_query, ...)
        """
        # 编码
        support_embed = self.encoder(support_x)  # (n_way*k_shot, embed_dim)
        query_embed = self.encoder(query_x)      # (n_query, embed_dim)

        # 计算原型
        prototypes = []
        for k in range(n_way):
            class_mask = (support_y == k)
            class_embed = support_embed[class_mask]
            prototype = class_embed.mean(dim=0)
            prototypes.append(prototype)

        prototypes = torch.stack(prototypes)  # (n_way, embed_dim)

        # 计算查询样本到原型的距离
        distances = self.euclidean_distance(query_embed, prototypes)

        # 转换为概率(距离越小,概率越大)
        logits = -distances

        return logits

# 训练
proto_net = PrototypicalNetwork(SimpleCNN())
optimizer = torch.optim.Adam(proto_net.parameters(), lr=1e-3)

for episode in range(num_episodes):
    # 采样一个episode
    support_x, support_y, query_x, query_y = sample_episode(dataset, n_way, k_shot, n_query)

    # 前向传播
    logits = proto_net(support_x, support_y, query_x, n_way, k_shot)

    # 计算损失
    loss = F.cross_entropy(logits, query_y)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Matching Networks

Text Only
核心思想:
查询样本的标签 = 支持集标签的加权平均
权重 = 查询样本与支持样本的注意力

注意力计算:
a(x̂, x_i) = softmax(cosine(f(x̂), g(x_i)))

其中f和g可以是同一个编码器

📊 任务构建

N-Way K-Shot 任务

Text Only
定义:
- N-way:N个类别
- K-shot:每个类别K个支持样本
- Q-query:每个类别Q个查询样本

示例(5-way 1-shot):
支持集:5个类别,每类1个样本 → 共5个样本
查询集:5个类别,每类15个样本 → 共75个样本

Episode采样

Python
class EpisodeSampler:
    def __init__(self, dataset, n_way, k_shot, n_query):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query

        # 按类别组织数据
        self.class_to_indices = {}
        for idx, label in enumerate(dataset.labels):  # enumerate同时获取索引和元素
            if label not in self.class_to_indices:
                self.class_to_indices[label] = []
            self.class_to_indices[label].append(idx)

        self.classes = list(self.class_to_indices.keys())

    def sample_episode(self):
        """采样一个episode"""
        # 随机选择n_way个类别
        episode_classes = np.random.choice(self.classes, self.n_way, replace=False)

        support_x, support_y = [], []
        query_x, query_y = [], []

        for new_label, old_label in enumerate(episode_classes):
            # 获取该类所有样本索引
            indices = self.class_to_indices[old_label]

            # 随机选择k_shot + n_query个样本
            selected = np.random.choice(indices, self.k_shot + self.n_query, replace=False)

            # 分为支持集和查询集
            support_indices = selected[:self.k_shot]
            query_indices = selected[self.k_shot:]

            # 收集样本
            for idx in support_indices:
                support_x.append(self.dataset[idx][0])
                support_y.append(new_label)

            for idx in query_indices:
                query_x.append(self.dataset[idx][0])
                query_y.append(new_label)

        return (torch.stack(support_x), torch.tensor(support_y),
                torch.stack(query_x), torch.tensor(query_y))

🎯 常用数据集

Omniglot

Text Only
- 1623个字符类
- 每类20个样本
- 小尺寸:28x28
- 元学习标准基准

miniImageNet

Text Only
- 100个类别
- 每类600个样本
- 尺寸:84x84
- 从ImageNet选取

tieredImageNet

Text Only
- 608个类别
- 按类别层次组织
- 训练/验证/测试类别不重叠
- 更难的任务

💡 总结

Text Only
元学习核心:
学习如何快速学习新任务

主要方法:
1. 基于优化:MAML,学习好的初始化
2. 基于度量:Prototypical Networks,学习距离度量
3. 基于记忆:存储和检索经验

应用场景:
- 小样本图像分类
- 个性化推荐
- 快速适应的机器人
- 持续学习

实践建议:
1. 从Prototypical Networks开始(简单有效)
2. 理解MAML的二阶优化
3. 注意任务分布的一致性
4. 选择合适的N-way K-shot设置

🎓 完整学习路径回顾

恭喜!你已经完成了整个机器学习教程的学习:

Text Only
基础阶段:
00-数学基础 → 01-基础概念 → 02-监督学习 → 03-无监督学习

进阶阶段:
04-时序模型 → 05-深度学习 → 06-实践指南 → 07-模型评估与调优

专项阶段:
08-特征工程 → 09-深度学习进阶 → 10-强化学习基础 → 11-MLOps与部署

高阶阶段:
12-集成学习进阶 → 13-降维与流形学习 → 14-贝叶斯方法 → 15-其他监督学习算法
16-聚类算法进阶 → 17-图神经网络 → 18-NLP与Transformer详解
19-生成模型深度解析 → 20-自监督学习 → 21-元学习

下一步:
- 选择感兴趣的方向深入研究
- 参与实际项目积累经验
- 关注领域最新进展
- 持续学习,不断进步!

祝你在机器学习的道路上越走越远!🚀