CN113610126B - 基于多目标检测模型无标签的知识蒸馏方法及存储介质 - Google Patents
基于多目标检测模型无标签的知识蒸馏方法及存储介质 Download PDFInfo
- Publication number
- CN113610126B CN113610126B CN202110838933.8A CN202110838933A CN113610126B CN 113610126 B CN113610126 B CN 113610126B CN 202110838933 A CN202110838933 A CN 202110838933A CN 113610126 B CN113610126 B CN 113610126B
- Authority
- CN
- China
- Prior art keywords
- network
- teacher
- loss
- distillation
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 43
- 238000000034 method Methods 0.000 title claims abstract description 38
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 28
- 238000004821 distillation Methods 0.000 claims abstract description 40
- 238000012549 training Methods 0.000 claims abstract description 38
- 230000008569 process Effects 0.000 claims abstract description 9
- 230000003044 adaptive effect Effects 0.000 claims description 5
- 238000005259 measurement Methods 0.000 claims description 5
- 230000006870 function Effects 0.000 claims description 4
- 238000004590 computer program Methods 0.000 claims description 3
- 238000000605 extraction Methods 0.000 claims description 3
- 238000011478 gradient descent method Methods 0.000 claims description 3
- 239000011159 matrix material Substances 0.000 claims description 3
- 238000004088 simulation Methods 0.000 claims description 3
- 210000003128 head Anatomy 0.000 description 18
- 238000012360 testing method Methods 0.000 description 12
- 238000010586 diagram Methods 0.000 description 10
- 230000008859 change Effects 0.000 description 6
- 238000002474 experimental method Methods 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 4
- 238000011160 research Methods 0.000 description 4
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 230000018109 developmental process Effects 0.000 description 2
- 210000000887 face Anatomy 0.000 description 2
- 238000012544 monitoring process Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000005021 gait Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 210000001747 pupil Anatomy 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Probability & Statistics with Applications (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于多目标检测模型无标签的知识蒸馏方法及存储介质,属于计算机视觉目标检测领域,该方法包括以下步骤:S1、获取多类别数据集;S2、利用不同类别的数据集训练出不同的教师网络,将无标签的图片输入至学生网络和多个教师网络,引导学生网络训练;学生网络的类别为多个教师网络类别的组合;S3、设计全局蒸馏损失以及自适应损失,平衡多个不同教师网络和学生网络之间的蒸馏损失,平衡不同教师网络之间的语言特性,优化学生网络训练过程。本发明能够有效提炼出不同教师网络中的多类别信息,进行完整类别的目标检测,并且在单一数据集的制定类别上与教师网络持平甚至超越。
Description
技术领域
本发明属于计算机视觉目标检测领域,具体涉及一种基于多目标检测模型无标签的知识蒸馏方法及存储介质。
背景技术
在人工智能发展迅速的今天,目标检测是计算机视觉和数字图像处理的一个热门方向,广泛应用于机器人导航、智能视频监控、工业检测、航空航天等诸多领域,通过计算机视觉减少对人力资本的消耗,具有重要的现实意义。因此,目标检测也就成为了近年来理论和应用的研究热点,它是图像处理和计算机视觉学科的重要分支,也是智能监控系统的核心部分,同时目标检测也是泛身份识别领域的一个基础性的算法,对后续的人脸识别、步态识别、人群计数、实例分割等任务起着至关重要的作用。得益于深度学习CNN网络架构的发展,目标检测任务的性能逐步提高。然而,现有的目标检测框架针对完全标注的监督学习模式设计,对于半标注和无标注的数据集,现有框架难以抽取出数据集中有效地信息并加以训练。目标检测任务在实际应用中存在目标域变换或目标类别变化,并且对模型大小和推理速度都有更加苛刻的要求。针对这一问题,基于知识蒸馏的目标检测被证实为一种行之有效的方案。
知识蒸馏(Knowledge distillation,KD)于2015年提出,被广泛应用于迁移学习和模型压缩中,知识蒸馏可以将一个或多个网络的知识转移到另一个同构或者异构的网络。知识蒸馏需要先训练一个或多个教师网络,然后使用这些教师网络的输出和数据的真实标签联合训练学生网络。知识蒸馏可以用于将网络从大的教师网络转化成一个小的学生网络,实现模型的压缩并保留接近于大网络的性能;也可以将多个教师网络的知识转移到一个学生网络中,使得单个网络的性能接近emsemble的结果。
现阶段大多数基于知识蒸馏的目标检测方法大多在单一数据集上展开,从大的教师网络中指导小的学生网络训练,用以获取挖掘学生模型性能,但是很少有跨数据集和类别的目标检测网络蒸馏研究。
实际场景中有很多类似的需求,往往需要同时检测出多个关注的类别。然而现有的开源数据大多针对通用场景下构建数据集,大多包含其中的一个类别或者多个类别,并不能包含关注的所有类别,因此要获得一个能够检测实际场景中的所有类别是一项研究的难点。假设A数据集中包含物体{a1,a2,…,an}类别但不包含{b1,b2,…bn}类别,B数据集中包含{b1,b2,…bn}等类别但不包含{a1,a2,…,an}类别,然而实际场景需要{a1,a2,…,an,b1,b2,…bn}检测所有类别的模型,如何更好地使用现有数据集获取检测完整类别的目标检测网络是一个重要的需求和难点。
发明内容
本发明的目的在于,提供一种基于多目标检测模型无标签的知识蒸馏方法及存储介质,获取检测完整类别的目标检测网络,实现跨数据集和跨类别的目标检测。
本发明提供的技术方案如下:
一种基于多目标检测模型无标签的知识蒸馏方法,包括以下步骤:
S1、获取多类别数据集;
S2、先利用不同类别的数据集训练出不同的教师网络模型,而后将无标签的图片输入至学生网络和多个教师网络模型,从而使教师网络模型引导学生网络训练;其中,学生网络的类别为多个教师网络类别的组合;
S3、设计全局蒸馏损失以及自适应损失,平衡多个不同教师网络和学生网络之间的蒸馏损失,平衡不同教师网络之间的语言特性,优化学生网络训练过程。
进一步地,数据集的类别大于等于2。
进一步地,多类别数据集包括CrowdHuman数据集、WiderFace数据集、SHWD数据集。
进一步地,步骤S2包括:
利用不同类别的数据集训练出不同的教师网络模型,将无标签的图片输入至学生网络和多个教师网络模型,将学生网络头部输出和不同教师网络头部输出计算蒸馏损失,采用反向传播梯度下降方法引导学生网络训练;学生网络的类别为不同教师网络类别的组合,类别通道数一一对应并分别结算分类损失,同理,相继计算出回归损失以及偏置损失。
进一步地,利用不同类别的数据集训练出不同的教师网络模型具体为:通过Teacher-i网络中backbone模块和Neck模块得到相应的头部输出,包括heatmap图,即对应的分类信息Ki和宽高的回归信息、中心点坐标的偏置信息;其中,i∈1,2,3…n,n表示总类别数。
进一步地,引导学生网络训练时,在学生网络的分类头部中加入分类注意力模块。
进一步地,引导学生网络训练,先将分类预测头的输入通过卷积层转化为类别特征图C×H×W,其中C为目标类别数,H和W为特征图的高度和长度尺寸,再通过卷积层构建类内注意力图HW×1×1,经过Softmax层归一化,并与原特征图进行矩阵乘法,获得类间特征图C×1×1,并通过Excitation操作,最后将类间注意力图C×1×1通过广播逐元素加法叠加到原特征图中,完成类别特征的提取。
进一步地,步骤S3包括:
对类别预测头的输出使用Leaky ReLU进行约束,再进行教师和学生网络间的模拟,类别蒸馏损失如下:
式中,S表示学生网络,T表示教师网络,Hijc为网络的分类头部输出,k,h,w分别对应着特征图的类别通道数、高和宽,l()代表Leaky ReLU约束;
中心点偏移量蒸馏损失如下:
式中,N表示该幅图像中关键点的个数,即正样本个数,O表示目标中心点的偏置量,所有的类别共享相同的偏移量预测,Loffset采用L1损失,只对目标中心点位置进行约束,忽略所有其他位置,并将/>处的特征/>作为权重叠加在对应位置;
尺度蒸馏损失如下:
式中,Si为学生或教师网络中宽高预测头输出的对应位置的预测结果,将特征作为权重叠加;
每一个教师网络和学生网络之间分别计算蒸馏损失,最后将不同教师网络的损失进行加权求和,总的损失函数为:
其中,λt是教师和学生网络之间蒸馏权重,αt、βt和γt为不同蒸馏损失间的权重。
进一步地,根据每次迭代间的损失变化比例控制损失在指定的区间内,自适应性损失为:
其中,损失指定区间为[α,β],r为上一个迭代与当前迭代的损失比例,包括Lcls、Loffset以及Lsize各自的损失,[rs,rl]为损失变化比例限定空间。
一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现上述的基于多目标检测模型无标签的知识蒸馏方法。
本发明的有益效果为:
本发明的基于多目标检测模型无标签的知识蒸馏方法及存储介质,能够有效提炼出不同教师网络中的多类别信息,进行完整类别的目标检测,并且在单一数据集的制定类别上与教师网络持平甚至超越。
附图说明
图1是本发明实施例的多模型蒸馏无标签的目标检测框架结构图。
图2是CH+WF数据集上未采用自适应损失训练时的损失变化图。
图3是CH+WF数据集上采用本方法提出的自适应损失训练时的损失变化图。
图4是本发明在CrowdHuman验证集和SHWD测试集上的部分检测结果图。
具体实施方式
为了使本发明目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。此外,下面所描述的本发明各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
本发明利用现有的多种数据集去预训练得到教师网络模型,同时将无标签的图片输入学生网络和多个教师网络模型,引导学生网络训练;然后设计全局蒸馏损失以及自适应损失,以平衡多个不同教师模型和学生之间的蒸馏损失,平衡不同教师之间的语言特性,优化训练过程;最后设计对比实验,利用不同的数据集训练的教师网络模型对比学生网络产生的结果影响。本发明能够有效提炼出不同教师网络中的多类别信息,并且在单一数据集的制定类别上与教师网络持平甚至超越。
本发明实施例的基于多目标检测模型无标签的知识蒸馏方法,包括以下步骤:
S1、获取多类别数据集。
本发明实施例从实际场景的广泛需求出发,构建了三种分类组合:行人+人脸;行人+安全帽;行人+人脸+安全帽。其中,行人数据集采用包含大量复杂密集场景的CrowdHuman数据集,人体的遮挡、多尺度等困难样本大量存在,更加贴近于实际应用场景,也进一步增大了检测的难度。CrowdHuman数据集拥有15000张用于训练的图像,4370张用于验证的图像和5000张用于测试的图像,训练和验证数据集共包含470K个人类实例。人脸数据集采用WiderFace数据集。WiderFace由32203张图像,393703个标注人脸组成,由于数据集中包含大量的尺度变化、姿态变化以及不同程度的遮挡,WiderFace贴近于实际应用场景。安全帽数据集采用安全帽佩戴检测数据集(Safety Helmet Wearing Dataset,SHWD)。SHWD数据集包含7581张图像,9044个戴安全帽的目标标注和11154个普通未带安全帽的目标标注,大量数据采集自实际工地场景。
S2、在不同的数据集上训练教师网络模型,而后将无标签的图片输入学生网络和多个教师网络,从而引导学生网络训练。
相较于多阶段目标检测网络的知识蒸馏,由于RPN网络输出的提议框(Proposal)的不确定性,蒸馏过程中难以处理教师网络和学生网络的区域提议。为了更有效的使用现有数据集获取能够检测实际场景中完整类别的目标检测网络,构建无标签知识蒸馏网络,本发明采用如下方法:在多个数据集上训练教师网络,而后将无标签的图片输入学生网络和多个教师网络,将学生网络头部输出和不同教师网络头部输出计算蒸馏损失,采用反向传播梯度下降方法引导学生网络训练。学生网络的类别为不同教师网络类别的组合,类别通道数一一对应并分别结算分类损失,同理,相继计算出回归损失以及偏置损失。
下面将结合附图1,对本发明的网络结构进行详细的介绍说明。首先针对于不同的数据集,去训练出不同的教师模型,具体是通过Teacher-i(i∈1,2,3...n)中backbone模块和Neck模块得到相应的头部输出,包括heatmap图,即对应的分类信息Ki,i∈1,2,3…n,和宽高的回归信息、中心点坐标的偏置信息。对于每一个教师模型,它们拥有着不同类别的丰富信息,即最后各自得到的权重向量Headi,i∈1,2,3...n。所以在学生网络训练阶段,将学生网络头部输出和不同教师网络头部输出之间计算蒸馏损失,引导学生网络训练。学生网络的类别为教师网络类别的组合,类别通道数一一对应并分别结算分类损失。
为了挖掘不同类别间的深层语义关系,在学生网络的分类头部中加入分类注意力模块。为了有效地挖掘类内和类间的语义关联,特别是相距较远的目标之间的语义关联,更加关注类别间的关系,故先将分类预测头的输入通过卷积层转化为类别特征图C×H×W,其中C为目标类别数,H和W为特征图的高度和长度尺寸,再通过卷积层构建类内注意力图HW×1×1,经过Softmax层归一化,并与原特征图进行矩阵乘法,获得类间特征图C×1×1,并通过Excitation操作,最后将类间注意力图C×1×1通过广播逐元素加法叠加到原特征图中,完成类别特征的提取,设置蒸馏权重为1.00。
S3、设计全局蒸馏损失以及自适应损失,以平衡多个不同教师模型和学生网络之间的蒸馏损失,平衡不同教师网络之间的语言特性,优化训练过程。
将上述步骤中所获得的不同数据集下的教师网络权重向量Headi,i∈1,2,3...n在损失函数作为引导的情况下,可以实现在几乎不牺牲计算复杂度的情况下,学生网络获得教师网络的有效知识。
考虑到教师模型在目标编码时将所有的目标通过高斯核的方式编码进热力图中,特定类别占据特定通道,故而目标的分类头部输出应该限定在0-1之间。为了一定程度上减少教师模型的错误预测的影响,在计算不同教师模型和学生模型之间的距离度量之前,对类别预测头的输出使用Leaky ReLU(L-ReLU)进行约束,再进行教师和学生模型间的模拟,类别蒸馏损失如下:
其中,为学生网络的分类头部输出,/>为教师网络的输出,k,h,w分别对应特征图的类别通道数,高和宽,l()代表L-ReLU约束。由于学生网络包含多个教师网络的类别,计算类别蒸馏损失时抽取学生网络和教师网络中对应类别通道的特征图进行计算。
为了更加蒸馏计算出中心点位置偏移量信息,引入中心点偏移量蒸馏损失:
式中,N表示该幅图像中关键点的个数,即正样本个数,O表示目标中心点的偏置量。所有的类别共享相同的偏移量预测,Loffset采用L1损失,只对目标中心点位置进行约束,忽略所有其他位置,并将/>处的特征/>作为权重叠加在对应位置,使得教师模型给出置信度更高的目标更大的权重,进而优化蒸馏的过程。
同样为了蒸馏出目标的宽高预测,引入尺度蒸馏损失:
式中,Si为学生或教师网络中宽高预测头输出的对应位置的预测结果,计算损失时也仅有计算目标中心位置参与计算,并将特征作为权重叠加。
每一个的教师和学生网络之间分别计算蒸馏损失。最后将不同教师网络的损失进行加权求和,总的损失函数为:
其中λt是教师和学生网络之间蒸馏权重,αt、βt和γt为不同蒸馏损失间的权重,便于学生网络学习到了教师网络的有效知识。
由于多模型和多数据集之间域的不同,导致模型难以训练,训练中损失图如附图2所示。为平衡多个不同教师模型与学生之间的蒸馏损失,本发明提出自适应损失,以自适应的平衡不同教师之间的语义鸿沟,优化训练过程。根据每次教师网络模型引导学生网络迭代的过程中,将损失变化比例控制在指定的区间内,调整因损失的剧烈变化导致训练失控。自适应性损失为:
其中,损失指定区间为[α,β],r为上一个迭代与当前迭代的损失比例,包括Lcls,Loffset以及Lsize各自的损失,[rs,rl]为损失变化比例限定空间。在添加了自适应损失后,多模型知识蒸馏的过程平稳的进行,并逐步收敛,训练过程中损失的曲线如附图3所示。
S4、设计对比实验,利用不同的数据集训练的教师网络模型对比学生网络产生的结果影响。
在本发明实例中,针对该实际应用场景在CrowdHuman、SHWD和WiderFace上展开研究,以分别在多个数据集上训练的以ResNet-50为骨架网络的自编码器作为教师模型,蒸馏一个同时检测人、人脸和安全帽的以ResNet-18为骨架网路目标检测模型。
表1行人和人脸组合相关的实验结果对比
其中第一组实验如表1所示,其中CH为CrowdHuman数据集,WF为WiderFace数据集。为了验证其优越性,本发明将MMKD方法与Ignore Label和Pseudo Label的方案,以及单一模型上训练的模型在对应数据集上进行比较。实验结果表明,Resnet-18-MMKD在CrowdHuman的测试集上的AP为32.3%,在WiderFace的测试集上的AP为32.4%,相较于Ignore Label的方案提高了3.0%和8.2%,相较于Pseudo Label的方法提高了3.3%和4.4%,在精度和泛化性能都高于单一的ResNet-18方法。
表2行人和安全帽组合相关的实验结果对比
第二组实验如表2所示,实验结果表明,Resnet-18-MMKD在CrowdHuman的测试集上的AP为33.2%,在SHWD测试集上Helmet类别的AP为61.7%,Head类别的AP为37.6%,相较于Ignore Label的方案提高了3.8%、6.7%和6.1%,相较于Pseudo Label的方法提高了4.7%、3.5%和4.2%。由于网络蒸馏了教师网络的有效知识且训练数据量扩大,学生网络在精度和泛化性能都高于单一的ResNet-18方法。
表3行人、人脸和安全帽组合相关的实验结果对比
第三组实验如表3所示,实验结果表明,Resnet-18-MMKD在CrowdHuman的测试集上的AP为30.4%,在WiderFace的测试集上的AP为30.7%,在SHWD测试集上Helmet类别的AP为59.5%,Head类别的AP为30.4%,相较于Ignore Label的方案提高了1.9%、7.5%、9.9%和0.8%,相较于Pseudo Label的方法提高了2.0%、6.3%、0.9%和0.7%。由于网络蒸馏了教师网络的有效知识且训练数据量扩大,学生网络在精度和泛化性能都高于单一的ResNet-18方法。
在CrowdHuman验证集、WiderFace测试集以及SHWD测试集上的部分检测结果如附图4所示。
本发明还提供一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现上述的基于多目标检测模型无标签的知识蒸馏方法。
需要指出,根据实施的需要,可将本申请中描述的各个步骤/部件拆分为更多步骤/部件,也可将两个或多个步骤/部件或者步骤/部件的部分操作组合成新的步骤/部件,以实现本发明的目的。
本领域的技术人员容易理解,以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明的保护范围之内。
Claims (9)
1.一种基于多目标检测模型无标签的知识蒸馏方法,其特征在于,包括以下步骤:
S1、获取多类别数据集;
S2、先利用不同类别的数据集训练出不同的教师网络模型,而后将无标签的图片输入至学生网络和多个教师网络模型,从而使教师网络模型引导学生网络训练;其中,学生网络的类别为多个教师网络类别的组合;
S3、设计全局蒸馏损失以及自适应损失,平衡多个不同教师网络和学生网络之间的蒸馏损失,平衡不同教师网络之间的语言特性,优化学生网络训练过程;包括:
对类别预测头的输出使用Leaky ReLU进行约束,再进行教师和学生网络间的模拟,类别蒸馏损失如下:
式中,S表示学生网络,T表示教师网络,Hijc为网络的分类头部输出,k,h,w分别对应着特征图的类别通道数、高和宽,l()代表Leaky ReLU约束;
中心点偏移量蒸馏损失如下:
式中,N表示图像中关键点的个数,即正样本个数,O表示目标中心点的偏置量,所有的类别共享相同的偏移量预测,Loffset采用L1损失,只对目标中心点位置进行约束,忽略所有其他位置,并将/>处的特征/>作为权重叠加在对应位置;
尺度蒸馏损失如下:
式中,Si为学生或教师网络中宽高预测头输出的对应位置的预测结果,将特征作为权重叠加;
每一个教师网络和学生网络之间分别计算蒸馏损失,最后将不同教师网络的损失进行加权求和,总的损失函数为:
其中,λt是教师和学生网络之间蒸馏权重,αt、βt和γt为不同蒸馏损失间的权重。
2.根据权利要求1所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,数据集的类别大于等于2。
3.根据权利要求1所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,多类别数据集包括CrowdHuman数据集、WiderFace数据集、SHWD数据集。
4.根据权利要求1所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,步骤S2包括:
利用不同类别的数据集训练出不同的教师网络模型,将图片输入至学生网络和多个教师网络模型,将学生网络头部输出和不同教师网络头部输出计算蒸馏损失,采用反向传播梯度下降方法引导学生网络训练;学生网络的类别为不同教师网络类别的组合,类别通道数一一对应并分别结算分类损失,同理,相继计算出回归损失以及偏置损失。
5.根据权利要求4所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,利用不同类别的数据集训练出不同的教师网络模型具体为:通过Teacher-i网络中backbone模块和Neck模块得到相应的头部输出,包括heatmap图,即对应的分类信息Ki和宽高的回归信息、中心点坐标的偏置信息;其中,i∈1,2,3…n,n表示总类别数。
6.根据权利要求5所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,引导学生网络训练时,在学生网络的分类头部中加入分类注意力模块。
7.根据权利要求6所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,引导学生网络训练时,先将分类预测头的输入通过卷积层转化为类别特征图C×H×W,其中C为目标类别数,H和W为特征图的高度和长度尺寸,再通过卷积层构建类内注意力图HW×1×1,经过Softmax层归一化,并与原特征图进行矩阵乘法,获得类间特征图C×1×1,并通过Excitation操作,最后将类间注意力图C×1×1通过广播逐元素加法叠加到原特征图中,完成类别特征的提取。
8.根据权利要求1所述的基于多目标检测模型无标签的知识蒸馏方法,其特征在于,根据每次迭代间的损失变化比例控制损失在指定的区间内,自适应性损失为:
其中,损失指定区间为[α,β],r为上一个迭代与当前迭代的损失比例,包括Lcls、Loffset以及Lsize各自的损失,[rs,rl]为损失变化比例限定空间。
9.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,计算机程序被处理器执行时实现权利要求1至8中任一项所述的基于多目标检测模型无标签的知识蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110838933.8A CN113610126B (zh) | 2021-07-23 | 2021-07-23 | 基于多目标检测模型无标签的知识蒸馏方法及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110838933.8A CN113610126B (zh) | 2021-07-23 | 2021-07-23 | 基于多目标检测模型无标签的知识蒸馏方法及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113610126A CN113610126A (zh) | 2021-11-05 |
CN113610126B true CN113610126B (zh) | 2023-12-05 |
Family
ID=78338219
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110838933.8A Active CN113610126B (zh) | 2021-07-23 | 2021-07-23 | 基于多目标检测模型无标签的知识蒸馏方法及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113610126B (zh) |
Families Citing this family (24)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114119959A (zh) * | 2021-11-09 | 2022-03-01 | 盛视科技股份有限公司 | 一种基于视觉的垃圾桶满溢检测方法及装置 |
CN114022494B (zh) * | 2021-11-14 | 2024-03-29 | 北京工业大学 | 一种基于轻型卷积神经网络和知识蒸馏的中医舌图像自动分割方法 |
CN114067411A (zh) * | 2021-11-19 | 2022-02-18 | 厦门市美亚柏科信息股份有限公司 | 一种人脸检测对齐网络知识蒸馏方法及装置 |
CN114095447B (zh) * | 2021-11-22 | 2024-03-12 | 成都中科微信息技术研究院有限公司 | 一种基于知识蒸馏与自蒸馏的通信网络加密流量分类方法 |
CN113822254B (zh) * | 2021-11-24 | 2022-02-25 | 腾讯科技(深圳)有限公司 | 一种模型训练方法及相关装置 |
CN114120065B (zh) * | 2021-11-30 | 2024-08-06 | 江苏集萃智能光电系统研究所有限公司 | 一种高内聚低耦合列车故障检测方法 |
CN113888538B (zh) * | 2021-12-06 | 2022-02-18 | 成都考拉悠然科技有限公司 | 一种基于内存分块模型的工业异常检测方法 |
CN114494776A (zh) * | 2022-01-24 | 2022-05-13 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置、设备以及存储介质 |
CN114863248B (zh) * | 2022-03-02 | 2024-04-26 | 武汉大学 | 一种基于深监督自蒸馏的图像目标检测方法 |
CN114743243B (zh) * | 2022-04-06 | 2024-05-31 | 平安科技(深圳)有限公司 | 基于人工智能的人脸识别方法、装置、设备及存储介质 |
CN114445670B (zh) * | 2022-04-11 | 2022-07-12 | 腾讯科技(深圳)有限公司 | 图像处理模型的训练方法、装置、设备及存储介质 |
CN114926471B (zh) * | 2022-05-24 | 2023-03-28 | 北京医准智能科技有限公司 | 一种图像分割方法、装置、电子设备及存储介质 |
CN115131627B (zh) * | 2022-07-01 | 2024-02-20 | 贵州大学 | 一种轻量化植物病虫害目标检测模型的构建和训练方法 |
CN114882228B (zh) * | 2022-07-08 | 2022-09-09 | 海门市三德体育用品有限公司 | 基于知识蒸馏的健身场所布局优化方法 |
CN115019180B (zh) * | 2022-07-28 | 2023-01-17 | 北京卫星信息工程研究所 | Sar图像舰船目标检测方法、电子设备及存储介质 |
CN116204770B (zh) * | 2022-12-12 | 2023-10-13 | 中国公路工程咨询集团有限公司 | 一种用于桥梁健康监测数据异常检测的训练方法及装置 |
CN115797794A (zh) * | 2023-01-17 | 2023-03-14 | 南京理工大学 | 基于知识蒸馏的卫星视频多目标跟踪方法 |
CN116416212B (zh) * | 2023-02-03 | 2023-12-08 | 中国公路工程咨询集团有限公司 | 路面破损检测神经网络训练方法及路面破损检测神经网络 |
CN116486285B (zh) * | 2023-03-15 | 2024-03-19 | 中国矿业大学 | 一种基于类别掩码蒸馏的航拍图像目标检测方法 |
CN117315617B (zh) * | 2023-09-06 | 2024-06-07 | 武汉理工大学 | 基于师徒模式的网络优化方法、系统、电子设备及介质 |
CN116935168B (zh) * | 2023-09-13 | 2024-01-30 | 苏州魔视智能科技有限公司 | 目标检测的方法、装置、计算机设备及存储介质 |
CN117274724B (zh) * | 2023-11-22 | 2024-02-13 | 电子科技大学 | 基于可变类别温度蒸馏的焊缝缺陷分类方法 |
CN117807235B (zh) * | 2024-01-17 | 2024-05-10 | 长春大学 | 一种基于模型内部特征蒸馏的文本分类方法 |
CN118627571A (zh) * | 2024-07-12 | 2024-09-10 | 腾讯科技(深圳)有限公司 | 模型训练方法、装置、电子设备及计算机可读存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021023202A1 (zh) * | 2019-08-07 | 2021-02-11 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法 |
CN112529178A (zh) * | 2020-12-09 | 2021-03-19 | 中国科学院国家空间科学中心 | 一种适用于无预选框检测模型的知识蒸馏方法及系统 |
CN112560693A (zh) * | 2020-12-17 | 2021-03-26 | 华中科技大学 | 基于深度学习目标检测的高速公路异物识别方法和系统 |
CN112766087A (zh) * | 2021-01-04 | 2021-05-07 | 武汉大学 | 一种基于知识蒸馏的光学遥感图像舰船检测方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180268292A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
-
2021
- 2021-07-23 CN CN202110838933.8A patent/CN113610126B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021023202A1 (zh) * | 2019-08-07 | 2021-02-11 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法 |
CN112529178A (zh) * | 2020-12-09 | 2021-03-19 | 中国科学院国家空间科学中心 | 一种适用于无预选框检测模型的知识蒸馏方法及系统 |
CN112560693A (zh) * | 2020-12-17 | 2021-03-26 | 华中科技大学 | 基于深度学习目标检测的高速公路异物识别方法和系统 |
CN112766087A (zh) * | 2021-01-04 | 2021-05-07 | 武汉大学 | 一种基于知识蒸馏的光学遥感图像舰船检测方法 |
Non-Patent Citations (1)
Title |
---|
基于增强监督知识蒸馏的交通标识分类;赵胜伟;葛仕明;叶奇挺;罗朝;李强;;中国科技论文(20);第78-83页 * |
Also Published As
Publication number | Publication date |
---|---|
CN113610126A (zh) | 2021-11-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113610126B (zh) | 基于多目标检测模型无标签的知识蒸馏方法及存储介质 | |
CN108846384A (zh) | 融合视频感知的多任务协同识别方法及系统 | |
Jiang et al. | An eight-layer convolutional neural network with stochastic pooling, batch normalization and dropout for fingerspelling recognition of Chinese sign language | |
CN112036276A (zh) | 一种人工智能视频问答方法 | |
Wang et al. | SemCKD: Semantic calibration for cross-layer knowledge distillation | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
CN113705218A (zh) | 基于字符嵌入的事件元素网格化抽取方法、存储介质及电子装置 | |
Wang et al. | A residual-attention offline handwritten Chinese text recognition based on fully convolutional neural networks | |
CN112818889A (zh) | 基于动态注意力的超网络融合视觉问答答案准确性的方法 | |
CN113609326A (zh) | 基于外部知识和目标间关系的图像描述生成方法 | |
Yin et al. | Self-paced active learning for deep CNNs via effective loss function | |
Gajurel et al. | A fine-grained visual attention approach for fingerspelling recognition in the wild | |
Liu et al. | Zero-shot learning with attentive region embedding and enhanced semantics | |
CN116136870A (zh) | 基于增强实体表示的智能社交对话方法、对话系统 | |
CN116089874A (zh) | 一种基于集成学习和迁移学习的情感识别方法及装置 | |
Choi et al. | Combining deep convolutional neural networks with stochastic ensemble weight optimization for facial expression recognition in the wild | |
Sun et al. | Updatable Siamese tracker with two-stage one-shot learning | |
Guo et al. | JAC-Net: Joint learning with adaptive exploration and concise attention for unsupervised domain adaptive person re-identification | |
CN115796029A (zh) | 基于显式及隐式特征解耦的nl2sql方法 | |
Ji et al. | A recognition method for Italian alphabet gestures based on convolutional neural network | |
Yu et al. | UnifiedTT: Visual tracking with unified transformer | |
Sheng et al. | Weakly supervised coarse-to-fine learning for human action segmentation in HCI videos | |
Shi | Image Recognition of Skeletal Action for Online Physical Education Class based on Convolutional Neural Network | |
Zhou et al. | Unit Correlation With Interactive Feature for Robust and Effective Tracking | |
Kindiroglu et al. | Transfer Learning for Cross-dataset Isolated Sign Language Recognition in Under-Resourced Datasets |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |