CN117095217A - 多阶段对比知识蒸馏方法 - Google Patents
多阶段对比知识蒸馏方法 Download PDFInfo
- Publication number
- CN117095217A CN117095217A CN202311064055.4A CN202311064055A CN117095217A CN 117095217 A CN117095217 A CN 117095217A CN 202311064055 A CN202311064055 A CN 202311064055A CN 117095217 A CN117095217 A CN 117095217A
- Authority
- CN
- China
- Prior art keywords
- model
- loss
- student
- teacher
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 43
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 16
- 230000000052 comparative effect Effects 0.000 title claims description 9
- 230000008569 process Effects 0.000 title description 14
- 238000012549 training Methods 0.000 claims abstract description 116
- 238000012545 processing Methods 0.000 claims abstract description 32
- 238000013145 classification model Methods 0.000 claims abstract description 9
- 238000009499 grossing Methods 0.000 claims description 25
- 230000000694 effects Effects 0.000 abstract description 4
- 230000006870 function Effects 0.000 description 25
- 238000013527 convolutional neural network Methods 0.000 description 8
- 238000003062 neural network model Methods 0.000 description 7
- 230000006835 compression Effects 0.000 description 4
- 238000007906 compression Methods 0.000 description 4
- 238000000605 extraction Methods 0.000 description 4
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000012512 characterization method Methods 0.000 description 3
- 238000006243 chemical reaction Methods 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 230000008859 change Effects 0.000 description 2
- 230000002708 enhancing effect Effects 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000011176 pooling Methods 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 239000002131 composite material Substances 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000004821 distillation Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种多阶段对比知识蒸馏方法。该方法包括:获取训练样本图像和与训练样本图像对应的理论分类标签,并对训练样本图像进行数据增强,得到至少一张待处理样本图像,基于至少一张待处理样本图像和理论分类标签,构建训练样本;基于预先训练完成的教师网络对训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出;将训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出;确定目标模型损失,并基于目标模型损失对待训练学生模型进行模型参数调整,得到图像分类模型。本技术方案,实现了训练完成的学生网络的性能进一步接近甚至超过教师网络性能的效果。
Description
技术领域
本发明涉及知识蒸馏技术领域,尤其涉及一种多阶段对比知识蒸馏方法。
背景技术
知识蒸馏是一种主流的模型压缩方法,其本质是从高性能的复杂网络(即,教师网络)中,抽取训练数据的概率分布传递给低性能的简单网络(即,学生网络),实现教师模型对学生模型的训练过程的指导,从而,提高学生模型的性能。
相关技术中,传统的知识蒸馏方法通常是应用教师网络输出的概率分布来训练学生网络,这种训练方式可能会存在无法充分将知识迁移到学生网络,难以充分利用教师网络蕴含的知识来提高轻量化网络的表征能力,使得知识迁移效果受限,神经网络模型压缩的准确度较低。
发明内容
本发明提供了一种多阶段对比知识蒸馏方法,以实现从教师网络传递了更加丰富的知识给学生网络,让学生网络的性能进一步接近甚至超过教师网络性能的效果,从而达到了简单模型高性能的目标。
根据本发明的一方面,提供了一种多阶段对比知识蒸馏方法,该方法包括:
获取训练样本图像和与所述训练样本图像对应的理论分类标签,并对所述训练样本图像进行数据增强,得到至少一张待处理样本图像,基于所述至少一张待处理样本图像和所述理论分类标签,构建训练样本;
基于预先训练完成的教师网络对所述训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出;
将所述训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出,其中,所述实际输出结果中包括实际分类结果和第二平滑化概率分布;
根据所述第一平滑化概率分布、多个所述第一模型输出、所述实际输出结果、多个所述第二模型输出以及所述理论分类标签,确定目标模型损失,并基于所述目标模型损失对所述待训练学生模型进行模型参数调整,得到图像分类模型。
本发明实施例的技术方案,通过获取训练样本图像和与训练样本图像对应的理论分类标签,并对训练样本图像进行数据增强,得到至少一张待处理样本图像,基于至少一张待处理样本图像和理论分类标签,构建训练样本。之后,基于预先训练完成的教师网络对训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出。进一步的,将训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出。最后,根据第一平滑化概率分布、多个第一模型输出、实际输出结果、多个第二模型输出以及理论分类标签,确定目标模型损失,并基于目标模型损失对待训练学生模型进行模型参数调整,得到图像分类模型,解决了相关技术中无法充分将知识迁移到学生网络,难以充分利用教师网络蕴含的知识来提高轻量化网络的表征能力,使得知识迁移效果受限,神经网络模型压缩的准确度较低等问题,实现了从教师网络传递了更加丰富的知识给学生网络,让学生网络的性能进一步接近甚至超过教师网络性能的效果,从而达到了简单模型高性能的目标。
应当理解,本部分所描述的内容并非旨在标识本发明的实施例的关键或重要特征,也不用于限制本发明的范围。本发明的其它特征将通过以下的说明书而变得容易理解。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是根据本发明实施例一提供的一种多阶段对比知识蒸馏方法的流程图;
图2是根据本发明实施例一提供的待训练教师网络训练过程的流程图;
图3是根据本发明实施例一提供的一种多阶段对比知识蒸馏方法的流程图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
实施例一
图1是本发明实施例一提供的一种多阶段对比知识蒸馏方法的流程图,本实施例可适用于基于预先训练完成的教师网络对待训练学生网络进行训练情况,该方法可以由多阶段对比知识蒸馏装置来执行,该多阶段对比知识蒸馏装置可以采用硬件和/或软件的形式实现,该多阶段对比知识蒸馏装置可配置于终端和/或服务器中。如图1所示,该方法包括:
S110、获取训练样本图像和与训练样本图像对应的理论分类标签,并对训练样本图像进行数据增强,得到至少一张待处理样本图像,基于至少一张待处理样本图像和理论分类标签,构建训练样本。
其中,训练样本图像可以是可以为摄像装置拍摄的图像;或者,图像重构模型重构出的图像;或者,在存储空间中预先存储的图像等。在本实施例中,训练样本图像可以是不同领域中开源的图像分类数据集中获取的图像。示例性的,图像分类数据集可以为CIFAR-100数据集或者ImageNet数据集等。理论分类标签可以是训练样本图像的真实所属类别。
在本实施例中,数据增强是指通过对已有数据添加微小改动或在已有数据的基础上重新创建合成数据,以增加数据量的方法。可选的,数据增强可以包括图像旋转、随机裁剪、图像翻转、颜色抖动以及高斯噪声等方法。
在实际应用中,在对待训练学生网络进行训练之前,可以构建多个训练样本。进而,可以根据已构建的训练样本来训练模型。具体来说,首先,可以获取训练样本图像,并对已获取的训练样本图像进行处理,得到与训练样本图像对应的理论分类标签。进一步的,可以根据预设数据增强方式对已获取的训练样本图像进行数据增强处理,得到至少一张待处理样本图像。之后,可以将至少一张待处理样本图像,以及与待处理样本图像对应的训练样本图像的理论分类标签作为一组训练样本,从而,可以基于上述方式构建丰富的训练样本。
需要说明的是,为了提高模型的准确性,可以尽可能多而丰富的获取训练样本。
还需说明的是,对已获取的训练样本图像进行数据增强,并根据数据增强后的图像和理论分类标签构建训练样本的好处在于:通过扩展了数据领域生成新的监督信号让教师网络生成了可以蕴含更多隐藏信息的特征图,从而提高了教师网络的性能,即分类准确率,同样也就提高了学生网络性能的可学习上限。
S120、基于预先训练完成的教师网络对训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出。
其中,教师网络可以是预先训练完成的,且与学生网络完成相同任务的高性能复杂网络模型。教师网络可以用于辅助训练对应的学生网络。教师网络可以是任意结构的神经网络模型,可选的,可以是深度卷积神经网络模型。在本实施例中,教师网络可以包括多个教师子模型、全连接层和分类器。其中,教师子模型可以包括教师主干网络模块和教师辅助网络模块。教师主干网络模块可以是包括多个卷积模块和多个池化层的卷积神经网络模型。教师辅助网络模块可以是与教师主干网络模块相连接的辅助分支网络。示例性的,教师辅助网络模块可以是包括一个卷积层的卷积神经网络。在实际应用中,每个教师子模型中的教师主干神经网络模块依次连接,进而,与全连接层和分类器连接;同时,对于每个教师子模型,教师子模型中的教师主干网络模块与教师辅助网络模块相连接,也就是说,教师主干神经网络的模型输出可以作为相应教师辅助网络模块的模型输入,从而,教师辅助网络模块的模型输出可以作为该教师子模型对应的第一模型输出。
需要说明的是,教师网络中所包括的教师子模型的数量可以是任意值,可选的,可以为3个。
在本实施例中,第一平滑化概率分布可以是经过平滑化处理的分类概率分布。在实际应用中,可以在分类器中预先部署平滑化处理算法,进而,可以在得到全连接层的模型输出之后,可以根据分类器中的平滑化处理算法对该模型输出进行平滑处理,从而,可以得到平滑化概率分布。本领域技术人员应用理解,在基于神经网络模型执行分类任务的过程中,一般情况下,在经过分类器处理后,可以输出至少一个类别对应的概率值,进而,可以将概率值最大的类别作为分类结果输出。然而,在知识蒸馏过程中,需要抽取教师网络中的模型输出对学生网络进行训练,在得到至少一个类别对应的概率值的情况下,虽然,任意类别对应的概率值小到可以忽略,但是,这些概率值也可以传递教师网络学习到的知识,因此,为了可以将分类器输出的全部类别对应的概率值应用在学生网络的训练过程中,使得训练完成的学生网络具有更强的泛化能力,可以对概率值进行平滑处理,以在不改变原有概率分布的情况下,改变概率值的大小,使其具有可比性。在经过平滑处理之后,平滑处理中所涉及的温度系数越高,分类器输出的平滑化概率分布越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将会更加关注负标签。
在实际应用中,在得到训练样本之后,可以将训练样本输入至预先训练完成的教师网络中,进而,可以根据教师网络对训练样本中的待处理样本图像进行处理,得到第一平滑化概率分布以及与多个教师子模型对应的第一模型输出。
可选的,基于预先训练完成的教师网络对训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出,包括:基于教师网络中的多个教师主干网络模块、全连接层以及分类器依次对训练样本进行处理,得到第一平滑化概率分布;针对每个教师子模型,将教师子模型中的教师主干网络模块的模型输出作为教师子模型的教师辅助网络模块的模型输入,并将教师辅助网络模块的模型输出作为教师子模型对应的第一模型输出。
在实际应用中,在得到训练样本之后,可以将训练样本输入至教师网络中,首先,基于首个教师子模型中的教师主干网络模块对训练样本中的待处理样本图像进行特征提取,得到第一教师特征,之后,基于第二个教师子模型中的教师主干网络模块对第一教师特征进行处理,得到第二教师特征,之后,基于第三个教师子模型中的教师主干网络模块对第二教师特征进行处理,得到第三教师特征,以此类推,在经过教师网络中所包括的多个教师主干网络模块处理之后,可以得到待处理教师特征。进一步的,可以基于全连接层和分类器依次对待处理教师特征进行处理,得到第一平滑化概率分布。同时,对于教师网络中包括的每个教师子模型,在得到教师子模型中教师主干网络模块的模型输出之后,可以将该模型输出输入至与该教师主干网络模块相连接的教师辅助网络模块中,以基于教师辅助网络模块对其进行处理,从而,可以得到与教师子模型对应的第一模型输出。具体来说,以首个教师子模型为例,在首个教师子模型中的教师主干网络模块对待处理样本图像进行特征提取,得到第一教师特征的情况下,在将第一教师特征输入至下一教师子模型中的教师主干网络模块的同时,也可以将第一教师特征输入至首个教师子模型中的教师辅助网络模块中。进而,可以根据教师辅助网络模块对第一教师特征进行处理,得到与首个教师子模型对应的第一模型输出。需要说明的是,对于教师网络中所包括的每个教师子模型,其对应的第一模型输出的确定方式均是相同的,本实施例在此不再具体赘述。
需要说明的是,在应用本实施例提供的教师网络之前,可以先对待训练教师网络进行训练。下面可以对待训练教师网络的训练过程进行说明:1、获取训练样本图像和与训练样本图像对应的理论分类标签,并对训练样本图像进行数据增强,得到至少一张待处理样本图像,基于至少一张待处理样本图像和理论分类标签,构建训练样本;2、将训练样本输入至待训练教师网络中,得到实际分类结果和与多个教师子模型对应的模型输出;3、根据实际分类结果与理论分类标签,确定第一模型损失;4、根据多个教师子模型对应的模型输出,确定第二模型损失;5、根据第一模型损失和第二模型损失,确定目标模型损失,并基于目标模型损失对待训练教师网络的模型参数进行调整,得到训练完成的教师网络。
示例性的,如图2所示,即为待训练教师网络的训练过程流程图:1、对训练样本图像进行数据增强,并构建训练样本;2、将训练样本输入至待训练教师网络中,经过待训练教师网络中的主干卷积神经网络以及辅助网络对训练样本进行特征提取;3、输出分类概率分布以及多个辅助网络对应的模型输出;4、确定目标模型损失,并根据梯度下降算法对模型参数进行更新,得到训练完成的高性能教师网络。
S130、将训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出。
在本实施例中,待训练学生网络中的模型参数可以是默认值。通过训练样本和预先训练完成的教师网络对待训练学生网络中的模型参数进行修正,以得到训练完成的学生网络。学生网络可以是低性能且模型结构简单的神经网络模型。待训练学生网络可以是与教师网络具有相同网络结构的神经网络模型。在教师网络为深层卷积神经网络模型的情况下,待训练学生网络同样为深层卷积神经网络模型。其中,待训练学生网络可以包括多个学生子模型、全连接层以及分类器。学生子模型可以包括学生主干网络模块和学生辅助网络模块。学生主干网络模块可以是包括多个卷积模块和多个池化层的卷积神经网络模型。学生辅助网络模块可以是与主干网络模块相连接的辅助分支网络。示例性的,学生辅助网络模块可以是包括一个卷积层的卷积神经网络。在实际应用中,每个学生子模型中的学生主干网络模块依次连接,进而,与全连接层和分类器连接;同时,对于每个学生子模型,学生子模型中的学生主干网络模块与学生辅助网络模块相连接,也就是说,学生主干神经网络的模型输出可以作为相应学生辅助网络模块的模型输入,从而,学生辅助网络模块的模型输出可以作为该学生子模型对应的第二模型输出。
在本实施例中,实际输出结果中包括实际分类结果和第二平滑化概率分布。实际分类结果是将训练样本输入至待训练学生网络后输出的图像分类类别。第二平滑化概率分布可以是将训练样本输入至待训练学生网络后输出的,经过平滑化处理的分类概率分布。
在实际应用中,在得到训练样本之后,可以将训练样本输入至待训练学生网络中,进而,可以基于待训练学生网络对训练样本中的待训练样本图像进行处理,得到实际输出结果和与每个学生子模型对应的第二模型输出。
可选的,将训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出,包括:基于待训练学生网络中的多个学生主干网络模块、全连接层以及分类器对训练样本进行处理,得到实际输出结果;针对每个学生子模型,将学生子模型中的学生主干网络模块的模型输出作为学生子模型的学生辅助网络模块的模型输入,并将学生辅助网络模块的模型输出作为学生子模型对应的第二模型输出。
在实际应用中,在得到训练样本之后,可以将训练样本输入至待训练学生网络中,以基于待训练学生网络对训练样本进行处理。首先,基于首个学生子模型中的学生主干网络模块对训练样本中的待处理样本图像进行特征提取,得到第一学生特征,之后,基于第二个学生子模型中的学生主干网络模块对第一学生特征进行处理,得到第二学生特征,之后,基于第三个学生子模型中的学生主干网络模块对第二学生特征进行处理,得到第三学生特征。以此类推,在经过待训练学生网络中所包括的多个学生主干网络模块处理之后,可以得到待处理学生特征。进一步的,可以基于全连接层和分类器依次对待处理学生特征进行护理,得到实际输出结果。同时,对于待训练学生网络中包括的每个学生子模型,在得到学生子模型中学生主干网络模块的模型输出之后,可以将该模型输出输入至与该学生主干网络模块相连接的学生辅助网络模块中,以基于学生辅助网络模块对其进行处理。从而,可以得到与学生子模型对应的第二模型输出。具体来说,以首个学生子模型为例,在首个学生子模型中的学生主干网络模块对待处理样本图像进行特征提取,得到第一学生特征的情况下,在将第一学生特征输入至下一学生子模型中的学生主干网络模块的同时,也可以将第一学生特征输入至首个学生子模型中的学生辅助网络模块中。进而,可以根据学生辅助网络模块对第一学生特征进行处理,得到与首个学生子模型对应的第二模型输出。需要说明的是,对于待训练学生网络中所包括的每个学生子模型,其对应的第二模型输出的确定方式均是相同的,本实施例在此不再具体赘述。
S140、根据第一平滑化概率分布、多个第一模型输出、实际输出结果、多个第二模型输出以及理论分类标签,确定目标模型损失,并基于目标模型损失对待训练学生模型进行模型参数调整,得到图像分类模型。
在本实施例中,在得到第一平滑化概率分布、与每个教师子模型对应的第一模型输出、实际输出结果以及与每个学生子模型对应的第二模型输出之后,即可根据第一平滑化概率分布、与每个教师子模型对应的第一模型输出、实际输出结果、与每个学生子模型对应的第二模型输出以及训练样本中的理论分类标签,确定目标模型损失。其中,目标模型损失可以理解为待训练学生网络在进行模型参数修正时所依据的损失值。
在实际应用中,实际输出结果中包括实际分类结果和第二平滑化概率分布。在基于预先训练完成的教师网络对待训练学生网络进行训练的情况下,可以根据实际分类结果与理论分类标签之间的差异值,确定待训练学生网络对应的分类损失;根据第一平滑化概率分布和第二平滑化概率分布之间的差异,确定教师网络与待训练学生网络之间针对分类预测结果的对比损失;根据多个第一模型输出与多个第二模型输出之间的差异,确定教师网络与待训练学生网络之间针对辅助网络模块的输出的对比损失;根据多个第二模型输出,确定待训练学生网络辅助网络模块到的输出与自监督标签之间的额外损失。进一步的,可以根据上述四种损失,确定目标模型损失。
可选的,根据第一平滑化概率分布、多个第一模型输出、实际输出结果、多个第二模型输出以及理论分类标签,确定目标模型损失,包括:根据实际输出结果中的实际分类结果和理论分类标签,确定第一模型损失;根据多个第二模型输出,确定第二模型损失;根据第一平滑化概率分布和实际输出结果中的第二平滑化概率分布,确定第三模型损失;根据多个第一模型输出和多个第二模型输出,确定第四模型损失;根据第一模型损失、第二模型损失、第三模型损失以及第四模型损失,得到目标模型损失。
在本实施例中,第一模型损失可以为实际分类结果和理论分类标签之间的差异值。第二模型损失可以为待训练学生网络中多个辅助网络模块的第二模型输出与自监督标签之间的差异值。第三模型损失可以为教师网络输出的第一平滑化概率分布和待训练学生网络输出的第二平滑化概率分布之间的差异值。第四模型损失可以为教师网络辅助网络模块的第一模型输出与待训练学生网络辅助网络模块的第二模型输出之间的差异值。
可选的,根据实际输出结果中的实际分类结果和理论分类标签,确定第一模型损失,包括:根据预先设置的第一损失函数对实际分类结果和理论分类标签进行损失处理,得到第一模型损失。
在本实施例中,第一损失函数可以是任意损失函数,可选的,可以为交叉熵损失函数。
在实际应用中,在确定实际分类结果的情况下,可以根据预先设置的第一损失函数对实际分类结果和训练样本中的理论分类标签进行损失处理。进而,可以得到损失值,可以将该损失值作为第一模型损失。
示例性的,假设第一损失函数可以为交叉熵损失函数,第一模型损失可以根据如下公式确定:
其中,表示第一模型损失;/>表示温度系数为1的训练样本图像xi的概率分布(即,实际分类结果);yi表示理论分类标签。
可选的,根据多个第二模型输出,确定第二模型损失,包括:根据第一损失函数对多个第二模型输出进行损失处理,得到第二模型损失。
在实际应用中,为了提高细粒度分类任务的预测结果,可以基于辅助网络模块的输出对待训练学生网络进行自蒸馏训练。因此,在得到待训练学生网络中每个辅助网络模块输出的第二模型输出的情况下,可以根据第一损失函数对多个第二模型输出和预先确定的自监督标签进行损失处理,得到损失值,并可以将该损失值作为第二模型损失。
示例性的,假设第一损失函数为交叉熵损失函数,第二模型损失可以根据如下公式确定:
其中,表示第二模型损失;U表示对训练样本图像xi进行数据增强的变换数;l表示第几个辅助网络模块;trans(·)表示任意数据增强方法;τ表示对第二模型输出进行平滑化操作的温度系数;cj表示数据增强得到的图像与训练样本图像之间关系的标签。
可选的,根据第一平滑化概率分布和实际输出结果中的第二平滑化概率分布,确定第三模型损失,包括:根据预先设置的第二损失函数对第一平滑化概率分布和第二平滑化概率分布进行损失处理,得到第三模型损失。
在本实施例中,第二损失函数可以是任意损失函数,可选的,可以是KL散度损失函数。其中,KL散度损失函数可以表示一个概率分布相对于另一个概率分布的差异程度。
在实际应用中,在得到教师网络输出的第一平滑化概率分布,以及待训练学生网络输出的第二平滑化概率分布的情况下,可以根据第二损失函数对第一平滑化概率分布和第二平滑化概率分布进行损失处理,得到损失值,可以将该损失值作为第三模型损失。
示例性的,假设第二损失函数为KL散度损失函数,第三模型损失可以根据如下公式确定:
其中,LKL_trans表示第三模型损失;U表示对训练样本图像xi进行数据增强的变换数;τ表示平滑化操作的温度系数;DKL(·)表示KL散度;pT(transj(x);τ)表示第一平滑化概率分布;pS(transj(x);τ)表示第二平滑化概率分布。
可选的,根据多个第一模型输出和多个第二模型输出,确定第四模型损失,包括:根据第二损失函数对多个第一模型输出和多个第二模型输出进行损失处理,得到第四模型损失。
在实际应用中,在得到教师网络中每个辅助网络模块输出的第一模型输出,以及待训练学生网络中每个辅助网络模块输出的第二模型输出的情况下,可以根据第二损失函数对多个第一模型输出和多个第二模型输出进行损失处理,得到损失值,并将该损失值作为第四模型损失。
示例性的,假设第二损失函数为KL散度损失函数,第四模型损失可以根据如下公式确定:
其中,LKL_con表示第四模型损失;U表示对训练样本图像xi进行数据增强的变换数;τ表示平滑化操作的温度系数;DKL(·)表示KL散度;l表示第几个辅助网络模块;表示第一模型输出;/>表示第二模型输出。
在实际应用中,在确定第一模型损失、第二模型损失、第三模型损失以及第四模型损失之后,即可根据第一模型损失以及相应的权重、第二模型损失以及相应的权重、第三模型损失以及相应的权重和第四模型损失以及相应的权重,确定目标模型损失。
可选的,根据第一模型损失、第二模型损失、第三模型损失以及第四模型损失,得到目标模型损失,包括:分别确定与第一模型损失对应的第一权重值、与第二模型损失对应的第二权重值、与第三模型损失对应的第三权重值以及与第四模型损失对应的第四权重值;将第一模型损失和第一权重值相乘,得到第一待处理损失值,将第二模型损失和第二权重值相乘,得到第二待处理损失值,将第三模型损失与第三权重值相乘,得到第三待处理损失值,以及,将第四模型损失与第四权重值相乘,得到第四待处理损失值;将第一待处理损失值、第二待处理损失值、第三待处理损失值以及第四待处理损失值相加,得到目标模型损失。
在本实施例中,第一权重值可以是表征第一模型损失在目标模型损失中所占比重的数值。第二权重值可以是表征第二模型损失在目标模型损失中所占比重的数值。第三权重值可以是表征第三模型损失在目标模型损失中所占比重的数值。第四权重值可以是表征第四模型损失在目标模型损失中所占比重的数值。
在实际应用中,分别确定第一模型损失对应的第一权重值、第二模型损失对应的第二权重值、第三模型损失对应的第三权重值以及第四模型损失对应的第四权重值。进一步的,可以将第一模型损失与第一权重值进行相乘处理,得到一个数值,可以将该数值作为第一待处理损失值;将第二模型损失和第二权重值进行相乘处理,得到一个数值,可以将该数值作为第二待处理损失值;将第三模型损失和第三权重值进行相乘处理,得到一个数值,可以将该数值作为第三待处理损失值;将第四模型损失与第四权重值进行相乘处理,得到一个数值,可以将该数值作为第四待处理损失值。之后,可以将第一待处理损失值、第二待处理损失值、第三待处理损失值和第四待处理损失值进行相加处理,从而,可以最终得到目标模型损失。
示例性的,可以基于如下公式确定目标模型损失:
其中,L表示目标模型损失;α1表示第一权重值;表示第一模型损失;α2表示第二权重值;/>表示第二模型损失;α3表示第三权重值;LKL_trans表示第三模型损失;α4表示第四权重值;LKL_con表示第四模型损失。
进一步的,可以根据目标模型损失对待训练学生网络的模型参数进行调整,从而,可以得到图像分类模型。
在具体实施中,可以将目标模型损失对应的损失函数收敛作为训练目标,比如,训练误差是否小于预设误差;或者,误差变化是否趋于稳定;或者,当前的迭代次数是否等于预设次数。若检测到满足收敛条件,比如,目标模型损失对应的损失函数的训练误差小于预设误差;或者,误差变化趋势趋于稳定,可以表明待训练学生网络训练完成,此时,可以停止迭代训练。若检测到当前未达到收敛条件,可以进一步获取其他训练样本对待训练学生网络继续训练,直至目标模型损失的训练误差在预设范围之内。当目标模型损失的训练误差达到收敛时,即可将训练完成的待训练学生网络作为图像分类模型,即,此时,将待分类图像输入至图像分类模型中后,即可准确的得到待分类图像对应的类别。
示例性的,如图3所示,即为多阶段对比知识蒸馏方法的示意图。
本发明实施例的技术方案,通过获取训练样本图像和与训练样本图像对应的理论分类标签,并对训练样本图像进行数据增强,得到至少一张待处理样本图像,基于至少一张待处理样本图像和理论分类标签,构建训练样本。之后,基于预先训练完成的教师网络对训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出。进一步的,将训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出。最后,根据第一平滑化概率分布、多个第一模型输出、实际输出结果、多个第二模型输出以及理论分类标签,确定目标模型损失,并基于目标模型损失对待训练学生模型进行模型参数调整,得到图像分类模型,解决了相关技术中无法充分将知识迁移到学生网络,难以充分利用教师网络蕴含的知识来提高轻量化网络的表征能力,使得知识迁移效果受限,神经网络模型压缩的准确度较低等问题,实现了从教师网络传递了更加丰富的知识给学生网络,让学生网络的性能进一步接近甚至超过教师网络性能的效果,从而达到了简单模型高性能的目标。
上述具体实施方式,并不构成对本发明保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本发明的精神和原则之内所作的修改、等同替换和改进等,均应包含在本发明保护范围之内。
Claims (9)
1.一种多阶段对比知识蒸馏方法,其特征在于,包括:
获取训练样本图像和与所述训练样本图像对应的理论分类标签,并对所述训练样本图像进行数据增强,得到至少一张待处理样本图像,基于所述至少一张待处理样本图像和所述理论分类标签,构建训练样本;
基于预先训练完成的教师网络对所述训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出;
将所述训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出,其中,所述实际输出结果中包括实际分类结果和第二平滑化概率分布;
根据所述第一平滑化概率分布、多个所述第一模型输出、所述实际输出结果、多个所述第二模型输出以及所述理论分类标签,确定目标模型损失,并基于所述目标模型损失对所述待训练学生模型进行模型参数调整,得到图像分类模型。
2.根据权利要求1所述的方法,其特征在于,所述教师网络包括多个教师子模型、全连接层以及分类器,所述教师子模型包括教师主干网络模块和教师辅助网络模块,所述基于预先训练完成的教师网络对所述训练样本进行处理,得到第一平滑化概率分布和与多个教师子模型对应的第一模型输出,包括:
基于所述教师网络中的所述多个教师主干网络模块、全连接层以及分类器依次对所述训练样本进行处理,得到所述第一平滑化概率分布;
针对每个教师子模型,将所述教师子模型中的教师主干网络模块的模型输出作为所述教师子模型的教师辅助网络模块的模型输入,并将所述教师辅助网络模块的模型输出作为所述教师子模型对应的第一模型输出。
3.根据权利要求1所述的方法,其特征在于,所述待训练学生网络包括多个学生子模型、全连接层以及分类器,所述学生子模型包括学生主干网络模块和学生辅助网络模块,所述将所述训练样本输入至待训练学生网络中,得到实际输出结果以及与多个学生子模型对应的第二模型输出,包括:
基于所述待训练学生网络中的所述多个学生主干网络模块、全连接层以及分类器对所述训练样本进行处理,得到所述实际输出结果;
针对每个学生子模型,将所述学生子模型中的学生主干网络模块的模型输出作为所述学生子模型的学生辅助网络模块的模型输入,并将所述学生辅助网络模块的模型输出作为所述学生子模型对应的第二模型输出。
4.根据权利要求1所述的方法,其特征在于,所述根据所述第一平滑化概率分布、多个所述第一模型输出、所述实际输出结果、多个所述第二模型输出以及所述理论分类标签,确定目标模型损失,包括:
根据所述实际输出结果中的实际分类结果和所述理论分类标签,确定第一模型损失;
根据多个所述第二模型输出,确定第二模型损失;
根据所述第一平滑化概率分布和所述实际输出结果中的第二平滑化概率分布,确定第三模型损失;
根据多个所述第一模型输出和多个所述第二模型输出,确定第四模型损失;
根据所述第一模型损失、所述第二模型损失、所述第三模型损失以及所述第四模型损失,得到所述目标模型损失。
5.根据权利要求4所述的方法,其特征在于,所述根据所述实际输出结果中的实际分类结果和所述理论分类标签,确定第一模型损失,包括:
根据预先设置的第一损失函数对所述实际分类结果和所述理论分类标签进行损失处理,得到所述第一模型损失。
6.根据权利要求5所述的方法,其特征在于,所述根据多个所述第二模型输出,确定第二模型损失,包括:
根据所述第一损失函数对多个所述第二模型输出进行损失处理,得到所述第二模型损失。
7.根据权利要求4所述的方法,其特征在于,所述根据所述第一平滑化概率分布和所述实际输出结果中的第二平滑化概率分布,确定第三模型损失,包括:
根据预先设置的第二损失函数对所述第一平滑化概率分布和所述第二平滑化概率分布进行损失处理,得到所述第三模型损失。
8.根据权利要求7所述的方法,其特征在于,所述根据多个所述第一模型输出和多个所述第二模型输出,确定第四模型损失,包括:
根据所述第二损失函数对多个所述第一模型输出和多个所述第二模型输出进行损失处理,得到所述第四模型损失。
9.根据权利要求4所述的方法,其特征在于,所述根据所述第一模型损失、所述第二模型损失、所述第三模型损失以及所述第四模型损失,得到所述目标模型损失,包括:
分别确定与所述第一模型损失对应的第一权重值、与所述第二模型损失对应的第二权重值、与所述第三模型损失对应的第三权重值以及与所述第四模型损失对应的第四权重值;
将所述第一模型损失和所述第一权重值相乘,得到第一待处理损失值,将所述第二模型损失和所述第二权重值相乘,得到第二待处理损失值,将所述第三模型损失与所述第三权重值相乘,得到第三待处理损失值,以及,将所述第四模型损失与所述第四权重值相乘,得到第四待处理损失值;
将所述第一待处理损失值、所述第二待处理损失值、所述第三待处理损失值以及所述第四待处理损失值相加,得到所述目标模型损失。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311064055.4A CN117095217A (zh) | 2023-08-22 | 2023-08-22 | 多阶段对比知识蒸馏方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311064055.4A CN117095217A (zh) | 2023-08-22 | 2023-08-22 | 多阶段对比知识蒸馏方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117095217A true CN117095217A (zh) | 2023-11-21 |
Family
ID=88783154
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311064055.4A Pending CN117095217A (zh) | 2023-08-22 | 2023-08-22 | 多阶段对比知识蒸馏方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117095217A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118233222A (zh) * | 2024-05-24 | 2024-06-21 | 浙江大学 | 一种基于知识蒸馏的工控网络入侵检测方法及装置 |
-
2023
- 2023-08-22 CN CN202311064055.4A patent/CN117095217A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118233222A (zh) * | 2024-05-24 | 2024-06-21 | 浙江大学 | 一种基于知识蒸馏的工控网络入侵检测方法及装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110288030B (zh) | 基于轻量化网络模型的图像识别方法、装置及设备 | |
CN109191382B (zh) | 图像处理方法、装置、电子设备及计算机可读存储介质 | |
CN109949255B (zh) | 图像重建方法及设备 | |
CN110223292B (zh) | 图像评估方法、装置及计算机可读存储介质 | |
CN113570508A (zh) | 图像修复方法及装置、存储介质、终端 | |
CN110363068B (zh) | 一种基于多尺度循环生成式对抗网络的高分辨行人图像生成方法 | |
CN110148088B (zh) | 图像处理方法、图像去雨方法、装置、终端及介质 | |
WO2021042857A1 (zh) | 图像分割模型的处理方法和处理装置 | |
CN111144214B (zh) | 基于多层堆栈式自动编码器的高光谱图像解混方法 | |
CN108197669B (zh) | 卷积神经网络的特征训练方法及装置 | |
CN111898482B (zh) | 基于渐进型生成对抗网络的人脸预测方法 | |
CN116089883B (zh) | 用于提高已有类别增量学习新旧类别区分度的训练方法 | |
CN113920043A (zh) | 基于残差通道注意力机制的双流遥感图像融合方法 | |
CN112784929A (zh) | 一种基于双元组扩充的小样本图像分类方法及装置 | |
CN117095217A (zh) | 多阶段对比知识蒸馏方法 | |
CN110598848A (zh) | 一种基于通道剪枝的迁移学习加速方法 | |
CN111461978A (zh) | 一种基于注意力机制的逐分辨率提升图像超分辨率复原方法 | |
CN114239861A (zh) | 基于多教师联合指导量化的模型压缩方法及系统 | |
CN112270366A (zh) | 基于自适应多特征融合的微小目标检测方法 | |
CN114897711A (zh) | 一种视频中图像处理方法、装置、设备及存储介质 | |
CN114492581A (zh) | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 | |
CN114581789A (zh) | 一种高光谱图像分类方法及系统 | |
CN112528077B (zh) | 基于视频嵌入的视频人脸检索方法及系统 | |
CN116888605A (zh) | 神经网络模型的运算方法、训练方法及装置 | |
CN116416212B (zh) | 路面破损检测神经网络训练方法及路面破损检测神经网络 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |