CN111553479B - 一种模型蒸馏方法、文本检索方法及装置 - Google Patents
一种模型蒸馏方法、文本检索方法及装置 Download PDFInfo
- Publication number
- CN111553479B CN111553479B CN202010405217.6A CN202010405217A CN111553479B CN 111553479 B CN111553479 B CN 111553479B CN 202010405217 A CN202010405217 A CN 202010405217A CN 111553479 B CN111553479 B CN 111553479B
- Authority
- CN
- China
- Prior art keywords
- layer
- model
- output
- matrix
- distillation loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 187
- 238000000034 method Methods 0.000 title claims abstract description 66
- 238000012549 training Methods 0.000 claims abstract description 52
- 239000011159 matrix material Substances 0.000 claims description 116
- 238000006243 chemical reaction Methods 0.000 claims description 27
- 230000009466 transformation Effects 0.000 claims description 20
- 230000006870 function Effects 0.000 claims description 19
- 230000007704 transition Effects 0.000 claims description 14
- 238000004364 calculation method Methods 0.000 claims description 13
- 239000002689 soil Substances 0.000 claims description 11
- 238000013528 artificial neural network Methods 0.000 claims description 7
- 238000004891 communication Methods 0.000 claims description 6
- 230000001131 transforming effect Effects 0.000 description 12
- 238000010586 diagram Methods 0.000 description 10
- 230000008569 process Effects 0.000 description 7
- 238000004422 calculation algorithm Methods 0.000 description 6
- 238000012545 processing Methods 0.000 description 5
- 238000005457 optimization Methods 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 238000004590 computer program Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000008901 benefit Effects 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000013140 knowledge distillation Methods 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000012827 research and development Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Machine Translation (AREA)
Abstract
本申请提供一种模型蒸馏方法、文本检索方法及装置。方法包括:获取老师模型和学生模型,将训练样本分别输入老师模型和学生模型中;利用EMD计算老师模型的第一transformer层的输出与学生模型的第二transformer层的输出的第一蒸馏损失,并分别计算老师模型的第一embedding层的输出与学生模型的第二embedding层的输出的第二蒸馏损失,以及老师模型的第一prediction层的输出与学生模型的第二prediction层的输出的第三蒸馏损失;根据第一蒸馏损失、第二蒸馏损失和第三蒸馏损失对学生模型中的参数进行优化获得训练好的学生模型。本申请能够保证学生模型从老师模型中学习到更多的知识。
Description
技术领域
本申请涉及自然语言处理技术领域,具体而言,涉及一种模型蒸馏方法、文本检索方法及装置。
背景技术
随着深度学习的发展,自然语言处理中深度神经网络的使用越来越多,然而很多模型,如Bert,xlnet,都存在模型复杂,参数量大,训练时间长,内存消耗大,推理时间长的问题,很难直接应用于图形处理器(Graphics Processing Unit,GPU)及智能手机等应用资源受限的设备上。
为了解决这些问题,研究者们提出了许多模型压缩的方法,以减少受训神经网络的冗余,对于nlp未来的发展而言非常有价值。知识蒸馏就是一种对模型压缩的方法,该方法是通过老师模型中的某一层与学生模型的某一层进行对应,通过学习将学生模型的某一层对老师模型的对应层的距离尽量缩小,以实现学生模型对老师模型的学习。这种模型压缩方法学生模型能够学习到老师模型中的知识有限,因此,训练获得的学生模型在文本检索时准确性不高。
发明内容
本申请实施例的目的在于提供一种模型蒸馏方法、文本检索方法及装置,用以解决现有技术获得的学生模型在文本检索时准确性不高的问题。
第一方面,本申请实施例提供一种模型蒸馏方法,包括:获取老师模型和学生模型,其中,所述老师模型包括第一向量embedding层、第一转换器transformer层和第一预测prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;其中,所述老师模型为经过预先训练获得的;利用搬土距离EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。
本申请实施例中,发明人在研发过程中发现Bert模型中的transformer层对模型的贡献最大,包含的信息最丰富,学生模型在该层的学习能力也最为重要,因此,基于EMD对老师模型进行蒸馏,能够保证学生模型学习到更多的老师模型的知识。
进一步地,所述利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,包括:获取所述第一transformer层中各层分别输出的第一注意力attention矩阵以及所述第二transformer层中各层分别输出的第二attention矩阵;根据所述第一attention矩阵和所述第二attention矩阵计算第一EMD距离;获取所述第一transformer层中各层分别输出的第一全连接前馈神经网络FFN隐层矩阵和所述第二transformer层中各层分别输出的第二FFN隐层矩阵;根据所述第一FFN隐层矩阵和所述第二FFN隐层矩阵计算第二EMD距离;根据所述第一EMD距离和所述第二EMD距离获得所述第一蒸馏损失。
本申请实施例通过分别计算第一transformer层中各层与第二transformer层中各层的之间的EMD距离,能够将第一transformer层中的知识尽可能多的转移给学生模型的第二transformer层。
进一步地,所述根据所述第一attention矩阵和所述第二attention矩阵计算第一EMD距离,包括:根据计算获得所述第一EMD距离;其中,Lattn为所述第一EMD距离,AT为所述第一attention矩阵,AS为所述第二attention矩阵,为第一attention矩阵与第二attention矩阵之间的均方误差,且为第i层第一transformer层的第一attention矩阵,/>为第j层第二transformer层的第二attention矩阵,fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数。
本申请实施例通过分别计算第一transformer层中各层对应的第一attention矩阵与第二transformer层中各层的第二attention矩阵之间的EMD距离,为获得第一transformer层与第二transformer层之间的EMD距离提供数据依据。
进一步地,所述根据所述第一FFN隐层矩阵和所述第二FFN隐层矩阵计算第二EMD距离,包括:根据计算获得第二EMD距离;其中,Lffn为所述第二EMD距离,HT为第一transformer层的第一FFN隐层矩阵,HS为第二transformer层的第二FFN隐层矩阵,/>为第一FFN隐层矩阵与第二FFN隐层矩阵的均方误差,且为第j层第二transformer层的第二FFN隐层矩阵,Wh为第一预设转换矩阵,/>为第i层第一transformer层的第一FFN隐层矩阵。fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数。
本申请实施例通过分别计算第一transformer层中各层对应的第一FFN隐层矩阵与第二transformer层中各层的第二FFN隐层矩阵之间的EMD距离,为获得第一transformer层与第二transformer层之间的EMD距离提供数据依据。
进一步地,所述计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,包括:根据Le=MSE(ESWe,ET)计算获得所述第二蒸馏损失;其中,Le为所述第二蒸馏损失,ES为学生模型的embedding层输出的向量矩阵,We为第二预设转换矩阵,ET为老师模型的embedding层输出的向量矩阵。
本申请实施例通过均方误差能够准确地反映老师模型的embedding层与学生模型的embedding层之间的距离。
进一步地,所述计算所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失,包括:根据Lp=αLph+(1-α)T2Lps计算获得所述第三蒸馏损失;其中,Lp为所述第三蒸馏损失,α为每个损失的权重,T为温度,Lph为学生模型的输出与真实标签之间的交叉熵损失,Lps为学生模型的输出与老师模型的输出之间的交叉熵损失,且Lph=-Y·softmax(zS),Lps=-softmax(zT/T)·log_softmax(zS/T),Y为真实标签,zT为老师模型的输出,zS为学生模型的输出。
本申请实施例通过软硬目标线性结合的方式对老师模型的prediction层进行蒸馏,老师模型的prediction层输出的软信息中包括了更多的信息量,能够使学生学习到更多的知识。
进一步地,所述根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,包括:根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失输出获得所述学生模型的目标函数;其中,所述目标函数为Lmodel=∑i∈{e,t,p}λiLi(Si,Ti),e表示embedding层,t表示transformer层,p表示prediction层,λi表示相应的权重,Li表示相应层的蒸馏损失,Si表示学生模型,Ti表示老师模型;根据所述目标函数对所述学生模型中的参数进行优化。
本申请实施例通过EMD距离的算法计算第一transformer层与第二transformer层之间的第一蒸馏损失,通过均方误差的方式计算第一embedding层与第二embedding层的第二蒸馏损失,利用软硬目标结合的方式计算第一prediction层与第二prediction层的第三蒸馏损失,最终可以获得整个模型对应的准确的蒸馏损失,从而利用整体模型的蒸馏损失对学生模型进行训练,获得的学生模型能够从老师模型中学到更多的知识。
第二方面,本申请实施例提供一种文本检索方法,包括:获取查询文本和候选文本;将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
本申请实施例通过利用EMD算法对老师模型中的transformer层进行模型蒸馏,能够使得学生模型学习到更多老师模型中的知识,从而获得的文本检索模型能够准确地从候选文本中获得与查询文件相匹配的目标文本。
进一步地,在将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中之前,所述方法还包括:
获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;
获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;
利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;
根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得所述文本检索模型。
本申请实施例中,由于Bert模型中的transformer层对模型的贡献最大,包含的信息最丰富,学生模型在该层的学习能力也最为重要,因此,基于EMD对老师模型进行蒸馏,能够保证学生模型学习到更多的老师模型的知识。
进一步地,所述确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本,包括:若最大匹配率大于预设值,则将匹配率最大的候选文本作为与所述查询文本相匹配的目标候选文本;或将匹配率大于预设值的候选文本作为与所述查询文本相匹配的目标候选文本。
本申请实施例中通过预设要求并根据预设要求选取目标候选文本,能够保证选取与查询文本更加匹配的候选文本。
第三方面,本申请实施例提供一种模型蒸馏装置,包括:初始模型获取模块,用于获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;样本获取模块,用于获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;其中,所述老师模型为经过预先训练获得的;损失计算模块,用于利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;优化模块,用于根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。
第四方面,本申请实施例提供一种文本检索装置,包括:文本获取模块,用于获取查询文本和候选文本;检索模块,用于将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;目标确定模块,用于确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
第五方面,本申请实施例提供一种电子设备,包括:处理器、存储器和总线,其中,所述处理器和所述存储器通过所述总线完成相互间的通信;所述存储器存储有可被所述处理器执行的程序指令,所述处理器调用所述程序指令能够执行第一方面或第二方面的方法。
第六方面,本申请实施例提供一种非暂态计算机可读存储介质,包括:所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令使所述计算机执行第一方面或第二方面的方法。
本申请的其他特征和优点将在随后的说明书阐述,并且,部分地从说明书中变得显而易见,或者通过实施本申请实施例了解。本申请的目的和其他优点可通过在所写的说明书、权利要求书、以及附图中所特别指出的结构来实现和获得。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1为本申请实施例提供的一种模型蒸馏方法流程示意图;
图2为本申请实施例提供的基于EMD对Bert模型蒸馏框架图;
图3为本申请实施例提供的transformer层结构示意图;
图4为本申请实施例提供的一种文本检索方法流程示意图;
图5为本申请实施例提供的模型蒸馏装置结构示意图;
图6为本申请实施例提供的文本检索装置结构示意图;
图7为本申请实施例提供的电子设备实体结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。
应理解,本申请实施例提供的模型蒸馏方法及文本检索方法可以应用于终端设备(也可以称为电子设备)以及服务器;其中,终端设备具体可以为智能手机、平板电脑、计算机、个人数字助理(Personal Digital Assitant,PDA)等;服务器具体可以为应用服务器,也可以为Web服务器。
为了便于理解本申请实施例提供的技术方案,下面以终端设备作为执行主体威力,对本申请实施例提供的模型蒸馏方法及文本检索方法的应用场景进行介绍。
图1为本申请实施例提供的一种模型蒸馏方法流程示意图,如图1所示,该方法包括:
步骤101:获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数。
在具体的实施过程中,老师模型和学生模型均为Bert模型,老师模型即为已经训练好的网络层数较多的Bert模型,学生模型相对于老师模型来说,其网络层数较少,且需要根据老师模型进行学习,以获得训练好的学生模型。
Bert模型可以分为向量(embedding)层、转换器(transformer)层和预测(prediction)层,每种层是知识的不同表示形式,因此我们对每种层采取不同的学习方式,即每种层的目标函数不同。其中,transformer层对模型贡献最大,包含的信息最丰富,学生模型在该层的学习能力也最为重要,为了使小模型中的信息量能最大化保存,对transformer层提出了基于EMD的蒸馏方式。
图2为本申请实施例提供的基于EMD对Bert模型蒸馏框架图,如图2所示。老师模型和学生模型均由embedding层、transformer层和prediction层组成,为了便于区分,将老师模型中的各层分别命名为第一embedding层、第一transformer层和第一prediction层,将学生模型中的各层命名为第二embedding层、第二transformer层和第二prediction层。其中老师模型的第一transformer层有N层,学生模型的第二transformer层是M层(M<N),老师模型的隐层维度为d,学生模型的隐层维度为d'(d'<d)。
步骤102:获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中,其中,所述老师模型为经过预先训练获得的。
在具体的实施过程中,训练样本可以根据学生模型的具体应用场景进行选择,例如:若训练好的学生模型用于文本检索,那么训练样本包括查询文本以及至少一个候选文本,其标签即为至少一个候选文本中哪些是与查询文本相匹配的文本,哪些是不匹配的。又如:若训练好的学生模型用于文本分类,那么训练样本包括文本数据以及该文本数据对应的类别。将训练样本输入老师模型和学生模型中,根据各层的输出计算老师模型和学生模型之间的距离。老师模型的训练过程本申请实施例不作具体限定。
步骤103:利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失。
在具体的实施过程中,为了使学生模型可以从老师模型中学到尽可能丰富的知识,本申请实施例对每种层都进行了学习,由于每种层的知识表示方式不同,所以每种层都是用不同的学习方式,最终将每种层的知识整合到学生模型。
在将训练样本分别输入老师模型和学生模型中后,通过EMD算法计算第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,第一蒸馏损失用于表征老师模型的第一transformer层与学生模型的第二transformer层之间的差距;通过均方误差计算第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,第二蒸馏损失用于表征老师模型的第一embedding层与学生模型的第二embedding层之间的差距;通过软硬目标结合的方式计算第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失,第三蒸馏损失用于表征第一prediction层与第二prediction层之间的差距。这里软目标指的是老师模型的输出(soft target),硬目标指的是真实目标(hardtarget),在计算第三蒸馏损失的时候综合考虑了学生模型与老师模型的输出之间的损失和学生模型与真实目标之间的损失。
步骤104:根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。
在具体的实施过程中,在获得第一蒸馏损失、第二蒸馏损失和第三蒸馏损失之后,可以获得老师模型与学生模型的蒸馏损失,并根据该蒸馏损失对学生模型中的参数进行优化,直至蒸馏损失小于预设值,或者训练的次数满足预设次数,从而获得训练好的学生模型。
本申请实施例中,由于Bert模型中的transformer层对模型的贡献最大,包含的信息最丰富,学生模型在该层的学习能力也最为重要,因此,基于EMD对老师模型进行蒸馏,能够保证学生模型学习到更多的老师模型的知识。
在上述实施例的基础上,所述利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,包括:
获取所述第一transformer层中各层分别输出的第一注意力(attention)矩阵以及所述第二transformer层中各层分别输出的第二attention矩阵;
根据所述第一attention矩阵和所述第二attention矩阵计算第一EMD距离;
获取所述第一transformer层中各层分别输出的第一全连接前馈神经网络FFN隐层矩阵和所述第二transformer层中各层分别输出的第二FFN隐层矩阵;
根据所述第一FFN隐层矩阵和所述第二FFN隐层矩阵计算第二EMD距离;
根据所述第一EMD距离和所述第二EMD距离获得所述第一蒸馏损失。
在具体的实施过程中,transformer层是Bert模型中的重要组成部分,通过自注意力机制可以捕获长距离依赖关系,一个标准的transformer主要包括两部分:多头注意力机制(Multi-Head Attention,MHA)和全连接前馈神经网络(FFN)。EMD是使用线性规划计算两个分布之间最优距离的方法,可以使知识的蒸馏更加合理。
图3为本申请实施例提供的transformer层结构示意图,如图3所示。训练样本经过embedding层处理后输出,并将embedding层的处理结果输入到transformer层,transformer层中的MHA对输入进行处理获得attention矩阵,然后经过正则化后输入FFN中,由FFN进行处理。
下面分别对MHA和FFN进行介绍。
一、MHA
attention矩阵包括查询矩阵(querys)、键矩阵(keys)、键值矩阵(values),并且attention矩阵如公式(1)所示:
其中,A为注意力矩阵,可以捕获丰富的语言知识,Q为矩阵querys,K为矩阵keys,V为矩阵values,Attention(Q,K,V)为A的一种新的表示,比如在输入attention之前,句子Z的表示是embedding(Z),然后经过公式(1)运算,就可以得到一个新的表示embedding'(Z),新的表示中因为使用了注意力矩阵,所以是包含了上下文信息的一个表示,此处使用Attention(Q,K,V)来表示。
在MHA中,将Q、K、V分为h个部分,分别计算attention矩阵,然后再将h个attention矩阵拼接起来,可以使模型关注不同方面的信息,可以理解的是,h的具体取值可以预先设定。其拼接公式如公式(2)所示:
Multihead(Q,K,V)=concat(A1,A2,...,Ah)W (2)
Ai为第i个attention矩阵,其中,attention矩阵通过公式(1)计算获得,W为随机变量。
二、FFN
FFN包含两个线性变换和一个ReLU激活函数,第二FFN隐层矩阵通过公式(3)计算获得:
FFN(x)=max(0,xW1+b1)W2+b2 (3)
其中,x为输入到FFN中的矩阵,W1和W2为线性变换中的两个变量,b1和b2为偏置;并且W1、W2、b1和b2均为学生模型中的参数,并且这些参数的具体值要经过训练获得。
下面将介绍EMD的原理,EMD本身是一个线性规划问题,指把p位置的m个坑的土,用最小的代价搬到q位置的n个坑中,dij是pi到qj两个坑的距离,fij是从pi搬到qj的土量,w是每个坑中的总土量,则工作量就是要最小化的目标。线性规划求解出fij后,再用fij对工作量作个归一化,就得到了EMD。其中,EMD的计算公式如公式(4)所示:
最小化工作量求解fij的数学形式如公式(5)所示:
为了使老师模型中的知识最大化转移到学生模型中,本申请实施例将上述EMD距离作为学生模型的第二transformer层的第一蒸馏损失。由于transformer层中MHA和FFN两部分都具有重要意义,所以两部分的蒸馏都需要考虑,即第一蒸馏损失通过公式(6)计算获得:
Lt=Lattn+Lffn (6)
其中,Lt为第一蒸馏损失,Lattn为第一EMD距离,Lffn为第二EMD距离。
transformer层中每层的知识可以表示不同的含义,如词法知识、句法知识、语义知识等,所以可以将transformer层的知识总量看作1,并假设老师模型中的第一transformer层的第i层的知识量为老师模型中的第二transformer层的第j层的知识量为/>因此应保证每个模型中所有的transformer层知识量之和为1。因此,可以将老师模型中第一transformer层每一层的知识量初始化为1/m,将第二transformer层每一层的知识量初始化为1/n,在之后的训练中对每层知识量进行更新学习,并且在每次更新完成之后,再用softmax对其进行归一化,保证知识量总和为1。应当说明的是,本申请实施例中的知识量可以相当于上述每个坑中的土量。
在根据上面获得attention矩阵的方法计算得到老师模型对应的第一attention矩阵以及学生模型对应的第二attention矩阵之后,可以计算第一attention矩阵与第二attention矩阵之间的EMD距离。具体计算方法如下:
根据公式(7)计算获得所述第一EMD距离:
其中,Lattn为所述第一EMD距离,AT为所述第二attention矩阵,可以通过公式(2)计算获得,AS为所述第一attention矩阵,同样可以通过公式(2)计算获得,为第一attention矩阵与第二attention矩阵之间的均方误差,且/>为第i层第一transformer层的第一attention矩阵,/>为第j层第二transformer层的第二attention矩阵,fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数大于,M大于N。
同理,老师模型中第一transformer层的第i层对应的第一FFN隐层矩阵与学生模型中第二transformer层的第j层对应的第二FFN隐层矩阵/>之间的第二EMD距离的计算方法如下:
根据公式(8)计算获得第二EMD距离:
其中,Lffn为所述第二EMD距离,HT为第一transformer层的第一FFN隐层矩阵,HS为第二transformer层的第二FFN隐层矩阵,为第一FFN隐层矩阵与第二FFN隐层矩阵的均方误差,且/>为第j层第二transformer层的第二FFN隐层矩阵,Wh为第一预设转换矩阵,/>为第i层第一transformer层的第一FFN隐层矩阵。fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数,M大于N。
应当说明的是,公式(7)和公式(8)中的fij可以根据相应的第一EMD距离或第二EMD通过公式(5)求解获得。
在根据公式(7)和公式(8)分别计算获得第一EMD距离和第二EMD距离后,可以根据公式(6)计算获得第一蒸馏损失。
在上述实施例的基础上,所述计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,包括:
根据公式(9)计算获得所述第二蒸馏损失:
Le=MSE(ESWe,ET) (9)
其中,Le为所述第二蒸馏损失,ES为学生模型的embedding层输出的向量矩阵,ET为老师模型的embedding层输出的向量矩阵,We为第二预设转换矩阵,其作用是将学生模型输出的向量矩阵和老师模型输出的向量矩阵映射定同一维度。
在上述实施例的基础上,所述计算所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失,包括:
根据公式(10)计算获得所述第三蒸馏损失:
Lp=αLph+(1-α)T2Lps (10)
其中,Lp为所述第三蒸馏损失,α为每个损失的权重,T为温度,用于softmax函数中,当其趋向于0时,softmax函数输出将收敛为一个one-hot向量,趋向于无穷时,softmax的输出则更软,即输出的信息量更加丰富;Lph为学生模型的输出与真实标签之间的交叉熵损失,Lps为学生模型的输出与老师模型的输出之间的交叉熵损失,且Lph=-Y·softmax(zS),Lps=-softmax(zT/T)·log_softmax(zS/T),Y为真实标签,即训练样本中的标签即真实目标,zT为老师模型的输出,zS为学生模型的输出。
本申请实施例中,由于每种层的知识表示方式不同,采用不同的学习方式,分别计算第一transformer层中各层与第二transformer层中各层的之间的EMD距离,能够将第一transformer层中的知识尽可能多的转移给学生模型的第二transformer层,从而能够从老师模型中学习到更多的知识。
在上述实施例的基础上,所述根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,包括:
根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失输出获得所述学生模型的目标函数;其中,所述目标函数为Lmodel=∑i∈{e,t,p}λiLi(Si,Ti),e表示embedding层,t表示transformer层,p表示prediction层,λi表示相应的权重,Li表示相应层的蒸馏损失,当i为t时,Li表示第一蒸馏损失;当i为e时,Li表示第二蒸馏损失;当i为p时,Li表示第三蒸馏损失;Si表示学生模型,Ti表示老师模型;
根据所述目标函数对所述学生模型中的参数进行优化,直到获得训练好的学生模型。
本申请实施例通过EMD距离的算法计算第一transformer层与第二transformer层之间的第一蒸馏损失,通过均方误差的方式计算第一embedding层与第二embedding层的第二蒸馏损失,利用软硬目标结合的方式计算第一prediction层与第二prediction层的第三蒸馏损失,最终可以获得整个模型对应的准确的蒸馏损失,从而利用整体模型的蒸馏损失对学生模型进行训练,获得的学生模型能够从老师模型中学到更多的知识。
图4为本申请实施例提供的一种文本检索方法流程示意图,如图4所示,该方法包括:
步骤401:获取查询文本和候选文本;可以理解的是,候选文本可以是一个,也可以有多个,当候选文本为一个时,本申请实施例的目的是通过训练好的学生模型判断查询文本与候选文本是否匹配。当候选文本为多个时,本申请实施例的目的是从多个候选文本中检索出与查询文本相匹配的文本。
步骤402:将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得。
其中,查询文本可以是一个句子、一个或多个词语,同样的,一个候选文本也可以是一个句子、一段话或者一篇文章。在将查询文本和候选文本输入到文本检索模型中后,文本检索模型可以获得每一个候选文本与查询文本之间的匹配率。
应当说明的是,本申请实施例中的文本检索模型可以通过上述实施例中的模型蒸馏方法对老师模型进行蒸馏获得。因此,文本检索模型的获得方式参见上述实施例,本申请实施例对此不再赘述。
步骤403:确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
其中,预设要求可以为匹配率大于预设值,若候选文本只有一个,则判断该候选文本与查询文本的匹配率是否大于预设值,若大于,则说明该候选文本就是与查询文本匹配的目标候选文本。若候选文本有多个,则判断每个候选文本分别与查询文本的匹配率是否大于预设值,将匹配率大于预设值的候选文本作为目标候选文本。
也可以是匹配率最大且匹配率大于预设值,若候选文本只有一个,那么认为该候选文本与查询文本的匹配率为最大匹配率,若该最大匹配率大于预设值,则说明该候选文本为目标候选文本,若该最大匹配率不大于预设值,则说明该候选文本不是目标候选文本。若候选文本中有多个,那么从多个候选文本分别与查询文本之间的匹配率中选择匹配率最大的那个,并且判断最大匹配率是否大于预设值,若大于,则将最大匹配率对应的候选文本作为目标候选文本。可以理解的是,若最大的匹配率有两个及以上,且最大匹配率大于预设值,那么最大的匹配率对应的候选文本均为目标候选文本。
本申请实施例通过利用EMD算法对老师模型中的transformer层进行模型蒸馏,能够使得学生模型学习到更多老师模型中的知识,从而获得的文本检索模型能够准确地从候选文本中获得与查询文件相匹配的目标文本。
应当说明的是,通过上述实施例提供的模型蒸馏方法获得的学生模型除了能够应用在文本检索这一场景外,还可以用于机器翻译、文本分类、文本纠错等。不同的应用场景,其选择的训练样本不同。
图5为本申请实施例提供的模型蒸馏装置结构示意图,该装置可以是电子设备上的模块、程序段或代码。应理解,该装置与上述图1方法实施例对应,能够执行图1方法实施例涉及的各个步骤,该装置具体的功能可以参见上文中的描述,为避免重复,此处适当省略详细描述。该装置包括:初始模型获取模块501、样本获取模块502、损失计算模块503和优化模块504,其中:
初始模型获取模块501用于获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;样本获取模块502用于获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;其中,所述老师模型为经过预先训练获得的;损失计算模块503用于利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;优化模块504用于根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。
在上述实施例的基础上,损失计算模块503具体用于:
获取所述第一transformer层中各层分别输出的第一attention矩阵以及所述第二transformer层中各层分别输出的第二attention矩阵;
根据所述第一attention矩阵和所述第二attention矩阵计算第一EMD距离;
获取所述第一transformer层中各层分别输出的第一全连接前馈神经网络FFN隐层矩阵和所述第二transformer层中各层分别输出的第二FFN隐层矩阵;
根据所述第一FFN隐层矩阵和所述第二FFN隐层矩阵计算第二EMD距离;
根据所述第一EMD距离和所述第二EMD距离获得所述第一蒸馏损失。
在上述实施例的基础上,损失计算模块503具体用于:
根据计算获得所述第一EMD距离;
其中,Lattn为所述第一EMD距离,AT为所述第一attention矩阵,AS为所述第二attention矩阵,为第一attention矩阵与第二attention矩阵之间的均方误差,且为第i层第一transformer层的第一attention矩阵,/>为第j层第二transformer层的第二attention矩阵,fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数。
在上述实施例的基础上,损失计算模块503具体用于:
根据计算获得第二EMD距离;
其中,Lffn为所述第二EMD距离,HT为第一transformer层的第一FFN隐层矩阵,HS为第二transformer层的第二FFN隐层矩阵,为第一FFN隐层矩阵与第二FFN隐层矩阵的均方误差,且/>为第j层第二transformer层的第二FFN隐层矩阵,Wh为第一预设转换矩阵,/>为第i层第一transformer层的第一FFN隐层矩阵。fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数。
在上述实施例的基础上,损失计算模块503具体用于:
根据Le=MSE(ESWe,ET)计算获得所述第二蒸馏损失;
其中,Le为所述第二蒸馏损失,ES为学生模型的embedding层输出的向量矩阵,We为第二预设转换矩阵,ET为老师模型的embedding层输出的向量矩阵。
在上述实施例的基础上,损失计算模块503具体用于:
根据Lp=αLph+(1-α)T2Lps计算获得所述第三蒸馏损失;
其中,Lp为所述第三蒸馏损失,α为每个损失的权重,T为温度,Lph为学生模型的输出与真实标签之间的交叉熵损失,Lps为学生模型的输出与老师模型的输出之间的交叉熵损失,且Lph=-Y·softmax(zS),Lps=-softmax(zT/T)·log_softmax(zS/T),Y为真实标签,zT为老师模型的输出,zS为学生模型的输出。
在上述实施例的基础上,优化模块504具体用于:
根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失输出获得所述学生模型的目标函数;其中,所述目标函数为Lmodel=∑i∈{e,t,p}λiLi(Si,Ti),e表示embedding层,t表示transformer层,p表示prediction层,λi表示相应的权重,Li表示相应层的蒸馏损失,Si表示学生模型,Ti表示老师模型;
根据所述目标函数对所述学生模型中的参数进行优化。
综上所述,本申请实施例中,由于Bert模型中的transformer层对模型的贡献最大,包含的信息最丰富,学生模型在该层的学习能力也最为重要,因此,基于EMD对老师模型进行蒸馏,能够保证学生模型学习到更多的老师模型的知识。
图6为本申请实施例提供的文本检索装置结构示意图,该装置可以是电子设备上的模块、程序段或代码。应理解,该装置与上述图4方法实施例对应,能够执行图4方法实施例涉及的各个步骤,该装置具体的功能可以参见上文中的描述,为避免重复,此处适当省略详细描述。该装置包括:文本获取模块601、检索模块602和目标确定模块603,其中:
文本获取模块601用于获取查询文本和候选文本;检索模块602用于将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;目标确定模块603用于确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
在上述实施例的基础上,该装置还包括模型训练模块,用于:
获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;
获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;
利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;
根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得所述文本检索模型。
在上述实施例的基础上,目标确定模块603具体用于:
若最大匹配率大于预设值,则将匹配率最大的候选文本作为与所述查询文本相匹配的目标候选文本;或
将匹配率大于预设值的候选文本作为与所述查询文本相匹配的目标候选文本。
综上所述,本申请实施例通过利用EMD算法对老师模型中的transformer层进行模型蒸馏,能够使得学生模型学习到更多老师模型中的知识,从而获得的文本检索模型能够准确地从候选文本中获得与查询文件相匹配的目标文本。
图7为本申请实施例提供的电子设备实体结构示意图,如图7所示,所述电子设备,包括:处理器(processor)701、存储器(memory)702和总线703;其中,
所述处理器701和存储器702通过所述总线703完成相互间的通信;
所述处理器701用于调用所述存储器702中的程序指令,以执行上述各方法实施例所提供的方法,例如包括:获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。或
获取查询文本和候选文本;将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
处理器701可以是一种集成电路芯片,具有信号处理能力。上述处理器701可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(NetworkProcessor,NP)等;还可以是数字信号处理器(DSP)、专用集成电路(ASIC)、现成可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。其可以实现或者执行本申请实施例中公开的各种方法、步骤及逻辑框图。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器702可以包括但不限于随机存取存储器(Random Access Memory,RAM),只读存储器(Read Only Memory,ROM),可编程只读存储器(Programmable Read-OnlyMemory,PROM),可擦除只读存储器(Erasable Programmable Read-Only Memory,EPROM),电可擦除只读存储器(Electrically Erasable Programmable Read-Only Memory,EEPROM)等。
本实施例公开一种计算机程序产品,所述计算机程序产品包括存储在非暂态计算机可读存储介质上的计算机程序,所述计算机程序包括程序指令,当所述程序指令被计算机执行时,计算机能够执行上述各方法实施例所提供的方法,例如包括:获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。或
获取查询文本和候选文本;将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
本实施例提供一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令使所述计算机执行上述各方法实施例所提供的方法,例如包括:获取老师模型和学生模型,其中,所述老师模型包括第一embedding层、第一transformer层和第一prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;所述第一transformer层的层数大于所述第二transformer层的层数;获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型。或
获取查询文本和候选文本;将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的transformer层为基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
在本申请所提供的实施例中,应该理解到,所揭露装置和方法,可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,又例如,多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些通信接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
另外,作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
再者,在本申请各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。
在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。
以上所述仅为本申请的实施例而已,并不用于限制本申请的保护范围,对于本领域的技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (10)
1.一种模型蒸馏方法,其特征在于,包括:
获取老师模型和学生模型,其中,所述老师模型包括第一向量embedding层、第一转换器transformer层和第一预测prediction层,所述学生模型包括第二embedding层、第二transformer层和第二prediction层;第一transformer层的层数大于第二transformer层的层数;
获取训练样本,并将所述训练样本分别输入所述老师模型和所述学生模型中;其中,所述老师模型为经过预先训练获得的;所述训练样本包括查询文本以及至少一个候选文本,所述至少一个候选文本的标签表示所述至少一个候选文本是否与所述查询文本匹配;
利用搬土距离EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,并分别计算第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;
根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,获得训练好的学生模型,所述训练好的学生模型用于文本检索。
2.根据权利要求1所述的方法,其特征在于,所述利用EMD计算所述第一transformer层的输出与第二transformer层的输出之间的第一蒸馏损失,包括:
获取所述第一transformer层中各层分别输出的第一注意力attention矩阵以及所述第二transformer层中各层分别输出的第二attention矩阵;
根据第一attention矩阵和所述第二attention矩阵计算第一EMD距离;
获取所述第一transformer层中各层分别输出的第一全连接前馈神经网络FFN隐层矩阵和所述第二transformer层中各层分别输出的第二FFN隐层矩阵;
根据第一FFN隐层矩阵和所述第二FFN隐层矩阵计算第二EMD距离;
根据所述第一EMD距离和所述第二EMD距离获得所述第一蒸馏损失;
所述第一蒸馏损失的计算方法包括:
Lt=Lattn+Lffn
其中,Lt为第一蒸馏损失,Lattn为第一EMD距离,Lffn为第二EMD距离。
3.根据权利要求2所述的方法,其特征在于,所述根据第一attention矩阵和所述第二attention矩阵计算第一EMD距离,包括:
根据计算获得所述第一EMD距离;
其中,Lattn为所述第一EMD距离,AT为所述第一attention矩阵,AS为所述第二attention矩阵,为第一attention矩阵与第二attention矩阵之间的均方误差,且 为第i层第一transformer层的第一attention矩阵,/>为第j层第二transformer层的第二attention矩阵,fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数。
4.根据权利要求2所述的方法,其特征在于,所述根据所述第一FFN隐层矩阵和所述第二FFN隐层矩阵计算第二EMD距离,包括:
根据计算获得第二EMD距离;
其中,Lffn为所述第二EMD距离,HT为第一transformer层的第一FFN隐层矩阵,HS为第二transformer层的第二FFN隐层矩阵,为第一FFN隐层矩阵与第二FFN隐层矩阵的均方误差,且/>为第j层第二transformer层的第二FFN隐层矩阵,Wh为第一预设转换矩阵,/>为第i层第一transformer层的第一FFN隐层矩阵;fij为从第i层第一transformer层迁移到第j层第二transformer层的知识量,M为第一transformer层的层数,N为第二transformer层的层数。
5.根据权利要求1所述的方法,其特征在于,所述计算所述第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,包括:
根据Le=MSE(ESWe,ET)计算获得所述第二蒸馏损失;
其中,Le为所述第二蒸馏损失,ES为学生模型的embedding层输出的向量矩阵,We为第二预设转换矩阵,ET为老师模型的embedding层输出的向量矩阵。
6.根据权利要求1所述的方法,其特征在于,所述计算所述第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失,包括:
根据Lp=αLph+(1-α)Τ2Lps计算获得所述第三蒸馏损失;
其中,Lp为所述第三蒸馏损失,α为每个损失的权重,Τ为温度,Lph为学生模型的输出与真实标签之间的交叉熵损失,Lps为学生模型的输出与老师模型的输出之间的交叉熵损失,且Lph=-Y·softmax(zS),Lps=-softmax(zT/T)·log_softmax(zS/T),Y为真实标签,zT为老师模型的输出,zS为学生模型的输出。
7.根据权利要求1-6任一项所述的方法,其特征在于,所述根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行优化,包括:
根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失输出获得所述学生模型的目标函数;其中,所述目标函数为Lmodel=∑i∈{e,t,p}λiLi(Si,Ti),e表示embedding层,t表示transformer层,p表示prediction层,λi表示相应的权重,Li表示相应层的蒸馏损失,Si表示学生模型,Ti表示老师模型;
根据所述目标函数对所述学生模型中的参数进行优化。
8.一种文本检索方法,其特征在于,包括:
获取查询文本和候选文本;
将所述查询文本和所述候选文本输入预先训练获得的文本检索模型中,获得所述文本检索模型输出的所述查询文本与所述候选文本的匹配率;其中,所述文本检索模型为Bert模型,且所述文本检索模型中的转换器transformer层为采用如权利要求1~7中任一项所述模型蒸馏方法中基于搬土距离EMD对老师模型中的transformer层进行模型蒸馏后获得;
确定满足预设要求的匹配率对应的候选文本为与所述查询文本相匹配的目标候选文本。
9.一种电子设备,其特征在于,包括:处理器、存储器和总线,其中,
所述处理器和所述存储器通过所述总线完成相互间的通信;
所述存储器存储有可被所述处理器执行的程序指令,所述处理器调用所述程序指令能够执行如权利要求1-8任一项所述的方法。
10.一种非暂态计算机可读存储介质,其特征在于,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令被计算机运行时,使所述计算机执行如权利要求1-8任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010405217.6A CN111553479B (zh) | 2020-05-13 | 2020-05-13 | 一种模型蒸馏方法、文本检索方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010405217.6A CN111553479B (zh) | 2020-05-13 | 2020-05-13 | 一种模型蒸馏方法、文本检索方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111553479A CN111553479A (zh) | 2020-08-18 |
CN111553479B true CN111553479B (zh) | 2023-11-03 |
Family
ID=72008143
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010405217.6A Active CN111553479B (zh) | 2020-05-13 | 2020-05-13 | 一种模型蒸馏方法、文本检索方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111553479B (zh) |
Families Citing this family (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111967941B (zh) * | 2020-08-20 | 2024-01-05 | 中国科学院深圳先进技术研究院 | 一种构建序列推荐模型的方法和序列推荐方法 |
CN111898707B (zh) * | 2020-08-24 | 2024-06-21 | 鼎富智能科技有限公司 | 文本分类方法、电子设备及存储介质 |
CN112101032B (zh) * | 2020-08-31 | 2024-09-24 | 广州探迹科技有限公司 | 一种基于自蒸馏的命名实体识别与纠错方法 |
CN112101484B (zh) * | 2020-11-10 | 2021-02-12 | 中国科学院自动化研究所 | 基于知识巩固的增量事件识别方法、系统、装置 |
CN112507209B (zh) * | 2020-11-10 | 2022-07-05 | 中国科学院深圳先进技术研究院 | 一种基于陆地移动距离进行知识蒸馏的序列推荐方法 |
CN112101573B (zh) * | 2020-11-16 | 2021-04-30 | 智者四海(北京)技术有限公司 | 一种模型蒸馏学习方法、文本查询方法及装置 |
CN112464760A (zh) * | 2020-11-16 | 2021-03-09 | 北京明略软件系统有限公司 | 一种目标识别模型的训练方法和装置 |
CN112418291B (zh) * | 2020-11-17 | 2024-07-26 | 平安科技(深圳)有限公司 | 一种应用于bert模型的蒸馏方法、装置、设备及存储介质 |
CN112733550B (zh) * | 2020-12-31 | 2023-07-25 | 科大讯飞股份有限公司 | 基于知识蒸馏的语言模型训练方法、文本分类方法及装置 |
CN113449610A (zh) * | 2021-06-08 | 2021-09-28 | 杭州格像科技有限公司 | 一种基于知识蒸馏和注意力机制的手势识别方法和系统 |
CN113850362A (zh) * | 2021-08-20 | 2021-12-28 | 华为技术有限公司 | 一种模型蒸馏方法及相关设备 |
CN115329063B (zh) * | 2022-10-18 | 2023-01-24 | 江西电信信息产业有限公司 | 一种用户的意图识别方法及系统 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018169708A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN109637546A (zh) * | 2018-12-29 | 2019-04-16 | 苏州思必驰信息科技有限公司 | 知识蒸馏方法和装置 |
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN110188358A (zh) * | 2019-05-31 | 2019-08-30 | 北京神州泰岳软件股份有限公司 | 自然语言处理模型的训练方法及装置 |
CN110287494A (zh) * | 2019-07-01 | 2019-09-27 | 济南浪潮高新科技投资发展有限公司 | 一种基于深度学习bert算法的短文本相似匹配的方法 |
CN111062489A (zh) * | 2019-12-11 | 2020-04-24 | 北京知道智慧信息技术有限公司 | 一种基于知识蒸馏的多语言模型压缩方法、装置 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20090119095A1 (en) * | 2007-11-05 | 2009-05-07 | Enhanced Medical Decisions. Inc. | Machine Learning Systems and Methods for Improved Natural Language Processing |
-
2020
- 2020-05-13 CN CN202010405217.6A patent/CN111553479B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018169708A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN109637546A (zh) * | 2018-12-29 | 2019-04-16 | 苏州思必驰信息科技有限公司 | 知识蒸馏方法和装置 |
CN110188358A (zh) * | 2019-05-31 | 2019-08-30 | 北京神州泰岳软件股份有限公司 | 自然语言处理模型的训练方法及装置 |
CN110287494A (zh) * | 2019-07-01 | 2019-09-27 | 济南浪潮高新科技投资发展有限公司 | 一种基于深度学习bert算法的短文本相似匹配的方法 |
CN111062489A (zh) * | 2019-12-11 | 2020-04-24 | 北京知道智慧信息技术有限公司 | 一种基于知识蒸馏的多语言模型压缩方法、装置 |
Non-Patent Citations (3)
Title |
---|
Ofir Press.et al."Improving Transformer Models by Reordering their Sublayers".《Arxiv》.2020,第1-10页. * |
Xiaoqi Jiao.et al."TINYBERT: DISTILLING BERT FOR NATURAL LANGUAGE UNDERSTANDING".《Arxiv》.2019,第1-13页. * |
岳一峰等."一种基于BERT的自动文本摘要模型构建方法".计算机与现代化.2020,(第01期),全文. * |
Also Published As
Publication number | Publication date |
---|---|
CN111553479A (zh) | 2020-08-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111553479B (zh) | 一种模型蒸馏方法、文本检索方法及装置 | |
US11113479B2 (en) | Utilizing a gated self-attention memory network model for predicting a candidate answer match to a query | |
CN109033068B (zh) | 基于注意力机制的用于阅读理解的方法、装置和电子设备 | |
CN108647233B (zh) | 一种用于问答系统的答案排序方法 | |
CN109376222B (zh) | 问答匹配度计算方法、问答自动匹配方法及装置 | |
CN111554268A (zh) | 基于语言模型的语言识别方法、文本分类方法和装置 | |
CN107590127B (zh) | 一种题库知识点自动标注方法及系统 | |
CN118349673A (zh) | 文本处理模型的训练方法、文本处理方法及装置 | |
WO2023197613A1 (zh) | 一种小样本微调方法、系统及相关装置 | |
CN106649514A (zh) | 用于受人启发的简单问答(hisqa)的系统和方法 | |
CN112287089B (zh) | 用于自动问答系统的分类模型训练、自动问答方法及装置 | |
CN112015868A (zh) | 基于知识图谱补全的问答方法 | |
WO2023137911A1 (zh) | 基于小样本语料的意图分类方法、装置及计算机设备 | |
CN111078847A (zh) | 电力用户意图识别方法、装置、计算机设备和存储介质 | |
CN113626589A (zh) | 一种基于混合注意力机制的多标签文本分类方法 | |
CN113204633B (zh) | 一种语义匹配蒸馏方法及装置 | |
CN112906398B (zh) | 句子语义匹配方法、系统、存储介质和电子设备 | |
CN115827819A (zh) | 一种智能问答处理方法、装置、电子设备及存储介质 | |
CN114282528A (zh) | 一种关键词提取方法、装置、设备及存储介质 | |
WO2023134085A1 (zh) | 问题答案的预测方法、预测装置、电子设备、存储介质 | |
CN116136870A (zh) | 基于增强实体表示的智能社交对话方法、对话系统 | |
CN116341651A (zh) | 实体识别模型训练方法、装置、电子设备及存储介质 | |
CN113204679B (zh) | 一种代码查询模型的生成方法和计算机设备 | |
CN114398893A (zh) | 一种基于对比学习的临床数据处理模型的训练方法及装置 | |
CN117668157A (zh) | 基于知识图谱的检索增强方法、装置、设备及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |