CN111507472A - 一种基于重要性剪枝的精度估计参数搜索方法 - Google Patents
一种基于重要性剪枝的精度估计参数搜索方法 Download PDFInfo
- Publication number
- CN111507472A CN111507472A CN202010259224.XA CN202010259224A CN111507472A CN 111507472 A CN111507472 A CN 111507472A CN 202010259224 A CN202010259224 A CN 202010259224A CN 111507472 A CN111507472 A CN 111507472A
- Authority
- CN
- China
- Prior art keywords
- training
- parameter
- hyper
- importance
- neural networks
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种基于重要性剪枝的精度估计参数搜索方法,其引入动态剪枝的思想,因此能在一定程度上解决神经网络结构检索难以训练的问题,在速度以及精度上,皆达到了最优。
Description
技术领域
本发明涉及神经架构搜索技术领域,具体涉及一种基于重要性剪枝的精度估计参数搜索方法。
背景技术
近年来,随着人工智能以及深度学习的发展,人们对于定制化的深度学习网络结构开始出现指数级别的增长。用户更多的希望深度学习对于当前自身的任务,产生定制化的网络结构以及参数,这就引导了神经网络结构检索系统的产生。给定数据集,神经架构搜索(NAS)旨在通过搜索算法在巨大的搜索空间中发现高性能卷积架构。NAS在各种深度学习任务的自动化架构搜索中取得了很大成功,例如图像分类,语言建模和语义分割。
神经架构搜索方法由三部分组成:搜索空间,搜索策略和性能评估。传统的NAS算法通过搜索策略对特定的卷积架构进行采样并估计性能,同时,性能可以被视为更新搜索策略的目标函数。尽管取得了显着进步,但传统的神经网络结构搜索方法仍然受到密集计算和内存成本的限制。例如,基于强化学习方法,需要在4天内训练并评估在500个GPU中超过20,000个神经网络。最近的工作通过以可微分的方式制来提高可扩展性,其中搜索空间被放宽到连续的空间,从而可以通过梯度下降的验证集上的性能来优化体系结构。然而,可微分神经网络结构搜索仍然受到高GPU内存消耗的影响,其随着候选搜索集的大小线性增长。因此现有技术存在的问题是,神经网络结构搜索受到高GPU内存消耗的影响,导致神经网络结构搜索在有限的GPU下,存在搜索效率低、精度低的问题。
发明内容
针对现有技术存在的问题,本发明的目的在于提供一种基于重要性剪枝的精度估计参数搜索方法,以此提高神经网络结构搜索算法的精度。
为实现上述目的,本发明采用的技术方案是:
一种基于重要性剪枝的精度估计参数搜索方法,其包括以下步骤:
步骤1、设定神经网络结构搜索空间,超参数搜索空间、训练次数T、训练集和验证集;
步骤2、在神经网络搜索空间中,随机采样n个神经网络,在全精度训练条件f下利用训练集训练这n个神经网络,并且在验证集上进行验证操作,得到n个全精度训练后的神经网络以及相应的精度集合;
步骤3、在超参数搜索空间中,随机采样训练超参数集b;其中,采样的公式为:
其中,FLOPs为一个超参数所占用的计算复杂度,Θi,j为每个具体的训练超参数;
步骤4、将步骤3得到的训练超参数集b对步骤2中采样得到的n个神经网络进行训练以及验证,得到n个超参训练后的神经网络,以及相应的精度集合;
计算损失rs(Rf,Rb);其中,rs为斯皮尔曼相关系数,Rf为步骤2中n个全精度训练后的神经网络按照在验证集上误差排序,Rb为步骤3得到的超参训练后的神经网络按照在验证集上误差的排序;
步骤5、根据步骤3的训练超参数集b以及损失来构建随机森林,获取每一训练超参数的重要程度;
步骤6、根据计算好的重要程度进行剪枝:将重要程度最低的训练超参数进行剪枝,直接设置为时间消耗最低的训练超参数,得到训练后的训练超参数集b’;
步骤7、循环步骤3-6,直至达到固定的训练次数T;
步骤8、从T个训练超参数集b’中选出性能最优的训练超参数集b’。
所述步骤5具体如下:
针对于训练超参数的切分变量和切分点的选择,采用穷举法,遍历每个特征和每个特征的所有取值,最后从中找出最好的切分变量和切分点;针对于切分变量和切分点的好坏,以切分后节点的不纯度来衡量,即各个子节点不纯度的加权和;每个节点的不纯度计算公式如下:
其中,xi为某一个切分变量,vij为切分变量的一个切分值,nleft,nright,Ns分别为切分后左子节点的训练样本个数、右子节点的训练样本个数以及当前节点所有训练样本个数,Xleft,Xright分为左右子节点的训练样本集合,H(X)为衡量节点不纯度的函数,该函数为:
根据随机森林计算出具体的重要程度,每一训练超参数的重要程度为:
Im=|Qm|H(Qm)-(Q{keft,m}|H(Q{left,m})-|Q{right,m}|H(Q{right,m}).。
采用上述方案后,本发明具有以下有益效果:本发明针对时间消耗比较小的样本点有更高概率的采样,且对重要性比较低的超参数进行剪枝,因此可以分配更多的资源在重要性更高的超参数上,因此搜索到的超参数会更好,从而提高所有神经网络结构搜索算法的精度。
附图说明
图1为本发明的方法流程图;
图2为本发明与现有神经网络结构检索方法之间的比较。
具体实施方式
如图1所示,本发明揭示了一种基于重要性剪枝的精度估计参数搜索方法,其包括以下步骤:
步骤1、设定神经网络结构搜索空间,超参数搜索空间、训练次数T、训练集和验证集;
步骤2、在神经网络搜索空间中,随机采样n个神经网络,在全精度训练条件f下利用训练集训练这n个神经网络,并且在验证集上进行验证操作,得到n个全精度训练后的神经网络以及相应的精度集合;
步骤3、在超参数搜索空间中,随机采样训练超参数集b;其中,采样的公式为:
其中,FLOPs为一个超参数所占用的计算复杂度,Θi,j为每个具体的训练超参数;该公式对于那些计算复杂度比较低的超参数,有更高的采样概率;
步骤4、将步骤3得到的训练超参数集b对步骤2中采样得到的n个神经网络进行训练以及验证,得到n个超参训练后的神经网络,以及相应的精度集合;
计算损失rs(Rf,Rb);其中,rs为斯皮尔曼相关系数,Rf为步骤2中n个全精度训练后的神经网络按照在验证集上误差排序,Rb为步骤3得到的超参训练后的神经网络按照在验证集上误差的排序;
步骤5、根据步骤3的训练超参数集b以及损失来构建随机森林,获取每一训练超参数的重要程度;
针对于训练超参数的切分变量和切分点的选择,采用穷举法,即遍历每个特征和每个特征的所有取值,最后从中找出最好的切分变量和切分点;针对于切分变量和切分点的好坏,以切分后节点的不纯度来衡量,即各个子节点不纯度的加权和,即每个节点的不纯度计算公式如下:
其中,xi为某一个切分变量,vij为切分变量的一个切分值,nleft,nright,Ns分别为切分后左子节点的训练样本个数、右子节点的训练样本个数以及当前节点所有训练样本个数,Xleft,Xright分为左右子节点的训练样本集合,H(X)为衡量节点不纯度的函数,该函数为:
我们认为,一个节点的减去左右节点的不纯度越高,重要程度就越高,根据随机森林计算出具体的重要程度,每一训练超参数的重要程度为具体值的加权求和:
Im=|Qm|H(Qm)-|Q{left,m}|H(Q{left,m})-|Q{right,m}|H(Q{right,m});
步骤6、根据计算好的重要程度进行剪枝:将重要程度最低的训练超参数进行剪枝,直接设置为时间消耗最低的训练超参数,得到训练后的训练超参数集b’;
步骤7、循环步骤3-6,直至达到固定的训练次数T。
步骤8、从T个训练超参数集b’中选出性能最优的训练超参数集b’。利用该训练超参数集b’搜索搜索时,能够提高搜索精度。
上述效果通过以下仿真实验作进一步的说明。
1.仿真条件
本发明在Pycharm平台上进行开发,开发的深度学习框架基于Pytorch。本发明中主要用的语言为Python,并且利用OpenCV实现本发明中用到的传统视觉算法。
2.仿真内容
我们在Cifar10以及ILSVRC2012数据集上进行仿真,CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。CIFAR-10数据集被划分成了5个训练的batch和1个测试的batch,每个batch均包含10000张图片。测试集batch的图片是从每个类别中随机挑选的1000张图片组成的,训练集batch以随机的顺序包含剩下的50000张图片。不过一些训练集batch可能出现包含某一类图片比其他类的图片数量多的情况。训练集batch包含来自每一类的5000张图片,一共50000张训练图片。
3、仿真结果
图2为仿真的结果,图2中,第一列为具体的神经网络模型,Resnet-18为18层的残差神经网络。Senet为基于注意力机制的网络,densenet为稠密网络,其余的为不同的搜索方法的组合。第二列为搜索的神经网络在CIFAR-10数据集中的误差,误差越小,表示搜索精度越高;第三列为模型大小。第四列为搜索时间消耗(GPU/天),搜索时间消耗越小,表示搜索速度越快;第五列为搜索方法,其中manual为手工设计,RL为基于强化学习,evolutional为基于进化计算的方法,gradient为基于梯度的方法,random search为随机搜索的方法。BPE-1与BPE-2为本发明中找到的最优训练超参数组合。
通过图2可以发现,相比于其他的方法,本发明的精度更高速度更快。
以上所述,仅是本发明实施例而已,并非对本发明的技术范围作任何限制,故凡是依据本发明的技术实质对以上实施例所作的任何细微修改、等同变化与修饰,均仍属于本发明技术方案的范围内。
Claims (2)
1.一种基于重要性剪枝的精度估计参数搜索方法,其特征在于:所述方法包括以下步骤:
步骤1、设定神经网络结构搜索空间,超参数搜索空间、训练次数T、训练集和验证集;
步骤2、在神经网络搜索空间中,随机采样n个神经网络,在全精度训练条件f下利用训练集训练这n个神经网络,并且在验证集上进行验证操作,得到n个全精度训练后的神经网络以及相应的精度集合;
步骤3、在超参数搜索空间中,随机采样训练超参数集b;其中,采样的公式为:
其中,FLOPs为一个超参数所占用的计算复杂度,Θi,j为每个具体的训练超参数;
步骤4、将步骤3得到的训练超参数集b对步骤2中采样得到的n个神经网络进行训练以及验证,得到n个超参训练后的神经网络,以及相应的精度集合;
计算损失rs(Rf,Rb);其中,rs为斯皮尔曼相关系数,Rf为步骤2中n个全精度训练后的神经网络按照在验证集上误差排序,Rb为步骤3得到的超参训练后的神经网络按照在验证集上误差的排序;
步骤5、根据步骤3的训练超参数集b以及损失来构建随机森林,获取每一训练超参数的重要程度;
步骤6、根据计算好的重要程度进行剪枝:将重要程度最低的训练超参数进行剪枝,直接设置为时间消耗最低的训练超参数,得到训练后的训练超参数集b’;
步骤7、循环步骤3-6,直至达到固定的训练次数T;
步骤8、从T个训练超参数集b’中选出性能最优的训练超参数集b’。
2.根据权利要求1所述的一种基于重要性剪枝的精度估计参数搜索方法,其特征在于:所述步骤5具体如下:
针对于训练超参数的切分变量和切分点的选择,采用穷举法,遍历每个特征和每个特征的所有取值,最后从中找出最好的切分变量和切分点;针对于切分变量和切分点的好坏,以切分后节点的不纯度来衡量,即各个子节点不纯度的加权和;每个节点的不纯度计算公式如下:
其中,xi为某一个切分变量,υij为切分变量的一个切分值,nleft,nright,Ns分别为切分后左子节点的训练样本个数、右子节点的训练样本个数以及当前节点所有训练样本个数,Xleft,Xright分为左右子节点的训练样本集合,H(X)为衡量节点不纯度的函数,该函数为:
根据随机森林计算出具体的重要程度,每一训练超参数的重要程度为:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010259224.XA CN111507472A (zh) | 2020-04-03 | 2020-04-03 | 一种基于重要性剪枝的精度估计参数搜索方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010259224.XA CN111507472A (zh) | 2020-04-03 | 2020-04-03 | 一种基于重要性剪枝的精度估计参数搜索方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN111507472A true CN111507472A (zh) | 2020-08-07 |
Family
ID=71874151
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010259224.XA Pending CN111507472A (zh) | 2020-04-03 | 2020-04-03 | 一种基于重要性剪枝的精度估计参数搜索方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111507472A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113657592A (zh) * | 2021-07-29 | 2021-11-16 | 中国科学院软件研究所 | 一种软件定义卫星自适应剪枝模型压缩方法 |
CN114896436A (zh) * | 2022-06-14 | 2022-08-12 | 厦门大学 | 一种基于表征互信息的网络结构搜索方法 |
-
2020
- 2020-04-03 CN CN202010259224.XA patent/CN111507472A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113657592A (zh) * | 2021-07-29 | 2021-11-16 | 中国科学院软件研究所 | 一种软件定义卫星自适应剪枝模型压缩方法 |
CN113657592B (zh) * | 2021-07-29 | 2024-03-05 | 中国科学院软件研究所 | 一种软件定义卫星自适应剪枝模型压缩方法 |
CN114896436A (zh) * | 2022-06-14 | 2022-08-12 | 厦门大学 | 一种基于表征互信息的网络结构搜索方法 |
CN114896436B (zh) * | 2022-06-14 | 2024-04-30 | 厦门大学 | 一种基于表征互信息的网络结构搜索方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113190699B (zh) | 一种基于类别级语义哈希的遥感图像检索方法及装置 | |
CN109376242B (zh) | 基于循环神经网络变体和卷积神经网络的文本分类方法 | |
CN109635083B (zh) | 一种用于搜索ted演讲中话题式查询的文档检索方法 | |
CN109948029A (zh) | 基于神经网络自适应的深度哈希图像搜索方法 | |
CN110490320B (zh) | 基于预测机制和遗传算法融合的深度神经网络结构优化方法 | |
CN111898689A (zh) | 一种基于神经网络架构搜索的图像分类方法 | |
CN112464004A (zh) | 一种多视角深度生成图像聚类方法 | |
CN110674326A (zh) | 一种基于多项式分布学习的神经网络结构检索方法 | |
CN112686376A (zh) | 一种基于时序图神经网络的节点表示方法及增量学习方法 | |
CN113191445B (zh) | 基于自监督对抗哈希算法的大规模图像检索方法 | |
CN110866134A (zh) | 一种面向图像检索的分布一致性保持度量学习方法 | |
CN111507472A (zh) | 一种基于重要性剪枝的精度估计参数搜索方法 | |
CN114611670A (zh) | 一种基于师生协同的知识蒸馏方法 | |
CN113239211A (zh) | 一种基于课程学习的强化学习知识图谱推理方法 | |
CN114548591A (zh) | 一种基于混合深度学习模型和Stacking的时序数据预测方法及系统 | |
CN113282747B (zh) | 一种基于自动机器学习算法选择的文本分类方法 | |
CN111310820A (zh) | 基于交叉验证深度cnn特征集成的地基气象云图分类方法 | |
CN114860973A (zh) | 一种面向小样本场景的深度图像检索方法 | |
CN113095229B (zh) | 一种无监督域自适应行人重识别系统及方法 | |
CN114241267A (zh) | 基于结构熵采样的多目标架构搜索骨质疏松图像识别方法 | |
CN112651499A (zh) | 一种基于蚁群优化算法和层间信息的结构化模型剪枝方法 | |
CN111079840B (zh) | 基于卷积神经网络和概念格的图像语义完备标注方法 | |
CN111507383A (zh) | 一种基于进化算法的神经网络自动剪枝方法 | |
CN117034060A (zh) | 基于ae-rcnn的洪水分级智能预报方法 | |
CN114757433B (zh) | 一种饮用水源抗生素抗性相对风险快速识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
WD01 | Invention patent application deemed withdrawn after publication | ||
WD01 | Invention patent application deemed withdrawn after publication |
Application publication date: 20200807 |