CN112101573A - 一种模型蒸馏学习方法、文本查询方法及装置 - Google Patents

一种模型蒸馏学习方法、文本查询方法及装置 Download PDF

Info

Publication number
CN112101573A
CN112101573A CN202011275406.2A CN202011275406A CN112101573A CN 112101573 A CN112101573 A CN 112101573A CN 202011275406 A CN202011275406 A CN 202011275406A CN 112101573 A CN112101573 A CN 112101573A
Authority
CN
China
Prior art keywords
model
distillation
query
text
trained
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202011275406.2A
Other languages
English (en)
Other versions
CN112101573B (zh
Inventor
杨均晖
方宽
申站
赵龙
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhizhe Sihai Beijing Technology Co ltd
Original Assignee
Zhizhe Sihai Beijing Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhizhe Sihai Beijing Technology Co ltd filed Critical Zhizhe Sihai Beijing Technology Co ltd
Priority to CN202011275406.2A priority Critical patent/CN112101573B/zh
Publication of CN112101573A publication Critical patent/CN112101573A/zh
Application granted granted Critical
Publication of CN112101573B publication Critical patent/CN112101573B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/33Querying

Abstract

本发明涉及一种模型蒸馏学习方法、文本查询方法及装置,属于自然语言处理技术领域,旨在提高训练模型的精度,从而提高文本查询时的准确性。该方法包括:利用已标注的数据集训练第一模型;将迁移数据集输入至训练好的第一模型和第二模型,分别输出第一相关性分数集和第二相关性分数集;至少部分的根据第一相关性分数集和第二相关性分数集确定蒸馏损失;根据蒸馏损失优化第二模型的参数,得到训练好的第二模型,其中:第一模型和第二模型为不同类型的模型。

Description

一种模型蒸馏学习方法、文本查询方法及装置
技术领域
本发明涉及自然语言处理技术领域,更具体地,涉及一种模型蒸馏学习方法、文本查询方法及装置。
背景技术
随着机器学习的不断发展,自然语言处理技术中神经网络被使用的越来越多,通常情况下将该神经网络模型应用于搜索引擎当中,利用神经网络模型对召回的文本进行打分,从而按照分数高低返回给用户。
然而现有技术中很多模型,比如Bert,其模型较为复杂,参数量大,训练时间长,内存消耗较大,这样导致很难应用至智能手机等移动终端设备上。
为了解决上述的问题,现有技术中提出了一种知识蒸馏的方法,其利用学生模型来直接去学习老师模型,来实现学生模型对老师模型的学习。现有的学生模型和老师模型为同类型模型,而由于受限于模型本身的准确度,使得同类型模型之间蒸馏效果有限,其蒸馏学习得到的模型的准确度较低,从而在文本查询时的准确性也较低。
发明内容
有鉴于此,本发明实施例的目的在于提供一种模型蒸馏学习方法、文本查询方法及装置,旨在维持或提高模型准确性的同时能够压缩模型大小,大量节省计算资源和存储资源,此外在提高模型准确度后,从而能够提高文本查询时的准确性。
本发明实施例的第一方面,提供一种通过蒸馏学习第一模型训练第二模型的方法,所述第一模型和第二模型为不同类型的模型,所述方法包括:利用已标注的数据集训练第一模型;将迁移数据集输入至训练好的第一模型和所述第二模型,分别输出第一相关性分数集和第二相关性分数集;至少部分的根据所述第一相关性分数集和所述第二相关性分数集确定蒸馏损失;根据所述蒸馏损失优化所述第二模型的参数,得到训练好的第二模型。
在一个可能的实施例中,所述至少部分的根据所述第一相关性分数集和所述第二相关性分数集确定蒸馏损失,包括:根据所述第一相关性分数集确定第一得分矩阵,以及根据所述第二相关性分数集确定第二得分矩阵;至少部分的根据所述第一得分矩阵和所述第二得分矩阵确定蒸馏矩阵;根据所述蒸馏矩阵确定蒸馏损失。
在一个可能的实施例中,所述至少部分的根据所述第一得分矩阵和所述第二得分矩阵确定蒸馏矩阵,包括:构建成对铰链损失pairwise hinge loss函数模型;将所述第一得分矩阵和所述第二得分矩阵输入至pairwise hinge loss函数模型,输出蒸馏矩阵。
在一个可能的实施例中,其中,所述第一模型和第二模型的输入内容以及输出内容不同,其中:所述第一模型为交互模型,所述第二模型为表示模型;或者,所述第一模型为表示模型,所述第二模型为交互模型。
在一个可能的实施例中,其中,所述交互模型包括输入层、Transformer层以及输出层,输入为:s1,s2,...sn,输出为:s1向量,s2向量,...,sn向量,其中:si是query和doci合并的语句;所述表示模型包括输入层、Transformer层以及输出层,输入为:query,doc1,...,docn,输出为:query向量,doc1向量,...,docn向量。
本发明实施例的第二方面,提供一种文本查询方法,所述方法包括:获取查询文本和候选文本;将所述查询文本和候选文本输入至预先训练好的文本查询模型中,输出查询文本和候选文本之间的得分;其中,所述文本查询模型为通过蒸馏学习交互模型训练得到的表示模型,且蒸馏过程中蒸馏损失是通过pairwise hinge loss函数模型确定的;根据所述查询文本和候选文本之间的得分输出所述查询文本相匹配的目标候选文本。
本发明实施例的第三方面,一种通过蒸馏学习第一模型训练第二模型的装置,所述第一模型和第二模型为不同类型的模型,所述装置包括:训练模块,被配置为利用已标注的数据集训练第一模型;输入模块,被配置为将迁移数据集输入至训练好的第一模型和所述第二模型;第一输出模块,被配置为输出第一相关性分数集和第二相关性分数集;确定模块,被配置为根据所述第一相关性分数集和所述第二相关性分数集确定蒸馏损失;优化模块,被配置为根据所述蒸馏损失优化所述第二模型的参数,得到训练好的第二模型。
在一个可能的实施例中,所述确定模块被配置为具体用于:根据所述第一相关性分数集确定第一得分矩阵,以及根据所述第二相关性分数集确定第二得分矩阵;至少部分的根据所述第一得分矩阵和所述第二得分矩阵确定蒸馏矩阵;根据所述蒸馏矩阵确定蒸馏损失。
在一个可能的实施例中,所述确定模块被配置为具体还用于:构建成对铰链损失pairwise hinge loss函数模型;将所述第一得分矩阵和所述第二得分矩阵输入至pairwise hinge loss函数模型,输出蒸馏矩阵。
本发明实施例的第四方面,提供一种文本查询装置,所述装置包括:获取模块,被配置为获取查询文本和候选文本;处理模块,被配置为将所述查询文本和候选文本输入至预先训练好的文本查询模型中,输出查询文本和候选文本之间的得分;其中,所述文本查询模型为通过蒸馏学习交互模型训练得到的表示模型,且蒸馏过程中的蒸馏损失是通过pairwise hinge loss函数模型确定的;第二输出模块,根据所述查询文本和候选文本之间的得分输出所述查询文本相匹配的目标候选文本。
本发明实施例的第五方面,提供一种电子设备,包括:存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如第一方面所述的方法。
本发明实施例的第六方面,提供一种计算机可读存储介质,所述计算机可读存储介质上存储有可执行指令,该指令被处理器执行时使处理器执行如第一方面所述的方法。
本发明实施例提供的模型蒸馏学习方法、文本查询方法及装置,首先,通过利用已标注的数据集训练第一模型;其次,将迁移数据集输入至训练好的第一模型和第二模型,分别输出第一相关性分数集和第二相关性分数集;然后,至少部分的根据第一相关性分数集和第二相关性分数集确定蒸馏损失;最后,根据蒸馏损失优化第二模型的参数,得到训练好的第二模型,其中:第一模型和第二模型为不同类型的模型。上述的方法通过确定第一模型和第二模型的参数差值作为第二模型的优化参数,能够准确找到第一模型和第二模型的差异,从而能够优化第二模型参数,能够维持或提高模型准确性的同时能够压缩模型大小,大量节省计算资源和存储资源,在提高模型准确度后,从而能够提高文本查询时的准确性。
本发明的其他特征和优点将在随后的说明书阐述,并且,部分地从说明书中变得显而易见,或者通过实施本发明实施例而了解。本发明的目的和其他优点可通过在所写的说明书以及附图中所特别指出的结构来实现和获得。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。通过附图所示,本申请的上述及其它目的、特征和优势将更加清晰。在全部附图中相同的附图标记指示相同的部分。并未刻意按实际尺寸等比例缩放绘制附图,重点在于示出本申请的主旨。
图1示出了本发明实施例提供的一种通过蒸馏学习方式训练学生模型的示意图;
图2示出了本发明实施例提供的交互模型的结构示意图;
图3示出了本发明实施例提供的表示模型的结构示意图;
图4示出了本发明实施例提供的一种通过蒸馏学习第一模型训练第二模型的方法流程图;
图5示出了本发明实施例提供的蒸馏学习模型方式和直接训练模型方式得到的精度对比示意图;
图6示出了本发明实施例提供的一种文本查询方法的流程图;
图7示出了本发明实施例提供的一种通过蒸馏学习第一模型训练第二模型的装置的结构示意图;
图8示出了本发明实施例提供的一种文本查询装置的结构示意图;
图9示出了本发明实施例提供的一种电子设备的结构示意图。
具体实施方式
以下,将参照附图来描述本发明的实施例。但是应该理解,这些描述只是示例性的,而并非要限制本发明的范围。此外,在以下说明中,省略了对公知结构和技术的描述,以避免不必要地混淆本发明的概念。
在此使用的术语仅仅是为了描述具体实施例,而并非意在限制本发明。这里使用的词语“一”、“一个(种)”和“该”等也应包括“多个”、“多种”的意思,除非上下文另外明确指出。此外,在此使用的术语“包括”、“包含”等表明了所述特征、步骤、操作和/或部件的存在,但是并不排除存在或添加一个或多个其他特征、步骤、操作或部件。
在此使用的所有术语(包括技术和科学术语)具有本领域技术人员通常所理解的含义,除非另外定义。应注意,这里使用的术语应解释为具有与本说明书的上下文相一致的含义,而不应以理想化或过于刻板的方式来解释。
目前,利用知识蒸馏的思想来进行训练模型的技术中,通常都是基于同类型的模型之间进行学习,比如,用表示模型的学生模型来学习老师模型的参数,或者用交互模型的学生模型来学习老师模型的参数,这里的学生模型和老师模型为同一类型的模型(同为表示模型或交互模型)。这种方式仅能够实现同类型模型之间的蒸馏学习,而对于跨不同类型间的蒸馏学习,目前暂未有相关人员进行研究。为此,本发明实施例提供一种跨不同类型模型间的蒸馏学习方法,用以解决现有技术中在面对不同类型的模型时无法进行蒸馏学习的技术难题。
本发明实施例提供的模型蒸馏学习方法、文本查询方法及装置,首先,通过利用已标注的数据集训练第一模型;其次,将迁移数据集输入至训练好的第一模型和第二模型,分别输出第一相关性分数集和第二相关性分数集;然后,根据第一相关性分数集和第二相关性分数集确定蒸馏损失;最后,根据蒸馏损失优化第二模型的参数,得到训练好的第二模型,其中:第一模型和第二模型为不同类型的模型。上述的方法通过确定第一模型和第二模型的参数差值作为第二模型的优化参数,能够准确找到第一模型和第二模型的差异,从而能够优化第二模型参数,能够维持或提高模型准确性的同时能够压缩模型大小,大量节省计算资源和存储资源,在提高模型准确度后,从而能够提高文本查询时的准确性。以下将结合附图1-6对本发明实施的技术内容进行详细说明,具体可以参考下文。
如图1所示为本发明实施例提供的一种通过蒸馏学习方式训练学生模型的示意图。通过图1可以得知本发明的发明思想,其具体内容如下:通过将迁移数据分别输入至Student和Teacher中,得到这两个模型的输出结果,根据这两个模型的输出结果间的差异,通过pairwise hinge loss来计算蒸馏损失,根据蒸馏损失来优化学生模型,从而得到训练好的学生模型。需要说明的是,这里的Student和Teacher采用不同类型的模型,例如,Student为表示模型,Teacher为交互模型。基于图1所示的发明思想,下文将详细描述具体的实现过程。
如图2所示,为本发明实施例提供的交互模型的结构示意图。可以看出该交互模型可以为BERT模型,包括输入层、Transformer层以及输出层,其中,Transformer层包括多个Transformer结构,输入层对应的输入内容为:s1,s2,...sn,该si用于表示query和doci合并的语句,输出层对应的输出内容为:s1向量,s2向量,...,sn向量。通过该BERT模型能够将输入文本转化为向量表现形式,然后经过池化和全连接层后,计算得到query与doc间的得分。
如图3所示,为本发明实施例提供的表示模型的结构示意图。可以看出该表示模型可以为BERT模型,包括输入层、Transformer层以及输出层,其中,Transformer层包括多个Transformer结构,输入层对应的输入内容为:query,doc1,...,docn,该query,doc1,...,docn用于表示查询文本和多个候选文本,输出层对应的输出内容为:query向量,doc1向量,...,docn向量。通过该BERT模型能够将输入文本转化为向量表现形式,根据输出层的向量来计算query向量与doc向量之间的相关性,从而得到query与doc间的得分。
基于上述的图2和图3的内容,可以看出两个模型的输入内容和输出内容不同,因此可以将该不同的输入和输出的模型称为不同类型的模型。
如图4所示,为本发明实施例提供的一种通过蒸馏学习第一模型训练第二模型的方法流程图。这里第一模型和第二模型为不同类型的模型。比如,第一模型为图2所示的交互模型,第二模型为图3所示的表示模型;或者,第一模型为图3所示的表示模型,第二模型为图2所示的交互模型;下面以第一模型为交互模型,第二模型为表示模型为例,进行详细的描述本发明的详细技术内容。该方法包括:
401、利用已标注的数据集训练第一模型。
示例性的,可以在网站的语料上预先训练一个基于RoBERTa-large的交互模型,并在已标注的数据集上进行相关性任务微调,这里的已标注数据集为人工标注的数据集,通过人工标注的数据集来训练该交互模型使得该模型的精度较高,从而利用该交互模型输出的相关性分数较为准确,从而提高学生模型的准确度。
402、将迁移数据集输入至训练好的第一模型和第二模型,分别输出第一相关性分数集和第二相关性分数集。
示例性的,这里可以从数据库随机选取近百万的无标注数据,将该无标注数据作为迁移数据,该迁移数据集包括近百万的无标注数据。从数据库中直接获取迁移数据集,无需进行人工标注,可以节省时间和人力成本。
假设有n个文档,通过将上述的迁移数据集输入至训练好的老师模型(交互模型)和学生模型(表示模型)后,得到的Teacher给出的相关性分数集为t1,...,tn,Student给出相关性分数集为s1,...,sn。
403、至少部分的根据第一相关性分数集和所述第二相关性分数集确定蒸馏损失。
示例性的,上述的步骤403可以通过以下内容实现:
403a、根据第一相关性分数集确定第一得分矩阵,以及根据第二相关性分数集确定第二得分矩阵。
403b、至少部分的根据第一得分矩阵和第二得分矩阵确定蒸馏矩阵。
403c、根据蒸馏矩阵确定蒸馏损失。
基于上述的n个文档,Teacher给出相关性分数集t1,...,tn后,令a_{ij}=sign(ti-tj),Student给出相关性分数集s1,...,sn后,令b_{ij}=si-sj,其中:a_{ij}表示Teacher预测的第i个文档是否比第j个文档分数高所形成的矩阵,sign代表符号位,当ti>tj时,该a_{ij}为1,当ti<tj时,该a_{ij}为-1,当ti=tj时,该a_{ij}为0;b_{ij}表示Student预测的第i个文档与第j个文档分数差值所形成的矩阵。这里以举例方式进行说明第一得分矩阵和第二得分矩阵,具体如下:
例如,t1=3,t2=2,t3=1,经过a_{ij}=sign(ti-tj)后,得到的第一得分矩阵为:
Figure 121956DEST_PATH_IMAGE001
例如,s1=1,s2=3,s3=2,经过b_{ij}=si-sj后,得到的第二得分矩阵为:
Figure 901693DEST_PATH_IMAGE002
示例性的,上述的步骤403b可以通过以下内容实现:
403b1、构建pairwise hinge loss函数模型。
403b2、将第一得分矩阵和所述第二得分矩阵输入至pairwise hinge loss函数模型,输出蒸馏矩阵。
示例性的,本发明实施例中提供的pairwise hinge loss函数模型如下:
p_{ij}=max(0,margin-a_{ij}*b_{ij})*abs(a_{ij}),其中margin为设定值,可以根据实际情况进行调整,通常情况下,可以设置为1; p_{ij}为蒸馏矩阵,a_{ij}为第一得分矩阵,b_{ij}为第二得分矩阵,abs代表绝对值。通过将上述的b_{ij}和b_{ij}代入至pairwise hinge loss函数模型中,就可以得到蒸馏矩阵p_{ij}。
进一步可选的,在得到蒸馏矩阵后p_{ij},可以计算该蒸馏矩阵p_{ij}的平均值,将该平均值作为蒸馏损失,即loss=ave(p_{ij}),loss用于表示蒸馏损失。
404、根据蒸馏损失优化第二模型的参数,得到训练好的第二模型。
示例性的,这里的第二模型为表示模型,即Student。根据蒸馏损失对Student进行优化,直到获得训练好的Student,这样获得的Student精度更高,在文本查询时的准确性也较高,并且实现了在表示模型和交互模型间进行蒸馏学习,提高了表示模型训练的精度、节省表示模型训练的时间,以及降低表示模型训练的复杂度。
在进行实验的过程中,发明人发现,采用上述的蒸馏方式训练得到的模型,相较于现有技术中直接进行训练表示模型,蒸馏交互模型能够产出预测能力更强的表示模型。其中,蒸馏得到的128维的表示模型的精度超过了直接训练得到的未经压缩的768维表示模型。
如图5所示,为本发明实施例提供的蒸馏学习模型方式和直接训练模型方式得到的精度对比示意图。其中,最后1行是通过直接训练的方式得到768维的表示模型,而第1行到第7行的分别是通过蒸馏方式得到的8维、16维、32维、64维、128维、256维度、384维度以及768维度的结果。
从搜索结果的前五名(对应图5中的valid_ndcg@5)来看,直接训练方式得到的768维对应的精度为0.865016,而蒸馏方式128维对应的精度为0.870348,其超过了直接训练表示模型方式的精度;从搜索结果的前十名(对应图5中的valid_ndcg@10)来看,直接训练方式得到的精度为0.893695,而蒸馏方式64维对应的精度为0.893961,其超过了直接训练表示模型方式的精度。因此,本发明在压缩表示模型维度的同时,超出了直接训练得到的未经压缩的表示模型的精度,并且能够实现跨不同类型的模型蒸馏。
需要说明的是,当第一模型为表示模型,第二模型为交互模型,通过蒸馏学习表示模型来训练交互模型的内容和图4所描述的相关内容类似,相应的,老师模型为表示模型,学生模型为交互模型,同样可以实现跨不同类型的蒸馏学习,具体可以参见上文的具体内容,这里就不再赘述。
如图6所示,为本发明实施例提供的一种文本查询方法的流程图。该方法包括:
601、获取查询文本和候选文本。
示例性的,上述的查询文本可以是一个或多个词语、一个句子等,上述的候选文本也可以是一个句子、一段话或者一篇文章。上述的候选文本可以是一个,也可以是多个。
602、将查询文本和候选文本输入至预先训练好的文本查询模型中,输出查询文本和候选文本之间的得分。
其中,上述的文本查询模型为通过蒸馏学习交互模型训练得到的表示模型,且蒸馏过程中的蒸馏损失是通过pairwise hinge loss函数模型确定的,其表示模型的训练过程可以参照上文图1-图5的对应内容。
603、根据查询文本和候选文本之间的得分输出查询文本相匹配的目标候选文本。
其中,可以设定一个阈值,当候选文本只有一个,可以通过阈值判断该候选文本是否为查询文本的相关文本,当候选文本包括多个,则将大于阈值的候选文本列表作为目标候选文本。
本发明实施例采用pairwise hinge loss作为损失函数,能够使得学生模型学习到更多老师模型中的知识,从而获得的文本查询模型能够准确地从候选文本中获得与查询文本相匹配的目标文本。
需要说明的是,本发明实施例提供的模型蒸馏方法所获得的表示除了能够应用在文本检索这一场景外,还可以用于机器翻译、文本分类等应用场景,其不同的应用场景,选择的训练样本不同即可。
下面将基于图4对应的通过蒸馏学习第一模型训练第二模型的方法的实施例中的相关描述对本发明实施例提供的一种通过蒸馏学习第一模型训练第二模型的装置进行介绍。以下实施例中与上述实施例相关的技术术语、概念等的说明可以参照上述的实施例。
如图7所示,为本发明实施例提供的一种通过蒸馏学习第一模型训练第二模型的装置的结构示意图。该装置7包括:训练模块701,输入模块702、第一输出模块703、确定模块704以及优化模块705,其中:
训练模块701,被配置为利用已标注的数据集训练第一模型;输入模块702,被配置为将迁移数据集输入至训练好的第一模型和第二模型;第一输出模块703,被配置为输出第一相关性分数集和第二相关性分数集;确定模块704,被配置为根据第一相关性分数集和第二相关性分数集确定蒸馏损失;优化模块705,被配置为根据蒸馏损失优化第二模型的参数,得到训练好的第二模型。
作为一种优选的实施方式,确定模块704被配置为具体用于:根据第一相关性分数集确定第一得分矩阵,以及根据第二相关性分数集确定第二得分矩阵;至少部分的根据第一得分矩阵和第二得分矩阵确定蒸馏矩阵;根据蒸馏矩阵确定蒸馏损失。
作为一种优选的实施方式,确定模块704被配置为具体还用于:构建成对铰链损失pairwise hinge loss函数模型;将第一得分矩阵和第二得分矩阵输入至pairwise hingeloss函数模型,输出蒸馏矩阵。
作为一种优选的实施方式,其中,第一模型和第二模型的输入数据类型以及输出类型不同,第一模型为交互模型,第二模型为表示模型,交互模型和表示模型的内容具体参见图2和图3对应的内容。
作为一种优选的实施方式,其中,交互模型包括输入层、Transformer层以及输出层,输入为:s1,s2,...sn,输出为:s1向量,s2向量,...,sn向量,其中:si是query和doci合并的语句;表示模型包括输入层、Transformer层以及输出层,输入为:query,doc1,...,docn,输出为:query向量,doc1向量,...,docn向量。
下面将基于图6对应的文本查询方法的实施例中的相关描述对本发明实施例提供的一种文本查询装置进行介绍。以下实施例中与上述实施例相关的技术术语、概念等的说明可以参照上述的实施例。
如图8所示,为本发明实施例提供的一种文本查询装置的结构示意图,装置包括:获取模块801、处理模块802以及第二输出模块803,其中:
获取模块801,被配置为获取查询文本和候选文本;处理模块802,被配置为将查询文本和候选文本输入至预先训练好的文本查询模型中,输出查询文本和候选文本之间的得分;其中,文本查询模型为通过蒸馏学习交互模型训练得到的表示模型,且蒸馏过程中的蒸馏损失是通过pairwise hinge loss函数模型确定的;第二输出模块803,根据查询文本和候选文本之间的得分输出查询文本相匹配的目标候选文本。
本发明实施例提供的模型蒸馏学习装置以及文本查询装置,首先,通过利用已标注的数据集训练第一模型;其次,将迁移数据集输入至训练好的第一模型和第二模型,分别输出第一相关性分数集和第二相关性分数集;然后,至少部分的根据第一相关性分数集和第二相关性分数集确定蒸馏损失;最后,根据蒸馏损失优化第二模型的参数,得到训练好的第二模型,其中:第一模型和第二模型为不同类型的模型。上述的方法通过确定第一模型和第二模型的参数差值作为第二模型的优化参数,能够准确找到第一模型和第二模型的差异,从而能够优化第二模型参数,使得训练后的第二模型精度较高,从而在文本查询时的准确性也较高。此外,本方案能够实现不同类型模型间的蒸馏,有利于提高模型训练的精度、节省模型训练的时间,以及降低模型训练的复杂度。
如图9所示,为本发明实施例提供的一种电子设备的结构示意图,该电子设备900包括中央处理单元(CPU)901,其可以根据存储在只读存储器(ROM)902中的程序或者从存储部分908加载到随机访问存储器(RAM)903中的程序而执行如图7所示的各种适当的动作和处理。在RAM 903中,还存储有电子设备900操作所需的各种程序和数据。CPU 901、ROM 902以及RAM 903通过总线904彼此相连。输入/输出(I/O)接口905也连接至总线904。
以下部件连接至I/O接口905:包括键盘、鼠标等的输入部分906;包括诸如阴极射线管(CRT)、液晶显示器(LCD)等以及扬声器等的输出部分907;包括硬盘等的存储部分908;以及包括诸如LAN卡、调制解调器等的网络接口卡的通信部分909。通信部分909经由诸如因特网的网络执行通信处理。驱动器910也根据需要连接至I/O接口905。可拆卸介质911,诸如磁盘、光盘、磁光盘、半导体存储器等等,根据需要安装在驱动器910上,以便于从其上读出的计算机程序根据需要被安装入存储部分908。
通过以上的实施方式的描述,所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,仅以上述各功能模块的划分举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将装置的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
以上,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (10)

1.一种通过蒸馏学习第一模型训练第二模型的方法,其特征在于,所述第一模型和第二模型为不同类型的模型,所述方法包括:
利用已标注的数据集训练第一模型;
将迁移数据集输入至训练好的第一模型和所述第二模型,分别输出第一相关性分数集和第二相关性分数集;
至少部分的根据所述第一相关性分数集和所述第二相关性分数集确定蒸馏损失;
根据所述蒸馏损失优化所述第二模型的参数,得到训练好的第二模型。
2.根据权利要求1所述的方法,其特征在于,所述至少部分的根据所述第一相关性分数集和所述第二相关性分数集确定蒸馏损失,包括:
根据所述第一相关性分数集确定第一得分矩阵,以及根据所述第二相关性分数集确定第二得分矩阵;
至少部分的根据所述第一得分矩阵和所述第二得分矩阵确定蒸馏矩阵;
根据所述蒸馏矩阵确定蒸馏损失。
3.根据权利要求2所述的方法,其特征在于,所述至少部分的根据所述第一得分矩阵和所述第二得分矩阵确定蒸馏矩阵,包括:
构建成对铰链损失pairwise hinge loss函数模型;
将所述第一得分矩阵和所述第二得分矩阵输入至pairwise hinge loss函数模型,输出蒸馏矩阵。
4.根据权利要求1所述的方法,其特征在于,所述第一模型和第二模型的输入内容以及输出内容不同,其中:所述第一模型为交互模型,所述第二模型为表示模型;或者,所述第一模型为表示模型,所述第二模型为交互模型。
5.根据权利要求4所述的方法,其特征在于,所述交互模型包括输入层、Transformer层以及输出层,输入为:s1,s2,...sn,输出为:s1向量,s2向量,...,sn向量,其中:si是query和doci合并的语句;
所述表示模型包括输入层、Transformer层以及输出层,输入为:query,doc1,...,docn,输出为:query向量,doc1向量,...,docn向量。
6.一种文本查询方法,其特征在于,所述方法包括:
获取查询文本和候选文本;
将所述查询文本和候选文本输入至预先训练好的文本查询模型中,输出查询文本和候选文本之间的得分;其中,所述文本查询模型为通过蒸馏学习交互模型训练得到的表示模型,且蒸馏过程中的蒸馏损失是通过pairwise hinge loss函数模型确定的;
根据所述查询文本和候选文本之间的得分输出所述查询文本相匹配的目标候选文本。
7.一种通过蒸馏学习第一模型训练第二模型的装置,其特征在于,所述第一模型和第二模型为不同类型的模型,所述装置包括:
训练模块,被配置为利用已标注的数据集训练第一模型;
输入模块,被配置为将迁移数据集输入至训练好的第一模型和所述第二模型;
第一输出模块,被配置为输出第一相关性分数集和第二相关性分数集;
确定模块,被配置为至少部分的根据所述第一相关性分数集和所述第二相关性分数集确定蒸馏损失;
优化模块,被配置为根据所述蒸馏损失优化所述第二模型的参数,得到训练好的第二模型。
8.一种文本查询装置,其特征在于,所述装置包括:
获取模块,被配置为获取查询文本和候选文本;
处理模块,被配置为将所述查询文本和候选文本输入至预先训练好的文本查询模型中,输出查询文本和候选文本之间的得分;其中,所述文本查询模型为通过蒸馏学习交互模型训练得到的表示模型,且蒸馏过程中的蒸馏损失是通过pairwise hinge loss函数模型确定的;
第二输出模块,根据所述查询文本和候选文本之间的得分输出所述查询文本相匹配的目标候选文本。
9.一种电子设备,包括:存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1-6任一项所述的方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有可执行指令,该指令被处理器执行时使处理器执行如权利要求1-6任一项所述的方法。
CN202011275406.2A 2020-11-16 2020-11-16 一种模型蒸馏学习方法、文本查询方法及装置 Active CN112101573B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011275406.2A CN112101573B (zh) 2020-11-16 2020-11-16 一种模型蒸馏学习方法、文本查询方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011275406.2A CN112101573B (zh) 2020-11-16 2020-11-16 一种模型蒸馏学习方法、文本查询方法及装置

Publications (2)

Publication Number Publication Date
CN112101573A true CN112101573A (zh) 2020-12-18
CN112101573B CN112101573B (zh) 2021-04-30

Family

ID=73785536

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011275406.2A Active CN112101573B (zh) 2020-11-16 2020-11-16 一种模型蒸馏学习方法、文本查询方法及装置

Country Status (1)

Country Link
CN (1) CN112101573B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113312455A (zh) * 2021-06-23 2021-08-27 北京鼎泰智源科技有限公司 一种基于知识蒸馏的合同智能审核方法及装置

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170235848A1 (en) * 2012-08-29 2017-08-17 Dennis Van Dusen System and method for fuzzy concept mapping, voting ontology crowd sourcing, and technology prediction
CN109344897A (zh) * 2018-09-29 2019-02-15 中山大学 一种基于图片蒸馏的通用物体检测框架及其实现方法
CN111553479A (zh) * 2020-05-13 2020-08-18 鼎富智能科技有限公司 一种模型蒸馏方法、文本检索方法及装置
CN111581929A (zh) * 2020-04-22 2020-08-25 腾讯科技(深圳)有限公司 基于表格的文本生成方法及相关装置

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170235848A1 (en) * 2012-08-29 2017-08-17 Dennis Van Dusen System and method for fuzzy concept mapping, voting ontology crowd sourcing, and technology prediction
CN109344897A (zh) * 2018-09-29 2019-02-15 中山大学 一种基于图片蒸馏的通用物体检测框架及其实现方法
CN111581929A (zh) * 2020-04-22 2020-08-25 腾讯科技(深圳)有限公司 基于表格的文本生成方法及相关装置
CN111553479A (zh) * 2020-05-13 2020-08-18 鼎富智能科技有限公司 一种模型蒸馏方法、文本检索方法及装置

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113312455A (zh) * 2021-06-23 2021-08-27 北京鼎泰智源科技有限公司 一种基于知识蒸馏的合同智能审核方法及装置

Also Published As

Publication number Publication date
CN112101573B (zh) 2021-04-30

Similar Documents

Publication Publication Date Title
CN109740126B (zh) 文本匹配方法、装置及存储介质、计算机设备
CN111553479B (zh) 一种模型蒸馏方法、文本检索方法及装置
WO2021135455A1 (zh) 语义召回方法、装置、计算机设备及存储介质
CN112287069B (zh) 基于语音语义的信息检索方法、装置及计算机设备
CN112131401B (zh) 一种概念知识图谱构建方法和装置
US11238050B2 (en) Method and apparatus for determining response for user input data, and medium
CN113297360B (zh) 基于弱监督学习和联合学习机制的法律问答方法及设备
CN111552773A (zh) 一种阅读理解任务中是否类问题关键句寻找方法及系统
CN111274822A (zh) 语义匹配方法、装置、设备及存储介质
CN110909539A (zh) 语料库的词语生成方法、系统、计算机设备和存储介质
CN110852069A (zh) 一种文本相关性评分方法及系统
CN115827819A (zh) 一种智能问答处理方法、装置、电子设备及存储介质
CN112115252A (zh) 智能辅助写作处理方法、装置、电子设备及存储介质
CN112347758A (zh) 文本摘要的生成方法、装置、终端设备及存储介质
CN111666376A (zh) 一种基于段落边界扫描预测与词移距离聚类匹配的答案生成方法及装置
CN112686053A (zh) 一种数据增强方法、装置、计算机设备及存储介质
CN116541493A (zh) 基于意图识别的交互应答方法、装置、设备、存储介质
CN116881425A (zh) 一种通用型文档问答实现方法、系统、设备及存储介质
CN111506596A (zh) 信息检索方法、装置、计算机设备和存储介质
CN112101573B (zh) 一种模型蒸馏学习方法、文本查询方法及装置
CN112800205B (zh) 基于语义变化流形分析获取问答相关段落的方法、装置
CN110597960A (zh) 一种个性化在线课程与职业双向推荐方法及系统
CN113822040A (zh) 一种主观题阅卷评分方法、装置、计算机设备及存储介质
CN113505786A (zh) 试题拍照评判方法、装置及电子设备
CN113204679B (zh) 一种代码查询模型的生成方法和计算机设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant