CN112784999A - 基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备 - Google Patents

基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备 Download PDF

Info

Publication number
CN112784999A
CN112784999A CN202110121769.9A CN202110121769A CN112784999A CN 112784999 A CN112784999 A CN 112784999A CN 202110121769 A CN202110121769 A CN 202110121769A CN 112784999 A CN112784999 A CN 112784999A
Authority
CN
China
Prior art keywords
loss
model
attention
simple model
layer
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110121769.9A
Other languages
English (en)
Inventor
黄明飞
姚宏贵
梁维斌
王昊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Open Intelligent Machine Shanghai Co ltd
Original Assignee
Open Intelligent Machine Shanghai Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Open Intelligent Machine Shanghai Co ltd filed Critical Open Intelligent Machine Shanghai Co ltd
Priority to CN202110121769.9A priority Critical patent/CN112784999A/zh
Publication of CN112784999A publication Critical patent/CN112784999A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提供了一种基于注意力机制的mobilenet‑v1知识蒸馏方法、存储器及终端设备,其中,包括:分别选择复杂模型WRN‑50‑8以及简单模型mobilenet‑v对应的特定中间层,用以进行注意图的知识转移;处理得到复杂模型和简单模型的中间层所对应的注意力图之间的损失,记为损失值一;处理获得复杂模型和简单模型的Logit层之间的KL散度;处理获得简单模型的交叉熵损失,记为损失值二;根据损失值一、KL散度及损失值二处理得到总损失;损失值一、RL散度、损失值二以及总损失用以简单模型的参数的计算。其技术方案的有益效果在于,与现有其他蒸馏方式相比,大幅提高mobilenet‑v1学生网络的识别精度和准确率,并可以将其部署在算力有限的设备。

Description

基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终 端设备
技术领域
本发明涉及深度学习模型压缩技术领域,尤其涉及基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备。
背景技术
知识蒸馏是将复杂模型(教师网络)中的暗知识(dark knowledge)迁移到简单模型(学生网络)中去,一般来说,复杂模型具有强大的能力和表现,而简单模型则更为紧凑。通过知识蒸馏,希望简单模型能尽可能逼近亦或是超过复杂模型,从而用更少的复杂度来获得类似的预测效果。(Geoffrey Hinton,Oriol Vinyals,Jeff Dean.“Distilling theKnowledge in a Neural Network”In NIPS,2014)首次提出了知识蒸馏的概念,通过引入教师网络的软目标(soft targets)以诱导学生网络的训练。近些年来出现了许多知识蒸馏的方法,而不同的方法对于网络中需要转移的暗知识定义也各不相同。(SergeyZagoruyko,Nikos Komodakis.“PAYING MORE ATTENTION TO ATTENTION:IMPROVING THEPERFORMANCE OF CONVOLUTIONAL NEURAL NETWORKS VIA ATTENTION TRANSFER”In ICLR,2017)首次提出利用注意力机制对WRN(Wide ResNet)网络进行蒸馏。
由于WRN网络结构依然很大,不适合部署在计算能力有限的设备(比如移动终端)。
发明内容
针对现有的在无法在计算能力有限的设备上部署WRN网络结存在的问题。现提供一种方便对对简单模型进行蒸馏以适应有限算力的端侧设备的基于注意力机制的mobilenet-v1知识蒸馏方法。
具体包括以下:
一种基于注意力机制的mobilenet-v1知识蒸馏方法,其中,包括:
分别选择复杂模型WRN-50-8以及简单模型mobilenet-v(MobileNets基于一种流线型结构使用深度可分离卷积来构造轻型权重深度神经网络。)对应的特定中间层,用以进行注意图的知识转移;
处理得到所述复杂模型和所述简单模型的中间层所对应的注意力图之间的损失,记为损失值一;
处理获得所述复杂模型和所述简单模型的Logit层之间的KL散度;
处理获得所述简单模型的交叉熵损失,记为损失值二;
根据所述损失值一、所述KL散度及所述损失值二处理得到总损失;
所述损失值一、所述RL散度、所述损失值二以及所述总损失用以所述简单模型的参数的计算。
优选的所述进行注意图的知识转移方法包括:
从所述复杂模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图一;
从所述简单模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图二;
将所述中间特征图一的知识转移给所述中间特征图二。
上述技术方案中,注意力图知识转移将教师网络中间层特征图经过计算获得教师网络中间层注意力图,再将与其对应的学生网络中间层特征图经过同样的计算过程得到学生网络中间层注意力图。
优选的,处理得到所述简单模型或所述复杂模型的中间层对应的注意力图的方法如下式所示:
设张量A∈RC*H*W为所述的简单模型或复杂模型的某个中间层特征图,即特征图A有C个通道,每个通道为H*W的二维矩阵,则注意力图按照如下公式计算:
Figure BDA0002922314740000021
其中,注意力图计算结果Q∈RH*W,A(i,:,:)表示第i个通道的H*W二维矩阵。
优选的,处理得到所述复杂模型和所述简单模型的中间层所对应的注意力图之间的损失的方法如下式所示:
Figure BDA0002922314740000031
其中,
Figure BDA0002922314740000032
表示复杂模型WRN-50-8的第j个注意力图,
Figure BDA0002922314740000033
表示对应的简单模型mobilenet-v1的第j个注意力图,||X||表示计算矩阵X的L2正则。
优选的,计算所述KL散度的方法包括:
所述复杂模型的logit层,是WRN-50-8网络的fc层的输出lT∈R1*1*10;所述的简单模型的logit层,是mobilenet-v1网络的fc层的输出lS∈R1*1*10
计算所述简单模型和所述复杂模型logit层之间的KL散度,如下式所示:
Figure BDA0002922314740000034
其中,lT[i]表示复杂模型fc层的输出lT的第i个值;
lS[i]表示简单模型fc层的输出lS的第i个值;T表示温度参数,这里取值为4。
优选的,处理获得所述简单模型的交叉熵损失的方法包括,将简单模型softmax层的输出与训练数据的真值标签计算交叉熵损失Lce
优选的,计算所述总损失的方法如下式所示:
ltotal=α*Lkl+(1-α)*Lce+β*LAT
其中,参数α取值为0.9,参数β取值为1000,Lkl表示KL散度,Lce表示交叉熵损失,LAT表示注意力图之间的损失。
还包括一种非易失性存储器,其中存储有软件,其中,所述软件用以实现权利要上述的基于注意力机制的mobilenet-v1知识蒸馏方法。
还包括一种终端设备,包括一个或多个处理器和与其耦合的一个或多个存储器,其中,所述一个或多个存储器用于存储计算机程序代码,所述计算机程序代码包括计算机指令;
所述一个或多个处理器用于执行所述计算机指令并实现上述的基于注意力机制的mobilenet-v1知识蒸馏方法。
上述技术方案具有如下优点或有益效果:与现有其他蒸馏方式相比,大幅提高mobilenet-v1学生网络的识别精度和准确率,并可以将其部署在算力有限的设备。
附图说明
图1是本发明中的一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例的流程示意图;
图2是本发明中的一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例中,关于复杂模型即教师网络WRN-50-8的结构示意图;
图3是本发明中的一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例中,关于简单模型即学生网络mobilenet-v的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动的前提下所获得的所有其他实施例,都属于本发明保护的范围。
需要说明的是,在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。
下面结合附图和具体实施例对本发明作进一步说明,但不作为本发明的限定。
具体包括以下内容:
一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例,其中,包括:
分别选择复杂模型即教师网络WRN-50-8以及简单模型即学生网络mobilenet-v对应的特定中间层,用以进行注意图的知识转移;
处理得到复杂模型和简单模型的中间层所对应的注意力图之间的损失,记为损失值一;
处理获得复杂模型和简单模型的Logit层之间的KL散度;
处理获得简单模型的交叉熵损失,记为损失值二;
根据损失值一、KL散度及损失值二处理得到总损失;
损失值一、RL散度、损失值二以及总损失用以简单模型的参数的计算。
具体步骤如下:
如图1所示,一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例,其中,包括:
S1、分别选择复杂模型WRN-50-8以及简单模型mobilenet-v对应的特定中间层,用以进行注意图的知识转移;
S2、处理得到复杂模型和简单模型的中间层所对应的注意力图之间的损失,记为损失值一;
S3、处理获得复杂模型和简单模型的Logit层之间的KL散度;
S4、处理获得简单模型的交叉熵损失,记为损失值二;
S5、根据损失值一、KL散度及损失值二处理得到总损失;
S6、损失值一、RL散度、损失值二以及总损失用以简单模型的参数的计算。
上述技术方案中,复杂模型即教师网络的结构如图2所示,简单模型即学生网络的结构如图3所示。
在一种较优的实施方式中,进行注意图的知识转移方法包括:
从复杂模型(教师网络WRN-50-8)的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图一,如图2所示,假设我们选择conv_2层、conv_3层、conv_4层输出作为计算注意力图的中间层特征图;
从简单模型(学生网络mobilenet-v)的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图二,如图2所示,假设我们选择group_0层、group_1层、group_2层输出作为计算注意力图的中间层特征图;
将中间特征图一的知识转移给中间特征图二。
在一种较优的实施方式中,处理得到简单模型(学生网络)或复杂模型(教师网络)的中间层对应的注意力图的方法如下式所示:
设张量A∈RC*H*W为的简单模型(学生网络)或复杂模型(教师网络)的某个中间层特征图,即特征图A有C个通道,每个通道为H*W的二维矩阵,则注意力图按照如下公式计算:
Figure BDA0002922314740000061
其中,注意力图计算结果Q∈RH*W,A(i,:,:)表示第i个通道的H*W二维矩阵。
在一种较优的实施方式中,处理得到复杂模型(教师网络)和简单模型(学生网络)的中间层所对应的注意力图之间的损失的方法如下式所示:
Figure BDA0002922314740000062
其中,
Figure BDA0002922314740000063
表示复杂模型(教师网络)WRN-50-8的第j个注意力图,
Figure BDA0002922314740000064
表示对应的简单模型(学生网络)mobilenet-v1的第j个注意力图,||X||表示计算矩阵X的L2正则。
在一种较优的实施方式中,计算KL散度的方法包括:
复杂模型(教师网络)的logit层,是WRN-50-8网络的fc层的输出lT∈R1*1*10;的简单模型(学生网络)的logit层,是mobilenet-v1网络的fc层的输出lS∈R1*1*10
计算简单模型(学生网络)和复杂模型(教师网络)logit层之间的KL散度,如下式所示:
Figure BDA0002922314740000065
其中,lT[i]表示复杂模型(教师网络)fc层的输出lT的第i个值;
lS[i]表示简单模型(学生网络)fc层的输出lS的第i个值;T表示温度参数,这里取值为4。
在一种较优的实施方式中,处理获得简单模型(学生网络)的交叉熵损失的方法包括,将简单模型(学生网络)softmax层的输出与训练数据的真值标签计算交叉熵损失Lce
在一种较优的实施方式中,计算总损失的方法如下式所示:
ltotal=α*Lkl+(1-α)*Lce+β*LAT
其中,参数α取值为0,9,参数β取值为1000,Lkl表示KL散度,Lce表示交叉熵损失,LAT表示注意力图之间的损失。
本发明的技术方案中还包括一种非易失性存储器,其中存储有软件,其中,软件用以实现权利要上述的基于注意力机制的mobilenet-v1知识蒸馏方法。
本发明的技术方案中还包括一种终端设备,包括一个或多个处理器和与其耦合的一个或多个存储器,其中,一个或多个存储器用于存储计算机程序代码,计算机程序代码包括计算机指令;
一个或多个处理器用于执行计算机指令并实现上述的基于注意力机制的mobilenet-v1知识蒸馏方法。
以上仅为本发明较佳的实施例,并非因此限制本发明的实施方式及保护范围,对于本领域技术人员而言,应当能够意识到凡运用本发明说明书及图示内容所作出的等同替换和显而易见的变化所得到的方案,均应当包含在本发明的保护范围内。

Claims (9)

1.一种基于注意力机制的mobilenet-v1知识蒸馏方法,其特征在于,包括:
分别选择复杂模型以及简单模型对应的特定中间层,用以进行注意图的知识转移;
分别处理得到所述复杂模型和所述简单模型的所述特定中间层所对应的注意力图之间的损失并记为第一损失值,根据所述第一损失值对所述简单模型中的所述特定中间层进行更新;
分别处理获得所述复杂模型和所述简单模型的Logit层的KL散度;
处理获得所述简单模型的交叉熵损失,记为损失值二;
根据所述损失值一、所述KL散度及所述损失值二处理得到总损失;
所述损失值一、所述RL散度、所述损失值二以及所述总损失用以所述简单模型的参数的计算。
2.根据权利1所述的方法,其特征在于,所述进行注意图的知识转移方法包括:
从所述复杂模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图一;
从所述简单模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图二;
将所述中间特征图一的知识转移给所述中间特征图二。
3.根据权利1所述的方法,其特征在于,处理得到所述简单模型或所述复杂模型的中间层对应的注意力图的方法如下式所示:
设张量A∈RC*H*W为所述的简单模型或复杂模型的某个中间层特征图,即特征图A有C个通道,每个通道为H*W的二维矩阵,则注意力图按照如下公式计算:
Figure FDA0002922314730000011
其中,注意力图计算结果Q∈RH*W,A(i,:,:)表示第i个通道的H*W二维矩阵。
4.根据权利1所述的方法,其特征在于,处理得到所述复杂模型和所述简单模型的中间层所对应的注意力图之间的损失的方法如下式所示:
Figure FDA0002922314730000021
其中,
Figure FDA0002922314730000022
表示复杂模型WRN-50-8的第j个注意力图,
Figure FDA0002922314730000023
表示对应的简单模型mobilenet-v1的第j个注意力图,||X||表示计算矩阵X的L2正则。
5.根据权利1所述的方法,其特征在于,计算所述KL散度的方法包括:
所述复杂模型的logit层,是WRN-50-8网络的fc层的输出lT∈R1*1*10;所述的简单模型的logit层,是mobilenet-v1网络的fc层的输出lS∈R1*1*10
计算所述简单模型和所述复杂模型logit层之间的KL散度,如下式所示:
Figure FDA0002922314730000024
其中,lT[i]表示复杂模型fc层的输出lT的第i个值;
lS[i]表示简单模型fc层的输出lS的第i个值;T表示温度参数,这里取值为4。
6.根据权利1所述的方法,其特征在于,处理获得所述简单模型的交叉熵损失的方法包括,将简单模型softmax层的输出与训练数据的真值标签计算交叉熵损失Lce
7.根据权利1所述的方法,其特征在于,计算所述总损失的方法如下式所示:
ltotal=α*Lkl+(1-α)*Lce+β*LAT
其中,参数α取值为0.9,参数β取值为1000,Lkl表示KL散度,Lce表示交叉熵损失,LAT表示注意力图之间的损失。
8.一种非易失性存储器,其中存储有软件,其特征在于,所述软件用以实现权利要求1-7中任一所述的基于注意力机制的mobilenet-v1知识蒸馏方法。
9.一种终端设备,包括一个或多个处理器和与其耦合的一个或多个存储器,其特征在于,所述一个或多个存储器用于存储计算机程序代码,所述计算机程序代码包括计算机指令;
所述一个或多个处理器用于执行所述计算机指令并实现权利要求1-7中任一所述的基于注意力机制的mobilenet-v1知识蒸馏方法。
CN202110121769.9A 2021-01-28 2021-01-28 基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备 Pending CN112784999A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110121769.9A CN112784999A (zh) 2021-01-28 2021-01-28 基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110121769.9A CN112784999A (zh) 2021-01-28 2021-01-28 基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备

Publications (1)

Publication Number Publication Date
CN112784999A true CN112784999A (zh) 2021-05-11

Family

ID=75759587

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110121769.9A Pending CN112784999A (zh) 2021-01-28 2021-01-28 基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备

Country Status (1)

Country Link
CN (1) CN112784999A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113807215A (zh) * 2021-08-31 2021-12-17 贵州大学 一种结合改进注意力机制和知识蒸馏的茶叶嫩芽分级方法
CN116385794A (zh) * 2023-04-11 2023-07-04 河海大学 基于注意力流转移互蒸馏的机器人巡检缺陷分类方法及装置

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107247989A (zh) * 2017-06-15 2017-10-13 北京图森未来科技有限公司 一种神经网络训练方法及装置
CN111062489A (zh) * 2019-12-11 2020-04-24 北京知道智慧信息技术有限公司 一种基于知识蒸馏的多语言模型压缩方法、装置
CN111126599A (zh) * 2019-12-20 2020-05-08 复旦大学 一种基于迁移学习的神经网络权重初始化方法
CN111554268A (zh) * 2020-07-13 2020-08-18 腾讯科技(深圳)有限公司 基于语言模型的语言识别方法、文本分类方法和装置
US20200302295A1 (en) * 2019-03-22 2020-09-24 Royal Bank Of Canada System and method for knowledge distillation between neural networks
CN111950302A (zh) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 基于知识蒸馏的机器翻译模型训练方法、装置、设备及介质

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107247989A (zh) * 2017-06-15 2017-10-13 北京图森未来科技有限公司 一种神经网络训练方法及装置
US20200302295A1 (en) * 2019-03-22 2020-09-24 Royal Bank Of Canada System and method for knowledge distillation between neural networks
CN111062489A (zh) * 2019-12-11 2020-04-24 北京知道智慧信息技术有限公司 一种基于知识蒸馏的多语言模型压缩方法、装置
CN111126599A (zh) * 2019-12-20 2020-05-08 复旦大学 一种基于迁移学习的神经网络权重初始化方法
CN111554268A (zh) * 2020-07-13 2020-08-18 腾讯科技(深圳)有限公司 基于语言模型的语言识别方法、文本分类方法和装置
CN111950302A (zh) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 基于知识蒸馏的机器翻译模型训练方法、装置、设备及介质

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113807215A (zh) * 2021-08-31 2021-12-17 贵州大学 一种结合改进注意力机制和知识蒸馏的茶叶嫩芽分级方法
CN113807215B (zh) * 2021-08-31 2022-05-13 贵州大学 一种结合改进注意力机制和知识蒸馏的茶叶嫩芽分级方法
CN116385794A (zh) * 2023-04-11 2023-07-04 河海大学 基于注意力流转移互蒸馏的机器人巡检缺陷分类方法及装置
CN116385794B (zh) * 2023-04-11 2024-04-05 河海大学 基于注意力流转移互蒸馏的机器人巡检缺陷分类方法及装置

Similar Documents

Publication Publication Date Title
US20230196117A1 (en) Training method for semi-supervised learning model, image processing method, and device
CN111797589B (zh) 一种文本处理网络、神经网络训练的方法以及相关设备
CN110866190A (zh) 训练用于表征知识图谱的图神经网络模型的方法及装置
CN116415654A (zh) 一种数据处理方法及相关设备
CN112733768B (zh) 基于双向特征语言模型的自然场景文本识别方法及装置
CN113486665B (zh) 隐私保护文本命名实体识别方法、装置、设备及存储介质
CN111260919B (zh) 交通流量预测方法
CN113190688B (zh) 基于逻辑推理和图卷积的复杂网络链接预测方法及系统
CN112199532B (zh) 一种基于哈希编码和图注意力机制的零样本图像检索方法及装置
CN111797970B (zh) 训练神经网络的方法和装置
CN113392359A (zh) 多目标预测方法、装置、设备及存储介质
CN111241306B (zh) 一种基于知识图谱和指针网络的路径规划方法
CN113627545B (zh) 一种基于同构多教师指导知识蒸馏的图像分类方法及系统
CN113065013B (zh) 图像标注模型训练和图像标注方法、系统、设备及介质
CN112784999A (zh) 基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备
CN115390164B (zh) 一种雷达回波外推预报方法及系统
CN113257361B (zh) 自适应蛋白质预测框架的实现方法、装置及设备
CN115170565B (zh) 基于自动神经网络架构搜索的图像欺诈检测方法及装置
CN115017178A (zh) 数据到文本生成模型的训练方法和装置
CN116992151A (zh) 一种基于双塔图卷积神经网络的在线课程推荐方法
CN113989566A (zh) 一种图像分类方法、装置、计算机设备和存储介质
CN116975686A (zh) 训练学生模型的方法、行为预测方法和装置
WO2023143570A1 (zh) 一种连接关系预测方法及相关设备
CN114386527B (zh) 一种用于域自适应目标检测的类别正则化方法及系统
CN113255899B (zh) 一种通道自关联的知识蒸馏方法与系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination