首页 > 新闻资讯 > 正文

【文献分享】NAS系列文章第一期:DARTS

作者:时间:2022-09-13点击数:


本期开始为大家分享NASNeural Architecture Search)系列文章,NAS的目的是发现神经网络的最佳结构,以满足特定的需要,其本质是将人工调整神经网络的过程变成自动执行任务以发现最佳架构的过程。

深度学习模型的使用越来越大众化,但是,高效神经网络的实现通常需要架构的知识和大量的时间,并且针对不同的任务和不同需求需要设定特定的结构。一般情况下,人们会利用过去的经验或技术知识来创建和设计神经网络,通过试错的方式设计这些网络,其过程是耗时且乏味的。基于此,人们就想到能不能通过机器自己来找到最优的神经网络架构呢?从而NAS应运而出。

神经网络架构搜索之前主流的方法主要包括:强化学习,进化学习,其搜索空间都是不可微的。今天给大家分享的第一篇NAS的论文是可微分神经网络架构搜索DARTS: DIFFERENTIABLE ARCHITECTURE SEARCHDARTS是使用梯度下降来解决架构搜索的问题,所以在搜索效率上比之前不可微的方法快几个数量级。


一、架构搜索

DARTS整体的架构是通过把基本单元Cell进行堆叠从而搭建的,故DARTS的目的可以简述为:训练出较好的Cell,然后把Cell相连构成一个大网络。DARTS的基本单元Cell有两类,分别为Normal-CellReduction-Cell,同一个网络中相同类型的Cell结构是共享的,所以网络只需要预测Cell的结构即可。对于CIFAR10ImageNet的分类任务,DARTS采用N个Normal-Cell1个Reduction-Cell的方式进行次堆叠搭建架构



1.基本单位Cell

DARTS的基本单元Cell的组成:两个输入节点,四个中间节点,一个输出节点。

对于第K个Cell而言,第K-1K-2Cell的输出作为其自身两个输入,另外其自身四个中间节点的输出进行concat操作作为该Cell的自身输出。

每个中间节点都是由有向无环图中所有的前继节点计算得来。例如下图中N0N1N2N3四个中间节点,其中N0有两个前继节点,分别为两个输入节点;N1有三个前继节点节点,分别为两个输入节点和N0N2有四个前继节点,分别为两个输入节点、N0N1N3有五个前继节点,分别为两个输入节点、N0N1N2

另外,如下图所示,节点与节点之间的连线称作为边,边代表的是operation。对于一般网络而言,每一层代表着一种操作,该操作可能是卷积、池化、激活等函数,但在DARTS中每一组边代表着多种操作,并且给每个操作添加对应权重(可视为概率)。DARTS则是通过不断迭代更新各操作对应权重,从而确定出每组边最终的离散操作。图中右侧表示为每个节点之间有对应8个不同的预定义操作,共同构成一组边,另外图中最右侧为其对应的权重α值。


2.节点计算

如上所述,节点与节点之间均有一组边,每一组边均有8个候选操作。其中节点到节点的输出如下公式所示,首先通过用softmax归一化权重α,然后将各候选的离散操作的输出通过α进行加权平均从而得到节点到节点的输出。


由于每个中间节点均有多个前继节点,故对于每个中间节点的总输出如下公式所示,将所有与前继节点间的输出进行累加作为当前中间节点的总输出。


3.搜索思想

对于每一组边的候选操作都赋予权重,称为架构参数α;对于卷积、池化等操作本身的类似卷积核的参数,称为模型参数w。DARTS是通过让每组边上的操作均存在并参与训练,最后对其加权平均,将搜索空间连续松弛化,从而不断训练架构参数α,其中模型训练希望效果最好的离散操作其对应的α权重最大。

另外因为Cell的边太多,故结构太复杂,参数太多不好训练,则作者希望能生成一个更简单的网络结构,故对最终Cell结构设置两条原则:1.每组边上只保留权重最大的操作;2.每个中间节点只保留两个前继节点(根据节点间边的最大权重进行比较,保留权重最大的两组边)。


二、损失函数

DARTS的优化目标是在验证集上的损失函数,如下公式所示,公式中星号上标代表最优。


公式表明要想找到最优的架构参数α则需要Lval最小,但这需要先找到最优的模型参数w,也即需要Ltrain最小,但优化w又和架构参数α有关,所以是两级最优化问题。

作者通过一种近似的迭代优化步骤来交替更新两个参数,从而来解决上述两级最优化问题。如下公式所示,ξ是模型参数w的学习率,这种近似在架构于训练集上达到局部极值点时,w则达到最优。这表明该种近似实际上是需要在训练集上对权重执行一次梯度下降来近似最优权重w*(α)

故其核心思想是每次更新α时需要w在训练集上先进行一步优化来近似w*

DARTS的训练过程:1.在验证集Lval损失上梯度下降更新架构参数α(需要w在训练集上先进行一步优化来近似w*);2.在训练集Ltrain损失上梯度下降更新模型参数w。上述过程作为一个周期,之后不断如此循环迭代更新架构参数α和模型参数w。

最后,如果ξ=0,则称为一阶近似,此时则是直接将w作为最优,而不用在更新架构参数α前对w进行一次优化,故训练时间会有效减少,但其实验效果不佳;如果ξ>0,则需要在更新架构参数α前对w进行一次优化,其实验效果较好。


三、实验结果

上表给出了用于卷积架构的CIFAR-10结果。值得注意的是,DARTS技术取得了与目前水平相当的结果,同时使用少三个数量级的计算资源。此外,在搜索时间稍长的情况下,DARTS表现优于ENAS。

DARTS在训练好Cell后再将多个Cell进行网络结构堆叠(堆叠更多的模块可以提取更多的视觉特征,同时带来模型参数量和显存占用的增加),以CIFAR10为例,首先通过8个Cells的网络进行训练,在训练好Cell后,按照一定顺序将20Cell进行堆叠,如下图。

另外对于大数据集搜索Cell结构中,可通过先在小的数据集上(如Cifar-10)搜索Cell结构,等搜索结果出来后,再堆叠更多的Cell,应用在大数据集上(如ImageNet)。这样在搜索的过程中,子网络模型训练的时间便大幅减小,提高搜索的效率。

上表结果表明,在CIFAR-10上学习到的Cell确实可以转移到ImageNet。


四、论文小结

1.Normal-CellReduction-Cell的区别是:Normal-Cell输入到输出特征尺寸不变,而Reduction-Cell输出的特征尺寸会减半,具体实现方法:对于第K个Cell,1.首先会对两个输入进行预处理,通过判断第K-1个Cell是否是Reduction-Cell来判断是否需要对来自K-2个Cell的输入进行特征减半(保证了两个输入的特征尺寸一致);2.判断自身是否为Reduction-Cell如果是则会将stride设置为2,从而实现输出特征尺寸减半。

2.以CIFAR-10的分类架构为例,数据在架构的前向传播:首先数据会先预处理成S0,S1(S0=S1)(解决Cell 0 和Cell 1的两个输入节点问题),然后以此经过多个Normal-CellReduction-Cell将最后一个Cell的输出进行全局最大池化,最后连接10分类的全连接层。

3.对于分辨率较大的数据,可以先通过卷积池化后再输入连接到Cell结构。

4.整体训练过程:首先训练N个epoch,不断迭代更新架构参数α,每个epoch都会根据α得到对应的best Cell,根据验证集的准确率,选择最佳epoch的best Cell。根据best Cell搭建更多Cell的网络架构(此时每组边都只有一个最佳离散操作),然后训练模型参数w,最后得到最终的最优网络架构。


、论文扩展

DARTS整个过程可以划分为两步:1.拿少量的Cell通过梯度下降更新架构参数α和模型参数w,得到最优的Cell结构;2.Cell进行堆叠,此时整体架构固定,故仅更新模型参数w。

1.P-DARTS

DARTS中,搜索时候是以 8 cells with 50 epochs 来进行的,evaluate时却是用20 cells来进行的,这bias造成了精度大幅度下降P-DARTS 以渐进的方式 5 cells, 11 cells, 17 cells 分别 25 epochs 来进行,这样更能接近evaluate时的情况,故性能也更好,如下图

P-DARTS的整体过程如下图所示:

(a) Cells=5时,每组边5个候选操作,当训练好了25 epochs后,会有对应的softmax置信度。

(b) 接着进行 Cells=11的搜索,虽然深度加了一倍多,但这时每组边的候选操作将会减少接近一半,即把(a)中最后置信度较低的离散操作直接delete掉。

(c) 同样,最后进行 Cells=17的搜索,再砍掉置信度低的一半离散操,通过这样的方式来tradeoff depthmemory。

其实验结构如下图所示,P-DARTSTest Error2.50%,相比DARTS2.83%得到了有效提升。


2.PC-DARTS

PC-DARTS致力于大规模节省计算量和memory,从而进行快速且大batchsize的搜索。相对于DARTS的两点改变:

1.设计了基于channel的sampling机制,故每次只有小部分1/K channel的node来进行operation search,减少了(K-1)/K 的memory,故batchsize可增大为K倍。

2.为了解决上述channel采样导致的不稳定性,提出了边缘正规化(edge normalization),在搜索时通过学习edge-level超参来减少不确定性。

其中部分通道连接的公式:

通过sampling机制能够减少网络搜索过程中的内存占用。

边缘归一化为基本单元中的第i层网络的每一个输入分配了一个β参数,公式表示如下:

其实验结构如下图所示,P-DARTSSearch Cost0.1,相比DARTSP-DARTS0.40.3得到了极大提升,并且其Test ErrorDARTS


声明:分享源自于个人理解仅供参考,更多相关研究具体详情见原文。

原文链接:https://arxiv.org/abs/1806.09055
撰稿人:AISLE 22级研究生—吴昊  

校对人:AISLE 22级研究生—刘心月


Copyright© 人工智能与统计学习团队 All Rights Reserved. 

鄂ICP备13013419号