CN115358374A - 基于知识蒸馏的模型训练方法、装置、设备及存储介质 - Google Patents
基于知识蒸馏的模型训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN115358374A CN115358374A CN202211004873.0A CN202211004873A CN115358374A CN 115358374 A CN115358374 A CN 115358374A CN 202211004873 A CN202211004873 A CN 202211004873A CN 115358374 A CN115358374 A CN 115358374A
- Authority
- CN
- China
- Prior art keywords
- model
- initial
- loss function
- label probability
- probability set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
- G06N5/022—Knowledge engineering; Knowledge acquisition
Abstract
本发明涉及人工智能技术领域,提供一种基于知识蒸馏的模型训练方法、装置、设备及存储介质,通过获取满足目标条件的第一模型和不满足目标条件的N个第二初始模型,根据训练样本集,得到N个第二标签概率集,对第二标签集进行处理,得到第二初始模型对应的第三标签概率集,根据基预先得到的第一标签概率集,以及第三标签概率集,得到N个第二初始模型对应的平均损失函数,根据个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据目标损失函数,得到N个满足目标条件的第二模型,在对第二模型进行训练时,将第二模型中的N个模型进行信息融合,使每个第二模型得到的信息更加全面,从而提高第二模型的性能。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于知识蒸馏的模型训练方法、装置、设备及存储介质。
背景技术
随着人工智能识别的发展,普遍采用模型进行数据处理、图像识别等。通常地,对于不同应用场景有定制化模型需求的时候,技术人员选择的模型训练方式大致有两种:一、使用通用数据集训练好的通用模型修改最后输出层的类别数量,然后使用自己的数据集对模型参数进行重新调整;二、自己设计结构简单的模型,使用自己的数据集从头训练模型参数。其中,前一种方法使用已训练好的模型参数继续训练,能够加快训练收敛,也能保证模型精度,但是模型较大参数众多,需要较长的训练时间,第二种方法可以定制结构简单参数较少的小模型,但是模型参数需要从头训练,势必会减慢收敛速度,也不能保证模型的精度,而且训练数据集较小的情况下,还容易造成模型过拟合。
目前,通过知识蒸馏的方法,将教师模型中学习到的知识传递到学生模型中,从而使学生模型具有教师模型的泛化能力,但当前的知识蒸馏方法一般将教师模型中学习到的知识传递给一个学生模型,由于学生模型的参数有限,无法完全学习到教师模型中的知识,使生成的学生模型性能较差,因此,如何改进知识蒸馏的训练过程,以提高学生模型的性能成为亟待解决的问题。
发明内容
基于此,有必要针对上述技术问题,提供一种基于知识蒸馏的模型训练方法、装置、设备及存储介质,以解决训练过程中模型性能较低的问题。
本申请实施例的第一方面提供了一种基于知识蒸馏的模型训练方法,所述方法包括:
获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数;
初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至所述第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,所述第二初始模型与第二标签概率集一一对应;
对所述第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出所述第二初始模型对应的第三标签概率集;
根据基于所述第一模型与所述训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数;
基于所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据所述目标损失函数,对对应的第二初始模型进行训练,调整每个第二初始模型对应的初始参数,得到N个满足目标条件的第二模型。
本申请实施例的第二方面提供了一种基于知识蒸馏的模型训练装置,所述装置包括:
获取模块,用于获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数;
第二标签概率集确定模块,用于初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至所述第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,所述第二初始模型与第二标签概率集一一对应;
第三标签概率集确定模块,用于对所述第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出所述第二初始模型对应的第三标签概率集;
平均损失函数构建模块,用于根据基于所述第一模型与所述训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数;
第二模型确定模块,用于基于所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据所述目标损失函数,对对应的第二初始模型进行训练,调整每个初始模型对应的初始参数,得到N个满足目标条件的第二模型。
第三方面,本发明实施例提供一种计算机设备,所述计算机设备包括处理器、存储器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如第一方面所述的基于知识蒸馏的模型训练方法。
第四方面,本发明实施例提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如第一方面所述的基于知识蒸馏的模型训练方法。
本发明与现有技术相比存在的有益效果是:
获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数,初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,第二初始模型与第二标签概率集一一对应,对第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集,根据基于第一模型与训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数,基于平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据目标损失函数,对对应的第二初始模型进行训练,调整每个第二初始模型对应的初始参数,得到N个满足目标条件的第二模型,在对第二模型进行训练时,将第二模型中的N个模型进行信息融合,使每个第二模型得到的信息更加全面,从而提高第二模型的性能。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的一应用环境示意图;
图2是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图;
图3是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图;
图4是本发明一实施例提供的一种基于知识蒸馏的模型训练装置的结构示意图;
图5是本发明一实施例提供的一种计算机设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本发明说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本发明说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
另外,在本发明说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本发明说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本发明的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
本发明实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
应理解,以下实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
为了说明本发明的技术方案,下面通过具体实施例来进行说明。
本发明一实施例提供的一种基于知识蒸馏的模型训练方法,可应用在如图1的应用环境中,其中,客户端与服务端进行通信。其中,客户端包括但不限于掌上电脑、桌上型计算机、笔记本电脑、超级移动个人计算机(ultra-mobile personal computer,UMPC)、上网本、个人数字助理(personal digital assistant,PDA)等计算机设备。服务端可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
参见图2,是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图,上述基于知识蒸馏的模型训练可以应用于图1中的服务端,上述服务端连接相应的客户端,为客户端提供模型训练服务。如图2所示,该基于知识蒸馏的模型训练方法可以包括以下步骤。
S201:获取满足目标条件的第一模型和不满足目标条件的N个第二模型。
在步骤S201中,满足目标条件的第一模型和不满足目标条件的N个第二模型为深度学习卷积神经网络模型,通过使用训练集对第一模型进行训练得到满足目标条件的第一模型,N为大于零的整数。满足目标条件的第一模型为训练好的一个深度学习模型,不满足目标条件的第二模型为未训练完成的深度学习模型,第二模型是为了学习第一模型中的参数,节省训练时间。
本实施例中,第一模型为知识蒸馏中教师模型,第二模型为知识蒸馏中的学生模型。获取满足条件的第一模型时,可以首先对训练集中的样本数据进行人工标注,然后利用标注的样本数据对第一模型进行训练,该训练过程可以理解为预训练过程。在训练的过程中,可以利用第一模型的损失函数计算第一模型的实际输出与标注结果之间的损失值,根据损失值进行反向传播训练,至损失收敛,得到满足条件的第一模型。例如,上述第一模型用于声纹识别,满足条件的第一模型实际输出可以为声纹特征向量。
需要说明的是,第一模型和第二模型可以是同类型的神经网络模型,即第一模型和第二模型具有相同的网络层结构,第一模型和第二模型也可以是不同类型的神经网络模型,即第一模型和第二模型的网络层结构有差异。但是,第一模型的网络层层数小于第二模型的网络层层数,也就是,第一模型的模型大小和模型参数量小于第二模型。由于第二模型的模型参数量较多,在实际应用时,输出预测结果需要耗费较多时间。为了提高预测结果的输出速度和节约计算机资源,可以对大模型进行知识蒸馏得到轻量级的小模型,从而在实际应用时,基于知识蒸馏的模型训练将复杂、学习能力强的第一网络模型已经学习到的特征表示知识蒸馏出来,传递给参数量小、学习能力弱的第二模型。
例如,第一模型可以基于残差网络ResNet34构建;第二模型可以基于ResNet10构建。由于较大的网络往往面临着深度学习网络模型大而冗余,识别速度难以满足实时性要求的问题,而小网络模型非常容易会因为参数量较小,模型特征表示能力不足,导致模型准确性能低下,带来的问题就是虽然满足线上应用的实时性要求,但却无法满足准确性要求。因此,通过第一模型对的知识对第二模型进行训练,可以使的大的网络模型对小的网络模型起到正向作用,使得第二模型能够获得较优的拟合参数,进而提升第二模型的准确度。
S202:初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集。
在步骤S202中,对N个第二模型分别设置初化参数,得到N个第二初始模型,其中,每个第二模型中的初始化参数不同,每个第二模型中的初始化参数的数量根据对应模型的卷积核大小,输入通道大小,输出通道大小以及偏置项大小决定,例如,卷积神经网络的卷积核大小为k×k,输入通道为i,输出通道为o,偏置项为y,则每进行一次卷积操作,需要的参数数量为:k×k×i×o+y。根据得到的初始化第二模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集。
本实施例中,对不同的第二模型设置不同的初始化参数时,可以根据不同的策略进行设置,以便于第二模型在进行信息交互时,可以使不同的第二模型可以从第一模型中学习到更全面的知识。例如,当激活函数为饱和性激活函数tanh,使用Xavier初始化方式对第二模型中的参数进行初始化,Xavier初始化方法有利于加快收敛、减小过拟合。当激活函数为ReLU及其变种激活函数,使用Kaiming初始化方法。LeCun初始化,它适用于sigmoid激活函数的神经网络。其主要设计思想为假设网络的输入为高斯分布,在激活向前传播过程中,通过控制初始化时权值参数采样分布的期望和方差使每层神经元激活值的期望为0和方差为1。
需要说明的是,初始化的目的是要避免出现梯度爆炸或者梯度消失,更高的要求就是同时保证激活向前传播和梯度向后传播稳定,也就是控制向前传播时神经元激活信号和向后传播时误差信号的均值和方差稳定,所以选择不同方法进行初始化时,尽可能使控制向前传播时神经元激活信号和向后传播时误差信号的均值和方差稳定。
本实施例中,对每个第二模型进行初始化后,得到第二初始模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集。第二标签概率集中包括每个第二初始模型对应的第二标签概率。对于相同的训练集,输入至不同的第二初始模型中,得到不同的标签概率值。
可选地,对每个第二模型设置不同的初始参数,得到N个第二初始模型,包括:
通过预设算法,计算得到每个第二模型中的参数权重;
根据参数权重,为每个第二模型赋值一个初始参数,得到N个第二初始模型。
本实施例中,第二模型是一种神经网络模型,神经网络模型中包括多层神经元,每一个神经元代表徐那脸样本中的一个特征,可以根据训练样本中特征的维度以及每个特征中的重要度,计算每个神经元中的参数权重,例如,当训练样本中的特征正向量的维度为5,则对应第二模型有5个神经元,计算每个神经元的参数权重时,根据每个神经元对应的特征确定参数权重,例如,当训练样本的特征向量为5个为度的特征时,若第三维度的的特征对训练样本的结果影响较大,则可以从随机会获取到的5个数值中选取数值最大的作为第三维度特征对应的神经元的参数权重。依次为每个第二模型参数进行赋值,得到N个第二初始模型。
需要说明的是,其中参数数值可以根据随机函数获得,当对每个神经元进行参数赋值时,若每个神经元中对应的特征对训练结果的影响相同,将对每个神经元赋予相同的参数值。
S203:对第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集。
在步骤S203中,对第二初始模型对应的第二标签概率集进行预处理,使第二标签概率值中的数值可以处于相似的尺度上,有利于加快梯度下降算法的运算。将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集。
本实施例中,对第二初始模型对应的第二标签概率集进行预处理,其中,预处理是将第二标签概率值记性适当的缩放,使第二标签概率值中的数值可以处于相似的尺度上,缩放时,设置对应的缩放参数,设置缩放参数时,可以依据第二标签概率集中数值的大小进行设置,进行缩放处理后,将处理后的标签概率值输入至归一化层,输出第二初始模型对应的第三标签概率集。第三标签概率集中的数值将在同一数量级中,对不同的样本,使第三标签概率集中可以保存训练样本中的不同特征,以防止标签概率较小时,对应的训练样本的特征被忽略掉。
可选地,对第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集,包括:
设置预处理参数,对第二初始模型对应的第二标签概率集进行蒸馏处理,得到蒸馏标签概率集;
将蒸馏标签概率集输入至归一化层,输出第二初始模型对应的第三标签概率集。
本实施例中,设置预处理参数,对第二初始模型对应的第二标签概率集进行蒸馏处理,得到蒸馏标签概率集,例如,可以根据第二标签概率集中的最大值与最小值之间的差值进行设置,获取第二标签概率集中的最大值与最小值的数量级,根据数量级的差异设置不同的参数,例如,当第二标签概率集中的最大值为20时,最小值为1时,可以设置参数为10,使第二标签概率值中的数值可以处于相似的尺度上,得到蒸馏标签概率集。
需要说明的是,当设置预处理参数时,也可以根据第二标签概率集中处于同一数量级最多数值的数量级进行确定,当第二标签概率集的数值中同一数量级最多的是两位数,则可以将预处理参数设置为10,当第二标签概率集的数值中同一数量级最多的是三位数,则可以将预处理参数设置为100等。得到对应的蒸馏标签概率集,将蒸馏标签概率集输入至归一化层,输出第二初始模型对应的第三标签概率集,其中,归一化层中有预先设置的归一化函数,归一化函数可以选择现行归一化或者范数归一化等。
S204:根据基于第一模型与训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数。
在步骤S204中,基于第一模型与训练样本集得到的第一标签概率集,其中第一标签概率集为基于软化处理得到的标签概率集,根据第一标签概率集预第三标签概率集构建每个第二初始模型对应的损失函数,对每个第二初始模型对应的损失函数进行求和,并进行均值计算,得到N个第二初始模型对应的平均损失函数。
本实施例中,将训练样本输入至第一模型中,得到对应的第一标签概率集,根据第一标签概率集与每个第二初始模型对应的第三概率集构建每个第二初始模型对应的损失函数,构建每个第二初始模型对应的损失函数时,可以根据第一标签概率集与第二初始模型对应的第三概率集之间出的差值,构建每个第二初始模型对应的损失函数,根据每个第二初始模型对应的损失函数,对N个损失函数的参数求均值,得到平均损失函数。
可选地,根据基于第一模型与训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,包括:
将训练样本集分别输入至第一模型中,输出训练样本集对应的第一标签概率集;
基于的第一标签概率集与每个第二初始模型对应的第三标签概率集,通过相对熵算法,得到每个第二初始模型对应的损失函数。
本实施例中,将训练样本集分别输入至第一模型中,输出训练样本集对应的第一标签概率集,基于的第一标签概率集与每个第二初始模型对应的第三标签概率集,通过相对熵算法,得到每个第二初始模型对应的损失函数,相对熵算法实现了每个预测值与真实值之间匹配程度的量化,其中,将第一标签概率及中的标签作为真实值,将每个第二初始模型对应的第三标签集中的标签作为预测值,相对熵的取值应该是在0到无穷大之间,因此可以作为两种分布间距离的度量,可以认为是一种对概率分布间距离的度量,特别是当两个分布的差异很小的时候,可以认为是一个局部分布间的距离,相对熵越小,真实值与预测值之间的匹配效果就越好,通过相对熵算法,得到每个第二初始模型对应的损失函数。
S205:基于平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据目标损失函数,对对应的第二初始模型进行训练,调整每个第二初始模型对应的初始参数,得到N个满足目标条件的第二模型。
在步骤S205中,每个第二初始模型对应的目标损失函数为平均损失函数与每个第二初始模型对应的损失函数之和,使用训练集对每个第二模型进行训练时,可以进行有监督的训练,设置对应的阈值,当训练结果与监督标签的差值小于阈值时,认为目标损失函数收敛,得到N个满足目标条件的第二模型。
本实施例中,使用训练样本集对每个初始第二模型进行训练时,针对每个初始第二模型中目标损失函数,得到的训练结果与训练样本集中的标签通过目标损失函数计算损失值,判断损失值是否符合预设条件,当未符合预设条件时,根据损失值对每个初始第二模型进行反向传播更新,得到更新模型参数的第二初始模型,再次基于训练集对更新模型参数的第二初始模型进行训练,直至损失值符合预设条件,依次得到N个已满足目标条件的第二模型。
可选地,基于平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,包括:
根据训练样本集与每个第二标签概率集,通过预设算法,构建第二初始模型对应的损失函数;
根据平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数。
本实施例中,将训练样本集输入至每个第二初始模型中,输出训练样本集对应的预测值,根据训练样本集与预测值,通过交叉熵算法,构建每个第二初始函数的损失函数,将平均损失函数与每个第二初始模型对应的损失函数进行求和,得到每个第二初始模型的目标损失函数。
需要说明的是,输出的训练样本对应的预测值为缩放参数为1的标签概率值,即每个第二初始模型得到的标签概率值为输出的真实标签概率值。
可选地,根据平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,包括:
获取平均损失函数与每个第二初始模型对应的损失函数的占比系数;
根据占比系数对平均损失函数与每个第二初始模型对应的损失函数进行加权求和,得到每个第二初始模型对应的目标损失函数。
本实施例中,获取平均损失函数与每个第二初始模型对应的损失函数的占比系数,获取时,从不同的占比系数的组合中获取相应的平均损失函数与每个第二初始模型对应的损失函数的占比系数,例如,在对第二初始模型进行训练时,训练前期,第一模型起到的贡献较多,平均损失函数中包含第一模型中的损失函数,所以,平均损失函数中的占比系数较大,获取占比系数时,需要从不同的占比系数的组合中获取一组组合,将数值较大的作为平均损失函数的占比系数,将数值较小的作为第二初始模型对应损失函数的占比系数,根据占比系数,对平均损失函数与每个第二初始模型对应的损失函数进行加权求和,得到每个第二初始模型对应的目标损失函数。在训练后期,第二初始模型起到的贡献较多,所以,第二初始模型对应的损失函数的占比系数较大,获取占比系数时,需要从不同的占比系数的组合中获取一组组合,将数值较大的作为第二初始模型对应的损失函数的占比系数,将数值较小的作为平均损失函数占比系数,根据占比系数,对平均损失函数与每个第二初始模型对应的损失函数进行加权求和,得到每个第二初始模型对应的目标损失函数。
获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数,初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,第二初始模型与第二标签概率集一一对应,对第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集,根据基于第一模型与训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数,基于平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据目标损失函数,对对应的第二初始模型进行训练,调整每个第二初始模型对应的初始参数,得到N个满足目标条件的第二模型,在对第二模型进行训练时,将第二模型中的N个模型进行信息融合,使每个第二模型得到的信息更加全面,从而提高第二模型的性能。
参见图3,是本发明一实施例提供的一种基于知识蒸馏的模型训练方法的流程示意图,如图3,该基于知识蒸馏的模型训练方法可以包括以下步骤:
S301:基于训练样本集对预先构建的第一模型进行训练处理,调整第一模型中的参数,得到满足目标条件的第一模型。
本实施例中,使用训练样本集对预先构建的第一模型进行训练时,为了加快收敛速度,可以使用样本对进行训练,每个样本对选定相同的正样本和负样本,由于样本反传梯度的大小由每一个正负样本对的差累计决定,这导致了一旦样本对中的样本数量过多时,在特征空间中找到满足所有样本对训练分界面会变得很困难,训练的收敛性将会随之变差,与此同时,每次重复计算两两样本间的距离差作为反传梯度会导致训练的冗余和训练时间的加长。由于不同样本对中的某一对正负样本的大小比较可能会出现多次,过多的增加样本对中的正负样本的数量对训练的帮助收益甚微,所以每个样本对中可以选择一个正样本和一个负样本。使用选定的样本对,对预先构建的第一模型进行训练,调整第一模型中的参数,得到满足目标条件的第一模型。
获得满足目标条件的第一模型,目的是将第一模型中的信息传递给N个不满足目标条件的第二模型,加快第二模型的训练。
S302:获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数;
S303:初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,第二初始模型与第二标签概率集一一对应;
S304:对第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集;
S305:根据基于第一模型与训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数;
S306:基于平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据目标损失函数,对对应的第二初始模型进行训练,调整每个第二初始模型对应的初始参数,得到N个满足目标条件的第二模型。
其中,上述步骤S302至步骤S306与上述步骤S201至步骤S205的内容相同,可参考上述步骤S201至步骤S205的描述,在此不再赘述。
请参阅图4,图4是本发明实施例提供的一种基于知识蒸馏的模型训练装置的结构示意图。本实施例中该终端包括的各单元用于执行图2至图3对应的实施例中的各步骤。具体请参阅图2至图3以及图2至图3所对应的实施例中的相关描述。为了便于说明,仅示出了与本实施例相关的部分。参见图4,训练装置40包括:获取模块41,第二标签概率集确定模块42,第三标签概率集确定模块43,平均损失函数构建模块44,第二模型确定模块45。
获取模块41,用于获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数。
第二标签概率集确定模块42,用于初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,第二初始模型与第二标签概率集一一对应。
第三标签概率集确定模块43,用于对第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出第二初始模型对应的第三标签概率集。
平均损失函数构建模块44,用于根据基于第一模型与训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数。
第二模型确定模块45,用于基于平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据目标损失函数,对对应的第二初始模型进行训练,调整每个初始模型对应的初始参数,得到N个满足目标条件的第二模型。
可选的是,上述第二标签概率集确定模块42包括:
计算单元,用于通过预设算法,计算得到每个第二模型中的参数权重。
赋值单元,用于根据参数权重,为每个第二模型赋值一个初始参数,得到N个第二初始模型。
可选的是,上述第三标签概率集确定模块43包括:
预处理单元,用于设置预处理参数,对第二初始模型对应的第二标签概率集进行蒸馏处理,得到蒸馏标签概率集。
输出单元,用于将蒸馏标签概率集输入至归一化层,输出第二初始模型对应的第三标签概率集。
可选的是,上述平均损失函数构建模块44包括:
第一标签概率集确定单元,用于将训练样本集分别输入至第一模型中,输出训练样本集对应的第一标签概率集。
第二初始模型对应的损失函数确定单元,用于基于的第一标签概率集与每个第二初始模型对应的第三标签概率集,通过相对熵算法,得到每个第二初始模型对应的损失函数。
可选的是,上述第二模型确定模块45包括:
构建单元,用于根据训练样本集与每个第二标签概率集,通过预设算法,构建第二初始模型对应的损失函数。
目标损失函数确定单元,用于根据平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数。
可选的是,上述目标损失函数确定单元包括:
获取子单元,用于获取平均损失函数与每个第二初始模型对应的损失函数的占比系数。
加权求和子单元,用于根据占比系数对平均损失函数与每个第二初始模型对应的损失函数进行加权求和,得到每个第二初始模型对应的目标损失函数。
可选的是,上述训练装置还包括:
训练模块,用于基于训练样本集对预先构建的第一模型进行训练处理,调整第一模型中的参数,得到满足目标条件的第一模型。
需要说明的是,上述单元之间的信息交互、执行过程等内容,由于与本发明方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
图5是本发明实施例提供的一种计算机设备的结构示意图。如图5所示,该实施例的计算机设备包括:至少一个处理器(图5中仅示出一个)、存储器以及存储在存储器中并可在至少一个处理器上运行的计算机程序,处理器执行计算机程序时实现上述任意各个基于知识蒸馏的模型训练方法实施例中的步骤。
该计算机设备可包括,但不仅限于,处理器、存储器。本领域技术人员可以理解,图5仅仅是计算机设备的举例,并不构成对计算机设备的限定,计算机设备可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如还可以包括网络接口、显示屏和输入装置等。
所称处理器可以是CPU,该处理器还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific IntegratedCircuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器包括可读存储介质、内存储器等,其中,内存储器可以是计算机设备的内存,内存储器为可读存储介质中的操作系统和计算机可读指令的运行提供环境。可读存储介质可以是计算机设备的硬盘,在另一些实施例中也可以是计算机设备的外部存储设备,例如,计算机设备上配备的插接式硬盘、智能存储卡(Smart Media Card,SMC)、安全数字(Secure Digital,SD)卡、闪存卡(Flash Card)等。进一步地,存储器还可以既包括计算机设备的内部存储单元也包括外部存储设备。存储器用于存储操作系统、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,该其他程序如计算机程序的程序代码等。存储器还可以用于暂时地存储已经输出或者将要输出的数据。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。上述装置中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述方法实施例的步骤。其中,计算机程序包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质至少可以包括:能够携带计算机程序代码的任何实体或装置、记录介质、计算机存储器、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、电载波信号、电信信号以及软件分发介质。例如U盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。
本发明实现上述实施例方法中的全部或部分流程,也可以通过一种计算机程序产品来完成,当计算机程序产品在计算机设备上运行时,使得计算机设备执行时实现可实现上述方法实施例中的步骤。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的实施例中,应该理解到,所揭露的装置/计算机设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/计算机设备实施例仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于知识蒸馏的模型训练方法,其特征在于,所述训练方法包括:
获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数;
初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至所述第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,所述第二初始模型与第二标签概率集一一对应;
对所述第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出所述第二初始模型对应的第三标签概率集;
根据基于所述第一模型与所述训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数;
基于所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据所述目标损失函数,对对应的第二初始模型进行训练,调整每个第二初始模型对应的初始参数,得到N个满足目标条件的第二模型。
2.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,包括:
通过预设算法,计算得到每个第二模型中的参数权重;
根据所述参数权重,为每个第二模型赋值一个初始参数,得到N个第二初始模型。
3.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述对所述第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出所述第二初始模型对应的第三标签概率集,包括:
设置预处理参数,对所述第二初始模型对应的第二标签概率集进行蒸馏处理,得到蒸馏标签概率集;
将所述蒸馏标签概率集输入至归一化层,输出所述第二初始模型对应的第三标签概率集。
4.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据基于所述第一模型与所述训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,包括:
将所述训练样本集分别输入至所述第一模型中,输出所述训练样本集对应的第一标签概率集;
基于所述的第一标签概率集与每个第二初始模型对应的第三标签概率集,通过相对熵算法,得到每个第二初始模型对应的损失函数。
5.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述基于所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,包括:
根据所述训练样本集与每个第二标签概率集,通过预设算法,构建第二初始模型对应的损失函数;
根据所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数。
6.如权利要求5所述的基于知识蒸馏的模型训练方法,其特征在于,所述根据所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,包括:
获取所述平均损失函数与每个第二初始模型对应的损失函数的占比系数;
根据所述占比系数对所述平均损失函数与每个第二初始模型对应的损失函数进行加权求和,得到每个第二初始模型对应的目标损失函数。
7.如权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述获取满足目标条件的第一模型之前,还包括:
基于所述训练样本集对预先构建的第一模型进行训练处理,调整所述第一模型中的参数,得到满足目标条件的第一模型。
8.一种基于知识蒸馏的模型训练装置,其特征在于,所述装置包括:
获取模块,用于获取满足目标条件的第一模型和不满足目标条件的N个第二模型,其中,N为大于1的整数;
第二标签概率集确定模块,用于初始化得到N个不同的初始参数,为每个第二模型赋值一个初始参数,得到N个第二初始模型,将训练样本集输入至所述第二初始模型中,输出第二标签概率集,得到N个第二标签概率集,所述第二初始模型与第二标签概率集一一对应;
第三标签概率集确定模块,用于对所述第二初始模型对应的第二标签概率集进行预处理,并将预处理结果输入至归一化层,输出所述第二初始模型对应的第三标签概率集;
平均损失函数构建模块,用于根据基于所述第一模型与所述训练样本集得到的第一标签概率集,以及每个第二初始模型对应的第三标签概率集,得到每个第二初始模型对应的损失函数,并构建N个第二初始模型对应的平均损失函数;
第二模型确定模块,用于基于所述平均损失函数与每个第二初始模型对应的损失函数,得到每个第二初始模型对应的目标损失函数,根据所述目标损失函数,对对应的第二初始模型进行训练,调整每个初始模型对应的初始参数,得到N个满足目标条件的第二模型。
9.一种计算机设备,其特征在于,所述计算机设备包括处理器、存储器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的基于知识蒸馏的模型训练方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的基于知识蒸馏的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211004873.0A CN115358374A (zh) | 2022-08-22 | 2022-08-22 | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211004873.0A CN115358374A (zh) | 2022-08-22 | 2022-08-22 | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115358374A true CN115358374A (zh) | 2022-11-18 |
Family
ID=84003557
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211004873.0A Pending CN115358374A (zh) | 2022-08-22 | 2022-08-22 | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115358374A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117236409A (zh) * | 2023-11-16 | 2023-12-15 | 中电科大数据研究院有限公司 | 基于大模型的小模型训练方法、装置、系统和存储介质 |
-
2022
- 2022-08-22 CN CN202211004873.0A patent/CN115358374A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117236409A (zh) * | 2023-11-16 | 2023-12-15 | 中电科大数据研究院有限公司 | 基于大模型的小模型训练方法、装置、系统和存储介质 |
CN117236409B (zh) * | 2023-11-16 | 2024-02-27 | 中电科大数据研究院有限公司 | 基于大模型的小模型训练方法、装置、系统和存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20210004663A1 (en) | Neural network device and method of quantizing parameters of neural network | |
US11295208B2 (en) | Robust gradient weight compression schemes for deep learning applications | |
US20200302271A1 (en) | Quantization-aware neural architecture search | |
US11604960B2 (en) | Differential bit width neural architecture search | |
WO2019157251A1 (en) | Neural network compression | |
CN115082920B (zh) | 深度学习模型的训练方法、图像处理方法和装置 | |
CN111989696A (zh) | 具有顺序学习任务的域中的可扩展持续学习的神经网络 | |
CN111523640A (zh) | 神经网络模型的训练方法和装置 | |
CN110781686B (zh) | 一种语句相似度计算方法、装置及计算机设备 | |
EP3649582A1 (en) | System and method for automatic building of learning machines using learning machines | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
KR20200063970A (ko) | 신경망 재구성 방법 및 장치 | |
CN111008689B (zh) | 使用softmax近似来减少神经网络推理时间 | |
CN113239702A (zh) | 意图识别方法、装置、电子设备 | |
EP3803580B1 (en) | Efficient incident management in large scale computer systems | |
CN111667069A (zh) | 预训练模型压缩方法、装置和电子设备 | |
CN115358374A (zh) | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 | |
CN116827685B (zh) | 基于深度强化学习的微服务系统动态防御策略方法 | |
CN115062769A (zh) | 基于知识蒸馏的模型训练方法、装置、设备及存储介质 | |
CN116739154A (zh) | 一种故障预测方法及其相关设备 | |
Huang et al. | Flow of renyi information in deep neural networks | |
CN113361621B (zh) | 用于训练模型的方法和装置 | |
CN110852361B (zh) | 基于改进深度神经网络的图像分类方法、装置与电子设备 | |
CN114792097A (zh) | 预训练模型提示向量的确定方法、装置及电子设备 | |
CN111815658A (zh) | 一种图像识别方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |