CN115984640A - 一种基于组合蒸馏技术的目标检测方法、系统和存储介质 - Google Patents

一种基于组合蒸馏技术的目标检测方法、系统和存储介质 Download PDF

Info

Publication number
CN115984640A
CN115984640A CN202211504333.9A CN202211504333A CN115984640A CN 115984640 A CN115984640 A CN 115984640A CN 202211504333 A CN202211504333 A CN 202211504333A CN 115984640 A CN115984640 A CN 115984640A
Authority
CN
China
Prior art keywords
target
trained
target detection
distillation
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202211504333.9A
Other languages
English (en)
Other versions
CN115984640B (zh
Inventor
常雨喆
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shumei Tianxia Beijing Technology Co ltd
Beijing Nextdata Times Technology Co ltd
Original Assignee
Shumei Tianxia Beijing Technology Co ltd
Beijing Nextdata Times Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shumei Tianxia Beijing Technology Co ltd, Beijing Nextdata Times Technology Co ltd filed Critical Shumei Tianxia Beijing Technology Co ltd
Priority to CN202211504333.9A priority Critical patent/CN115984640B/zh
Publication of CN115984640A publication Critical patent/CN115984640A/zh
Application granted granted Critical
Publication of CN115984640B publication Critical patent/CN115984640B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了一种基于组合蒸馏技术的目标检测方法、系统和存储介质,包括:通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的教师模型对学生模型进行知识蒸馏,得到学生模型的目标损失函数;将每个训练样本输入至训练好的教师模型中,得到每个训练样本对应的中间特征图、目标定位框和目标分类概率值;基于每个训练样本、每个训练样本对应的中间特征图、目标定位框和目标分类概率值,以及目标损失函数,对学生模型进行迭代训练,直至得到训练好的学生模型;将待测图像输入训练好的学生模型,得到待测图像的目标检测结果。本发明在保证目标检测准确率的同时,降低了模型的参数量,实现了模型的压缩。

Description

一种基于组合蒸馏技术的目标检测方法、系统和存储介质
背景技术
知识蒸馏是一种通用的模型压缩算法,其基本思路是把最终用于实时检测的模型称为学生模型,然后找到一个比使用模型更大的训练完毕模型,称为教师模型,教师模型的准召率都明显高于学生模型。在学生模型训练时,除了利用标注好的数据进行有监督训练以外,还通过算法设计,将教师模型学到的“知识”传递给学生模型,最终得到一个参数量远小于教师模型,准召率都接近教师模型的学生模型。对于传递的知识一般有两种方法:学生模型的中间网络输出的特征图对教师模型中间网络输出的特征图的模仿,叫做featureimitation;学生模型的最终输出对教师模型最终输出的模仿,叫做logit mimicking。常用的知识蒸馏是采用其中一种方式,尽管能够通过对学生模型的输出添加新的约束条件达到效果的提升,但提升的效果有限。
因此,亟需提供一种技术方案解决上述技术问题。
发明内容
为解决上述技术问题,本发明提供了一种基于组合蒸馏技术的目标检测方法、系统和存储介质。
本发明的一种基于组合蒸馏技术的目标检测方法的技术方案如下:
通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数;
将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值;
基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型;
将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
本发明的一种基于组合蒸馏技术的目标检测方法的有益效果如下:
本发明的方法通过采用组合蒸馏的方式并利用训练好的教师模型生成学生模型,在保证目标检测准确率的同时,降低了模型的参数量,实现了对复杂结构模型的压缩。
在上述方案的基础上,本发明的一种基于组合蒸馏技术的目标检测方法还可以做如下改进。
进一步,还包括:
基于所述多个训练样本,对待训练的目标检测教师模型进行训练,得到所述训练好的目标检测教师模型。
进一步,所述目标损失函数为:
Figure BDA0003967618410000021
其中,L为所述目标损失函数,Loriginal为所述待训练的目标检测学生模型的原始损失函数,Lfea为所述改进的特征图蒸馏所对应的特征图蒸馏损失函数,LLD为所述改进的定位蒸馏所对应的定位蒸馏损失函数,
Figure BDA0003967618410000022
为所述改进的分类蒸馏所对应的分类蒸馏损失函数;
其中,
Figure BDA0003967618410000023
Figure BDA0003967618410000024
为训练样本的标注框,i和j为特征图上的像素点,
Figure BDA0003967618410000031
Hr为所述标注框的高度,Wr为所述标注框的宽度,Nbg为所有背景像素点的个数,FT为所述训练好的目标检测教师模型输出的第一中间特征图,FS为所述待训练的目标检测学生模型输出的第二中间特征图,C为第一中间特征图和第二中间特征图的通道数,H为第一中间特征图和第二中间特征图的高度,W为第一中间特征图和第二中间特征图的宽度,f为辅助网络,用于将所述待训练的目标检测学生模型的第二中间特征图的通道数放缩至与和所述训练好的目标检测教师模型的第一中间特征图相同,α和β为用于平衡中间特征图的前景和背景之间损失的超参项;
其中,
Figure BDA0003967618410000032
TCKD为训练样本对应的第一目标分类概率值中的标注类别概率的蒸馏,NCKD为训练样本对应的第一目标分类概率值中的其他非标注类别概率的蒸馏,m和n为可调节的超参数;
其中,
Figure BDA0003967618410000033
Figure BDA0003967618410000034
e为目标定位文本框的任意一边,
Figure BDA0003967618410000035
为所述任意一边的定位蒸馏损失函数,ZS为所述待训练的目标检测学生模型的所述任意一边的n个预测值,
Figure BDA0003967618410000036
为所述待训练的目标检测学生模型的所述任意一边的n个预测值经过softmax的值,ZT为所述训练好的目标检测教师模型的所述任意一边的n个预测值,
Figure BDA0003967618410000037
为所述训练好的目标检测教师模型的所述任意一边的n个预测值经过softmax的值;BS为所述待训练的目标检测学生模型输出的第二目标定位框,BT为所述训练好的目标检测教师模型的输出的第一目标定位框。
进一步,所述基于每个训练样本、每个训练样本对应的中间特征图、目标定位框和目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型的步骤,包括:
将任一训练样本输入所述待训练的目标检测学生模型,得到所述任一训练样本的第二中间特征图、第二目标定位框和第二目标分类概率值;
基于所述目标损失函数、所述任一训练样本的第一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值,得到所述任一训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述待训练的目标检测学生模型的参数进行优化,得到优化后的目标检测学生模型,将所述优化后的目标检测学生模型作为所述待训练的目标检测学生模型并返回执行将任一训练样本输入所述待训练的目标检测学生模型的步骤,直至满足预设迭代训练条件时,将所述优化后的目标检测学生模型确定为所述训练好的目标检测学生模型。
进一步,所述目标检测教师模型采用yolov5l模型,所述目标检测学生模型采用yolov5s模型。
本发明的一种基于组合蒸馏技术的目标检测系统的技术方案如下:
包括:构建模块、处理模块、训练模块和检测模块;
所述构建模块用于:通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数;
所述处理模块用于:将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值;
所述训练模块用于:基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型;
所述检测模块用于:将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
本发明的一种基于组合蒸馏技术的目标检测系统的有益效果如下:
本发明的系统通过采用组合蒸馏的方式并利用训练好的教师模型生成学生模型,在保证目标检测准确率的同时,降低了模型的参数量,实现了对复杂结构模型的压缩。
在上述方案的基础上,本发明的一种基于组合蒸馏技术的目标检测系统还可以做如下改进。
进一步,还包括:预训练模块;
所述预训练模块用于:基于所述多个训练样本,对待训练的目标检测教师模型进行训练,得到所述训练好的目标检测教师模型。
进一步,所述目标损失函数为:
Figure BDA0003967618410000051
其中,L为所述目标损失函数,Loriginal为所述待训练的目标检测学生模型的原始损失函数,Lfea为所述改进的特征图蒸馏所对应的特征图蒸馏损失函数,LLD为所述改进的定位蒸馏所对应的定位蒸馏损失函数,
Figure BDA0003967618410000052
为所述改进的分类蒸馏所对应的分类蒸馏损失函数;
其中,
Figure BDA0003967618410000053
Figure BDA0003967618410000054
为训练样本的标注框,i和j为特征图上的像素点,
Figure BDA0003967618410000055
Hr为所述标注框的高度,Wr为所述标注框的宽度,Nbg为所有背景像素点的个数,FT为所述训练好的目标检测教师模型输出的第一中间特征图,FS为所述待训练的目标检测学生模型输出的第二中间特征图,C为第一中间特征图和第二中间特征图的通道数,H为第一中间特征图和第二中间特征图的高度,W为第一中间特征图和第二中间特征图的宽度,f为辅助网络,用于将所述待训练的目标检测学生模型的第二中间特征图的通道数放缩至与和所述训练好的目标检测教师模型的第一中间特征图相同,α和β为用于平衡中间特征图的前景和背景之间损失的超参项;
其中,
Figure BDA0003967618410000061
TCKD为训练样本对应的第一目标分类概率值中的标注类别概率的蒸馏,NCKD为训练样本对应的第一目标分类概率值中的其他非标注类别概率的蒸馏,m和n为可调节的超参数;
其中,
Figure BDA0003967618410000062
Figure BDA0003967618410000063
e为目标定位文本框的任意一边,
Figure BDA0003967618410000064
为所述任意一边的定位蒸馏损失函数,ZS为所述待训练的目标检测学生模型的所述任意一边的n个预测值,
Figure BDA0003967618410000065
为所述待训练的目标检测学生模型的所述任意一边的n个预测值经过softmax的值,ZT为所述训练好的目标检测教师模型的所述任意一边的n个预测值,
Figure BDA0003967618410000066
为所述训练好的目标检测教师模型的所述任意一边的n个预测值经过softmax的值;BS为所述待训练的目标检测学生模型输出的第二目标定位框,BT为所述训练好的目标检测教师模型的输出的第一目标定位框。
进一步,所述训练模块具体用于:
将任一训练样本输入所述待训练的目标检测学生模型,得到所述任一训练样本的第二中间特征图、第二目标定位框和第二目标分类概率值;
基于所述目标损失函数、所述任一训练样本的第一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值,得到所述任一训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述待训练的目标检测学生模型的参数进行优化,得到优化后的目标检测学生模型,将所述优化后的目标检测学生模型作为所述待训练的目标检测学生模型并返回执行将任一训练样本输入所述待训练的目标检测学生模型的步骤,直至满足预设迭代训练条件时,将所述优化后的目标检测学生模型确定为所述训练好的目标检测学生模型。
本发明的一种存储介质的技术方案如下:
存储介质中存储有指令,当计算机读取所述指令时,使所述计算机执行如本发明的一种基于组合蒸馏技术的目标检测方法的步骤。
附图说明
图1示出了本发明的一种基于组合蒸馏技术的目标检测方法的第一实施例的流程示意图;
图2示出了本发明的一种基于组合蒸馏技术的目标检测方法的第一实施例中步骤130的流程示意图;
图3示出了本发明的一种基于组合蒸馏技术的目标检测方法的第二实施例的流程示意图;
图4示出了本发明的一种基于组合蒸馏技术的目标检测系统的实施例的结构示意图。
具体实施方式
图1示出了本发明的一种基于组合蒸馏技术的目标检测方法的第一实施例的流程示意图。如图1所示,包括如下步骤:
步骤110:通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数。
其中,①特征图蒸馏的方式为:学生模型的中间网络所输出的特征图对教师模型中间网络所输出的特征图的模仿。②定位蒸馏的方式为:学生模型所输出的目标定位框对教师模型所输出的目标定位框的模仿。③分类蒸馏的方式为:学生模型所输出的目标分类概率值对教师模型所输出的目标分类概率值的模仿。④知识蒸馏的过程为:通过预设蒸馏方式,利用训练好的教师模型对学生模型进行蒸馏,得到蒸馏后的学生模型的蒸馏损失函数。⑤目标检测教师模型和目标检测学生模型的具体结构不设限制,仅需能够实现目标检测即可;通常教师模型的网络结构比学生模型的网络结构复杂。在本实施例中,目标检测教师模型采用yolov5l模型,目标检测学生模型采用yolov5s模型。⑥经过知识蒸馏所得到的学生模型的损失函数一般由原始损失函数和蒸馏损失函数构成。在本实施例中,目标损失函数为:
Figure BDA0003967618410000081
Figure BDA0003967618410000087
需要说明的是,L为所述目标损失函数,Loriginal为所述待训练的目标检测学生模型的原始损失函数,Lfea为所述改进的特征图蒸馏所对应的特征图蒸馏损失函数,LLD为所述改进的定位蒸馏所对应的定位蒸馏损失函数,
Figure BDA0003967618410000082
为所述改进的分类蒸馏所对应的分类蒸馏损失函数。
其中,
Figure BDA0003967618410000083
Figure BDA0003967618410000084
为训练样本的标注框,i和j为特征图上的像素点,
Figure BDA0003967618410000085
Hr为所述标注框的高度,Wr为所述标注框的宽度,Nbg为所有背景像素点的个数,FT为所述训练好的目标检测教师模型输出的第一中间特征图,FS为所述待训练的目标检测学生模型输出的第二中间特征图,C为第一中间特征图和第二中间特征图的通道数,H为第一中间特征图和第二中间特征图的高度,W为第一中间特征图和第二中间特征图的宽度,f为辅助网络,用于将所述待训练的目标检测学生模型的第二中间特征图的通道数放缩至与和所述训练好的目标检测教师模型的第一中间特征图相同,α和β为用于平衡中间特征图的前景和背景之间损失的超参项;
其中,
Figure BDA0003967618410000086
TCKD为训练样本对应的第一目标分类概率值中的标注类别概率的蒸馏,NCKD为训练样本对应的第一目标分类概率值中的其他非标注类别概率的蒸馏,m和n为可调节的超参数;
其中,
Figure BDA0003967618410000091
Figure BDA0003967618410000092
e为目标定位文本框的任意一边,
Figure BDA0003967618410000093
为所述任意一边的定位蒸馏损失函数,ZS为所述待训练的目标检测学生模型的所述任意一边的n个预测值,
Figure BDA0003967618410000094
为所述待训练的目标检测学生模型的所述任意一边的n个预测值经过softmax的值,ZT为所述训练好的目标检测教师模型的所述任意一边的n个预测值,
Figure BDA0003967618410000095
为所述训练好的目标检测教师模型的所述任意一边的n个预测值经过softmax的值;BS为所述待训练的目标检测学生模型输出的第二目标定位框,BT为所述训练好的目标检测教师模型的输出的第一目标定位框。
步骤120:将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值。
其中,①训练样本为:包含待测物体种类的图像。例如:当目标检测的对象为猫时,训练样本则为包含猫的图像。②第一中间特征图为:训练样本经过训练好的目标检测教师模型所输出的中间特征图。③第一目标定位框为:训练样本经过训练好的目标检测教师模型所输出的目标定位框。④第一目标分类概率值为:训练样本经过训练好的目标检测教师模型所输出的目标分类概率值。
具体地,将任一训练样本输入至所述训练好的目标检测教师模型中,得到该训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,重复上述方式,直至得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值。
步骤130:基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型。
具体地,如图2所示,步骤130包括:
步骤131:将任一训练样本输入所述待训练的目标检测学生模型,得到所述任一训练样本的第二中间特征图、第二目标定位框和第二目标分类概率值。
其中,①第二中间特征图为:训练样本经过待训练的目标检测学生模型所输出的中间特征图。②第二目标定位框为:训练样本经过待训练的目标检测学生模型所输出的目标定位框。③第二目标分类概率值为:训练样本经过待训练的目标检测学生模型所输出的目标分类概率值。
步骤132:基于所述目标损失函数、所述任一训练样本的第一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值,得到所述任一训练样本的目标损失值,直至得到每个训练样本的目标损失值。
具体地,将任一训练样本的一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值输入至目标损失函数,得到该训练样本的目标损失值,重复上述方式,直至得到每个训练样本的目标损失值。
需要说明的是:第一目标定位框为4个值所组成的定位框,而本实施例中的(目标检测学生模型)第二目标定位框是由4×n组成,每条边预测多个间隔相同的值,对于值的预测属于回归任务,由于回归任务对于边界模糊的目标,优化困难,因此本实施例中的目标检测学生模型改为对每条边预测多个间隔相同的值,将回归任务改成分类任务,n为一条边的值的个数,4条边就是4×n。
步骤133:基于所有的目标损失值,对所述待训练的目标检测学生模型的参数进行优化,得到优化后的目标检测学生模型,将所述优化后的目标检测学生模型作为所述待训练的目标检测学生模型并返回执行步骤131,直至满足预设迭代训练条件时,将所述优化后的目标检测学生模型确定为所述训练好的目标检测学生模型。
其中,预设迭代训练条件包括但不限于:最大迭代次数、损失函数收敛等。
步骤140:将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
其中,待测图像为:任意选取的图像。例如,当目标检测学生模型是用于检测图像中是否包含猫的模型,此时待测图像可以是包含猫的图像,也可以是包含狗的图像,在此不设限制。
本实施例的技术方案通过采用组合蒸馏的方式并利用训练好的教师模型生成学生模型,在保证目标检测准确率的同时,降低了模型的参数量,实现了对复杂结构模型的压缩。
图3示出了本发明的一种基于组合蒸馏技术的目标检测方法的第二实施例的流程示意图。如图3所示,包括如下步骤:
步骤210:基于所述多个训练样本,对待训练的目标检测教师模型进行训练,得到所述训练好的目标检测教师模型。
具体地,将任一训练样本输入至待训练的目标检测教师模型中,得到该训练样本对应的损失值,重复上述方式,直至得到每个训练样本的损失值,并根据所有的损失值对待训练的目标检测教师模型的参数进行优化,得到优化后的目标检测教师模型,并返回执行上述训练过程,直至满足预设迭代训练条件时,将优化后的目标检测教师模型确定为训练好的目标检测教师模型。
需要说明的是,训练目标检测教师模型的样本与训练目标检测学生模型的样本可以相同,也可以不同,在此不设限制。
步骤220:通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数。
步骤230:将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值。
步骤240:基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型。
步骤250:将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
本实施例的技术方案进一步通过对教师模型进行训练,并基于组合蒸馏的方式和训练好的教师模型生成学生模型,在保证目标检测准确率的同时,降低了模型的参数量,实现了对复杂结构模型的压缩。
图4示出了本发明的一种基于组合蒸馏技术的目标检测系统的实施例的结构示意图。如图4所示,该系统300包括:构建模块310、处理模块320、训练模块330和检测模块340;
所述构建模块310用于:通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数;
所述处理模块320用于:将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值;
所述训练模块330用于:基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型;
所述检测模块340用于:将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
较优地,还包括:预训练模块;
所述预训练模块用于:基于所述多个训练样本,对待训练的目标检测教师模型进行训练,得到所述训练好的目标检测教师模型。
较优地,所述目标损失函数为:
Figure BDA0003967618410000131
其中,L为所述目标损失函数,Loroginal为所述待训练的目标检测学生模型的原始损失函数,Lfea为所述改进的特征图蒸馏所对应的特征图蒸馏损失函数,LLD为所述改进的定位蒸馏所对应的定位蒸馏损失函数,
Figure BDA0003967618410000132
为所述改进的分类蒸馏所对应的分类蒸馏损失函数;
其中,
Figure BDA0003967618410000133
Figure BDA0003967618410000134
为训练样本的标注框,i和j为特征图上的像素点,
Figure BDA0003967618410000135
Hr为所述标注框的高度,Wr为所述标注框的宽度,Nbg为所有背景像素点的个数,FT为所述训练好的目标检测教师模型输出的第一中间特征图,FS为所述待训练的目标检测学生模型输出的第二中间特征图,C为第一中间特征图和第二中间特征图的通道数,H为第一中间特征图和第二中间特征图的高度,W为第一中间特征图和第二中间特征图的宽度,f为辅助网络,用于将所述待训练的目标检测学生模型的第二中间特征图的通道数放缩至与和所述训练好的目标检测教师模型的第一中间特征图相同,α和β为用于平衡中间特征图的前景和背景之间损失的超参项;
其中,
Figure BDA0003967618410000136
TCKD为训练样本对应的第一目标分类概率值中的标注类别概率的蒸馏,NCKD为训练样本对应的第一目标分类概率值中的其他非标注类别概率的蒸馏,m和n为可调节的超参数;
其中,
Figure BDA0003967618410000141
Figure BDA0003967618410000142
e为目标定位文本框的任意一边,
Figure BDA0003967618410000143
为所述任意一边的定位蒸馏损失函数,ZS为所述待训练的目标检测学生模型的所述任意一边的n个预测值,
Figure BDA0003967618410000144
为所述待训练的目标检测学生模型的所述任意一边的n个预测值经过softmax的值,ZT为所述训练好的目标检测教师模型的所述任意一边的n个预测值,
Figure BDA0003967618410000145
为所述训练好的目标检测教师模型的所述任意一边的n个预测值经过softmax的值;BS为所述待训练的目标检测学生模型输出的第二目标定位框,BT为所述训练好的目标检测教师模型的输出的第一目标定位框。
较优地,所述训练模块330具体用于:
将任一训练样本输入所述待训练的目标检测学生模型,得到所述任一训练样本的第二中间特征图、第二目标定位框和第二目标分类概率值;
基于所述目标损失函数、所述任一训练样本的第一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值,得到所述任一训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述待训练的目标检测学生模型的参数进行优化,得到优化后的目标检测学生模型,将所述优化后的目标检测学生模型作为所述待训练的目标检测学生模型并返回执行将任一训练样本输入所述待训练的目标检测学生模型的步骤,直至满足预设迭代训练条件时,将所述优化后的目标检测学生模型确定为所述训练好的目标检测学生模型。
本实施例的技术方案通过采用组合蒸馏的方式并利用训练好的教师模型生成学生模型,在保证目标检测准确率的同时,降低了模型的参数量,实现了对复杂结构模型的压缩。
上述关于本实施例的一种基于组合蒸馏技术的目标检测系统300中的各参数和各个模块实现相应功能的步骤,可参考上文中关于一种基于组合蒸馏技术的目标检测方法的实施例中的各参数和步骤,在此不做赘述。
本发明实施例提供的一种存储介质,包括:存储介质中存储有指令,当计算机读取所述指令时,使所述计算机执行如一种基于组合蒸馏技术的目标检测方法的步骤,具体可参考上文中一种基于组合蒸馏技术的目标检测方法的实施例中的各参数和步骤,在此不做赘述。
计算机存储介质例如:优盘、移动硬盘等。
所属技术领域的技术人员知道,本发明可以实现为方法、系统和存储介质。
因此,本发明可以具体实现为以下形式,即:可以是完全的硬件、也可以是完全的软件(包括固件、驻留软件、微代码等),还可以是硬件和软件结合的形式,本文一般称为“电路”、“模块”或“系统”。此外,在一些实施例中,本发明还可以实现为在一个或多个计算机可读介质中的计算机程序产品的形式,该计算机可读介质中包含计算机可读的程序代码。可以采用一个或多个计算机可读的介质的任意组合。计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质。计算机可读存储介质例如可以是但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(RAM),只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本文件中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。

Claims (10)

1.一种基于组合蒸馏技术的目标检测方法,其特征在于,包括:
通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数;
将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值;
基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型;
将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
2.根据权利要求1所述的基于组合蒸馏技术的目标检测方法,其特征在于,还包括:
基于所述多个训练样本,对待训练的目标检测教师模型进行训练,得到所述训练好的目标检测教师模型。
3.根据权利要求1所述的基于组合蒸馏技术的目标检测方法,其特征在于,所述目标损失函数为:
Figure FDA0003967618400000011
其中,L为所述目标损失函数,Lorginal为所述待训练的目标检测学生模型的原始损失函数,Lfea为所述改进的特征图蒸馏所对应的特征图蒸馏损失函数,LLD为所述改进的定位蒸馏所对应的定位蒸馏损失函数,
Figure FDA0003967618400000012
为所述改进的分类蒸馏所对应的分类蒸馏损失函数;
其中,
Figure FDA0003967618400000013
Figure FDA0003967618400000014
r为训练样本的标注框,i和j为特征图上的像素点,
Figure FDA0003967618400000021
Hr为所述标注框的高度,Wr为所述标注框的宽度,Nbg为所有背景像素点的个数,FT为所述训练好的目标检测教师模型输出的第一中间特征图,FS为所述待训练的目标检测学生模型输出的第二中间特征图,C为第一中间特征图和第二中间特征图的通道数,H为第一中间特征图和第二中间特征图的高度,W为第一中间特征图和第二中间特征图的宽度,f为辅助网络,用于将所述待训练的目标检测学生模型的第二中间特征图的通道数放缩至与和所述训练好的目标检测教师模型的第一中间特征图相同,α和β为用于平衡中间特征图的前景和背景之间损失的超参项;
其中,
Figure FDA0003967618400000022
TCKD为训练样本对应的第一目标分类概率值中的标注类别概率的蒸馏,NCKD为训练样本对应的第一目标分类概率值中的其他非标注类别概率的蒸馏,m和n为可调节的超参数;
其中,
Figure FDA0003967618400000023
Figure FDA0003967618400000024
e为目标定位文本框的任意一边,
Figure FDA0003967618400000025
为所述任意一边的定位蒸馏损失函数,ZS为所述待训练的目标检测学生模型的所述任意一边的n个预测值,
Figure FDA0003967618400000026
为所述待训练的目标检测学生模型的所述任意一边的n个预测值经过softmax的值,ZT为所述训练好的目标检测教师模型的所述任意一边的n个预测值,
Figure FDA0003967618400000027
为所述训练好的目标检测教师模型的所述任意一边的n个预测值经过softmax的值;BS为所述待训练的目标检测学生模型输出的第二目标定位框,BT为所述训练好的目标检测教师模型的输出的第一目标定位框。
4.根据权利要求1所述的基于组合蒸馏技术的目标检测方法,其特征在于,所述基于每个训练样本、每个训练样本对应的中间特征图、目标定位框和目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型的步骤,包括:
将任一训练样本输入所述待训练的目标检测学生模型,得到所述任一训练样本的第二中间特征图、第二目标定位框和第二目标分类概率值;
基于所述目标损失函数、所述任一训练样本的第一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值,得到所述任一训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述待训练的目标检测学生模型的参数进行优化,得到优化后的目标检测学生模型,将所述优化后的目标检测学生模型作为所述待训练的目标检测学生模型并返回执行将任一训练样本输入所述待训练的目标检测学生模型的步骤,直至满足预设迭代训练条件时,将所述优化后的目标检测学生模型确定为所述训练好的目标检测学生模型。
5.根据权利要求1-4任一项所述的基于组合蒸馏技术的目标检测方法,其特征在于,所述目标检测教师模型采用yolov5l模型,所述目标检测学生模型采用yolov5s模型。
6.一种基于组合蒸馏技术的目标检测系统,其特征在于,包括:构建模块、处理模块、训练模块和检测模块;
所述构建模块用于:通过包含改进的特征图蒸馏、改进的定位蒸馏和改进的分类蒸馏在内的组合蒸馏方式,利用训练好的目标检测教师模型对待训练的目标检测学生模型进行知识蒸馏,得到所述待训练的目标检测学生模型的目标损失函数;
所述处理模块用于:将每个训练样本分别输入至所述训练好的目标检测教师模型中,得到每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值;
所述训练模块用于:基于每个训练样本、每个训练样本对应的第一中间特征图、第一目标定位框和第一目标分类概率值,以及所述目标损失函数,对所述待训练的目标检测学生模型进行迭代训练,直至得到训练好的目标检测学生模型;
所述检测模块用于:将待测图像输入所述训练好的目标检测学生模型,得到所述待测图像的目标检测结果。
7.根据权利要求6所述的基于组合蒸馏技术的目标检测系统,其特征在于,还包括:预训练模块;
所述预训练模块用于:基于所述多个训练样本,对待训练的目标检测教师模型进行训练,得到所述训练好的目标检测教师模型。
8.根据权利要求6所述的基于组合蒸馏技术的目标检测系统,其特征在于,所述目标损失函数为:
Figure FDA0003967618400000041
其中,L为所述目标损失函数,Loriginal为所述待训练的目标检测学生模型的原始损失函数,Lfea为所述改进的特征图蒸馏所对应的特征图蒸馏损失函数,LLD为所述改进的定位蒸馏所对应的定位蒸馏损失函数,
Figure FDA0003967618400000042
为所述改进的分类蒸馏所对应的分类蒸馏损失函数;
其中,
Figure FDA0003967618400000043
Figure FDA0003967618400000044
r为训练样本的标注框,i和j为特征图上的像素点,
Figure FDA0003967618400000045
Hr为所述标注框的高度,Wr为所述标注框的宽度,Nbg为所有背景像素点的个数,FT为所述训练好的目标检测教师模型输出的第一中间特征图,FS为所述待训练的目标检测学生模型输出的第二中间特征图,C为第一中间特征图和第二中间特征图的通道数,H为第一中间特征图和第二中间特征图的高度,W为第一中间特征图和第二中间特征图的宽度,f为辅助网络,用于将所述待训练的目标检测学生模型的第二中间特征图的通道数放缩至与和所述训练好的目标检测教师模型的第一中间特征图相同,α和β为用于平衡中间特征图的前景和背景之间损失的超参项;
其中,
Figure FDA0003967618400000051
TCKD为训练样本对应的第一目标分类概率值中的标注类别概率的蒸馏,NCKD为训练样本对应的第一目标分类概率值中的其他非标注类别概率的蒸馏,m和n为可调节的超参数;
其中,
Figure FDA0003967618400000052
Figure FDA0003967618400000053
e为目标定位文本框的任意一边,
Figure FDA0003967618400000054
为所述任意一边的定位蒸馏损失函数,ZS为所述待训练的目标检测学生模型的所述任意一边的n个预测值,
Figure FDA0003967618400000055
为所述待训练的目标检测学生模型的所述任意一边的n个预测值经过softmax的值,ZT为所述训练好的目标检测教师模型的所述任意一边的n个预测值,
Figure FDA0003967618400000056
为所述训练好的目标检测教师模型的所述任意一边的n个预测值经过softmax的值;BS为所述待训练的目标检测学生模型输出的第二目标定位框,BT为所述训练好的目标检测教师模型的输出的第一目标定位框。
9.根据权利要求6所述的基于组合蒸馏技术的目标检测系统,其特征在于,所述训练模块具体用于:
将任一训练样本输入所述待训练的目标检测学生模型,得到所述任一训练样本的第二中间特征图、第二目标定位框和第二目标分类概率值;
基于所述目标损失函数、所述任一训练样本的第一中间特征图、第一目标定位框、第一目标分类概率值、第二中间特征图、第二目标定位框和第二目标分类概率值,得到所述任一训练样本的目标损失值,直至得到每个训练样本的目标损失值;
基于所有的目标损失值,对所述待训练的目标检测学生模型的参数进行优化,得到优化后的目标检测学生模型,将所述优化后的目标检测学生模型作为所述待训练的目标检测学生模型并返回执行将任一训练样本输入所述待训练的目标检测学生模型的步骤,直至满足预设迭代训练条件时,将所述优化后的目标检测学生模型确定为所述训练好的目标检测学生模型。
10.一种存储介质,其特征在于,所述存储介质中存储有指令,当计算机读取所述指令时,使所述计算机执行如权利要求1至5中任一项所述的基于组合蒸馏技术的目标检测方法。
CN202211504333.9A 2022-11-28 2022-11-28 一种基于组合蒸馏技术的目标检测方法、系统和存储介质 Active CN115984640B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211504333.9A CN115984640B (zh) 2022-11-28 2022-11-28 一种基于组合蒸馏技术的目标检测方法、系统和存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211504333.9A CN115984640B (zh) 2022-11-28 2022-11-28 一种基于组合蒸馏技术的目标检测方法、系统和存储介质

Publications (2)

Publication Number Publication Date
CN115984640A true CN115984640A (zh) 2023-04-18
CN115984640B CN115984640B (zh) 2023-06-23

Family

ID=85974831

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211504333.9A Active CN115984640B (zh) 2022-11-28 2022-11-28 一种基于组合蒸馏技术的目标检测方法、系统和存储介质

Country Status (1)

Country Link
CN (1) CN115984640B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116778300A (zh) * 2023-06-25 2023-09-19 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、系统和存储介质

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113487614A (zh) * 2021-09-08 2021-10-08 四川大学 胎儿超声标准切面图像识别网络模型的训练方法和装置
CN113610069A (zh) * 2021-10-11 2021-11-05 北京文安智能技术股份有限公司 基于知识蒸馏的目标检测模型训练方法
JP2022058915A (ja) * 2021-05-27 2022-04-12 ベイジン バイドゥ ネットコム サイエンス テクノロジー カンパニー リミテッド 画像認識モデルをトレーニングするための方法および装置、画像を認識するための方法および装置、電子機器、記憶媒体、並びにコンピュータプログラム
CN115147687A (zh) * 2022-07-07 2022-10-04 浙江啄云智能科技有限公司 学生模型训练方法、装置、设备及存储介质
CN115376195A (zh) * 2022-10-09 2022-11-22 珠海大横琴科技发展有限公司 训练多尺度网络模型的方法及人脸关键点检测方法

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2022058915A (ja) * 2021-05-27 2022-04-12 ベイジン バイドゥ ネットコム サイエンス テクノロジー カンパニー リミテッド 画像認識モデルをトレーニングするための方法および装置、画像を認識するための方法および装置、電子機器、記憶媒体、並びにコンピュータプログラム
CN113487614A (zh) * 2021-09-08 2021-10-08 四川大学 胎儿超声标准切面图像识别网络模型的训练方法和装置
CN113610069A (zh) * 2021-10-11 2021-11-05 北京文安智能技术股份有限公司 基于知识蒸馏的目标检测模型训练方法
CN115147687A (zh) * 2022-07-07 2022-10-04 浙江啄云智能科技有限公司 学生模型训练方法、装置、设备及存储介质
CN115376195A (zh) * 2022-10-09 2022-11-22 珠海大横琴科技发展有限公司 训练多尺度网络模型的方法及人脸关键点检测方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
楚玉春等: "基于YOLOv4的目标检测知识蒸馏算法研究", 计算机科学 *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116778300A (zh) * 2023-06-25 2023-09-19 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、系统和存储介质
CN116778300B (zh) * 2023-06-25 2023-12-05 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、系统和存储介质

Also Published As

Publication number Publication date
CN115984640B (zh) 2023-06-23

Similar Documents

Publication Publication Date Title
US20190361972A1 (en) Method, apparatus, device for table extraction based on a richly formatted document and medium
CN108038107B (zh) 基于卷积神经网络的语句情感分类方法、装置及其设备
CN110598620B (zh) 基于深度神经网络模型的推荐方法和装置
CN111210446B (zh) 一种视频目标分割方法、装置和设备
CN111368634B (zh) 基于神经网络的人头检测方法、系统及存储介质
CN110969200A (zh) 基于一致性负样本的图像目标检测模型训练方法及装置
CN113065013B (zh) 图像标注模型训练和图像标注方法、系统、设备及介质
CN111242922A (zh) 一种蛋白质图像分类方法、装置、设备及介质
CN116823793A (zh) 设备缺陷检测方法、装置、电子设备和可读存储介质
CN115984640A (zh) 一种基于组合蒸馏技术的目标检测方法、系统和存储介质
CN114078197A (zh) 一种基于支撑样本特征增强的小样本目标检测方法及装置
CN114330588A (zh) 一种图片分类方法、图片分类模型训练方法及相关装置
CN114565803A (zh) 用于提取难样本的方法、装置及机械设备
CN116797973A (zh) 应用于环卫智慧管理平台的数据挖掘方法及系统
CN111832435A (zh) 基于迁移与弱监督的美丽预测方法、装置及存储介质
CN115052154B (zh) 一种模型训练和视频编码方法、装置、设备及存储介质
CN114118410A (zh) 图结构的节点特征提取方法、设备及存储介质
CN114067099A (zh) 学生图像识别网络的训练方法及图像识别方法
US11177018B2 (en) Stable genes in comparative transcriptomics
CN113033397A (zh) 目标跟踪方法、装置、设备、介质及程序产品
CN116778300B (zh) 一种基于知识蒸馏的小目标检测方法、系统和存储介质
CN116385844B (zh) 一种基于多教师模型的特征图蒸馏方法、系统和存储介质
CN113283345B (zh) 板书书写行为检测方法、训练方法、装置、介质及设备
CN114970955B (zh) 基于多模态预训练模型的短视频热度预测方法及装置
WO2024135112A1 (ja) 機械学習システム

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant