CN114386604A - 基于多教师模型的模型蒸馏方法、装置、设备及存储介质 - Google Patents

基于多教师模型的模型蒸馏方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN114386604A
CN114386604A CN202210044224.7A CN202210044224A CN114386604A CN 114386604 A CN114386604 A CN 114386604A CN 202210044224 A CN202210044224 A CN 202210044224A CN 114386604 A CN114386604 A CN 114386604A
Authority
CN
China
Prior art keywords
model
teacher
student
soft
deviation value
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210044224.7A
Other languages
English (en)
Inventor
王健宗
李泽远
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202210044224.7A priority Critical patent/CN114386604A/zh
Publication of CN114386604A publication Critical patent/CN114386604A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Image Analysis (AREA)

Abstract

本申请涉及人工智能领域,尤其涉及基于多教师模型的模型蒸馏方法,所述方法包括:获取训练样本数据及对应的硬标签,通过多个教师和第一学生模型对训练样本数据进行识别,得到多个第一软标签和第二软标签;根据硬标签、多个第一软标签和第二软标签,对第一学生模型进行知识蒸馏学习,生成第二学生模型并通过第二学生模型得到第三软标签;根据第一软标签和第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,基于更新后的教师模型选择策略重新确定对应的教师模型,并根据重新确定的教师模型对第一学生模型进行知识蒸馏学习,直至第一学生模型收敛,得到目标学生模型。由此能够使蒸馏得到的学生模型表现达到最佳,提高用户体验。

Description

基于多教师模型的模型蒸馏方法、装置、设备及存储介质
技术领域
本申请涉及人工智能领域,尤其涉及基于多教师模型的模型蒸馏方法、基于多教师模型的模型蒸馏装置、计算机设备及存储介质。
背景技术
现有的模型参数量巨大,这对于技术人员在精调和线上部署上带来了巨大的挑战,例如BERT-base模型拥有1.1亿参数、BERT-large模型拥有3.4亿参数,海量的参数使得这些模型在微调和部署时速度慢,计算成本大,对实时的应用造成了极大的延迟和容量限制,因此模型压缩意义重大。
作为模型压缩三大方法之一的模型蒸馏,在学术界和工业界得到广泛的认可和应用,越来越多的蒸馏方法得以提出与应用。目前常用的蒸馏方法是基于一个教师模型和一个学生模型的框架进行的,将单个复杂的教师模型学习到的知识蒸馏到简单的学生模型中,在一定程度上保持推理速度的同时有效地提升了学生模型推理的准确率,但是采用单一的复杂教师模型对学生模型准确率的提升并不是必要的。
发明内容
本申请提供了一种基于多教师模型的模型蒸馏方法、基于多教师模型的模型蒸馏装置、计算机设备及存储介质,旨在解决现有的基于教师模型蒸馏得到的学生模型表现不佳的问题。
为实现上述目的,本申请提供一种基于多教师模型的模型蒸馏方法,所述方法包括:
获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签;
根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同;
通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签;
根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型;
基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
为实现上述目的,本申请还提供一种基于多教师模型的模型蒸馏装置,所述基于多教师模型的模型蒸馏装置包括:
第一标签生成模块,用于获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签;
模型生成模块,用于根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同;
第二标签生成模块,用于通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签;
策略更新模块,用于根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型;
模型确定模块,用于基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
此外,为实现上述目的,本申请还提供一种计算机设备,所述计算机设备包括存储器和处理器;所述存储器,用于存储计算机程序;所述处理器,用于执行所述的计算机程序并在执行所述的计算机程序时实现本申请实施例提供的任一项所述的基于多教师模型的模型蒸馏方法。
此外,为实现上述目的,本申请还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时使所述处理器实现本申请实施例提供的任一项所述的基于多教师模型的模型蒸馏方法。
本申请实施例公开的基于多教师模型的模型蒸馏方法、基于多教师模型的模型蒸馏装置、设备及存储介质,通过蒸馏得到的学生模型表现对教师模型进行筛选并动态地选择教师模型,由此能够使多个教师模型将有效的知识蒸馏给学生模型,以使蒸馏得到的学生模型表现达到最佳,提高了用户体验。
附图说明
为了更清楚地说明本申请实施例技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的一种基于多教师模型的模型蒸馏方法的场景示意图;
图2是本申请实施例提供的一种基于多教师模型的模型蒸馏方法的流程示意图;
图3是本申请一实施例提供的一种基于多教师模型的模型蒸馏装置的示意性框图;
图4是本申请一实施例提供的一种计算机设备的示意性框图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
附图中所示的流程图仅是示例说明,不是必须包括所有的内容和操作/步骤,也不是必须按所描述的顺序执行。例如,有的操作/步骤还可以分解、组合或部分合并,因此实际执行的顺序有可能根据实际情况改变。另外,虽然在装置示意图中进行了功能模块的划分,但是在某些情况下,可以以不同于装置示意图中的模块划分。
在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
目前常用的蒸馏方法是基于一个教师模型和一个学生模型的框架进行的,将单个复杂的教师模型学习到的知识蒸馏到简单的学生模型中,在一定程度上保持推理速度的同时有效地提升了学生模型推理的准确率。但是采用单一的复杂教师模型对学生模型准确率的提升并不是必要的,选择更合适的性能较弱的教师模型进行蒸馏比选择更强的教师模型取得的效果可能更好。
基于这种思想,基于多教师模型的蒸馏方法应运而生。但是,目前的多教师模型蒸馏方法大都采用固定权重的方式分配各教师对学生施加的影响,无法动态充分地将有效的知识蒸馏给学生模型,这样可能会导致蒸馏得到的学生模型表现并没有达到最佳。
为解决上述问题,本申请提供了一种基于多教师模型的模型蒸馏方法,可以应用在服务器中,当然也可以应用于终端设备上,可以动态地选择教师模型并为多个教师模型分配最合适的权重比例,由此能够使多个教师模型将有效的知识蒸馏给学生模型,以使蒸馏得到的学生模型表现达到最佳,提高了用户体验。
其中,所述终端设备可以包括诸如手机、平板电脑、个人数字助理(PersonalDigital Assistant,PDA)等固定终端。服务器例如可以为单独的服务器或服务器集群。但为了便于理解,以下实施例将以应用于服务器的基于多教师模型的模型蒸馏方法进行详细介绍。
示例性的,本申请实施例提出的基于多教师模型的模型蒸馏方法可以应用在问答(判断一个问题与一个答案是否匹配)、语句匹配(两句话是否表达同一个意思)等应用场景中。通过不断地动态选择教师模型并确定多个教师模型对应的权重比例,以使蒸馏得到的学生模型在问答上、语句匹配上更为快速以及准确。
下面结合附图,对本申请的一些实施方式作详细说明。在不冲突的情况下,下述的实施例及实施例中的特征可以相互组合。
如图1所示,本申请实施例提供的基于多教师模型的模型蒸馏方法,可以应用于如图1所示的应用环境中。该应用环境中包含有终端设备110和服务器120,其中,终端设备110可以通过网络与服务器120进行通信。具体地,服务器120对多个教师模型进行多次知识蒸馏学习,得到目标学生模型,并将目标学生模型发送给终端设备110,以使用户通过终端设备110使用该目标学生模型。其中,服务器120可以是独立的服务器,也可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。终端设备110可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。终端以及服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。
请参阅图2,图2是本申请实施例提供的一种基于多教师模型的模型蒸馏方法的示意流程图。其中,该基于多教师模型的模型蒸馏方法可以应用于服务器中,由此能够使多个教师模型将有效的知识蒸馏给学生模型,以使蒸馏得到的学生模型表现达到最佳。
如图2所示,该基于多教师模型的模型蒸馏方法包括步骤S101至步骤S105。
S101、获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签。
其中,所述训练样本数据为用于训练学生模型参数的样本数据,具体可以为不同的句子对或不同的图片等等。所述教师模型是一种复杂但推理性能优越的模型,所述学生模型为精简且低复杂度的模型。
本申请实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
具体地,对所述训练样本数据进行识别是通过建立了一个以softmax为损失函数的神经网络进行识别。其中,所述训练样本数据对应的硬标签为所述训练样本数据对应的真实标签,在经过以softmax为损失函数的神经网络进行识别后,其表示为0或1,0表征该训练样本数据不属于某一分类,1表征该训练样本数据属于某一分类。
具体地,可以通过全部的教师模型均对所述训练样本数据进行识别,得到多个第一软标签,可以通过第一学生模型对所述训练样本数据进行识别,得到第二软标签。其中,所述第一软标签为所选择的教师模型对于训练样本数据属于某一类别对应的概率,其数值介于0-1之间。所述第一学生模型为未经过知识蒸馏的学生模型,所述第二软标签为第一学生模型对于训练样本数据属于某一类别对应的概率。
需要说明的是,每个训练样本数据对应有硬标签,每个教师模型对于每个训练样本数据对应有第一软标签,第一学生模型对于每个训练样本数据对应有第二软标签。
示例性的,假设有一个关于汽车的分类任务,需要对训练样本数据进行分类,比如在A品牌汽车,B品牌汽车,自行车之中确定,且该训练样本数据为B品牌汽车,则对应的硬标签为[0,1,0]。若通过教师模型对该训练样本数据进行识别,得到对应的第一软标签可能为[0.09,0.9,0.01]。若通过第一学生模型对该训练样本数据进行识别,得到对应的第二软标签可能为[0.4,0.5,0.1]。
S102、根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同。
其中,知识蒸馏是指用准确度较高但结构复杂的老师模型指导训练准确率较低但结构简单的学生模型的方法,所述模型参数可以包括超参数、模型层数、模型的参数数量等等。
示例性的,可以通过所述硬标签、所述多个第一软标签和所述第二软标签对所述第一学生模型进行知识蒸馏学习,初始化所述第一学生模型的模型参数,生成第二学生模型。
在一些实施例中,根据所述多个第一软标签和所述第二软标签确定第一偏差值;根据所述硬标签和所述第二软标签确定第二偏差值;根据所述第一偏差值和所述第二偏差值确定第三偏差值,并根据所述第三偏差值对所述第一学生模型的模型参数进行初始化处理,生成第二学生模型。其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同,所述第一偏差值为多个第一软标签和第二软标签的平均损失值,所述第二偏差值为所述硬标签和所述第二软标签的损失值,所述第三偏差值为根据所述第一偏差值和所述第二偏差值调整得到的综合损失值。由此可以通过多个教师模型对学生模型进行知识蒸馏,从而初始化学生模型的模型参数。
具体地,根据所述多个第一软标签和所述第二软标签确定第一偏差值构建的损失函数为:
Figure BDA0003471491300000071
其中,k为所选教师模型的总数;yi,k,c为第k个教师模型预测的样本xi属于第c类的概率,即第k个教师模型对应的第一软标签;ps(yi=c|xi;θs)为第一学生模型预测的样本xj属于第c类的概率,即第一学生模型对应的第二软标签,其中θs为学生模型的模型参数。
具体地,根据所述硬标签和所述第二软标签确定第二偏差值构建的损失函教为:
Figure BDA0003471491300000072
其中,N[yi=c]为预测的样本xi对应的硬标签。
在一些实施例中,基于反向梯度传播算法,确定所述第一偏差值和所述第二偏差值对应的权重比例;根据所述第一偏差值和所述第二偏差值以及对应的权重比例确定第三偏差值。其中,所述反向梯度传播算法是一种用来训练人工神经网络的方法,该方法对网络中所有权重计算损失函数的梯度,这个梯度会反馈给最优化方法,用来更新权值以最小化损失函数。由此可以通过调整权重比例以平衡两类损失函数,从而使综合损失值最小化,达到初始化第一学生模型的模型参数的效果,生成第二学生模型。
具体地,根据所述第一偏差值和所述第二偏差值确定第三偏差值构建的损失函数为:
lKD=αlDL+(1-α)lCE
其中,lKD为第三偏差值,lDL为第一偏差值,lCE为第二偏差值,α为平衡两类损失函数的超参数,通过调整α的值可以调整蒸馏过程中学生模型对软硬标签的关注度,从而使综合损失值最小化,达到初始化第一学生模型的模型参数的效果,生成第二学生模型。
当神经网络仅采用硬标签的时候(即α=0时),损失了原始数据的信息,降低了模型对于数据的拟合难度,使得模型变得更加容易拟合,容易产生过拟合,导致模型的泛化能力下降。而采用软标签的时候(即α≠0时),模型需要学习更多的知识,比如学习两个接近的概率之间的相似性和差异性,从而增强模型的泛化能力。
通过最小化综合损失值,从而使第一学生模型能够平均地从所有教师中提取到知识,达到初始化第一学生模型的模型参数的效果,生成第二学生模型。
S103、通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签。
其中,所述第三软标签为第二学生模型对于训练样本数据属于某一类别对应的概率。
具体地,可以通过第二学生模型对所述训练样本数据进行识别,得到第三软标签。由此可以快速了解第二学生模型的识别准确率,能够快速确定第二学生模型的学习蒸馏情况,为后续的教师模型选择策略的更新作准备。
S104、根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型。
其中,所述教师模型选择策略用于选择教师模型对应的策略。
在一些实施例中,根据所述硬标签和所述第三软标签确定第四偏差值;根据所述第一软标签和所述第四偏差值生成状态向量参数;根据每个所述教师模型对应的状态向量参数确定更新策略;基于所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略。由此可以根据教师模型和更新了模型参数的学生模型的表现对教师模型选择策略进行更新。
其中,所述第四偏差值为所述硬标签和所述第三软标签的损失值,所述状态向量参数包括第一软标签和第四偏差值,用于表示更新迭代过程中每次迭代对应的状态,具体地,所述状态向量参数可以使用词向量表示,词向量是自然语言领域中普遍的说法,实际上就是词的向量表示,由于计算机无法直接计算没有进过处理的词,因此采用词向量的形式表示。所述更新策略用于对教师模型选择策略进行更斯,具体可以包括筛选后的教师模型等。
具体地,根据所述硬标签和所述第三软标签确定第四偏差值构建的损失函数为:
Figure BDA0003471491300000091
生成状态向量参数之后,根据每个所述教师模型对应的状态向量参数确定更新策略;基于所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略,以便重新选择对应的教师模型。
在一些实施例中,基于阈值函数,根据每个所述教师模型对应的状态向量参数计算得到所述教师模型选择策略中各个教师模型的分值;根据所述各个教师模型的分数对所述教师模型选择策略中的教师模型进行筛选,得到筛选后的教师模型;基于所述筛选后的教师模型确定更新策略。其中,所述阈值函数为可训练教师模型选择策略参数θ的sigmoid函数。Sigmoid函数是一个S型函数,用于将变量映射到[0,1]之间,所述分值为[0,1]之间的数值。
其中,具体可以用公式表示为:
π(sj,aj)=ajσ(AF(sj)+b)+(1-aj)(1-σ(AF(sj)+b))
其中,π(sj,aj)用于表征选或不选第j个教师模型,具体地,当π(sj,aj)=1时,选择第j个教师模型;当π(sj,aj)=0时,不选择第j个教师模型。sj为第j个状态,aj为0或者1的常数,具体根据所述状态向量参数的分布决定;σ(AF(sj)+b)为可训练教师模型选择策略参数θ的sigmoid函数,用于计算所述教师模型选择策略中各个教师模型的分值,其中F(sj)为状态向量参数,A和B可以为任意参数,具体通过实验或经验确定。
具体地,利用sigmoid函数将每个所述教师模型对应的状态向量参数映射得到对应的分值,确定所述分值是否超过预设的分值阈值;若所述分值超过预设的分值阈值,则将所述分值对应的教师模型筛选出来,得到筛选后的教师模型;若所述分值未超过预设的分值阈值,则不将所述分值对应的教师模型筛选出来。其中,所述预设的分值阈值可以为任意数值,在此不做具体限定。
示例性的,若预设的分值阈值为0.7,A模型对应的分值为0.8,B模型对应的分值为0.6,C模型对应的分值为0.75,则筛选后得到的教师模型为A模型和C模型。
在一些实施例中,通过测试集对所述第二学生模型进行测试,得到所述第二学生模型的准确率;根据所述第二学生模型的准确率确定对应的反馈策略;基于所述反馈策略和所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略。
所述反馈策略包括三种,可以用公式表示为:
Figure BDA0003471491300000101
其中,γ为模型超参数,表示两部分的关注程度,具体由用户自行确定,accD为所述第二学生模型的准确率。
具体地,可以根据实际情况选择对应的反馈公式从而得到对应的反馈策略,还可以根据第二学生模型的准确率确定对应的反馈策略,比如当第二学生模型的准确率小于预设第一准确率阈值时,选择-lCE这个反馈策略;当第二学生模型的准确率不小于预设第一准确率阈值且不大于预设第二准确率阈值时,选择-lCE-lDL这个反馈策略;当第二学生模型的准确率大于预设第一准确率阈值时,选择γ*(-lCE-lDL)+(1-γ)*accD这个反馈策略。其中,所述第一准确率阈值小于第二准确率阈值,具体数值在此不做具体限定。
具体地,基于所述反馈策略和所述更新策略对所述教师模型选择策略进行更新可以用公式表示为:
Figure BDA0003471491300000102
其中,θ为教师模型选择策略参数,更新教师模型选择策略实际上就是更新教师模型选择策略参数;lr是学习率,所述学习率可以是预先设置的,也可以根据训练过程进行调整;
Figure BDA0003471491300000103
是向量微分算符,可以用于指代梯度算符。
具体地,在第一次迭代时,根据对应的反馈策略和更新策略,并通过预先设置的学习率相点乘,由此可以对教师模型选择策略参数进行更新,从而实现对教师模型选择策略的更新。在多次迭代时,将每一次迭代得到的反馈策略进行累加,将每个筛选得到的教师模型求梯度并累加,最后通过预先设置的学习率相点乘,由此可以在多次迭代后对教师模型选择策略参数进行动态更新,从而实现对教师模型选择策略的动态更新。
S105、基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
其中,所述目标学生模型为经过多次迭代更新教师模型选择策略后,通过教师模型选择策略对应的教师模型蒸馏得到的表现最好的学生模型。所述收敛就是说指蒸馏得到的模型不再产生大的波动。
在一些实施例中,根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,生成第三学生模型;获取所述第三学生模型对应的第三偏差值,确定所述第三偏差值是否满足收敛条件;若所述第三偏差值满足收敛条件,则将所述第三学生模型作为目标学生模型。
具体地,通过所述重新确定的教师模型对所述训练样本数据进行识别,重新得到多个第一软标签;根据所述硬标签、重新得到的多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第三学生模型;其中,所述第三学生模型与第一学生模型和第二学生模型的模型参数不同。
同样的,根据所述多个第一软标签和所述第二软标签确定第一偏差值;根据所述硬标签和所述第二软标签确定第二偏差值;根据所述第一偏差值和所述第二偏差值确定第三偏差值,并根据所述第三偏差值对所述第一学生模型的模型参数进行初始化处理,生成第三学生模型。由于教师模型是经过重新确定的,因此每次迭代生成的学生模型对应的第三偏差值均不相同。
确定所述第三偏差值是否满足收敛条件具体可以通过确定所述第三偏差值是否小于预设偏差值来确定;若所述第三偏差值小于预设偏差值,则确定所述第三偏差值满足收敛条件,即第三学生模型能够作为目标学生模型;若所述第三偏差值不小于预设偏差值,则需要重新更新教师模型选择策略,从而调整蒸馏得到的学生模型的模型参数。其中,所述预设偏差值可以为任意数值,在此不做具体限定,由此可以控制综合损失值在一个很小的范围内,使蒸馏得到的学生模型不再有发生突变的情况。
具体地,获取所述第三学生模型对应的第三偏差值,确定所述第三偏差值是否满足收敛条件;若所述第三偏差值满足收敛条件,则将所述第三学生模型作为目标学生模型;若所述第三偏差值不满足收敛条件,则确定所述第三偏差值的损失程度,根据该损失程度对第三学生模型中的参数进行调整。
示例性的,比如损失程度越大,则对预设学生模型中的参数的调整越大;损失程度越小,则对预设学生模型中的参数的调整越小。这样,基于损失值对预设的学生模型进行调整,可以实现在学生模型的错误程度越大时,进行更大程度的调整,进而提高学生模型的收敛速度,提高训练效率,同时,也使得对学生模型的调整操作更加精准,进而提高学生模型训练的精度。
请参阅图3,图3是本申请一实施例提供的一种基于多教师模型的模型蒸馏装置的示意性框图,该基于多教师模型的模型蒸馏装置可以配置于服务器中,用于执行前述的基于多教师模型的模型蒸馏方法。
如图3所示,该基于多教师模型的模型蒸馏装置200包括:第一标签生成模块201、模型生成模块202、第二标签生成模块203、策略更新模块204和模型确定模块205。
第一标签生成模块201,用于获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签;
模型生成模块202,用于根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同;
第二标签生成模块203,用于通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签;
策略更新模块204,用于根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型;
模型确定模块205,用于基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
需要说明的是,所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,上述描述的装置和各模块、单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本申请的方法、装置可用于众多通用或专用的计算系统环境或配置中。例如:个人计算机、服务器计算机、手持设备或便携式设备、平板型设备、多处理器系统、基于微处理器的系统、机顶盒、可编程的消费终端设备、网络PC、小型计算机、大型计算机、包括以上任何系统或设备的分布式计算环境等等。
示例性的,上述的方法、装置可以实现为一种计算机程序的形式,该计算机程序可以在如图4所示的计算机设备上运行。
请参阅图4,图4是本申请实施例提供的一种计算机设备的示意图。该计算机设备可以是服务器。
如图4所示,该计算机设备包括通过系统总线连接的处理器、存储器和网络接口,其中,存储器可以包括非易失性存储介质和内存储器。
非易失性存储介质可存储操作系统和计算机程序。该计算机程序包括程序指令,该程序指令被执行时,可使得处理器执行任意一种基于多教师模型的模型蒸馏方法。
处理器用于提供计算和控制能力,支撑整个计算机设备的运行。
内存储器为非易失性存储介质中的计算机程序的运行提供环境,该计算机程序被处理器执行时,可使得处理器执行任意一种基于多教师模型的模型蒸馏方法。
该网络接口用于进行网络通信,如发送分配的任务等。本领域技术人员可以理解,该计算机设备的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
应当理解的是,处理器可以是中央处理单元(Central Processing Unit,CPU),该处理器还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。其中,通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
其中,在一些实施方式中,所述处理器用于运行存储在存储器中的计算机程序,以实现如下步骤:获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签;
根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同;通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签;根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型;基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
在一些实施例中,所述处理器还用于:根据所述多个第一软标签和所述第二软标签确定第一偏差值;根据所述硬标签和所述第二软标签确定第二偏差值;根据所述第一偏差值和所述第二偏差值确定第三偏差值,并根据所述第三偏差值对所述第一学生模型的模型参数进行初始化处理,生成第二学生模型。
在一些实施例中,所述处理器还用于:基于反向梯度传播算法,确定所述第一偏差值和所述第二偏差值对应的权重比例;根据所述第一偏差值和所述第二偏差值以及对应的权重比例确定第三偏差值。
在一些实施例中,所述处理器还用于:根据所述硬标签和所述第三软标签确定第四偏差值;根据多个所述第一软标签和所述第四偏差值生成每个所述教师模型对应的状态向量参数;根据每个所述教师模型对应的状态向量参数确定更新策略;基于所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略。
在一些实施例中,所述处理器还用于:基于阈值函数,根据每个所述教师模型对应的状态向量参数计算得到所述教师模型选择策略中各个教师模型的分值;根据所述各个教师模型的分值对所述教师模型选择策略中的教师模型进行筛选,得到筛选后的教师模型;基于所述筛选后的教师模型确定更新策略。
在一些实施例中,所述处理器还用于:通过测试集对所述第二学生模型进行测试,得到所述第二学生模型的准确率;根据所述第二学生模型的准确率确定对应的反馈策略;基于所述反馈策略和所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略。
在一些实施例中,所述处理器还用于:根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,生成第三学生模型;获取所述第三学生模型对应的第三偏差值,确定所述第三偏差值是否满足收敛条件;若所述第三偏差值满足收敛条件,则将所述第三学生模型作为目标学生模型。
本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,所述计算机程序中包括程序指令,所述程序指令被执行时实现本申请实施例提供的任一种基于多教师模型的模型蒸馏方法。
其中,所述计算机可读存储介质可以是前述实施例所述的计算机设备的内部存储单元,例如所述计算机设备的硬盘或内存。所述计算机可读存储介质也可以是所述计算机设备的外部存储设备,例如所述计算机设备上配备的插接式硬盘,智能存储卡(SmartMedia Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。
进一步地,所述计算机可读存储介质可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序等;存储数据区可存储根据区块链节点的使用所创建的数据等。
本发明所指区块链语言模型的存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。区块链(Blockchain),本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层等。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围为准。

Claims (10)

1.一种基于多教师模型的模型蒸馏方法,其特征在于,所述方法包括:
获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签;
根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同;
通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签;
根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型;
基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
2.根据权利要求1所述的方法,其特征在于,所述根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型,包括:
根据所述多个第一软标签和所述第二软标签确定第一偏差值;
根据所述硬标签和所述第二软标签确定第二偏差值;
根据所述第一偏差值和所述第二偏差值确定第三偏差值,并根据所述第三偏差值对所述第一学生模型的模型参数进行初始化处理,生成第二学生模型。
3.根据权利要求2所述的方法,其特征在于,所述根据所述第一偏差值和所述第二偏差值确定第三偏差值,包括:
基于反向梯度传播算法,确定所述第一偏差值和所述第二偏差值对应的权重比例;
根据所述第一偏差值和所述第二偏差值以及对应的权重比例确定第三偏差值。
4.根据权利要求1所述的方法,其特征在于,所述根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,包括:
根据所述硬标签和所述第三软标签确定第四偏差值;
根据多个所述第一软标签和所述第四偏差值生成每个所述教师模型对应的状态向量参数;
根据每个所述教师模型对应的状态向量参数确定更新策略;
基于所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略。
5.根据权利要求4所述的方法,其特征在于,所述根据每个所述教师模型对应的状态向量参数确定更新策略,包括:
基于阈值函数,根据每个所述教师模型对应的状态向量参数计算得到所述教师模型选择策略中各个教师模型的分值;
根据所述各个教师模型的分值对所述教师模型选择策略中的教师模型进行筛选,得到筛选后的教师模型;
基于所述筛选后的教师模型确定更新策略。
6.根据权利要求4所述的方法,其特征在于,所述根据每个所述教师模型对应的状态向量参数确定更新策略之后,所述方法还包括:
通过测试集对所述第二学生模型进行测试,得到所述第二学生模型的准确率;
根据所述第二学生模型的准确率确定对应的反馈策略;
基于所述反馈策略和所述更新策略对所述教师模型选择策略进行更新,得到更新后的教师模型选择策略。
7.根据权利要求1所述的方法,其特征在于,所述根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型,包括:
根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,生成第三学生模型;
获取所述第三学生模型对应的第三偏差值,确定所述第三偏差值是否满足收敛条件;
若所述第三偏差值满足收敛条件,则将所述第三学生模型作为目标学生模型。
8.一种基于多教师模型的模型蒸馏装置,其特征在于,包括:
第一标签生成模块,用于获取训练样本数据和所述训练样本数据对应的硬标签,通过多个教师模型对所述训练样本数据进行识别,得到多个第一软标签,以及通过第一学生模型对所述训练样本数据进行识别,得到第二软标签;
模型生成模块,用于根据所述硬标签、所述多个第一软标签和所述第二软标签,对所述第一学生模型进行知识蒸馏学习,生成第二学生模型;其中,所述第一学生模型的模型参数和所述第二学生模型的模型参数不同;
第二标签生成模块,用于通过所述第二学生模型对所述训练样本数据进行识别,得到第三软标签;
策略更新模块,用于根据所述第一软标签和所述第三软标签更新教师模型选择策略,得到更新后的教师模型选择策略,所述教师模型选择策略用于选择教师模型;
模型确定模块,用于基于所述更新后的教师模型选择策略重新确定对应的教师模型,并根据所述重新确定的教师模型对所述第一学生模型进行知识蒸馏学习,直至所述第一学生模型收敛,得到目标学生模型。
9.一种计算机设备,其特征在于,所述计算机设备包括存储器和处理器;
所述存储器,用于存储计算机程序;
所述处理器,用于执行所述的计算机程序并在执行所述的计算机程序时实现:
如权利要求1-7任一项所述的基于多教师模型的模型蒸馏方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时使所述处理器实现如权利要求1至7中任一项所述的基于多教师模型的模型蒸馏方法。
CN202210044224.7A 2022-01-14 2022-01-14 基于多教师模型的模型蒸馏方法、装置、设备及存储介质 Pending CN114386604A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210044224.7A CN114386604A (zh) 2022-01-14 2022-01-14 基于多教师模型的模型蒸馏方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210044224.7A CN114386604A (zh) 2022-01-14 2022-01-14 基于多教师模型的模型蒸馏方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN114386604A true CN114386604A (zh) 2022-04-22

Family

ID=81202099

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210044224.7A Pending CN114386604A (zh) 2022-01-14 2022-01-14 基于多教师模型的模型蒸馏方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN114386604A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117391900A (zh) * 2023-11-23 2024-01-12 重庆第二师范学院 基于大数据分析的学习效率检测系统及方法
WO2024087468A1 (zh) * 2022-10-25 2024-05-02 京东城市(北京)数字科技有限公司 类别预测模型的训练方法、预测方法、设备和存储介质

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2024087468A1 (zh) * 2022-10-25 2024-05-02 京东城市(北京)数字科技有限公司 类别预测模型的训练方法、预测方法、设备和存储介质
CN117391900A (zh) * 2023-11-23 2024-01-12 重庆第二师范学院 基于大数据分析的学习效率检测系统及方法
CN117391900B (zh) * 2023-11-23 2024-05-24 重庆第二师范学院 基于大数据分析的学习效率检测系统及方法

Similar Documents

Publication Publication Date Title
Swathi et al. An optimal deep learning-based LSTM for stock price prediction using twitter sentiment analysis
CN111859960B (zh) 基于知识蒸馏的语义匹配方法、装置、计算机设备和介质
WO2023065545A1 (zh) 风险预测方法、装置、设备及存储介质
Zhang et al. MOOCRC: A highly accurate resource recommendation model for use in MOOC environments
CN114386604A (zh) 基于多教师模型的模型蒸馏方法、装置、设备及存储介质
CN112380344A (zh) 文本分类的方法、话题生成的方法、装置、设备及介质
CN113706151A (zh) 一种数据处理方法、装置、计算机设备及存储介质
CN112269875B (zh) 文本分类方法、装置、电子设备及存储介质
CN113821527A (zh) 哈希码的生成方法、装置、计算机设备及存储介质
CN112785005A (zh) 多目标任务的辅助决策方法、装置、计算机设备及介质
US11568263B2 (en) Techniques to perform global attribution mappings to provide insights in neural networks
CN114880449B (zh) 智能问答的答复生成方法、装置、电子设备及存储介质
CN113761375A (zh) 基于神经网络的消息推荐方法、装置、设备及存储介质
Hao et al. Sentiment recognition and analysis method of official document text based on BERT–SVM model
Yousefnezhad et al. A new selection strategy for selective cluster ensemble based on diversity and independency
CN115713386A (zh) 一种多源信息融合的商品推荐方法及系统
Stein et al. Applying QNLP to sentiment analysis in finance
Das A new technique for classification method with imbalanced training data
CN116340458A (zh) 一种相似司法案例匹配方法、装置及设备
CN113268563B (zh) 基于图神经网络的语义召回方法、装置、设备及介质
Vo et al. Development of a fake news detection tool for Vietnamese based on deep learning techniques
CN114529063A (zh) 一种基于机器学习的金融领域数据预测方法、设备及介质
Xiaohui An adaptive genetic algorithm-based background elimination model for English text
Karpov et al. Elimination of negative circuits in certain neural network structures to achieve stable solutions
Liu et al. Novel Uncertainty Quantification through Perturbation-Assisted Sample Synthesis

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination