CN112101526A - 基于知识蒸馏的模型训练方法及装置 - Google Patents

基于知识蒸馏的模型训练方法及装置 Download PDF

Info

Publication number
CN112101526A
CN112101526A CN202010965719.4A CN202010965719A CN112101526A CN 112101526 A CN112101526 A CN 112101526A CN 202010965719 A CN202010965719 A CN 202010965719A CN 112101526 A CN112101526 A CN 112101526A
Authority
CN
China
Prior art keywords
model
data set
training
loss function
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010965719.4A
Other languages
English (en)
Inventor
宿绍勋
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
BOE Technology Group Co Ltd
Original Assignee
BOE Technology Group Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by BOE Technology Group Co Ltd filed Critical BOE Technology Group Co Ltd
Priority to CN202010965719.4A priority Critical patent/CN112101526A/zh
Publication of CN112101526A publication Critical patent/CN112101526A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Feedback Control In General (AREA)

Abstract

本发明公开了一种基于知识蒸馏的模型训练方法及装置,涉及知识蒸馏技术领域,主要目的在于提高student模型的预测精度。本发明主要的技术方案为:利用第一数据集训练第一模型;基于数据增强技术对所述第一数据集进行扩展,得到第二数据集;利用所述第一模型、第一数据集与第二数据集训练第二模型,确定所述第二模型的损失函数。本发明用于训练高精度的student模型。

Description

基于知识蒸馏的模型训练方法及装置
技术领域
本发明涉及知识蒸馏技术领域,尤其涉及一种基于知识蒸馏的模型训练方法及装置。
背景技术
知识蒸馏的概念首次提出于文章《Distilling the Knowledge in a NeuralNetwork》中,通过引入教师网络用以诱导学生网络的训练,实现知识迁移。因此,知识蒸馏是将一个网络的知识转移到另一个网络,两个网络可以是同构或者异构。做法是先训练一个teacher模型,然后使用这个teacher模型的输出和数据的真实标签去训练student模型。知识蒸馏,可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能,以此解决模型在边缘段的部署硬件不足的问题。
但是,在知识蒸馏的过程中,由于student模型的结构更为简单,使得该student模型的精度只能是尽可能接近teacher模型的精度,这就使得在实际应用过程中,基于知识蒸馏得到的student模型的精度无法满足应用需求。
发明内容
鉴于上述问题,本发明提出了一种基于知识蒸馏的模型训练方法及装置,主要目的在于提高student模型的预测精度。
为达到上述目的,本发明主要提供如下技术方案:
第一方面,本发明提供一种基于知识蒸馏的模型训练方法,包括:
利用第一数据集训练第一模型;
基于数据增强技术对所述第一数据集进行扩展,得到第二数据集;
利用所述第一模型、第一数据集与第二数据集训练第二模型,确定所述第二模型的损失函数。
优选的,利用所述第一模型、第一数据集与第二数据集训练第二模型,确定所述第二模型的损失函数,包括:
利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数;
利用第一数据集与第二数据集训练所述第二模型,得到第二损失函数;
根据所述第一损失函数与第二损失函数确定所述第二模型的损失函数。
优选的,利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数,包括:
将同一样本数据分别输入所述第一模型与第二模型,得到第一逻辑输出与第二逻辑输出;
基于所述第一逻辑输出与第二逻辑输出的均方误差确定为所述第一损失函数。
优选的,利用第一数据集与第二数据集训练所述第二模型,得到第二损失函数,包括:
利用第二模型对所述第一数据集与第二数据集中的数据进行预测,确定所述数据的预测标注信息;
根据所述数据的预测标注信息与所述数据携带的标注信息的交叉熵,确定所述第二损失函数。
优选的,根据所述第一损失函数与第二损失函数确定所述第二模型的损失函数,包括:
获取所述第一损失函数与第二损失函数的占比系数;
根据所述占比系数对所述第一损失函数与第二损失函数进行加权求和,确定所述第二模型的损失函数。
优选的,在利用第一数据集训练第一模型之前,所述方法还包括:
构建第一模型,所述第一模型的模型结构中含有第二模型的模型结构。
优选的,所述第二模型采用TextCNN模型,所述构建第一模型包括:
将BERT模型与TextCNN模型组合构建所述第一模型,其中,BERT模型的输出作为所述TextCNN模型的输入,所述第一模型的输入为所述BERT模型的输入,所述第一模型的输出为所述TextCNN模型的输出。
优选的,所述第一数据集中的数据为携带标注信息的语句,所述基于数据增强技术对所述第一数据集进行扩展,得到第二数据集包括:
利用BERT模型处理所述第一数据集中的语句,生成与所述语句含义相似的扩展语句;
在所述扩展语句上对应标注所述语句携带的标注信息,得到所述第二数据集。
优选的,所述方法还包括:
利用所述第二模型对目标数据集中的数据进行预测。
第二方面,本发明提供一种基于知识蒸馏的模型训练装置,所述装置包括:
训练单元,用于利用第一数据集训练第一模型;
扩展单元,用于基于数据增强技术对所述第一数据集进行扩展,得到第二数据集;
确定单元,用于利用所述训练单元得到的第一模型、所述第一数据集与所述扩展单元得到的第二数据集训练第二模型,确定所述第二模型的损失函数。
优选的,所述确定单元包括:
第一训练模块,用于利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数;
第二训练模块,用于利用第一数据集与第二数据集训练所述第二模型,得到第二损失函数;
确定模块,用于根据所述第一训练模块得到的第一损失函数与所述第二训练模块得到的第二损失函数确定所述第二模型的损失函数。
优选的,所述第一训练模块具体用于:
将同一样本数据分别输入所述第一模型与第二模型,得到第一逻辑输出与第二逻辑输出;
基于所述第一逻辑输出与第二逻辑输出的均方误差确定为所述第一损失函数。
优选的,所述第二训练模块具体用于:
利用第二模型对所述第一数据集与第二数据集中的数据进行预测,确定所述数据的预测标注信息;
根据所述数据的预测标注信息与所述数据携带的标注信息的交叉熵,确定所述第二损失函数。
优选的,所述确定模块具体用于:
获取所述第一损失函数与第二损失函数的占比系数;
根据所述占比系数对所述第一损失函数与第二损失函数进行加权求和,确定所述第二模型的损失函数。
优选的,所述装置还包括:
构建单元,用于在训练单元利用第一数据集训练第一模型之前,构建第一模型,所述第一模型的模型结构中含有第二模型的模型结构。
优选的,所述第二模型采用TextCNN模型,所述构建单元具体用于,将BERT模型与TextCNN模型组合构建所述第一模型,其中,BERT模型的输出作为所述TextCNN模型的输入,所述第一模型的输入为所述BERT模型的输入,所述第一模型的输出为所述TextCNN模型的输出。
优选的,所述扩展单元包括:
生成模块,用于利用BERT模型处理所述第一数据集中的语句,生成与所述语句含义相似的扩展语句;
标注模块,用于在所述生成模块得到的扩展语句上对应标注所述语句携带的标注信息,得到所述第二数据集。
优选的,所述装置还包括:
预测单元,用于利用所述第二模型对目标数据集中的数据进行预测。
另一方面,本发明还提供一种处理器,所述处理器用于运行程序,其中,所述程序运行时执行上述第一方面的基于知识蒸馏的模型训练方法。
另一方面,本发明还提供一种存储介质,所述存储介质用于存储计算机程序,其中,所述计算机程序运行时控制所述存储介质所在设备执行上述第一方面的基于知识蒸馏的模型训练方法。
借由上述技术方案,本发明提供的一种基于知识蒸馏的模型训练方法及装置,将第一模型作为teacher模型,将第二模型作为student模型,在知识蒸馏过程中,通过对训练第一模型的第一数据集进行扩展,并结合扩展得到的第二数据集对第二模型进行训练,确定第二模型的损失函数,让第二模型通过更多数据的训练,达到提升模型预测精度的目的。相对于现有的知识蒸馏方式,本发明是将所扩展得到的第二数据集仅用于训练第二模型,而没有将该第二数据集用于训练第一模型,使得在确定第二函数的损失函数时,是利用由第一数据集训练的第一模型与第二数据集共同训练得到的,而两者之间的差异可以避免第二模型在学习第一模型时由于数据近似而导致的过拟合问题,同时,所扩展出的第二数据集也为第二模型提供了大量的训练样本,也使得第二模型的预测精度更加近似于第一模型。
上述说明仅是本发明技术方案的概述,为了能够更清楚了解本发明的技术手段,而可依照说明书的内容予以实施,并且为了让本发明的上述和其它目的、特征和优点能够更明显易懂,以下特举本发明的具体实施方式。
附图说明
通过阅读下文优选实施方式的详细描述,各种其他的优点和益处对于本领域普通技术人员将变得清楚明了。附图仅用于示出优选实施方式的目的,而并不认为是对本发明的限制。而且在整个附图中,用相同的参考符号表示相同的部件。在附图中:
图1示出了本发明实施例提出的一种基于知识蒸馏的模型训练方法的流程图;
图2示出了本发明实施例提出的另一种基于知识蒸馏的模型训练方法的流程图;
图3示出了本发明实施例中第一模型的模型结构;
图4示出了本发明实施例提出的基于知识蒸馏训练第二模型的流程框图;
图5示出了本发明实施例提出的一种基于知识蒸馏的模型训练装置的结构示意图;
图6示出了本发明实施例提出的另一种基于知识蒸馏的模型训练装置的结构示意图。
具体实施方式
下面将参照附图更详细地描述本发明的示例性实施例。虽然附图中显示了本发明的示例性实施例,然而应当理解,可以以各种形式实现本发明而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更透彻地理解本发明,并且能够将本发明的范围完整的传达给本领域的技术人员。
本发明实施例提供了一种基于知识蒸馏的模型训练方法,该方法相对于现有的知识蒸馏,可以使得student模型在知识蒸馏过程中在避免过拟合的情况下,提高模型的训练效果,让student模型的精度更加接近或达到teacher模型的效果。其具体执行步骤如图1所示,包括:
101、利用第一数据集训练第一模型。
本步骤是将经过训练的第一模型作为知识蒸馏过程中的teacher模型,而第一数据集为携带有标注信息的训练样本集。
其中,本发明实施例对于该第一模型的结构不做具体限定,可以根据业务需求而选择所需的模型。
102、基于数据增强技术对第一数据集进行扩展,得到第二数据集。
本步骤中,数据增强的目的是要得到与第一数据集中数据样本相近似的样本,以获取足够数量的数据样本。
具体的,本步骤可以是基于第一数据集中的数据样本进行修改,得到相似的数据样本,也可以是通过其他渠道获取数据样本,再从这些数据样本中找出与第一数据集中的数据样本相近似的,对此本实施例不做限定。
本实施例中,经过扩展得到的第二数据集中的数据样本不参与对第一模型的训练,仅用于对第二模型的训练,以此增加第一模型与第二模型训练的差异性。
103、利用第一模型、第一数据集与第二数据集训练第二模型,确定第二模型的损失函数。
本步骤中的第二模型是指知识蒸馏过程中的student模型,现有的知识蒸馏过程中,训练student模型的过程也是用第一数据集中的硬标签(即,数据样本所携带的标签或标注信息)来训练teacher模型,再利用该teacher模型中得到的软标签(即,经过softmax层的输出)结合第一数据集中的硬标签共同确定student模型的损失函数。而本发明实施例中区别于现有知识蒸馏过程的主要特征在于,利用第二数据集来训练第二模型,提高了第二模型的样本基数,同时,由于第一模型并未使用第二数据集训练,这就使得其得到的软标签的分布将会存在明显差异,而这种差异可以防止知识蒸馏过程中第二模型相对第一模型的过拟合。
基于上述图1的实现方式可以看出,本发明实施例所提出的基于知识蒸馏的模型训练方法,是利用数据增强技术对第一数据集进行扩展,并通过对第一数据集与第二数据集的差异化应用,防止第二模型在知识蒸馏过程中出现过拟合导致的精度下降问题。而通过增加用于训练第二模型的样本数量,也可以进一步提高第二模型与第一模型的相似程度,从而提高第二模型的预测精度。
进一步的,本发明的优先实施例是在上述图1的基础上,对各个步骤进行优化,以确保第二模型的预测精度通过知识蒸馏后能够更加接近于第一模型的预测精度。
具体的,本发明实施例以文本分类模型为例详细说明所提出的知识蒸馏过程,其中,第二模型(student模型)采用TextCNN模型,其具体步骤如图2所示,包括:
201、构建第一模型。
本实施例中的第一模型的模型结构需要根据第二模型的模型结构而确定。即,在第一模型的模型结构中要含有第二模型的模型结构。
而两个模型的模型结构相互结合需要根据具体的模型结构确定,其目的上两个模型在结构上相类似,可以让知识蒸馏得到更好的效果。
具体到本实施例中,由于第二模型为TextCNN模型,第一模型是以BERT模型为基础构建的模型,该第一模型的结构如图3所示,图中E1,E2,…,En为输入的字向量,Trm为BERT模型中的transformer结构,以上两个部分代表了BERT模型,将BERT模型最后一层softmax的逻辑输出传给TextCNN,由其输出结果T1,T2,…,Tn。可见,本实施例中的第一模型是将BERT模型与TextCNN模型进行串联,将BERT模型的输出作为TextCNN模型的输入,使得第一模型的输入为BERT模型的输入,第一模型的输出为TextCNN模型的输出。
202、利用第一数据集训练第一模型。
本实施例中,第一数据集中的数据样本为携带有标注信息的语句,其中,在文本分类应用的场景中,常见的标注信息如情感标识、意图标识等,对此本实施例不具体限定。
203、基于数据增强技术对第一数据集进行扩展,得到第二数据集。
具体的,本实施例中利用BERT模型处理所述第一数据集中的语句,生成与该语句含义相似的扩展语句。其中,BERT模型通过有监督训练,实现对输入语句的增强扩展,即利用Seq2Seq生成任务得到与输入语句相近似的扩展语句。一条语句所对应生成的扩展语句数量可根据需求进行自定义设置。
在生成扩展语句后,在扩展语句上对应标注输入语句所携带的标注信息,并将带有标注信息的扩展语句加入第二数据集,如此便可得到大量带有标注信息的新训练样本。
204、利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数。
具体的,将同一样本数据分别输入第一模型与第二模型,得到第一逻辑输出与第二逻辑输出。其中,该样本数据取自第一数据集。
之后,基于第一逻辑输出与第二逻辑输出的均方误差确定为第一损失函数。
该第一损失函数可以表示为:
Figure BDA0002682204350000081
其中,LDS表示第一损失函数,z(T)、z(S)表示第一模型(teacher模型)和第二模型(student模型)的逻辑输出。
205、利用第一数据集与第二数据集训练第二模型,得到第二损失函数。
具体的,利用第二模型对第一数据集与第二数据集中的数据进行预测,确定该数据的预测标注信息,即,采用数据集中的数据样本对第二模型进行训练,得到第二模型的预测结果。
之后,根据数据的预测标注信息与该数据携带的标注信息的交叉熵,确定第二损失函数。即,计算数据样本的实际标注信息与其经过第二模型预测的预测结果的交叉熵,利用该交叉熵求第二损失函数。
该第二损失函数可以表示为:
Figure BDA0002682204350000091
其中,LCE表示第二损失函数,ti表示第i个数据样本实际携带的标注信息,yi表示第i个数据样本经过第二模型的预测结果,s表示第二模型(student模型),N表示数据样本的数量。
需要说明的是,在实际应用中,主要应用第二数据集中的数据样本来训练第二模型,确定第二损失函数。此外,本步骤与步骤204之间没有逻辑上的先后顺序关系。
206、根据第一损失函数与第二损失函数确定第二模型的损失函数。
具体的,获取所述第一损失函数与第二损失函数的占比系数。在本实施例中,由于损失函数是以第一损失函数与第二损失函数所构成,因此,确定其中一个占比系数(α),就可以确定另一个占比系数(即,1-α)。
之后,根据占比系数对第一损失函数与第二损失函数进行加权求和,确定第二模型的损失函数。
该第二模型的损失函数结合上述说明可具体表示为:
L=α·LCE+(1-α)LDS
综合上述说明,本发明实施例提出的基于知识蒸馏的模型训练过程可以通过图4表示,图4示出了基于知识蒸馏训练第二模型的流程框图,根据该训练过程,可以确定第二模型的损失函数,进而应用该第二模型对目标数据集中的数据进行预测,该目标数据集中的数据为将第二模型部署到设备端后的输入数据。
进一步的,作为对上述图1、2所示方法实施例的实现,本发明实施例提供了一种基于知识蒸馏的模型训练装置,该装置用于提高student模型的预测精度。该装置的实施例与前述方法实施例对应,为便于阅读,本实施例不再对前述方法实施例中的细节内容进行逐一赘述,但应当明确,本实施例中的装置能够对应实现前述方法实施例中的全部内容。具体如图5所示,该装置包括:
训练单元31,用于利用第一数据集训练第一模型;
扩展单元32,用于基于数据增强技术对所述第一数据集进行扩展,得到第二数据集;
确定单元33,用于利用所述训练单元31得到的第一模型、所述第一数据集与所述扩展单元32得到的第二数据集训练第二模型,确定所述第二模型的损失函数。
进一步的,如图6所示,所述确定单元33包括:
第一训练模块331,用于利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数;
第二训练模块332,用于利用第一数据集与第二数据集训练所述第二模型,得到第二损失函数;
确定模块333,用于根据所述第一训练模块331得到的第一损失函数与所述第二训练模块332得到的第二损失函数确定所述第二模型的损失函数。
进一步的,所述第一训练模块331具体用于:
将同一样本数据分别输入所述第一模型与第二模型,得到第一逻辑输出与第二逻辑输出;
基于所述第一逻辑输出与第二逻辑输出的均方误差确定为所述第一损失函数。
进一步的,所述第二训练模块332具体用于:
利用第二模型对所述第一数据集与第二数据集中的数据进行预测,确定所述数据的预测标注信息;
根据所述数据的预测标注信息与所述数据携带的标注信息的交叉熵,确定所述第二损失函数。
进一步的,所述确定模块333具体用于:
获取所述第一损失函数与第二损失函数的占比系数;
根据所述占比系数对所述第一损失函数与第二损失函数进行加权求和,确定所述第二模型的损失函数。
进一步的,如图6所示,所述装置还包括:
构建单元34,用于在训练单元31利用第一数据集训练第一模型之前,构建第一模型,所述第一模型的模型结构中含有第二模型的模型结构。
进一步的,所述第二模型采用TextCNN模型,所述构建单元34具体用于,将BERT模型与TextCNN模型组合构建所述第一模型,其中,BERT模型的输出作为所述TextCNN模型的输入,所述第一模型的输入为所述BERT模型的输入,所述第一模型的输出为所述TextCNN模型的输出。
进一步的,如图6所示,所述扩展单元32包括:
生成模块321,用于利用BERT模型处理所述第一数据集中的语句,生成与所述语句含义相似的扩展语句;
标注模块322,用于在所述生成模块321得到的扩展语句上对应标注所述语句携带的标注信息,得到所述第二数据集。
进一步的,如图6所示,所述装置还包括:
预测单元35,用于利用所述第二模型对目标数据集中的数据进行预测。
进一步的,本发明实施例还提供一种处理器,所述处理器用于运行程序,其中,所述程序运行时执行上述图1-2中所述的基于知识蒸馏的模型训练方法。
进一步的,本发明实施例还提供一种存储介质,所述存储介质用于存储计算机程序,其中,所述计算机程序运行时控制所述存储介质所在设备执行上述图1-2中所述的基于知识蒸馏的模型训练方法。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见其他实施例的相关描述。
可以理解的是,上述方法及装置中的相关特征可以相互参考。另外,上述实施例中的“第一”、“第二”等是用于区分各实施例,而并不代表各实施例的优劣。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在此提供的算法和显示不与任何特定计算机、虚拟系统或者其它设备固有相关。各种通用系统也可以与基于在此的示教一起使用。根据上面的描述,构造这类系统所要求的结构是显而易见的。此外,本发明也不针对任何特定编程语言。应当明白,可以利用各种编程语言实现在此描述的本发明的内容,并且上面对特定语言所做的描述是为了披露本发明的最佳实施方式。
此外,存储器可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM),存储器包括至少一个存储芯片。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
在一个典型的配置中,计算设备包括一个或多个处理器(CPU)、输入/输出接口、网络接口和内存。
存储器可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM)。存储器是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitory media),如调制的数据信号和载波。
还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括要素的过程、方法、商品或者设备中还存在另外的相同要素。
本领域技术人员应明白,本申请的实施例可提供为方法、系统或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
以上仅为本申请的实施例而已,并不用于限制本申请。对于本领域技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本申请的权利要求范围之内。

Claims (10)

1.一种基于知识蒸馏的模型训练方法,所述方法包括:
利用第一数据集训练第一模型;
基于数据增强技术对所述第一数据集进行扩展,得到第二数据集;
利用所述第一模型、第一数据集与第二数据集训练第二模型,确定所述第二模型的损失函数。
2.根据权利要求1所述的方法,其特征在于,利用所述第一模型、第一数据集与第二数据集训练第二模型,确定所述第二模型的损失函数,包括:
利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数;
利用第一数据集与第二数据集训练所述第二模型,得到第二损失函数;
根据所述第一损失函数与第二损失函数确定所述第二模型的损失函数。
3.根据权利要求2所述的方法,其特征在于,利用第一模型的逻辑输出训练所述第二模型,得到第一损失函数,包括:
将同一样本数据分别输入所述第一模型与第二模型,得到第一逻辑输出与第二逻辑输出;
基于所述第一逻辑输出与第二逻辑输出的均方误差确定为所述第一损失函数。
4.根据权利要求2所述的方法,其特征在于,利用第一数据集与第二数据集训练所述第二模型,得到第二损失函数,包括:
利用第二模型对所述第一数据集与第二数据集中的数据进行预测,确定所述数据的预测标注信息;
根据所述数据的预测标注信息与所述数据携带的标注信息的交叉熵,确定所述第二损失函数。
5.根据权利要求2所述的方法,其特征在于,根据所述第一损失函数与第二损失函数确定所述第二模型的损失函数,包括:
获取所述第一损失函数与第二损失函数的占比系数;
根据所述占比系数对所述第一损失函数与第二损失函数进行加权求和,确定所述第二模型的损失函数。
6.根据权利要求1所述的方法,其特征在于,在利用第一数据集训练第一模型之前,所述方法还包括:
构建第一模型,所述第一模型的模型结构中含有第二模型的模型结构。
7.根据权利要求6所述的方法,其特征在于,所述第二模型采用TextCNN模型,所述构建第一模型包括:
将BERT模型与TextCNN模型组合构建所述第一模型,其中,BERT模型的输出作为所述TextCNN模型的输入,所述第一模型的输入为所述BERT模型的输入,所述第一模型的输出为所述TextCNN模型的输出。
8.一种基于知识蒸馏的模型训练装置,所述装置包括:
训练单元,用于利用第一数据集训练第一模型;
扩展单元,用于基于数据增强技术对所述第一数据集进行扩展,得到第二数据集;
确定单元,用于利用所述训练单元得到的第一模型、所述第一数据集与所述扩展单元得到的第二数据集训练第二模型,确定所述第二模型的损失函数。
9.一种处理器,其特征在于,所述处理器用于运行程序,其中,所述程序运行时执行权利要求1-7中任意一项所述的基于知识蒸馏的模型训练方法。
10.一种存储介质,其特征在于,所述存储介质用于存储计算机程序,其中,所述计算机程序运行时控制所述存储介质所在设备执行权利要求1-7中任意一项所述的基于知识蒸馏的模型训练方法。
CN202010965719.4A 2020-09-15 2020-09-15 基于知识蒸馏的模型训练方法及装置 Pending CN112101526A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010965719.4A CN112101526A (zh) 2020-09-15 2020-09-15 基于知识蒸馏的模型训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010965719.4A CN112101526A (zh) 2020-09-15 2020-09-15 基于知识蒸馏的模型训练方法及装置

Publications (1)

Publication Number Publication Date
CN112101526A true CN112101526A (zh) 2020-12-18

Family

ID=73758590

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010965719.4A Pending CN112101526A (zh) 2020-09-15 2020-09-15 基于知识蒸馏的模型训练方法及装置

Country Status (1)

Country Link
CN (1) CN112101526A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112686046A (zh) * 2021-01-06 2021-04-20 上海明略人工智能(集团)有限公司 模型训练方法、装置、设备及计算机可读介质
CN112988975A (zh) * 2021-04-09 2021-06-18 北京语言大学 一种基于albert和知识蒸馏的观点挖掘方法
CN113204633A (zh) * 2021-06-01 2021-08-03 吉林大学 一种语义匹配蒸馏方法及装置
CN114663714A (zh) * 2022-05-23 2022-06-24 阿里巴巴(中国)有限公司 图像分类、地物分类方法和装置
CN115309849A (zh) * 2022-06-27 2022-11-08 北京邮电大学 一种基于知识蒸馏的特征提取方法、装置及数据分类方法

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112686046A (zh) * 2021-01-06 2021-04-20 上海明略人工智能(集团)有限公司 模型训练方法、装置、设备及计算机可读介质
CN112988975A (zh) * 2021-04-09 2021-06-18 北京语言大学 一种基于albert和知识蒸馏的观点挖掘方法
CN113204633A (zh) * 2021-06-01 2021-08-03 吉林大学 一种语义匹配蒸馏方法及装置
CN113204633B (zh) * 2021-06-01 2022-12-30 吉林大学 一种语义匹配蒸馏方法及装置
CN114663714A (zh) * 2022-05-23 2022-06-24 阿里巴巴(中国)有限公司 图像分类、地物分类方法和装置
CN115309849A (zh) * 2022-06-27 2022-11-08 北京邮电大学 一种基于知识蒸馏的特征提取方法、装置及数据分类方法

Similar Documents

Publication Publication Date Title
CN112101526A (zh) 基于知识蒸馏的模型训练方法及装置
CN110287477B (zh) 实体情感分析方法及相关装置
CN111738016B (zh) 多意图识别方法及相关设备
Firdaus et al. A deep multi-task model for dialogue act classification, intent detection and slot filling
CN111783993A (zh) 智能标注方法、装置、智能平台及存储介质
CN113837370B (zh) 用于训练基于对比学习的模型的方法和装置
CN112711660A (zh) 文本分类样本的构建方法和文本分类模型的训练方法
CN111078881B (zh) 细粒度情感分析方法、系统、电子设备和存储介质
US20210150270A1 (en) Mathematical function defined natural language annotation
CN116579345B (zh) 命名实体识别模型的训练方法、命名实体识别方法及装置
CN111062204B (zh) 基于机器学习的文本标点符号使用错误的识别方法和装置
CN114548102A (zh) 实体文本的序列标注方法、装置及计算机可读存储介质
CN110851600A (zh) 基于深度学习的文本数据处理方法及装置
CN108460453B (zh) 一种用于ctc训练的数据处理方法、装置及系统
CN111126066B (zh) 基于神经网络的中文修辞手法的确定方法和装置
CN113536790A (zh) 基于自然语言处理的模型训练方法及装置
Kumar et al. Domain adaptation based technique for image emotion recognition using image captions
Singh et al. Building Machine Learning System with Deep Neural Network for Text Processing
CN112541341A (zh) 一种文本事件元素提取方法
Zagagy et al. ACKEM: automatic classification, using KNN based ensemble modeling
Modran et al. Learning Methods Based on Artificial Intelligence in Educating Engineers for the New Jobs of the 5 th Industrial Revolution
CN112579768A (zh) 一种情感分类模型训练方法、文本情感分类方法及装置
CN111860508A (zh) 图像样本选择方法及相关设备
CN117574878B (zh) 用于混合领域的成分句法分析方法、装置及介质
CN113849592B (zh) 文本情感分类方法及装置、电子设备、存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination