CN115829029A - 一种基于通道注意力的自蒸馏实现方法 - Google Patents

一种基于通道注意力的自蒸馏实现方法 Download PDF

Info

Publication number
CN115829029A
CN115829029A CN202211184522.2A CN202211184522A CN115829029A CN 115829029 A CN115829029 A CN 115829029A CN 202211184522 A CN202211184522 A CN 202211184522A CN 115829029 A CN115829029 A CN 115829029A
Authority
CN
China
Prior art keywords
network
distillation
training
channel attention
neural network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211184522.2A
Other languages
English (en)
Inventor
朱隆熙
刘宁钟
吴磊
王淑君
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jiangsu Lemote Technology Corp ltd
Nanjing University of Aeronautics and Astronautics
Original Assignee
Jiangsu Lemote Technology Corp ltd
Nanjing University of Aeronautics and Astronautics
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jiangsu Lemote Technology Corp ltd, Nanjing University of Aeronautics and Astronautics filed Critical Jiangsu Lemote Technology Corp ltd
Priority to CN202211184522.2A priority Critical patent/CN115829029A/zh
Publication of CN115829029A publication Critical patent/CN115829029A/zh
Pending legal-status Critical Current

Links

Images

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了一种基于通道注意力的自蒸馏实现方法,该方法包括:首先下载CIFAR数据集,并对其进行划分和增广;然后在残差网络网络结构的基础上,使用四个阶段特征加入通道注意力后分别作为学生网络和教师网络,构造出新的蒸馏框架;将划分后的数据集送入神经网络进行训练,直至网络收敛,获得权重文件;最后利用训练好的神经网络和权重文件来检测测试图像,并输出分类结果。本发明能够很好地解决了目前蒸馏框架中教师网络预训练耗时和小模型精度不达标的问题,提高了蒸馏下模型的准确率。

Description

一种基于通道注意力的自蒸馏实现方法
技术领域
本发明涉及一种基于通道注意力的自蒸馏实现方法,属于计算机视觉技术领域。
背景技术
近年来,深度学习在学术界和工业界取得了巨大的成功,根本原因在于其可拓展性和编码大规模数据的能力。但是,深度学习的主要挑战在于,受限制于资源容量,深度神经模型很难部署在资源受限制的设备上。如嵌入式设备和移动设备。因此,涌现出了大量的模型压缩和加速技术,知识蒸馏是其中的代表,可以有效的从大型的教师模型中学习到小型的学生模型。
传统知识蒸馏可以分为基于响应的知识蒸馏和基于特征的知识蒸馏。基于反应的知识通常指教师模型最后一个输出层的神经反应。其主要思想是直接模拟教师模型的最终预测。基于反应的知识蒸馏是一种简单而有效的模型压缩方法,在不同的任务和应用中得到了广泛的应用。
基于特征的知识蒸馏来自于中间层,是基于响应的知识的一个很好的扩展,利用中间层的特征图可以作为监督学生模型训练的知识。深度神经网络善于学习到不同层级的表征,因此中间层和输出层的都可以被用作知识来训练学生模型,中间层的特征对于响应是一个很好的补充,其主要思想是将教师和学生的特征激活直接匹配起来。但是,上述两种经典方法有两个缺点包括:第一个缺点是知识转移效率低,这意味着学生模型几乎没有利用教师模型中的所有知识。一个杰出的学生模型其表现优于其教师模式,仍然是罕见的;另一个缺点是如何设计和培训合适的教师模式。现有的蒸馏框架需要大量的努力和实验才能找到最佳的教师模型架构,这需要相对较长的时间,例如传统蒸馏方法在CIFAR100上对教师网络ResNet152训练需要14.67小时,第二步对学生网络ResNet50训练需要12.31小时。
发明内容
本发明目的在于针对上述现有技术的不足,提出了一种基于通道注意力的自蒸馏实现方法,该方法很好地解决了目前蒸馏框架中教师网络预训练耗时和小模型精度不达标的问题,提高了蒸馏下模型的准确率。
本发明解决其技术问题所采取的技术方案是:一种基于通道注意力的自蒸馏实现方法,该方法包括以下步骤:
步骤1:数据集获取过程;
使用CIFAR10和CIFAR100数据集,并根据五比一的比例划分训练集和测试集;
步骤2:构建神经网络过程;
使用残差网络作为骨干网络,首先目标卷积神经神经网络根据其深度和原始结构划分为几个浅段,浅层网络可以视为学生模型,在概念上深层网络可以被视为教师模型;
步骤3:神经网络训练过程;
将划分后的CIFAR数据集送入上述步骤2构建的神经网络进行训练,直至网络收敛;
步骤4:测试图像检测过程;
采用训练好的神经网络和权重文件来检测测试图像中的准确率。
进一步地,本发明所述步骤2包括以下步骤:
步骤2-1:在残差网络中对于不同浅层网络的预测结果,将其当作学生网络,在每个浅层block之后,设置仅用于训练和可在推理中去除的瓶颈层和全连接层;
步骤2-2:对深层网络的特征加入通道注意力机制,使其中重要特征更加突出,非重要特征淡化,使知识的传递更加充分。
进一步地,本发明所述步骤3包括以下步骤:
步骤3-1:针对数据集中目标的大小,使用随机裁剪和随机水平翻转的数据增强方法;
步骤3-2:使用随机梯度下降的方法进行优化,学习率进行两次衰减,从初始值进行衰减,使神经网络能够达到更好的蒸馏结果;
步骤3-3:在神经网络上尝试不同的训练超参数,进行训练,当损失函数收敛或者达到最大迭代次数时,停止训练得到蒸馏后的网络文件和权重文件;
进一步地,本发明所述步骤3-1中对原始图像进行随机裁剪,裁剪填充大小为4。
进一步地,本发明所述步骤3-2是在训练过程中的不同阶段进行学习率衰减。
进一步地,本发明所述步骤4包括以下步骤:
步骤4-1:将测试图像送入改进的残差网络主干网络中,获取四个阶段的预测结果;
步骤4-2:将四个阶段的结果进行加权平均;
步骤4-3:对比五者结果,选择预测准确率高的作为最终结果。
有益效果:
1、本发明在残差网络主干网网络的基础上,采用深层网络作为教师网络来对浅层的学生网络进行蒸馏,能够让浅层学习到更深层的语义信息,增强了模型的分类精度。
2、本发明通过添加通道注意力的方法,突出深层特征的重要部分,能够更有效地利用教师网络的暗知识,提升了目标图片分类的准确率。
附图说明
图1为本发明的方法流程图。
图2为本发明实施例步骤2的方法流程图。
图3为本发明实施例步骤3的方法流程图。
图4为本发明实施例步骤4的方法流程图。
图5为本发明实施例中的测试结果图。
具体实施方式
下面结合说明书附图对本发明创造作进一步地详细说明。
如图1所示,本发明提供了一种基于通道注意力的自蒸馏实现方法,该方法包括以下步骤:
步骤1:获取数据集,并对CIFAR10和CIFAR100数据集进行划分,分成训练集和测试集;
步骤2:构建神经网络,使用残差网络作为骨干网络,构建网络时将四个阶段的特征作为分支,添加bottleneck层和FC层作为学生网络的预测,使用最后一层作为教师网络来蒸馏;
步骤3:训练神经网络,将划分后的CIFAR数据集送入神经网络进行训练,直至网络收敛;
步骤4:分类测试,利用训练好的神经网络和权重文件检测测试图像中的类别来验证蒸馏效果;
在本实施例中,本发明具体采用以下技术方案:
步骤1)从CIFAR数据集官网下载数据,并对数据进行划分;
步骤2)首先,增加四个分支进行特征的提取,再利用bottleneck层更有效地提取特征,最终通过FC层进行预测。
如图2所示,本发明步骤2包括如下步骤:
步骤201)将残差网络第一层的特征到第三层的特征进行抽取,并添加通道注意力,让网络学习重要特征;
步骤202)再利用bottleneck层对特征进行提取;
步骤203)最后利用FC层对提取的特征进行预测;
如图3所示,本发明步骤3包括如下步骤:
步骤301:在对网络进行训练前,重计算数据集的均值和方差,对数据进行归一化;
步骤302:使用随机权重作为初始权重,设置学习率、迭代次数、batch_size等;并在100和150轮,对学习率从初始值进行衰减,使神经网络能够达到更好的检测结果;
步骤303:对输入图像进行增广,进行训练,当损失函数收敛或者达到最大迭代次数时,停止训练获得自蒸馏后的权重文件。
如图4所示,本发明步骤4包括如下步骤:
步骤401:将测试图像送入改进的残差网络主干网络中,获取四个阶段的卷积特征;
步骤402:将四个阶段的特征分别进行预测;
步骤403:通过简单的加权平均获得四个阶段集合的预测结果,对比五者的结果取最优;
图5为使用本发明方法的检测结果,训练与测试在一张TITAN XP显卡上进行,在蒸馏时蒸馏温度设置为4.0,随机梯度下降算法中的权重衰减设置为0.0001,在每一轮的训练中都会将损失函数的值输出在终端,方便观察整体的收敛情况,并在每轮结束的时候使用测试集进行验证,在训练过程中还会将每个分支的预测结果进行输出,如Acc1-4表示为当前四层中第一层分支的预测结果,ensemble则表示加权不同分支后取平均的结果,验证准确率时对残差网络的第四层的分类结果进行比较,如果当前验证结果大于历史最优准确率则更新权重,经检验,本发明可在CIFAR100上达到78.76%的分类精确度。
以上所述实施例仅为说明本发明的优选实施方式,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明保护范围之内。

Claims (6)

1.一种基于通道注意力的自蒸馏实现方法,其特征在于,所述方法包括以下步骤:
步骤1:数据集获取过程;
下载CIFAR数据集,并对其进行划分为训练集和测试集以及数据增广;
步骤2:构建神经网络过程;
在残差网络网络结构的基础上,使用四个阶段特征加入通道注意力后分别作为学生网络和教师网络,构造出新的蒸馏框架;
步骤3:神经网络训练过程;
将增广划分后的CIFAR数据集送入步骤2构建的神经网络进行训练,直至网络收敛;
步骤4:测试图像检测过程;
采用训练好的神经网络和权重文件来检测测试图像中的分类准确率。
2.根据权利要求1所述的一种基于通道注意力的自蒸馏实现方法,其特征在于,所述步骤2包括以下步骤:
步骤2-1:在残差网络中对于不同浅层网络的预测结果,将其当作学生网络,在每个浅层block之后,设置仅用于训练和在推理中去除的瓶颈层和全连接层;
步骤2-2:对原本的特征加入通道注意力,使其中重要特征更加凸显,并对不重要的特征进行过滤,让网络学习到更充分更重要的特征。
3.根据权利要求1所述的一种基于通道注意力的自蒸馏实现方法,其特征在于,所述步骤3包括以下步骤:
步骤3-1:针对数据集中目标的大小,使用随机裁剪和随机水平翻转的数据增强方法;
步骤3-2:使用随机梯度下降的方法进行优化,学习率进行两次衰减,从初始值进行衰减,使神经网络能够达到更好的蒸馏结果;
步骤3-3:在神经网络上尝试不同的训练超参数,进行训练,当损失函数收敛或者达到最大迭代次数时,停止训练得到蒸馏后的网络文件和权重文件。
4.根据权利要求3所述的一种基于通道注意力的自蒸馏实现方法,其特征在于,所述步骤3-1中对原始图像进行随机裁剪,裁剪填充大小为4。
5.根据权利要求3所述的一种基于通道注意力的自蒸馏实现方法,其特征在于,所述步骤3-2是在训练过程中的不同阶段进行学习率衰减。
6.根据权利要求1所述的一种基于通道注意力的自蒸馏实现方法,其特征在于,所述步骤4包括以下步骤:
步骤4-1:将测试图像送入改进的残差网络主干网络中,获取四个阶段的预测结果;
步骤4-2:将四个阶段的结果进行加权平均;
步骤4-3:对比五者结果,选择预测准确率高的作为最终结果。
CN202211184522.2A 2022-09-27 2022-09-27 一种基于通道注意力的自蒸馏实现方法 Pending CN115829029A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211184522.2A CN115829029A (zh) 2022-09-27 2022-09-27 一种基于通道注意力的自蒸馏实现方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211184522.2A CN115829029A (zh) 2022-09-27 2022-09-27 一种基于通道注意力的自蒸馏实现方法

Publications (1)

Publication Number Publication Date
CN115829029A true CN115829029A (zh) 2023-03-21

Family

ID=85524077

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211184522.2A Pending CN115829029A (zh) 2022-09-27 2022-09-27 一种基于通道注意力的自蒸馏实现方法

Country Status (1)

Country Link
CN (1) CN115829029A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116384439A (zh) * 2023-06-06 2023-07-04 深圳市南方硅谷半导体股份有限公司 一种基于自蒸馏的目标检测方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116384439A (zh) * 2023-06-06 2023-07-04 深圳市南方硅谷半导体股份有限公司 一种基于自蒸馏的目标检测方法
CN116384439B (zh) * 2023-06-06 2023-08-25 深圳市南方硅谷半导体股份有限公司 一种基于自蒸馏的目标检测方法

Similar Documents

Publication Publication Date Title
CN108648188B (zh) 一种基于生成对抗网络的无参考图像质量评价方法
CN110533631B (zh) 基于金字塔池化孪生网络的sar图像变化检测方法
CN112508085B (zh) 基于感知神经网络的社交网络链路预测方法
CN109857871B (zh) 一种基于社交网络海量情景数据的用户关系发现方法
CN112381097A (zh) 一种基于深度学习的场景语义分割方法
CN112418292B (zh) 一种图像质量评价的方法、装置、计算机设备及存储介质
CN110532452B (zh) 一种基于gru神经网络的新闻网站通用爬虫设计方法
CN115170874A (zh) 一种基于解耦蒸馏损失的自蒸馏实现方法
CN111460818A (zh) 一种基于增强胶囊网络的网页文本分类方法及存储介质
CN113628059A (zh) 一种基于多层图注意力网络的关联用户识别方法及装置
CN115829029A (zh) 一种基于通道注意力的自蒸馏实现方法
CN116935128A (zh) 一种基于可学习提示的零样本异常图像检测方法
CN113435588B (zh) 基于深度卷积神经网络bn层尺度系数的卷积核嫁接方法
CN115587616A (zh) 网络模型训练方法、装置、存储介质及计算机设备
CN115331081A (zh) 图像目标检测方法与装置
CN115759225A (zh) 一种基于对比学习的自蒸馏实现方法
CN113822339B (zh) 一种自知识蒸馏和无监督方法相结合的自然图像分类方法
CN115511059B (zh) 一种基于卷积神经网络通道解耦的网络轻量化方法
CN115267883B (zh) 地震响应预测模型训练及预测方法、系统、设备及介质
CN113917370B (zh) 一种基于油中溶解气体小样本数据的变压器故障诊断方法
CN117315400A (zh) 一种基于特征频率的自蒸馏实现方法
CN112417447B (zh) 一种恶意代码分类结果的精确度验证方法及装置
CN113822339A (zh) 一种自知识蒸馏和无监督方法相结合的自然图像分类方法
CN117351279A (zh) 一种时空蒸馏融合的自蒸馏实现方法
CN115174421A (zh) 基于自监督解缠绕超图注意力的网络故障预测方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication