CN112766463A - 基于知识蒸馏技术优化神经网络模型的方法 - Google Patents
基于知识蒸馏技术优化神经网络模型的方法 Download PDFInfo
- Publication number
- CN112766463A CN112766463A CN202110098053.1A CN202110098053A CN112766463A CN 112766463 A CN112766463 A CN 112766463A CN 202110098053 A CN202110098053 A CN 202110098053A CN 112766463 A CN112766463 A CN 112766463A
- Authority
- CN
- China
- Prior art keywords
- network model
- network
- loss
- layer
- output
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 64
- 238000003062 neural network model Methods 0.000 title claims abstract description 26
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 23
- 238000005516 engineering process Methods 0.000 title claims abstract description 20
- 238000012549 training Methods 0.000 claims abstract description 48
- 230000005012 migration Effects 0.000 claims abstract description 10
- 238000013508 migration Methods 0.000 claims abstract description 10
- 230000006870 function Effects 0.000 claims description 50
- 238000012545 processing Methods 0.000 claims description 21
- 238000004821 distillation Methods 0.000 claims description 16
- 238000004364 calculation method Methods 0.000 claims description 12
- 238000004590 computer program Methods 0.000 claims description 7
- 238000013528 artificial neural network Methods 0.000 claims description 6
- 238000005096 rolling process Methods 0.000 claims description 4
- 238000012047 cause and effect analysis Methods 0.000 claims description 3
- 238000012043 cost effectiveness analysis Methods 0.000 claims description 3
- 230000008569 process Effects 0.000 abstract description 6
- 239000010410 layer Substances 0.000 description 90
- 238000010586 diagram Methods 0.000 description 16
- 238000004422 calculation algorithm Methods 0.000 description 4
- 238000010606 normalization Methods 0.000 description 4
- 230000003287 optical effect Effects 0.000 description 4
- 230000002085 persistent effect Effects 0.000 description 4
- 230000000694 effects Effects 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 238000012512 characterization method Methods 0.000 description 2
- 238000011478 gradient descent method Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 230000002159 abnormal effect Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 239000002355 dual-layer Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本申请是关于一种基于知识蒸馏技术优化神经网络模型的方法、电子设备及存储介质,通过根据动态知识迁移的策略,基于第一网络模型的预测值以及第一数据集的真实标签值,得到动态权重;根据所述动态权重实时更新所述第一交叉熵损失以及所述第二交叉熵损失,得到综合交叉熵损失,以此利用第一网络模型的输出结果动态监督第二网络模型的训练过程,相比于原本直接训练第二网络模型的方法,机器人识别模块在保持计算复杂度的情况下,使得第二网络模型的精度提升,解决了传统的知识蒸馏并没有充分地将大网络的知识迁移到轻量型网络中,轻量型网络的精度尚存在提高空间的问题。
Description
技术领域
本申请涉及人工智能技术领域,尤其涉及基于知识蒸馏技术优化神经网络模型的方法、电子设备及存储介质。
背景技术
随着深度学习的发展,基于深度神经网络的识别算法得到普及,同时这类算法计算量庞大的问题,导致这类算法在机器人等边缘端设备很难运用,受限其有限的计算资源,需要对算法复杂度优化,对网络结构精剪。
知识蒸馏(Knowledge distillation)是一种通过知识迁移的想法,通过训练好的大模型(Teacher Model)来指导小模型(Student Model)训练的方法,缩小小模型与大模型精度差异的效果。但是,传统的知识蒸馏并没有充分地将大网络的知识迁移到轻量型网络中,轻量型网络的精度尚存在提高空间。
因此,期望借助知识蒸馏的技术,实现小模型的精度提升。
发明内容
为克服相关技术中存在的问题,本申请提供一种基于知识蒸馏技术优化神经网络模型的方法、电子设备及存储介质,旨在解决在机器人端因算力不足而导致神经网络性能下降的问题。同时,借助知识蒸馏技术,提出了一种无监督学习的策略,大大增加业务小模型训练数据量,降低标注成本。
本申请解决上述技术问题的技术方案如下:一种基于知识蒸馏技术优化神经网络模型的方法,其特征在于,包括以下步骤:步骤1,基于Darknet53网络架构构建第一网络模型,得到所述第一网络模型的第一交叉熵损失;步骤2,将第一数据集输入所述第一网络模型,根据所述第一交叉熵损失训练所述第一网络模型,得到第一网络模型的网络参数;步骤3,基于Mobilenet网络构建第二网络模型;步骤4,根据所述第一交叉熵损失,基于知识蒸馏,得到所述第二网络模型的第二交叉熵损失;步骤5,根据动态知识迁移的策略,基于第一网络模型的预测值以及第一数据集的真实标签值,得到动态权重;步骤6,根据所述动态权重实时更新所述第一交叉熵损失以及所述第二交叉熵损失,得到综合交叉熵损失。
优选的,在步骤6之后还包括以下步骤:S101,将所述第一数据集和一第二数据集按照数目比1:5混合,得到第二网络模型的训练集;S102,将所述训练集输入到所述第二网络模型中,根据训练集的数据类型来选择所述综合交叉熵损失的计算,得到所述第二网络模型的损失类型;S103,根据所述第二网络模型的损失类型,采用随机梯度下降法SGD和动量法,训练所述第二网络模型;S104,当第二网络模型的第二交叉熵损失与所述第一网络模型的第一交叉熵损失比例小于1%时,停止训练所述第二网络模型,得到所述第二网络模型的网络参数。
优选的,所述第一网络模型的结构按计算单元处理顺序为:第一卷积块,用于对输入到所述第一网络模型的第一数据集进行卷积处理,所述第一卷积块由52层卷积堆叠而成;第一嵌入层,所述第一嵌入层由16倍下采样层构成,用于对所述第一卷积块的输出进行下采样;第二卷积块,用于对所述第一嵌入层的输出进行卷积处理;第一全连接层,用于将所述第二卷积块的输出经过所述第一全连接层后输入到一第一网络输出层。
优选的,所述第二网络模型的结构按计算单元处理顺序为:第一组卷积块,用于对输入到所述第二网络模型的训练集进行卷积处理,所述第一组卷积块由19层组卷积模块构成;第二嵌入层,所述第二嵌入层由16倍下采样层构成,用于对所述第一组卷积块的输出进行下采样;第二组卷积块,用于对所述第二嵌入层的输出进行卷积处理;第二全连接层,用于将所述第二组卷积块的输出经过所述第二全连接层后输入到一第二网络输出层。
优选的,基于知识蒸馏,在所述第一网络输出层中添加温度参数T,形成一第一网络温度输出层;以及在所述第二网络输出层中添加温度参数T,形成一第二网络温度输出层,所述第一网络温度输出层以及所述第二网络温度输出层用于控制网络软标签的值。
优选的,所述网络软标签的函数为:
当T=1,所述网络软标签退化成原始标签,Zi表示神经网络得到的概率分布。
优选的,所述综合交叉熵损失的函数为:
Loss=LossCE+weight*(Losssoft+Lossembed);
LossCE=∑Yilog(yi);
Losssoft=∑Ytilog(ysi);
Lossembed=1-featt*feats
featt=Norm(featt)
feats=Norm(feats);
其中,weight为动态权重,LossCE为第二交叉熵损失函数,Losssoft为网络软标签的损失函数,Lossembed为嵌入层的损失函数,featt为第一嵌入层的输出,feats为第二嵌入层的输出,Norm表示对特征层按照特征层均值标准差归一化,第二网络模型的原始标签预测值yi,数据集的真实标签值Yi,第二网络模型的网络软标签的输出值ysi,第一网络模型的网络软标签的输出值Yti。
优选的,当在所述第二数据集上,无监督训练时,所述综合交叉熵损失的函数退化成:
Loss=weight*(Losssoft+Lossembed);
其中,数据集的真实标签值为Yi,第一网络模型的原始标签预测值为yti。
优选的,所述第一数据集为已有标注的数据集,所述第二数据集为无标准的数据集。
本申请实施例的第二方面提供了一种电子设备,包括:
处理器;以及一个或多个处理器;一个或多个程序,其中所述一个或多个程序存储在所述存储器中并被配置为由所述一个或多个处理器执行,所述一个或多个程序包括用于执行如上所述的方法的指令。
本申请第三方面提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上所述的方法。
本申请提供一种基于知识蒸馏技术优化神经网络模型的方法、电子设备及存储介质,通过根据动态知识迁移的策略,基于第一网络模型的预测值以及第一数据集的真实标签值,得到动态权重;根据所述动态权重实时更新所述第一交叉熵损失以及所述第二交叉熵损失,得到综合交叉熵损失后,根据训练集的数据类型来选择所述综合交叉熵损失的计算,得到所述第二网络模型的损失类型;再然后根据所述第二网络模型的损失类型,采用随机梯度下降法SGD和动量法,训练所述第二网络模型,以此利用第一网络模型的输出结果动态监督第二网络模型的训练过程,相比于原本直接训练第二网络模型的方法,机器人识别模块在保持计算复杂度的情况下,使得第二网络模型的精度提升,解决了传统的知识蒸馏并没有充分地将大网络的知识迁移到轻量型网络中,轻量型网络的精度尚存在提高空间的问题。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本申请。
附图说明
通过结合附图对本申请示例性实施方式进行更详细的描述,本申请的上述以及其它目的、特征和优势将变得更加明显,其中,在本申请示例性实施方式中,相同的参考标号通常代表相同部件。
图1是本申请实施例示出的基于知识蒸馏技术优化神经网络模型的方法的流程示意图;
图2是本申请实施例示出的第一网络模型的结构示意图;
图3是本申请实施例示出的组卷积模块的结构示意图;
图4是本申请实施例示出的第二网络模型的结构示意图;
图5是本申请实施例示出的基于知识蒸馏技术优化神经网络模型的方法的另一流程示意图;
图6是本申请实施例示出的基于知识蒸馏技术优化神经网络模型的整体结构示意图;
图7是本申请实施例示出的电子设备的结构示意图。
具体实施方式
下面将参照附图更详细地描述本申请的优选实施方式。虽然附图中显示了本申请的优选实施方式,然而应该理解,可以以各种形式实现本申请而不应被这里阐述的实施方式所限制。相反,提供这些实施方式是为了使本申请更加透彻和完整,并且能够将本申请的范围完整地传达给本领域的技术人员。
在本申请使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本申请。在本申请和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。还应当理解,本文中使用的术语“和/或”是指并包含一个或多个相关联的列出项目的任何或所有可能组合。
应当理解,尽管在本申请可能采用术语“第一”、“第二”、“第三”等来描述各种信息,但这些信息不应限于这些术语。这些术语仅用来将同一类型的信息彼此区分开。例如,在不脱离本申请范围的情况下,第一信息也可以被称为第二信息,类似地,第二信息也可以被称为第一信息。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本申请的描述中,“多个”的含义是两个或两个以上,除非另有明确具体的限定。
在资源受限的器人等边缘端设备上部署这类大的网络模型较为困难,仅能部署规模较小的网络。然而,直接训练小网络得到的模型性能远低于大网络的性能。
针对上述问题,本申请实施例提供一种基于知识蒸馏技术优化神经网络模型的方法,解决在机器人端因算力不足而导致神经网络性能下降的问题,并且基于知识蒸馏技术优化神经网络模型的方法可以增加业务小模型训练数据量,降低标注成本,解决如何在标注数据有限的情况下,扩充机器人训练样本集的问题。
以下结合附图详细描述本申请实施例的技术方案。
请参阅图1,图1为本申请第一实施例示出的基于知识蒸馏技术优化神经网络模型的方法的流程示意图,如图1所示,方法包括以下步骤:
步骤S1,基于Darknet53网络架构构建第一网络模型,得到所述第一网络模型的第一交叉熵损失。
请参阅图2,图2为第一网络模型的结构示意图,具体的,所述第一网络模型(第一网络模型也被作为教师模型Teacher Model),Darknet53网络架构包括:堆叠52层卷积块,一层分类层以及一层网络输出层Softmax。本实施例中,选取中间16倍下采样层作为嵌入层(Embedding layer),所述第一网络模型的结构按计算单元处理顺序为:
第一卷积块,用于对输入到所述第一网络模型的第一数据集进行卷积处理,所述第一卷积块由52层卷积堆叠而成;
第一嵌入层,所述第一嵌入层由16倍下采样层构成,用于对所述第一卷积块的输出进行下采样;
第二卷积块,用于对所述第一嵌入层的输出进行卷积处理;
第一全连接层,用于将所述第二卷积块的输出经过所述第一全连接层后输入到一第一网络输出层。
步骤S2,将第一数据集输入所述第一网络模型,根据所述第一交叉熵损失训练所述第一网络模型,得到第一网络模型的网络参数。
具体的,定义所述第一网络模型的第一交叉熵损失函数如下:
其中Yi是第一数据集的标签值,yi是第一网络模型的预测值,第一数据集为已有标注的数据集。
根据上述第一交叉熵损失函数Loss公式,采用随机梯度下降法SGD和动量法,学习动量参数设定为0.9,卷积参数L2正则惩罚系数设定为0.001,学习率为多项式缓慢下降,训练100次后终止训练第一网络模型,得到所述第一网络模型(Teacher Model)的网络参数。
步骤S3,基于Mobilenet网络构建第二网络模型。
具体的,基于Mobilenet网络搭建第二网络模型(第二网络模型也称学生模型Student Model),MobileNet的基本单元是深度级可分离卷积(depthwise separableconvolution)。深度级可分离卷积其实是一种可分解卷积操作(factorizedconvolutions),其可以分解为两个更小的操作:深度卷积depthwise convolution和点卷积pointwise convolution。深度卷积Depthwise convolution和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上(input channels),而深度卷积depthwiseconvolution针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积depthwise convolution是深度depth级别的操作。而点卷积pointwiseconvolution其实就是普通的卷积,只不过点卷积采用1x1的卷积核。对于深度级可分离卷积depthwise separable convolution,其首先是采用深度卷积depthwise convolution对不同输入通道分别进行卷积,然后采用点卷积pointwise convolution将上面的输出再进行结合,这样其实整体效果和一个标准卷积是差不多的,但是会大大减少计算量和模型参数量。
请参阅图3,图3为组卷积模块的结构示意图,在本实施例中,所述第二网络模型的基本结构由组卷积(group convolution)模块构成,组卷积(group convolution)模块用于对输入特征图进行分组,然后每组分别卷积,组卷积模块的结构按计算单元处理顺序为:
组卷积层,用于对输入到所述第二网络模型的数据集进行卷积处理;
第一批归一化层,所述组卷积层的输出经过所述第一批归一化层后输入到一第一激活层;
第一卷积层,所述第一激活层的输出经过所述第一卷积层输入到第二批归一化层;
所述第二批归一化层的输出又作为第二激活层的输入。
请参阅图4,图4为第二网络模型的结构示意图,在本实施例中,第二网络模型(Student Model)基于所述组卷积模块搭建,由19层组卷积模块,一层分类层以及一网络输出层Softmax,所述第二网络模型选取中间16倍下采样层作为嵌入层,所述第二网络模型的结构按计算单元处理顺序为:
第一组卷积块,用于对输入到所述第二网络模型的训练集进行卷积处理,所述第一组卷积块由19层组卷积模块构成;
第二嵌入层,所述第二嵌入层由16倍下采样层构成,用于对所述第一组卷积块的输出进行下采样;
第二组卷积块,用于对所述第二嵌入层的输出进行卷积处理;
第二全连接层,用于将所述第二组卷积块的输出经过所述第二全连接层后输入到一第二网络输出层。
在其中一个实施例中,基于知识蒸馏,在所述第一网络输出层中添加温度参数T,形成一第一网络温度输出层;
以及在所述第二网络输出层中添加温度参数T,形成一第二网络温度输出层,所述第一网络温度输出层以及所述第二网络温度输出层用于控制网络软标签的值。
所述网络软标签的函数为:
当T=1,所述网络软标签退化成原始标签,Zi表示神经网络得到的概率分布。
步骤S4,根据所述第一交叉熵损失,基于知识蒸馏,得到所述第二网络模型的第二交叉熵损失。
具体的,基于知识蒸馏技术,根据所述第一网络模型的第一叉熵损失函数定义,得到第二交叉熵损失函数为:
其中Yi是第二数据集的标签值,yi是第二网络模型的预测值,第二数据集为无标注的数据集。
步骤S5,根据动态知识迁移的策略,基于第一网络模型的预测值以及第一数据集的真实标签值,得到动态权重。
具体的,在实际应用过程中,由于第一网络模型并不是完美的训练指导网络模型,存在错误输出的情况,本实例根据动态知识迁移的策略,在第一网络模型(Teacher Model)预测值异常时,降低第一网络模型的迁移权重,得到动态权重,所述动态权重的函数为:
第一网络模型的预测值为yti,数据集的真实标签为Yi。
步骤S6,根据所述动态权重实时更新所述第一交叉熵损失以及所述第二交叉熵损失,得到综合交叉熵损失。
具体的,训练第二网络模型的损失函数主要由第二交叉熵损失函数、网络软标签的损失函数以及嵌入层的损失函数组成,即综合交叉熵损失包括所述第二交叉熵损失函数、所述网络软标签的损失函数以及嵌入层的损失函数。与第一网络模型的损失函数定义类似,所述综合交叉熵损失的函数为:
Loss=LossCE+weight*(Losssoft+Lossembed);
LossCE=∑Yilog(yi);
Losssoft=∑Ytilog(ysi);
Lossembed=1-featt*feats
featt=Norm(featt)
feats=Norm(feats);
其中,weight为动态权重,LossCE为第二交叉熵损失函数,Losssoft为网络软标签的损失函数,Lossembed为嵌入层的损失函数,featt为第一嵌入层的输出,feats为第二嵌入层的输出,Norm表示对特征层按照特征层均值标准差归一化,第二网络模型的原始标签预测值yi,数据集的真实标签值Yi,第二网络模型的网络软标签的输出值ysi,第一网络模型的网络软标签的输出值Yti。
具体的,所述第二网络模型的原始标签预测值yi,数据集真实标签值Yi,得到所述第二交叉熵损失函数:LossCE=∑Yilog(yi);
所述第二网络模型的网络软标签输出值ysi与所述第一网络模型的网络软标签输出值Yti,得到所述网络软标签的损失函数:Losssoft=∑Ytilog(ysi);
所述第一网络模型的第一嵌入层输出featt,和所述第二网络模型的第二嵌入层输出feats,在训练所述第二网络模型时,期望所述第二网络模型在中间层的特征表征与所述第一网络模型具有类似的效果,定义所述第一网络模型以及所述第二网络模型的嵌入层特征表征相似度定义如下:
featt=Norm(featt)
feats=Norm(feats)
Lossembed=1-featt*feats
其中Norm表示对特征层按照特征层均值标准差归一化,公式如下:
根据所述动态权重实时更新所述第一交叉熵损失、所述第二交叉熵损失以及上述公式的各个损失函数的权重,最终得到所述综合交叉熵损失的函数为:Loss=LossCE+weight*(Losssoft+Lossembed)。
在本实施例中,通过根据动态知识迁移的策略,基于第一网络模型的预测值以及第一数据集的真实标签值,得到动态权重;根据所述动态权重实时更新所述第一交叉熵损失以及所述第二交叉熵损失,得到综合交叉熵损失,以此利用第一网络模型的输出结果动态监督第二网络模型的训练过程,相比于原本直接训练第二网络模型的方法,机器人识别模块在保持计算复杂度的情况下,使得第二网络模型的精度提升,解决了传统的知识蒸馏并没有充分地将大网络的知识迁移到轻量型网络中,轻量型网络的精度尚存在提高空间的问题。
请参阅图5,图5为本申请第二实施例示出的基于知识蒸馏技术优化神经网络模型的方法的另一流程示意图。第二实施例为在上述实施例的基础上增加以下步骤,具体如下:
在步骤6之后还包括以下步骤:
S101,将所述第一数据集和一第二数据集按照数目比1:5混合,得到第二网络模型的训练集。
具体的,所述第一数据集为已有标注的数据集,所述第二数据集为无标注的数据集,通过将第一数据集和所述第二数据集按照1:5的比例混合,得到所述训练集用于作为所述第二网络模型的输入。
S102,将所述训练集输入到所述第二网络模型中,根据训练集的数据类型来选择所述综合交叉熵损失的计算,得到所述第二网络模型的损失类型。
具体的,当在所述第二数据集上,无监督训练时,所述综合交叉熵损失的函数退化成:
Loss=weight*(Losssoft+Lossembed);
其中,数据集的真实标签值为Yi,第一网络模型的原始标签预测值为yti。
其中weight计算中的Yi定义如下:
N为数据集分类的类别数目。
当输入第二网络模型的数据集为第一数据集时,选择所述综合交叉熵损失函数的计算,得到所述第二网络模型的损失类型为:
Loss=LossCE+weight*(Losssoft+Lossembed);
当输入第二网络模型的数据集为第二数据集时,选择所述综合交叉熵损失函数的计算,得到所述第二网络模型的损失类型为:
Loss=weight*(Losssoft+Lossembed)。
S103,根据所述第二网络模型的损失类型,采用随机梯度下降法SGD和动量法,训练所述第二网络模型。
具体的,采用随机梯度下降法SGD和动量法,学习动量参数设定为0.9,卷积参数L2正则惩罚系数设定为0.001,学习率为多项式缓慢下降,来训练所述第二网络模型。
S104,当第二网络模型的第二交叉熵损失与所述第一网络模型的第一交叉熵损失比例小于1%时,停止训练所述第二网络模型,得到所述第二网络模型的网络参数。
请参照图6,图6为基于知识蒸馏技术优化神经网络模型的整体结构示意图,基于知识蒸馏技术优化神经网络模型包括第一网络模型以及第二网路模型,当所述训练集输入到基于知识蒸馏技术优化神经网络模型中时,如果所述训练集为所述第一数据集时,所述训练集则输入到所述第一网络模型中时,所述第一数据集依次通过第一卷积块、第一嵌入层、第二卷积块、第一全连接层,最后一部分通过第一网络输出层softmax输出,另一部分通过所述第一网络温度输出层softmaxT输出,其中所述第一网络输出层softmax的输出作为所述动态权重的输入,所述第一网络温度输出层softmaxT的输出作为所述网络软标签的损失函数的输入。如果所述训练集为所述第二数据集时,所述训练集则输入到所述第二网络模型中时,所述第二数据集依次通过第一组卷积块、第二嵌入层、第二组卷积块、第二全连接层,最后一部分通过第二网络输出层softmax输出,另一部分通过所述第二网络温度输出层softmaxT输出,其中所述第二网络输出层softmax的输出作为所述第二交叉熵损失函数的输入,所述第二网络温度输出层softmaxT的输出作为所述网络软标签的损失函数的输入。其中动态权重的输出分别作为作为所述网络软标签的损失函数的输入以及嵌入层的损失函数的输入,以此降低第一网络模型的迁移权重,从而以此利用第一网络模型的输出结果动态监督第二网络模型的训练过程,相比于原本直接训练第二网络模型的方法,机器人识别模块在保持计算复杂度的情况下,使得第二网络模型的精度提升。最后所述网络软标签的损失函数、嵌入层的损失函数以及所述第二交叉熵损失函数形成所述综合交叉熵损失函数。
在本实施例中,通过所述第一数据集为已有标注的数据集,所述第二数据集为无标注的数据集,通过将第一数据集和所述第二数据集按照1:5的比例混合,得到所述训练集用于作为所述第二网络模型的输入,改进了知识蒸馏技术在大网络因训练集不足而导致训练不佳的情况下的错误监督的问题,本实施例中,利用知识蒸馏技术,结合无监督学习策略,在标注数据有限的情况下,扩充机器人训练样本集的方法,降低了标注成本。同时利用第一网络模型的输出结果动态监督第二网络模型的训练过程,相比于原本直接训练第二网络模型的方法,机器人识别模块在保持计算复杂度的情况下,精度提升3.2%。
图7是本申请实施例示出的电子设备的结构示意图。
参见图7,电子设备400包括存储器410和处理器420。
处理器420可以是中央处理单元(Central Processing Unit,CPU),还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器410可以包括各种类型的存储单元,例如系统内存、只读存储器(ROM),和永久存储装置。其中,ROM可以存储处理器1020或者计算机的其他模块需要的静态数据或者指令。永久存储装置可以是可读写的存储装置。永久存储装置可以是即使计算机断电后也不会失去存储的指令和数据的非易失性存储设备。在一些实施方式中,永久性存储装置采用大容量存储装置(例如磁或光盘、闪存)作为永久存储装置。另外一些实施方式中,永久性存储装置可以是可移除的存储设备(例如软盘、光驱)。系统内存可以是可读写存储设备或者易失性可读写存储设备,例如动态随机访问内存。系统内存可以存储一些或者所有处理器在运行时需要的指令和数据。此外,存储器410可以包括任意计算机可读存储媒介的组合,包括各种类型的半导体存储芯片(DRAM,SRAM,SDRAM,闪存,可编程只读存储器),磁盘和/或光盘也可以采用。在一些实施方式中,存储器410可以包括可读和/或写的可移除的存储设备,例如激光唱片(CD)、只读数字多功能光盘(例如DVD-ROM,双层DVD-ROM)、只读蓝光光盘、超密度光盘、闪存卡(例如SD卡、min SD卡、Micro-SD卡等等)、磁性软盘等等。计算机可读存储媒介不包含载波和通过无线或有线传输的瞬间电子信号。
存储器410上存储有可执行代码,当可执行代码被处理器420处理时,可以使处理器420执行上文述及的方法中的部分或全部。
上文中已经参考附图详细描述了本申请的方案。在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详细描述的部分,可以参见其他实施例的相关描述。本领域技术人员也应该知悉,说明书中所涉及的动作和模块并不一定是本申请所必须的。另外,可以理解,本申请实施例方法中的步骤可以根据实际需要进行顺序调整、合并和删减,本申请实施例装置中的模块可以根据实际需要进行合并、划分和删减。
此外,根据本申请的方法还可以实现为一种计算机程序或计算机程序产品,该计算机程序或计算机程序产品包括用于执行本申请的上述方法中部分或全部步骤的计算机程序代码指令。
或者,本申请还可以实施为一种非暂时性机器可读存储介质(或计算机可读存储介质、或机器可读存储介质),其上存储有可执行代码(或计算机程序、或计算机指令代码),当所述可执行代码(或计算机程序、或计算机指令代码)被电子设备(或电子设备、服务器等)的处理器执行时,使所述处理器执行根据本申请的上述方法的各个步骤的部分或全部。
本领域技术人员还将明白的是,结合这里的申请所描述的各种示例性逻辑块、模块、电路和算法步骤可以被实现为电子硬件、计算机软件或两者的组合。
附图中的流程图和框图显示了根据本申请的多个实施例的系统和方法的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现中,方框中所标记的功能也可以以不同于附图中所标记的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
以上已经描述了本申请的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术的改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。
Claims (11)
1.一种基于知识蒸馏技术优化神经网络模型的方法,其特征在于,包括以下步骤:
步骤1,基于Darknet53网络架构构建第一网络模型,得到所述第一网络模型的第一交叉熵损失;
步骤2,将第一数据集输入所述第一网络模型,根据所述第一交叉熵损失训练所述第一网络模型,得到第一网络模型的网络参数;
步骤3,基于Mobilenet网络构建第二网络模型;
步骤4,根据所述第一交叉熵损失,基于知识蒸馏,得到所述第二网络模型的第二交叉熵损失;
步骤5,根据动态知识迁移的策略,基于第一网络模型的预测值以及第一数据集的真实标签值,得到动态权重;
步骤6,根据所述动态权重实时更新所述第一交叉熵损失以及所述第二交叉熵损失,得到综合交叉熵损失。
2.根据权利要求1所述的基于知识蒸馏技术优化神经网络模型的方法,其特征在于,在步骤6之后还包括以下步骤:
S101,将所述第一数据集和一第二数据集按照数目比1:5混合,得到第二网络模型的训练集;
S102,将所述训练集输入到所述第二网络模型中,根据训练集的数据类型来选择所述综合交叉熵损失的计算,得到所述第二网络模型的损失类型;
S103,根据所述第二网络模型的损失类型,采用随机梯度下降法SGD和动量法,训练所述第二网络模型;
S104,当第二网络模型的第二交叉熵损失与所述第一网络模型的第一交叉熵损失比例小于1%时,停止训练所述第二网络模型,得到所述第二网络模型的网络参数。
3.根据权利要求2所述的基于知识蒸馏技术优化神经网络模型的方法,其特征在于,所述第一网络模型的结构按计算单元处理顺序为:
第一卷积块,用于对输入到所述第一网络模型的第一数据集进行卷积处理,所述第一卷积块由52层卷积堆叠而成;
第一嵌入层,所述第一嵌入层由16倍下采样层构成,用于对所述第一卷积块的输出进行下采样;
第二卷积块,用于对所述第一嵌入层的输出进行卷积处理;
第一全连接层,用于将所述第二卷积块的输出经过所述第一全连接层后输入到一第一网络输出层。
4.根据权利要求3所述的基于知识蒸馏技术优化神经网络模型的方法,其特征在于,所述第二网络模型的结构按计算单元处理顺序为:
第一组卷积块,用于对输入到所述第二网络模型的训练集进行卷积处理,所述第一组卷积块由19层组卷积模块构成;
第二嵌入层,所述第二嵌入层由16倍下采样层构成,用于对所述第一组卷积块的输出进行下采样;
第二组卷积块,用于对所述第二嵌入层的输出进行卷积处理;
第二全连接层,用于将所述第二组卷积块的输出经过所述第二全连接层后输入到一第二网络输出层。
5.根据权利要求4所述的基于知识蒸馏技术优化神经网络模型的方法,其特征在于,基于知识蒸馏,在所述第一网络输出层中添加温度参数T,形成一第一网络温度输出层;
以及在所述第二网络输出层中添加温度参数T,形成一第二网络温度输出层,所述第一网络温度输出层以及所述第二网络温度输出层用于控制网络软标签的值。
9.根据权利要求2所述的基于知识蒸馏技术优化神经网络模型的方法,其特征在于:所述第一数据集为已有标注的数据集,所述第二数据集为无标准的数据集。
10.一种电子设备,包括:存储器;一个或多个处理器;一个或多个程序,其中所述一个或多个程序存储在所述存储器中并被配置为由所述一个或多个处理器执行,所述一个或多个程序包括用于执行根据权利要求1-8所述方法中的任一方法的指令。
11.一种存储介质,存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,实现权利要求1-9任一项所述基于知识蒸馏技术优化神经网络模型的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110098053.1A CN112766463A (zh) | 2021-01-25 | 2021-01-25 | 基于知识蒸馏技术优化神经网络模型的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110098053.1A CN112766463A (zh) | 2021-01-25 | 2021-01-25 | 基于知识蒸馏技术优化神经网络模型的方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112766463A true CN112766463A (zh) | 2021-05-07 |
Family
ID=75707232
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110098053.1A Pending CN112766463A (zh) | 2021-01-25 | 2021-01-25 | 基于知识蒸馏技术优化神经网络模型的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112766463A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113343898A (zh) * | 2021-06-25 | 2021-09-03 | 江苏大学 | 基于知识蒸馏网络的口罩遮挡人脸识别方法、装置及设备 |
CN113706347A (zh) * | 2021-08-31 | 2021-11-26 | 深圳壹账通智能科技有限公司 | 一种多任务模型蒸馏方法、系统、介质及电子终端 |
CN113762368A (zh) * | 2021-08-27 | 2021-12-07 | 北京市商汤科技开发有限公司 | 数据蒸馏的方法、装置、电子设备和存储介质 |
CN114995131A (zh) * | 2022-05-25 | 2022-09-02 | 福建德尔科技股份有限公司 | 用于电子级三氟甲烷制备的精馏控制系统及其控制方法 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111291836A (zh) * | 2020-03-31 | 2020-06-16 | 中国科学院计算技术研究所 | 一种生成学生网络模型的方法 |
CN111611377A (zh) * | 2020-04-22 | 2020-09-01 | 淮阴工学院 | 基于知识蒸馏的多层神经网络语言模型训练方法与装置 |
CN111738436A (zh) * | 2020-06-28 | 2020-10-02 | 电子科技大学中山学院 | 一种模型蒸馏方法、装置、电子设备及存储介质 |
CN111985523A (zh) * | 2020-06-28 | 2020-11-24 | 合肥工业大学 | 基于知识蒸馏训练的2指数幂深度神经网络量化方法 |
CN112132268A (zh) * | 2020-09-25 | 2020-12-25 | 交叉信息核心技术研究院(西安)有限公司 | 任务牵引的特征蒸馏深度神经网络学习训练方法及系统、可读存储介质 |
CN112183718A (zh) * | 2020-08-31 | 2021-01-05 | 华为技术有限公司 | 一种用于计算设备的深度学习训练方法和装置 |
-
2021
- 2021-01-25 CN CN202110098053.1A patent/CN112766463A/zh active Pending
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111291836A (zh) * | 2020-03-31 | 2020-06-16 | 中国科学院计算技术研究所 | 一种生成学生网络模型的方法 |
CN111611377A (zh) * | 2020-04-22 | 2020-09-01 | 淮阴工学院 | 基于知识蒸馏的多层神经网络语言模型训练方法与装置 |
CN111738436A (zh) * | 2020-06-28 | 2020-10-02 | 电子科技大学中山学院 | 一种模型蒸馏方法、装置、电子设备及存储介质 |
CN111985523A (zh) * | 2020-06-28 | 2020-11-24 | 合肥工业大学 | 基于知识蒸馏训练的2指数幂深度神经网络量化方法 |
CN112183718A (zh) * | 2020-08-31 | 2021-01-05 | 华为技术有限公司 | 一种用于计算设备的深度学习训练方法和装置 |
CN112132268A (zh) * | 2020-09-25 | 2020-12-25 | 交叉信息核心技术研究院(西安)有限公司 | 任务牵引的特征蒸馏深度神经网络学习训练方法及系统、可读存储介质 |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113343898A (zh) * | 2021-06-25 | 2021-09-03 | 江苏大学 | 基于知识蒸馏网络的口罩遮挡人脸识别方法、装置及设备 |
CN113762368A (zh) * | 2021-08-27 | 2021-12-07 | 北京市商汤科技开发有限公司 | 数据蒸馏的方法、装置、电子设备和存储介质 |
WO2023024406A1 (zh) * | 2021-08-27 | 2023-03-02 | 上海商汤智能科技有限公司 | 数据蒸馏的方法、装置、设备、存储介质、计算机程序及产品 |
CN113706347A (zh) * | 2021-08-31 | 2021-11-26 | 深圳壹账通智能科技有限公司 | 一种多任务模型蒸馏方法、系统、介质及电子终端 |
CN114995131A (zh) * | 2022-05-25 | 2022-09-02 | 福建德尔科技股份有限公司 | 用于电子级三氟甲烷制备的精馏控制系统及其控制方法 |
CN114995131B (zh) * | 2022-05-25 | 2023-02-03 | 福建德尔科技股份有限公司 | 用于电子级三氟甲烷制备的精馏控制系统及其控制方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112766463A (zh) | 基于知识蒸馏技术优化神经网络模型的方法 | |
US11836610B2 (en) | Concurrent training of functional subnetworks of a neural network | |
US10891540B2 (en) | Adaptive neural network management system | |
US11995539B2 (en) | Electronic apparatus and method for re-learning trained model | |
US20230127656A1 (en) | Method for managing training data | |
US20090204558A1 (en) | Method for training a learning machine having a deep multi-layered network with labeled and unlabeled training data | |
JP7037605B2 (ja) | データに対するラベル付けの優先順位を決める方法 | |
US20230145919A1 (en) | Method and apparatus for class incremental learning | |
US11790232B2 (en) | Method and apparatus with neural network data input and output control | |
CA3202896A1 (en) | Methods and systems for improved deep-learning models | |
Tambwekar et al. | Estimation and applications of quantiles in deep binary classification | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
US11568303B2 (en) | Electronic apparatus and control method thereof | |
US11887002B2 (en) | Method of generating data by using artificial neural network model having encoder-decoder structure | |
US20220188070A1 (en) | Method and apparatus with data processing | |
US20220180244A1 (en) | Inter-Feature Influence in Unlabeled Datasets | |
CN113609745A (zh) | 一种超参数寻优方法、装置及电子设备和存储介质 | |
CN114298197A (zh) | 增量学习方法、装置、电子设备及机器可读存储介质 | |
US20220405599A1 (en) | Automated design of architectures of artificial neural networks | |
US20240144021A1 (en) | Method and apparatus with machine learning model | |
CN113240565B (zh) | 基于量化模型的目标识别方法、装置、设备及存储介质 | |
KR102492277B1 (ko) | 멀티모달 정보를 이용한 질의응답 수행 방법 | |
US20240028902A1 (en) | Learning apparatus and method | |
CN107451662A (zh) | 优化样本向量的方法及装置、计算机设备 | |
González | Augmenting Deep Learning Models using Continual and Meta Learning Strategies |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |