CN115953643A - 基于知识蒸馏的模型训练方法、装置及电子设备 - Google Patents

基于知识蒸馏的模型训练方法、装置及电子设备 Download PDF

Info

Publication number
CN115953643A
CN115953643A CN202211608782.8A CN202211608782A CN115953643A CN 115953643 A CN115953643 A CN 115953643A CN 202211608782 A CN202211608782 A CN 202211608782A CN 115953643 A CN115953643 A CN 115953643A
Authority
CN
China
Prior art keywords
model
teacher
student
output
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211608782.8A
Other languages
English (en)
Inventor
于铭扬
唐三立
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Goldway Intelligent Transportation System Co Ltd
Original Assignee
Shanghai Goldway Intelligent Transportation System Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Goldway Intelligent Transportation System Co Ltd filed Critical Shanghai Goldway Intelligent Transportation System Co Ltd
Priority to CN202211608782.8A priority Critical patent/CN115953643A/zh
Publication of CN115953643A publication Critical patent/CN115953643A/zh
Pending legal-status Critical Current

Links

Images

Landscapes

  • Image Analysis (AREA)

Abstract

本申请实施例提供一种基于知识蒸馏的模型训练方法、装置及电子设备,涉及机器学习技术领域,实现了在单阶段检测网络的学生模型以及两阶段检测网络的教师模型的情况下,其中学生模型与教师模型为异构模型,以知识蒸馏的方式训练用于目标检测的学生模型。该方法包括:获取已经训练好的教师模型;将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,并确定特征蒸馏损失;将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,并确定输出蒸馏损失;依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数;基于改进后的损失函数训练学生模型,得到训练完成的模型。

Description

基于知识蒸馏的模型训练方法、装置及电子设备
技术领域
本申请涉及机器学习技术领域,尤其涉及一种基于知识蒸馏的模型训练方法、装置及电子设备。
背景技术
目标检测的任务是找出图像中感兴趣的目标,确定它们的位置和类别。实现目标检测的途径通常是训练出目标检测模型并通过训练好的目标检测模型实现目标检测功能。对于训练得到的目标检测模型而言,模型的复杂度通常会随着模型精度的提高而越加复杂,进而对于部署资源的要求也就会越高。知识蒸馏是一种通过引入教师模型来诱导学生模型训练的模型压缩方法,可以实现从教师模型对学生模型的知识迁移。故而,通过知识蒸馏的方法,可以将训练得到的复杂模型作为教师模型,并将教师模型的知识迁移至结构更为简单的学生模型,使得学生模型在保留教师模型精度的同时,解决模型部署资源不足问题。
上述基于知识蒸馏所提及的教师模型与学生模型可以为同构模型,也可以为异构模型。在本申请中,目标检测模型训练过程所使用的两阶段检测网络的教师模型与单阶段检测网络的学生模型,其中学生模型与教师模型为异构模型。而异构模型所带来的预测框尺寸不匹配、类别得分意义不匹配及特征尺寸不匹配等问题目前尚未有完整的解决方法。因此,本申请针提出了一种基于知识蒸馏的模型训练方法、装置、电子设备及存储介质,以解决上述提及的一系列异构蒸馏问题。
发明内容
本申请提供一种基于知识蒸馏的模型训练方法、装置及电子设备,可以解决采用知识蒸馏的方法训练目标检测模型时,对于单阶段检测网络的学生模型以及两阶段检测网络的教师模型,也即学生模型与教师模型为异构模型的情况下,预测框尺寸不匹配、类别得分意义不匹配及特征尺寸不匹配等一系列异构蒸馏问题。
为达到上述目的,本申请实施例采用如下技术方案:
第一方面,本申请提供了一种基于知识蒸馏的模型训练方法,该方法中,知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络,也即教师模型与学生模型为异构模型,将教师模型中的知识蒸馏迁移至学生模型,经过知识蒸馏过程的学生模型为训练得到的模型,用于对输入图像进行目标检测,该方法包括:获取已经训练好的教师模型;将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,并确定特征蒸馏损失,特征蒸馏损失表示特征尺度对齐后,教师模型与学生模型特在特征层的差异程度;将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,并确定输出蒸馏损失,输出蒸馏损失表示预测框及对应的概率分布对齐后,教师模型与学生模型在输出层的差异程度,概率分布为预测框的输出的概率分布;依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数;基于改进后的损失函数训练学生模型,得到训练完成的模型。
本申请实施例提供的技术方案至少带来以下有益效果:
本申请采用知识蒸馏的方式训练得到用于目标检测的学生模型,且知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络,也即教师模型与学生模型为异构模型。在训练时,本申请通过将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,解决了学生模型与教师模型输出信息不对齐的问题;通过将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,解决了学生模型与教师模型特征信息不对齐的问题;通过确定特征蒸馏损失和输出蒸馏损失对学生模型的损失函数进行改进,并通过改进后的损失函数指导学生模型的训练,可以使学生模型学习到教师模型的知识,在不改变学生模型结构的前提下,使学生模型拥有接近教师模型的泛化能力和功能。综上所述,本申请一方面提出了完整的异构蒸馏方案,在所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络,也即教师模型与学生模型为异构模型的情况下,采用知识蒸馏的方式实现了目标检测模型的训练。另一方面,通过利用结构更复杂,功能更完善的教师模型对学生模型进行知识蒸馏,可以在不改变学生模型结构的情况下,令学生模型具有与教师模型相近的能力。同时,由于学生模型结构简单,对于部署资源性能的要求较低,训练得到的用于目标检测的学生模型的应用场景也会更广阔。
在一种可能的实现方式中,上述将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,包括:获取教师模型的教师特征集合;其中,教师特征集合中的教师特征是教师模型的输入的第一图像中目标区域的特征,目标区域为基于锚框确定的区域,目标区域的特征包括目标区域的前景区域特征及背景区域特征;在学生模型的各层特征中获取的学生特征集合;其中,学生特征集合中的学生特征是目标区域在学生模型中的特征;将教师特征集合与学生特征集合中的特征尺度转换至同一维度。
在该种可能的实现方式中,提出了一种解决知识蒸馏过程中教师模型与学生模型特征尺度不匹配问题的方法,将学生模型及教师模型的特征信息转换至同一维度,有助于提升方案的可实施性。
在一种可能的实现方式中,上述在学生模型的各层特征中获取的学生特征集合,包括:基于输入学生模型的第一图像中的目标区域提取各层的学生区域特征,得到学生特征集合。
在一种可能的实现方式中,上述确定在特征层的特征蒸馏损失,包括:基于注意力机制或余弦相似度算法计算教师特征集合中的各教师特征与学生特征集合中的各学生特征之间的特征相似度,并依据特征相似度确定特征蒸馏损失。
在该种可能的实现方式中,可以得到教师特征集合中的各特征与学生集合中的各特征两两之间的特征相似度,据此可以进一步计算出特征对齐过程中的特征蒸馏损失,以便后续基于该特征蒸馏损失改进学生模型的损失函数,使训练得到的学生模型具有更好的目标检测性能。
在一种可能的实现方式中,上述将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,包括:计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框;获取目标预测框对应的目标概率分布,目标概率分布为目标预测框的输出的概率分布;将目标概率分布与教师模式输出层的预测框对应的概率分布对齐。
在该种可能的实现方式中,通过计算各教师预测框与学生预测框的交并比,选出与教师预测框最匹配的学生预测框,并将选出的学生预测框的概率分布与对应的教师预测框的概率分布进行含义对齐,解决了异构蒸馏中预测框尺寸不一致以及类别得分意义不对齐的问题。
在一种可能的实现方式中,上述计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框,包括:将输入图像对应的特征图划分为多个网格,并计算各教师预测框的中心点在特征图中的中心点所属网格;在每一个教师预测框对应的中心点所属网格中,计算该中心点所属网格中各学生预测框与教师预测框的交并比;选择数值最大的所述交并比对应的学生预测框作为目标预测框。
在该种可能的实现方式中,确定每个教师预测框的中心点所属的网格,计算各教师预测框与所有在其中心点所属网格中的各学生预测框的交并比,并根据计算出的交并比选择出最匹配的学生预测框,提出了方案的可实施性。此外,本方式将匹配范围限制在了同一特征网格内,可以应对部分学生模型训练过程中的偏移值有区间限制的情况。
在一种可能的实现方式中,上述计计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框,包括:对于每一个教师预测框,计算该教师预测框与特征图中各学生预测框的交并比;选择数值最大的交并比对应的学生预测框作为目标预测框。
在该种可能的实现方式中,通过计算各教师预测框与特征图中所有学生预测框的交并比来为各教师预测框选择最匹配的学生预测框,可以实现学生模型预测框没有区间限制情况下的预测框对齐,有助于提升方案的可实施性。
在一种可能的实现方式中,上述确定输出蒸馏损失,包括:基于学生模型的输出概率分布与前景置信度的乘积以及教师模型输出概率分布,计算输出蒸馏损失。
在该种可能的实现方式中,提出了一种确定输出蒸馏损失的具体实现方法,提升了方案的可实施性。此外,通过计算输出蒸馏损失,可在后续基于该输出蒸馏损失对学生模型的损失函数进行改进,以使训练得到的学生模型具有更好的目标检测性能。
在一种可能的实现方式中,上述确定在输出层的输出蒸馏损失,包括:对所述目标概率分布对应的所述学生模型的输出概率与所述教师模型的输出概率归一化,并基于归一化的结果确定所述输出蒸馏损失。
在该种可能的实现方式中,提出了另一种确定输出蒸馏损失的具体实现方法,提升了方案的可实施性。此外,通过计算输出蒸馏损失,可在后续基于该输出蒸馏损失对学生模型的损失函数进行改进,以使训练得到的学生模型具有更好的目标检测性能。
在一种可能的实现方式中,上述依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数,包括:依据学生模型的改进前的检测损失、输出蒸馏损失以及特征蒸馏损失,算学生模型的改进后的损失函数。
在该种可能的实现方式中,可以实现学生模型改进损失函数的计算,以便基于改进后的损失函数指导学生模型的训练,以使学生模型具有更好的目标检测功能,有助于提升方案的可实施性。
第二方面,本申请实施例提供一种基于知识蒸馏的模型训练装置,该装置具有实现上述第一方面中任一项的基于知识蒸馏的模型训练方法的功能。该功能可以通过硬件实现,也可以通过硬件执行相应的软件实现。该硬件或软件包括一个或多个与上述功能相对应的模块。
第三方面,本申请提供一种电子设备,该电子设备包括存储器和处理器。上述存储器和处理器耦合。该存储器用于存储计算机程序代码,该计算机程序代码包括计算机指令。当处理器执行该计算机指令时,使得电子设备执行如第一方面及其任一种可能的设计方式所述的基于知识蒸馏的模型训练方法。
第四方面,本申请提供一种计算机可读存储介质,该计算机可读存储介质存储有计算机指令,当所述计算机指令在电子设备上运行时,使得电子设备执行如第一方面及其任一种可能的设计方式所述的基于知识蒸馏的模型训练方法。
第五方面,本申请提供一种计算机程序产品,该计算机程序产品包括计算机指令,当计算机指令在电子设备上运行时,使得电子设备执行如第一方面及其任一种可能的设计方式所述的基于知识蒸馏的模型训练方法。
本申请中第二方面到第五方面及其各种实现方式的具体描述,可以参考第一方面及其各种实现方式中的详细描述;并且,第二方面到第五方面及其各种实现方式的有益效果,可以参考第一方面及其各种实现方式中的有益效果分析,此处不再赘述。
本申请的这些方面或其他方面在以下的描述中会更加简明易懂。
附图说明
图1为现有技术中教师模型与学生模型特征尺度不一致的示意图;
图2为本申请实施例提供的一种基于知识蒸馏的模型训练方法所涉及的实施环境示意图;
图3为本申请实施例提供的一种基于知识蒸馏的模型训练方法的流程图;
图4为本申请实施例提供的一种基于知识蒸馏的模型训练方法中在教师模型中确定输入图像的目标区域的示意图;
图5为本申请实施例提供的一种基于知识蒸馏的模型训练方法中基于注意力机制计算特征相似度及匹配得分的示意图;
图6为本申请实施例提供的一种基于知识蒸馏的模型训练方法中一种确定目标预测框的方法的示意图;
图7为本申请实施例提供的一种基于知识蒸馏的模型训练方法的具体实施例的异构蒸馏整体框架的示意图;
图8为本申请实施例提供的一种基于知识蒸馏的模型训练方法的具体实施例的流程图;
图9为本申请实施例提供的一种基于知识蒸馏的目标检测模型训练装置的结构示意图;
图10为本申请实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本申请的描述中,除非另有说明,“多个”的含义是两个或两个以上。
在本申请实施例中,“示例性的”或者“例如”等词用于表示作例子、例证或说明。本申请实施例中被描述为“示例性的”或者“例如”的任何实施例或设计方案不应被解释为比其它实施例或设计方案更优选或更具优势。确切而言,使用“示例性的”或者“例如”等词旨在以具体方式呈现相关概念。
首先,对本申请实施例涉及的技术术语进行介绍:
1、知识蒸馏
知识蒸馏是模型压缩的一种常用方法,不同于剪枝和量化,知识蒸馏是通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息来训练上述构建的小模型,以期达到更好的性能和精度。其中,上述用于训练的性能更好的大模型为教师模型,被训练的轻量化的小模型为学生模型。
2、目标检测
目标检测的任务是找出图像中所有感兴趣的目标,确定它们的类别和位置,是计算机视觉领域的核心问题之一。
知识蒸馏可以将学习能力相对更强的教师模型的知识传递到学习能力相对较弱的学生模型。现有的知识蒸馏技术多为教师模型与学生模型为同构模型的情况,而当教师模型和学生模型为异构模型,例如知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络时,还未有完整方案来解决异构蒸馏中预测框尺寸不匹配、类别得分意义不匹配、特征尺寸不匹配等问题。
其中,候选框尺寸不一致是指在异构蒸馏中,学生模型与教师模型的预测框尺寸完全不同,无法直接对应匹配;类别得分意义不对齐是指在异构蒸馏中,教师模型与学生模型的类别输出意义不一致。以教师模型为两阶段模型,学生模型为单阶段模型为例,教师模型的类别得分输出为C+1维,其中C为类别数,额外的一维是当前区域为背景区域的得分,使用softmax函数(柔性最大传递函数,为激活函数)激活;部分单阶段学生模型的类别得分输出为C维,且额外同过conf分支(用于区别当前区域为前景区域还是背景区域的置信度分支)输出当前区域是前景的得分,均使用sigmoid函数(一类具有指数函数形状的激活函数,具有指数函数形状,在物理意义上最为接近生物神经元)激活。激活函数的不同意味着二者的功能含义不同;特征尺度不对齐指的是,当异构蒸馏目标为neck(一系列混合和组合图像特征的网络层)特征时,教师模型下采样倍数与学生模型下采样倍数无法对齐,如图1所示。而在实际问题中,教师模型与学模型特征尺度不对齐问题更加严重。例如,在一种可能的情况下,教师模型neck输出5个尺度的特征,而学生模型neck只输出3个甚至1个尺度的特征。若按照传统的同构特征蒸馏方案,至少有2层教师neck特征无法参与蒸馏,教师特征知识的利用率较低。
为了解决现有技术中存在的上述问题,本申请提出了一种基于知识蒸馏的模型训练方法,下面将结合附图对本申请实施例的实施方式进行详细描述:
请参考图2,其示出本申请实施例提供的一种基于知识蒸馏的模型训练方法的实施环境示意图。如图2所示,该实施环境可以包括:终端210、终端220以及服务器230。其中,终端210及终端220用于部署教师模型或学生模型,服务器230用于异构蒸馏及基于异构蒸馏过程训练学生模型。
示例性地,上述终端210用于部署教师模型,上述终端220用于部署学生模型,其中,上述教师模型与学生模型为异构模型,则上述基于知识蒸馏的模型训练方法包括:将已经训练好的用于目标检测的教师模型部署于终端210,将用于目标检测的学生模型部署于终端220,服务器230从终端210中提取教师模型的特征信息,从终端220中提取学生模型的特征信息,且在提取学生模型的特征信息时,需将学生模型的特征信息与教师模型的特征信息进行匹配,并计算得到特征蒸馏损失;服务器230从终端210中提取教师模型的输出信息,从终端220中提取学生模型的输出信息,并将学生模型的输出信息与教师模型的输出信息进行匹配,计算得到输出蒸馏损失;服务器230基于上述计算得到的特征蒸馏损失函数与输出蒸馏损失函数对学生模型的损失函数进行改进,并基于改进后的损失函数指导学生模型的训练,以使训练得到的学生模型拥有与教师模型相近的目标检测功能。
需要说明的是,本申请实施例中的终端210及终端220可以是手机、平板电脑、笔记本电脑、以及其他具有数据处理能力的设备等,本申请实施例对该终端210及终端220的具体形态不作特殊限制。
本申请实施例训练得到的学生模型可用于计算机视觉应用领域的各种场景,或者是部署到边缘设备上(例如移动电话、可穿戴设备、计算节点等)的基于神经网络模型的处理系统,或者是由于有限资源和时延要求需要对神经网络模型进行压缩的应用场景。
示例性地,上述终端220可以为智能手机,上述学生模型既保留了与教师模型相近的目标检测功能,又因结构简单可部署于有限资源的终端,故可将上述学生模型部署于该智能手机终端220,用户可以使用手机自动抓取人脸、动物等目标,从而可以在拍照时帮助手机自动对焦、美化等,给用户带来更好的用户体验。此外,上述终端220还可用于自动驾驶场景分割的应用场景。自动驾驶车辆的摄像头捕捉到道路画面后需要对画面进行分割,从中分出路面、路基、车辆、行人等不同物体,从而保持车辆行驶在正确的区域。本申请实施例可以依据具体应用环境进行调整和改进,此处不做具体限定。
下面结合图3所示的流程,对本申请实施例提供的基于知识蒸馏的模型训练方法进行详细说明,如图3所示,该方法可以包括S301-S305。
S301、获取已经训练好的所述教师模型。
其中,上述教师模型可以用于对所述输入图像进行目标检测。
本示例实施方式所提出的方法用于在教师模型和学生模型为异构模型,例如知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络时,将知识蒸馏用于目标检测模型的训练。
上述教师模型为知识蒸馏中功能更强的用于提供监督信息的大模型,上述学生模型为学习教师模型知识的小模型。示例性地,上述教师模型可以为对锚框进行两阶段或多阶段优化的模型,如Faster RCNN、Cascade RCNN等模型,上述学生模型可以为仅对锚框进行单阶段优化的模型,如RetinaNet、YOLO等模型,因教师模型与学生模型的优化阶段不同,锚框尺寸也不一致,在进行知识蒸馏时无法直接进行匹配,本实施例提出的方案用于解决这一异构蒸馏问题。
需要说明的是,上述情况只是一种示例性说明,其他教师模型与学生模型为异构模型的情况也属于本示例实施方式的保护范畴。
在本示例实施方式中,上述教师模型为预先训练好的模型,该教师模型具有较强的目标检测能力,但结构也相对复杂,对部署资源的要求相对更高。上述输入图像指输入该教师模型的图像数据。上述该教师模型用于对输入图像进行目标检测可以理解为该教师模型对输入模型的图像数据中的物体进行分类及定位。示例性地,上述输入图像可以为一副包含人、狗、树等多种物体的照片,将该照片输入教师模型,教师模型可以检测出图像中的目标物体为人、狗还是树,可以定位人在照片中的位置。需要说明的是,上述场景只是一种示例性说明,本示例实施方式对此不做特殊限定。
S302、将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,并确定在特征层的特征蒸馏损失。
其中,特征蒸馏损失表示特征尺度对齐后,教师模型与学生模型特在特征层的差异程度。
在本示例实施方式中,上述将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,可以具体实现为以下步骤1至步骤3:
步骤1、获取教师模型的教师特征集合。
其中,教师特征集合中的教师特征是教师模型的输入的第一图像中目标区域的特征,目标区域为基于锚框确定的区域,目标区域的特征包括目标区域的前景区域特征及背景区域特征。
可选的,可以基于锚框在输入图像中确定一个目标区域,并获取该目标区域对应的目标区域特征,作为教师特征集合,该目标区域特征既包括目标区域的前景特征,也包括目标区域的背景特征,如图4所示,401为图像的一个锚框,该锚框确定的区域即为目标区域,该区域中的前景特征与背景特征均为目标区域特征。
需要说明的是,上述场景只是一种示例性说明,本示例实施方式的保护范畴并不以此为限,上述目标区域也可以图像中的其他区域。
步骤2、在学生模型的各层特征中获取学生特征集合。
其中,上述学生特征集合中的学生特征是该目标区域在学生模型中的特征。
可选的,可以确定上述锚框确定的教师模型中输入图像(例如上述第一图像)的目标区域的位置,并在该图像(例如上述第一图像)输入学生模型时,在该图像的对应位置中获取学生模型的区域特征,得到学生特征集合。
步骤3、将教师特征集合与学生特征集合中的特征尺度转换至同一维度。
示例性地,上述将教师特征集合与学生特征集合中的特征尺度转换至同一维度可以实现如下:依次对上述学生特征集合进行尺度变换和通道维度变换,使得教师模型与学生模型的特征尺度变换到同一纬度。具体地,上述转换过程可通过维度转化模块实现,该维度转化模块可以包括RoIAlign(用于将任意尺寸的目标区域的特征图转换为具有固定尺寸的小特征图)和Adaptive Layer(自适应层)。
上述确定在特征层的特征蒸馏损失可以实现如下:计算教师特征集合中的各教师特征与学生特征集合中的各学生特征之间的特征相似度,并归一化得到对应匹配得分;依据各匹配得分以及与各匹配得分对应的蒸馏损失计算得到特征蒸馏损失。
示例性地,上述计算教师特征集合中的各教师特征与学生特征集合中的各学生特征之间的特征相似度,并归一化得到对应匹配得分可以采用基于注意力机制的方式或余弦相似度实现。
具体地,基于注意力机制的方式可以如图5所示的流程实现:教师模型的特征信息,也即教师特征集合FeatT依次经过avg pooling(平均滤波卷积)、全连接层Query layer全连接层,且与教师特征集合对应的目标区域位于对应区域的学生特征FeatS经过avgpooling(平均滤波卷积)、全连接层Key layer后,对教师特征与学生特征进行相似度计算,并通过softmax函数进行归一化得到匹配得分α。
可选的,基于余弦相似度算法的方式则可具体实现如下:教师特征和学生特征分别经过avg pooling(平均滤波卷积)后,计算教师特征和学生特征间的余弦相似度,并通过softmax函数归一化得到匹配得分α。其中,上述匹配过程为教师特征集合与学生特征集合中的特征两两之间进行匹配,得到对应的匹配得分α。
示例性地,上述依据各匹配得分以及与各匹配得分对应的蒸馏损失计算得到特征蒸馏损失可基于以下公式计算得到:
L1=∑ijαij·LMSE(fi,fj)
其中,fi和fj分别表示单层教师特征和学生特征,α表示上述过程计算得到的匹配得分。LMSE(,)表示MSE损失函数。
S303、将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,并确定在输出层的输出蒸馏损失。
其中,上述输出蒸馏损失表示预测框及对应的概率分布对齐后,教师模型与学生模型在输出层的差异程度,概率分布为预测框的输出的概率分布。
在本示例实施方式中,该步骤首先获取教师模型输出层的预测框及对应的概率分布,然后将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,并计算输出蒸馏损失,以完成对输出层的知识蒸馏。
在本示例实施方式中,上述教师预测框与学生预测框即为用于在上述输入图像中定位的锚框,上述教师模型的概率分布可以为对应教师预测框输出的类别概率分布,上述学生模型的概率分布可以为对应学生预测框输出的类别概率分布。
示例性地,上述将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐可以实现如下:计算学生预测框与教师预测框的交并比,并依据交并比在多个学生预测框中选择与对应的教师预测框匹配的目标预测框;获取上述目标预测框对应的目标概率分布;将目标概率分布与相匹配的教师预测框对应的概率分布进行含义对齐。
在一种可能的实现方式中,上述计算学生预测框与教师预测框的交并比,并依据交并比在多个学生预测框中选择与对应的教师预测框匹配的目标预测框可实现如下:
将输入图像对应的特征图划分为多个网格;计算各教师预测框的中心点在特征图中的中心点所属网格;在每一个教师预测框对应的中心点所属网格中,计算该中心点所属网格中各学生预测框与教师预测框的交并比;选择数值最大的交并比对应的学生预测框作为目标预测框。
示例性的,以图6为例对上述过程进行更加详细的说明:如图6所示,输入图像被划分为多个网格,对于每一个教师预测框,计算其中心点在图像中所属的网格,假设当前教师预测框所属网格为图6中的网格601,则获取该网格601中所有的学生预测框,并计算各学生预测框与当前教师框的交并比,选取数值最大的交并比对应的学生预测框作为当前教师预测框对应的目标预测框,对其他所有教师预测框重复上述过程获取对应的目标预测框。
需要说明的是,上述可能的实现方式将匹配范围限制在了同一特征网格内,可以应对部分学生模型训练过程中预测框的偏移值有区间限制的情况。该情况下教师提供的监督值中没有区间限制,而学生模型有区间上限,无法完全拟合。
在另一种可能的实现方式中,上述计算学生预测框与教师预测框的交并比,并依据交并比在多个学生预测框中选择与对应的教师预测框匹配的目标预测框可实现如下:
对于每一个教师预测框,计算该教师预测框与特征图中各学生预测框的交并比;选择数值最大的交并比对应的学生预测框作为目标预测框。
在该种可能的实现方式中,将输入图像的特征图作为一个整体,不进行网格划分。对于每一个教师预测框,获取该特征图中的全部学生预测框,并计算教师预测框与全部学生预测框的交并比,选取数值最大的交并比对应的学生预测框作为目标预测框。需要说明的是,上述可能的实现方式适用于学生模型预测框的偏移值没有区间限制的情况。
进一步地,在一些实施例中,上述确定输出蒸馏损失的可以实现如下:将目标概率分布与学生模型的前景置信度相乘,以将目标概率分布与相匹配的教师预测框对应的概率分布进行含义对齐;基于目标概率分布与前景置信度的乘积、教师模型及学生模型的输出概率分布计算输出蒸馏损失,输出蒸馏损失的具体计算公式可如下:
L2=h(softmax(p),conf*sigmoid(q))
其中,sigmoid(q)为学生模型的输出概率分布,softmax(p)为教师模型的输出概率分布,softmax、sigmoid函数的操作与网络训练过程操作保持一致,conf为学生模型的前景置信度,该前景置信度用于表示当前预测框确定的区域为前景区域的概率,损失函数h()可以是KL损失、ce损失等形式。
此外,在另一实施例中,上述确定输出蒸馏损失的还可以实现如下:对目标概率分布对应的学生模型的输出概率与教师模型的输出概率归一化,以将目标概率分布与相匹配的教师预测框对应的概率分布进行含义对齐;基于目标概率分布对应的学生模型的输出概率分布与教师模型的输出概率分布计算输出蒸馏损失,输出蒸馏损失的具体计算公式可如下:
L2=h(softmax(p),softmax(q))
其中,softmax(p)为学生模型的输出概率分布,softmax(q)为教师模型的输出概率分布,softmax、sigmoid为归一化函数,损失函数h()可以是KL损失、ce损失等形式。
S304、依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数。
在本示例实施方式中,该步骤用于改进学生模型的损失函数,并基于改进后的损失函数指导学生模型的训练,以获取与教师模型具有相近目标检测功能的学生模型。该过程可以实现如下:依据学生模型的原始检测损失,上述计算得到的特征蒸馏损失L1以及输出蒸馏损失L2计算得到学生模型的改进后的损失函数。
示例性地,上述计算改进后的损失函数可实现为:将输出蒸馏损失与对应的输出权重参数相乘,得到第一乘积;将特征蒸馏损失与对应的特征权重参数相乘,得到第二乘积;将原始检测损失、第一乘积及第二乘积相加,得到改进后的损失函数,改进后的损失函数的具体计算公式如下:
L=Ldet+β·L1+γ·L2
其中,Ldet为学生模型的原始检测损失,L1和L2分别为特征异构蒸馏损失和输出异构蒸馏损失,β和γ分别为特征异构蒸馏损失和输出异构蒸馏损失对应的损失权重超参。
S305、基于改进后的所述损失函数训练学生模型,得到训练完成的模型。
下面,结合图7与图8所示的具体应用场景,对上述基于知识蒸馏的模型训练方法进行完整详细地说明:
图7为上述基于知识蒸馏的模型训练方法的异构蒸馏整体框架,应用于教师模型与学生模型为异构模型时,通过知识蒸馏的方式训练目标检测模型。该框架包括两阶段教师模型及单阶段学生模型,分别从教师模型及学生模型中提取对应的特征层信息及输出层信息,并进行特征异构蒸馏及输出异构蒸馏来改进学生模型。
对图7对应的具体实施过程如图8所示,包括以下步骤:
S801、获取已经训练好的教师模型,该教师模型用于对输入图像进行目标检测。
该步骤用于获取预先训练好的教师模型,该教师模型为知识蒸馏中功能更强的用于提供监督信息的大模型,用于实现目标检测。
S802、对教师模型与学生模型进行特征异构蒸馏。
该步骤用于对教师模型与学生模型进行特征异构蒸馏,并计算特征异构蒸馏损失。具体实现如下:
S8021、获取上述教师模型的目标区域特征,既包含前景区域,也包含背景区域,得到教师模型对应的教师特征集合FeatT。
S8022:在学生模型的各层特征中得到教师模型的目标区域所对应的区域特征,得到学生特征集合FeatS。
S8023:将教师特征集合与学生特征集合中特征的尺度通过维度转换模块变换到同一维度。
具体地,上述维度转换模块可以包含RoIAlign和Adaptive Layer(1x1 conv),该步骤可实现为:通过该维度转换模块依次对FeatS进行尺度变换和通道维度转换,从而使教师特征集合与学生特征集合的特征尺度变换到同一维度。
S8024:针对变换后的特征集合做多对多的匹配,匹配得分用来对损失进行加权。
实现方式1:可以采用注意力机制计算特征之间的匹配得分。教师特征集合依次经过avg pooling(平均滤波卷积)、全连接层Query layer全连接层,且与教师特征集合对应的目标区域位于对应区域的学生特征经过avg pooling(平均滤波卷积)、全连接层Keylayer后,对教师特征与学生特征进行相似度计算,并通过softmax函数进行归一化得到匹配得分α。
实现方式2:采用余弦相似度计算匹配得分。教师特征集合和学生特征集合分别经过avg pooling后,计算教师特征和学生特征间的余弦相似度,最后通过softmax归一化得到匹配得分α。
S8025:匹配得分与对应特征蒸馏损失相乘作为得到最终的特征异构蒸馏损失。
L1=∑ijαij·LMSE(fi,fj)
其中,fi和fj分别表示单层教师特征和学生特征,α表示上述过程计算得到的匹配得分。LMSE(,)表示MSE损失函数。
S803、对教师模型与学生模型进行输出异构蒸馏。
该步骤用于对教师模型与学生模型进行输出异构蒸馏,并计算输出异构蒸馏损失。具体实现如下:
S8031:获取教师模型和学生模型的预测框。
S8032:对教师模型的预测框和学生模型的预测框进行位置匹配,预测框匹配方式为:计算学生模型和教师模型的预测框的交并比,并选择交并比最大的预测框在训练过程中在线匹配。示例性地,具体实现可以如下:
实现方式1:将特征图划分为SxS(S为正整数)的网格;计算每个教师预测框中心点所属的网格;每个教师预测框与中心点所属的网格内的所有学生预测框计算交并比;选择交并比最大的学生预测框与其组成匹配对。
该实现方式将匹配范围限制在了同一特征网格内,可以应对部分学生模型训练过程中预测框的偏移值有区间限制的情况。该情况下教师提供的监督值中没有区间限制,而学生模型有区间上限,无法完全拟合。
实现方式2:每个教师预测框与特征图中的所有学生预测框计算交并比;选择交并比最大的学生预测框与其组成匹配对。该实现方式适用于学生模型预测框的偏移值没有区间限制的情况。
S8033:获取已匹配的教师模型预测框及学生模型预测框对应的概率分布。
S8034:将教师模型预测框的概率分布与学生模型预测框的概率分布进行意义对齐,并计算输出蒸馏损失。
实现方式1:具体的输出蒸馏损失形式由下式所示:
L2=h(softmax(p),conf*sigmoid(q))
其中,q为学生模型的输出概率分布,p为教师模型的输出概率分布,softmax、sigmoid函数的操作与网络训练过程操作保持一致,conf为学生模型的前景置信度,该前景置信度用于表示当前预测框确定的区域为前景区域的概率,将学生的前景置信度得分conf与分类概率分布相乘,可以使概率分布含义对齐,损失函数h()可以是KL损失、ce损失等形式。
实现方式2:
L2=h(softmax(p),softmax(q))
学生模型与教师模型的输出概率的归一化形式一致时,统一归一化得到意义一致的输出概率分布,并通过损失函数h()计算蒸馏损失。其中,q为学生模型的输出概率分布,p为教师模型的输出概率分布,softmax、sigmoid为归一化函数,损失函数h()可以是KL损失、ce损失等形式。
S804、改进学生模型的损失函数,并基于改进后的损失函数训练学生模型。
改进后的损失函数的具体计算公式如下:
L=Ldet+β·L1+γ·L2
其中,Ldet为学生模型的原始检测损失,L1和L2分别为特征异构蒸馏损失和输出异构蒸馏损失,β和γ分别为特征异构蒸馏损失和输出异构蒸馏损失对应的损失权重超参。
本示例实施方式通过提供完整的异构蒸馏方案,可以在目标检测模型的训练过程中,使用性能更好的两阶段检测网络作为教师模型,从而提高了蒸馏的性能上限,拓展了教师的选择范围和蒸馏算法的应用范围。此外,本实施例方案已在Adas(Advanced DriverAssistance System,先进驾驶辅助系统)业务场景、公开数据集VOC均验证有效。
上述主要从方法的角度对本申请实施例提供的方案进行了介绍。为了实现上述功能,其包含了执行各个功能相应的硬件结构和/或软件模块。本领域技术目标应该很容易意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,本申请能够以硬件或硬件和计算机软件的结合形式来实现。某个功能究竟以硬件还是计算机软件驱动硬件的方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术目标可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
本申请实施例还提供一种基于知识蒸馏的目标检测模型训练装置,其中,教师模型与学生模型为异构模型,例如知识蒸馏中所涉及的教师模型为两阶段检测网络、学生模型为单阶段检测网络。如图9所示,为本申请实施例提供的一种基于知识蒸馏的模型训练装置900的结构示意图。该装置900可以包括:教师模型获取模块901、特征蒸馏模块902、输出蒸馏模块903和学生模型训练模块904,其中:
教师模型获取模块901,可以用于获取已经训练好的所述教师模型。
特征蒸馏模块902,可以用于将学生模型特征层的特征尺度与教师模型特征层的特征尺度对齐,并确定特征蒸馏损失,特征蒸馏损失表示特征尺度对齐后,教师模型与学生模型特在特征层的差异程度。
输出蒸馏模块903,可以用于将学生模型输出层的预测框及对应的概率分布与教师模型输出层的预测框及对应的概率分布对齐,并确定输出蒸馏损失,输出蒸馏损失表示预测框及对应的概率分布对齐后,教师模型与学生模型在输出层的差异程度,概率分布为预测框的输出的概率分布。
学生模型训练模块904,可以用于依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数;并基于改进后的损失函数训练学生模型,得到训练完成的模型。
在一种可能的实现方式中,上述特征蒸馏模块具体用于:获取教师模型的教师特征集合;其中,教师特征集合中的教师特征是教师模型的输入的第一图像中目标区域的特征,目标区域为基于锚框确定的区域,目标区域的特征包括目标区域的前景区域特征及背景区域特征;在学生模型的各层特征中获取学生特征集合;其中,学生特征集合中的学生特征是目标区域在学生模型中的特征;将教师特征集合与学生特征集合中的特征尺度转换至同一维度。
具体地,上述学生模型的各层特征中获取的学生特征集合,包括:基于输入学生模型的第一图像中的目标区域提取各层的学生区域特征,得到学生特征集合。
具体地,上述确定在特征层的特征蒸馏损失,包括:基于注意力机制或余弦相似度算法计算教师特征集合中的各教师特征与学生特征集合中的各学生特征之间的特征相似度,并归一化得到对应匹配得分;依据各匹配得分以及与各匹配得分对应的蒸馏损失计算得到特征蒸馏损失。
在一种可能的实现方式中,上述输出蒸馏模块具体用于:计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框;获取目标预测框对应的目标概率分布,目标概率分布为目标预测框的输出的概率分布;将目标概率分布与教师模式输出层的预测框对应的概率分布对齐。
在一种可能的实现方式中,上述计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框,包括:将输入图像对应的特征图划分为多个网格,并计算各教师预测框的中心点在特征图中的中心点所属网格;在每一个教师预测框对应的中心点所属网格中,计算该中心点所属网格中各学生预测框与教师预测框的交并比;选择数值最大的所述交并比对应的学生预测框作为目标预测框。
在另一种可能的实现方式中,上述计算学生模型输出层的预测框与教师模式输出层的预测框的交并比,并依据交并比在多个学生模型输出层的预测框中选择与对应的教师模式输出层的预测框匹配的目标预测框,包括:对于每一个教师预测框,计算该教师预测框与特征图中各学生预测框的交并比;选择数值最大的交并比对应的学生预测框作为目标预测框。
在一种可能的实现方式中,上述确定输出蒸馏损失,包括:基于学生模型的输出概率分布与前景置信度的乘积以及教师模型输出概率分布,计算输出蒸馏损失。
在另一种可能的实现方式中,上述确定输出蒸馏损失,包括:对目标概率分布对应的学生模型的输出概率与教师模型的输出概率归一化,以将目标概率分布与相匹配的教师预测框对应的概率分布进行含义对齐;基于目标概率分布对应的学生模型的输出概率分布与教师模型的输出概率分布计算输出蒸馏损失。
在一种可能的实现方式中,上述依据特征蒸馏损失及输出蒸馏损失改进学生模型的损失函数,包括:依据学生模型的改进前的检测损失、输出蒸馏损失以及特征蒸馏损失,算学生模型的改进后的损失函数。
当然,本申请实施例提供的基于知识蒸馏的模型训练装置900包括但不限于上述模块。
本申请另一实施例还提供一种电子设备。如图10所示,电子设备1000包括存储器1001和处理器1002;存储器1001和处理器1002耦合;存储器1001用于存储计算机程序代码,计算机程序代码包括计算机指令。其中,当处理器1002执行计算机指令时,使得电子设备900执行上述方法实施例所示的方法流程中电子设备执行的各个步骤。
在实际实现时,教师模型获取模块901、特征蒸馏模块902、输出蒸馏模块903和学生模型训练模块904可以由图10所示的处理器1002调用存储器1001中的计算机程序代码来实现。其具体的执行过程可参考上述方法部分的描述,这里不再赘述。
本申请另一实施例还提供一种计算机可读存储介质,该计算机可读存储介质中存储有计算机指令,当计算机指令在电子设备上运行时,使得电子设备执行上述方法实施例所示的方法流程中电子设备执行的各个步骤。
在本申请另一实施例中还提供一种计算机程序产品,该计算机程序产品包括计算机指令,当计算机指令在电子设备上运行时,使得电子设备执行上述方法实施例所示的方法流程中电子设备执行的各个步骤。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件程序实现时,可以全部或部分地以计算机程序产品的形式来实现。该计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行计算机执行指令时,全部或部分地产生按照本申请实施例的流程或功能。计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,计算机指令可以从一个网站站点、计算机、服务器或者数据中心通过有线(例如同轴电缆、光纤、数字用户线(digitalsubscriber line,DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可以用介质集成的服务器、数据中心等数据存储设备。可用介质可以是磁性介质(例如,软盘、硬盘、磁带),光介质(例如,DVD)等。
以上所述,仅为本申请的具体实施方式。熟悉本技术领域的技术人员根据本申请提供的具体实施方式,可想到变化或替换,都应涵盖在本申请的保护范围之内。

Claims (13)

1.一种基于知识蒸馏的模型训练方法,其特征在于,知识蒸馏中所涉及的教师模型与学生模型为异构模型,所述方法包括:
获取已经训练好的所述教师模型;
将所述学生模型特征层的特征尺度与所述教师模型特征层的特征尺度对齐,并确定特征蒸馏损失,所述特征蒸馏损失表示特征尺度对齐后,所述教师模型与所述学生模型特在特征层的差异程度;
将所述学生模型输出层的预测框及对应的概率分布与所述教师模型输出层的预测框及对应的概率分布对齐,并确定输出蒸馏损失,所述输出蒸馏损失表示预测框及对应的概率分布对齐后,所述教师模型与所述学生模型在输出层的差异程度,所述概率分布为所述预测框的输出的概率分布;
依据所述特征蒸馏损失及所述输出蒸馏损失改进所述学生模型的损失函数;
基于改进后的所述损失函数训练所述学生模型,得到训练完成的模型。
2.根据权利要求1所述的方法,其特征在于,所述将所述学生模型特征层的特征尺度与所述教师模型特征层的特征尺度对齐,包括:
获取所述教师模型的教师特征集合;其中,所述教师特征集合中的教师特征是所述教师模型的输入的第一图像中目标区域的特征,所述目标区域为基于锚框确定的区域,所述目标区域的特征包括所述目标区域的前景区域特征及背景区域特征;
在所述学生模型的各层特征中获取学生特征集合;其中,所述学生特征集合中的学生特征是所述目标区域在所述学生模型中的特征;
将所述教师特征集合与所述学生特征集合中的特征尺度转换至同一维度。
3.根据权利要求2所述的方法,其特征在于,所述在所述学生模型的各层特征中获取的学生特征集合,包括:
基于输入所述学生模型的所述第一图像中的所述目标区域提取所述各层的学生区域特征,得到所述学生特征集合。
4.根据权利要求2或3所述的方法,其特征在于,所述确定在所述特征层的特征蒸馏损失,包括:
基于注意力机制或余弦相似度算法计算所述教师特征集合中的各教师特征与所述学生特征集合中的各学生特征之间的特征相似度,并依据所述特征相似度确定所述特征蒸馏损失。
5.根据权利要求1所述的方法,其特征在于,所述将所述学生模型输出层的预测框及对应的概率分布与所述教师模型输出层的预测框及对应的概率分布对齐,包括:
计算所述学生模型输出层的预测框与所述教师模式输出层的预测框的交并比,并依据所述交并比在多个所述学生模型输出层的预测框中选择与对应的所述教师模式输出层的预测框匹配的目标预测框;
获取所述目标预测框对应的目标概率分布,所述目标概率分布为所述目标预测框的输出的概率分布;
将所述目标概率分布与所述教师模式输出层的预测框对应的概率分布对齐。
6.根据权利要求5所述的方法,其特征在于,所述计算所述学生模型输出层的预测框与所述教师模式输出层的预测框的交并比,并依据所述交并比在多个所述学生模型输出层的预测框中选择与对应的所述教师模式输出层的预测框匹配的目标预测框,包括:
将所述输入图像对应的特征图划分为多个网格,并计算各所述教师预测框的中心点在所述特征图中的中心点所属网格;
在每一个所述教师预测框对应的所述中心点所属网格中,计算该所述中心点所属网格中各所述学生预测框与所述教师预测框的交并比;
选择数值最大的所述交并比对应的所述学生预测框作为所述目标预测框。
7.根据权利要求5所述的方法,其特征在于,所述计算所述学生模型输出层的预测框与所述教师模式输出层的预测框的交并比,并依据所述交并比在多个所述学生模型输出层的预测框中选择与对应的所述教师模式输出层的预测框匹配的目标预测框,包括:
对于每一个所述教师预测框,计算该所述教师预测框与所述特征图中各所述学生预测框的交并比;
选择数值最大的所述交并比对应的所述学生预测框作为所述目标预测框。
8.根据权利要求5所述的方法,其特征在于,所述确定输出蒸馏损失,包括:基于所述学生模型的输出概率分布与所述前景置信度的乘积以及所述教师模型输出概率分布,计算所述输出蒸馏损失。
9.根据权利要求5所述的方法,其特征在于,所述确定输出蒸馏损失,包括:
对所述学生模型的输出概率与所述教师模型的输出概率归一化,并基于归一化的结果确定所述输出蒸馏损失。
10.根据权利要求1-3或5-9中任一项所述的方法,其特征在于,所述依据所述特征蒸馏损失及所述输出蒸馏损失改进所述学生模型的损失函数,包括:
依据所述学生模型的改进前的检测损失、所述输出蒸馏损失以及所述特征蒸馏损失,算所述学生模型的改进后的损失函数。
11.一种基于知识蒸馏的模型训练装置,其特征在于,知识蒸馏中所涉及的教师模型与学生模型为异构模型,所述装置包括:
教师模型获取模块,用于获取已经训练好的所述教师模型;
特征蒸馏模块,用于将所述学生模型特征层的特征尺度与所述教师模型特征层的特征尺度对齐,并确定特征蒸馏损失,所述特征蒸馏损失表示特征尺度对齐后,所述教师模型与所述学生模型特在特征层的差异程度;
输出蒸馏模块,用于将所述学生模型输出层的预测框及对应的概率分布与所述教师模型输出层的预测框及对应的概率分布对齐,并确定输出蒸馏损失,所述输出蒸馏损失表示预测框及对应的概率分布对齐后,所述教师模型与所述学生模型在输出层的差异程度,所述概率分布为所述预测框的输出的概率分布;
学生模型训练模块,用于依据所述特征蒸馏损失及所述输出蒸馏损失改进所述学生模型的损失函数;并基于改进后的所述损失函数训练所述学生模型,得到训练完成的模型。
12.一种电子设备,其特征在于,所述电子设备包括存储器和处理器;所述存储器和所述处理器耦合;所述存储器用于存储计算机程序代码,所述计算机程序代码包括计算机指令;
其中,当所述处理器执行所述计算机指令时,使得所述电子设备执行如权利要求1-10中任意一项所述的基于知识蒸馏的模型训练方法。
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,当所述计算机指令在电子设备上运行时,使得所述电子设备执行如权利要求1-10中任一项所述的基于知识蒸馏的模型训练方法。
CN202211608782.8A 2022-12-14 2022-12-14 基于知识蒸馏的模型训练方法、装置及电子设备 Pending CN115953643A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211608782.8A CN115953643A (zh) 2022-12-14 2022-12-14 基于知识蒸馏的模型训练方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211608782.8A CN115953643A (zh) 2022-12-14 2022-12-14 基于知识蒸馏的模型训练方法、装置及电子设备

Publications (1)

Publication Number Publication Date
CN115953643A true CN115953643A (zh) 2023-04-11

Family

ID=87286974

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211608782.8A Pending CN115953643A (zh) 2022-12-14 2022-12-14 基于知识蒸馏的模型训练方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN115953643A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116778300A (zh) * 2023-06-25 2023-09-19 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、系统和存储介质
CN117372685A (zh) * 2023-12-08 2024-01-09 深圳须弥云图空间科技有限公司 目标检测方法、装置、电子设备及存储介质
CN117576381A (zh) * 2024-01-16 2024-02-20 深圳华付技术股份有限公司 目标检测训练方法及电子设备、计算机可读存储介质

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116778300A (zh) * 2023-06-25 2023-09-19 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、系统和存储介质
CN116778300B (zh) * 2023-06-25 2023-12-05 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、系统和存储介质
CN117372685A (zh) * 2023-12-08 2024-01-09 深圳须弥云图空间科技有限公司 目标检测方法、装置、电子设备及存储介质
CN117372685B (zh) * 2023-12-08 2024-04-16 深圳须弥云图空间科技有限公司 目标检测方法、装置、电子设备及存储介质
CN117576381A (zh) * 2024-01-16 2024-02-20 深圳华付技术股份有限公司 目标检测训练方法及电子设备、计算机可读存储介质

Similar Documents

Publication Publication Date Title
CN111797893B (zh) 一种神经网络的训练方法、图像分类系统及相关设备
CN115953643A (zh) 基于知识蒸馏的模型训练方法、装置及电子设备
CN108197326B (zh) 一种车辆检索方法及装置、电子设备、存储介质
WO2020042658A1 (zh) 数据处理方法、装置、设备和系统
CN112396106B (zh) 内容识别方法、内容识别模型训练方法及存储介质
CN109919073B (zh) 一种具有光照鲁棒性的行人再识别方法
CN112668588B (zh) 车位信息生成方法、装置、设备和计算机可读介质
CN113361710B (zh) 学生模型训练方法、图片处理方法、装置及电子设备
WO2023273628A1 (zh) 一种视频循环识别方法、装置、计算机设备及存储介质
CN115699082A (zh) 缺陷检测方法及装置、存储介质及电子设备
CN115082752A (zh) 基于弱监督的目标检测模型训练方法、装置、设备及介质
CN115018039A (zh) 一种神经网络蒸馏方法、目标检测方法以及装置
CN113326826A (zh) 网络模型的训练方法、装置、电子设备及存储介质
CN114550053A (zh) 一种交通事故定责方法、装置、计算机设备及存储介质
CN111242176A (zh) 计算机视觉任务的处理方法、装置及电子系统
CN116362294B (zh) 一种神经网络搜索方法、装置和可读存储介质
CN112288702A (zh) 一种基于车联网的道路图像检测方法
CN114170484B (zh) 图片属性预测方法、装置、电子设备和存储介质
WO2022127576A1 (zh) 站点模型更新方法及系统
CN112364946B (zh) 图像确定模型的训练方法、图像确定的方法、装置和设备
CN111461228B (zh) 图像推荐方法和装置及存储介质
CN112001211B (zh) 对象检测方法、装置、设备及计算机可读存储介质
CN113762042A (zh) 视频识别方法、装置、设备以及存储介质
CN111860331A (zh) 无人机在安防未知域的人脸识别系统
CN113627241B (zh) 一种用于行人重识别的背景抑制方法与系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination