CN115374278A - 文本处理模型蒸馏方法、装置、计算机设备及介质 - Google Patents
文本处理模型蒸馏方法、装置、计算机设备及介质 Download PDFInfo
- Publication number
- CN115374278A CN115374278A CN202210948994.4A CN202210948994A CN115374278A CN 115374278 A CN115374278 A CN 115374278A CN 202210948994 A CN202210948994 A CN 202210948994A CN 115374278 A CN115374278 A CN 115374278A
- Authority
- CN
- China
- Prior art keywords
- label
- model
- original data
- preset
- distillation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Abstract
本发明公开了一种文本处理模型蒸馏方法,该方法包括:获取原始数据以及原始数据对应的原始标签,对第一预设模型进行训练,得到第一预测模型;通过第一预测模型对原始数据进行预测,得到预测标签,并将所有预测标签分为目标标签以及未达标标签;根据未达标标签和未达标标签对应的原始数据对第二预设模型进行训练,得到第二预测模型,并对未达标标签进行优化处理,得到优化标签;通过原始数据、原始标签、目标标签和优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。本发明通过一次蒸馏将第一预测模型和第二预测模型的预测标签蒸馏到文本处理模型中,进而提升了文本处理模型预测的准确性,以及提高了文本处理模型蒸馏的效率。
Description
技术领域
本发明涉及预测模型技术领域,尤其涉及一种文本处理模型蒸馏方法、装置、计算机设备及介质。
背景技术
随着科学技术的发展,自然语言处理技术也逐渐应用在不同的领域当中。例如,关键词抽取,实体识别或者短语抽取等技术。这些技术往往需要通过训练模型的方法进行实现。例如训练文本处理模型对文本进行关键词抽取等。
现有技术中,往往在通过标注标签的文本数据对文本处理模型进行训练。针对预测效果较差的文本数据,往往需要对文本处理模型进行参数调整。如此,会导致调整后的文本处理模型无法保持对前一轮预测效果较好的文本数据的预测能力。进而导致训练得到的文本处理模型的文本处理准确率较低。
发明内容
本发明实施例提供一种文本处理模型蒸馏方法、装置、计算机设备及介质,以解决现有技术中模型预测的准确性较低和模型训练效率低的问题。
一种文本处理模型蒸馏方法,包括:
获取原始数据以及原始数据对应的原始标签,根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型;
通过所述第一预测模型对所述原始数据进行预测,得到预测标签,并将所有所述预测标签分为目标标签以及未达标标签;
根据所述未达标标签和所述未达标标签对应的所述原始数据对第二预设模型进行训练,得到第二预测模型,并通过所述第二预测模型对所述未达标标签进行优化处理,得到优化标签;
通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。
一种文本处理模型蒸馏装置,包括:
预测模块,用于获取原始数据以及原始数据对应的原始标签,根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型;
分类模块,用于通过所述第一预测模型对所述原始数据进行预测,得到预测标签,并将所有所述预测标签分为目标标签以及未达标标签;
优化模块,用于根据所述未达标标签和所述未达标标签对应的所述原始数据对第二预设模型进行训练,得到第二预测模型,并通过所述第二预测模型对所述未达标标签进行优化处理,得到优化标签;
蒸馏模块,用于通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述文本处理模型蒸馏方法。
一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述文本处理模型蒸馏方法。
本发明提供一种文本处理模型蒸馏方法、装置、计算机设备及存储介质,该方法通过训练得到的第一预测模型对原始数据进行预测,从而使得第一预测模型学习到预测效果较好的标签(也即目标标签)中的特征。将第一预测模型预测效果较差的标签(也即未达标标签)交由第二预测模型进行学习。如此,即可通过两个不同的模型学习到原始数据不同的数据特征,提高了模型训练的效率和准确率。进而通过将第一预测模型预测效果好的目标标签,以及第二预测模型预测效果好的优化标签蒸馏至第三预设模型中,使得第三预设模型可以学习到第一预测模型的预测优点以及第二预测模型的预测优点,从而提高了文本处理模型的训练效率和准确率。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明一实施例中文本处理模型蒸馏方法的应用环境示意图;
图2是本发明一实施例中文本处理模型蒸馏方法的流程图;
图3是本发明一实施例中文本处理模型蒸馏方法中步骤S20的流程图;
图4是本发明一实施例中文本处理模型蒸馏方法中步骤S40的流程图;
图5是本发明一实施例中文本处理模型蒸馏装置的原理框图;
图6是本发明一实施例中计算机设备的示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例提供的文本处理模型蒸馏方法,该文本处理模型蒸馏方法可应用如图1所示的应用环境中。具体地,该文本处理模型蒸馏方法应用在文本处理模型蒸馏装置中,该文本处理模型蒸馏装置包括如图1所示的客户端和服务器,客户端与服务器通过网络进行通信,用于解决现有技术中文本处理模型准确性较低和处理效率低的问题。其中,该服务器可以是独立的服务器,也可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。客户端又称为用户端,是指与服务器相对应,为客户提供本地服务的程序。客户端可安装在但不限于各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备上。
在一实施例中,如图2所示,提供一种文本处理模型蒸馏方法,以该方法应用在图1中的服务器为例进行说明,包括如下步骤:
S10:获取原始数据以及原始数据对应的原始标签,根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型。
可理解地,原始数据可以通过爬虫技术从不同的网站上采集得到,亦或者从不同的数据库中采集得到。在本实施例中,该原始数据为文本数据(文本数据可以为中文文本、英文文本或者同时包含中文和英文的文本)。原始标签作为原始数据的表征,在不同应用场景下该原始标签表征的含义不同。示例性地,在关键词抽取的应用场景下,该原始标签即表征了原始数据中的关键词。此时,即可以通过人工标注的方式或者对原始数据进行关键词识别的方式,从原始数据中抽取出关键词作为该原始标签。在实体识别的应用场景下,该原始标签可以为原始数据中某一个字体的实体含义。此时,即可以通过人工标注的方式或者对原始数据进行实体识别的方式,确定原始数据中不同的字词对应的实体含义作为该原始标签。
进一步地,第一预设模型为基于Bert模型构建的模型,用于对原始数据进行标签预测的。第一预设模型可以为Bert-Seq2Seq模型,该模型包括多个encoder层,可以对不同长度文本进行识别。第一预测模型为通过原始数据对第一预设模型训练得到的。
具体地,从服务器的数据库中调取原始数据和原始数据对应的原始标签,并将原始数据和原始数据对应的原始标签输入到第一预设模型中,通过原始数据对第一预设模型进行训练。也即通过原始标签和第一预设模型的模型预测结果对第一预设模型中的初始参数进行调整,使得调整初始参数后的第一预设模型的标签预测结果不断的向原始标签靠拢。当调整初始参数后的第一预设模型对原始数据预测的预测标签的预测值达到收敛条件时,则结束训练,并将收敛后的第一预设模型确定为第一预测模型。
S20:通过所述第一预测模型对所述原始数据进行预测,得到预测标签,并将所有所述预测标签分为目标标签以及未达标标签。
可理解地,预测标签为第一预测模型对原始数据进行标签预测的结果。目标标签为大于或等于预设标签阈值的预测标签。未达标标签为小于预设标签阈值的预测标签。预设标签阈值为用于判断原始标签和预测标签之间的是否相似的。预设标签阈值可以为F1分数,可以将预设标签阈值设置为0.9,F1分数为统计学中用来衡量模型精确度的一种指标。也即通过计算预测标签的F1分数,将预测标签的F1分数和预设标签阈值进行比较。预设标签阈值还可以为欧式距离或余弦相似度,可以预设标签阈值设置为0.95,也即通过计算原始标签和预测标签之间欧式距离或余弦相似度,并将计算结果和预设标签阈值进行比较。
具体地,在得到第一预测模型之后,将原始数据输入到第一预测模型中,通过第一预测模型中的嵌入层对原始数据进行向量转换,即将原始数据转换为向量,得到原始数据对应的嵌入向量。通过注意力层中的多组三个权值矩阵WQ,WK,WV对嵌入向量进行计算,得到嵌入向量对应的Query向量,Keys向量和Values向量。使用点积法计算嵌入向量之间的相关性得分,即用Q中每一个嵌入向量与K中每一个嵌入向量计算点积。对嵌入向量之间的相关性得分进行归一化处理,即通过softmax函数,将嵌入向量之间的得分转换成[0,1]之间的概率分布。根据嵌入向量之间的概率分布,然后乘上对应的Values值,得到矩阵。将得到的多个矩阵进行拼接,并通过第一残差连接层对拼接后的矩阵进行处理,避免在模型训练中发生退化问题。再通过第一归一层对处理后的矩阵进行归一化处理。然后通过ReLU函数将归一化后的矩阵进行激活,并通过第二残差连接层和第二归一层对激活后的矩阵进行归一化处理,即可得到预测标签。
进一步地,对同一原始数据对应的原始标签和预测标签进行损失计算,可以通过损失函数对原始标签和预测标签之间的差异进行计算,直接得到预测标签对应的预测值。也可以通过先计算原始标签和预测标签之间的欧氏距离或余弦相似度,从而基于该欧式距离或者余弦相似度确定预测标签对应的预测值。如此即可根据上述方式确定各预测标签对应的预测值,并根据预测值和预设标签阈值将所有预测标签分为目标标签和未达标标签。其中,损失函数可以是CTC损失函数,还可以是Focal Loss损失函数等等。
S30:根据所述未达标标签和所述未达标标签对应的所述原始数据对第二预设模型进行训练,得到第二预测模型,并通过所述第二预测模型对所述未达标标签进行优化处理,得到优化标签。
可理解地,第二预设模型为基于bert模型构建的模型,用于对未达标标签对应的原始数据进行标签预测,第二预设模型可以为Bert-Seq2Seq模型,也可以为Bert-Dense模型。进一步地,第一预设模型和第二预设模型可以为相同的模型,也可以为不同的模型。当两个模型相同时均可以为Bert模型,当两个模型不同时第一预设模型可以为Bert-Seq2Seq模型,第二预设模型可以为Bert-Dense模型。在本实施例中,第一预设模型和第二预设模型优选为不同的模型。由于在训练好的第一预设模型(也即第一预测模型)预测效果较差的未达标标签对应的原始数据再次输入至相同的模型(也即第二预设模型)中进行训练时,该第二预设模型仍无法学习到未达标标签对应的原始数据的其它特征。从而导致第二预设模型学习效果较差,进而影响后续步骤中对第三预设模型的蒸馏学习。第二预测模型为通过未达标标签对应的原始数据对第二预设模型进行训练得到的。优化标签为第二预测模型对未达标标签进行优化得到的标签。
具体地,在得到未达标标签之后,获取未达标标签对应的原始数据,并将未达标标签对应的原始数据输入到第二预设模型中,对第二预设模型进行训练。也即通过原始标签和第二预设模型的模型预测结果对第二预设模型中的初始参数进行调整,使得调整初始参数后的第二预设模型的标签预测结果不断的向原始标签靠拢。当调整初始参数后的第二预设模型对原始数据的预测的预测标签的预测值达到收敛条件时,则结束训练,并将收敛后的第二预设模型确定为第二预测模型。通过第二预测模型重新对未达标标签对应的原始数据进行预测,将第二预测模型预测的预测标签替换掉未达标标签,并将第二预测模型预测的预测标签确定为优化标签,如此即可得到优化标签。
S40:通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。
可理解地,第三预设模型可以为TextCNN模型,用于对第一预测模型和第二预测模型进行蒸馏学习。文本处理模型为通过蒸馏学习具有第一预测模型和第二预测模型的预测能力的模型。
具体地,在得到优化标签之后,通过原始数据和原始标签对第三预设模型进行蒸馏学习,并通过蒸馏后的第三预设模型对原始数据进行预测,得到模型预测结果。将目标标签和优化标签蒸馏到蒸馏后的第三预设模型中,并将同一原始数据对应的目标标签或优化标签和第三预设模型的模型预测结果进行比较,确定第三预设模型的损失值。根据第三预设模型的损失值对蒸馏学习后的第三预设模型的初始参数进行调整,当第三预设模型的损失值达到收敛条件时,将蒸馏学习后的第三预设模型确定为文本处理模型。
在本发明实施例中,该方法通过训练得到的第一预测模型对原始数据进行预测,从而使得第一预测模型学习到预测效果较好的标签(也即目标标签)中的特征。将第一预测模型预测效果较差的标签(也即未达标标签)交由第二预测模型进行学习。如此,即可通过两个不同的模型学习到原始数据不同的数据特征,提高了模型训练的效率和准确率。进而通过将第一预测模型预测效果好的目标标签,以及第二预测模型预测效果好的优化标签蒸馏至第三预设模型中,使得第三预设模型可以学习到第一预测模型的预测优点以及第二预测模型的预测优点,从而提高了文本处理模型的训练效率和准确率。
在一实施例中,步骤S10中,也即根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型,包括:
S101,将所述原始数据输入至所述第一预设模型中,通过所述第一预设模型对所述原始数据进行预测,得到第二训练标签。
可理解地,第二训练标签为第一预设模型对原始数据进行预测得到的模型预测结果。
具体地,在获取原始数据和原始标签之后。将原始数据输入至第一预设模型中,通过第一预设模型中的嵌入层对原始数据进行转换,得到原始数据对应的嵌入向量。通过注意力层对嵌入向量进行计算,得到与嵌入向量相对应的矩阵。将多个矩阵进行拼接,并通过第一残差连接层对拼接后的矩阵进行处理,避免在模型训练中发生退化问题。再通过第一归一层对处理后的矩阵进行归一化处理。然后通过ReLU函数将归一化后的矩阵进行激活,并通过第二残差连接层和第二归一层对激活后矩阵进行归一化处理,即可得到第二训练标签。具体过程与上述步骤S20相同,在此不再赘述。其中,第一预测模型预测的准确性远大于第一预设模型预测的准确性。本实施例中只是列举第一预设模型中的一层encoder层,第一预设模型中包括多层encoder层,通过第一预设模型中的所有encoder层,即可得到第一预设模型预测的第二训练标签。
S102,根据所述第二训练标签和所述原始标签,确定第三损失值。
可理解地,第三损失值为第一预设模型的损失值,也即对第一预设模型训练中生成的损失。
具体地,在得到第二训练标签之后,对第二训练标签和原始标签之间的差异进行计算,可以通过CTC损失函数或Focal Loss损失函数计算第二训练标签和原始标签之间的差异,从而确定第一预设模型的损失值,如此即可得到第三损失值。也可以先计算第二训练标签和原始标签之间的欧氏距离或余弦相似度,从而基于该欧式距离或者余弦相似度确定第三损失值。
S103,根据所述第三损失值对所述第一预设模型进行优化处理,得到所述第一预测模型。
具体地,在得到第三损失值之后,根据第三损失值对第一预设模型中每个层的初始参数进行优化处理。通过优化后的第一预设模型对原始数据进行预测,得到对应的第二训练标签,通过CTC损失函数对新的第二训练标签和同一原始数据对应的目标标签或优化标签进行损失计算,得到新的第三损失值。并判断新的第三损失值是否符合收敛条件,当新的第三损失值符合收敛条件时,将优化后的第一预设模型确定为第一预测模型。当新的第三损失值不符合收敛条件时,根据新的第三损失值对第一预设模型中每个层的初始参数重新进行优化处理。如此,直至第一预设模型的损失值符合收敛条件时,将符合收敛条件的第一预设模型记录为第一预测模型。
本发明实施例通过第一预设模型对原始数据进行预测,得到第二训练标签,并根据第二训练标签和原始标签,实现了对第三损失值的确定。通过第三损失值对第一预设模型进行优化处理,直至第三损失值符合收敛条件,实现了对第一预测模型的确定,进而提高了对第一预测模型预测的准确性。
在一实施例中,如图3所示,步骤S20中,所述将所有所述预测标签分为目标标签以及未达标标签,包括:
S201:基于同一所述原始数据对应的所述原始标签和所述预测标签,确定所述预测标签对应的预测值。
可理解地,预测值为用于表征原始标签和预测标签之间的相似度。
具体地,在得到目标标签以及未达标标签之后,对同一原始数据对应的原始标签和预测标签进行获取,通过CTC损失函数或Focal Loss损失函数对同一原始数据对应的原始标签和预测标签之间的差异进行计算,并根据计算的结果确定为预测标签对应的预测值。如此通过上述方式即可得到各预测标签对应的预测值。
S202:获取预设标签阈值,并将所述预测值和所述预设标签阈值进行比较。
S203:将大于或等于所述预设标签阈值的所述预测值对应的所述预测标签确定为所述目标标签,将小于所述预设标签阈值的所述预测值对应的所述预测标签确定为所述未达标标签。
可以理解地,预设标签阈值为用于判断原始标签和预测标签之间相似的。目标标签为大于或等于预设标签阈值的预测值对应的预测标签。未达标标签为小于预设标签阈值的预测值对应的预测标签。
具体地,在得到预测标签对应的预测值之后,从服务器中调取预设标签阈值或者从第三方平台获取预设标签阈值,并将预测标签对应的预测值和预设标签阈值进行比较。当预测标签对应的预测值大于或等于预设标签阈值时,将大于或等于预设标签阈值的预测值对应的预测标签确定为目标标签。当预测标签对应的预测值小于预设标签阈值时,将小于所述预设标签阈值的预测值对应的预测标签确定为未达标标签。如此通过上述方式即可得到所有目标标签和所有未达标标签。
本发明实施例通过同一原始数据对应的原始标签和预测标签,实现了确定预测标签对应的预测值。通过预设标签阈值和所有预测值进行比较,实现了对目标标签以及未达标标签的确定,方便了后续对第二预设模型的训练。
在一实施例中,步骤S30中,即根据所述未达标标签和所述未达标标签对应的原始数据对第二预设模型进行训练,得到第二预测模型,包括:
S301,将所述未达标标签对应的所述原始数据输入至所述第二预设模型中,通过所述第二预设模型对所述未达标标签对应的所述原始数据进行预测,得到第三训练标签。
可理解地,第三训练标签为第二预设模型对未达标标签对应的原始数据进行预测得到的模型预测结果。
具体地,在得到目标标签以及未达标标签之后,通过未达标标签获取与未达标标签相对应的原始数据,并将未达标标签对应的原始数据输入到第二预设模型中,通过第二预设模型对未达标标签对应的原始数据进行预测,也即当第二预设模型为Bert-Dense模型时,通过第二预设模型中的嵌入层对未达标标签对应的原始数据进行向量转换,得到未达标标签对应的原始数据所对应的嵌入向量。具体过程与上述步骤S20相同,在此不再赘述,只说明其不同的部分。也即在通过ReLU函数将归一化后的矩阵进行激活之后,将激活后的矩阵输入到全连接层,全连接层中的隐藏层通过不同的权重对所有的激活后的矩阵进行计算处理。将处理结果通过全连接层中的输出层输入到第二残差连接层中,并通过第二残差连接层和第二归一层对处理结果进行预测,即可得到第三训练标签。
S302,根据同一所述原始数据对应的所述第三训练标签和所述未达标标签,确定第四损失值。
可理解地,第四损失值为第二预设模型的损失值,也即对第二预设模型训练中生成的损失。
具体地,在得到第三训练标签之后,对第三训练标签和原始标签之间的差异进行计算,通过CTC损失函数计算第三训练标签和原始标签之间的差异,从而确定第二预设模型的损失值,如此即可得到第四损失值。也可以先计算第三训练标签和原始标签之间的欧氏距离或余弦相似度,从而基于该欧式距离或者余弦相似度确定第四损失值。
S303,通过所述第四损失值对所述第二预设模型进行优化处理,得到所述第二预测模型。
具体地,在得到第四损失值之后,根据第四损失值对第二预设模型中每个层的初始参数进行优化处理。通过优化后的第二预设模型对未达标标签对应的原始数据进行预测,得到对应的第三训练标签,通过CTC损失函数对新的第三训练标签和同一原始数据对应的目标标签或优化标签进行损失计算,得到新的第四损失值。并判断新的第四损失值是否符合收敛条件,当新的第四损失值符合收敛条件时,将优化后的第二预设模型确定为第二预测模型。当新的第四损失值不符合收敛条件时,根据新的第四损失值对第二预设模型中每个层的初始参数重新进行优化。如此,直至第二预设模型的损失值符合收敛条件时,将符合收敛条件的第二预设模型记录为第二预测模型。
本发明实施例通过未达标标签对应的原始数据对第二预设模型进行训练,减少了参数量,提高了模型训练效率。通过第二预设模型对未达标标签对应的原始数据进行预测,得到第三训练标签,并根据第三训练标签和原始标签,实现了对第二损失值的确定。通过第二损失值对第二预设模型进行优化处理,直至第二损失值符合收敛条件,确定第二预测模型,提高了对第二预测模型预测的准确性。
在一实施例中,如图4所示,步骤S40中,即通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型,包括:
S401,将所述原始数据和所述原始标签输入至所述第三预设模型中,通过所述原始数据和所述原始标签对所述第三预设模型进行蒸馏学习,得到蒸馏模型。
S402,通过所述蒸馏模型对所述原始数据进行预测,得到第一训练标签。
可理解地,第三预设模型为TextCNN模型,包括嵌入层、卷积层、池化层和全连接层。蒸馏模型为通过原始数据和原始标签对第三预设模型进行蒸馏学习得到的。第一训练标签为蒸馏模型对原始数据进行预测得到的。
具体地,在得到优化标签之后,将原始数据和原始标签输入至第三预设模型中,通过原始数据和原始标签对第三预设模型进行蒸馏学习,使得第三预设模型具有简单的预测能力,并将蒸馏学习后的第三预设模型确定为蒸馏模型。根据蒸馏模型对原始数据进行预测,也即先通过蒸馏模型中的嵌入层将原始数据进行向量化,得到原始数据对应的嵌入向量。再通过卷积层对嵌入向量进行一维卷积处理,得到卷积特征向量。然后通过池化层对卷积特征向量进行最大池化处理,也即将不同长度的卷积特征向量变成固定长度的向量,并拼接到为池化文本向量。最后通过全连接层对池化文本向量进行处理,避免过度拟合,并对处理后的池化文本向量进行预测,即可得到第一训练标签。
S403,将所述目标标签和所述优化标签蒸馏到所述蒸馏模型,并根据所述目标标签、所述优化标签以及所述第一训练标签,确定第一损失值。
S404,根据所述第一损失值对所述蒸馏模型进行优化处理,并确定所述第一损失值是否符合收敛条件,当所述第一损失值达到所述收敛条件时,将所述蒸馏模型确定为所述文本处理模型。
可理解地,第一损失值为第三预设模型的损失值,也即对第三预设模型训练中生成的损失。
具体地,在得到第一训练标签之后,通过蒸馏的方式将目标标签和优化标签蒸馏到蒸馏模型中,也即将第一预测模型和第二预测模型的预测优点蒸馏到蒸馏模型中,也即使得蒸馏模型具有第一预测模型和第二预测模型的预测能力。然后通过同一原始数据对应的第一训练标签以及目标标签或优化标签进行损失计算,也即通过CTC损失函数对第一训练标签和目标标签或优化标签之间的差异进行计算,得到第一损失值。并根据第一损失值对蒸馏模型中每个层的初始参数进行优化调整,得到优化蒸馏模型。然后通过优化蒸馏模型对原始数据进行预测,得到新的第一训练标签,并通过CTC损失函数对新的第一训练标签和同一原始数据对应的目标标签或优化标签进行损失计算,得到新的第一损失值。并判断新的第一损失值是否符合收敛条件,当新的第一损失值符合收敛条件时,将新的蒸馏模型确定为文本处理模型。
本发明实施例通过将目标标签和优化标签蒸馏到第三预设模型中,从而实现了将第一预测模型和第二预测模型的模型结构蒸馏到第三预设模型。并判断第三预设模型的第一损失值是否达到收敛条件,从而实现了对文本处理模型的确定。进而提高了文本处理模型对原始数据预测的准确性,提高了文本处理模型蒸馏的效率。
在一实施例中,步骤S404中,即根据所述第一损失值对所述蒸馏模型进行优化处理,并确定所述第一损失值是否符合收敛条件之后,包括:
S4041,若所述第一损失值未达到所述收敛条件,对所述蒸馏模型的初始参数进行调整,得到目标蒸馏模型。
S4042,通过所述目标蒸馏模型对所述原始数据进行预测,得到蒸馏标签;根据所述蒸馏标签和所述原始标签,确定第二损失值。
可理解地,目标蒸馏模型为通过第一损失值对蒸馏模型的初始参数进行调整得到的。第二损失值为目标蒸馏模型的损失值,也即对蒸馏模型训练中生成的损失。
具体地,当第一损失值不符合收敛条件时,根据第一损失值对蒸馏模型中每个层的初始参数重新进行调整,并将调整初始参数后的蒸馏模型确定为目标蒸馏模型。通过目标蒸馏模型对原始数据进行预测,得到原始数据对应的蒸馏标签。可以通过CTC损失函数对蒸馏标签和同一原始数据对应的目标标签或优化标签进行损失计算,得到第二损失值。
S4043,在所述第二损失值未达到所述收敛条件,迭代更新所述目标蒸馏模型中的初始参数,直至所述第二损失值达到所述收敛条件时,将收敛之后的所述目标蒸馏模型记录为所述文本处理模型。
可以理解地,该收敛条件可以为第二损失值小于设定阈值的条件,也即在第二损失值小于设定阈值时,停止训练;收敛条件还可以为第二损失值经过了500次计算后值为很小且不会再下降的条件,也即第二损失值经过500次计算后值很小且不会下降时,停止训练。
具体地,确定第二损失值之后,在第二损失值未达到预设的收敛条件时,根据第二损失值调整目标蒸馏模型的初始参数,并将原始数据重新输入至调整初始参数后的目标蒸馏模型中,得到与调整初始参数的目标蒸馏模型相对应的第二损失值。以在该第二损失值达到预设的收敛条件时,将收敛之后的目标蒸馏模型记录为文本处理模型。并在该第二损失值未达到预设的收敛条件时,根据该第二损失值再次调整目标蒸馏模型的初始参数,使得再次调整初始参数的目标蒸馏模型输出的结果可以不断向准确地结果靠拢,让模型预测的准确率越来越高。直至第二损失值达到预设的收敛条件时,并将收敛之后的目标蒸馏模型记录为文本处理模型。
本发明实施例通过在第一损失值未达到收敛条件时,对蒸馏模型的初始参数进行调整,从而实现了对目标蒸馏模型的获取。通过目标蒸馏模型对原始数据进行预测,得到第二损失值,并在第二损失值达到收敛条件时,实现了对文本处理模型的确定,进而提高了文本处理模型预测的准确性,以及提高了文本处理模型蒸馏效率。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
在一实施例中,提供一种文本处理模型蒸馏装置,该文本处理模型蒸馏装置与上述实施例中文本处理模型蒸馏方法一一对应。如图5所示,该文本处理模型蒸馏装置包括预测模块11、分类模块12、优化模块13和蒸馏模块14。各功能模块详细说明如下:
预测模块11,用于获取原始数据以及原始数据对应的原始标签,根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型;
分类模块12,用于通过所述第一预测模型对所述原始数据进行预测,得到预测标签,并将所有所述预测标签分为目标标签以及未达标标签;
优化模块13,用于根据所述未达标标签和所述未达标标签对应的所述原始数据对第二预设模型进行训练,得到第二预测模型,并通过所述第二预测模型对所述未达标标签进行优化处理,得到优化标签;
蒸馏模块14,用于通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。
在一实施例中,所述预测模块11包括:
第二标签预测单元,用于将所述原始数据输入至所述第一预设模型中,通过所述第一预设模型对所述原始数据进行预测,得到第二训练标签;
第三损失值单元,用于根据同一所述原始数据对应的所述第二训练标签和所述原始标签,确定第三损失值;
第一预测模型单元,用于根据所述第三损失值对所述第一预设模型进行优化处理,得到所述第一预测模型。
在一实施例中,所述分类模块12包括:
确定单元,用于基于同一所述原始数据对应的所述原始标签和所述预测标签,确定所述预测标签对应的预测值;
比较单元,用于获取预设标签阈值,并将所述预测值和所述预设标签阈值进行比较;
结果单元,用于将大于或等于所述预设标签阈值的所述预测值对应的所述预测标签确定为目标标签,将小于所述预设标签阈值的所述预测值对应的所述预测标签确定为未达标标签。
在一实施例中,所述优化模块13包括:
第三标签预测单元,用于将所述未达标标签对应的所述原始数据输入至所述第二预设模型中,通过所述第二预设模型对所述未达标标签对应的所述原始数据进行预测,得到第三训练标签;
第四损失值单元,用于根据同一所述原始数据对应的所述第三训练标签和所述未达标标签,确定第四损失值;
第二预测模型单元,用于通过所述第四损失值对所述第二预设模型进行优化处理,得到所述第二预测模型。
在一实施例中,所述蒸馏模块14包括:
蒸馏学习单元,用于将所述原始数据和所述原始标签输入至所述第三预设模型中,通过所述原始数据和所述原始标签对所述第三预设模型进行蒸馏学习,得到蒸馏模型;
第一标签预测单元,用于通过所述蒸馏模型对所述原始数据进行预测,得到第一训练标签;
第一损失值单元,用于将所述目标标签和所述优化标签蒸馏到所述蒸馏模型,并根据所述目标标签、所述优化标签以及所述第一训练标签,确定第一损失值;
模型确定单元,用于根据所述第一损失值对所述蒸馏模型进行优化处理,并确定所述第一损失值是否符合收敛条件,当所述第一损失值达到所述收敛条件时,将所述蒸馏模型确定为所述文本处理模型。
在一实施例中,所述文本处理模型单元还包括:
参数调整单元,用于若所述第一损失值未达到所述收敛条件,对所述蒸馏模型的初始参数进行调整,得到目标蒸馏模型;
第二损失值单元,用于通过所述目标蒸馏模型对所述原始数据进行预测,得到蒸馏标签;根据所述蒸馏标签和所述原始标签,确定第二损失值;
模型收敛单元,用于在所述第二损失值未达到所述收敛条件,迭代更新所述目标蒸馏模型中的初始参数,直至所述第二损失值达到所述收敛条件时,将收敛之后的所述目标蒸馏模型记录为所述文本处理模型。
关于文本处理模型蒸馏装置的具体限定可以参见上文中对于文本处理模型蒸馏方法的限定,在此不再赘述。上述文本处理模型蒸馏装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图6所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储上述实施例中文本处理模型蒸馏方法所用到的数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种文本处理模型蒸馏方法。
在一个实施例中,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述实施例中文本处理模型蒸馏方法。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现上述实施例中文本处理模型蒸馏方法。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
Claims (10)
1.一种文本处理模型蒸馏方法,其特征在于,包括:
获取原始数据以及原始数据对应的原始标签,根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型;
通过所述第一预测模型对所述原始数据进行预测,得到预测标签,并将所有所述预测标签分为目标标签以及未达标标签;
根据所述未达标标签和所述未达标标签对应的所述原始数据对第二预设模型进行训练,得到第二预测模型,并通过所述第二预测模型对所述未达标标签进行优化处理,得到优化标签;
通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。
2.如权利要求1所述的文本处理模型蒸馏方法,其特征在于,所述将所有所述预测标签分为目标标签以及未达标标签,包括:
基于同一所述原始数据对应的所述原始标签和所述预测标签,确定所述预测标签对应的预测值;
获取预设标签阈值,并将所述预测值和所述预设标签阈值进行比较;
将大于或等于所述预设标签阈值的所述预测值对应的所述预测标签确定为所述目标标签,将小于所述预设标签阈值的所述预测值对应的所述预测标签确定为所述未达标标签。
3.如权利要求1所述的文本处理模型蒸馏方法,其特征在于,所述通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型,包括:
将所述原始数据和所述原始标签输入至所述第三预设模型中,通过所述原始数据和所述原始标签对所述第三预设模型进行蒸馏学习,得到蒸馏模型;
通过所述蒸馏模型对所述原始数据进行预测,得到第一训练标签;
将所述目标标签和所述优化标签蒸馏到所述蒸馏模型,并根据所述目标标签、所述优化标签以及所述第一训练标签,确定第一损失值;
根据所述第一损失值对所述蒸馏模型进行优化处理,并确定所述第一损失值是否符合收敛条件,当所述第一损失值达到所述收敛条件时,将所述蒸馏模型确定为所述文本处理模型。
4.如权利要求3所述的文本处理模型蒸馏方法,其特征在于,所述根据所述第一损失值对所述蒸馏模型进行优化处理,并确定所述第一损失值是否符合收敛条件之后,包括:
若所述第一损失值未达到所述收敛条件,对所述蒸馏模型的初始参数进行调整,得到目标蒸馏模型;
通过所述目标蒸馏模型对所述原始数据进行预测,得到蒸馏标签;根据所述蒸馏标签和所述原始标签,确定第二损失值;
在所述第二损失值未达到所述收敛条件,迭代更新所述目标蒸馏模型中的初始参数,直至所述第二损失值达到所述收敛条件时,将收敛之后的所述目标蒸馏模型记录为所述文本处理模型。
5.如权利要求1所述的文本处理模型蒸馏方法,其特征在于,所述根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型,包括:
将所述原始数据输入至所述第一预设模型中,通过所述第一预设模型对所述原始数据进行预测,得到第二训练标签;
根据同一所述原始数据对应的所述第二训练标签和所述原始标签,确定第三损失值;
根据所述第三损失值对所述第一预设模型进行优化处理,得到所述第一预测模型。
6.如权利要求1所述的文本处理模型蒸馏方法,其特征在于,所述根据所述未达标标签和所述未达标标签对应的原始数据对第二预设模型进行训练,得到第二预测模型,包括:
将所述未达标标签对应的所述原始数据输入至所述第二预设模型中,通过所述第二预设模型对所述未达标标签对应的所述原始数据进行预测,得到第三训练标签;
根据同一所述原始数据对应的所述第三训练标签和所述未达标标签,确定第四损失值;
通过所述第四损失值对所述第二预设模型进行优化处理,得到所述第二预测模型。
7.一种文本处理模型蒸馏装置,其特征在于,包括:
预测模块,用于获取原始数据以及原始数据对应的原始标签,根据所述原始数据和所述原始标签对第一预设模型进行训练,得到第一预测模型;
分类模块,用于通过所述第一预测模型对所述原始数据进行预测,得到预测标签,并将所有所述预测标签分为目标标签以及未达标标签;
优化模块,用于根据所述未达标标签和所述未达标标签对应的所述原始数据对第二预设模型进行训练,得到第二预测模型,并通过所述第二预测模型对所述未达标标签进行优化处理,得到优化标签;
蒸馏模块,用于通过所述原始数据、所述原始标签、所述目标标签和所述优化标签对第三预设模型进行蒸馏学习,得到文本处理模型。
8.如权利要求7所述的文本处理模型蒸馏装置,其特征在于,所述分类模块包括:
确定单元,用于基于同一所述原始数据对应的所述原始标签和所述预测标签,确定所述预测标签对应的预测值;
比较单元,用于获取预设标签阈值,并将所述预测值和所述预设标签阈值进行比较;
结果单元,用于将大于或等于所述预设标签阈值的所述预测值对应的所述预测标签确定为目标标签,将小于所述预设标签阈值的所述预测值对应的所述预测标签确定为未达标标签。
9.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至6任一项所述文本处理模型蒸馏方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述文本处理模型蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210948994.4A CN115374278A (zh) | 2022-08-09 | 2022-08-09 | 文本处理模型蒸馏方法、装置、计算机设备及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210948994.4A CN115374278A (zh) | 2022-08-09 | 2022-08-09 | 文本处理模型蒸馏方法、装置、计算机设备及介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115374278A true CN115374278A (zh) | 2022-11-22 |
Family
ID=84063340
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210948994.4A Pending CN115374278A (zh) | 2022-08-09 | 2022-08-09 | 文本处理模型蒸馏方法、装置、计算机设备及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115374278A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116340552A (zh) * | 2023-01-06 | 2023-06-27 | 北京达佳互联信息技术有限公司 | 一种标签排序方法、装置、设备及存储介质 |
-
2022
- 2022-08-09 CN CN202210948994.4A patent/CN115374278A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116340552A (zh) * | 2023-01-06 | 2023-06-27 | 北京达佳互联信息技术有限公司 | 一种标签排序方法、装置、设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111428021B (zh) | 基于机器学习的文本处理方法、装置、计算机设备及介质 | |
CN110598206A (zh) | 文本语义识别方法、装置、计算机设备和存储介质 | |
CN109829629B (zh) | 风险分析报告的生成方法、装置、计算机设备和存储介质 | |
CN110569500A (zh) | 文本语义识别方法、装置、计算机设备和存储介质 | |
CN109063217B (zh) | 电力营销系统中的工单分类方法、装置及其相关设备 | |
CN112926654B (zh) | 预标注模型训练、证件预标注方法、装置、设备及介质 | |
CN111553479A (zh) | 一种模型蒸馏方法、文本检索方法及装置 | |
CN114528844A (zh) | 意图识别方法、装置、计算机设备及存储介质 | |
CN110598210B (zh) | 实体识别模型训练、实体识别方法、装置、设备及介质 | |
CN112231224A (zh) | 基于人工智能的业务系统测试方法、装置、设备和介质 | |
CN110362798B (zh) | 裁决信息检索分析方法、装置、计算机设备和存储介质 | |
CN112699923A (zh) | 文档分类预测方法、装置、计算机设备及存储介质 | |
CN111985228A (zh) | 文本关键词提取方法、装置、计算机设备和存储介质 | |
CN115495553A (zh) | 查询文本排序方法、装置、计算机设备及存储介质 | |
WO2020052183A1 (zh) | 商标侵权的识别方法、装置、计算机设备和存储介质 | |
CN111583911B (zh) | 基于标签平滑的语音识别方法、装置、终端及介质 | |
CN111611383A (zh) | 用户意图的识别方法、装置、计算机设备及存储介质 | |
CN110377618B (zh) | 裁决结果分析方法、装置、计算机设备和存储介质 | |
CN115374278A (zh) | 文本处理模型蒸馏方法、装置、计算机设备及介质 | |
CN111859916A (zh) | 古诗关键词提取、诗句生成方法、装置、设备及介质 | |
CN117093682A (zh) | 意图识别方法、装置、计算机设备及存储介质 | |
CN115169334A (zh) | 意图识别模型训练方法、装置、计算机设备及存储介质 | |
CN113627514A (zh) | 知识图谱的数据处理方法、装置、电子设备和存储介质 | |
CN115840817A (zh) | 基于对比学习的信息聚类处理方法、装置和计算机设备 | |
CN113468322A (zh) | 关键词识别模型的训练、提取方法、装置、设备及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |