CN112529178A - 一种适用于无预选框检测模型的知识蒸馏方法及系统 - Google Patents

一种适用于无预选框检测模型的知识蒸馏方法及系统 Download PDF

Info

Publication number
CN112529178A
CN112529178A CN202011429812.XA CN202011429812A CN112529178A CN 112529178 A CN112529178 A CN 112529178A CN 202011429812 A CN202011429812 A CN 202011429812A CN 112529178 A CN112529178 A CN 112529178A
Authority
CN
China
Prior art keywords
model
training
teacher
student
student model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202011429812.XA
Other languages
English (en)
Other versions
CN112529178B (zh
Inventor
张瑞琰
安军社
姜秀杰
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
National Space Science Center of CAS
Original Assignee
National Space Science Center of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by National Space Science Center of CAS filed Critical National Space Science Center of CAS
Priority to CN202011429812.XA priority Critical patent/CN112529178B/zh
Publication of CN112529178A publication Critical patent/CN112529178A/zh
Application granted granted Critical
Publication of CN112529178B publication Critical patent/CN112529178B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种适用于无预选框检测模型的知识蒸馏方法及系统,所述方法包括:分别建立教师模型和学生模型;所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;对教师模型进行训练得到训练好的教师模型;对学生模型进行预训练得到预训练后的学生模型;通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。本方法针对以往的需要对整幅特征图进行学习,从而导致网络训练关注点过于分散的现状做出改进,为小模型的训练指明了拟合的方向,对资源受限的硬件部署十分友好,具备较高的实用价值。

Description

一种适用于无预选框检测模型的知识蒸馏方法及系统
技术领域
本发明涉及计算机视觉及光学遥感目标检测领域,尤其涉及一种适用于无预选框检测模型的知识蒸馏方法及系统。
背景技术
在目标检测领域,深度检测模型可根据有无预选框可分为基于预选框模型和无预选框模型。二者相比,无预选框模型的网络结构更为简单,不需要生成大量的预选框,从而降低了检测难度及提高检测速度,成为实际部署中的一个优先考虑方法。即便如此,无预选框模型仍旧具有深度神经网络模型固有的缺点:层数多,参数量大,计算复杂度高,从而难以在资源受限的硬件平台上部署(如移动设施、星载设备等等)。为了进一步缩减无预选框模型的网络规模,本发明采用参数量小计算量少的小检测模型代替原始模型。但小模型无法完美的反映出目标的类别和所处位置,其回归和分类能力都低于原始模型。因此,本发明通过所提基于热点图的知识蒸馏的方法提高小模型的检测性能,使之更好地反映目标特征。
以往的知识蒸馏的方法大多基于图像分类任务展开,对于检测任务的探究往往适用于基于预选框的检测模型,而本发明则提出适用于无预选框检测网络的知识蒸馏方法。这里的原始模型被称作教师模型,小模型则被称作学生模型。将知识蒸馏按照蒸馏位置分类可分为输出层蒸馏和中间层蒸馏,输出层蒸馏由Hint等人提出,将学习分类激活函数层的输入作为暗知识在教师模型和学生模型中传递,并利用学生模型和教师模型的KL散度进行训练。但是这里的输出层蒸馏仅适用于分类网络中的一维向量学习,不适用于无预选框检测中的三维向量学习,因此针对无预选框检测模型需要针对性的设计输出层的学习函数。而对于中间层蒸馏来说,以往的方法针对的是整张特征图的特征模拟,如AT方法将特征图的注意力图作为暗知识,而SP方法则将批量图像产生的激活矩阵作为暗知识。但由于光学遥感图像中目标的稀疏性差距很大、图像背景复杂,学习整张图片的特征往往不能取得优异的效果。综合上述问题可以看出,现有的知识蒸馏方法不完全适用于无预选框检测模型。因此,本发明将充分结合无预选框的网络结构特征,利用无预选框中的热点图作为暗知识,在网络的输出层和中间层均给予直接和间接的指导,进一步提高了小模型的检测精度,实现精度和速度的相对平衡。
发明内容
本发明的目的在于克服现有技术缺陷,提出了一种适用于无预选框检测模型的知识蒸馏方法及系统。
针对现有技术存在的缺陷和不足,主要基于无预选框的中心点检测模型展开,本发明将采用知识蒸馏的方法实现小模型的精度,具体将解决的技术问题为:
(1)在输出层蒸馏训练中,摒弃以往一维变量的学习方法,设计一种适合学习教师模型的三维变量的方法。
(2)在中间层蒸馏训练中,减少学习的总层数为单层。将目标的确切位置考虑在内,该位置由学生模型的真实热点图提供,设计一种将学生模型的注意力集中于显著位置上的学习方法。
为了实现上述目的,一种适用于无预选框检测模型的知识蒸馏方法,所述方法包括:
分别建立教师模型和学生模型;所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;
对教师模型进行训练得到训练好的教师模型;
对学生模型进行预训练得到预训练后的学生模型;
通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。
作为上述方法的一种改进,所述教师模型和学生模型均为检测模型,均采用CenterNet网络,包括主干网络、上采样网络和检测分支网络,所述教师模型和学生模型的输入均为图片,输出均为图片检测结果,其中,
所述教师模型的主干网络为Shufflenet或ResNet或MobileNet,其中ResNet为18层,中间特征图的最大通道数为512;MobileNet的扩张系数为6;
所述学生模型的主干网络为Shufflenet或ResNet或MobileNet,其中ResNet为8层,中间特征图的最大通道数为256;MobileNet的扩张系数为3,最大通道数是教师模型MobileNet最大通道数的1/2的MobileNet。
作为上述方法的一种改进,所述对教师模式进行训练得到训练好的教师模型;具体包括:
构建训练集;
将训练集中的图片依次输入教师模型,采用损失函数L进行训练:
L=Lcls+λLwh+Loff
其中,Lcls为定位损失函数,Lwh为回归损失函数,Lreg为中心偏移损失函数,λ为调节系数,设置为0.1;
采用Adam作为训练优化器,设置初始学习率为1.25e-4,并在训练的第K次和第L次分别衰减学习率10倍,直至得到训练好的教师模型,其中K小于L。
作为上述方法的一种改进,所述构建训练集具体包括:
选取复杂背景航天遥感目标检测公开数据集NWPU VHR-10和通用数据集DOTAv1.0中有标注信息的图片作为数据集;
对数据集中的图片进行裁剪处理,裁剪后图片尺寸为640×640,并且每两张图片有140个像素的重叠区域;
检测裁剪后的图片,如果包含中心点,则保留目标框并调整标注的长宽;如果不包含中心点,则抛弃该目标框;
对检测后的图片进行数据增强操作,包括随机左右翻转,上下翻转及比例放缩,得到大小为512×512的图像,构成训练集。
作为上述方法的一种改进,所述对学生模型进行预训练得到预训练后的学生模型;具体包括:
将训练集中的图片依次输入学生模型,采用损失函数L进行训练:
L=Lcls+λLwh+Loff
其中,λ设置为0.1;
采用Adam作为训练优化器,设置初始学习率为1.25e-4,并在训练到第K次和第L次分别衰减学习率10倍,直至得到预训练好的学生模型。
作为上述方法的一种改进,所述通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型;具体包括:
采用预训练好的学生模型的参数值作为初始值,设置初始学习率为6.25e-4;
将训练集中的图片依次输入学生模型,并在训练到第P次和第Q次分别衰减学习率10倍,通过中间层蒸馏和输出层蒸馏对学生模型的参数进行调整,得到一次重训练后的学生模型;
将训练集中的图片依次输入一次重训练后的学生模型,并在训练到第P次和第Q次分别衰减学习率10倍,通过中间层蒸馏和输出层蒸馏对学生模型的参数进行调整,得到重训练好的的学生模型。
作为上述方法的一种改进,所述中间层蒸馏具体包括:
通过1x1卷积层Conv(·)和PRelu激活函数使得训练好的教师模型的通道自适应匹配预训练后的学生模型的通道
Figure BDA0002826243480000041
Figure BDA0002826243480000042
其中,PRelu为激活函数,
Figure BDA0002826243480000043
为教师模型第l层的未经过激活函数处理的教师特征图,下角标t代表教师模型,θ代表学生模型中的所有参数;
在教师模型和学生模型输出的特征图的第二维度上均采取L2正则化的方法:
Figure BDA0002826243480000044
其中,
Figure BDA0002826243480000045
为学生模型第l层的未经过激活函数处理的学生特征图,下角标s代表学生模型;
选取掩膜图,对掩膜图和教师模型的预测定位热点图进行求和操作,再采用自适应池化和平均池化处理,得到注意力系数图T*
Figure BDA0002826243480000046
其中,Tk为第k类目标的教师模型的预测定位热点图,C表示类别总数,Uk为第k类目标的学生方形掩膜图,adaptive_pool(·)表示自适应池化,avg_pool(·)表示平均池化;
蒸馏函数LMFD为:
Figure BDA0002826243480000047
其中,ρr为显著系数图。
作为上述方法的一种改进,所述输出层蒸馏具体包括:
将教师模型生成的掩膜作为软目标,对于正样本,采用交叉熵的方法计算损失函数;对于负样本,采用教师模型负样本点的激活值引导学生模型的负样本,为学生模型提供先验掩膜形状,定位损失函数Lcls_d为:
Figure BDA0002826243480000051
其中,N为所有目标类别的正样本总数,S(·)为输出的学生模型的定位预测热点图,T(·)为输出的教师模型的定位预测热点图,Pi,j,k为预测定位热点图的任一像素点,i和j为热点图的宽和高二维坐标索引,k为目标的类别,ρ为自定义指数,取值为2,Y为负样本占比的调节系数,用于调节学生和教师输出的分布匹配度,调节收敛速度,ω取2,采用差值二次方来降低位于目标内部的负样本的影响力;
将教师模型的回归预测热点图作为差错上界,当学生模型的预测差错大于差错上界时,采用硬标签的训练结果,否则,令回归损失函数Lwh_d的结果为0;
计算损失函数L为:
L=Lcls_d+λLwh_d+Loff
其中,λ为调节系数,Loff为中心偏移损失,中心偏移损失采用检测模型的原始训练方式。
一种适用于无预选框检测模型的知识蒸馏系统,其特征在于,所述系统包括:教师模型、学生模型、教师模型训练模块、学生模型预训练模块和学生模型重训练模块;其中,
所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;
所述教师模型训练模块,用于对教师模型进行训练得到训练好的教师模型;
所述学生模型预训练模块,用于对学生模型进行预训练得到预训练后的学生模型;
学生模型重训练模块,用于通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。
与现有技术相比,本发明的优势在于:
1、本发明提出了基于中心点检测模型的知识蒸馏方法,针对以往的需要对整幅特征图进行学习,从而导致网络训练关注点过于分散的现状做出改进,为小模型的训练指明了拟合的方向;
2、本发明在输出层蒸馏训练中,提出了一种定位损失函数,令学生模型专注于学习教师模型关于负样本的推理结果,在输出层拟合目标的大致轮廓范围,放宽对学生模型的定位标准;
3、本发明在中间层蒸馏训练中,提出一种显著热点图损失函数,利用学生模型的自定义真实热点图划出包裹目标的外围框图,在训练学生模型时着重训练所划框图内部的数据,达到针对性训练;
4、本发明提出的方法在参数量裁剪到97%时,精度只降低1.5%mAP,实现有效提升小模型的设计目的,对资源受限的硬件部署十分友好,具备较高的实用价值。
附图说明
图1是本发明适用于无预选框检测模型的知识蒸馏方法基于热点图的总体蒸馏框架示意图;
图2是本发明适用于无预选框检测模型的知识蒸馏方法的流程图;
图3是本发明输出层蒸馏的流程图;
图4是本发明中间层蒸馏的流程图;
图5是本发明学生模型的中心点检测网络的结构示意图。
具体实施方式
本发明提供了一种适用于无预选框检测模型的知识蒸馏方法,根据教师模型对学生模型进行重训练进而获得训练好的学生模型,总体技术路线为:
分别建立教师模型和学生模型;所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;
对教师模型进行训练得到训练好的教师模型;
对学生模型进行预训练得到预训练后的学生模型;
通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。
核心设计包括以下内容:
1、总体蒸馏框架设置
总体蒸馏的网络结构如图1所示,在学生模型中增加中间层蒸馏训练和输出层蒸馏训练。其中中间层蒸馏训练的位置放置在主干网络(下采样)和上采样网络之间,输出层蒸馏训练放置在定位检测分支网络的输出层处。
2、基于输出层蒸馏训练模块的检测模型损失函数设计
(1)输出层蒸馏训练的总体损失函数设计。
学生模型的定位损失函数Lcls_d和回归损失函数Lwh_d采用蒸馏的方法得到,中心偏移损失Loff则保持原网络训练方式。定位损失函数则是将蒸馏融入了原始损失函数,而不是在原始损失函数上添加正则惩罚项。
L=Lcls_d+λLwh_d+Loff
(2)输出层蒸馏训练的定位损失函数Lcls_d设计
原始定位方式是在目标中心点人工设置高斯掩膜,代表目标中心点周围像素点作为负样本对损失函数的影响程度,该掩膜还可以指导网络收敛方向。但这种人工设置的方式不够自主化,为此本发明将教师模型生成的掩膜作为软目标,直接利用教师模型负样本点的激活值引导学生模型的负样本,为学生模型提供先验掩膜形状,调节学生模型对负样本的学习程度,进一步可加强学生模型的收敛能力。这里采用如下式所示的定位损失函数。
Figure BDA0002826243480000071
这里N为所有目标类别的正样本总数,S(·)和T(·)分别为输出的学生模型和教师模型的定位预测热点图,ρ为指数,取值为2,Y为负样本占比的调节系数,用于调节学生和教师输出的分布匹配度,调节收敛速度。ω取2,Pi,j,k为预测定位热点图的任一像素点,i和j为热点图的宽和高二维坐标索引,k为目标的类别,采用差值二次方来降低位于目标内部的负样本的影响力;正样本仍采用交叉熵的方法计算损失函数,对于负样本来说,用教师模型的定位预测热点图指导学生。
(3)输出层蒸馏训练模块的回归损失函数Lwh_d设计。
回归损失只是将教师模型的标签作为差错上界,当差错小于这个上界时,采用真实标签的训练结果,当差错大于这个上界时,令损失函数损失函数的结果为0。
3、基于中间层蒸馏训练模块的检测模型损失函数设计
(1)学习的位置确定
这里只学习教师模型的主干网络的最后一个卷积层的输出特征图的信息。区别于以往学习多个层的方法,可降低训练资源。
f(x,θ)为一个包含Conv,BN和ReLU层的前馈卷积神经网络,x包含各层输入图片,θ代表模型中的所有参数。令该网络的第l层的输入为xl-1,则该层fl(xl-1l)的运算定义为:
Ml=BN(conv(Xl-1l)),
Xl=fl(xl-1l)=relu(Ml),
其中relu激活函数将小于0的数据全部舍弃,信息损失较大。教师模型中被裁掉的小于0的信息可能对学生模型有帮助,因此本发明对教师和学生模型的未经过激活函数的特征图M进行迁移处理。
(2)中间层蒸馏流程
1)首先解决通道未对齐问题。添加一个含有bias的1x1卷积层Conv(·)让教师模型的通道自适应匹配学生通道,卷积层后接入PRelu激活函数,既保留下负激活值点,又相应起到抑制作用。该过程定义为
Figure BDA0002826243480000081
其中,
Figure BDA0002826243480000082
为教师模型第l层的未经过激活函数处理的教师特征图。采用Prelu的原因是在于传递更多的激活信息。一个通道常融合多个特征,卷积层即使保证通道间大部分特征匹配,学生模型仍存在教师模型不包含的特征,这些特征在教师特征图中的激活值可能为负。
2)随后,在教师模型和学生模型的特征图的第二维度上均采取L2正则化的方法处理。
Figure BDA0002826243480000083
为学生模型第l层的未经过激活函数处理的学生特征图,
Figure BDA0002826243480000084
3)提出了像素重要性分配蒸馏法。本发明提供了一种方形掩膜的设设计方案,其宽高为目标宽高的一定倍数,为了覆盖目标周围的局部信息,这里的倍数设定为1.1。掩膜设计为目标中心点取1,其余点均取0.9。这里将生成的方形掩膜图记为U,U与热点图尺度一致,将U的各通道相加,以叠加各类目标的掩膜图。同时引入教师的预测定位热点图T。T的不同通道代表不同类别,共有C类,各通道的值均在范围(0,1)内。鉴于在教师预测定位热点图T中,某类目标也会出现在其他类的热点图上,因此不采用叠加的方式,而采用各通道求平均,以融合各类目标。接着对掩膜图和教师热点图进行求和操作,得到注意力系数图(限定其最大值为1)。
再对其自适应池化操作,变换后与学生网络的尺度相同。之后我们引入平滑处理模块,即采用平均池化处理,步长设置为1,感受野设置为3x3,保持图片尺度不变。该平均处理模块可以将数据分布由陡峭变为平缓,集中的数据可以相对扩散开来,使得注意力系数图中的数值变化更为平滑。该过程定义为:
Figure BDA0002826243480000091
之后采用指数函数求得最终的重要性系数值ρr。最终中间层的蒸馏函数如下式:
Figure BDA0002826243480000092
下面结合附图和实施例对本发明的技术方案进行详细的说明。
实施例1
如图2所示,本发明的实施例1提出了一种适用于无预选框检测模型的知识蒸馏方法,首先要对所选的数据集进行分割操作,减小单张图片的尺寸;然后将数据集划分为训练集、测试集和验证集;之后选取模型复杂、检测精度较高的教师模型和模型相对简单的学生模型,并对二者进行训练;然后将学生模型的输出层的两个损失函数(分类和定位函数)替换为本发明所提的算法,而中心偏移损失函数保持不变,并在学生模型对应的中间层添加中间层蒸馏训练模块;分类损失函数处,若输出热点图落在目标的中心点,则视为正样本,采用图中所示的正样本损失函数,否则则采用负样本损失函数,如图3所示;中间层蒸馏模型处,学生既要接收来自教师模型对应的中间层特征图,如图4所示,又要将教师模型的输出层预测图和自身的自定义真实热点图相结合得到显著图系数,并按照最终的中间层损失函数进行训练;最后,训练结束后,对重训练后的学生模型进行推理,即可得到其最终检测效果。
具体实施方法包含以下步骤:
1、光学遥感训练数据集及测试数据集的选取。
(1)数据集选取为复杂背景航天遥感目标检测公开数据集NWPU VHR-10和通用数据集DOTAv1.0。NWPU VHR-10有标注信息的图像650张,背景信息图像150张,包含十类目标。DOTAv1.0数据集包含2806张标注信息的光学遥感图像,共15个类别;
(2)对数据集进行裁剪处理。设定裁剪图片尺寸为图片裁剪为640×640,并保证两张图片有140个像素的重叠区域。在裁剪时,检测物体中心点是否在所得图像内,若包含中心点,则保留目标框并调整标注的长宽;若不包含中心点,则抛弃该目标框;
(3)测试集和训练集的选取。对NWPU VHR-10数据集来说,本文采用数据集中含标注信息的650张图片进行训练和测试。处理后共获得1743张图像,取其中的60%作为训练集(1045张),20%作为验证集(349张),剩余20%为测试集(349张)。对DOTA v1.0数据集来说,对其进行同样大小的裁剪,共得到33892张图片,随机取数据集的1/2为训练集(16946张),1/6为验证集(5649张),1/3为测试集(11297张);
(4)对图片进行数据增强操作,包括随机左右翻转,上下翻转及比例比例放缩等。最终输入网络的图像大小为512×512。
2、教师模型和学生模型的基准检测模型的选取和训练/测试
(1)教师模型和学生模型主要基于CenterNet网络实现,网络结构如图5所示,其分为三部分,主干网络(左侧)、上采样网络(右上)和检测分支网络(右下)。其中检测分支网络分为三类:定位检测分支网络、宽高检测分支网络和偏移检测分支网络。中心点网络将目标视作点,需要生成真实热点图作为最终的图像训练标签。
(2)对于主干网络,采用卷积层部分作为检测模型的主干网络。教师模型的主干网络为Shufflenet或ResNet或MobileNet,其中ResNet为18层,中间特征图的最大通道数为512;MobileNet的扩张系数为6;
学生模型的主干网络为Shufflenet或ResNet或MobileNet,其中ResNet为8层,中间特征图的最大通道数为256;MobileNet的扩张系数为3,最大通道数是教师模型MobileNet最大通道数的1/2的MobileNet。
(3)基准检测模型的损失函数如下式。其中Lcls为定位损失函数,Lwh为回归损失函数,以及Lreg为中心偏移损失函数。λ为调节系数,默认为0.1。这里的学生和教师的基准模型均采用该损失函数训练。
L=Lcls+λLwh+Loff
(2)对教师模型和学生模型的初始精度训练。二者均训练280次,初始学习率为1.25e-4,并在训练的第140次和240次衰减学习率十倍。训练优化器采用Adam。训练及测试的硬件平台为The GTX 1080 8G GPU和i7-7700K 4.20Ghz CPU。以此得到的模型检测精度为以后实验比较的基准值。
3、基于知识蒸馏的学生模型的训练设置及训练结果
(1)训练参数设置
本发明采用训练完备的基准学生模型的参数值来作为模型的初始值。初始学习率设置为6.25e-4,并采用重复两次训练的方法提高学生模型的检测精度。两次均训练170次,分别在第80次和140次处衰减学习率十倍。
(2)蒸馏结果
本发明的最终蒸馏结果如表1所示。其中包含了仅采用输出层蒸馏的效果,输出层加中间层蒸馏(显著系数图ρr=1)以及输出层加中间层蒸馏的最终效果。可以看出,学生模型的参数量十分少,而经过本发明的处理,可以明显的提高小模型的检测精度,缩短其与教师模型的差异。
表1本发明的蒸馏效果
Figure BDA0002826243480000111
实施例2
本发明的实施例2提出了一种适用于无预选框检测模型的知识蒸馏系统,该系统包括:教师模型、学生模型、教师模型训练模块、学生模型预训练模块和学生模型重训练模块;其中,
所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;
所述教师模型训练模块,用于对教师模型进行训练得到训练好的教师模型;
所述学生模型预训练模块,用于对学生模型进行预训练得到预训练后的学生模型;
学生模型重训练模块,用于通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。
本发明提出了基于中心点检测模型的知识蒸馏方法。针对以往的需要对整幅特征图进行学习,从而导致网络训练关注点过于分散的现状做出改进,为小模型的训练指明了拟合的方向。最后经知识蒸馏方法重训练的小模型性能得到极大提高,以检测性能为目标时,参数量为2.09M的模型可以达到94.60%mAP的检测精度;而以压缩率为目标时,检测精度为91.89%mAP的模型参数量仅为0.45M,比原始网络的75.19%mAP的检测精度提高了16.7%mAP。这种参数量小的模型对资源受限的硬件部署十分友好,具备较高的实用价值。
最后所应说明的是,以上实施例仅用以说明本发明的技术方案而非限制。尽管参照实施例对本发明进行了详细说明,本领域的普通技术人员应当理解,对本发明的技术方案进行修改或者等同替换,都不脱离本发明技术方案的精神和范围,其均应涵盖在本发明的权利要求范围当中。

Claims (9)

1.一种适用于无预选框检测模型的知识蒸馏方法,所述方法包括:
分别建立教师模型和学生模型;所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;
对教师模型进行训练得到训练好的教师模型;
对学生模型进行预训练得到预训练后的学生模型;
通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。
2.根据权利要求1所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述教师模型和学生模型均为检测模型,均采用CenterNet网络,包括主干网络、上采样网络和检测分支网络,所述教师模型和学生模型的输入均为图片,输出均为图片检测结果,其中,
所述教师模型的主干网络为Shufflenet或ResNet或MobileNet,其中ResNet为18层,中间特征图的最大通道数为512;MobileNet的扩张系数为6;
所述学生模型的主干网络为Shufflenet或ResNet或MobileNet,其中ResNet为8层,中间特征图的最大通道数为256;MobileNet的扩张系数为3,最大通道数是教师模型MobileNet最大通道数的1/2的MobileNet。
3.根据权利要求1所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述对教师模式进行训练得到训练好的教师模型;具体包括:
构建训练集;
将训练集中的图片依次输入教师模型,采用损失函数L进行训练:
L=Lcls+λLwh+Loff
其中,Lcls为定位损失函数,Lwh为回归损失函数,Lreg为中心偏移损失函数,λ为调节系数,设置为0.1;
采用Adam作为训练优化器,设置初始学习率为1.25e-4,并在训练的第K次和第L次分别衰减学习率10倍,直至得到训练好的教师模型,其中K小于L。
4.根据权利要求3所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述构建训练集具体包括:
选取复杂背景航天遥感目标检测公开数据集NWPU VHR-10和通用数据集DOTAv1.0中有标注信息的图片作为数据集;
对数据集中的图片进行裁剪处理,裁剪后图片尺寸为640×640,并且每两张图片有140个像素的重叠区域;
检测裁剪后的图片,如果包含中心点,则保留目标框并调整标注的长宽;如果不包含中心点,则抛弃该目标框;
对检测后的图片进行数据增强操作,包括随机左右翻转,上下翻转及比例放缩,得到大小为512×512的图像,构成训练集。
5.根据权利要求4所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述对学生模型进行预训练得到预训练后的学生模型;具体包括:
将训练集中的图片依次输入学生模型,采用损失函数L进行训练:
L=Lcls+λLwh+Loff
其中,λ设置为0.1;
采用Adam作为训练优化器,设置初始学习率为1.25e-4,并在训练到第K次和第L次分别衰减学习率10倍,直至得到预训练好的学生模型。
6.根据权利要求1所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型;具体包括:
采用预训练好的学生模型的参数值作为初始值,设置初始学习率为6.25e-4;
将训练集中的图片依次输入学生模型,并在训练到第P次和第Q次分别衰减学习率10倍,通过中间层蒸馏和输出层蒸馏对学生模型的参数进行调整,得到一次重训练后的学生模型;
将训练集中的图片依次输入一次重训练后的学生模型,并在训练到第P次和第Q次分别衰减学习率10倍,通过中间层蒸馏和输出层蒸馏对学生模型的参数进行调整,得到重训练好的的学生模型。
7.根据权利要求6所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述中间层蒸馏具体包括:
通过1x1卷积层Conv(·)和PRelu激活函数使得训练好的教师模型的通道自适应匹配预训练后的学生模型的通道
Figure FDA0002826243470000021
Figure FDA0002826243470000022
其中,PRelu为激活函数,
Figure FDA0002826243470000023
为教师模型第l层的未经过激活函数处理的教师特征图,下角标t代表教师模型,θ代表学生模型中的所有参数;
在教师模型和学生模型输出的特征图的第二维度上均采取L2正则化的方法:
Figure FDA0002826243470000031
其中,
Figure FDA0002826243470000032
为学生模型第l层的未经过激活函数处理的学生特征图,下角标s代表学生模型;
选取掩膜图,对掩膜图和教师模型的预测定位热点图进行求和操作,再采用自适应池化和平均池化处理,得到注意力系数图T*
Figure FDA0002826243470000033
其中,Tk为第k类目标的教师模型的预测定位热点图,C表示类别总数,Uk为第k类目标的学生方形掩膜图,adaptive_pool(·)表示自适应池化,avg_pool(·)表示平均池化;
蒸馏函数LMFD为:
Figure FDA0002826243470000034
其中,ρr为显著系数图。
8.根据权利要求6所述的适用于无预选框检测模型的知识蒸馏方法,其特征在于,所述输出层蒸馏具体包括:
将教师模型生成的掩膜作为软目标,对于正样本,采用交叉熵的方法计算损失函数;对于负样本,采用教师模型负样本点的激活值引导学生模型的负样本,为学生模型提供先验掩膜形状,定位损失函数Lcls_d为:
Figure FDA0002826243470000035
其中,N为所有目标类别的正样本总数,S(·)为输出的学生模型的定位预测热点图,T(·)为输出的教师模型的定位预测热点图,Pi,j,k为预测定位热点图的任一像素点,i和j为热点图的宽和高二维坐标索引,k为目标的类别,ρ为自定义指数,取值为2,Y为负样本占比的调节系数,用于调节学生和教师输出的分布匹配度,调节收敛速度,ω取2,采用差值二次方来降低位于目标内部的负样本的影响力;
将教师模型的回归预测热点图作为差错上界,当学生模型的预测差错大于差错上界时,采用硬标签的训练结果,否则,令回归损失函数Lwh_d的结果为0;
计算损失函数L为:
L=Lcls_d+λLwh_d+Loff
其中,λ为调节系数,Loff为中心偏移损失,中心偏移损失采用检测模型的原始训练方式。
9.一种适用于无预选框检测模型的知识蒸馏系统,其特征在于,所述系统包括:教师模型、学生模型、教师模型训练模块、学生模型预训练模块和学生模型重训练模块;其中,
所述教师模型采用参数固定的大型网络,所述学生模型采用参数可训练的小型网络;
所述教师模型训练模块,用于对教师模型进行训练得到训练好的教师模型;
所述学生模型预训练模块,用于对学生模型进行预训练得到预训练后的学生模型;
学生模型重训练模块,用于通过知识蒸馏方法对预训练后的学生模型通过中间层蒸馏和输出层蒸馏进行重训练,得到训练好的学生模型。
CN202011429812.XA 2020-12-09 2020-12-09 一种适用于无预选框检测模型的知识蒸馏方法及系统 Active CN112529178B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011429812.XA CN112529178B (zh) 2020-12-09 2020-12-09 一种适用于无预选框检测模型的知识蒸馏方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011429812.XA CN112529178B (zh) 2020-12-09 2020-12-09 一种适用于无预选框检测模型的知识蒸馏方法及系统

Publications (2)

Publication Number Publication Date
CN112529178A true CN112529178A (zh) 2021-03-19
CN112529178B CN112529178B (zh) 2024-04-09

Family

ID=74998580

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011429812.XA Active CN112529178B (zh) 2020-12-09 2020-12-09 一种适用于无预选框检测模型的知识蒸馏方法及系统

Country Status (1)

Country Link
CN (1) CN112529178B (zh)

Cited By (17)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112949766A (zh) * 2021-04-07 2021-06-11 成都数之联科技有限公司 目标区域检测模型训练方法及系统及装置及介质
CN113255899A (zh) * 2021-06-17 2021-08-13 之江实验室 一种通道自关联的知识蒸馏方法与系统
CN113361710A (zh) * 2021-06-29 2021-09-07 北京百度网讯科技有限公司 学生模型训练方法、图片处理方法、装置及电子设备
CN113610126A (zh) * 2021-07-23 2021-11-05 武汉工程大学 基于多目标检测模型无标签的知识蒸馏方法及存储介质
CN113657483A (zh) * 2021-08-14 2021-11-16 北京百度网讯科技有限公司 模型训练方法、目标检测方法、装置、设备以及存储介质
CN113743514A (zh) * 2021-09-08 2021-12-03 庆阳瑞华能源有限公司 一种基于知识蒸馏的目标检测方法及目标检测终端
CN113744220A (zh) * 2021-08-25 2021-12-03 中国科学院国家空间科学中心 一种基于pynq的无预选框检测系统
CN114155436A (zh) * 2021-12-06 2022-03-08 大连理工大学 长尾分布的遥感图像目标识别逐步蒸馏学习方法
CN114241285A (zh) * 2021-11-25 2022-03-25 华南理工大学 一种基于知识蒸馏和半监督学习的船舶快速检测方法
CN115640809A (zh) * 2022-12-26 2023-01-24 湖南师范大学 一种基于正向引导知识蒸馏的文档级关系抽取方法
CN115965964A (zh) * 2023-01-29 2023-04-14 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备
CN116071625A (zh) * 2023-03-07 2023-05-05 北京百度网讯科技有限公司 深度学习模型的训练方法、目标检测方法及装置
CN116486285A (zh) * 2023-03-15 2023-07-25 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116612379A (zh) * 2023-05-30 2023-08-18 中国海洋大学 一种基于多知识蒸馏的水下目标检测方法及系统
CN117521848A (zh) * 2023-11-10 2024-02-06 中国科学院空天信息创新研究院 面向资源受限场景的遥感基础模型轻量化方法、装置
CN117542085A (zh) * 2024-01-10 2024-02-09 湖南工商大学 基于知识蒸馏的园区场景行人检测方法、装置及设备
CN114155436B (zh) * 2021-12-06 2024-05-24 大连理工大学 长尾分布的遥感图像目标识别逐步蒸馏学习方法

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108764462A (zh) * 2018-05-29 2018-11-06 成都视观天下科技有限公司 一种基于知识蒸馏的卷积神经网络优化方法
US20190122077A1 (en) * 2016-03-15 2019-04-25 Impra Europe S.A.S. Method for classification of unique/rare cases by reinforcement learning in neural networks
CN110443784A (zh) * 2019-07-11 2019-11-12 中国科学院大学 一种有效的显著性预测模型方法
CN110472730A (zh) * 2019-08-07 2019-11-19 交叉信息核心技术研究院(西安)有限公司 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法
CN110874634A (zh) * 2018-08-31 2020-03-10 阿里巴巴集团控股有限公司 神经网络的优化方法及装置、设备和存储介质
CN111275192A (zh) * 2020-02-28 2020-06-12 交叉信息核心技术研究院(西安)有限公司 一种同时提高神经网络精确度和鲁棒性的辅助训练方法
CN111626330A (zh) * 2020-04-23 2020-09-04 南京邮电大学 基于多尺度特征图重构和知识蒸馏的目标检测方法与系统
CN111680600A (zh) * 2020-05-29 2020-09-18 北京百度网讯科技有限公司 人脸识别模型处理方法、装置、设备和存储介质
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
US20200364542A1 (en) * 2019-05-16 2020-11-19 Salesforce.Com, Inc. Private deep learning

Patent Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190122077A1 (en) * 2016-03-15 2019-04-25 Impra Europe S.A.S. Method for classification of unique/rare cases by reinforcement learning in neural networks
CN108764462A (zh) * 2018-05-29 2018-11-06 成都视观天下科技有限公司 一种基于知识蒸馏的卷积神经网络优化方法
CN110874634A (zh) * 2018-08-31 2020-03-10 阿里巴巴集团控股有限公司 神经网络的优化方法及装置、设备和存储介质
US20200364542A1 (en) * 2019-05-16 2020-11-19 Salesforce.Com, Inc. Private deep learning
CN110443784A (zh) * 2019-07-11 2019-11-12 中国科学院大学 一种有效的显著性预测模型方法
CN110472730A (zh) * 2019-08-07 2019-11-19 交叉信息核心技术研究院(西安)有限公司 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法
CN111275192A (zh) * 2020-02-28 2020-06-12 交叉信息核心技术研究院(西安)有限公司 一种同时提高神经网络精确度和鲁棒性的辅助训练方法
CN111626330A (zh) * 2020-04-23 2020-09-04 南京邮电大学 基于多尺度特征图重构和知识蒸馏的目标检测方法与系统
CN111680600A (zh) * 2020-05-29 2020-09-18 北京百度网讯科技有限公司 人脸识别模型处理方法、装置、设备和存储介质
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
J. YU 等: "Mobile Centernet for EmbeddedDeep Learning Object Detection", 《2020 IEEE INTERNATIONAL CONFERENCE ON MULTIMEDIA & EXPO WORKSHOPS (ICMEW), LONDON, UK, 2020》, 9 June 2020 (2020-06-09), pages 1 - 6 *
X. ZHOU 等: "Objects as Points", 《 ARXIV:1904.07850》, 31 December 2019 (2019-12-31), pages 1 - 12 *
张瑞琰 等: "面向光学遥感目标的全局上下文检测模型设计", 《中国光学》, vol. 16, no. 6, 22 October 2020 (2020-10-22), pages 1302 - 1313 *

Cited By (28)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112949766A (zh) * 2021-04-07 2021-06-11 成都数之联科技有限公司 目标区域检测模型训练方法及系统及装置及介质
CN113255899A (zh) * 2021-06-17 2021-08-13 之江实验室 一种通道自关联的知识蒸馏方法与系统
CN113255899B (zh) * 2021-06-17 2021-10-12 之江实验室 一种通道自关联的知识蒸馏方法与系统
CN113361710A (zh) * 2021-06-29 2021-09-07 北京百度网讯科技有限公司 学生模型训练方法、图片处理方法、装置及电子设备
CN113361710B (zh) * 2021-06-29 2023-11-24 北京百度网讯科技有限公司 学生模型训练方法、图片处理方法、装置及电子设备
CN113610126A (zh) * 2021-07-23 2021-11-05 武汉工程大学 基于多目标检测模型无标签的知识蒸馏方法及存储介质
CN113610126B (zh) * 2021-07-23 2023-12-05 武汉工程大学 基于多目标检测模型无标签的知识蒸馏方法及存储介质
CN113657483A (zh) * 2021-08-14 2021-11-16 北京百度网讯科技有限公司 模型训练方法、目标检测方法、装置、设备以及存储介质
CN113744220A (zh) * 2021-08-25 2021-12-03 中国科学院国家空间科学中心 一种基于pynq的无预选框检测系统
CN113744220B (zh) * 2021-08-25 2024-03-26 中国科学院国家空间科学中心 一种基于pynq的无预选框检测系统
CN113743514A (zh) * 2021-09-08 2021-12-03 庆阳瑞华能源有限公司 一种基于知识蒸馏的目标检测方法及目标检测终端
CN114241285B (zh) * 2021-11-25 2024-05-28 华南理工大学 一种基于知识蒸馏和半监督学习的船舶快速检测方法
CN114241285A (zh) * 2021-11-25 2022-03-25 华南理工大学 一种基于知识蒸馏和半监督学习的船舶快速检测方法
CN114155436A (zh) * 2021-12-06 2022-03-08 大连理工大学 长尾分布的遥感图像目标识别逐步蒸馏学习方法
CN114155436B (zh) * 2021-12-06 2024-05-24 大连理工大学 长尾分布的遥感图像目标识别逐步蒸馏学习方法
CN115640809B (zh) * 2022-12-26 2023-03-28 湖南师范大学 一种基于正向引导知识蒸馏的文档级关系抽取方法
CN115640809A (zh) * 2022-12-26 2023-01-24 湖南师范大学 一种基于正向引导知识蒸馏的文档级关系抽取方法
CN115965964A (zh) * 2023-01-29 2023-04-14 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备
CN115965964B (zh) * 2023-01-29 2024-01-23 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备
CN116071625A (zh) * 2023-03-07 2023-05-05 北京百度网讯科技有限公司 深度学习模型的训练方法、目标检测方法及装置
CN116486285A (zh) * 2023-03-15 2023-07-25 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116486285B (zh) * 2023-03-15 2024-03-19 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116612379B (zh) * 2023-05-30 2024-02-02 中国海洋大学 一种基于多知识蒸馏的水下目标检测方法及系统
CN116612379A (zh) * 2023-05-30 2023-08-18 中国海洋大学 一种基于多知识蒸馏的水下目标检测方法及系统
CN117521848A (zh) * 2023-11-10 2024-02-06 中国科学院空天信息创新研究院 面向资源受限场景的遥感基础模型轻量化方法、装置
CN117521848B (zh) * 2023-11-10 2024-05-28 中国科学院空天信息创新研究院 面向资源受限场景的遥感基础模型轻量化方法、装置
CN117542085A (zh) * 2024-01-10 2024-02-09 湖南工商大学 基于知识蒸馏的园区场景行人检测方法、装置及设备
CN117542085B (zh) * 2024-01-10 2024-05-03 湖南工商大学 基于知识蒸馏的园区场景行人检测方法、装置及设备

Also Published As

Publication number Publication date
CN112529178B (zh) 2024-04-09

Similar Documents

Publication Publication Date Title
CN112529178A (zh) 一种适用于无预选框检测模型的知识蒸馏方法及系统
US11581130B2 (en) Internal thermal fault diagnosis method of oil-immersed transformer based on deep convolutional neural network and image segmentation
CN109712165B (zh) 一种基于卷积神经网络的同类前景图像集分割方法
CN110796009A (zh) 基于多尺度卷积神经网络模型的海上船只检测方法及系统
CN107229932A (zh) 一种图像文本的识别方法和装置
CN113569667B (zh) 基于轻量级神经网络模型的内河船舶目标识别方法及系统
CN110135446B (zh) 文本检测方法及计算机存储介质
CN109740585A (zh) 一种文本定位方法及装置
CN113610905B (zh) 基于子图像匹配的深度学习遥感图像配准方法及应用
CN114821390A (zh) 基于注意力和关系检测的孪生网络目标跟踪方法及系统
CN116110022B (zh) 基于响应知识蒸馏的轻量化交通标志检测方法及系统
CN114359245A (zh) 一种工业场景下产品表面缺陷检测方法
CN110516512B (zh) 行人属性分析模型的训练方法、行人属性识别方法及装置
CN111931915A (zh) 一种基于diou损失函数的训练网络的方法
CN115565043A (zh) 结合多表征特征以及目标预测法进行目标检测的方法
CN109255382A (zh) 用于图片匹配定位的神经网络系统,方法及装置
CN114021704B (zh) 一种ai神经网络模型的训练方法及相关装置
CN114022727A (zh) 一种基于图像知识回顾的深度卷积神经网络自蒸馏方法
CN116416468B (zh) 一种基于神经架构搜索的sar目标检测方法
CN111428191A (zh) 基于知识蒸馏的天线下倾角计算方法、装置和存储介质
CN116543433A (zh) 一种基于改进YOLOv7模型的口罩佩戴检测方法和装置
CN116229217A (zh) 一种应用于复杂环境下的红外目标检测方法
CN115272755A (zh) 一种激光点云检测分割方法及系统
CN115240084A (zh) 一种无人机跟踪方法、装置和计算机可读存储介质
CN114049478A (zh) 基于改进Cascade R-CNN的红外船舶图像快速识别方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant