CN113837284B - 基于深度学习的双支路滤波器剪枝方法 - Google Patents
基于深度学习的双支路滤波器剪枝方法 Download PDFInfo
- Publication number
- CN113837284B CN113837284B CN202111128830.9A CN202111128830A CN113837284B CN 113837284 B CN113837284 B CN 113837284B CN 202111128830 A CN202111128830 A CN 202111128830A CN 113837284 B CN113837284 B CN 113837284B
- Authority
- CN
- China
- Prior art keywords
- network
- double
- layer
- model
- pruning
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/0002—Inspection of images, e.g. flaw detection
- G06T7/0004—Industrial image inspection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20081—Training; Learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20084—Artificial neural networks [ANN]
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computational Linguistics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Bioinformatics & Computational Biology (AREA)
- General Health & Medical Sciences (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- Software Systems (AREA)
- Quality & Reliability (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种基于深度学习的双支路滤波器剪枝方法,包括下列步骤:划分数据集,划分为训练集和测试集;基于原始VGG‑16进行重新搭建,得到需要剪枝的原始网络,之后加入双支路模块,得到新的网络模型;利用数据集中的训练集训练加入双支路模块的网络模型,使模型在测试集上的测试准确率达到最高,此时获得一个最优模型;再次把训练集输入最优模型,利用双支路模块和输入图片,得到每个滤波器的激活值;将单网络层滤波器的权重方差进行排序,获得每层滤波器的重要程度排序;按照预设定的每层的剪枝比率获得阈值,将大于阈值的数保留然后返回对应滤波器即得到网络的剪枝结果。
Description
技术领域
本方法涉及图像处理中模型轻量化领域,适用于计算资源较少的平台,具体涉及一种双支路滤波器剪枝方法。
背景技术
图像分类与检测的研究,是整个计算机视觉研究的基石,是解决跟踪、分割、场景理解等其他复杂视觉问题的基础。鉴于图像分类与检测在计算机视觉领域的重要地位,研究鲁棒、准确的图像分类与检测算法,无疑有着重要的理论意义和实际意义。神经网络具有可从海量数据中学习数据之间相关性和差异性的优点,可以避免人工设计提取特征的麻烦,且分类准确度高,训练难度较低,对特征中的噪声有着较强的鲁棒性和容错能力,能够充分拟合分类任务中需要的复杂非线性关系,所以神经网络是当前的一个热点研究问题。并且在许多现实场景的应用程序需要实时的设备处理能力。如在自动驾驶领域中,智能控制系统必须实时观察道路,当出现突发情况,必须及时预警停车。在这种情况下,需要能够在系统上实时地处理视觉信息并及时做出决策,并且要保证决策的准确性。
深度学习网络以其高准确率和稳定性能的优势被越来越多得应用于现实设备中。一般来说,深度学习网络越深越具有更强的表达能力。凭着这一基本准则卷积神经分类网络自Alexnet[1]的7层发展到了VGG[2]的16乃至19层,后来更有了Googlenet[3]的22层。可后来发现直连式卷积神经网络达到一定深度后再一味地增加层数并不能带来进一步地分类性能提高,反而会导致网络收敛变得更慢,测试集的分类准确率也变得更差。排除数据集过小带来的模型过拟合等问题后,过深的网络仍然还会使分类准确度下降。在深度学习网络中,随着网络深度的加深,梯度消失问题会愈加明显,为了缓解梯度消失问题,针对这个问题,通过改进网络结构并增加网络深度从而进一步提高分类模型准确率。如ResNet[4]通过使用多个有参层来学习输入输出之间的残差表示,而不是像一般VGG网络那样使用有参层来直接尝试学习输入、输出之间的映射,缓解了梯度消失问题。DenseNet[5]以密集的方式将每个模块与其他模块连接起来,每一层的输入来自前面所有层的输出。
ResNet和DenseNet在一定程度上缓解了梯度消失的问题,网络性能明显提高,但是由于大量输入输出的连接,造成了部分网络结构和参数冗余,不利于模型在资源受限的设备上进行部署,并且由于网络模型计算量过多,也无法保证实时性。因此在高性能的网络上进行模型压缩成为亟待解决的问题。
现有的注意力网络只包含一个跟随卷积块的注意力模块,这使得注意力模块只能从当前的特征地图中学习。因此,独立的注意力模块不能有效地决定要注意什么,并且当前的注意力模块很难调整对重点区域的关注,甚至在不同阶段会发生显著变化。注意力模块的学习能力明显不足。一个合理的解释是,从当前层中学习的额外信息的缺乏影响了它的辨别能力。因此需要一种新的设计,将前项信息与当前层信息融合起来,从而允许注意力模块之间相互合作。
参考文献:
[1]Krizhevsky A,Sutskever I,Hinton G E.Imagenet classification withdeep convolutional neural networks[J].Advances in neural informationprocessing systems,2012,25:1097-1105.
[2]Simonyan K,Zisserman A.Very deep convolutional networks for large-scale image recognition[J].arXiv preprint arXiv:1409.1556,2014.
[3]Szegedy C,Liu W,Jia Y,et al.Going deeper with convolutions[C]//Proceedings of the IEEE conference on computer vision and patternrecognition.2015:1-9.
[4]He K,Zhang X,Ren S,et al.Deep residual learning for imagerecognition[C]//Proceedings of the IEEE conference on computer vision andpattern recognition.2016:770-778.
[5]Huang G,Liu Z,Van Der Maaten L,et al.Densely connectedconvolutional networks[C]//Proceedings of the IEEE conference on computervision and pattern recognition.2017:4700-4708.
[6]Krizhevsky A,Hinton G.Learning multiple layers of features fromtiny images[J].2009.
发明内容
本发明的目的是提供一种基于深度学习的双支路滤波器剪枝方法,可以保留对输入特征敏感的滤波器,剪除对输入特征不敏感的滤波器使得网络参数进一步降低。技术方案如下:
一种基于深度学习的双支路滤波器剪枝方法,包括下列步骤:
第一步:划分数据集,划分为训练集和测试集;
第二步:基于原始VGG-16进行重新搭建,把预测部分的三个全连接层变为两个,交互神经元个数由原来的4096减少到512,得到需要剪枝的原始网络,之后加入双支路模块,得到新的网络模型;
第三步:利用数据集中的训练集训练加入双支路模块的网络模型,使模型在测试集上的测试准确率达到最高,此时获得一个最优模型;
第四步:再次把训练集输入最优模型,在不改变网络中的任何参数的情况下,利用双支路模块和输入图片,得到每个滤波器的激活值,网络不同层中不同滤波器的权重向量wi通过下列公式得到:
其中i代表第i层,Sigmoid代表规一化函数,H和W为输入特征的空间分辨率,Ii-1为输入特征,Ii为原始模块输出特征,W1表示支路中全连接层的权值;
第五步:计算网络中所有滤波器对于所有输入训练集图片的权值的方差,将单网络层滤波器的权重方差进行排序,获得每层滤波器的重要程度排序;按照预设定的每层的剪枝比率获得阈值,将大于阈值的数保留然后返回对应滤波器即得到网络的剪枝结果;
第六步:将裁剪后的原始网络使用数据集重新训练,恢复精度,经过训练收敛后的网络模型为最终得到轻量化的最优模型。
本发明的有益效果如下:
1、本发明提出的双支路模块中,一个支路根据输入的浅层特征,在细节上增强有用的特征,并在一定程度上抑制无用的特征。另一个支路根据本层网络输出特征图对高层特征进行增强与抑制,从而实现了联合注意机制。
2、本发明使用双支路模块来学习深层网络中对应层不同滤波器的权重,统计单个卷积核重要性的方差,方差越大的滤波器越活跃、对输入特征越敏感。该方法可以针对性的剪去网络中对输入特征不敏感的滤波器来获得分类所需的最小卷积核数目。
附图说明
图1为所提出的双支路滤波器剪枝算法的网络结构图。
具体实施方式
本发明提出了一种基于深度学习的双支路滤波器剪枝算法,由于浅层网络能够提取输入图片的低级特征,并且具有大量细节信息,深层网络提取输入图片语义级别的特征。在网络传播的过程中,上一层的细节信息很可能在下一层丢失,因此本发明内其中一个支路根据输入的特征图生成每个滤波器的激活值,对低层特征信息进行特征增强与抑制。另一个支路根据本层网络输出特征图来生成每个滤波器的激活值,对高层语义特征进行特征增强与抑制。
在支路的设计上,为了匹配空间分辨率,发明中采用平均池化层(Avg-pooling),将特征图分辨率压缩到1×1。最大池化层(Max-pooling)虽然也有很好的效果,但是它只考虑了部分信息而不是全部的注意力信息。在低层特征上仅进行特征压缩和归一化就能够提取浅层重要信息,因此设计比较简单。在高层特征上进行特征压缩和归一化操作的同时,还加入了一层全连接层。虽然现有的方法一般采取两个全连接层来补充更多的非线性信息,但由于前一个支路的存在,仅使用单层全连接层就能达到很好的效果。
在训练和测试高分辨率的图像时,直接把所提出的双支路模块加到直连型网络VGG上,可以在一定程度上缓解梯度消失的问题,并且模型的准确率具有一定提升。因此可以证明提出的双支路模块是一种有效的设计,可以在不改变内部结构的情况下增强CNN模型的准确率,并且模块可以以最小的计算开销适用所有的CNN架构。由于其对网络具有一定的增益效果,说明模块对特征进行增强和抑制后对分类结果有一定的指导作用。因此可根据其输出结果判断网络中的滤波器的活跃程度,对不活跃的滤波器进行剪除,从而减小模型的计算量和参数量。
下面将结合附图中VGG16网络对实施方式进一步的详细描述:
(1)数据准备:
(a)划分数据集,本方法采用的是分类通用数据集Cifar10,数据集中包含10个类别,分别为飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车,每类图像均有6千张,其中没有任何的重叠情况,也不会在同一张照片中出现两类事物。数据集共有6万张图像,其中5万张图像用于训练,1万张图像用于测试,且图像大小均为32×32。
(2)网络的搭建:本发明的网络结构主要为VGG-16和双支路模块。下面将结合附图1,对本发明搭建的网络结构进行详细的介绍说明。
(a)Cifar10数据集具有十个类别,但在Pytorch库中的预训练好的VGG-16是以千分类构造的网络,并且原始VGG-16中预测部分的三个全连接层具有巨大的参数量和计算量。本发明中在原始VGG-16的基础上进行了部分修改,把预测部分的三个全连接层变为两个,交互通道数由原来的4096减少到512,得到需要剪枝的原始网络模型。并且在每层网络中加入双支路模块。
(b)训练加入双支路模块的网络模型,网络层中不同层不同通道分别针对卷积核的权重向量wi计算如下,
其中i代表第i层,Sigmoid代表规一化函数,W1表示模块中全连接层的权值。H和W为输入特征的空间分辨率,Ii-1为输入特征,Ii为原始模块输出特征。wi表示神经网络中第l层的滤波器的激活值。
(c)加入双支路模块的网络模型通过将wi和xi相乘来生成加权特征。
加权后的特征既是当前层网络的输出,同时也是下一层网络层的输入,之后的网络可以根据经过增强和抑制后的输入特征图来补充增强细节特征。
(d)利用训练集训练加入双支路模块的网络模型,使模型在测试集上的测试准确率达到最高,此时获得一个最优模型。
(e)计算最优模型中所有滤波器对于所有训练集图片激活值的方差,将单网络层滤波器的方差进行排序,可以获得每层滤波器的重要程度排序。按照预设定的每层的剪枝比率获得阈值,将大于阈值的数保留然后返回对应滤波器即得到网络的剪枝结果。根据剪枝结果计算每一层剪枝之后剩余的通道数,改变之前在VGG-16上进行了部分修改的原始网络模型每一层的通道数,得到裁剪后的网络。裁剪后的网络再次使用训练集重新训练,恢复精度。
(3)模型训练:学习率设为0.1;衰减间隔为80,120,160,180,总共训练200次。采用交叉熵函数作为损失函数;采用SGD优化方法,权重衰减率为0.1,动量值为0.9。
(4)评价指标:本发明实验采用分类准确率衡量算法效果。
(5)提出了适用于图片分类的剪枝算法,原始VGG-16网络的分类准确率为93.65%,加入双支路模块后的准确率为93.95%,剪枝后的轻量化模型网络参数量从原网络14.98M降为1.56M;浮点运算次数(floating-point operations,FLOPs)从原网络313.73M降为104.61M;剪枝后的分类准确率为93.01%,证明该算法在降低计算量和参数量的基础上对原网络的分类性能没有较大损失。
Claims (1)
1.一种基于深度学习的双支路滤波器剪枝方法,包括下列步骤:
第一步:划分数据集,划分为训练集和测试集;
第二步:基于原始VGG-16进行重新搭建,把预测部分的三个全连接层变为两个,交互神经元个数由原来的4096减少到512,得到需要剪枝的原始网络,之后加入双支路模块,得到新的网络模型;
第三步:利用数据集中的训练集训练加入双支路模块的网络模型,使模型在测试集上的测试准确率达到最高,此时获得一个最优模型;
第四步:再次把训练集输入最优模型,在不改变网络中的任何参数的情况下,利用双支路模块和输入图片,得到每个滤波器的激活值,网络不同层中不同滤波器的权重向量wi通过下列公式得到:
其中i代表第i层,Sigmoid代表规一化函数,H和W为输入特征的空间分辨率,Ii-1为输入特征,Ii为原始模块输出特征,W1表示支路中全连接层的权值;
第五步:计算网络中所有滤波器对于所有输入训练集图片的权值的方差,将单网络层滤波器的权重方差进行排序,获得每层滤波器的重要程度排序;按照预设定的每层的剪枝比率获得阈值,将大于阈值的数保留然后返回对应滤波器即得到网络的剪枝结果;
第六步:将裁剪后的原始网络使用数据集重新训练,恢复精度,经过训练收敛后的网络模型为最终得到轻量化的最优模型。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111128830.9A CN113837284B (zh) | 2021-09-26 | 2021-09-26 | 基于深度学习的双支路滤波器剪枝方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111128830.9A CN113837284B (zh) | 2021-09-26 | 2021-09-26 | 基于深度学习的双支路滤波器剪枝方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113837284A CN113837284A (zh) | 2021-12-24 |
CN113837284B true CN113837284B (zh) | 2023-09-15 |
Family
ID=78970407
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111128830.9A Active CN113837284B (zh) | 2021-09-26 | 2021-09-26 | 基于深度学习的双支路滤波器剪枝方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113837284B (zh) |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110263841A (zh) * | 2019-06-14 | 2019-09-20 | 南京信息工程大学 | 一种基于滤波器注意力机制和bn层缩放系数的动态结构化网络剪枝方法 |
CN110532859A (zh) * | 2019-07-18 | 2019-12-03 | 西安电子科技大学 | 基于深度进化剪枝卷积网的遥感图像目标检测方法 |
CN110619385A (zh) * | 2019-08-31 | 2019-12-27 | 电子科技大学 | 基于多级剪枝的结构化网络模型压缩加速方法 |
CN111310615A (zh) * | 2020-01-23 | 2020-06-19 | 天津大学 | 基于多尺度信息和残差网络的小目标交通标志检测方法 |
CN111444760A (zh) * | 2020-02-19 | 2020-07-24 | 天津大学 | 一种基于剪枝与知识蒸馏的交通标志检测与识别方法 |
CN113052211A (zh) * | 2021-03-11 | 2021-06-29 | 天津大学 | 一种基于特征的秩和通道重要性的剪枝方法 |
-
2021
- 2021-09-26 CN CN202111128830.9A patent/CN113837284B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110263841A (zh) * | 2019-06-14 | 2019-09-20 | 南京信息工程大学 | 一种基于滤波器注意力机制和bn层缩放系数的动态结构化网络剪枝方法 |
CN110532859A (zh) * | 2019-07-18 | 2019-12-03 | 西安电子科技大学 | 基于深度进化剪枝卷积网的遥感图像目标检测方法 |
CN110619385A (zh) * | 2019-08-31 | 2019-12-27 | 电子科技大学 | 基于多级剪枝的结构化网络模型压缩加速方法 |
CN111310615A (zh) * | 2020-01-23 | 2020-06-19 | 天津大学 | 基于多尺度信息和残差网络的小目标交通标志检测方法 |
CN111444760A (zh) * | 2020-02-19 | 2020-07-24 | 天津大学 | 一种基于剪枝与知识蒸馏的交通标志检测与识别方法 |
CN113052211A (zh) * | 2021-03-11 | 2021-06-29 | 天津大学 | 一种基于特征的秩和通道重要性的剪枝方法 |
Non-Patent Citations (3)
Title |
---|
卷积神经网络模型剪枝结合张量分解压缩方法;巩凯强;张春梅;曾光华;计算机应用;第40卷(第011期);全文 * |
基于GoogLeNet模型的剪枝算法;彭冬亮;王天兴;控制与决策(第006期);全文 * |
基于滤波器注意力机制与特征缩放系数的动态网络剪枝;卢海伟;夏海峰;袁晓彤;小型微型计算机系统(第009期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN113837284A (zh) | 2021-12-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110472483B (zh) | 一种面向sar图像的小样本语义特征增强的方法及装置 | |
CN110084281B (zh) | 图像生成方法、神经网络的压缩方法及相关装置、设备 | |
CN108133188B (zh) | 一种基于运动历史图像与卷积神经网络的行为识别方法 | |
CN109753903B (zh) | 一种基于深度学习的无人机检测方法 | |
WO2021043112A1 (zh) | 图像分类方法以及装置 | |
US20220215227A1 (en) | Neural Architecture Search Method, Image Processing Method And Apparatus, And Storage Medium | |
CN112116030A (zh) | 一种基于向量标准化和知识蒸馏的图像分类方法 | |
Wang et al. | A deep-learning-based sea search and rescue algorithm by UAV remote sensing | |
CN110222718B (zh) | 图像处理的方法及装置 | |
CN113011338B (zh) | 一种车道线检测方法及系统 | |
CN111223087B (zh) | 一种基于生成对抗网络的桥梁裂缝自动检测方法 | |
CN116343330A (zh) | 一种红外-可见光图像融合的异常行为识别方法 | |
CN113011562A (zh) | 一种模型训练方法及装置 | |
CN111476133B (zh) | 面向无人驾驶的前背景编解码器网络目标提取方法 | |
CN112419333B (zh) | 一种遥感影像自适应特征选择分割方法及系统 | |
CN112464745A (zh) | 一种基于语义分割的地物识别与分类方法和装置 | |
CN115115924A (zh) | 基于ir7-ec网络的混凝土图像裂缝类型迅捷智能识别方法 | |
CN115393690A (zh) | 一种轻量化神经网络的空对地观测多目标识别方法 | |
CN116194933A (zh) | 处理系统、处理方法以及处理程序 | |
CN113850373B (zh) | 一种基于类别的滤波器剪枝方法 | |
CN116152678A (zh) | 小样本条件下基于孪生神经网络的海洋承灾体识别方法 | |
CN115861756A (zh) | 基于级联组合网络的大地背景小目标识别方法 | |
CN114332075A (zh) | 基于轻量化深度学习模型的结构缺陷快速识别与分类方法 | |
CN114049532A (zh) | 基于多阶段注意力深度学习的风险道路场景识别方法 | |
CN112132207A (zh) | 基于多分支特征映射目标检测神经网络构建方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |