CN113326940A - 基于多重知识迁移的知识蒸馏方法、装置、设备及介质 - Google Patents

基于多重知识迁移的知识蒸馏方法、装置、设备及介质 Download PDF

Info

Publication number
CN113326940A
CN113326940A CN202110712121.9A CN202110712121A CN113326940A CN 113326940 A CN113326940 A CN 113326940A CN 202110712121 A CN202110712121 A CN 202110712121A CN 113326940 A CN113326940 A CN 113326940A
Authority
CN
China
Prior art keywords
network
loss function
training
collaborative
knowledge
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110712121.9A
Other languages
English (en)
Inventor
苟建平
孙立媛
柯佳
夏书银
陈潇君
欧卫华
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jiangsu University
Original Assignee
Jiangsu University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jiangsu University filed Critical Jiangsu University
Priority to CN202110712121.9A priority Critical patent/CN113326940A/zh
Publication of CN113326940A publication Critical patent/CN113326940A/zh
Priority to CN202210535574.3A priority patent/CN114742224A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation
    • G06N5/022Knowledge engineering; Knowledge acquisition
    • G06N5/025Extracting rules from data
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computational Linguistics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • General Health & Medical Sciences (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及知识蒸馏技术领域,公开了一种基于多重知识迁移的知识蒸馏方法、装置、设备及介质,所述方法包括:构建两个网络组,网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与协同网络相同的预训练网络进行预训练获取教师网络;将训练数据输入两个网络组中获取各个协同网络和教师网络的输出结果,训练数据还包括对应的真实标签数据;基于与协同网络同组的教师网络的输出结果、两个协同网络的输出结果、两个协同网络迁移样本间的关系确定蒸馏损失函数;基于蒸馏损失函数对协同网络进行迭代训练。本发明可在节省网络模型训练时间成本的基础上,进一步提升网络模型的性能。

Description

基于多重知识迁移的知识蒸馏方法、装置、设备及介质
技术领域
本发明涉及知识蒸馏技术领域,具体是指基于多重知识迁移的知识蒸馏方法、装置、设备及介质。
背景技术
近年来,深度神经网络在计算机视觉、自然语言处理等诸多应用领域取得了突破性的进展。深度学习模型的显著性能依赖于设计更深或者更宽的网络结构,其中包含许多的层和大量的参数。然而,如此庞大的网络几乎不可能部署在计算资源和存储空间十分有限的平台上,例如移动设备和嵌入式系统。为了解决上述问题,人们提出了模型压缩技术。
目前提出的模型压缩技术大体可分为:低秩近似,网络剪枝,网络量化,以及知识蒸馏。低秩近似将网络权值矩阵看做满秩矩阵,因此可以用多个低秩矩阵来近似该矩阵。然而这种做法实现起来并不容易,不仅涉及计算成本高昂的分解操作,而且需要做大量的重新训练来使模型达到收敛。网络剪枝将权值矩阵中相对不重要的权值提出,然后重新对网络进行微调。然而这种做法会导致网络连接不规整,需要通过稀疏表达来减少内存占用,进而导致在前向传播时不适合并行计算。网络量化通过牺牲精度来降低每个权值需要的空间。这种做法在很多情况下,会很大程度的降低模型的表达能力。
知识蒸馏作为一种高效的模型压缩技术,近年来在深度学习中受到了广泛的关注。不同于上述方案,知识蒸馏不需要人为的改变网络结构,其成功的关键是把知识从一个大型、复杂的教师网络传递给一个小型、精简的学生网络。然而,现有的知识蒸馏方法大多只考虑通过特定的特征提取策略,即从实例特征或实例关系中获取某一类知识。由于教师网络和学生固定的网络结构,学生网络通过经常使用的离线蒸馏方式学到的知识是有限的。不仅如此,训练一个复杂而繁琐的教师网络需要大量的数据,但是在很多现实情况下(例如在医疗领域、军工领域等),由于种种因素大量数据难以获得。
发明内容
基于以上技术问题,为进一步提升知识蒸馏获得的学生网络的性能,本发明提供了一种基于多重知识迁移的知识蒸馏方法、装置、设备及介质,其协同网络之间互相通过在线蒸馏的协作式学习同时接收来自对方网络基于响应的知识和基于关系的知识,并且通过自蒸馏的自我学习方式来进一步提升自身性能,具体包括以下技术方案:
一种基于多重知识迁移的知识蒸馏方法,包括:
构建两个网络组,网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与协同网络相同的预训练网络进行预训练获取教师网络;
将训练数据输入两个网络组中获取各个协同网络和教师网络的输出结果,训练数据还包括对应的真实标签数据;
基于与协同网络同组的教师网络的输出结果、两个协同网络的输出结果、两个协同网络迁移样本间的关系确定蒸馏损失函数;
基于蒸馏损失函数对协同网络进行迭代训练。
一种基于多重知识迁移的知识蒸馏装置,包括:
网络构建模块,网络构建模块用于构建两个网络组,网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与协同网络相同的预训练网络进行预训练获取教师网络;
数据处理模块,数据处理模块用于将训练数据输入两个网络组中获取各个协同网络和教师网络的输出结果,训练数据还包括对应的真实标签数据;
损失函数确定模块,损失函数确定模块用于基于与协同网络同组的教师网络的输出结果、两个协同网络的输出结果、两个协同网络迁移样本间的关系确定蒸馏损失函数;
网络训练模块,网络训练模块用于基于蒸馏损失函数对协同网络进行迭代训练。
一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,处理器执行计算机程序时实现上述基于多重知识迁移的知识蒸馏方法的步骤。
一种计算机可读存储介质,其特征在于,计算机可读存储介质上存储有计算机程序,计算机程序被处理器执行时实现上述基于多重知识迁移的知识蒸馏方法的步骤。
与现有技术相比,本发明的有益效果是:
本发明使用两个协同网络(网络结构可以不同),有效的利用了多种知识,并在一个统一的框架下进行知识迁移。在训练时,协同网络之间互相通过在线蒸馏的协作式学习同时接收来自对方网络基于响应的知识和基于关系的知识,并且通过自蒸馏的自我学习方式来进一步提升自身性能。
现有的发明通常将教师网络中存在的某一种知识通过某一种学习方式迁移给协同网络,这种做法不仅需要提前训练好大型教师网络,而且使学生网络性能的提升十分有限。而本发明在训练的过程中,将网络中不同的知识通过不同的方式提取出来,然后迁移给另一个协同网络,在节省时间成本的同时使协同网络获得更高的性能。且本发明不需要预训练大型的教师网络,只对当前需要训练的预训练网络进行预训练,并在后续学习过程中用作教师网络即可,因此节省了时间成本。
附图说明
本申请将以示例性实施例的方式进一步说明,这些示例性实施例将通过附图进行详细描述,其中:
图1为基于多重知识迁移的知识蒸馏方法流程示意图。
图2为通过对与协同网络相同的预训练网络进行预训练获取教师网络流程示意图。
图3为基于多重知识迁移的知识蒸馏装置基本框架示意图。
图4为基于多重知识迁移的知识蒸馏方法训练获得的两个协同网络分别与其他知识蒸馏方法性能对比示意图。
具体实施方式
为使本公开实施例的目的、技术方案和优点更加清楚,下面将结合本公开实施例的附图,对本公开实施例的技术方案进行清楚、完整地描述。显然,所描述的实施例是本公开的一部分实施例,而不是全部的实施例。基于所描述的本公开的实施例,本领域普通技术人员在无需创造性劳动的前提下所获得的所有其他实施例,都属于本公开保护的范围。
本申请的目的在于提供一种基于多重知识迁移的知识蒸馏方法、装置、设备及介质,所述方法包括:构建两个网络组,网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与协同网络相同的预训练网络进行预训练获取教师网络;将训练数据输入两个网络组中获取各个协同网络和教师网络的输出结果,训练数据还包括对应的真实标签数据;基于与协同网络同组的教师网络的输出结果、两个协同网络的输出结果、两个协同网络迁移样本间的关系确定蒸馏损失函数;基于蒸馏损失函数对协同网络进行迭代训练。
本申请实施例可用于以下应用场景,包括但是不限于,计算机视觉应用领域的各种场景,例如人脸识别、图像分类、目标检测、语义分割等,或者是部署到边缘设备上(例如移动电话、可穿戴设备、计算节点等)的基于神经网络模型的处理系统,或者用于语音信号处理、自然语言处理、推荐系统的应用场景,或者是由于有限资源和时延要求需要对神经网络模型进行压缩的应用场景。
仅仅出于说明性目的,本申请实施例可用于手机端物体检测的应用场景。该应用场景需要解决的技术问题是:当用户使用手机拍照时,需要自动抓取人脸、动物等目标,从而帮助手机自动对焦、美化等,因此需要一个体积小、运行快的用于目标检测的卷积神经网络模型,进而给用户带来更好的用户体验并提升手机产品品质。
仅仅出于说明性目的,本申请实施例还可用于自动驾驶场景分割的应用场景。该应用场景需要解决的技术问题是:自动驾驶车辆的摄像头捕捉到道路画面后需要对画面进行分割,从中分出路面、路基、车辆、行人等不同物体,从而保持车辆行驶在正确的区域。因此需要能够快速实时对画面进行正确解读和语义分割的卷积神经网络模型。
仅仅出于说明性目的,本申请实施例还可用于入口闸机人脸验证的应用场景。该应用场景需要解决的技术问题是:在高铁、机场等入口的闸机上,乘客进行人脸认证时,摄像头会拍摄人脸图像并使用卷积神经网络抽取特征,然后和存储在系统中的身份证件的图像特征进行相似度计算;如果相似度高就验证成功。其中,通过卷积神经网络抽取特征是最耗时的,因此需要能够快速进行人脸验证和特征提取的高效的卷积神经网络模型。
仅仅出于说明性目的,本申请实施例还可用于翻译机同声传译的应用场景。该应用场景需要解决的技术问题是:在语音识别和机器翻译问题上,必须达到实时语音识别并进行翻译,因此需要高效的卷积神经网络模型。
本申请实施例可以依据具体应用环境进行调整和改进,此处不做具体限定。
为了使本技术领域的人员更好地理解本申请方案,下面将结合本申请实施例中的附图,对本申请的实施例进行描述。
参阅图1,在本实施方式中,基于多重知识迁移的知识蒸馏方法,包括:
S101,构建两个网络组,网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与协同网络相同的预训练网络进行预训练获取教师网络;
其中,协同网络相当于传统知识蒸馏方法中的学生网络;
其中,两个网络组中的协同网络其结构可以相同或者不相同。
S102,将训练数据输入两个网络组中获取各个协同网络和教师网络的输出结果,训练数据还包括对应的真实标签数据;
其中,训练数据是训练中所用的输入数据。优选的,训练数据可以根据教师网络和学生网络的输入层的输入格式进行预处理,得到规则化的训练数据;
其中,真实标签数据其获取方式可以采取人工标签的方式获得,也可以从现有的数据集中获取训练数据和真实标签数据。
其中,训练数据和真实标签数据的具体内容与教师网络和学生网络的具体应用场景相关,例如:在对象分类的应用场景中,训练数据可以是预先选取的样本对象的特征数据,真实标签数据可以是样本对象的分类标签;在图像分类的应用场景中,训练数据可以是样本图片,真实标签数据可以是样本图片的分类标签;
其中,令训练数据X={x1,x2,...,xn}是来自m个类别的n个样本,每个样本对应的真实标签数据标记为y={y1,y2,...,ym}。
S103,基于与协同网络同组的教师网络的输出结果、两个协同网络的输出结果、两个协同网络迁移样本间的关系确定蒸馏损失函数;
其中,对于第k个网络Nk输出的Logits标记为
Figure BDA0003133332610000051
其经过Softmax函数的输出为σi(zk(x),T),T为温度参数。在Logits与Softmax均是属于网络模型输出层中的内容,具体的,Logits表示未归一化的概率,即各个特征的加权之和,Logits经过Softmax函数后变为归一化的概率值。
S104,基于蒸馏损失函数对协同网络进行迭代训练。
其中,蒸馏损失函数用于更新优化协同网络的参数,在协同网络训练过程中的每次迭代中,通过最小化损失函数或以其他方式调整蒸馏损失函数的值,对应更新协同网络的参数,通过对协同网络进行多次迭代训练,以逐步协同网络的参数值趋向于拟合,该训练过程即监督学习的过程。
基于上述内容,在本实施例中,使用两个协同网络(网络结构可以不同),有效的利用了多种知识,并在一个统一的框架下进行知识迁移。在训练时,协同网络之间互相通过在线蒸馏的协作式学习同时接收来自对方网络基于响应的知识和基于关系的知识,并且通过自蒸馏的自我学习方式来进一步提升自身性能。
现有的发明通常将教师网络中存在的某一种知识通过某一种学习方式迁移给协同网络,这种做法不仅需要提前训练好大型教师网络,而且使学生网络性能的提升十分有限。而本发明在训练的过程中,将网络中不同的知识通过不同的方式提取出来,然后迁移给另一个协同网络,在节省时间成本的同时使协同网络获得更高的性能。且本申请不需要预训练大型的教师网络,只对当前需要训练的预训练网络进行预训练,并在后续学习过程中用作教师网络即可,因此节省了时间成本。
参阅图2,在一些实施例中,通过对与协同网络相同的预训练网络进行预训练获取教师网络包括:
S201,将训练数据输入预训练网络中获取预训练网络的输出结果;
S202,基于预训练网络的输出结果和真实标签数据确定交叉熵损失函数;
S203,基于交叉熵损失函数对预训练网络进行迭代训练获取教师网络。
在一些实施例中,基于与协同网络同组的教师网络输出结果、两个协同网络输出结果、两个协同网络迁移样本间关系确定蒸馏损失函数包括:
基于与协同网络同组的教师网络输出结果确定的第一损失函数;
基于两个协同网络输出结果确定的第二损失函数;
基于两个协同网络迁移样本间关系确定的第三损失函数。
在本实施例中,第一损失函数用于定义在网络进行自学习时,将知识从完成预训练的网络迁移至自身的损失函数;第二损失函数用于定义基于响应的知识;第三损失函数用于定于基于关系的知识。
其中,为了在训练过程中给当前在训练的协同网络提供正确的知识指导(教师网络是预训练好的,因此我们认为它输出的知识是相对正确的知识),第一损失函数将知识从预训练好的教师模型中提取出来,然后迁移至当前协同网络。第二损失函数迁移的是基于响应的知识,第三损失迁移的是基于关系的知识,两者的目的是使协同网络互相之间尽可能充分的学习来自对方网络多方面的知识,以共同提升性能。
具体的,第一损失函数为:
Figure BDA0003133332610000061
其中,LSD表示第一损失函数,X={x1,x2,...,xn}表示训练数据,训练数据X对应的真实标签数据标记为Y={y1,y2,...,ym},
Figure BDA0003133332610000062
表示教师网络t经过Logits函数的输出结果,σi(zt,T)表示教师网络在温度参数为T的条件下经过Softmax函数的输出结果,T表示温度参数。
相应的,
Figure BDA0003133332610000063
表示协同网络k经过Logits函数的输出结果,σi(zk,T)表示教师网络在温度参数为T的条件下经过Softmax函数的输出结果。
具体的,第二损失函数为:
Figure BDA0003133332610000064
其中,LKL表示第二损失函数,pk表示协同网络k的输出结果,pk′表示协同网络k′的输出结果,X={x1,x2,...,xn}是来自m个类别的n个样本,
Figure BDA0003133332610000065
表示协同网络k经过Logits函数的输出结果,15i(zk,1)表示协同网络k在温度参数为T=1的条件下经过Softmax函数的输出结果。
具体的,第三损失函数包括距离损失函数和角度损失函数,其中:
具体的,第三损失函数具体为:
LRD=LDD1LAD
其中,LRD表示第三损失函数,LDD表示距离损失函数,LAD表示角度损失函数,β1表示权重系数;
具体的,距离损失函数为:
Figure BDA0003133332610000066
其中,xu表示训练数据中第u个样本,
Figure BDA0003133332610000071
表示样本xu在协同网络k中间层输出的特征,
Figure BDA0003133332610000072
表示样本xu在协同网络k′中间层输出的特征;
其中,
Figure BDA0003133332610000073
表示在协同网络k、k′中样本之间的距离,具体为:
Figure BDA0003133332610000074
Figure BDA0003133332610000075
其中,
Figure BDA0003133332610000076
表示归一化常数,χn表示一个batch中n个不同样本的组合个数,χ2={(xu,xv)|u≠v}。R(·)表示Huber损失函数,其定义如下:
Figure BDA0003133332610000077
具体的,角度损失函数为:
Figure BDA0003133332610000078
其中,χ3={(xu,xv,xw)|u≠v≠w};
其中,
Figure BDA0003133332610000079
表示在协同网络k、k′中样本之间的角度,具体为:
Figure BDA00031333326100000710
Figure BDA00031333326100000711
其中,
Figure BDA00031333326100000712
Figure BDA00031333326100000713
Figure BDA00031333326100000714
表示记号。
将第一损失函数LSD、第二损失函数LKL和第三损失函数LRD加权求和获取所述蒸馏损失函数,具体为:
L=αLSD+βLKL+γLRD
其中,L表示蒸馏损失函数,α、β、γ为超参数表示权重系数。
下面将结合实验数据对本发明的基于多重知识迁移的知识蒸馏方法做进一步说明:
参阅图4,采用CIFAR-100数据集作为训练数据将其输入本申请的基于多重知识迁移的知识蒸馏方法与其他知识蒸馏方案中,最终获得图4中的基于多重知识迁移的知识蒸馏方法与其他知识蒸馏方法性能对比示意图,从图4中(a)、(b)两个对比结果可以看出,神经网络在经过本发明所提方法训练后,本申请获得的协同网络其准确性最高,性能相较其他方法有明显提升。
具体的,图4中CTSL-MKT表示本申请的基于多重知识迁移的知识蒸馏方法;
具体的,图4中(a)、(b)两个对比结果图中位于最上方的为本申请的基于多重知识迁移的知识蒸馏方法获取的协同网络的准确性结果;
具体的,图4中其他知识蒸馏方法包括DML、RKD和Tf-KD。
参与图3,在一些实施例中,本申请还公开了一种基于多重知识迁移的知识蒸馏装置,包括:
网络构建模块,网络构建模块用于构建两个网络组,网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与协同网络相同的预训练网络进行预训练获取教师网络;
数据处理模块,数据处理模块用于将训练数据输入两个网络组中获取各个协同网络和教师网络的输出结果,训练数据还包括对应的真实标签数据;
损失函数确定模块,损失函数确定模块用于基于与协同网络同组的教师网络的输出结果、两个协同网络的输出结果、两个协同网络迁移样本间的关系确定蒸馏损失函数;
网络训练模块,网络训练模块用于基于蒸馏损失函数对协同网络进行迭代训练。
在一些实施例中,本申请还公开了一种计算机设备,其特征在于,包括存储器和处理器,存储器中存储有计算机程序,处理器执行计算机程序时实现上述基于多重知识迁移的知识蒸馏方法的步骤。
其中,所述计算机设备可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机设备可以与用户通过键盘、鼠标、遥控器、触摸板或声控设备等方式进行人机交互。
所述存储器至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、硬盘、多媒体卡、卡型存储器(例如,SD或D界面显示存储器等)、随机访问存储器(RAM)、静态随机访问存储器(SRAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、可编程只读存储器(PROM)、磁性存储器、磁盘、光盘等。在一些实施例中,所述存储器可以是所述计算机设备的内部存储单元,例如该计算机设备的硬盘或内存。在另一些实施例中,所述存储器也可以是所述计算机设备的外部存储设备,例如该计算机设备上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。当然,所述存储器还可以既包括所述计算机设备的内部存储单元也包括其外部存储设备。本实施例中,所述存储器常用于存储安装于所述计算机设备的操作系统和各类应用软件,例如基于多重知识迁移的知识蒸馏方法的程序代码等。此外,所述存储器还可以用于暂时地存储已经输出或者将要输出的各类数据。
所述处理器在一些实施例中可以是中央处理器(Central Processing Unit,CPU)、控制器、微控制器、微处理器、或其他数据处理芯片。该处理器通常用于控制所述计算机设备的总体操作。本实施例中,所述处理器用于运行所述存储器中存储的程序代码或者处理数据,例如运行所述基于多重知识迁移的知识蒸馏方法的程序代码。
在一些实施例中,本申请还公开了一种计算机可读存储介质,其特征在于,计算机可读存储介质上存储有计算机程序,计算机程序被处理器执行时实现上述基于多重知识迁移的知识蒸馏方法的步骤。
其中,所述计算机可读存储介质存储有界面显示程序,所述界面显示程序可被至少一个处理器执行,以使所述至少一个处理器执行如上述的基于多重知识迁移的知识蒸馏方法的程序代码的步骤。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器或者网络设备等)执行本申请各个实施例所述的方法。
如上即为本发明的实施例。上述实施例以及实施例中的具体参数仅是为了清楚表述发明的验证过程,并非用以限制本发明的专利保护范围,本发明的专利保护范围仍然以其权利要求书为准,凡是运用本发明的说明书及附图内容所作的等同结构变化,同理均应包含在本发明的保护范围内。

Claims (9)

1.基于多重知识迁移的知识蒸馏方法,其特征在于,包括:
构建两个网络组,所述网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与所述协同网络相同的预训练网络进行预训练获取所述教师网络;
将训练数据输入两个所述网络组中获取各个协同网络和教师网络的输出结果,所述训练数据还包括对应的真实标签数据;
基于与所述协同网络同组的所述教师网络的输出结果、两个所述协同网络的输出结果、两个所述协同网络迁移样本间的关系确定蒸馏损失函数;
基于所述蒸馏损失函数对所述协同网络进行迭代训练。
2.根据权利要求1所述的基于多重知识迁移的知识蒸馏方法,其特征在于,通过对与所述协同网络相同的预训练网络进行预训练获取所述教师网络包括:
将所述训练数据输入所述预训练网络中获取所述预训练网络的输出结果;
基于所述预训练网络的输出结果和所述真实标签数据确定交叉熵损失函数;
基于所述交叉熵损失函数对所述预训练网络进行迭代训练获取所述教师网络。
3.根据权利要求1所述的基于多重知识迁移的知识蒸馏方法,其特征在于,基于与所述协同网络同组的所述教师网络输出结果、两个所述协同网络输出结果、两个所述协同网络迁移样本间关系确定蒸馏损失函数包括:
基于与所述协同网络同组的所述教师网络输出结果确定的第一损失函数;
基于两个所述协同网络输出结果确定的第二损失函数;
基于两个所述协同网络迁移样本间关系确定的第三损失函数;
将所述第一损失函数、所述第二损失函数和所述第三损失函数加权求和获取所述蒸馏损失函数。
4.根据权利要求3所述的基于多重知识迁移的知识蒸馏方法,其特征在于,所述第一损失函数为:
Figure FDA0003133332600000011
其中,LSD表示第一损失函数,X={x1,x2,...,xn}表示训练数据,训练数据X对应的真实标签数据标记为Y={y1,y2,...,ym},
Figure FDA0003133332600000012
表示教师网络t经过Logits函数的输出结果,σi(zt,T)表示教师网络在温度参数为T的条件下经过Softmax函数的输出结果,T表示温度参数。
5.根据权利要求3所述的基于多重知识迁移的知识蒸馏方法,其特征在于,所述第二损失函数为:
Figure FDA0003133332600000021
其中,LKL表示第二损失函数,pk表示协同网络k的输出结果,pk′表示协同网络k′的输出结果,X={x1,x2,...,xn}是来自m个类别的n个样本,
Figure FDA0003133332600000022
表示协同网络k经过Logits函数的输出结果,σi(zk,1)表示协同网络k在温度参数为T=1的条件下经过Softmax函数的输出结果。
6.根据权利要求3所述的基于多重知识迁移的知识蒸馏方法,其特征在于,所述第三损失函数包括距离损失函数和角度损失函数,其中:
所述第三损失函数具体为:
LRD=LDD+ηLAD
其中,LRD表示第三损失函数,LDD表示距离损失函数,LAD表示角度损失函数,η为超参数表示权重系数;
所述距离损失函数为:
Figure FDA0003133332600000023
其中,xu表示训练数据中第u个样本,
Figure FDA0003133332600000024
表示样本xu在协同网络k中间层输出的特征,
Figure FDA0003133332600000025
表示样本xu在协同网络k′中间层输出的特征;
其中,
Figure FDA0003133332600000026
表示在协同网络k、k′中样本之间的距离,具体为:
Figure FDA0003133332600000027
Figure FDA0003133332600000028
其中,
Figure FDA0003133332600000029
表示归一化常数,χn表示一个batch中n个不同样本的组合个数,χ2={(xu,xv)|u≠v}。R(·)表示Huber损失函数,其定义如下:
Figure FDA00031333326000000210
所述角度损失函数为:
Figure FDA00031333326000000211
其中,χ3={(xu,xv,xw)|u≠v≠w};
其中,
Figure FDA0003133332600000031
表示在协同网络k、k′中样本之间的角度,具体为:
Figure FDA0003133332600000032
Figure FDA0003133332600000033
其中,
Figure FDA0003133332600000034
Figure FDA0003133332600000035
表示记号。
7.一种基于多重知识迁移的知识蒸馏装置,其特征在于,包括:
网络构建模块,所述网络构建模块用于构建两个网络组,所述网络组包括未训练的协同网络和完成预训练的教师网络;其中,通过对与所述协同网络相同的预训练网络进行预训练获取所述教师网络;
数据处理模块,所述数据处理模块用于将训练数据输入两个所述网络组中获取各个协同网络和教师网络的输出结果,所述训练数据还包括对应的真实标签数据;
损失函数确定模块,所述损失函数确定模块用于基于与所述协同网络同组的所述教师网络的输出结果、两个所述协同网络的输出结果、两个所述协同网络迁移样本间的关系确定蒸馏损失函数;
网络训练模块,所述网络训练模块用于基于所述蒸馏损失函数对所述协同网络进行迭代训练。
8.一种计算机设备,其特征在于,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至5任一项所述的基于多重知识迁移的知识蒸馏方法的步骤。
9.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至5中任一项所述的基于多重知识迁移的知识蒸馏方法的步骤。
CN202110712121.9A 2021-06-25 2021-06-25 基于多重知识迁移的知识蒸馏方法、装置、设备及介质 Pending CN113326940A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202110712121.9A CN113326940A (zh) 2021-06-25 2021-06-25 基于多重知识迁移的知识蒸馏方法、装置、设备及介质
CN202210535574.3A CN114742224A (zh) 2021-06-25 2022-05-17 行人重识别方法、装置、计算机设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110712121.9A CN113326940A (zh) 2021-06-25 2021-06-25 基于多重知识迁移的知识蒸馏方法、装置、设备及介质

Publications (1)

Publication Number Publication Date
CN113326940A true CN113326940A (zh) 2021-08-31

Family

ID=77424821

Family Applications (2)

Application Number Title Priority Date Filing Date
CN202110712121.9A Pending CN113326940A (zh) 2021-06-25 2021-06-25 基于多重知识迁移的知识蒸馏方法、装置、设备及介质
CN202210535574.3A Pending CN114742224A (zh) 2021-06-25 2022-05-17 行人重识别方法、装置、计算机设备及存储介质

Family Applications After (1)

Application Number Title Priority Date Filing Date
CN202210535574.3A Pending CN114742224A (zh) 2021-06-25 2022-05-17 行人重识别方法、装置、计算机设备及存储介质

Country Status (1)

Country Link
CN (2) CN113326940A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113635310A (zh) * 2021-10-18 2021-11-12 中国科学院自动化研究所 模型迁移方法、装置

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117372785B (zh) * 2023-12-04 2024-03-26 吉林大学 一种基于特征簇中心压缩的图像分类方法
CN117612214B (zh) * 2024-01-23 2024-04-12 南京航空航天大学 一种基于知识蒸馏的行人搜索模型压缩方法

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN112508169A (zh) * 2020-11-13 2021-03-16 华为技术有限公司 知识蒸馏方法和系统

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN112508169A (zh) * 2020-11-13 2021-03-16 华为技术有限公司 知识蒸馏方法和系统

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113635310A (zh) * 2021-10-18 2021-11-12 中国科学院自动化研究所 模型迁移方法、装置
CN113635310B (zh) * 2021-10-18 2022-01-11 中国科学院自动化研究所 模型迁移方法、装置

Also Published As

Publication number Publication date
CN114742224A (zh) 2022-07-12

Similar Documents

Publication Publication Date Title
CN111444340B (zh) 文本分类方法、装置、设备及存储介质
CN111079532B (zh) 一种基于文本自编码器的视频内容描述方法
CN113344206A (zh) 融合通道与关系特征学习的知识蒸馏方法、装置及设备
WO2021159714A1 (zh) 一种数据处理方法及相关设备
CN111275046B (zh) 一种字符图像识别方法、装置、电子设备及存储介质
CN112685565A (zh) 基于多模态信息融合的文本分类方法、及其相关设备
CN113326940A (zh) 基于多重知识迁移的知识蒸馏方法、装置、设备及介质
CN112418292B (zh) 一种图像质量评价的方法、装置、计算机设备及存储介质
CN111275107A (zh) 一种基于迁移学习的多标签场景图像分类方法及装置
CN110569359B (zh) 识别模型的训练及应用方法、装置、计算设备及存储介质
CN113326941A (zh) 基于多层多注意力迁移的知识蒸馏方法、装置及设备
CN110162766B (zh) 词向量更新方法和装置
CN109214001A (zh) 一种中文语义匹配系统及方法
CN110210468B (zh) 一种基于卷积神经网络特征融合迁移的文字识别方法
WO2023134082A1 (zh) 图像描述语句生成模块的训练方法及装置、电子设备
CN110502610A (zh) 基于文本语义相似度的智能语音签名方法、装置及介质
CN114596566B (zh) 文本识别方法及相关装置
WO2022001232A1 (zh) 一种问答数据增强方法、装置、计算机设备及存储介质
WO2024041479A1 (zh) 一种数据处理方法及其装置
CN113011568B (zh) 一种模型的训练方法、数据处理方法及设备
CN113343898B (zh) 基于知识蒸馏网络的口罩遮挡人脸识别方法、装置及设备
CN112749556B (zh) 多语言模型的训练方法和装置、存储介质和电子设备
CN111488732A (zh) 一种变形关键词检测方法、系统及相关设备
CN113761868A (zh) 文本处理方法、装置、电子设备及可读存储介质
CN114282059A (zh) 视频检索的方法、装置、设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20210831

WD01 Invention patent application deemed withdrawn after publication