CN116384439B - 一种基于自蒸馏的目标检测方法 - Google Patents
一种基于自蒸馏的目标检测方法 Download PDFInfo
- Publication number
- CN116384439B CN116384439B CN202310658974.8A CN202310658974A CN116384439B CN 116384439 B CN116384439 B CN 116384439B CN 202310658974 A CN202310658974 A CN 202310658974A CN 116384439 B CN116384439 B CN 116384439B
- Authority
- CN
- China
- Prior art keywords
- layer
- candidate network
- distillation
- training
- self
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 78
- 238000001514 detection method Methods 0.000 title claims abstract description 56
- 238000012549 training Methods 0.000 claims abstract description 79
- 238000003062 neural network model Methods 0.000 claims abstract description 14
- 238000010845 search algorithm Methods 0.000 claims abstract description 8
- 230000006978 adaptation Effects 0.000 claims abstract description 7
- 230000005484 gravity Effects 0.000 claims abstract description 3
- 238000000034 method Methods 0.000 claims description 29
- 238000002372 labelling Methods 0.000 claims description 6
- 230000003044 adaptive effect Effects 0.000 claims description 4
- 238000007781 pre-processing Methods 0.000 claims description 2
- 238000005516 engineering process Methods 0.000 abstract description 5
- 238000013135 deep learning Methods 0.000 abstract description 4
- 238000013528 artificial neural network Methods 0.000 description 14
- 230000006870 function Effects 0.000 description 12
- 238000013140 knowledge distillation Methods 0.000 description 11
- 230000008569 process Effects 0.000 description 11
- 238000013461 design Methods 0.000 description 6
- 238000004364 calculation method Methods 0.000 description 4
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 238000012546 transfer Methods 0.000 description 3
- 230000001133 acceleration Effects 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 238000007477 logistic regression Methods 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000013139 quantization Methods 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 101100465000 Mus musculus Prag1 gene Proteins 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 230000009193 crawling Effects 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000013100 final test Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000003278 mimic effect Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000001537 neural effect Effects 0.000 description 1
- 238000013138 pruning Methods 0.000 description 1
- 238000005292 vacuum distillation Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/01—Dynamic search techniques; Heuristics; Dynamic trees; Branch-and-bound
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
- Air Conditioning Control Device (AREA)
Abstract
本发明公开了一种基于自蒸馏的目标检测方法,涉及深度学习技术领域,解决了目前用于目标检测的自蒸馏技术不够灵活,导致自蒸馏效率低,从而导致目标检测结果的准确率不高。该方法包括S1、构建用于目标检测的神经网络模型,选择候选网络层并添加适配结构;S2、通过初始训练,获取每个候选网络层目标检测结果的均值平均精度和每个样本的错误率;S3、根据基于引力搜索算法的匹配条件和均值平均精度,当前候选网络层自动匹配其他候选网络层;S4、根据样本错误率,自动匹配的两个候选网络层进行自蒸馏训练;S5、则得到训练好的目标检测模型。本发明实现了网络层的知识跨层高效蒸馏,从而提高自蒸馏效率。
Description
技术领域
本发明涉及深度学习技术领域,尤其涉及一种基于自蒸馏的目标检测方法。
背景技术
深度学习在计算机视觉尤其是目标检测领域取得了令人难以置信的性能。然而,为了获得良好的性能,现代卷积神经网络总是需要大量的参数,并进行长时间的训练,这造成了模型性能、训练成本和模型存储成本、计算代价之间的矛盾。
近年来,许多模型压缩和加速方法被提出来解决这个问题。典型的方法包括剪枝、量化、轻量级神经网络设计、低秩分解和知识蒸馏。其中,知识蒸馏是最有效的方法之一,它首先训练一个过度参数化的神经网络作为教师,然后训练一个小的学生网络来模仿教师网络的输出。由于学生模型继承了教师的知识,因此可以替代过度参数化的教师模型,实现模型压缩和快速推理。然而,传统的知识蒸馏一直存在两个问题——教师模型的选择和知识转移的效率。研究人员发现,教师模型的选择对学生模型的准确性有很大影响,准确率最高的老师并不是蒸馏的最佳教师。因此,需要大量的实验来寻找最合适的蒸馏教师模型,这可能非常耗时。知识蒸馏的第二个问题是,学生模型不能总是像教师模型那样达到那么高的精度,这可能会导致推理期间不可接受的精度下降。
针对这些问题,自蒸馏技术应运而生。最早关于自蒸馏学习的工作发表在ICCV2019会议上,思路是在卷积神经网络的中间每一层接入一个提前预测分类结果的分类器,由模型最后的主分类器输出的logits函数引导中间各层的早期预测。自蒸馏技术不通过新增一个大模型的方式找到一个教师模型,同样可以提供有效增益信息给学生模型,这里的教师模型往往不会比学生模型复杂,但提供的增益信息对于学生模型是有效的增量信息,以提升学生模型效率。该方式可以避免使用更复杂的模型,也可以避免通过一些聚类或者是元计算的步骤生成伪标签。目前该方法在学术界较为新颖,从2020年开始逐渐有顶会浮现相关论文,主要探索任务也较为丰富,包括计算机视觉、自然语言处理、图神经网络等。
自蒸馏在同一个神经网络模型的不同层次或者不同训练轮次间进行蒸馏。与传统知识蒸馏相比,自蒸馏减少了训练开销。由于所提出的自我蒸馏中的教师模型和学生模型都是同一神经网络中的分类器,因此可以避免在常规知识蒸馏中搜索教师模型的大量实验。自我蒸馏是一种单阶段训练方法,其中教师模型和学生模型可以一起训练。自蒸馏的一级特性进一步降低了训练开销。与传统知识蒸馏相比,自蒸馏可实现更高的精度、加速度和压缩。与传统的知识蒸馏侧重于不同模型之间的知识转移不同,自我蒸馏在一个模型中转移知识。实验表明,自蒸馏比其他知识蒸馏方法的效果好得多。此外,业内研究还发现,自蒸馏和常规知识蒸馏方法可以一起使用,以达到更好的效果。
自蒸馏具有训练轻量、知识迁移效率高的特点,可以有效提高神经神经网络性能,受到研究者的重视。然而,目前自蒸馏技术还存在比较明显的缺点,现有自蒸馏技术不够灵活。在用深层网络层蒸馏浅层网络层的过程中,具体不同层之间的匹配关系并没有太明确且统一的范式和准则,一般还是根据经验手动选择学生层和教师层进行自蒸馏,对相应规律总体仍处于探索阶段。
在实现本发明过程中,发明人发现现有技术中至少存在如下问题:
目前用于目标检测的自蒸馏技术不够灵活,不同层之间的匹配关系不统一,需要手动选择学生层和教师层,导致自蒸馏效率低,从而导致目标检测结果的准确率不高。
发明内容
本发明的目的在于提供一种基于自蒸馏的目标检测方法,以解决现有技术中存在的目前用于目标检测的自蒸馏技术不够灵活,不同层之间的匹配关系不统一,需要手动选择学生层和教师层,导致自蒸馏效率低,从而导致目标检测结果的准确率不高的技术问题。本发明提供的诸多技术方案中的优选技术方案所能产生的诸多技术效果详见下文阐述。
为实现上述目的,本发明提供了以下技术方案:
本发明提供的一种基于自蒸馏的目标检测方法,包括:
S1、构建用于目标检测的神经网络模型,选择候选网络层并为所述候选网络层添加适配结构;
S2、通过对待训练样本集进行初始训练,获取每个所述候选网络层的适配结构目标检测结果的均值平均精度和所述待训练样本集中每个样本的错误率;
S3、根据基于引力搜索算法的匹配条件和所述均值平均精度,当前所述候选网络层自动匹配其他所述候选网络层;
S4、根据所述样本的错误率,自动匹配的两个所述候选网络层作为学生层和教师层进行自蒸馏训练,并更新所述候选网络层的均值平均精度和样本错误率;
S5、判断是否完成所有轮次的训练,若是,则得到训练好的目标检测模型;否则,执行步骤S3;
步骤S1之前,需要收集用于目标检测的数据样本,所述数据样本为图片;对所述数据样本进行预处理、标注信息后得到所述待训练样本集;
第n个所述样本的错误率为:;
若所述待训练样本集中的某数据样本含有多个标注目标,只要有未正确检测的所述标注目标,就判定所述数据样本检测错误;
步骤S3包括:
S31、设置所述学生层、教师层自动匹配的间隔轮次和引力阈值;
S32、根据所述均值平均精度计算当前所述候选网络层和其他所述候选网络层的引力值大小;
S33、将其他所述候选网络层中满足所述匹配条件的,匹配给当前所述候选网络层;其中,将自动匹配的两个所述候选网络层中,浅层的作为学生层,深层的作为教师层;
S34、判断训练轮次间隔是否达到所述间隔轮次,若是,则执行步骤S32;否则,执行步骤S4;
所述匹配条件为:在所述引力值大于所述引力阈值的其他所述候选网络层中,匹配相隔层数最小的所述候选网络层;
所述引力值为:,
其中,G为引力参数;mAPi为当前候选网络层i的均值平均精度;mAPj为其他候选网络层j的均值平均精度;△L为所述候选网络层i、候选网络层j之间的间隔层数。
优选的,若某一当前所述候选网络层没有满足匹配条件的其他所述候选网络层,则当前所述候选网络层不参加本批次的蒸馏训练。
优选的,步骤S4包括:
S41、在当前批次训练中,选取BatchSize个所述样本,通过所述样本的错误率,获取当前批次所述样本的训练难易程度值;
S42、根据所述训练难易程度值,获取温度系数;
S43、所述学生层、教师层通过所述温度系数进行自蒸馏训练,并更新所述候选网络层的均值平均精度和样本错误率。
优选的,所述训练难易程度值为:,
其中,为本批次BatchSize个所述样本的所述错误率之和。
优选的,所述温度系数为:,
其中,为所述训练难易程度值;/>为预设的温度参数。
优选的,进行自蒸馏训练所用的Sigmoid函数为:,
其中,e为自然常数,T为所述温度系数,X为函数自变量。
实施本发明上述技术方案中的一个技术方案,具有如下优点或有益效果:
本发明引入引力搜索算法,结合目标检测神经网络的训练特征,制定学生层和教师层之间的匹配规则,使得不同网络层之间可以按照匹配规则进行学生层和教师层的自动匹配,无需再手动选择匹配;实现了网络层的知识跨层高效蒸馏,从而提高自蒸馏效率。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单的介绍,显而易见,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图,附图中:
图1是本发明实施例一种基于自蒸馏的目标检测方法的流程图;
图2是本发明实施例候选网络层添加适配结构的示意图;
图3是本发明实施例一种基于自蒸馏的目标检测方法步骤S3的流程图;
图4是本发明实施例一种基于自蒸馏的目标检测方法步骤S4的流程图;
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,下文将要描述的各种示例性实施例将要参考相应的附图,这些附图构成了示例性实施例的一部分,其中描述了实现本发明可能采用的各种示例性实施例。除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。应明白,它们仅是与如所附权利要求书中所详述的、本发明公开的一些方面相一致的流程、方法和装置等的例子,还可使用其他的实施例,或者对本文列举的实施例进行结构和功能上的修改,而不会脱离本发明的范围和实质。
在本发明的描述中,需要理解的是,术语“中心”、“纵向”、“横向”等指示的是基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的元件必须具有的特定的方位、以特定的方位构造和操作。术语“第一”、“第二”等仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。术语“多个”的含义是两个或两个以上。术语“相连”、“连接”应做广义理解,例如,可以是固定连接、可拆卸连接、一体连接、机械连接、电连接、通信连接、直接相连、通过中间媒介间接相连,可以是两个元件内部的连通或两个元件的相互作用关系。术语“和/或”包括一个或多个相关的所列项目的任意的和所有的组合。对于本领域的普通技术人员而言,可以根据具体情况理解上述术语在本发明中的具体含义。
为了说明本发明所述的技术方案,下面通过具体实施例来进行说明,仅示出了与本发明实施例相关的部分。
实施例一:如图1所示,本发明提供了一种基于自蒸馏的目标检测方法,包括:
S1、构建用于目标检测神经网络模型,选择候选网络层并为候选网络层添加适配结构;
S2、通过对待训练样本集进行初始训练,获取每个候选网络层的适配结构目标检测结果的均值平均精度和待训练样本集中每个样本的错误率;
S3、根据基于引力搜索算法的匹配条件和均值平均精度,作为学生层的候选网络层自动匹配教师层;
S4、根据样本的错误率,自动匹配的两个候选网络层作为学生层和教师层进行自蒸馏训练,并更新候选网络层的均值平均精度和样本错误率;
S5、判断是否完成所有轮次的训练,若是,则得到训练好的目标检测模型;否则,执行步骤S3。
本实施例引入引力搜索算法,结合目标检测神经网络的训练特征,制定学生层和教师层之间的匹配规则,使得不同网络层之间可以按照匹配规则进行学生层和教师层的自动匹配,无需再手动选择匹配;实现了网络层的知识跨层高效蒸馏,从而提高自蒸馏效率。
在步骤S1构建神经网络模型之前,需要构建目标检测的数据集。在目标检测中,收集的数据样本是图像。数据的采集方法有很多,其中主要的采集方法有人工收集、系统采集、网络爬取、虚拟仿真、对抗生成、开源数据等。选定数据集后,要对数据进行预处理、标注信息,最后将数据集划分为训练集、验证集和测试集。训练集一般是用来对神经网络模型进行训练的,验证集用来验证训练的结果是否达标,测试集用来在模型经过验证合格后再进行的最后测试。
准备好数据集后,根据实际工作需要,构建神经网络模型和设置神经网络训练过程中的超参数。在这个过程中,既可以完全从零开始设计神经网络模型,也可以已有网络模型为主干网络进行设计。神经网络模型的设计主要包括结构设计、激活函数设计、损失函数设计、优化器设计等。如果是从零开始设计神经网格模型,最好为模型选择相对合理的初始化权重和参数。
在步骤S1中,根据构建的神经网络模型,选取有可能成为自蒸馏点的网络层,即为选择候选网络层的过程。选出的网络层既有可能作为学生层,又有可能作为教师层参与自蒸馏过程。
基于深度学习的目标检测算法中有三个组件,分别为Backbone、Neck和Head。其中,Backbone是模型的主干网络,其作用就是提取图片中的特征信息,供网络的其它部分使用。这些网络经常使用的是残差网络或者VGG(Visual Geometry Group,视觉几何组)等,这些网络已经证明了在分类等问题上的特征提取能力是很强的。Head是获取网络输出内容的网络,利用之前提取的特征,做出预测。Neck是放在Backbone和Head之间的,是为了更好的利用Backbone提取特征。本实施例中,为每个候选网络层添加的适配结构,指的是为神经网络模型Backbone部分的每个候选网络层添加Neck组件和Head组件,通过这两个组件适配后,每个候选网络层都可以作为一个弱目标检测器输出目标检测结果,每个目标检测器具有不同的准确度和响应性能,如图2所示。为候选网络层添加的这些适配结构,在完成神经网络训练后都可以删除,并不影响神经网络模型的响应时间。
本实施例采用的是是多点蒸馏方法,即在神经网络层选取多组学生-教师层同时进行跨层自蒸馏。和单点蒸馏相比,多个学生层从多个教师层中获得的信息更多,通常认为它们可以表现出更好的知识迁移效果。但是这也引出了明显的问题。那就是怎样为学生层匹配合适的教师层。如果选择的模型不匹配,则学生层不能更好的从教师层学到知识,不能达到自蒸馏的效果。具体来说,学生层和教师层既不能距离太远也不能具体太近。如果学生层和教师层相隔距离太远,则由于教师层提供的知识对学生层太抽象,会出现知识蒸馏效率下降的情况。如果学生层和教师层距离太近,则由于两者知识差距不大,学生层学习知识的速度太慢,也会影响蒸馏效率。
另外,在神经网络中,需要选出多少对学生层和教师层,目前也没有特别明确的规范。如果选择的学生-教师层太少,则会明显影响蒸馏效率;但是,如果选择的学生-教师层太密集,则容易造成学生层对于教师层的过度知识拟合,造成模型性能降低。传统的神经网络跨层匹配方式有两种,第一种是skip模式,即每隔几层去学习一个中间层,其中具体间隔层数取固定值。第二种是last模式,即学习教师模型的最后几层。这两种方式都过于简单和机械化。不同于传统的手动选择学生层和教师层的方式,本实施例引入引力搜索算法,制定学生-教师层自动匹配的规则,使学生层和教师层能够自动匹配。
由于在匹配过程中,既不希望两个网络层距离太近,又不希望两个网络层距离太远。如果学生层和教师层距离太近,仅间隔很少的层数,两个网络层之间的知识很相似,即使进行了蒸馏浅层网络也不能从深层网络层那获取太多知识,反而浪费了蒸馏过程所需要的算力,造成蒸馏效率低下。如果学生层和教师层距离太远,则由于深层网络层包含的知识太抽象,不适合浅层网络层学习,也不利于蒸馏。因此,在选取浅层网络和深层网络进行匹配时,既要保证深层网络层包含浅层网络还没有的知识,又有需要保证这部分知识适合浅层网络层学习。因此,本实施例引入不同网络层之间的“引力”概念,将每个网络层的输出mAP看作本层网络的“质量”,将不同网络层的间隔层数看作两个网络层之间的“距离”,根据时间任务定义“引力常数”G ,这样就可以成功计算出网络层之间的引力F。
为了获取网络层之间引力,需要步骤S2中,通过初始t个轮次的训练,计算每个候选网络层的适配结构Head目标检测结果的均值平均精度。
然后进行步骤S3,如图3所示,步骤S3包括:
S31、设置学生层、教师层自动匹配的间隔轮次和引力阈值;
S32、根据均值平均精度计算当前候选网络层和其他候选网络层的引力值大小;引力值为:,其中,G为引力参数,可根据实际训练情况设置;mAPi为当前候选网络层i的均值平均精度;mAPj为其他候选网络层j的均值平均精度;△L为候选网络层i、候选网络层j之间的间隔层数。
S33、将其他候选网络层中满足匹配条件的,匹配给当前候选网络层;其中,将自动匹配的两个候选网络层中,浅层的作为学生层,深层的作为教师层;设定一个学生层最多只匹配一个教师层。匹配条件为:在引力值大于引力阈值的其他候选网络层中,匹配相隔层数最小的候选网络层。如,设置引力阈值为8,浅层的第2层网络层与深层的第6层网络层、第7层网络层、第8层网络层Q的引力值分别为7、16、12,则第2层网络层和第7层网络层匹配为学生-教师层,且第2层网络层为学生层,第7层网络层为教师层。
S34、判断训练轮次间隔是否达到间隔轮次,若是,则执行步骤S32;否则,执行步骤S4。
若某一候选网络层没有满足匹配条件的其他候选网络层,则候选网络层不参加本批次的蒸馏训练。
知识来源于数据集的数据样本,区分不同知识来源的重要性,也就是区分不同数据样本的重要性。数据样本根据训练的难易程度不同可以分为难训练样本和易训练样本。在训练过程中,难训练样本是指在多轮次训练中错误次数较多的样本,易训练样本是指在多轮次训练中错误次数偏少的样本。有理由相信,当模型训练的输入为难训练样本时,来自教师层的数据包含更多的“知识”,当模型训练的输入是容易训练样本时,来自教师层的数据包含相对较少的“知识”。
从另一角度讲,根据每次训练中样本的难易程度,确定是高温蒸馏还是低温蒸馏。就像在学习过程中,遇到较难的知识点,需要老师重点讲解,学生重点学习;遇到简单的知识点,由于学生自学就能学会,不太需要教师提供太多的指导,如果教师过分强调简单的知识点,有时还会误导一些不太聪明的学生,使得这部分学生的思路僵化。
因此,本实施例引入自适应蒸馏温度的调节方法,以实现对数据集中难训练样本和易训练样本的自适应学习。蒸馏温度的高低改变的是学生层对负标签的关注程度:温度较低时,困难样本携带的信息会被相对减少,对难训练样本的关注较少,难训练样本的概率越低,关注越少;而温度较高时,难训练样本的概率值会相对增大,难训练样本携带的信息会被相对地放大,学生网络会更多关注到负标签。为了充分利用教师模型负类别的 dark信息,一般会选用一个较高的温度系数。温度系数的作用就是它控制了模型对难易样本的区分度,错误率高时提高温度系数。
在目标检测任务中,类别预测检测头一般将原始图像分类任务的单标签分类函数改为多标签分类,也就是将检测头中的softmax层换成用于多标签多分类的逻辑回归层,逻辑回归层主要用sigmoid函数,该函数可以将输入约束在0到1的范围内,以计算每个检测目标所属的类别概率。因此本实施例进行自蒸馏训练的激活函数采用Sigmoid函数:,其中,e为自然常数,T为温度系数,X为函数自变量。T越高,Sigmoid函数的输出概率分布越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
在开始训练之前,需要计算每个样本的错误率,第n个样本的计算公式为:
。
如果目标检测数据训练集中的某样本含有多个标注目标(比如,在行人检测数据训练集中,一张图片中可能包含多个被标注的行人),本实施例设定,只要有未正确检测的标注目标,就判定该数据样本检测错误。此处是基于数据样本定义的,而不是目标检测模型的正确率等指标。显然,在开始训练时,每个数据样本的错误率都为0。
开始进行训练,如图4所示,步骤S4包括:
S41、在当前批次训练中,选取BatchSize个样本,通过样本的错误率,获取当前批次样本的训练难易程度值;本实施例用样本的错误率Rerror量化表示数据样本的训练难度。将当前批次所有样本错误率Rerror的均值,作为当前批次样本的训练难易程度值TDL。其计算公式为:,其中,/>为本批次BatchSize个样本的错误率之和。TDL的取值范围为[0,1]。
S42、根据训练难易程度值,获取温度系数;其计算公式为:,其中,TDL为训练难易程度值;T0为预设的温度参数,且规定T0>1。本实施例的温度系数可随着每批次样本错误率的更新动态调整,当第一轮次训练时,TDL=0,T=T0;从第二轮次开始,对于错误率高的批次样本,其TDL会增大,甚至接近1,T接近2T0;随着轮次的增大,TDL逐渐减小,动态温度接近于T0。在训练过程中,动态蒸馏温度系数T的变化规律如下:如果批次z中的样本都比较容易训练,则TDL比较小,动态温度T相对较小;如果批次z中的样本都比较难训练,则TDL比较大,动态温度T相对较大。以上规律符合困难样本需要“高温”蒸馏,容易样本“低温”蒸馏的实际训练需求。实现了在不同难度样本间进行不同程度的蒸馏,提高了蒸馏效率。
S43、学生层、教师层通过温度系数进行自蒸馏训练,并更新候选网络层的均值平均精度和样本错误率。
本实施例引入引力搜索算法,结合目标检测神经网络的训练特征,制定学生层和教师层之间的匹配规则,使得不同网络层之间可以按照匹配规则进行学生层和教师层的自动匹配,实现了网络层的知识跨层高效蒸馏,从而提高自蒸馏效率;同时考虑到不同网络层包含知识的丰富程度和相对其他网络层的学习难度,设计可以动态调整蒸馏温度的方法,针对目标检测训练过程中样本训练难度不均衡的常见情况,可以实现在不同难度样本间进行不同程度的蒸馏,提高了蒸馏效率,降低了训练目标检测神经网络模型所需要的时间成本。在同等训练轮次下,通过提高蒸馏效率,本实施例方法可以在一定程度上提高神经网络目标检测的准确率,提高网络性能。
实施例仅是一个特例,并不表明本发明就这样一种实现方式。
以上所述仅为本发明的较佳实施例而已,本领域技术人员知悉,在不脱离本发明的精神和范围的情况下,可以对这些特征和实施例进行各种改变或等同替换。另外,在本发明的教导下,可以对这些特征和实施例进行修改以适应具体的情况及材料而不会脱离本发明的精神和范围。因此,本发明不受此处所公开的具体实施例的限制,所有落入本申请的权利要求范围内的实施例都属于本发明的保护范围。
Claims (6)
1.一种基于自蒸馏的目标检测方法,其特征在于,包括:
S1、构建用于目标检测的神经网络模型,选择候选网络层并为所述候选网络层添加适配结构;
S2、通过对待训练样本集进行初始训练,获取每个所述候选网络层的适配结构目标检测结果的均值平均精度和所述待训练样本集中每个样本的错误率;
S3、根据基于引力搜索算法的匹配条件和所述均值平均精度,当前所述候选网络层自动匹配其他所述候选网络层;
S4、根据所述样本的错误率,自动匹配的两个所述候选网络层作为学生层和教师层进行自蒸馏训练,并更新所述候选网络层的均值平均精度和样本错误率;
S5、判断是否完成所有轮次的训练,若是,则得到训练好的目标检测模型;否则,执行步骤S3;
步骤S1之前,需要收集用于目标检测的数据样本,所述数据样本为图片;对所述数据样本进行预处理、标注信息后得到所述待训练样本集;
第n个所述样本的错误率为:;
若所述待训练样本集中的某数据样本含有多个标注目标,只要有未正确检测的所述标注目标,就判定所述数据样本检测错误;
步骤S3包括:
S31、设置所述学生层、教师层自动匹配的间隔轮次和引力阈值;
S32、根据所述均值平均精度计算当前所述候选网络层和其他所述候选网络层的引力值大小;
S33、将其他所述候选网络层中满足所述匹配条件的,匹配给当前所述候选网络层;其中,将自动匹配的两个所述候选网络层中,浅层的作为学生层,深层的作为教师层;
S34、判断训练轮次间隔是否达到所述间隔轮次,若是,则执行步骤S32;否则,执行步骤S4;
所述匹配条件为:在所述引力值大于所述引力阈值的其他所述候选网络层中,匹配相隔层数最小的所述候选网络层;
所述引力值为:,
其中,G为引力参数;mAPi为当前候选网络层i的均值平均精度;mAPj为其他候选网络层j的均值平均精度;△L为所述候选网络层i、候选网络层j之间的间隔层数。
2.根据权利要求1所述的一种基于自蒸馏的目标检测方法,其特征在于,若某一当前所述候选网络层没有满足匹配条件的其他所述候选网络层,则当前所述候选网络层不参加本批次的蒸馏训练。
3.根据权利要求1所述的一种基于自蒸馏的目标检测方法,其特征在于,步骤S4包括:
S41、在当前批次训练中,选取BatchSize个所述样本,通过所述样本的错误率,获取当前批次所述样本的训练难易程度值;
S42、根据所述训练难易程度值,获取温度系数;
S43、所述学生层、教师层通过所述温度系数进行自蒸馏训练,并更新所述候选网络层的均值平均精度和样本错误率。
4.根据权利要求3所述的一种基于自蒸馏的目标检测方法,其特征在于,所述训练难易程度值为:,
其中,为本批次BatchSize个所述样本的所述错误率之和。
5.根据权利要求3所述的一种基于自蒸馏的目标检测方法,其特征在于,所述温度系数为:,
其中,为所述训练难易程度值;/>为预设的温度参数。
6.根据权利要求5所述的一种基于自蒸馏的目标检测方法,其特征在于,进行自蒸馏训练所用的Sigmoid函数为:,
其中,e为自然常数,T为所述温度系数,X为函数自变量。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310658974.8A CN116384439B (zh) | 2023-06-06 | 2023-06-06 | 一种基于自蒸馏的目标检测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310658974.8A CN116384439B (zh) | 2023-06-06 | 2023-06-06 | 一种基于自蒸馏的目标检测方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116384439A CN116384439A (zh) | 2023-07-04 |
CN116384439B true CN116384439B (zh) | 2023-08-25 |
Family
ID=86963756
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310658974.8A Active CN116384439B (zh) | 2023-06-06 | 2023-06-06 | 一种基于自蒸馏的目标检测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116384439B (zh) |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110472730A (zh) * | 2019-08-07 | 2019-11-19 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法 |
CN115170874A (zh) * | 2022-06-27 | 2022-10-11 | 江苏中科梦兰电子科技有限公司 | 一种基于解耦蒸馏损失的自蒸馏实现方法 |
CN115829029A (zh) * | 2022-09-27 | 2023-03-21 | 江苏中科梦兰电子科技有限公司 | 一种基于通道注意力的自蒸馏实现方法 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
GB2610319A (en) * | 2020-12-17 | 2023-03-01 | Zhejiang Lab | Automatic compression method and platform for multilevel knowledge distillation-based pre-trained language model |
WO2022203729A1 (en) * | 2021-03-26 | 2022-09-29 | Google Llc | Self-adaptive distillation |
-
2023
- 2023-06-06 CN CN202310658974.8A patent/CN116384439B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110472730A (zh) * | 2019-08-07 | 2019-11-19 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法 |
CN115170874A (zh) * | 2022-06-27 | 2022-10-11 | 江苏中科梦兰电子科技有限公司 | 一种基于解耦蒸馏损失的自蒸馏实现方法 |
CN115829029A (zh) * | 2022-09-27 | 2023-03-21 | 江苏中科梦兰电子科技有限公司 | 一种基于通道注意力的自蒸馏实现方法 |
Non-Patent Citations (1)
Title |
---|
基于 Transformer 的旋转机械故障诊断方法研究;曹丰;《中国优秀硕士学位论文全文数据库 (工程科技Ⅱ辑)》(第02期);C029-214 * |
Also Published As
Publication number | Publication date |
---|---|
CN116384439A (zh) | 2023-07-04 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114241282B (zh) | 一种基于知识蒸馏的边缘设备场景识别方法及装置 | |
CN109447140B (zh) | 一种基于神经网络深度学习的图像识别并推荐认知的方法 | |
CN113326731B (zh) | 一种基于动量网络指导的跨域行人重识别方法 | |
CN108021947B (zh) | 一种基于视觉的分层极限学习机目标识别方法 | |
CN113128620B (zh) | 一种基于层次关系的半监督领域自适应图片分类方法 | |
CN108345866B (zh) | 一种基于深度特征学习的行人再识别方法 | |
CN111967325A (zh) | 一种基于增量优化的无监督跨域行人重识别方法 | |
CN111239137B (zh) | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 | |
CN110991516A (zh) | 一种基于风格迁移的侧扫声呐图像目标分类方法 | |
CN115563327A (zh) | 基于Transformer网络选择性蒸馏的零样本跨模态检索方法 | |
CN111695640A (zh) | 地基云图识别模型训练方法及地基云图识别方法 | |
CN110909158A (zh) | 基于改进萤火虫算法和k近邻的文本分类方法 | |
CN117152503A (zh) | 一种基于伪标签不确定性感知的遥感图像跨域小样本分类方法 | |
CN111126155B (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
CN114357221B (zh) | 一种基于图像分类的自监督主动学习方法 | |
CN115439715A (zh) | 基于反标签学习的半监督少样本图像分类学习方法及系统 | |
CN116824216A (zh) | 一种无源无监督域适应图像分类方法 | |
CN116258990A (zh) | 一种基于跨模态亲和力的小样本参考视频目标分割方法 | |
CN113095229B (zh) | 一种无监督域自适应行人重识别系统及方法 | |
CN117830616A (zh) | 基于渐进式伪标签的遥感图像无监督跨域目标检测方法 | |
CN110533074B (zh) | 一种基于双深度神经网络的图片类别自动标注方法及系统 | |
CN116384439B (zh) | 一种基于自蒸馏的目标检测方法 | |
CN116433909A (zh) | 基于相似度加权多教师网络模型的半监督图像语义分割方法 | |
CN113313178B (zh) | 一种跨域图像示例级主动标注方法 | |
CN113626537B (zh) | 一种面向知识图谱构建的实体关系抽取方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |