CN113673698A - 适用于bert模型的蒸馏方法、装置、设备及存储介质 - Google Patents
适用于bert模型的蒸馏方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN113673698A CN113673698A CN202110975567.0A CN202110975567A CN113673698A CN 113673698 A CN113673698 A CN 113673698A CN 202110975567 A CN202110975567 A CN 202110975567A CN 113673698 A CN113673698 A CN 113673698A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- integrated
- prediction
- distillation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 125
- 238000000034 method Methods 0.000 title claims abstract description 61
- 238000012549 training Methods 0.000 claims abstract description 208
- 238000004364 calculation method Methods 0.000 claims abstract description 53
- 230000010354 integration Effects 0.000 claims description 114
- 238000010276 construction Methods 0.000 claims description 18
- 238000012795 verification Methods 0.000 claims description 15
- 238000002372 labelling Methods 0.000 claims description 14
- 238000004590 computer program Methods 0.000 claims description 11
- 238000010200 validation analysis Methods 0.000 claims description 5
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 238000009792 diffusion process Methods 0.000 abstract description 2
- 238000005516 engineering process Methods 0.000 abstract description 2
- 230000008569 process Effects 0.000 description 9
- ZDXPYRJPNDTMRX-VKHMYHEASA-N L-glutamine Chemical compound OC(=O)[C@@H](N)CCC(N)=O ZDXPYRJPNDTMRX-VKHMYHEASA-N 0.000 description 5
- 238000010586 diagram Methods 0.000 description 3
- 239000006185 dispersion Substances 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000001360 synchronised effect Effects 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000002457 bidirectional effect Effects 0.000 description 1
- 230000006870 function Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
- G06F16/353—Clustering; Classification into predefined classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Biophysics (AREA)
- Evolutionary Computation (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请涉及人工智能技术领域,揭示了一种适用于BERT模型的蒸馏方法、装置、设备及存储介质,其中方法包括:获取预训练BERT模型和门控线性单元;根据预训练BERT模型和门控线性单元进行模型构建,得到学生模型;获取第一训练集和教师模型,其中,教师模型是基于预训练BERT模型训练得到的集成模型;根据第一训练集和教师模型,对学生模型进行蒸馏训练,得到目标模型。通过门控线性单元减少了BERT模型的维度数,有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,从而提高了学生模型的计算速度和准确度,提高了目标模型的实时性。本申请可适用于智慧政务、数字医疗、科技金融等领域。
Description
技术领域
本申请涉及到人工智能技术领域,特别是涉及到一种适用于BERT模型的蒸馏方法、装置、设备及存储介质。
背景技术
BERT(Bidirectional Encoder Representation from Transformers)模型是自然语言处理(NLP)领域的热门模型。因BERT模型中存在大量的参数,导致模型的训练效率较低,训练后的BERT模型应用到生成环境时的实时性较差。通常采用蒸馏技术解决该问题,也就是通过对BERT模型训练好后,采用一个参数较少的学生模型模仿训练好后的BERT模型,然后将完成模仿的学生模型运用到生成环境中,以用于提高实时性。学生模型采用简化结构的BERT模型或非Bert模型的网络结构,简化结构的BERT模型导致学生模型的准确度有限,非Bert模型的网络结构导致学生模型的计算速度和准确度有限。
发明内容
本申请的主要目的为提供一种适用于BERT模型的蒸馏方法、装置、设备及存储介质,旨在解决采用参数较少的学生模型模仿训练好后的BERT模型时,简化结构的BERT模型导致学生模型的准确度有限,非Bert模型的网络结构导致学生模型的计算速度和准确度有限的技术问题。
为了实现上述发明目的,本申请提出一种适用于BERT模型的蒸馏方法,所述方法包括:
获取预训练BERT模型和门控线性单元;
根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;
获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;
根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
进一步的,所述根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,包括:
根据所述门控线性单元,对所述预训练BERT模型中的每个隐藏层对应的TransFormer Block单元分别进行替换处理,将替换处理后的所述预训练BERT模型作为所述学生模型,其中,所述学生模型的所述隐藏层与所述预训练BERT模型的所述隐藏层的层数相同。
进一步的,所述获取第一训练集和教师模型之前,还包括:
获取第二训练集、验证集和超参数数据列表;
根据所述第二训练集和所述超参数数据列表中的每个超参数数据,分别对所述预训练BERT模型进行训练,将训练结束的所述预训练BERT模型作为待集成模型;
根据所述验证集,从各个所述待集成模型中获取初始集成模型,所述初始集成模型为各个所述待集成模型中预测准确率最高的模型;
基于贪婪集成方法,根据所述初始集成模型,对各个所述待集成模型进行集成,得到所述教师模型。
进一步的,所述根据所述验证集,从各个所述待集成模型中获取初始集成模型,包括:
分别将所述验证集中的每个第三训练样本输入每个所述待集成模型中进行分类概率预测,得到每个所述第三训练样本在每个所述待集成模型中的第一预测概率值;
根据每个所述第一预测概率值进行标注,得到每个所述第三训练样本在每个所述待集成模型中的目标预测标注值;
根据各个所述目标预测标注值和各个所述第三训练样本各自对应的第三原始标定值,对每个所述待集成模型进行预测准确率计算,得到每个所述待集成模型的待分析预测准确率;
将各个所述待分析预测准确率中的最大所述待分析预测准确率对应的所述待集成模型作为所述初始集成模型。
进一步的,所述基于贪婪集成方法,根据所述初始集成模型,对各个所述待集成模型进行集成,得到所述教师模型,包括:
从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型;
将所述初始集成模型和所述待分析模型进行组合,得到待评估集成模型;
获取所述待评估集成模型对应的各个所述待集成模型各自对应的所述第一预测概率值,得到待分析的预测概率值集;
对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算,得到所述待评估集成模型对应的预测概率综合值;
根据每个所述预测概率综合值进行标注,得到每个所述第三训练样本在所述待评估集成模型中的综合预测标注值;
根据各个所述综合预测标注值和各个所述第三训练样本的所述第三原始标定值,对所述待评估集成模型进行预测准确率计算,得到所述待评估集成模型的目标预测准确率;
重复执行所述从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型的步骤,直至完成所述初始集成模型以外的所述待集成模型的获取;
将所述目标预测准确率为最大的所述待评估集成模型作为待处理集成模型;
获取所述初始集成模型的所述目标预测准确率;
当所述初始集成模型的所述目标预测准确率小于所述待处理集成模型的所述目标预测准确率时,将所述待处理集成模型作为所述初始集成模型,重复执行所述从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型的步骤,直至所述初始集成模型的所述目标预测准确率大于或等于所述待处理集成模型的所述目标预测准确率;
将所述初始集成模型作为所述教师模型。
进一步的,所述根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,包括:
从所述第一训练集中获取第一训练样本作为目标样本;
根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值;
将所述目标样本输入所述学生模型进行交叉熵损失计算,得到所述目标样本的交叉熵损失值;
根据所述目标样本的所述蒸馏损失值和所述交叉熵损失值,对所述学生模型的网络参数进行更新,更新后的所述学生模型用于下一次进行蒸馏训练;
重复执行所述从所述第一训练集中获取第一训练样本作为目标样本的步骤,直至达到蒸馏训练结束条件,将达到所述蒸馏训练结束条件的所述学生模型作为所述目标模型。
进一步的,所述根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值,包括:
将所述目标样本输入所述教师模型进行分类概率预测;
获取所述教师模型中的各个模型的各个隐藏层输出的数据,得到多个待加权特征数据;
根据各个所述待加权特征数据,针对所述教师模型的每类所述隐藏层进行加权求和,得到所述教师模型的每类所述隐藏层对应的第一特征数据;
获取所述教师模型输出的分类概率预测的数据,得到第二预测概率值;
将所述目标样本输入所述学生模型进行分类概率预测;
获取所述学生模型的各个所述隐藏层输出的数据,得到多个第二特征数据;
获取所述学生模型输出的分类概率预测的数据,得到第三预测概率值;
根据各个所述第一特征数据和各个所述第二特征数据,针对所述学生模型进行每个所述隐藏层的蒸馏损失计算,得到多个层蒸馏损失;
根据所述第二预测概率值和所述第三预测概率值,对所述学生模型进行蒸馏损失计算,得到综合蒸馏损失;
根据各个所述层蒸馏损失和所述综合蒸馏损失进行求和计算,得到所述目标样本的所述蒸馏损失值。
本申请还提出了一种适用于BERT模型的蒸馏装置,所述装置包括:
第一数据获取模块,用于获取预训练BERT模型和门控线性单元;
学生模型确定模块,用于根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;
第二数据获取模块,用于获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;
目标模型确定模块,用于根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
本申请还提出了一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现上述任一项所述方法的步骤。
本申请还提出了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任一项所述的方法的步骤。
本申请的适用于BERT模型的蒸馏方法、装置、设备及存储介质,其中方法通过首先根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,然后获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型,最后根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,因学生模型是根据所述预训练BERT模型和所述门控线性单元进行模型构建得到的,通过门控线性单元减少了BERT模型的维度数,有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,从而提高了学生模型的计算速度和准确度,提高了目标模型的实时性。
附图说明
图1为本申请一实施例的适用于BERT模型的蒸馏方法的流程示意图;
图2为本申请一实施例的适用于BERT模型的蒸馏装置的结构示意框图;
图3为本申请一实施例的计算机设备的结构示意框图。
本申请目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
参照图1,本申请实施例中提供一种适用于BERT模型的蒸馏方法,所述方法包括:
S1:获取预训练BERT模型和门控线性单元;
S2:根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;
S3:获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;
S4:根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
本实施例通过首先根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,然后获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型,最后根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,因学生模型是根据所述预训练BERT模型和所述门控线性单元进行模型构建得到的,通过门控线性单元减少了BERT模型的维度数,有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,从而提高了学生模型的计算速度和准确度,提高了目标模型的实时性。
对于S1,可以获取用户输入的预训练BERT模型和门控线性单元,也可以从数据库中获取预训练BERT模型和门控线性单元,还可以从第三方应用系统中获取预训练BERT模型和门控线性单元。
预训练BERT模型,是谷歌预训练好的BERT-BASE模型。预训练BERT模型包括12个隐藏层,每个隐藏层包括一个TransFormer Block单元。TransFormer Block单元依次包括:Attention子层和FNN子层,其中,Attention子层的序列的序列长度是512,Attention子层的维度数为768,预训练BERT模型的FNN子层的序列长度是512,预训练BERT模型的FNN子层的维度数为768。
TransFormer Block,是一种基于Encoder-Decoder(编码-解码)结构的模型。
Attention子层,是采用多头注意力机制的层。
FNN子层,是基于前馈神经网络得到的子层。
门控线性单元,也就是GLU block。门控线性单元依次包括GLUm子层和FNN子层,其中GLUm子层的序列长度是512,GLUm子层的维度数为128,门控线性单元的FNN子层的序列长度是512,门控线性单元的FNN子层的维度数为128。
GLUm子层,用于实现线性门控的层。
对于S2,采用所述门控线性单元对所述预训练BERT模型的隐藏层的TransFormerBlock单元进行替换,将完成替换后的所述预训练BERT模型作为学生模型,从而使学生模型的隐藏层从768维减少到128维度,减少了学生模型的参数量。而门控线性单元的GLUm子层相对TransFormer Block单元的Attention子层,参数量明显减少,从而使学生模型的计算量相对所述预训练BERT模型的计算量明显减少。
对于S3,可以获取用户输入的第一训练集和教师模型,也可以从数据库中获取第一训练集和教师模型,还可以从第三方应用系统中获取第一训练集和教师模型。
其中,通过修改所述预训练BERT模型的超参数,采用第二训练集对所述预训练BERT模型训练出多个待集成模型,然后采用训练集计算每个待集成模型的预测概率值,基于贪婪集成方法和验证集,根据各个待集成模型的预测概率值,对各个待集成模型进行模型集成,将模型集成得到的模型作为所述教师模型。
第一训练集中包括多个第一训练样本,每个第一训练样本包括:第一文本样本数据和第一原始标定值。第一文本样本数据,也就是文本数据。第一原始标定值是一个向量,该向量的每个向量元素的取值有两个,可以是0,也可以是1。
第二训练集中包括多个第二训练样本,每个第二训练样本包括:第二文本样本数据和第二原始标定值。第二文本样本数据,也就是文本数据。第二原始标定值是一个向量,该向量的每个向量元素的取值有两个,可以是0,也可以是1。
验证集中包括多个第三训练样本,每个第三训练样本包括:第三文本样本数据和第三原始标定值。第三文本样本数据,也就是文本数据。第三原始标定值是一个向量,该向量的每个向量元素的取值有两个,可以是0,也可以是1。
对于S4,根据所述第一训练集,对所述教师模型和所述学生模型进行蒸馏,得到所述学生模型的蒸馏损失,将所述第一训练集输入所述学生模型进行进行交叉熵损失计算,得到所述学生模型的交叉熵损失,最后根据所述学生模型的蒸馏损失和所述学生模型的交叉熵损失进行所述学生模型的最终损失计算,根据计算得到的最终损失更新所述学生模型,将蒸馏训练结束的学生模型作为目标模型。
也就是说,目标模型模仿了所述教师模型的分类概率预测的功能和性能,但是目标模型的参数相对所述教师模型和预训练BERT模型少,从而提高了目标模型的计算速度和准确度,提高了目标模型的实时性。
可选的,所述根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型之后,还包括:获取待预测的文本数据;将所述待预测的文本数据输入所述目标模型进行分类概率预测,得到待预测的文本数据对应的分类概率预测结果;获取预设概率阈值;根据预设概率阈值和分类概率预测结果确定待预测的文本数据的目标分类标签。
其中,将分类概率预测结果中大于预设概率阈值的向量元素对应的分类标签作为待预测的文本数据的目标分类标签。可以理解的是,目标分类标签的数量小于或等于分类概率预测结果的向量元素的数量。
可以理解的是,分类概率预测结果的向量元素的数量可以是1个,也可以是多个。分类概率预测结果的向量元素的取值范围为0-1,可以包括0,也可以包括1。
预设概率阈值的取值范围为0-1,不包括0,可以包括1。
可以理解的是,当本申请应用于数字医疗领域时,第一文本样本数据、第二文本样本数据、第三文本样本数据和待预测的文本数据均是数字医疗领域的文本数据,比如,病历文本数据。
在一个实施例中,上述根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,包括:
S21:根据所述门控线性单元,对所述预训练BERT模型中的每个隐藏层对应的TransFormer Block单元分别进行替换处理,将替换处理后的所述预训练BERT模型作为所述学生模型,其中,所述学生模型的所述隐藏层与所述预训练BERT模型的所述隐藏层的层数相同。
本实施例实现了将所述预训练BERT模型中的每个隐藏层对应的TransFormerBlock单元替换成门控线性单元,从而替换处理后的所述预训练BERT模型的隐藏层的维度数变成为128,从而减少了参数数量,提高了模型的计算速度;又因是每个隐藏层对应的TransFormer Block单元替换成一个门控线性单元,使替换处理后的所述预训练BERT模型的隐藏层相对所述预训练BERT模型的隐藏层的数量不变,提供门控线性单元有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,提高了教师模型的准确度。
对于S21,所述预训练BERT模型中的每个隐藏层对应的TransFormer Block单元采用一个所述门控线性单元进行替换。
在一个实施例中,上述获取第一训练集和教师模型之前,还包括:
S31:获取第二训练集、验证集和超参数数据列表;
S32:根据所述第二训练集和所述超参数数据列表中的每个超参数数据,分别对所述预训练BERT模型进行训练,将训练结束的所述预训练BERT模型作为待集成模型;
S33:根据所述验证集,从各个所述待集成模型中获取初始集成模型,所述初始集成模型为各个所述待集成模型中预测准确率最高的模型;
S34:基于贪婪集成方法,根据所述初始集成模型,对各个所述待集成模型进行集成,得到所述教师模型。
本实施例通过先针对每个超参数数据训练一个待集成模型,然后从各个待集成模型中获取初始集成模型,最后根据初始集成模型对待集成模型进行集成,教师模型相对预训练BERT模型提高了的准确度,从而进一步提高了目标模型的准确度。
对于S31,可以获取用户输入的第二训练集、验证集和超参数数据列表,也可以从数据库中获取第二训练集、验证集和超参数数据列表,还可以从第三方应用系统中获取第二训练集、验证集和超参数数据列表。
超参数数据列表包括多个超参数数据。每个超参数数据包括:超参数和超参数值,每个超参数对应一个超参数值。
可选的,超参数数据中的超参数包括:随机数种子和学习率。
可选的,超参数数据列表中的第一个超参数数据为:随机数种子设置为123,学习率设置为0.001,超参数数据列表中的第二个超参数数据为:随机数种子设置为234和学习率设置为0.001,超参数数据列表中的第三个超参数数据为:随机数种子设置为123和学习率设置为0.005,超参数数据列表中的第四个超参数数据为:随机数种子设置为234和学习率设置为0.005。可以理解的是,超参数数据列表的超参数数据的超参数数据还可以包括其他数值,在此不做限定。
对于S32,所述超参数数据列表中的一个超参数数据作为目标超参数数据;将目标超参数数据更新到所述预训练BERT模型,得到待训练模型;采用所述第二训练集的各个第二训练样本对待训练模型进行训练,将训练结束的待训练模型作为待集成模型。
训练结束的待训练模型,是指达到预设模型训练结束条件的待训练模型。
预设模型训练结束条件包括:待训练模型的损失值达到第一收敛条件或待训练模型的迭代次数达到第二收敛条件。
所述第一收敛条件是指相邻两次计算的待训练模型的损失值的大小满足lipschitz条件(利普希茨连续条件)。
所述迭代次数是指所述待训练模型被训练的次数,也就是说,被训练一次,迭代次数增加1。
第二收敛条件是一个具体数值。
对于S33,根据所述验证集,对每个所述待集成模型进行预测准确率计算,然后找出预测准确率最高的待集成模型作为初始集成模型。
对于S34,基于贪婪集成方法,也就是将所述初始集成模型和从各个所述待集成模型中的除所述初始集成模型以外的所述待集成模型中选择的最优模型进行集成,重复执行集成过程,直至初始集成模型的预测准确率不再提高,将预测准确率不再提高的初始集成模型作为所述教师模型。也就是说,教师模型中包括一个或多个超参数不同的待集成模型,从而使教师模型相对预训练BERT模型提高了的准确度,进一步提高了目标模型的准确度。
在一个实施例中,上述根据所述验证集,从各个所述待集成模型中获取初始集成模型,包括:
S331:分别将所述验证集中的每个第三训练样本输入每个所述待集成模型中进行分类概率预测,得到每个所述第三训练样本在每个所述待集成模型中的第一预测概率值;
S332:根据每个所述第一预测概率值进行标注,得到每个所述第三训练样本在每个所述待集成模型中的目标预测标注值;
S333:根据各个所述目标预测标注值和各个所述第三训练样本各自对应的第三原始标定值,对每个所述待集成模型进行预测准确率计算,得到每个所述待集成模型的待分析预测准确率;
S334:将各个所述待分析预测准确率中的最大所述待分析预测准确率对应的所述待集成模型作为所述初始集成模型。
本实施例根据所述验证集,对每个所述待集成模型进行预测准确率计算,然后找出预测准确率最高的待集成模型作为初始集成模型,为构建具有包括一个或多个超参数不同的待集成模型的教师模型提供了基础。
对于S331,分别将所述验证集中的每个第三训练样本输入每个所述待集成模型中进行分类概率预测,所述待集成模型针对每个第三训练样本输出一个预测数据,将每个预测数据作为一个第一预测概率值。也就是说,第一预测概率值的数量等于第三训练样本的数量与所述待集成模型的数量的乘积。
对于S332,获取所述预设概率阈值;将所述第一预测概率值中的大于所述预设概率阈值的每个向量元素标注为1,将所述第一预测概率值中的小于或等于所述预设概率阈值的每个向量元素标注为0,将完成标注的所述第一预测概率值作为目标预测标注值。也就是说,目标预测标注值的数量等于第三训练样本的数量与所述待集成模型的数量的乘积。
可以理解的是,预设概率阈值可以是0.5、0.55、0.6、0.65、0.7,在此举例不做具体限定。
对于S333,根据各个所述目标预测标注值,针对每个所述待集成模型进行划分,得到每个所述待集成模型对应的目标预测标注值集;根据目标预测标注值集和各个所述第三训练样本各自对应的第三原始标定值,对所述待集成模型进行预测准确率计算,将计算得到的预测准确率作为待分析预测准确率。
对于S334,从各个所述待分析预测准确率中找出最大的所述待分析预测准确率,将找出的所述待分析预测准确率对应的所述待集成模型作为所述初始集成模型。
在一个实施例中,上述基于贪婪集成方法,根据所述初始集成模型,对各个所述待集成模型进行集成,得到所述教师模型,包括:
S341:从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型;
S342:将所述初始集成模型和所述待分析模型进行组合,得到待评估集成模型;
S343:获取所述待评估集成模型对应的各个所述待集成模型各自对应的所述第一预测概率值,得到待分析的预测概率值集;
S344:对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算,得到所述待评估集成模型对应的预测概率综合值;
S345:根据每个所述预测概率综合值进行标注,得到每个所述第三训练样本在所述待评估集成模型中的综合预测标注值;
S346:根据各个所述综合预测标注值和各个所述第三训练样本的所述第三原始标定值,对所述待评估集成模型进行预测准确率计算,得到所述待评估集成模型的目标预测准确率;
S347:重复执行所述从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型的步骤,直至完成所述初始集成模型以外的所述待集成模型的获取;
S348:将所述目标预测准确率为最大的所述待评估集成模型作为待处理集成模型;
S349:获取所述初始集成模型的所述目标预测准确率;
S3410:当所述初始集成模型的所述目标预测准确率小于所述待处理集成模型的所述目标预测准确率时,将所述待处理集成模型作为所述初始集成模型,重复执行所述从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型的步骤,直至所述初始集成模型的所述目标预测准确率大于或等于所述待处理集成模型的所述目标预测准确率;
S3411:将所述初始集成模型作为所述教师模型。
本实施例基于贪婪集成方法,也就是根据所述初始集成模型和各个所述待集成模型中的除所述初始集成模型以外的所述待集成模型中选择最优模型进行集成,重复执行集成过程,直至初始集成模型的预测准确率不再提高,将预测准确率不再提高的初始集成模型作为所述教师模型,从而使教师模型相对预训练BERT模型提高了的准确度,从而进一步提高了目标模型的准确度。
对于S341,将所述初始集成模型对应的模型从各个所述待集成模型中剔除,然后从剩余的所有所述待集成模型中获取一个所述待集成模型作为待分析模型。
比如,所述初始集成模型包括:(mt1,mt2),首先将所述待集成模型mt1和所述待集成模型mt2从各个所述待集成模型中剔除,剩余的所有所述待集成模型包括mt3、mt3、mt5、mt6,则从mt3、mt3、mt5、mt6中获取一个所述待集成模型作为待分析模型,在此举例不做具体限定。
对于S342,将所述初始集成模型和所述待分析模型进行组合,将组合后的模型作为待评估集成模型。
其中,将所述初始集成模型和所述待分析模型进行组合,也就是将所述初始集成模型中的各个所述初始集成模型的输出结果和所述待分析模型的输出结果进行加权求和,将加权求和得到的数据作为待评估集成模型的最终输出结果。
可以理解的是,将所述初始集成模型中的各个所述初始集成模型的输出结果和所述待分析模型的输出结果进行加权求和,可以是平均值计算,也可以是非平均值计算,在此不做限定。
比如,所述初始集成模型为(mt1,mt2),所述待分析模型是mt4,则待评估集成模型包括:(mt1,mt2,mt4),将所述初始集成模型中的各个所述初始集成模型的输出结果和所述待分析模型的输出结果进行加权求和就是将(mt1,mt2,mt4)中的三个所述初始集成模型的输出结果进行加权求和,在此举例不做具体限定。
对于S343,可以从数据库中获取所述待评估集成模型对应的各个所述待集成模型各自对应的所述第一预测概率值,也可以从第三方应用系统中获取所述待评估集成模型对应的各个所述待集成模型各自对应的所述第一预测概率值,将获取的各个第一预测概率值作为待分析的预测概率值集。
对于S344,对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算,将加权计算得到的数据作为所述待评估集成模型对应的预测概率综合值。
可以理解的是,对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算,可以是对所述待分析的预测概率值集中的各个所述第一预测概率值进行平均值计算,也可以是对所述待分析的预测概率值集中的各个所述第一预测概率值进行非平均值计算,在此不做限定。
可选的,对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算得到的预测概率综合值的计算公式G1为:
其中,W1、W2、Wn是常量,f1是所述待评估集成模型对应的第1个所述待集成模型对应的所述第一预测概率值,f2是所述待评估集成模型对应的第2个所述待集成模型对应的所述第一预测概率值,fn是所述待评估集成模型对应的第n个所述待集成模型对应的所述第一预测概率值。
可选的,对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算得到的预测概率综合值的计算公式G2为:
其中,k、q是常量,q的取值范围是0-1(包括1,不包括0),f1是所述待评估集成模型对应的第1个所述待集成模型对应的所述第一预测概率值,f2是所述待评估集成模型对应的第2个所述待集成模型对应的所述第一预测概率值,fn是所述待评估集成模型对应的第n个所述待集成模型对应的所述第一预测概率值。从而使准确度高的所述待集成模型对应的所述第一预测概率值在预测概率综合值中占的权重高,有利于提高教师模型预测的准确度。
对于S345,获取所述预设概率阈值;将所述预测概率综合值中的大于所述预设概率阈值的每个向量元素标注为1,将所述预测概率综合值中的小于或等于所述预设概率阈值的每个向量元素标注为0,将完成标注的所述预测概率综合值作为综合预测标注值。
对于S346,根据各个所述综合预测标注值和各个所述第三训练样本的所述第三原始标定值,对所述待评估集成模型进行预测准确率计算,将计算得到的预测准确率作为所述待评估集成模型的目标预测准确率。
对于S347,重复执行步骤S341至步骤S347,直至完成所述初始集成模型以外的所述待集成模型的获取,当完成所述初始集成模型以外的所述待集成模型的获取时,意味着所述初始集成模型已经分别与所述初始集成模型以外的每个所述待集成模型完成了集成和目标预测准确率的计算。
对于S348,从各个所述目标预测准确率找出最大的所述目标预测准确率,将找出的所述目标预测准确率对应的所述待评估集成模型作为待处理集成模型。
对于S349,可以从数据库中获取所述初始集成模型的所述目标预测准确率,也可以从第三方应用系统中获取所述初始集成模型的所述目标预测准确率。
对于S3410,当所述初始集成模型的所述目标预测准确率小于所述待处理集成模型的所述目标预测准确率时,意味着所述初始集成模型的目标预测准确率还有提升空间,因此将所述待处理集成模型作为所述初始集成模型以用于下一轮的集成,重复执行步骤S341至步骤S3410,直至所述初始集成模型的所述目标预测准确率大于或等于所述待处理集成模型的所述目标预测准确率,当所述初始集成模型的所述目标预测准确率大于或等于所述待处理集成模型的所述目标预测准确率时,意味着所述初始集成模型的目标预测准确率已经没有提升空间。
对于S3411,将目标预测准确率已经没有提升空间的所述初始集成模型作为所述教师模型。
在一个实施例中,上述根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,包括:
S41:从所述第一训练集中获取第一训练样本作为目标样本;
S42:根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值;
S43:将所述目标样本输入所述学生模型进行交叉熵损失计算,得到所述目标样本的交叉熵损失值;
S44:根据所述目标样本的所述蒸馏损失值和所述交叉熵损失值,对所述学生模型的网络参数进行更新,更新后的所述学生模型用于下一次进行蒸馏训练;
S45:重复执行所述从所述第一训练集中获取第一训练样本作为目标样本的步骤,直至达到蒸馏训练结束条件,将达到所述蒸馏训练结束条件的所述学生模型作为所述目标模型。
本实施例先根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,然后对所述学生模型进行交叉熵损失计算,最后根据所述目标样本的所述蒸馏损失值和所述交叉熵损失值,对所述学生模型的网络参数进行更新,通过对所述学生模型进行交叉熵损失计算提高了所述学生模型对所述教师模型的学习能力;通过根据所述目标样本的所述蒸馏损失值和所述交叉熵损失值,对所述学生模型的网络参数进行更新,加快了蒸馏训练的速度,使所述学生模型能更准确的模仿所述教师模型。
对于S41,从所述第一训练集中获取第一训练样本作为目标样本。
对于S42,根据所述目标样本和所述教师模型,对所述学生模型分别进行每个隐藏层的蒸馏损失计算和最后一个网络层的蒸馏损失计算,根据每个隐藏层的蒸馏损失和最后一个网络层的蒸馏损失得到所述目标样本的蒸馏损失值。
对于S43,将所述目标样本的第三文本样本数据输入所述学生模型进行分类概率预测,根据分类概率预测得到的数据进行交叉熵损失计算,将交叉熵损失计算得到的数据作为所述目标样本的交叉熵损失值。
对于S44,将所述目标样本的所述蒸馏损失值和所述交叉熵损失值进行相加计算,根据相加计算得到的数据对所述学生模型的网络参数进行更新,将更新后的所述学生模型用于下一次进行蒸馏训练,也就是更新后的所述学生模型用于下一次进行所述蒸馏损失值和所述交叉熵损失值的计算。
对于S45,重复执行步骤S41至步骤S44,直至达到蒸馏训练结束条件,将达到所述蒸馏训练结束条件的所述学生模型作为所述目标模型。
蒸馏训练结束条件包括:所述学生模型的总损失值达到第三收敛条件或蒸馏次数达到第四收敛条件。
所述第三收敛条件是指相邻两次计算的所述学生模型的总损失值的大小满足lipschitz条件(利普希茨连续条件)。
所述蒸馏次数是指根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练的次数,也就是说,蒸馏训练一次,蒸馏次数增加1。
第四收敛条件是一个具体数值。
在一个实施例中,上述根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值,包括:
S421:将所述目标样本输入所述教师模型进行分类概率预测;
S422:获取所述教师模型中的各个模型的各个隐藏层输出的数据,得到多个待加权特征数据;
S423:根据各个所述待加权特征数据,针对所述教师模型的每类所述隐藏层进行加权求和,得到所述教师模型的每类所述隐藏层对应的第一特征数据;
S424:获取所述教师模型输出的分类概率预测的数据,得到第二预测概率值;
S425:将所述目标样本输入所述学生模型进行分类概率预测;
S426:获取所述学生模型的各个所述隐藏层输出的数据,得到多个第二特征数据;
S427:获取所述学生模型输出的分类概率预测的数据,得到第三预测概率值;
S428:根据各个所述第一特征数据和各个所述第二特征数据,针对所述学生模型进行每个所述隐藏层的蒸馏损失计算,得到多个层蒸馏损失;
S429:根据所述第二预测概率值和所述第三预测概率值,对所述学生模型进行蒸馏损失计算,得到综合蒸馏损失;
S4210:根据各个所述层蒸馏损失和所述综合蒸馏损失进行求和计算,得到所述目标样本的所述蒸馏损失值。
本实施例通过先对所述教师模型的各个隐藏层输出的数据和分类概率预测的数据,与学生模型的各个所述隐藏层输出的数据和分类概率预测的数据进行所述目标样本的所述蒸馏损失值,从而有利于加快学生模型的收敛,也有利于提高学生模型模仿教师模型的准确性。
对于S421,将所述目标样本的第三文本样本数据输入所述教师模型进行分类概率预测。
对于S422,获取所述教师模型中的各个模型的各个隐藏层输出的数据,将每个隐藏层输出的数据作为一个待加权特征数据。也就是说,待加权特征数据的数量和所述教师模型的隐藏层的数量相同。
对于S423,根据各个所述待加权特征数据,针对所述教师模型的每类所述隐藏层进行加权求和,也就是将所述教师模型中的各个模型在相同序号的隐藏层的输出结果进行加权求和。可以理解的是,第一特征数据的数量和所述教师模型的隐藏层的类型数量相同。第一特征数据的数量和所述预训练BERT模型的隐藏层的数量相同。
比如,根据各个所述待加权特征数据,针对所述教师模型中的各个模型的第2个的隐藏层的输出的所述待加权特征数据进行加权求和,将加权求和得到的数据作为所述教师模型中的第2个的隐藏层对应的第一特征数据,在此举例不做具体限定。
其中,W1、W2、Wn是常量,t1是所述待评估集成模型对应的第1个所述待集成模型的第i个所述隐藏层对应的所述待加权特征数据,t2是所述待评估集成模型对应的第2个所述待集成模型的第i个所述隐藏层对应的所述待加权特征数据,tn是所述待评估集成模型对应的第n个所述待集成模型的第i个所述隐藏层对应的所述待加权特征数据。
其中,k、q是常量,q的取值范围是0-1(包括1,不包括0),t1是所述待评估集成模型对应的第1个所述待集成模型的第i个所述隐藏层对应的所述待加权特征数据,t2是所述待评估集成模型对应的第2个所述待集成模型的第i个所述隐藏层对应的所述待加权特征数据,tn是所述待评估集成模型对应的第n个所述待集成模型的第i个所述隐藏层对应的所述待加权特征数据。
对于S424,将所述教师模型的各个模型输出的分类概率预测的数据进行加权求和,将加权求和得到的数据作为第二预测概率值。
其中,W1、W2、Wn是常量,h1是所述待评估集成模型对应的第1个所述待集成模型输出的分类概率预测的数据,h2是所述待评估集成模型对应的第2个所述待集成模型输出的分类概率预测的数据,hn是所述待评估集成模型对应的第n个所述待集成模型输出的分类概率预测的数据。
其中,k、q是常量,q的取值范围是0-1(包括1,不包括0),h1是所述待评估集成模型对应的第1个所述待集成模型输出的分类概率预测的数据,h2是所述待评估集成模型对应的第2个所述待集成模型输出的分类概率预测的数据,hn是所述待评估集成模型对应的第n个所述待集成模型输出的分类概率预测的数据。
对于S425,将所述目标样本的第三文本样本数据输入所述学生模型进行分类概率预测。
对于S426,获取所述学生模型的各个隐藏层输出的数据,将所述学生模型的每个隐藏层输出的数据作为一个第二特征数据。也就是说,第二特征数据的数量和所述学生模型的隐藏层的数量相同。第二特征数据的数量和所述预训练BERT模型的隐藏层的数量相同。
对于S427,将所述学生模型输出的分类概率预测的数据作为第三预测概率值。
对于S428,根据各个所述第一特征数据和各个所述第二特征数据,针对所述学生模型进行每个所述隐藏层的蒸馏损失计算,将每个所述隐藏层的蒸馏损失作为一个层蒸馏损失。也就是说,层蒸馏损失的数量与所述学生模型的隐藏层数量相同。
进行每个所述隐藏层的蒸馏损失计算,也就是计算所述第一特征数据对应的向量与所述第二特征数据对应的向量之前的距离。
比如,将第2层对应的所述第一特征数据和所述第二特征数据之间的距离作为第2层对应的层蒸馏损失,在此举例不做具体限定。
对于S429,所述第二预测概率值对应的向量和所述第三预测概率值对应的向量之前的距离,将计算的距离作为综合蒸馏损失。
对于S4210,将各个所述层蒸馏损失和所述综合蒸馏损失进行求和计算,将求和计算得到的数据作为所述目标样本的所述蒸馏损失值。所述蒸馏损失值也就是所述目标样本对所述学生模型的最终损失。
参照图2,本申请还提出了一种适用于BERT模型的蒸馏装置,所述装置包括:
第一数据获取模块100,用于获取预训练BERT模型和门控线性单元;
学生模型确定模块200,用于根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;
第二数据获取模块300,用于获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;
目标模型确定模块400,用于根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
本实施例通过首先根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,然后获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型,最后根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,因学生模型是根据所述预训练BERT模型和所述门控线性单元进行模型构建得到的,通过门控线性单元减少了BERT模型的维度数,有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,从而提高了学生模型的计算速度和准确度,提高了目标模型的实时性。
参照图3,本申请实施例中还提供一种计算机设备,该计算机设备可以是服务器,其内部结构可以如图3所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设计的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于储存适用于BERT模型的蒸馏方法等数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种适用于BERT模型的蒸馏方法。所述适用于BERT模型的蒸馏方法,包括:获取预训练BERT模型和门控线性单元;根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
本实施例通过首先根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,然后获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型,最后根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,因学生模型是根据所述预训练BERT模型和所述门控线性单元进行模型构建得到的,通过门控线性单元减少了BERT模型的维度数,有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,从而提高了学生模型的计算速度和准确度,提高了目标模型的实时性。
本申请一实施例还提供一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现一种适用于BERT模型的蒸馏方法,包括步骤:获取预训练BERT模型和门控线性单元;根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
上述执行的适用于BERT模型的蒸馏方法,通过首先根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型,然后获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型,最后根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,因学生模型是根据所述预训练BERT模型和所述门控线性单元进行模型构建得到的,通过门控线性单元减少了BERT模型的维度数,有效地降低了BERT模型的梯度弥散,而且还保留了BERT模型的非线性的能力,从而提高了学生模型的计算速度和准确度,提高了目标模型的实时性。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的和实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可以包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双速据率SDRAM(SSRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、装置、物品或者方法不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、装置、物品或者方法所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、装置、物品或者方法中还存在另外的相同要素。
以上所述仅为本申请的优选实施例,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。
Claims (10)
1.一种适用于BERT模型的蒸馏方法,其特征在于,所述方法包括:
获取预训练BERT模型和门控线性单元;
根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;
获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;
根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
2.根据权利要求1所述的适用于BERT模型的蒸馏方法,其特征在于,所述根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型包括:
根据所述门控线性单元,对所述预训练BERT模型中的每个隐藏层对应的TransFormerBlock单元分别进行替换处理,将替换处理后的所述预训练BERT模型作为所述学生模型,其中,所述学生模型的所述隐藏层与所述预训练BERT模型的所述隐藏层的层数相同。
3.根据权利要求1所述的适用于BERT模型的蒸馏方法,其特征在于,所述获取第一训练集和教师模型之前,还包括:
获取第二训练集、验证集和超参数数据列表;
根据所述第二训练集和所述超参数数据列表中的每个超参数数据,分别对所述预训练BERT模型进行训练,将训练结束的所述预训练BERT模型作为待集成模型;
根据所述验证集,从各个所述待集成模型中获取初始集成模型,所述初始集成模型为各个所述待集成模型中预测准确率最高的模型;
基于贪婪集成方法,根据所述初始集成模型,对各个所述待集成模型进行集成,得到所述教师模型。
4.根据权利要求3所述的适用于BERT模型的蒸馏方法,其特征在于,所述根据所述验证集,从各个所述待集成模型中获取初始集成模型,包括:
分别将所述验证集中的每个第三训练样本输入每个所述待集成模型中进行分类概率预测,得到每个所述第三训练样本在每个所述待集成模型中的第一预测概率值;
根据每个所述第一预测概率值进行标注,得到每个所述第三训练样本在每个所述待集成模型中的目标预测标注值;
根据各个所述目标预测标注值和各个所述第三训练样本各自对应的第三原始标定值,对每个所述待集成模型进行预测准确率计算,得到每个所述待集成模型的待分析预测准确率;
将各个所述待分析预测准确率中的最大所述待分析预测准确率对应的所述待集成模型作为所述初始集成模型。
5.根据权利要求4所述的适用于BERT模型的蒸馏方法,其特征在于,所述基于贪婪集成方法,根据所述初始集成模型,对各个所述待集成模型进行集成,得到所述教师模型,包括:
从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型;
将所述初始集成模型和所述待分析模型进行组合,得到待评估集成模型;
获取所述待评估集成模型对应的各个所述待集成模型各自对应的所述第一预测概率值,得到待分析的预测概率值集;
对所述待分析的预测概率值集中的各个所述第一预测概率值进行加权计算,得到所述待评估集成模型对应的预测概率综合值;
根据每个所述预测概率综合值进行标注,得到每个所述第三训练样本在所述待评估集成模型中的综合预测标注值;
根据各个所述综合预测标注值和各个所述第三训练样本的所述第三原始标定值,对所述待评估集成模型进行预测准确率计算,得到所述待评估集成模型的目标预测准确率;
重复执行所述从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型的步骤,直至完成所述初始集成模型以外的所述待集成模型的获取;
将所述目标预测准确率为最大的所述待评估集成模型作为待处理集成模型;
获取所述初始集成模型的所述目标预测准确率;
当所述初始集成模型的所述目标预测准确率小于所述待处理集成模型的所述目标预测准确率时,将所述待处理集成模型作为所述初始集成模型,重复执行所述从各个所述待集成模型中获取所述初始集成模型以外的所述待集成模型,作为待分析模型的步骤,直至所述初始集成模型的所述目标预测准确率大于或等于所述待处理集成模型的所述目标预测准确率;
将所述初始集成模型作为所述教师模型。
6.根据权利要求1所述的适用于BERT模型的蒸馏方法,其特征在于,所述根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型,包括:
从所述第一训练集中获取第一训练样本作为目标样本;
根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值;
将所述目标样本输入所述学生模型进行交叉熵损失计算,得到所述目标样本的交叉熵损失值;
根据所述目标样本的所述蒸馏损失值和所述交叉熵损失值,对所述学生模型的网络参数进行更新,更新后的所述学生模型用于下一次进行蒸馏训练;
重复执行所述从所述第一训练集中获取第一训练样本作为目标样本的步骤,直至达到蒸馏训练结束条件,将达到所述蒸馏训练结束条件的所述学生模型作为所述目标模型。
7.根据权利要求6所述的适用于BERT模型的蒸馏方法,其特征在于,所述根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值包括:
将所述目标样本输入所述教师模型进行分类概率预测;
获取所述教师模型中的各个模型的各个隐藏层输出的数据,得到多个待加权特征数据;
根据各个所述待加权特征数据,针对所述教师模型的每类所述隐藏层进行加权求和,得到所述教师模型的每类所述隐藏层对应的第一特征数据;
获取所述教师模型输出的分类概率预测的数据,得到第二预测概率值;
将所述目标样本输入所述学生模型进行分类概率预测;
获取所述学生模型的各个所述隐藏层输出的数据,得到多个第二特征数据;
获取所述学生模型输出的分类概率预测的数据,得到第三预测概率值;
根据各个所述第一特征数据和各个所述第二特征数据,针对所述学生模型进行每个所述隐藏层的蒸馏损失计算,得到多个层蒸馏损失;
根据所述第二预测概率值和所述第三预测概率值,对所述学生模型进行蒸馏损失计算,得到综合蒸馏损失;
根据各个所述层蒸馏损失和所述综合蒸馏损失进行求和计算,得到所述目标样本的所述蒸馏损失值。
8.一种适用于BERT模型的蒸馏装置,其特征在于,所述装置包括:
第一数据获取模块,用于获取预训练BERT模型和门控线性单元;
学生模型确定模块,用于根据所述预训练BERT模型和所述门控线性单元进行模型构建,得到学生模型;
第二数据获取模块,用于获取第一训练集和教师模型,其中,所述教师模型是基于所述预训练BERT模型训练得到的集成模型;
目标模型确定模块,用于根据所述第一训练集和所述教师模型,对所述学生模型进行蒸馏训练,得到目标模型。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至7中任一项所述方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110975567.0A CN113673698B (zh) | 2021-08-24 | 2021-08-24 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
PCT/CN2022/072362 WO2023024427A1 (zh) | 2021-08-24 | 2022-01-17 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110975567.0A CN113673698B (zh) | 2021-08-24 | 2021-08-24 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113673698A true CN113673698A (zh) | 2021-11-19 |
CN113673698B CN113673698B (zh) | 2024-05-10 |
Family
ID=78545639
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110975567.0A Active CN113673698B (zh) | 2021-08-24 | 2021-08-24 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN113673698B (zh) |
WO (1) | WO2023024427A1 (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023024427A1 (zh) * | 2021-08-24 | 2023-03-02 | 平安科技(深圳)有限公司 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
CN116861302A (zh) * | 2023-09-05 | 2023-10-10 | 吉奥时空信息技术股份有限公司 | 一种案件自动分类分拨方法 |
Families Citing this family (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116244417B (zh) * | 2023-03-23 | 2024-05-24 | 上海笑聘网络科技有限公司 | 应用于ai聊天机器人的问答交互数据处理方法及服务器 |
CN116720530A (zh) * | 2023-06-19 | 2023-09-08 | 内蒙古工业大学 | 一种基于预训练模型和对抗训练的蒙汉神经机器翻译方法 |
CN117082004A (zh) * | 2023-08-30 | 2023-11-17 | 湖北省楚天云有限公司 | 一种基于蒸馏表征模型的轻量级加密流量分析方法及系统 |
CN117390520B (zh) * | 2023-12-08 | 2024-04-16 | 惠州市宝惠电子科技有限公司 | 变压器状态监测方法及系统 |
CN117807235B (zh) * | 2024-01-17 | 2024-05-10 | 长春大学 | 一种基于模型内部特征蒸馏的文本分类方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112529153A (zh) * | 2020-12-03 | 2021-03-19 | 平安科技(深圳)有限公司 | 基于卷积神经网络的bert模型的微调方法及装置 |
CN112884150A (zh) * | 2021-01-21 | 2021-06-01 | 北京航空航天大学 | 一种预训练模型知识蒸馏的安全性增强方法 |
US20210182662A1 (en) * | 2019-12-17 | 2021-06-17 | Adobe Inc. | Training of neural network based natural language processing models using dense knowledge distillation |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112836762A (zh) * | 2021-02-26 | 2021-05-25 | 平安科技(深圳)有限公司 | 模型蒸馏方法、装置、设备及存储介质 |
CN113673698B (zh) * | 2021-08-24 | 2024-05-10 | 平安科技(深圳)有限公司 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
-
2021
- 2021-08-24 CN CN202110975567.0A patent/CN113673698B/zh active Active
-
2022
- 2022-01-17 WO PCT/CN2022/072362 patent/WO2023024427A1/zh active Application Filing
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210182662A1 (en) * | 2019-12-17 | 2021-06-17 | Adobe Inc. | Training of neural network based natural language processing models using dense knowledge distillation |
CN112529153A (zh) * | 2020-12-03 | 2021-03-19 | 平安科技(深圳)有限公司 | 基于卷积神经网络的bert模型的微调方法及装置 |
CN112884150A (zh) * | 2021-01-21 | 2021-06-01 | 北京航空航天大学 | 一种预训练模型知识蒸馏的安全性增强方法 |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023024427A1 (zh) * | 2021-08-24 | 2023-03-02 | 平安科技(深圳)有限公司 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
CN116861302A (zh) * | 2023-09-05 | 2023-10-10 | 吉奥时空信息技术股份有限公司 | 一种案件自动分类分拨方法 |
CN116861302B (zh) * | 2023-09-05 | 2024-01-23 | 吉奥时空信息技术股份有限公司 | 一种案件自动分类分拨方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113673698B (zh) | 2024-05-10 |
WO2023024427A1 (zh) | 2023-03-02 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113673698B (zh) | 适用于bert模型的蒸馏方法、装置、设备及存储介质 | |
Mehrkanoon et al. | Approximate solutions to ordinary differential equations using least squares support vector machines | |
CN109598337B (zh) | 基于分解模糊神经网络的二氧化硫浓度预测方法 | |
CN110909926A (zh) | 基于tcn-lstm的太阳能光伏发电预测方法 | |
CN117194637B (zh) | 基于大语言模型的多层级可视化评估报告生成方法、装置 | |
CN109977394B (zh) | 文本模型训练方法、文本分析方法、装置、设备及介质 | |
CN111598213B (zh) | 网络训练方法、数据识别方法、装置、设备和介质 | |
CN110362797B (zh) | 一种研究报告生成方法及相关设备 | |
CN113792682A (zh) | 基于人脸图像的人脸质量评估方法、装置、设备及介质 | |
WO2022178948A1 (zh) | 模型蒸馏方法、装置、设备及存储介质 | |
CN113642707A (zh) | 基于联邦学习的模型训练方法、装置、设备及存储介质 | |
CN112766485A (zh) | 命名实体模型的训练方法、装置、设备及介质 | |
Lataniotis | Data-driven uncertainty quantification for high-dimensional engineering problems | |
CN115496144A (zh) | 配电网运行场景确定方法、装置、计算机设备和存储介质 | |
Bae et al. | Multi-rate vae: Train once, get the full rate-distortion curve | |
CN113268564B (zh) | 相似问题的生成方法、装置、设备及存储介质 | |
Bryson et al. | A generalized multiple criteria data-fitting model with sparsity and entropy with application to growth forecasting | |
CN117093924A (zh) | 基于域适应特征的旋转机械变工况故障诊断方法 | |
CN116484904A (zh) | 一种基于人工智能深度学习的监管数据处理实现方法 | |
CN115238874A (zh) | 一种量化因子的搜索方法、装置、计算机设备及存储介质 | |
Su et al. | Language modeling using tensor trains | |
CN116431758A (zh) | 文本分类方法、装置、电子设备及计算机可读存储介质 | |
CN108073704B (zh) | 一种liwc词表扩展方法 | |
CN112949307A (zh) | 预测语句实体的方法、装置和计算机设备 | |
Caldana et al. | Neural ordinary differential equations for model order reduction of stiff systems |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |