CN112464981B - 基于空间注意力机制的自适应知识蒸馏方法 - Google Patents

基于空间注意力机制的自适应知识蒸馏方法 Download PDF

Info

Publication number
CN112464981B
CN112464981B CN202011165181.5A CN202011165181A CN112464981B CN 112464981 B CN112464981 B CN 112464981B CN 202011165181 A CN202011165181 A CN 202011165181A CN 112464981 B CN112464981 B CN 112464981B
Authority
CN
China
Prior art keywords
sample
network
distillation
loss function
loss
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202011165181.5A
Other languages
English (en)
Other versions
CN112464981A (zh
Inventor
王金桥
王童
朱优松
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sinovision Jurong Technology Co ltd
Original Assignee
Sinovision Jurong Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Sinovision Jurong Technology Co ltd filed Critical Sinovision Jurong Technology Co ltd
Priority to CN202011165181.5A priority Critical patent/CN112464981B/zh
Publication of CN112464981A publication Critical patent/CN112464981A/zh
Application granted granted Critical
Publication of CN112464981B publication Critical patent/CN112464981B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/40Software arrangements specially adapted for pattern recognition, e.g. user interfaces or toolboxes therefor
    • G06F18/41Interactive pattern learning with a human teacher
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/94Hardware or software architectures specially adapted for image or video understanding
    • G06V10/95Hardware or software architectures specially adapted for image or video understanding structured as a network, e.g. client-server architectures

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Software Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • Human Computer Interaction (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

本发明公开了一种基于空间注意力机制的自适应知识蒸馏方法,属于计算机视觉和模式识别技术领域。该方法利用每个样本的分类损失值评估样本的难易程度,若分类损失值大于阈值,则该样本较难,给予该样本更多关注;若分类损失值小于阈值,则忽略该样本;然后,使用单调递增函数为每一个较难的样本赋予蒸馏权重值;将样本的蒸馏权重值映射到特征图对应的空间位置,得到和特征图大小相同的注意力图;最后在进行网络蒸馏时利用注意力图对特征图的每个位置进行加权,指导学生网络的学习过程。该方法基于空间注意力机制,能够引导学生网络关注比较困难的样本的特征学习,从而提升学生网络的性能。

Description

基于空间注意力机制的自适应知识蒸馏方法
技术领域
本发明属于计算机视觉和模式识别技术领域,具体涉及基于空间注意力机制的自适应知识蒸馏方法。
背景技术
目标检测是计算机视觉中的一项基础任务,是很多计算机视觉任务的前提。很多实际应用场景都对检测器的模型大小和推理时间提出了严格的要求。因此,如何设计一种小而精的检测器成为了目标检测领域的一个研究热点。
在深度学习时代,目标检测器主要可划分成两大类:双阶段检测器和单阶段检测器。双阶段检测器首先通过RPN(Region Proposal Network)产生候选框,然后利用RoIPooling得到的特征对在第二个阶段产生的候选框进行类别细分和边框微调。相比于双阶段检测器,单阶段的检测器结构更简单。单阶段检测器没有显式产生候选框的过程,而是直接对每一个anchor进行类别预测和边框回归。由于单阶段检测器没有第二个阶段的微调,所以速度上对比双阶段检测器更有优势。
为了进一步减小网络模型大小,提升网络推理速度。研究者提出了很多方法,比如:轻量级网络结构设计,网络剪枝,量化压缩,知识蒸馏等。通过设计轻量级的网络结构,模型的计算复杂度可以得到一定程度的削减,进而减少推理时间。网络剪枝通过裁剪冗余的,不重要的通道来减小模型大小以及推理时间。量化通过将网络的浮点数运算转换为定点数运算来减小网络的计算开销,同时尽可能的保持精度。知识蒸馏通过在训练小容量的学生网络的过程中引入大容量的教师网络的监督,来将教师网络学习到的知识“传授”给学生网络,从而提升学生网络的精度。相比于前三种方法,知识蒸馏不需要对网络的结构和参数进行改变,操作起来比较方便,并且可以和前三种方法结合起来使用。
发明内容
针对现有技术中存在的问题,本发明要解决的技术问题在于提供一种基于空间注意力机制的自适应知识蒸馏方法,该方法基于空间注意力机制,能够引导学生网络关注于比较困难的样本的特征学习,从而提升学生网络的性能。
为了解决上述问题,本发明所采用的技术方案如下:
基于空间注意力机制的自适应知识蒸馏方法,在特征蒸馏的过程中,通过引入空间注意力机制,学生网络会更加关注难度较大的样本对应的特征区域,忽略难度较小的样本对应的特征区域,从而使学生网络的知识蒸馏过程更加高效和有针对性。具体包括以下步骤:
(1)评估样本的难易程度:在训练检测器的过程中,网络会对每一个样本产生分类损失和回归损失,采用样本分类损失值评估样本的难易程度,如果样本分类损失值大于阈值,则该样本进入步骤(2);如果样本分类损失值小于阈值,则学生网络忽略该样本;
(2)构建空间注意力图:使用单调递增函数为每一个进入该步骤的样本赋予一个蒸馏权重值;将每一个样本的蒸馏权重值映射到教师网络特征图对应的空间位置,得到和特征图大小相同的空间注意力图;
(3)完成学习过程:利用空间注意力图对教师网络的特征图的每个位置进行加权,指导学生网络的学习过程。
所述基于空间注意力机制的自适应知识蒸馏方法,所述教师网络为在PASCAL VOC数据集上训练一个VGG16-SSD检测器得到的,在进行知识蒸馏过程之前对教师网络进行初始化,结构上只保留主干网络部分。
所述基于空间注意力机制的自适应知识蒸馏方法,所述学生网络的主干部分选用1/n的VGG16网络,并在学生网络的conv4-3,fc7,conv6-2,conv7-2,conv8-2和conv9-2层上添加检测头、注意力图产生结构以及和教师网络对应层之间的损失函数。
所述基于空间注意力机制的自适应知识蒸馏方法,所述采用如下公式为每一个样本赋予权重值:
wk为样本的蒸馏权重值,lk是样本Sk的分类损失值,α和β为超参数,wmax作为蒸馏权重的上限值。
所述基于空间注意力机制的自适应知识蒸馏方法,学生网络的通道数比教师网络的通道数少,于是使用一个1*1的卷积层加Relu对学生网络的特征层进行升维,使之达到和对应教师网络特征相同的维度。
所述基于空间注意力机制的自适应知识蒸馏方法,采用两个损失函数监督学生网络的学习过程,所述两个损失函数分别是目标检测器正常的检测损失函数和注意力机制引导的蒸馏损失函数,其中目标检测器正常的检测损失函数包括分类损失函数和回归损失函数。
所述基于空间注意力机制的自适应知识蒸馏方法,样本的总体损失函数为:
Ltotal=Ldet1Ldis
检测损失函数的形式为:
Ldet=Lcls2Lreg
注意力机制引导的蒸馏损失函数为一个加权的L2范数损失函数,具体形式为:
其中,Ltotal为样本的总体损失函数,Ldet为检测损失函数,Ldis为注意力机制引导的蒸馏损失函数,Lcls为分类损失函数,Lreg为回归损失函数,λ1和λ2为不同损失函数之间的平衡因子;M是检测器用来预测的特征层的个数,Am是属于第m个特征层的空间注意力图,Tm是教师网络的特征,R(Sm)是经过升维后的学生网络的特征。
所述基于空间注意力机制的自适应知识蒸馏方法,所述步骤(2),阈值为0.01。
有益效果:与现有的技术相比,本发明的优点包括:
(1)本发明提出了一种新的空间注意力机制,能够引导学生网络关注于比较困难的样本的特征学习,从而提升学生网络的性能。
(2)本发明方法可以帮助学生网络从随机初始值开始训练,省去了冗长的ImageNet预训练的过程。
附图说明
图1为基于空间注意力机制的自适应蒸馏方法流程示意图;
图2为不同α和β取值下的权重-损失值变化曲线图。
具体实施方式
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合具体实施例对本发明的具体实施方式做详细的说明。
实施例1
一种基于空间注意力机制的自适应知识蒸馏方法,具体流程如图1所示,该方法以经典的单阶段检测器SSD(“ssd:Single shot multibox detector”)为例介绍其工作原理。由图1可知,整个蒸馏框架由两部分组成,一部分是已经训练完毕的精度比较高的教师网络,另一部分是随机初始化的网络容量比较小的学生网络。该方法在特征蒸馏的过程中,通过引入空间注意力机制,学生网络会更加关注到难度较大的样本对应的特征区域,忽略难度较小的样本对应的特征区域,从而使学生网络知识蒸馏过程更加高效和有针对性。具体包括以下步骤:
(1)在PASCAL VOC数据集上训练一个VGG16-SSD检测器作为教师网络;并对教师网络进行初始化,结构上只保留主干网络部分;学生网络的主干部分选用1/n的VGG16网络(网络结构和原来的结构相同,但是每层的通道数为原网络的1/n),学生网络是一个完整的SSD检测器的结构;由于SSD检测器采取的是多层预测的方式,因此选择被用来进行预测的特征层作为蒸馏层进行特征蒸馏;并在学生网络的conv4-3,fc7,conv6-2,conv7-2,conv8-2和conv9-2层上添加检测头(分类分支+回归分支),注意力图产生结构以及和教师网络对应层之间的损失函数;学生网络的通道数比教师网络的通道数少,于是使用一个1*1的卷积层加Relu对学生网络的特征层进行升维,使之达到和对应教师网络特征相同的维度;
(2)评估样本的难易程度:在训练检测器的过程中,网络会对每一个样本产生分类损失和回归损失,该步骤利用已经得到的分类损失来评估样本的难易程度;如果样本分类损失值大于阈值0.01,说明当前网络对该样本的预测和真实值相差较远,该样本是比较困难的样本,则该样本进入步骤(2);如果样本分类损失值小于阈值0.01,则表明该样本对于当前网络来说比较简单,学生网络忽略该样本;
(3)构建空间注意力图:根据步骤(1)得到的分类损失值,使用单调递增函数为每一个进入该步骤的样本赋予一个蒸馏权重值;将每一个样本的蒸馏权重值映射到教师网络特征图对应的空间位置,得到和特征图大小相同的空间注意力图;采用如下公式为每一个样本赋予权重值:
wk为样本的蒸馏权重值,lk是样本Sk的分类损失值,α和β为超参数,wmax作为蒸馏权重的上限值;α和β为引入的两个超参数,用来控制每个样本的权重;为了避免网络训练初期过大的蒸馏权重导致网络的发散,引入了参数wmax作为蒸馏权重的上限值;当α和β取不同值时,蒸馏权重随着损失函数值的变化曲线图,如图2所示。图2的左图可以看出,当损失值比较大时,曲线对参数α比较敏感。而当损失值比较小时,曲线对参数β比较敏感。通过选取合适的参数α和β,可以控制不同难度样本之间的蒸馏权重差异。
对于若干个样本重合的特征区域中的像素点,其蒸馏权重取这些样本的蒸馏权重里的最大值。数学公式表示如下公式所示:
上式中的ai,j为空间注意力图A在i,j位置上的值,Gi,j是一个样本的集合,其中的每一个样本都包含空间位置i,j,wgk是由式(1)计算出来的样本gk的蒸馏权重;经过式(2),可以得到一个大小为1×1×H×W的空间注意力图,该空间注意力图与对应的特征图大小相同;对于特征图来说,其维度为1×C×H×W,所以需要将单通道的空间注意力图扩展到C个通道,使其维度和特征图的维度相匹配;
(4)完成学习过程:利用空间注意力图对教师网络的特征图的每个位置进行加权,在训练的过程中,教师网络的参数不进行更新,只负责为学生网络提供特征层面的监督信息,只更新学生网络的参数;采用两个损失函数来监督学生网络的训练,一个是目标检测器正常的检测损失函数(分类损失函数+回归损失函数),另一个是注意力机制引导的蒸馏损失函数;
其中检测损失函数的形式为:
Ldet=Lcls2Lreg
注意力机制引导的蒸馏损失函数为一个加权的L2范数损失函数,具体形式为:
样本的总体损失函数为:
Ltotal=Ldet1Ldis
其中,Ltotal为样本的总体损失函数,Ldet为检测损失函数,Ldis为注意力机制引导的蒸馏损失函数,Lcls为分类损失函数,Lreg为回归损失函数,λ1和λ2为不同损失函数之间的平衡因子;M是检测器用来预测的特征层的个数,Am是属于第m个特征层的空间注意力图,Tm是教师网络的特征,R(Sm)是经过升维后的学生网络的特征;该蒸馏损失函数的目的是让学生网络的输出特征去模拟教师网络的输出特征,并且通过产生的注意力图来引导学生网络的学习过程,使之更加关注那些比较难学习的样本对应的特征区域;
(5)网络测试,在测试阶段,用学生网络的骨干部分和检测头进行测试,教师网络的部分在测试阶段被丢弃。

Claims (4)

1.基于空间注意力机制的自适应知识蒸馏方法,其特征在于,在特征蒸馏的过程中,通过引入空间注意力机制,学生网络会更加关注难度较大的样本对应的特征区域,忽略难度较小的样本对应的特征区域,从而使学生网络的知识蒸馏过程更加高效和有针对性;该方法包括以下步骤:
(1)评估样本的难易程度:在训练检测器的过程中,网络会对每一个样本产生分类损失和回归损失,采用样本分类损失值评估样本的难易程度,如果样本分类损失值大于阈值,则该样本进入步骤(2);如果样本分类损失值小于阈值,则忽略该样本;
(2)构建空间注意力图:使用单调递增函数为每一个进入该步骤的样本赋予一个蒸馏权重值;将每一个样本的蒸馏权重值映射到教师网络特征图对应的空间位置,得到和特征图大小相同的空间注意力图;
(3)完成学习过程:利用空间注意力图对教师网络的特征图的每个位置进行加权,指导学生网络的学习过程;
所述教师网络为在PASCAL VOC数据集图像样本上训练一个VGG16-SSD检测器得到,在进行知识蒸馏过程之前对教师网络进行初始化,结构上只保留主干网络部分;
所述学生网络的主干部分选用1/n的VGG16网络,并在学生网络的conv4-3,fc7,conv6-2,conv7-2,conv8-2和conv9-2层上添加检测头、注意力图产生结构以及和教师网络对应层之间的损失函数;
采用如下公式为每一个样本赋予权重值:
wk为样本的蒸馏权重值,lk是样本Sk的分类损失值,α和β为超参数,wmax作为蒸馏权重的上限值;
学生网络的通道数比教师网络的通道数少,于是使用一个1*1的卷积层加Relu对学生网络的特征层进行升维,使之达到和对应教师网络特征相同的维度。
2.根据权利要求1所述基于空间注意力机制的自适应知识蒸馏方法,其特征在于,采用两个损失函数监督学生网络的学习过程,所述两个损失函数分别是目标检测器正常的检测损失函数和注意力机制引导的蒸馏损失函数,其中目标检测器正常的检测损失函数包括分类损失函数和回归损失函数。
3.根据权利要求2所述基于空间注意力机制的自适应知识蒸馏方法,其特征在于,样本的总体损失函数为:
Ltotal=Ldet1Ldis
检测损失函数的形式为:
Ldet=Lcls2Lreg
注意力机制引导的蒸馏损失函数为一个加权的L2范数损失函数,具体为:
其中,Ltotal为样本的总体损失函数,Ldet为检测损失函数,Ldis为注意力机制引导的蒸馏损失函数,Lcls为分类损失函数,Lreg为回归损失函数,λ1和λ2为不同损失函数之间的平衡因子;M是检测器用来预测的特征层的个数,Am是属于第m个特征层的空间注意力图,Tm是教师网络的特征,R(Sm)是经过升维后的学生网络的特征。
4.根据权利要求1所述基于空间注意力机制的自适应知识蒸馏方法,其特征在于,所述步骤(2),阈值为0.01。
CN202011165181.5A 2020-10-27 2020-10-27 基于空间注意力机制的自适应知识蒸馏方法 Active CN112464981B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011165181.5A CN112464981B (zh) 2020-10-27 2020-10-27 基于空间注意力机制的自适应知识蒸馏方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011165181.5A CN112464981B (zh) 2020-10-27 2020-10-27 基于空间注意力机制的自适应知识蒸馏方法

Publications (2)

Publication Number Publication Date
CN112464981A CN112464981A (zh) 2021-03-09
CN112464981B true CN112464981B (zh) 2024-02-06

Family

ID=74834930

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011165181.5A Active CN112464981B (zh) 2020-10-27 2020-10-27 基于空间注意力机制的自适应知识蒸馏方法

Country Status (1)

Country Link
CN (1) CN112464981B (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113989577B (zh) * 2021-12-24 2022-04-05 中科视语(北京)科技有限公司 图像分类方法及装置
CN115294407B (zh) * 2022-09-30 2023-01-03 山东大学 基于预习机制知识蒸馏的模型压缩方法及系统
CN117315516B (zh) * 2023-11-30 2024-02-27 华侨大学 基于多尺度注意力相似化蒸馏的无人机检测方法及装置

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CA3076424A1 (en) * 2019-03-22 2020-09-22 Royal Bank Of Canada System and method for knowledge distillation between neural networks

Non-Patent Citations (4)

* Cited by examiner, † Cited by third party
Title
GAN-Knowledge Distillation for One-Stage Object Detection;Wanwei Wang等;《IEEE Access》;第8卷;60719 - 60727 *
Improving Object Detection with Inverted Attention;Zeyi Huang等;《2020 IEEE Winter Conference on Applications of Computer Vision (WACV)》;1294-1302 *
Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer;Sergey Zagoruyko等;《arXiv》;1-13 *
基于单目视觉夜间前方车辆检测与距离研究;齐春阳;《中国优秀硕士学位论文全文数据库 (工程科技Ⅱ辑)》(第8期);C035-407 *

Also Published As

Publication number Publication date
CN112464981A (zh) 2021-03-09

Similar Documents

Publication Publication Date Title
CN112464981B (zh) 基于空间注意力机制的自适应知识蒸馏方法
Shen et al. Wind speed prediction of unmanned sailboat based on CNN and LSTM hybrid neural network
US5924085A (en) Stochastic encoder/decoder/predictor
CN112150821B (zh) 轻量化车辆检测模型构建方法、系统及装置
US20220351019A1 (en) Adaptive Search Method and Apparatus for Neural Network
CN110472778A (zh) 一种基于Blending集成学习的短期负荷预测方法
CN111860982A (zh) 一种基于vmd-fcm-gru的风电场短期风电功率预测方法
CN111310672A (zh) 基于时序多模型融合建模的视频情感识别方法、装置及介质
CN108764298B (zh) 基于单分类器的电力图像环境影响识别方法
CN110428413B (zh) 一种用于灯诱设备下的草地贪夜蛾成虫图像检测方法
CN112766496B (zh) 基于强化学习的深度学习模型安全性保障压缩方法与装置
CN116562908A (zh) 一种基于双层vmd分解和ssa-lstm的电价预测方法
CN116307211A (zh) 一种风电消纳能力预测及优化方法及系统
CN115424177A (zh) 一种基于增量学习的孪生网络目标跟踪的方法
CN113255873A (zh) 一种聚类天牛群优化方法、系统、计算机设备和存储介质
CN117455855A (zh) 一种基于YOLOv8的轻量级管纱检测模型构建方法
CN117173449A (zh) 基于多尺度detr的航空发动机叶片缺陷检测方法
CN116485021A (zh) 一种煤炭企业技术技能人才人岗匹配预测方法与系统
CN115423091A (zh) 一种条件对抗神经网络训练方法、场景生成方法和系统
CN113807005A (zh) 基于改进fpa-dbn的轴承剩余寿命预测方法
CN110728292A (zh) 一种多任务联合优化下的自适应特征选择算法
CN111259860A (zh) 基于数据自驱动的多阶特征动态融合手语翻译方法
AU2021101713A4 (en) Remote sensing image building recognition method based on activated representational substitution
Wang et al. A Deep Neural Network Optimization Strategy based on Roofline Model
CN117835329B (zh) 车载边缘计算中基于移动性预测的服务迁移方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant