CN115062769A - 基于知识蒸馏的模型训练方法、装置、设备及存储介质 - Google Patents
基于知识蒸馏的模型训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN115062769A CN115062769A CN202210816261.5A CN202210816261A CN115062769A CN 115062769 A CN115062769 A CN 115062769A CN 202210816261 A CN202210816261 A CN 202210816261A CN 115062769 A CN115062769 A CN 115062769A
- Authority
- CN
- China
- Prior art keywords
- model
- loss function
- similarity
- target
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Feedback Control In General (AREA)
Abstract
本发明适用于人工智能技术领域,尤其涉及一种基于知识蒸馏的模型训练方法、装置、设备及存储介质,该方法通过获取满足目标条件的第一模型和不满足目标条件的第二模型,根据第一模型的输出构建的优化损失函数,得到更新后的第二模型,计算并确定目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数,对第二模型进行训练,得到满足目标条件的第二模型。将第一模型与第二模型中网络层之间的表征相似度作为第二模型中损失函数的一部分,充分利用了网络中间层的信息,增加了第二模型学习第一模型的能力和范围,从而提高第二模型训练时的稳定性和收敛性。
Description
技术领域
本发明涉及人工智能领域,尤其涉及一种基于知识蒸馏的模型训练方法、装置、设备及存储介质。
背景技术
目前,深度学习神经网络已成功应用于各种计算机视觉应用,如图像分类、对象检测和语义分割,大型的深度学习模型训练必须从非常大的、高度冗余的数据集中训练得到,但是数据集的数据量较大时模型训练需要占据大量的时间和存储空间,因此,为了缩短训练时间和减少资源占据,使用知识蒸馏方法对大型深度学习网络进行压缩得到了广泛运用,知识蒸馏方法对教师网络与学生网络的匹配度要求较高,而当前的知识蒸馏方法只会对训练样本集的标签进行优化,无法应对教师网络与学生网络的匹配程度不高的情况,导致的训练过程不稳定、不收敛的问题。因此,如何改进知识蒸馏的训练过程,以提高学生网络训练过程的稳定性、收敛性成为亟待解决的问题。
发明内容
基于此,有必要针对上述技术问题,提供一种基于知识蒸馏的模型训练方法、装置、设备及存储介质,以解决训练过程中不稳定、不收敛的问题。
本申请实施例的第一方面提供了一种基于知识蒸馏的模型训练方法,所述方法包括:
获取满足目标条件的第一模型和不满足目标条件的第二模型,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数;
根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型;
计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数;
使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
本申请实施例的第二方面提供了一种基于知识蒸馏的模型训练装置,所述装置包括:
获取模型模块,用于获取满足目标条件的第一模型和不满足所述目标条件的第二模型,所述第一模型包括M个网络层,所述第二模型包括N个网络层,N、M均为大于零的整数;
更新模块,用于根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型;
目标相似度确定模块,用于计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
目标损失函数确定模块,用于根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数;
训练模块,用于使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
第三方面,本发明实施例提供一种计算机设备,所述计算机设备包括处理器、存储器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如第一方面所述的基于知识蒸馏的模型训练方法。
第四方面,本发明实施例提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如第一方面所述的基于知识蒸馏的模型训练方法。
本发明与现有技术相比存在的有益效果是:
本发明通过获取满足目标条件的第一模型和不满足目标条件的第二模型,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数,根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型,计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数,使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。将第一模型与第二模型中网络层之间的表征相似度作为第二模型中损失函数的一部分,充分利用了网络中间层的信息,增加了第二模型学习第一模型的能力和范围,从而提高第二模型训练时的稳定性和收敛性。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的一应用环境示意图;
图2是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图;
图3是本发明一实施例提供的一种基于知识蒸馏的模型训练装置的结构示意图;
图4是本发明一实施例提供的一种计算机设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本发明说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本发明说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
另外,在本发明说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本发明说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本发明的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
本发明实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
应理解,以下实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
为了说明本发明的技术方案,下面通过具体实施例来进行说明。
本发明一实施例提供的一种基于知识蒸馏的模型训练方法,可应用在如图1的应用环境中,其中,客户端与服务端进行通信。其中,客户端包括但不限于掌上电脑、桌上型计算机、笔记本电脑、超级移动个人计算机(ultra-mobile personal computer,UMPC)、上网本、个人数字助理(personal digital assistant,PDA)等计算机设备。服务端可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
参见图2,是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图,上述基于知识蒸馏的模型训练可以应用于图1中的服务端,上述服务端连接相应的客户端,为客户端提供模型训练服务。如图2所示,该基于知识蒸馏的模型训练方法可以包括以下步骤。
S201:获取满足目标条件的第一模型和不满足目标条件的第二模型。
在步骤S201中,满足目标条件的第一模型和不满足目标条件的第二模型为深度学习卷积神经网络模型,通过使用训练集对第一模型进行训练得到满足目标条件的第一模型,其中,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数。满足目标条件的第一模型为训练好的一个深度学习模型,不满足目标条件的第二模型为未训练完成的深度学习模型,第二模型是为了学习第一模型中的参数,节省训练时间。本实施例中,第一模型为知识蒸馏中教师模型,第二模型为知识蒸馏中的学生模型。获取满足条件的第一模型时,可以首先对训练集中的样本数据进行人工标注,然后利用标注的样本数据对第一模型进行训练,该训练过程可以理解为预训练过程。在训练的过程中,可以利用第一模型的损失函数计算第一模型的实际输出与标注结果之间的损失值,根据损失值进行反向传播训练,至损失收敛,得到满足条件的第一模型。例如,上述第一模型用于声纹识别,满足条件的第一模型实际输出可以为声纹特征向量。
需要说明的是,第一模型和第二模型可以是同类型的神经网络模型,即第一模型和第二模型具有相同的网络层结构,第一模型和第二模型也可以是不同类型的神经网络模型,即第一模型和第二模型的网络层结构有差异。但是,第一模型的网络层层数小于第二模型的网络层层数,也就是,第一模型的模型大小和模型参数量小于第二模型。由于第二模型的模型参数量较多,在实际应用时,输出预测结果需要耗费较多时间。为了提高预测结果的输出速度和节约计算机资源,可以对大模型进行知识蒸馏得到轻量级的小模型,从而在实际应用时,基于知识蒸馏的模型训练将复杂、学习能力强的第一网络模型已经学习到的特征表示知识蒸馏出来,传递给参数量小、学习能力弱的第二模型。
需要说明的是,本申请实施例中,上述训练集中的样本数据可以为是对录音数据进行预处理后得到的数据。例如,可以将声纹标注完成的40小时的客服录音数据通过加噪声、加快语速和增加数据扰动等方式进行数据扩增,获得数据集,并按照训练集和测试集为8:2的比例进行数据划分,划分同时充分考虑说话人信息,做到训练集与测试集的说话人语音分开。读取训练集中的录音文件形成数据标签(data-label)的特征数据组合。该特征数据组合可以理解为训练集中的样本数据。
S202:根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型。
在步骤S202中,第一模型与第二模型具有相似的结构的网络模型,将第一模型中的输出结果指导第二模型的训练,简化第二模型的训练过程。其中根据第一模型的输出构建优化的损失函数,将优化的损失函数作为第二模型训练过程中的损失函数,得到更新后的第二模型。
本实施例中,通过使用复杂且性能和泛化能力较高的网络的输出和真实的标签数据来训练简单网络,使简单网络的性能得到提升。假设有单个或者多个网络复杂且性能良好的模型,记为第一模型,有一个网络层少且学习能力低的模型,记为第二模型。使用第一模型训练学到的知识作为第二模型训练的目标,得到更新后的第二模型,更新后的第二模型的性能接近第一模型的性能,但是和第一模型相比,第二模型的参数量少,训练时间比第一模型短,从而相当于实现大模型的压缩和加速,提升小模型的精度。
例如,第一模型可以基于残差网络ResNet34构建;第二模型可以基于ResNet10构建。由于较大的网络往往面临着深度学习网络模型大而冗余,识别速度难以满足实时性要求的问题,而小网络模型非常容易会因为参数量较小,模型特征表示能力不足,导致模型准确性能低下,带来的问题就是虽然满足线上应用的实时性要求,但却无法满足准确性要求。因此,通过第一模型对的知识对第二模型进行训练,可以使的大的网络模型对小的网络模型起到正向作用,使得第二模型能够获得较优的拟合参数,进而提升第二模型的准确度。
需要说明的是,由于第一模型的网络层层数大于第二模型的网络层层数,因此,可以采用隔层蒸馏方法进行知识蒸馏,确定第二网络模型的各个网络层与第一模型的网络层的对应关系,让第二模型的网络层学习拟合第一模型对应的网络层,即将第一网络模型对应的网络层经过知识蒸馏压缩为第二模型的网络层。由于第一模型的网络层层数大于第二模型的网络层层数,因此第一模型的网络层是间隔对应第二模型的网络层。例如,当第一模型包括24层网络层、第二模型包括12层网络层时,可以是第二模型的第1层对应第一模型的第2层,第二模型的第2层对应第一模型的第4层,第二模型的第3层对应第一模型的第6层,第二模型的第4层对应第一模型的第8层,以此类推。
可选地,根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型,包括:
将带有原始标签的第一训练样本输入至第一模型中,输出第一训练样本对应的新标签数据,得到第二训练样本;
利用第一训练样本与第二训练样本,分别对第二模型进行训练,得到第一知识蒸馏损失函数与第二知识蒸馏损失函数;
通过第一知识蒸馏损失函数与第二知识蒸馏损失函数,构建优化损失函数;
根据优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型。
本实施例中,将原始数据的真实标签记为Hard-target,将第一模型的输出概率记为soft-target,由于第一模型的参数多,得到的表征信息比较多,而真实的标签Hard-target所含的表征信息很少,例如,在二分类任务中,假设第一模型的softmax输出为[0.995,0.005],负例输出携带0.005的样本信息,可以显示两类之间的相似性,而真实的标签只有[1,0],负例为0不包含有用信息。引入超参数温度T,对soft-target进行平滑操作,当T趋向于0时,结果中最大的值会接近1,另一个为0,当T越大时,两个输出结果分布越平缓,两个概率值的差距越小,使得保留的相似信息越多。将第一模型的输出概率作为第二模型的学习知识,为此减少卷积神经网络的参数量和计算量,避免巨大的计算开销。
需要说明的是,也可以使用第一模型的输出作为知识训练第二模型,而是使用第一模型的中间层特征输入到第二模型中间层的特征中,这种方法允许第二模型的网络层比第一模型的网络层多,但是中间的神经元应比第一模型的神经元少。
可选地,根据所述第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型,包括:
将带有原始标签的第一训练样本输入至第一模型中,以所述第一模型输出的新标签更新所述第一训练样本对应的原始标签,得到第二训练样本;
利用第一训练样本与第二训练样本,分别对第二模型进行训练,得到第一知识蒸馏损失函数与第二知识蒸馏损失函数;
通过第一知识蒸馏损失函数与第二知识蒸馏损失函数,构建优化损失函数。
本实施例中,根据第一知识蒸馏损失函数与第二知识蒸馏损失函数构建优化损失函数时,对第一知识蒸馏损失函数与第二知识蒸馏损失函数随机设置不同的参数,通过梯度下降算法,得到目标参数,为了使得第二模型可以学到更多来自第一模型的指导信息,设置参数时,对第一知识蒸馏损失函数设置的参数一般大于第二知识蒸馏损失函数对应的参数。
需要说明的是,在梯度下降过程中,需要选取一个合适的步长,一个合适的步长能够让迭代次数大大降低,甚至可以更方便收敛到全局最优解,如果在梯度下降过程中我们将步长选取过大,有可能会使梯度直接跳过局部最小值或者直接呈发散,如果选取步长过小,那么可能会大大降低收敛速度,会在收敛过程中耗费过多的不必要时间。
S203:计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,确定目标表征相似度。
在步骤S203中,使用中心核对齐算法计算网络层之间的相似度,通过相似度判断第二模型的学习拟合第一模型的程度。当存在对应关系的两个网络层各自输出的特征矩阵较为相似时,可以表明该存在对应关系的两个网络层之间的参数较为相似,也就是,第二模型的网络层较为成功地学习拟合了第一模型对应的网络层。第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度。
本实施例中,通过线性CKA(centered kernel alignment)算法计算不同网络层之间的相似性,线性CKA算法能够确定基于不同随机初始化和不同宽度训练的神经网络的隐藏层之间的对应关系。当第一模型与第二模型中的而网络层数不相同时,第一模型与第二模型中各个网络层之间不是一一对应的关系,需要根据第一模型与第二模型不同网络层之间的相似度,找到对应关系,该相似度为目标相似度。
可选地,计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,包括:
分别获取第一模型中M个网络层与更新后的第二模型中N个网络层中每个网络层的特征矩阵;
计算所述第一模型中M个网络层的特征矩阵分别与所述更新后的第二模型中N个网络层的特征矩阵的表征相似度,得到所述第一模型中每个网络层对应的表征相似度序列;
从第一模型中每个网络层对应的表征相似度序列中确定目标表征相似度。
本实施例中,使用相同的训练集对第一模型和更新后的第二模型进行训练,将训练集中数据分别输入至第一模型和更新后的第二模型中,获取第二模型的各个网络层输出的特征矩阵,和第一模型的各个网络层输出的特征矩阵,计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,得到第一模型中每个网络层对应的表征相似度序列,从第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度。
需要说明的是,当计算表征相似度时,也可以隔层计算,例如,当第二模型的第1层对应第一模型的第3层,第2层对应第6层,第3层对应第9层,第4层对应第12层时,可以计算第二模型的第1层输出的特征矩阵和第一模型的第3层输出的特征矩阵之间的相似度,计算第二模型的第2层输出的特征矩阵和第一模型的第6层输出的特征矩阵之间的相似度,计算第二模型的第3层输出的特征矩阵和第一模型的第9层输出的特征矩阵之间的相似度,计算第二模型的第4层输出的特征矩阵和第一模型的第12层输出的特征矩阵之间的相似度,计算得到的相似度为目标相似度。
可选地,从第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度,包括:
从第一模型中每个网络层对应的表征相似度序列中获取表征相似度最大值,将表征相似度最大值作为目标相似度。
本实施例中,当表征相似度最大时,认为第二模型中网络层从对应第一模型中网络层学习到的知识越多,可以提取到与第一模型相似的特征。
S204:根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数。
在步骤S204中,相似度损失函数为第二模型知识蒸馏后学到的网络层输出特征与第一模型中对应网络层输出特征的差异值,根据相似度损失函数以及优化损失函数计算得到目标损失函数,目标损失函数为第一模型与第二模型之间网络层的输出特征差异与整个模型输出特征之间的差异之和。
本实施例中,根据第一模型与第二模型之间的相似度构建相似度损失函数,相似度损失函数为第二模型知识蒸馏后学到的网络层输出特征与第一模型中对应网络层输出特征的差异值,当相似度越大时,认为第一模型与第二模型之间网络层之间的差异越小,当相似度越小时,认为第一模型与第二模型之间网络层之间的差异越大,基于此,构建相似度损失函数。利用相似度损失函数与优化损失函数加权求和作为目标损失函数。不同的损失函数可帮助神经网络模型学习到不同的知识,目标损失函数在优化损失函数的基础上,加入了相似度损失函数,扩大了模型的学习范围。
可选地,根据目标相似度,构建相似度损失函数,包括:
根据目标相似度,计算第二模型中每个网络层的损失值;
基于第二模型中每个网络层的损失值,构建相似度损失函数。
本实施例中,基于第二模型中每层网络的损失构建相似度损失函数,第一模型与第二模型中每层网络层中的表征相似度的值取值范围为(0,1),当表征相似度的大小越接近1时,认为第一模型与第二模型中对应网络层越相近,当表征相似度的大小约接近0时,认为第一模型与第二模型中对应网络层差距越大。所以根据目标相似度,得到第二模型中每层网络在第一模型中的对应网络层,将目标相似度与1的差值,作为第二模型中网络层的损失值,则第二模型中的相似度损失函数为第二模型中每层网络层损失值的和。
S205:使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
在步骤S205中,使用训练集对第二模型进行训练时,可以进行有监督的训练,设置对应的阈值,当训练结果与监督标签的差值小于阈值时,认为目标损失函数收敛,得到满足目标条件的第二模型。
本实施例中,使用训练集对第二模型进行训练时,针对第二模型中目标损失函数,得到的训练结果与训练集中的标签信息通过目标损失函数计算损失值,判断损失值是否符合预设条件,当未符合预设条件时,根据损失值对第二模型进行反向传播更新,得到更新模型参数的第二模型,再次基于训练集对更新模型参数的第二模型进行训练,直至损失值符合预设条件,得到已得到满足目标条件的第二模型。
可选地,使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型,包括:
根据训练集中的正负样本,构建样本对;样本对至少包括一个正样本与一个负样本;
基于样本对,对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
本实施例中,使用训练集对第二模型进行训练时,为了加快收敛速度,可以使用样本对进行训练,每个样本对选定相同的正样本和负样本,由于样本反传梯度的大小由每一个正负样本对的差累计决定,这导致了一旦样本对中的样本数量过多时,在特征空间中找到满足所有样本对训练分界面会变得很困难,训练的收敛性将会随之变差,与此同时,每次重复计算两两样本间的距离差作为反传梯度会导致训练的冗余和训练时间的加长。由于不同样本对中的某一对正负样本的大小比较可能会出现多次,过多的增加样本对中的正负样本的数量对训练的帮助收益甚微,所以每个样本对中可以选择一个正样本和一个负样本。使用选定的样本对,对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
本发明通过获取满足目标条件的第一模型和不满足目标条件的第二模型,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数;根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型,计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数,使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。将第一模型与第二模型中网络层之间的表征相似度作为第二模型中损失函数的一部分,充分利用了网络中间层的信息,增加了第二模型学习第一模型的能力和范围,对第二模型进行训练时,可以提高第二模型的稳定性和收敛性。
请参阅图3,图3是本发明实施例提供的一种基于知识蒸馏的模型训练装置的结构示意图。本实施例中该终端包括的各单元用于执行图2对应的实施例中的各步骤。具体请参阅图2以及图2所对应的实施例中的相关描述。为了便于说明,仅示出了与本实施例相关的部分。
参见图3,模型训练装置30包括:
获取模型模块31,用于获取满足目标条件的第一模型和不满足目标条件的第二模型,第一模型包括M个网络层,第二模型包括N个网络层,N、M均为大于零的整数;
更新模块32,用于根据第一模型的输出构建的优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型;
目标相似度确定模块33,用于计算第一模型中M个网络层分别与更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
目标损失函数确定模块34,用于根据目标相似度,构建相似度损失函数,并将相似度损失函数与优化损失函数的和作为目标损失函数;
训练模块35,用于使用训练集对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
可选的是,上述更新模块32包括:
第二训练样本获取单元,用于将带有原始标签的第一训练样本输入至第一模型中,以所述第一模型输出的新标签更新所述第一训练样本对应的原始标签,得到第二训练样本;
知识蒸馏损失函数获取单元,用于利用第一训练样本与第二训练样本,分别对第二模型进行训练,得到第一知识蒸馏损失函数与第二知识蒸馏损失函数;
优化损失函数获取单元,用于通过第一知识蒸馏损失函数与第二知识蒸馏损失函数,构建优化损失函数;
更新后第二模型获取单元,用于根据优化损失函数,更新第二模型的初始损失函数,得到更新后的第二模型。
可选的是,上述优化损失函数获取单元包括:
初始损失函数获取子单元,用于对第一知识蒸馏损失函数与第二知识蒸馏损失函数设置不同的初始参数,得到初始蒸馏损失函数;
构建子单元,用于使用梯度下降算法对所述初始蒸馏损失函数进行参数更新,得到目标参数,使用所述目标参数更新初始蒸馏损失函数,得到优化损失函数。
可选的是,上述目标相似度确定模块33包括:
特征矩阵获取单元,用于分别获取第一模型中M个网络层与更新后的第二模型中N个网络层中每个网络层的特征矩阵;
表征相似度序列获取单元,用于通
计算所述第一模型中M个网络层的特征矩阵分别与所述更新后的第二模型中N个网络层的特征矩阵的表征相似度,得到所述第一模型中每个网络层对应的表征相似度序列;
目标相似度获取单元,用于从第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度。
可选的是,上述目标相似度获取单元包括:
目标相似度确定子单元,用于从第一模型中每个网络层对应的表征相似度序列中获取表征相似度最大值,将表征相似度最大值作为目标相似度。
可选的是,上述目标损失函数确定模块34包括:
每个网络层的损失值确定单元,用于根据目标相似度,计算第二模型中每个网络层的损失值;
相似度损失函数构建单元,用于基于第二模型中每个网络层的损失值,构建相似度损失函数。
可选的是,上述训练模块35包括:
样本对构建单元,用于根据训练集中的正负样本,构建样本对;样本对至少包括一个正样本与一个负样本;
满足目标条件的第二模型获取单元,用于基于样本对,对第二模型进行训练,直至目标损失函数收敛,得到满足目标条件的第二模型。
图4是本发明实施例提供的一种计算机设备的结构示意图。如图4所示,该实施例的计算机设备包括:至少一个处理器(图4中仅示出一个)、存储器以及存储在存储器中并可在至少一个处理器上运行的计算机程序,处理器执行计算机程序时实现上述任意各个基于知识蒸馏的模型训练方法实施例中的步骤。
该计算机设备可包括,但不仅限于,处理器、存储器。本领域技术人员可以理解,图4仅仅是计算机设备的举例,并不构成对计算机设备的限定,计算机设备可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如还可以包括网络接口、显示屏和输入装置等。
所称处理器可以是CPU,该处理器还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific IntegratedCircuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器包括可读存储介质、内存储器等,其中,内存储器可以是计算机设备的内存,内存储器为可读存储介质中的操作系统和计算机可读指令的运行提供环境。可读存储介质可以是计算机设备的硬盘,在另一些实施例中也可以是计算机设备的外部存储设备,例如,计算机设备上配备的插接式硬盘、智能存储卡(Smart Media Card,SMC)、安全数字(Secure Digital,SD)卡、闪存卡(Flash Card)等。进一步地,存储器还可以既包括计算机设备的内部存储单元也包括外部存储设备。存储器用于存储操作系统、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,该其他程序如计算机程序的程序代码等。存储器还可以用于暂时地存储已经输出或者将要输出的数据。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。上述装置中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述方法实施例的步骤。其中,计算机程序包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质至少可以包括:能够携带计算机程序代码的任何实体或装置、记录介质、计算机存储器、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、电载波信号、电信信号以及软件分发介质。例如U盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。
本发明实现上述实施例方法中的全部或部分流程,也可以通过一种计算机程序产品来完成,当计算机程序产品在计算机设备上运行时,使得计算机设备执行时实现可实现上述方法实施例中的步骤。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的实施例中,应该理解到,所揭露的装置/计算机设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/计算机设备实施例仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于知识蒸馏的模型训练方法,其特征在于,包括:
获取满足目标条件的第一模型和不满足所述目标条件的第二模型,所述第一模型包括M个网络层,所述第二模型包括N个网络层,N、M均为大于零的整数;
根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型;
计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
根据所述目标相似度,构建相似度损失函数,并将所述相似度损失函数与所述优化损失函数的和作为目标损失函数;
使用训练集对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型。
2.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型,包括:
将带有原始标签的第一训练样本输入至所述第一模型中,以所述第一模型输出的新标签更新所述第一训练样本对应的原始标签,得到第二训练样本;
利用所述第一训练样本与所述第二训练样本,分别对第二模型进行训练,得到第一知识蒸馏损失函数与第二知识蒸馏损失函数;
通过所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数,构建优化损失函数;
根据所述优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型。
3.如权利要求2所述的基于知识蒸馏的模型训练方法,其特征在于,所述通过所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数,构建优化损失函数,包括:
对所述第一知识蒸馏损失函数与所述第二知识蒸馏损失函数设置不同的初始参数,得到初始蒸馏损失函数;
使用梯度下降算法对所述初始蒸馏损失函数进行参数更新,得到目标参数,使用所述目标参数更新初始蒸馏损失函数,得到优化损失函数。
4.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度,包括:
分别获取所述第一模型中M个网络层与所述更新后的第二模型中N个网络层中每个网络层的特征矩阵;
计算所述第一模型中M个网络层的特征矩阵分别与所述更新后的第二模型中N个网络层的特征矩阵的表征相似度,得到所述第一模型中每个网络层对应的表征相似度序列;
从所述第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度。
5.如权利要求4所述的基于知识蒸馏的模型训练方法,其特征在于,所述从所述第一模型中每个网络层对应的表征相似度序列中通过预设选取条件,确定目标相似度,包括:
从所述第一模型中每个网络层对应的表征相似度序列中获取表征相似度最大值,将所述表征相似度最大值作为目标相似度。
6.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据所述目标相似度,构建相似度损失函数,包括:
根据所述目标相似度,计算所述第二模型中每个网络层的损失值;
基于所述第二模型中每个网络层的损失值,构建相似度损失函数。
7.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述使用训练集对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型,包括:
根据所述训练集中的正负样本,构建样本对;所述样本对至少包括一个正样本与一个负样本;
基于所述样本对,对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型。
8.一种基于知识蒸馏的模型训练装置,其特征在于,所述装置包括:
获取模型模块,用于获取满足目标条件的第一模型和不满足所述目标条件的第二模型,所述第一模型包括M个网络层,所述第二模型包括N个网络层,N、M均为大于零的整数;
更新模块,用于根据所述第一模型的输出构建的优化损失函数,更新所述第二模型的初始损失函数,得到更新后的第二模型;
目标相似度确定模块,用于计算所述第一模型中M个网络层分别与所述更新后的第二模型中N个网络层的表征相似度,通过预设选取条件,确定目标相似度;
目标损失函数确定模块,用于根据所述目标相似度,构建相似度损失函数,并将所述相似度损失函数与所述优化损失函数的和作为目标损失函数;
训练模块,用于使用训练集对所述第二模型进行训练,直至所述目标损失函数收敛,得到满足所述目标条件的第二模型。
9.一种计算机设备,其特征在于,所述计算机设备包括处理器、存储器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的基于知识蒸馏的模型训练方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的基于知识蒸馏的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210816261.5A CN115062769A (zh) | 2022-07-12 | 2022-07-12 | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210816261.5A CN115062769A (zh) | 2022-07-12 | 2022-07-12 | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115062769A true CN115062769A (zh) | 2022-09-16 |
Family
ID=83205853
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210816261.5A Pending CN115062769A (zh) | 2022-07-12 | 2022-07-12 | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115062769A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116485505A (zh) * | 2023-06-25 | 2023-07-25 | 杭州金智塔科技有限公司 | 基于用户表现公平性训练推荐模型的方法及装置 |
-
2022
- 2022-07-12 CN CN202210816261.5A patent/CN115062769A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116485505A (zh) * | 2023-06-25 | 2023-07-25 | 杭州金智塔科技有限公司 | 基于用户表现公平性训练推荐模型的方法及装置 |
CN116485505B (zh) * | 2023-06-25 | 2023-09-19 | 杭州金智塔科技有限公司 | 基于用户表现公平性训练推荐模型的方法及装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
EP3711000B1 (en) | Regularized neural network architecture search | |
US11544536B2 (en) | Hybrid neural architecture search | |
CN111858859A (zh) | 自动问答处理方法、装置、计算机设备及存储介质 | |
CN111523640B (zh) | 神经网络模型的训练方法和装置 | |
CN111667056B (zh) | 用于搜索模型结构的方法和装置 | |
CN110598869B (zh) | 基于序列模型的分类方法、装置、电子设备 | |
US20190228297A1 (en) | Artificial Intelligence Modelling Engine | |
CN111783873A (zh) | 基于增量朴素贝叶斯模型的用户画像方法及装置 | |
CN111858878A (zh) | 从自然语言文本中自动提取答案的方法、系统及存储介质 | |
CN113239702A (zh) | 意图识别方法、装置、电子设备 | |
CN115312033A (zh) | 基于人工智能的语音情感识别方法、装置、设备及介质 | |
CN111461353A (zh) | 一种模型训练的方法和系统 | |
CN115062769A (zh) | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 | |
CN112307048B (zh) | 语义匹配模型训练方法、匹配方法、装置、设备及存储介质 | |
CN113870863A (zh) | 声纹识别方法及装置、存储介质及电子设备 | |
CN115358374A (zh) | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 | |
CN116401522A (zh) | 一种金融服务动态化推荐方法和装置 | |
CN113361621B (zh) | 用于训练模型的方法和装置 | |
CN115687934A (zh) | 意图识别方法、装置、计算机设备及存储介质 | |
CN114861671A (zh) | 模型训练方法、装置、计算机设备及存储介质 | |
CN114358284A (zh) | 一种基于类别信息对神经网络分步训练的方法、装置、介质 | |
CN112949313A (zh) | 信息处理模型训练方法、装置、设备及存储介质 | |
CN111813941A (zh) | 结合rpa和ai的文本分类方法、装置、设备及介质 | |
CN116431757B (zh) | 基于主动学习的文本关系抽取方法、电子设备及存储介质 | |
CN116912920B (zh) | 表情识别方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |