CN115170874A - 一种基于解耦蒸馏损失的自蒸馏实现方法 - Google Patents
一种基于解耦蒸馏损失的自蒸馏实现方法 Download PDFInfo
- Publication number
- CN115170874A CN115170874A CN202210740525.3A CN202210740525A CN115170874A CN 115170874 A CN115170874 A CN 115170874A CN 202210740525 A CN202210740525 A CN 202210740525A CN 115170874 A CN115170874 A CN 115170874A
- Authority
- CN
- China
- Prior art keywords
- distillation
- network
- training
- decoupling
- neural network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/26—Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Multimedia (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于解耦蒸馏损失的自蒸馏实现方法,该方法包括:首先下载CIFAR数据集,并对其进行划分和增广;然后在残差网络网络结构的基础上,使用四个阶段特征分别作为学生网络和教师网络,构造出新的蒸馏框架;将划分后的数据集送入神经网络进行训练,直至网络收敛,获得权重文件;最后利用训练好的神经网络和权重文件来检测测试图像,并输出分类结果。本发明很好地解决了目前蒸馏框架中教师网络预训练耗时和小模型精度不达标的问题,提高了蒸馏下模型的准确率。
Description
技术领域
本发明涉及一种基于解耦蒸馏损失的自蒸馏实现方法,属于计算机视觉技术领域。
背景技术
近深度学习取得了巨大进步,但是受限于庞大的计算量和参数量很难实际应用与资源受限设备上。为了使深度模型更加高效,人们探索知识蒸馏这个领域。2006年,Bucilua等人最先提出将大模型的知识迁移到小模型的想法。2015年,Hinton才正式提出广为人知的知识蒸馏的概念。知识蒸馏的主要的想法是:学生模型通过模仿教师模型来获得和教师模型相当的精度,关键问题是如何将教师模型的知识迁移到学生模型。
传统知识蒸馏可以分为基于响应的知识蒸馏和基于特征的知识蒸馏。基于反应的知识通常指教师模型最后一个输出层的神经反应。其主要思想是直接模拟教师模型的最终预测。基于反应的知识蒸馏是一种简单而有效的模型压缩方法,在不同的任务和应用中得到了广泛的应用。
基于特征的知识蒸馏来自于中间层,是基于响应的知识的一个很好的扩展,利用中间层的特征图可以作为监督学生模型训练的知识。最直接的想法是匹配中间特征的激活函数值,特别地,Zagoruyko和Komodakis(2017)提出用attention map来表示知识;为了匹配教师和学生之间的语义信息,Chen et al.(2021)提出cross-layer KD,通过注意力定位自适应地为每个学生网络中的层分配教师网络中的层。但是,上述两种经典方法有两个缺点包括:第一个缺点是知识转移效率低,这意味着学生模型几乎没有利用教师模型中的所有知识。一个杰出的学生模型其表现优于其教师模式,仍然是罕见的;另一个缺点是如何设计和培训合适的教师模式。现有的蒸馏框架需要大量的努力和实验才能找到最佳的教师模型架构,这需要相对较长的时间,例如传统蒸馏方法在CIFAR100上对教师网络ResNet152训练需要14.67小时,第二步对学生网络ResNet50训练需要12.31小时。
发明内容
本发明目的在于针对上述现有技术的不足,提出了一种基于解耦蒸馏损失的自蒸馏实现方法,该方法很好地解决了目前蒸馏框架预训练教师网络耗时长、教师网络学生网络规模差异大导致学生精度差的问题。
本发明解决其技术问题所采取的技术方案是:一种基于解耦蒸馏损失的自蒸馏实现方法,该方法包括以下步骤:
步骤1:数据集获取过程;
使用CIFAR10和CIFAR100数据集,并根据五比一的比例划分训练集和测试集;
步骤2:构建神经网络过程;
使用残差网络作为骨干网络,首先目标卷积神经神经网络根据其深度和原始结构划分为几个浅段,浅层网络可以视为学生模型,在概念上深层网络可以被视为教师模型;
步骤3:神经网络训练过程;
将划分后的CIFAR数据集送入步骤2构建的神经网络进行训练,直至网络收敛;
步骤4:测试图像检测过程;
用训练好的神经网络和权重文件来检测测试图像中的准确率。
进一步地,本发明所述步骤2包括以下步骤:
步骤2-1:在残差网络网络中对于不同浅层网络的预测结果,将其当作学生网络,在每个浅层block之后,设置仅用于训练和可在推理中去除的瓶颈层和全连接层;
步骤2-2:对原本的基于响应的知识蒸馏损失进行分解,拆解成目标类别和非目标类别的二分类损失以及非目标类别概率分布,并将二分类损失和非目标类别概率分布的权重解耦出来。
进一步地,本发明所述步骤3包括以下步骤:
步骤3-1:针对数据集中目标的大小,使用随机裁剪和随机水平翻转的数据增强方法;
步骤3-2:使用随机梯度下降的方法进行优化,学习率进行两次衰减,从初始值进行衰减,使神经网络能够达到更好的蒸馏结果;
步骤3-3:在神经网络上尝试不同的训练超参数,进行训练,当损失函数收敛或者达到最大迭代次数时,停止训练得到蒸馏后的网络文件和权重文件;
进一步地,本发明所述步骤3-1中对原始图像进行随机裁剪,裁剪填充大小为4。
进一步地,本发明所述步骤3-2是在训练过程中的不同阶段进行学习率衰减。
进一步地,本发明所述步骤4包括以下步骤:
步骤4-1:将测试图像送入改进的残差网络主干网络中,获取四个阶段的预测结果;
步骤4-2:将四个阶段的结果进行加权平均;
步骤4-3:对比五者结果,选择预测准确率高的作为最终结果。
有益效果:
1、本发明在残差网络主干网网络的基础上,采用深层网络作为教师网络来对浅层的学生网络进行蒸馏,可以让浅层学习到更深层的语义信息,增强了模型的分类精度。
2、本发明通过改进方式蒸馏损失的方法,使用解耦知识蒸馏,能更有效地利用非目标类别所蕴含的暗知识,提升了目标图片分类的准确率。
附图说明
图1为本发明实施例的方法流程图。
图2为本发明实施例步骤2的方法流程图。
图3为本发明实施例步骤3的方法流程图。
图4为本发明实施例步骤4的方法流程图。
图5为本发明实施例中的测试结果图。
具体实施方式
下面结合说明书附图对本发明创造作进一步地详细说明。
如图1所示,本发明提供了一种基于解耦蒸馏损失的自蒸馏实现方法,该方法包括以下步骤:
步骤1:获取数据集,并对CIFAR10和CIFAR100数据集进行划分,分成训练集和测试集;
步骤2:构建神经网络,使用残差网络作为骨干网络,构建网络时将四个阶段的特征作为分支,添加bottleneck层和FC层作为学生网络的预测,使用最后一层作为教师网络来蒸馏;
步骤3:训练神经网络,将划分后的CIFAR数据集送入神经网络进行训练,直至网络收敛;
步骤4:分类测试,利用训练好的神经网络和权重文件检测测试图像中的类别来验证蒸馏效果;
在本实施例中,本发明具体采用以下技术方案:
步骤1)从CIFAR数据集官网下载数据,并对数据进行划分;
步骤2)首先,增加四个分支进行特征的提取,再利用bottleneck层更有效地提取特征,最终通过FC层进行预测。
如图2所示,本发明步骤2包括如下步骤:
步骤201)将残差网络网络第一层的特征到第三层的特征进行抽取,并添加注意力,让网络学习重要特征;
步骤202)再利用bottleneck层对特征进行提取;
步骤203)最后利用FC层对提取的特征进行预测;
如图3所示,本发明步骤3包括如下步骤:
步骤301:在对网络进行训练前,重计算数据集的均值和方差,对数据进行归一化;
步骤302:使用随机权重作为初始权重,设置学习率、迭代次数、batch_size等;并在100和150轮,对学习率从初始值进行衰减,使神经网络能够达到更好的检测结果;
步骤303:对输入图像进行增广,进行训练,当损失函数收敛或者达到最大迭代次数时,停止训练获得自蒸馏后的权重文件。
如图4所示,本发明步骤4包括如下步骤:
步骤401:将测试图像送入改进的残差网络主干网络中,获取四个阶段的卷积特征;
步骤402:将四个阶段的特征分别进行预测;
步骤403:通过简单的加权平均获得四个阶段集合的预测结果,对比五者的结果取最优;
图5为使用本发明方法的检测结果,训练与测试在一张TITAN XP显卡上进行,在蒸馏时蒸馏温度设置为4.0,随机梯度下降算法中的权重衰减设置为0.0001,在每一轮的训练中都会将损失函数的值输出在终端,方便观察整体的收敛情况,并在每轮结束的时候使用测试集进行验证,在训练过程中还会将每个分支的预测结果进行输出,如Acc1-4表示为当前四层中第一层分支的预测结果,ensemble则表示加权不同分支后取平均的结果,验证准确率时对残差网络的第四层的分类结果进行比较,如果当前验证结果大于历史最优准确率则更新权重,经检验,本发明可在CIFAR100上达到78.94%的分类精确度。
以上所述实施例仅为说明本发明的优选实施方式,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明保护范围之内。
Claims (6)
1.一种基于解耦蒸馏损失的自蒸馏实现方法,其特征在于,所述方法包括以下步骤:
步骤1:数据集获取过程;
首先下载CIFAR数据集,并对其进行划分为训练集和测试集以及数据增广;
步骤2:构建神经网络过程;
在残差网络网络结构的基础上,使用四个阶段特征分别作为学生网络和教师网络,构造出新的蒸馏框架;
步骤3:神经网络训练过程;
将增广划分后的CIFAR数据集送入步骤2构建的神经网络进行训练,直至网络收敛;
步骤4:测试图像检测过程;
利用训练好的神经网络和权重文件来检测测试图像中的分类准确率。
2.根据权利要求1所述的一种基于解耦蒸馏损失的自蒸馏实现方法,其特征在于,所述步骤2包括以下步骤:
步骤2-1:在残差网络网络中对于不同浅层网络的预测结果,将其当作学生网络,在每个浅层block之后,设置仅用于训练和可在推理中去除的瓶颈层和全连接层;
步骤2-2:对原本的基于响应的知识蒸馏损失进行分解,拆解成目标类别和非目标类别的二分类损失以及非目标类别概率分布,并将二分类损失和非目标类别概率分布的权重解耦出来。
3.根据权利要求1所述的一种基于解耦蒸馏损失的自蒸馏实现方法,其特征在于,所述步骤3包括以下步骤:
步骤3-1:针对数据集中目标的大小,使用随机裁剪和随机水平翻转的数据增强方法;
步骤3-2:使用随机梯度下降的方法进行优化,学习率进行两次衰减,从初始值进行衰减,使神经网络能够达到更好的蒸馏结果;
步骤3-3:在神经网络上尝试不同的训练超参数,进行训练,当损失函数收敛或者达到最大迭代次数时,停止训练得到蒸馏后的网络文件和权重文件。
4.根据权利要求3所述的一种基于解耦蒸馏损失的自蒸馏实现方法,其特征在于,所述步骤3-1中对原始图像进行随机裁剪,裁剪填充大小为4。
5.根据权利要求3所述的一种基于解耦蒸馏损失的自蒸馏实现方法,其特征在于,所述步骤3-2是在训练过程中的不同阶段进行学习率衰减。
6.根据权利要求1所述的一种基于解耦蒸馏损失的自蒸馏实现方法,其特征在于,所述步骤4包括以下步骤:
步骤4-1:将测试图像送入改进的残差网络主干网络中,获取四个阶段的预测结果;
步骤4-2:将四个阶段的结果进行加权平均;
步骤4-3:对比五者结果,选择预测准确率高的作为最终结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210740525.3A CN115170874A (zh) | 2022-06-27 | 2022-06-27 | 一种基于解耦蒸馏损失的自蒸馏实现方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210740525.3A CN115170874A (zh) | 2022-06-27 | 2022-06-27 | 一种基于解耦蒸馏损失的自蒸馏实现方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115170874A true CN115170874A (zh) | 2022-10-11 |
Family
ID=83487289
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210740525.3A Pending CN115170874A (zh) | 2022-06-27 | 2022-06-27 | 一种基于解耦蒸馏损失的自蒸馏实现方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115170874A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116384439A (zh) * | 2023-06-06 | 2023-07-04 | 深圳市南方硅谷半导体股份有限公司 | 一种基于自蒸馏的目标检测方法 |
CN117708726A (zh) * | 2024-02-05 | 2024-03-15 | 成都浩孚科技有限公司 | 网络模型解耦的开集合类别训练方法、装置及其存储介质 |
-
2022
- 2022-06-27 CN CN202210740525.3A patent/CN115170874A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116384439A (zh) * | 2023-06-06 | 2023-07-04 | 深圳市南方硅谷半导体股份有限公司 | 一种基于自蒸馏的目标检测方法 |
CN116384439B (zh) * | 2023-06-06 | 2023-08-25 | 深圳市南方硅谷半导体股份有限公司 | 一种基于自蒸馏的目标检测方法 |
CN117708726A (zh) * | 2024-02-05 | 2024-03-15 | 成都浩孚科技有限公司 | 网络模型解耦的开集合类别训练方法、装置及其存储介质 |
CN117708726B (zh) * | 2024-02-05 | 2024-04-16 | 成都浩孚科技有限公司 | 网络模型解耦的开集合类别训练方法、装置及其存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110533631B (zh) | 基于金字塔池化孪生网络的sar图像变化检测方法 | |
CN109583501B (zh) | 图片分类、分类识别模型的生成方法、装置、设备及介质 | |
CN114092832B (zh) | 一种基于并联混合卷积网络的高分辨率遥感影像分类方法 | |
CN110781406B (zh) | 一种基于变分自动编码器的社交网络用户多属性推断方法 | |
CN115170874A (zh) | 一种基于解耦蒸馏损失的自蒸馏实现方法 | |
CN109857871B (zh) | 一种基于社交网络海量情景数据的用户关系发现方法 | |
CN110837602A (zh) | 基于表示学习和多模态卷积神经网络的用户推荐方法 | |
CN110751698A (zh) | 一种基于混和网络模型的文本到图像的生成方法 | |
CN113628059A (zh) | 一种基于多层图注意力网络的关联用户识别方法及装置 | |
CN114898121A (zh) | 基于图注意力网络的混凝土坝缺陷图像描述自动生成方法 | |
CN112560948A (zh) | 数据偏差下的眼底图分类方法及成像方法 | |
CN117556369B (zh) | 一种动态生成的残差图卷积神经网络的窃电检测方法及系统 | |
CN115829029A (zh) | 一种基于通道注意力的自蒸馏实现方法 | |
CN110390050B (zh) | 一种基于深度语义理解的软件开发问答信息自动获取方法 | |
CN111505706A (zh) | 基于深度T-Net网络的微地震P波初至拾取方法及装置 | |
US20230186091A1 (en) | Method and device for determining task-driven pruning module, and computer readable storage medium | |
CN113889274A (zh) | 一种孤独症谱系障碍的风险预测模型构建方法及装置 | |
CN115759225A (zh) | 一种基于对比学习的自蒸馏实现方法 | |
CN117315400A (zh) | 一种基于特征频率的自蒸馏实现方法 | |
CN113822339B (zh) | 一种自知识蒸馏和无监督方法相结合的自然图像分类方法 | |
CN117351279A (zh) | 一种时空蒸馏融合的自蒸馏实现方法 | |
CN116610770B (zh) | 一种基于大数据的司法领域类案推送方法 | |
CN117496162B (zh) | 一种红外卫星遥感影像薄云去除方法、装置及介质 | |
CN113705873B (zh) | 影视作品评分预测模型的构建方法及评分预测方法 | |
CN116863190A (zh) | 一种图像识别方法和计算机设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |