跳转至

04-神经架构搜索(NAS)

学习时间: 约5-7小时 难度级别: ⭐⭐⭐⭐⭐ 高级 前置知识: 深度学习基础、CNN/Transformer架构、强化学习概念、PyTorch 学习目标: 理解NAS的核心范式(搜索空间、搜索策略、评估策略),掌握DARTS等可微分方法的原理与实现


目录


1. 为什么需要NAS

1.1 手工设计的局限

挑战 说明
搜索空间巨大 层数、通道数、连接方式的组合呈指数级增长
依赖专家经验 好的架构往往来自直觉和大量试错
任务特异性 不同任务/数据集可能需要不同最优架构
人力成本高 手工调整耗费大量研究者时间

1.2 NAS的目标

自动在搜索空间中找到在目标任务上表现最优的神经网络架构。

1.3 NAS发展里程碑

年份 方法 贡献 GPU天数
2017 NASNet RL + Cell搜索 1800
2018 ENAS 权重共享 0.5
2019 DARTS 可微分搜索 1
2020 OFA 一次训练弹性推理 2
2021+ Zero-cost NAS 无需训练的代理指标 <0.01

2. NAS框架概述

NAS搜索策略

2.1 三大组件

Text Only
NAS 框架
├── 搜索空间 (Search Space)
│   ├── 定义候选操作(卷积、池化、跳连等)
│   └── 定义网络拓扑结构
├── 搜索策略 (Search Strategy)
│   ├── 强化学习
│   ├── 进化算法
│   ├── 贝叶斯优化
│   └── 梯度方法(可微分)
└── 评估策略 (Performance Estimation)
    ├── 完整训练(昂贵)
    ├── 缩短训练 / 低保真代理
    ├── 权重共享 (Weight Sharing)
    └── 零代价代理 (Zero-cost Proxy)

2.2 核心权衡

NAS权衡

  • 搜索空间越大:表达能力越强,但搜索难度越大
  • 评估越精确:搜索越准确,但计算代价越高
  • 搜索策略:效率与全局最优之间的平衡

3. 搜索空间设计

NAS搜索空间

3.1 宏搜索 vs 微搜索

类型 说明 代表
宏搜索 搜索整个网络架构(层数、每层类型) 原始NAS
微搜索(Cell-based) 只搜索一个Cell结构,然后堆叠 NASNet, DARTS

3.2 Cell-based搜索空间

Text Only
一个Cell包含:
- N个有序节点 (通常N=4~7)
- 每对节点之间选择一个操作
- 候选操作集合: {3x3 conv, 5x5 conv, 3x3 sep_conv,
                  5x5 sep_conv, max_pool, avg_pool,
                  skip_connect, none}

3.3 两种Cell类型

  • Normal Cell:空间分辨率不变,提取特征
  • Reduction Cell:空间分辨率减半,下采样

最终网络由这两种Cell按预定模式堆叠。


4. 基于强化学习的NAS

4.1 核心思想

将架构搜索建模为序列决策问题:

  • 控制器(RNN)生成架构编码(序列 of tokens)
  • 环境:训练生成的架构并返回验证精度作为奖励
  • 优化:用 REINFORCE 更新控制器参数

4.2 流程

Text Only
控制器(RNN) → 采样架构 → 训练子网络 → 验证精度(R) → 更新控制器
      ↑___________________________________|

4.3 REINFORCE梯度

\[\nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \log P(a; \theta) \cdot (R - b)]\]

其中 \(b\) 为基线(历史平均奖励),用于降低方差。

4.4 局限性

  • 需要采样并训练大量子网络,计算极度昂贵
  • NASNet 使用 500 个 GPU 搜索了 4 天

5. 基于进化算法的NAS

5.1 基本流程

Text Only
初始化种群 (随机架构)
while 未达到预算:
    1. 选择: 锦标赛选择适应度高的个体
    2. 变异: 随机修改架构(改变操作、连接)
    3. 评估: 训练并计算验证精度
    4. 更新: 将优秀个体加入种群,淘汰差的

5.2 AmoebaNet的变异操作

变异类型 说明
替换操作 将某条边的操作换为另一个候选操作
改变连接 将某条边的输入节点改为另一个节点
添加/删除边 增加或移除节点之间的连接

5.3 优缺点

优点:简单直观,容易并行化,不需要梯度 缺点:同样需要大量评估,搜索效率低于梯度方法


6. 可微分架构搜索(DARTS)

DARTS可微分架构搜索

6.1 核心创新

将离散的架构选择松弛为连续优化问题:

对每条边上的操作混合使用 softmax 加权:

\[\bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o'} \exp(\alpha_{o'}^{(i,j)})} \cdot o(x)\]

其中 \(\alpha\) 为可学习的架构参数。

6.2 双层优化

\[\min_\alpha \quad \mathcal{L}_{val}(w^*(\alpha), \alpha)$$ $$\text{s.t.} \quad w^*(\alpha) = \arg\min_w \mathcal{L}_{train}(w, \alpha)\]
  • 内层优化:固定 \(\alpha\),更新网络权重 \(w\)
  • 外层优化:固定 \(w\),更新架构参数 \(\alpha\)

6.3 近似求解

完整双层优化不可行,DARTS用一阶近似交替优化:

Text Only
for each step:
    1. 在训练集上更新 w(一步SGD)
    2. 在验证集上更新 α(一步SGD)

6.4 架构离散化

搜索完成后,对每条边保留权重最大的操作:

\[o^{(i,j)} = \arg\max_{o \in \mathcal{O}} \alpha_o^{(i,j)}\]

7. 高效NAS方法

ENAS高效架构搜索

7.1 权重共享(ENAS)

所有候选架构共享同一套权重,避免从头训练每个架构:

方面 无权重共享 权重共享
每个架构评估 完整训练 继承共享权重
GPU天数 1800+ 0.5
精度 更可靠 可能有偏差

7.2 One-Shot NAS

训练一个超网络(Supernet),所有候选架构是其子图:

Text Only
超网络训练 → 架构评估(通过子图采样)→ 导出最优架构

7.3 Zero-Cost NAS

无需任何训练,使用代理指标在初始化时评估架构:

代理指标 计算内容 思想
grad_norm 梯度范数 梯度流越通畅,架构越好
snip 连接敏感度 重要连接对loss影响大
synflow 参数保守度 保持信号传播的架构更优

8. 实战:DARTS简化实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

# 候选操作集合
OPS = {
    'sep_conv_3x3': lambda C: SepConv(C, C, 3, 1),  # lambda匿名函数,简洁定义单行函数
    'sep_conv_5x5': lambda C: SepConv(C, C, 5, 2),
    'max_pool_3x3': lambda C: nn.MaxPool2d(3, 1, 1),
    'avg_pool_3x3': lambda C: nn.AvgPool2d(3, 1, 1),
    'skip_connect': lambda C: nn.Identity(),
    'none':         lambda C: Zero(),
}

class SepConv(nn.Module):  # 继承nn.Module定义神经网络层
    """深度可分离卷积"""
    def __init__(self, C_in, C_out, kernel, padding):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.op = nn.Sequential(
            nn.Conv2d(C_in, C_in, kernel, 1, padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, 1, bias=False),
            nn.BatchNorm2d(C_out),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.op(x)

class Zero(nn.Module):
    def forward(self, x):
        return x * 0

class MixedOp(nn.Module):
    """混合操作:所有候选操作的加权和"""
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([
            OPS[name](C) for name in OPS.keys()  # 列表推导式,简洁创建列表
        ])

    def forward(self, x, weights):
        return sum(w * op(x) for w, op in zip(weights, self.ops))  # zip按位置配对多个可迭代对象

class DARTSCell(nn.Module):
    """一个DARTS搜索单元"""
    def __init__(self, C, n_nodes=4):
        super().__init__()
        self.n_nodes = n_nodes
        self.ops = nn.ModuleList()
        # 每对节点之间都有一组混合操作
        for i in range(n_nodes):
            for j in range(i + 2):  # 包含2个输入节点
                self.ops.append(MixedOp(C))
        # 拼接n_nodes个中间节点后通道数为C*n_nodes,需要降回C
        self.channel_reduce = nn.Sequential(
            nn.Conv2d(C * n_nodes, C, 1, bias=False),
            nn.BatchNorm2d(C),
        )

    def forward(self, s0, s1, alphas):
        states = [s0, s1]
        offset = 0
        for i in range(self.n_nodes):
            s = sum(
                self.ops[offset + j](states[j], F.softmax(alphas[offset + j], dim=-1))
                for j in range(len(states))
            )
            offset += len(states)
            states.append(s)
        # 拼接中间节点后用1x1卷积将通道数从C*n_nodes降回C
        return self.channel_reduce(torch.cat(states[2:], dim=1))

class DARTSNetwork(nn.Module):
    """简化的DARTS搜索网络"""
    def __init__(self, C=16, n_classes=10, n_cells=4, n_nodes=4):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, C, 3, 1, 1, bias=False),
            nn.BatchNorm2d(C),
        )
        self.cells = nn.ModuleList([
            DARTSCell(C, n_nodes) for _ in range(n_cells)
        ])
        self.classifier = nn.Linear(C, n_classes)

        # 架构参数
        n_edges = sum(i + 2 for i in range(n_nodes))
        n_ops = len(OPS)
        self.alphas = nn.Parameter(torch.randn(n_edges, n_ops) * 1e-3)

    def forward(self, x):
        s0 = s1 = self.stem(x)
        for cell in self.cells:
            s0, s1 = s1, cell(s0, s1, self.alphas)
        out = F.adaptive_avg_pool2d(s1, 1).flatten(1)  # 链式调用,连续执行多个方法
        return self.classifier(out)

# 训练示意
def train_darts():
    model = DARTSNetwork(C=16, n_classes=10)
    w_optim = torch.optim.SGD(
        [p for n, p in model.named_parameters() if n != 'alphas'],
        lr=0.025, momentum=0.9, weight_decay=3e-4
    )
    a_optim = torch.optim.Adam([model.alphas], lr=3e-4, weight_decay=1e-3)

    # 交替优化(伪代码)
    for epoch in range(50):
        for x_train, y_train, x_val, y_val in zip_loaders():
            # 步骤1: 更新网络权重w
            w_optim.zero_grad()
            loss_train = F.cross_entropy(model(x_train), y_train)
            loss_train.backward()  # 反向传播计算梯度
            w_optim.step()

            # 步骤2: 更新架构参数α
            a_optim.zero_grad()
            loss_val = F.cross_entropy(model(x_val), y_val)
            loss_val.backward()
            a_optim.step()

    # 导出离散架构
    ops_names = list(OPS.keys())
    best_ops = [ops_names[i] for i in model.alphas.argmax(dim=-1)]
    print("搜索到的架构:", best_ops)

9. 面试高频题

Q1: NAS的三大组件是什么?

:搜索空间(定义候选操作和网络拓扑)、搜索策略(如何在空间中寻找最优架构,包括RL、进化、梯度方法等)、评估策略(如何高效估计候选架构的性能,包括完整训练、权重共享、零代价代理等)。

Q2: DARTS的核心思想是什么?

:DARTS将离散的架构选择松弛为连续优化问题。在搜索阶段,每条边上所有候选操作通过softmax加权混合;通过双层优化交替更新网络权重和架构参数;搜索完成后取每条边上权重最大的操作进行离散化。这样可以用梯度下降高效搜索,将搜索时间从数千GPU天降至约1GPU天。

Q3: 权重共享的优点和潜在问题?

:优点是大幅减少计算量(所有子架构共享权重,无需单独训练)。潜在问题:(1) 共享权重下的架构排名可能与独立训练后的排名不一致;(2) 超网络的训练质量影响所有子架构的评估;(3) 可能存在权重耦合,导致某些架构被不公平地评估。

Q4: 如何设计好的搜索空间?

:好的搜索空间应:(1) 包含已知有效的操作(如残差连接、深度可分离卷积);(2) 大小适中——太小限制表达能力,太大搜索困难;(3) 使用Cell-based设计减少搜索复杂度;(4) 考虑硬件约束(如延迟、内存)。

Q5: Zero-Cost NAS的原理是什么?可靠吗?

:Zero-Cost NAS在网络初始化时使用代理指标(如梯度范数、连接敏感度)来评估架构潜力,无需任何训练。其优点是搜索速度极快(秒级),但可靠性有限——代理指标与最终性能的相关性取决于搜索空间和任务,某些场景下效果可能不稳定。


10. 练习与自我检查

编程练习

  1. 基础:实现一个简单的随机搜索NAS,在CIFAR-10上搜索卷积核大小和层数
  2. 进阶:实现简化版DARTS,搜索一个4节点Cell
  3. 挑战:加入硬件感知约束(如FLOPs限制),实现多目标NAS

检查清单

  • 能说明NAS的三大组件(搜索空间、搜索策略、评估策略)
  • 理解Cell-based搜索空间的设计思想
  • 知道RL-based NAS的工作流程和局限性
  • 能解释DARTS的连续松弛和双层优化
  • 理解权重共享的原理及潜在问题
  • 了解One-Shot NAS和Zero-Cost NAS的基本思想
  • 能阅读和理解DARTS的核心代码
  • 知道NAS在实际工业部署中的应用现状

扩展阅读: - Zoph & Le, 2017: Neural Architecture Search with Reinforcement Learning - Liu et al., 2019: DARTS: Differentiable Architecture Search - Cai et al., 2020: Once-for-All: Train One Network and Specialize it for Efficient Deployment - White et al., 2023: Neural Architecture Search: Insights from 1000 Papers