CN113240120A - 基于温习机制的知识蒸馏方法、装置、计算机设备和介质 - Google Patents

基于温习机制的知识蒸馏方法、装置、计算机设备和介质 Download PDF

Info

Publication number
CN113240120A
CN113240120A CN202110495734.1A CN202110495734A CN113240120A CN 113240120 A CN113240120 A CN 113240120A CN 202110495734 A CN202110495734 A CN 202110495734A CN 113240120 A CN113240120 A CN 113240120A
Authority
CN
China
Prior art keywords
network
student network
student
distillation
intermediate layer
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110495734.1A
Other languages
English (en)
Inventor
陈鹏光
刘枢
贾佳亚
沈小勇
吕江波
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Smartmore Technology Co Ltd
Shanghai Smartmore Technology Co Ltd
Original Assignee
Shenzhen Smartmore Technology Co Ltd
Shanghai Smartmore Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Smartmore Technology Co Ltd, Shanghai Smartmore Technology Co Ltd filed Critical Shenzhen Smartmore Technology Co Ltd
Priority to CN202110495734.1A priority Critical patent/CN113240120A/zh
Publication of CN113240120A publication Critical patent/CN113240120A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation
    • G06N5/022Knowledge engineering; Knowledge acquisition
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation
    • G06N5/027Frames
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • G06V20/20Scenes; Scene-specific elements in augmented reality scenes

Abstract

本申请涉及一种基于温习机制的知识蒸馏方法、装置、计算机设备和存储介质。本申请能够将教师网络不同阶段之间的信息都能传递至学生网络,进而提高知识蒸馏的效果,提高分类准确度。该方法包括:获取训练数据;将训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将训练数据输入至学生网络,得到学生网络输出的学生网络预测结果和学生网络的各中间层的输出特征;根据学生网络中各中间层的输出特征与教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;根据学生网络预测结果确定基础损失部分;基于蒸馏损失部分和基础损失部分,对学生网络进行训练。

Description

基于温习机制的知识蒸馏方法、装置、计算机设备和介质
技术领域
本申请涉及人工智能技术领域,特别是涉及一种基于温习机制的知识蒸馏方法、装置、计算机设备和存储介质。
背景技术
卷积神经网络(CNN,Convolutional neural network)已被广泛应用于计算机视觉任务,并取得了显著的成功。然而,CNN的成功伴随着大量计算资源的使用,包括硬件、软件以及时间资源,因此,如何减少神经网络的计算开销成为当前人工智能中的重要研究领域。
目前,可通过设计新架构、网络剪枝、量化和知识蒸馏等技术手段实现CNN的资源消耗。其中,知识蒸馏(Knowledge Distillation)最早是由Hinton等人提出的一种预训练方法,主要思想是利用已经训练完全的比较大的教师网络(teacher网络)去辅助一个资源消耗小的学生网络(student网络)的训练,以便在计算机视觉任务处理中减少资源消耗,并能达到与教师网络(teacher网络)同样的任务处理效果。目前普遍采用的知识蒸馏方法中,例如FitNet网络,只在teacher网络和student网络同一个阶段的特征中进行知识蒸馏,仅实现了局部的知识迁移,无法有效地利用teacher网络的所有信息,使得student网络在学习的时候只能学习到一部分teacher网络的知识,从而导致了知识蒸馏的效果有限。
发明内容
基于此,有必要针对上述技术问题,提供一种基于温习机制的知识蒸馏方法、装置、计算机设备和存储介质。
一种基于温习机制的知识蒸馏方法,所述方法包括:
获取训练数据;
将所述训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将所述训练数据输入至学生网络,得到所述学生网络输出的学生网络预测结果和所述学生网络的各中间层的输出特征;
根据所述学生网络中各中间层的输出特征与所述教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;
根据所述学生网络预测结果确定基础损失部分;
基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练。
在其中一个实施例中,所述基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练,包括:
将所述蒸馏损失部分和基础损失部分的总和作为所述学生网络的总体损失;
基于所述总体损失调整所述学生网络的网络参数,直至所述总体损失满足预设条件。
在其中一个实施例中,所述蒸馏损失部分通过L2范数计算所述学生网络中各中间层的第一输出特征与所述教师网络中各中间层的第一输出特征的L2距离,并将各中间层对应的L2距离的累加作为所述蒸馏损失部分;所述第一输出特征是通过第一变换模块对原始输出特征进行变换得到的。
在其中一个实施例中,所述第一变换模块包括卷积层和最近插值层。
在其中一个实施例中,所述蒸馏损失部分使用基于温习机制的损失函数进行计算,所述基于温习机制的损失函数为:
Figure BDA0003054182450000021
Figure BDA0003054182450000022
其中,
Figure BDA0003054182450000023
表示学生网络第i个中间层的输出特征;
Figure BDA0003054182450000024
表示针对学生网络中第i个中间层使用特定变换模块进行处理后的所述第一输出特征;
Figure BDA0003054182450000025
表示教师网络第j个中间层的输出特征;
Figure BDA0003054182450000026
表示
Figure BDA0003054182450000027
Figure BDA0003054182450000028
之间的距离的累加;n为学生网络中中间层的总层数;
Figure BDA0003054182450000029
表示累加的距离;U为特征融合模块,
Figure BDA00030541824500000210
表示将学生网络中从第j层到第n层的输出特征进行融合得到的融合特征。
一种图像识别方法,所述方法包括:
利用上述任一种基于温习机制的知识蒸馏方法训练得到用于识别图像的所述学生网络;
获取图像;所述图像中包括待识别对象;
将所述图像输入至学生网络,以使所述学生网络输出所述待识别对象的类别标签。
一种基于温习机制的知识蒸馏装置,所述装置包括:
数据获取模块,用于获取训练数据;
中间层特征获取模块,用于将所述训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将所述训练数据输入至学生网络,得到所述学生网络输出的学生网络预测结果和所述学生网络的各中间层的输出特征;
蒸馏损失部分确定模块,用于根据所述学生网络中各中间层的输出特征与所述教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;
基础损失部分确定模块,用于根据所述学生网络预测结果确定基础损失部分;
学生网络训练模块,用于基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练。
一种图像识别装置,所述装置包括:
图像获取模块,用于获取图像;所述图像中包括不同类别的待识别对象;
类别标签输出模块,用于利用上述任一种基于温习机制的知识蒸馏方法训练得到用于识别图像的所述学生网络,将所述图像输入至学生网络,以使所述学生网络输出所述待识别对象的类别标签。
一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现上述任一种基于温习机制的知识蒸馏方法实施例以及图像识别方法实施例中的各步骤。
一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任一种基于温习机制的知识蒸馏方法实施例以及图像识别方法实施例中的各步骤。
上述基于温习机制的知识蒸馏方法、装置、计算机设备和存储介质,获取训练数据;将训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将训练数据输入至学生网络,得到学生网络输出的学生网络预测结果和学生网络的各中间层的输出特征;根据学生网络中各中间层的输出特征与教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;根据学生网络预测结果确定基础损失部分;基于蒸馏损失部分和基础损失部分,对学生网络进行训练。该方法能够将教师网络不同阶段之间的信息都能传递至学生网络,使得教师网络中即使是早期阶段的知识也能被学生网络学习得到,达到“温故而知新”的效果,进而提高知识蒸馏的效果,提高分类准确度。
附图说明
图1(a)为一个实施例中基于温习机制的知识蒸馏方法的温习机制示意图;
图1(b)为另一个实施例中基于温习机制的知识蒸馏方法的温习机制示意图;
图1(c)为又一个实施例中基于温习机制的知识蒸馏方法的温习机制示意图;
图2为一个实施例中基于温习机制的知识蒸馏方法的流程示意图;
图3为一个实施例中图像识别方法的流程示意图;
图4为一个实施例中基于温习机制的知识蒸馏装置的结构框图;
图5为一个实施例中图像识别装置的结构框图;
图6为一个实施例中计算机设备的内部结构图;
图7为另一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
本申请提供的基于温习机制的知识蒸馏方法,可以按照如图1所示的温习机制示意图来辅助理解。知识蒸馏是利用从一个大型模型或模型集合汇总提取的知识来训练一个紧凑的神经网络,我们称大模型或模型集合为教师网络,而称小而紧凑的模型为学生网络,教师网络通常对硬件的要求较高,往往需要大型的服务器或者服务器集群,而学生网络可运行于各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备。
在一个实施例中,如图2所示,提供了一种基于温习机制的知识蒸馏方法,包括以下步骤:
步骤S201,获取训练数据;
其中,训练数据是指用于模型构建的数据,在本申请中可以指各种图片,这些图片中包含了已经被人工标记或者机器标记了对象类别的类别标签,例如一幅图片中包含了人和车两类对象。
步骤S202,将所述训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将上述训练数据输入至学生网络,得到学生网络输出的学生网络预测结果和学生网络的各中间层的输出特征。
其中,学生网络预测结果是指学生网络对给定输入图像中待识别对象的分类结果,例如识别出图像中包含的人和车,并在输出图像中标记出来。
如图1(a)所示,对于图像分类任务,给定一张输入图像X和学生网络S(student网络),可以用YS=(X)表示学生网络S的输出,即学生网络预测结果。S(X)可以被分为不同的阶段(S1、S2、……Sn,Sc),其中,Sc表示分类器,S1、S2、……Sn表示网络的不同阶段,即不同的中间层,这些阶段通过下采样层来区分。因此,学生网络S的输出过程可以表示为:
Figure BDA0003054182450000051
其中,空心圆°表示函数的嵌套。YS是学生网络S最终的输出,可以用
Figure BDA0003054182450000052
来表示学生网络S中各中间层的输出特征,即第i个中间层的输出可以通过以下方式得到:
Figure BDA0003054182450000061
对于教师网络T(teacher网络),这一过程是类似的,在此不再赘述。
步骤S203,根据上述学生网络总各中间层的输出特征与教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分。
其中,蒸馏损失部分是指衡量教师网络各中间层输出特征与学生网络各中间层输出特征之间的差异分布的函数。
具体地,对于单层知识蒸馏网络(SKD,Single Knowledge Distilling),其损失函数(即蒸馏损失部分)为:
Figure BDA0003054182450000062
其中,M表示变换模块,用于将网络的特征转化为特定的表示方法,比如注意力特征或概率特征等。
Figure BDA0003054182450000063
表示学生网络中第i个中间层输出特征经过变换模块变换后得到的表示方法,
Figure BDA0003054182450000064
表示教师网络中第i个中间层输出特征经过变换模块变换后得到的表示方法。D是距离函数,用于衡量学生网络与教师网络特征的差异。类似地,可以利用单层知识蒸馏网络进行拓展,从而得到多层知识蒸馏网络(MKD,Multi-Knowledge Distilling)的蒸馏损失函数(即蒸馏损失部分)表达如下:
Figure BDA0003054182450000065
其中,集合I中存放了要用来蒸馏的特征编号。
进一步地,还可以选择使用“温习”机制(Review)对上述知识蒸馏网络进行优化,所谓“温习”机制是指利用教师网络的浅层特征来指导学生网络的深层特征的学习,这种机制和单层知识蒸馏进行结合之后,其蒸馏损失函数可以表示为:
Figure BDA0003054182450000066
同理可得,这种温习机制与多层知识蒸馏进行结合之后,其蒸馏损失函数可以表示为:
Figure BDA0003054182450000071
其中,
Figure BDA0003054182450000072
表示从第i层输出特征到第j层输出特征的特征变换,
Figure BDA0003054182450000073
表示从教师网络第j层输出特征到第i层输出特征的特征变换。在多层知识蒸馏网络中,当i固定时,即对于学生网络的第i层特征
Figure BDA0003054182450000074
教师网络对应的特征是在前i层进行总循环。
步骤S204,根据学生网络预测结果确定基础损失部分。
具体地,训练学生网络时,蒸馏损失函数是直接与原来的任务中的损失函数(即基础损失部分)加起来一起优化的,基础损失部分是根据学生网络预测结果与真实标签之间的距离来衡量的,一般使用交叉熵函数LCE,即交叉熵损失(Cross Entropy),可以使用一个参数来平衡蒸馏损失部分和原任务损失函数,以分类任务为例,可以将总体损失表示为:
L=LCE+λLMKD_R
步骤S205,基于上述蒸馏损失部分和基础损失部分,对学生网络进行训练。
具体地,在对学生网络进行训练时,不断调整网络中的各个参数使得总体损失L达到预设条件,例如达到极小值,最终得到的学生网络即为被认可的最终结果。
上述实施例,获取训练数据;将训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将训练数据输入至学生网络,得到学生网络输出的学生网络预测结果和学生网络的各中间层的输出特征;根据学生网络中各中间层的输出特征与教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;根据学生网络预测结果确定基础损失部分;基于蒸馏损失部分和基础损失部分,对学生网络进行训练。该方法能够教师网络不同阶段之间的信息都能传递至学生网络,使得教师网络中即使是早期阶段的知识也能被学生网络学习得到,学生网络在学习过程中,在学习教师网络的后面阶段的特征的同时,不断复习之前学过的知识,从中可以得到新的有用的信息,达到“温故而知新”的效果,进而提高知识蒸馏的效果,提高分类准确度。
在一实施例中,上述蒸馏损失部分通过L2范数计算学生网络中各中间层的第一输出特征与教师网络中各中间层的第一输出特征的L2距离,并将各中间层对应的L2距离的累加作为蒸馏损失部分;上述第一输出特征是通过第一变换模块对原始输出特征进行变换得到的。
具体地,可以使用L2距离作为距离函数D,L2距离即使用L2范数,也即欧几里得范数计算学生网络中各中间层的第一输出特征与教师网络中各中间层的第一输出特征的距离。
图1(a)展示了在单层知识蒸馏网络中实现温习机制的示意图,学生网络的特征被转换为与教师网络特征相同的大小。
图1(b)展示了多层知识蒸馏网络中的每个中间层都实现温习机制的示意图,多层指示蒸馏网络中的蒸馏损失函数为:
Figure BDA0003054182450000081
其中,M为第一变换模块,
Figure BDA0003054182450000082
为学生网络中对于第i层,通过第一变换模块对原始输出特征
Figure BDA0003054182450000083
进行变换后得到的第一输出特征,
Figure BDA0003054182450000084
为教师网络中对于第j层,通过第一变换模块对原始输出特征
Figure BDA0003054182450000085
进行变换后得到的第一输出特征。其中,第一变换模块可以简单地由卷积层和最近插值层组成,如图1中所示,可以是1×1卷积层、最近邻插值层以及3×3卷积层组成。
上述实施例给出了基于温习机制的知识蒸馏的具体细节,通过第一变换模块得到各个中间层的输出特征,并通过欧几里得距离计算学生网络和教师网络之间的距离,便于后续通过对该距离函数的处理得到训练好的学生网络。
在一实施例中,上述蒸馏损失部分还可以使用另一简化的基于温习机制的损失函数进行计算,所述基于温习机制的损失函数为:
Figure BDA0003054182450000091
Figure BDA0003054182450000092
其中,
Figure BDA0003054182450000093
表示学生网络第i个中间层的输出特征;
Figure BDA0003054182450000094
表示针对学生网络中第i个中间层使用特定变换模块进行处理后的所述第一输出特征;
Figure BDA0003054182450000095
表示教师网络第j个中间层的输出特征;
Figure BDA0003054182450000096
表示
Figure BDA0003054182450000097
Figure BDA0003054182450000098
之间的距离的累加;n为学生网络中中间层的总层数;
Figure BDA0003054182450000099
表示累加的距离;U为特征融合模块,
Figure BDA00030541824500000910
表示将学生网络中从第j层到第n层的输出特征进行融合得到的融合特征。
具体地,在图1(b)展示的多层知识蒸馏网络中,可以看出,当使用所有阶段(即所有中间层)的输出特征时,它给出了一个很复杂的过程,例如具有n个阶段(n个中间层)的网络需要计算n(n+1)/2对特征的L2距离,这比原始训练过程(即没有温习机制的训练过程)花费更多的资源。
为了简化这一过程,可将上述结构进行优化,将蒸馏损失函数定义为:
Figure BDA00030541824500000911
改变两个累加的顺序,可以得到:
Figure BDA00030541824500000912
当j固定时,可以近似地把距离的累加改为累加的距离,即上式中等号右边的部分可近似修改为累加的距离:
Figure BDA0003054182450000101
其中,
Figure BDA0003054182450000102
表示学生网络第i个中间层的输出特征;
Figure BDA0003054182450000103
表示针对学生网络中第i个中间层使用特定变换模块进行处理后的所述第一输出特征;
Figure BDA0003054182450000104
表示教师网络第j个中间层的输出特征;
Figure BDA0003054182450000105
表示
Figure BDA0003054182450000106
Figure BDA0003054182450000107
之间的距离的累加;n为学生网络中中间层的总层数;
Figure BDA0003054182450000108
表示累加的距离;U为特征融合模块,即图1(c)中的自定义符号
Figure BDA0003054182450000109
表示将学生网络中从第j层到第n层的输出特征进行融合得到的融合特征。
上述实施例,使用了一种特征融合模块U,可以将不同层的特征进行融合,且这个模块的计算量很小,可以有效减少图1(b)带来的额外代价。
在一实施例中,还提供了一种图像识别方法,如3所示,图3展示了该图像识别方法的流程示意图,该方法包括:
步骤S301,获取图像,该图像中包括待识别对象;
具体地,给定任意一幅图像,该图像可以是静态图片,还可以是视频流,图像中包含待识别对象,例如包含各类动物等。
步骤S302,将上述图像输入至学生网络,以使该学生网络输出待识别对象的类别标签,其中,该学生网络是利用上述各方法实施例中的方法训练得到的。
具体地,使用上述知识蒸馏方法训练得到一可用于分类任务的学生网络,将上述图像输入至该学生网络,该学生网络可识别出该图像中的待识别对象的类别标签。
上述实施例,通过使用上述各知识蒸馏方法实施例中的方法训练得到学生网络,该学生网络可以用于识别图像的对象并输出其类别标签,使得学生网络学到了更多教师网络提取的知识,进一步提高了网络性能。
进一步地,我们对各种任务进行实验。首先,我们在分类任务上将我们的方法与其他有关分类的知识蒸馏方法进行比较。我们尝试使用不同的设置来改变网络结构和数据集。同样,我们将我们的方法应用于目标检测和实例分割任务,我们的方法仍然极大地改进了其他方法的效果。
1.分类任务
数据集:(1)CIFAR-100数据集包含了50,000张训练图片,一共有100类,每类有500张图;还有10,000张测试图片,每类100张。(2)ImageNet数据集是目前最具有挑战性的图片分类数据集,它包含了超过1,200,000张图片,共有1,000类,每类有接近1,300张图片;还有50,000张测试图片,每类50张。
实验细节:在CIFAR-100数据集上,我们对不同的网络结构进行了实验,包括VGG,ResNet,WideResNet,MobileNet和ShuffleNet。我们采用了和之前方法相同的训练方法,只是线性增加了批次大小和学习率。具体的,我们一共训练240轮。初始学习率设置为0.1(MobielNet和ShuffleNet为0.02),从第150轮开始每30轮学习率下降10倍。在ImageNet数据集上,我们采用了标准的训练流程,一共训练100轮,每30轮将学习率下降10倍,初始学习率设置为0.1。
Figure BDA0003054182450000111
表1
表1总结了我们的方法在CIFAR-100上的实验结果。我们试验了多种不同的网络结构,涵盖了深度上的区别和宽度上的区别,可以看出我们的方法在所有的实验设置中都取得了最好的结果。我们的方法使用了多层知识蒸馏和温习机制,所采用的损失函数很简单,只是L2距离。相比较于同样使用了L2距离的FitNet方法,我们的方法取得了显著的提升,这说明了我们所提出的温习机制的优越性。
Figure BDA0003054182450000121
表2
表2中同样是CIFAR-100数据集上的结果,与表1不同的是,这里student网络和teacher网络具有不同的网络结构。这种实验设置是更加具有挑战性的。
在这种更加具有挑战性的实验中,我们的方法仍然取得了最好的结果。
Figure BDA0003054182450000122
表3中是ImageNet数据集的结果,ImageNet数据集包含了更多的种类和更大的图片,更加接近于真实图片的分布,是最具有代表性的分类数据集。在这个数据集上,我们同样进行了两种不同的设置。首先是对于teacher网络和student网络具有相同结构类型的实验,如(a)所示,我们的方法取得了最好的效果。其次是对于teacher网络和student网络具有不同的结构类型的实验,如(b)所示,我们的方法仍然具有显著的优势。
2.目标检测任务
我们还将我们的方法应用于其他计算机视觉任务。在目标检测方面,和分类任务的过程相似,我们在teacher网络和student网络的主干输出特征之间进行知识蒸馏。我们使用具有代表性的COCO2017数据集来评估我们的方法,并以最受欢迎的开源报告Detectron2作为我们的基础模型。我们使用Detrctron2提供的最好的预训练模型作为teacher网络。按照传统标准,我们使用标准训练政策对student网络进行训练。所有性能均在COCO2017验证集上进行评估。我们进行了两阶段检测器和一阶段检测器的实验。
Figure BDA0003054182450000131
表4
由于只有少数几种方法可用于目标检测任务,因此我们将挑选了其中最具有代表性和最新的方法进行比较。表4中给出了比较。我们注意到,在分类任务山传统的知识蒸馏方法(例如KD和FitNet)也可以提高检测性能。但是收益是有限的。FGFI是一种直接设计用于目标检测的方法,在此任务上比其他方法效果更好。尽管如此,我们的方法仍取得了优于它的效果。
我们还更改了实验设置以检查一般性。在两阶段检测器FasterRCNN上,我们更改了骨干结构。在同样式的体系结构之间的知识蒸馏中,我们将ResNet18和ResNet50的mAP分别提高了3.49和2.43。ResNet50和MobileNetV2之间的知识蒸馏仍将基线从29.47提高到33.71。在RetinaNet这种一阶段检测器上,student网络和teacher网络之间的精度差距很小,我们的方法仍将mAP提高了2.33。在具有挑战性目标检测任务上的成功证明了我们方法的普遍性和有效性。
3.实例分割
在本节中,我们将我们的方法应用于更具挑战性的实例分割任务。据我们所知,这是知识蒸馏方法首次应用于实例分割。我们仍然使用了Detectron2提供的强大基础模型。我们以Mask R-CNN为基础,并在不同的主干架构之间进行知识蒸馏。这些模型在COCO2017训练集上进行训练,并在其验证集上进行评估。结果示于表5。
Figure BDA0003054182450000141
表5
我们的方法仍然显着提高了实例分割任务的性能。对于相同样式的架构之间的知识蒸馏,我们将ResNet18和ResNet50的性能提高了2.37和1.74,并将student网络和teacher网络之间的差距相对减少了32%和51%。即使是对不同样式的体系结构的知识蒸馏,我们也将MobileNetV2改进了3.19。
我们的方法在所有图像分类,对象检测和实例分割任务上都表现出色,并超过了所有其他模型的结果,这一事实证明了我们方法的卓越功效和适用性。
应该理解的是,虽然图1-3的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1-3中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
在一个实施例中,如图4所示,提供了一种基于温习机制的知识蒸馏装置400,包括:数据获取模块401、中间层特征获取模块402、蒸馏损失部分确定模块403、基础损失部分确定模块404和学生网络训练模块405,其中:
数据获取模块401,用于获取训练数据;
中间层特征获取模块402,用于将所述训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将所述训练数据输入至学生网络,得到所述学生网络输出的学生网络预测结果和所述学生网络的各中间层的输出特征;
蒸馏损失部分确定模块403,用于根据所述学生网络中各中间层的输出特征与所述教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;
基础损失部分确定模块404,用于根据所述学生网络预测结果确定基础损失部分;
学生网络训练模块405,用于基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练。
在一实施例中,上述学生网络训练模块405进一步用于:将所述蒸馏损失部分和基础损失部分的总和作为所述学生网络的总体损失;基于所述总体损失调整所述学生网络的网络参数,直至所述总体损失满足预设条件。
在一实施例中,上述蒸馏损失部分通过L2范数计算所述学生网络中各中间层的第一输出特征与所述教师网络中各中间层的第一输出特征的L2距离,并将各中间层对应的L2距离的累加作为所述蒸馏损失部分;所述第一输出特征是通过第一变换模块对原始输出特征进行变换得到的。
在一实施例中,上述第一变换模块包括卷积层和最近插值层。
在一实施例中,上述蒸馏损失部分使用基于温习机制的损失函数进行计算,所述基于温习机制的损失函数为:
Figure BDA0003054182450000161
Figure BDA0003054182450000162
其中,
Figure BDA0003054182450000163
表示学生网络第i个中间层的输出特征;
Figure BDA0003054182450000164
表示针对学生网络中第i个中间层使用特定变换模块进行处理后的所述第一输出特征;
Figure BDA0003054182450000165
表示教师网络第j个中间层的输出特征;
Figure BDA0003054182450000166
表示
Figure BDA0003054182450000167
Figure BDA0003054182450000168
之间的距离的累加;n为学生网络中中间层的总层数;
Figure BDA0003054182450000169
表示累加的距离;U为特征融合模块,
Figure BDA00030541824500001610
表示将学生网络中从第j层到第n层的输出特征进行融合得到的融合特征。
在一实施例中,还提供了一种图像识别装置500,如图5所示,该装置包括图像获取模块501和类别标签输出模块502,其中:
图像获取模块501,用于获取图像;所述图像中包括不同类别的待识别对象;
类别标签输出模块502,用于上述基于温习机制的知识蒸馏方法实施例中的步骤训练得到用于识别图像的所述学生网络,将所述图像输入至学生网络,以使所述学生网络输出所述待识别对象的类别标签。
关于基于温习机制的知识蒸馏装置和图像识别装置的具体限定可以参见上文中对于基于温习机制的知识蒸馏方法和图像识别方法的限定,在此不再赘述。上述基于温习机制的知识蒸馏装置和图像识别装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图6所示。该计算机设备包括通过系统总线连接的处理器、存储器和网络接口。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储各中间层输出特征数据以及图像预测结果。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种基于温习机制的知识蒸馏方法或图像识别方法。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是终端,其内部结构图可以如图7所示。该计算机设备包括通过系统总线连接的处理器、存储器、通信接口、显示屏和输入装置。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和计算机程序。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的通信接口用于与外部的终端进行有线或无线方式的通信,无线方式可通过WIFI、运营商网络、NFC(近场通信)或其他技术实现。该计算机程序被处理器执行时以实现一种基于温习机制的知识蒸馏方法或图像识别方法。该计算机设备的显示屏可以是液晶显示屏或者电子墨水显示屏,该计算机设备的输入装置可以是显示屏上覆盖的触摸层,也可以是计算机设备外壳上设置的按键、轨迹球或触控板,还可以是外接的键盘、触控板或鼠标等。
本领域技术人员可以理解,图6-7中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现如上述的基于温习机制的知识蒸馏方法实施例和图像识别方法实施例中的步骤。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现如上述的基于温习机制的知识蒸馏方法实施例和图像识别方法实施例中的步骤。本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-Only Memory,ROM)、磁带、软盘、闪存或光存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic Random Access Memory,DRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。

Claims (10)

1.一种基于温习机制的知识蒸馏方法,其特征在于,所述方法包括:
获取训练数据;
将所述训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将所述训练数据输入至学生网络,得到所述学生网络输出的学生网络预测结果和所述学生网络的各中间层的输出特征;
根据所述学生网络中各中间层的输出特征与所述教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;
根据所述学生网络预测结果确定基础损失部分;
基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练。
2.根据权利要求1所述的方法,其特征在于,所述基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练,包括:
将所述蒸馏损失部分和基础损失部分的总和作为所述学生网络的总体损失;
基于所述总体损失调整所述学生网络的网络参数,直至所述总体损失满足预设条件。
3.根据权利要求2所述的方法,其特征在于,所述蒸馏损失部分通过L2范数计算所述学生网络中各中间层的第一输出特征与所述教师网络中各中间层的第一输出特征的L2距离,并将各中间层对应的L2距离的累加作为所述蒸馏损失部分;所述第一输出特征是通过第一变换模块对原始输出特征进行变换得到的。
4.根据权利要求3所述的方法,其特征在于,所述第一变换模块包括卷积层和最近插值层。
5.根据权利要求1至4任一项所述的方法,其特征在于,所述蒸馏损失部分使用基于温习机制的损失函数进行计算,所述基于温习机制的损失函数为:
Figure FDA0003054182440000011
Figure FDA0003054182440000012
其中,
Figure FDA0003054182440000021
表示学生网络第i个中间层的输出特征;
Figure FDA0003054182440000022
表示针对学生网络中第i个中间层使用特定变换模块进行处理后的所述第一输出特征;
Figure FDA0003054182440000023
表示教师网络第j个中间层的输出特征;
Figure FDA0003054182440000024
表示
Figure FDA0003054182440000025
Figure FDA0003054182440000026
之间的距离的累加;n为学生网络中中间层的总层数;
Figure FDA0003054182440000027
表示累加的距离;U为特征融合模块,
Figure FDA0003054182440000028
表示将学生网络中从第j层到第n层的输出特征进行融合得到的融合特征。
6.一种图像识别方法,其特征在于,所述方法包括:
利用如权利要求1至5任一项所述的方法训练得到用于识别图像的所述学生网络;
获取图像;所述图像中包括待识别对象;
将所述图像输入至所述学生网络,以使所述学生网络输出所述待识别对象的类别标签。
7.一种基于温习机制的知识蒸馏装置,其特征在于,所述装置包括:
数据获取模块,用于获取训练数据;
中间层特征获取模块,用于将所述训练数据输入至教师网络,得到教师网络的各中间层的输出特征,以及将所述训练数据输入至学生网络,得到所述学生网络输出的学生网络预测结果和所述学生网络的各中间层的输出特征;
蒸馏损失部分确定模块,用于根据所述学生网络中各中间层的输出特征与所述教师网络中各中间层的输出特征的距离的累加确定蒸馏损失部分;
基础损失部分确定模块,用于根据所述学生网络预测结果确定基础损失部分;
学生网络训练模块,用于基于所述蒸馏损失部分和基础损失部分,对所述学生网络进行训练。
8.一种图像识别装置,其特征在于,所述装置包括:
图像获取模块,用于获取图像;所述图像中包括不同类别的待识别对象;
类别标签输出模块,用于利用如权利要求1至5任一项所述的方法训练得到用于识别图像的所述学生网络,将所述图像输入至学生网络,以使所述学生网络输出所述待识别对象的类别标签。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至6中任一项所述的方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至6中任一项所述的方法的步骤。
CN202110495734.1A 2021-05-07 2021-05-07 基于温习机制的知识蒸馏方法、装置、计算机设备和介质 Pending CN113240120A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110495734.1A CN113240120A (zh) 2021-05-07 2021-05-07 基于温习机制的知识蒸馏方法、装置、计算机设备和介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110495734.1A CN113240120A (zh) 2021-05-07 2021-05-07 基于温习机制的知识蒸馏方法、装置、计算机设备和介质

Publications (1)

Publication Number Publication Date
CN113240120A true CN113240120A (zh) 2021-08-10

Family

ID=77132331

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110495734.1A Pending CN113240120A (zh) 2021-05-07 2021-05-07 基于温习机制的知识蒸馏方法、装置、计算机设备和介质

Country Status (1)

Country Link
CN (1) CN113240120A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113657483A (zh) * 2021-08-14 2021-11-16 北京百度网讯科技有限公司 模型训练方法、目标检测方法、装置、设备以及存储介质
CN114298224A (zh) * 2021-12-29 2022-04-08 云从科技集团股份有限公司 图像分类方法、装置以及计算机可读存储介质
CN115601536A (zh) * 2022-12-02 2023-01-13 荣耀终端有限公司(Cn) 一种图像处理方法及电子设备
CN116205290A (zh) * 2023-05-06 2023-06-02 之江实验室 一种基于中间特征知识融合的知识蒸馏方法和装置

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113657483A (zh) * 2021-08-14 2021-11-16 北京百度网讯科技有限公司 模型训练方法、目标检测方法、装置、设备以及存储介质
CN114298224A (zh) * 2021-12-29 2022-04-08 云从科技集团股份有限公司 图像分类方法、装置以及计算机可读存储介质
CN115601536A (zh) * 2022-12-02 2023-01-13 荣耀终端有限公司(Cn) 一种图像处理方法及电子设备
CN116205290A (zh) * 2023-05-06 2023-06-02 之江实验室 一种基于中间特征知识融合的知识蒸馏方法和装置
CN116205290B (zh) * 2023-05-06 2023-09-15 之江实验室 一种基于中间特征知识融合的知识蒸馏方法和装置

Similar Documents

Publication Publication Date Title
CN110580482B (zh) 图像分类模型训练、图像分类、个性化推荐方法及装置
CN113240120A (zh) 基于温习机制的知识蒸馏方法、装置、计算机设备和介质
WO2021022521A1 (zh) 数据处理的方法、训练神经网络模型的方法及设备
JP2022505775A (ja) 画像分類モデルの訓練方法、画像処理方法及びその装置、並びにコンピュータプログラム
CN111507378A (zh) 训练图像处理模型的方法和装置
WO2021139191A1 (zh) 数据标注的方法以及数据标注的装置
WO2021147325A1 (zh) 一种物体检测方法、装置以及存储介质
WO2022001805A1 (zh) 一种神经网络蒸馏方法及装置
CN111782840B (zh) 图像问答方法、装置、计算机设备和介质
Liu et al. Traffic-light sign recognition using Capsule network
CN112801236B (zh) 图像识别模型的迁移方法、装置、设备及存储介质
CN113177559B (zh) 结合广度和密集卷积神经网络的图像识别方法、系统、设备及介质
CN113516227B (zh) 一种基于联邦学习的神经网络训练方法及设备
Ye et al. Steering angle prediction YOLOv5-based end-to-end adaptive neural network control for autonomous vehicles
CN115601692A (zh) 数据处理方法、神经网络模型的训练方法及装置
US11651191B2 (en) Methods, apparatuses, and computer program products using a repeated convolution-based attention module for improved neural network implementations
CN116310318A (zh) 交互式的图像分割方法、装置、计算机设备和存储介质
CN115238909A (zh) 一种基于联邦学习的数据价值评估方法及其相关设备
CN114332484A (zh) 关键点检测方法、装置、计算机设备和存储介质
Arun Prasath et al. Prediction of sign language recognition based on multi layered CNN
CN115292439A (zh) 一种数据处理方法及相关设备
WO2023207531A1 (zh) 一种图像处理方法及相关设备
Cai et al. Pedestrian detection algorithm in traffic scene based on weakly supervised hierarchical deep model
CN115577768A (zh) 半监督模型训练方法和装置
Panda et al. Feedback through emotion extraction using logistic regression and CNN

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
CB03 Change of inventor or designer information

Inventor after: Chen Pengguang

Inventor after: Liu Shu

Inventor after: Shen Xiaoyong

Inventor after: Lv Jiangbo

Inventor before: Chen Pengguang

Inventor before: Liu Shu

Inventor before: Jia Jiaya

Inventor before: Shen Xiaoyong

Inventor before: Lv Jiangbo

CB03 Change of inventor or designer information
RJ01 Rejection of invention patent application after publication

Application publication date: 20210810

RJ01 Rejection of invention patent application after publication