CN112836762A - 模型蒸馏方法、装置、设备及存储介质 - Google Patents
模型蒸馏方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN112836762A CN112836762A CN202110220512.9A CN202110220512A CN112836762A CN 112836762 A CN112836762 A CN 112836762A CN 202110220512 A CN202110220512 A CN 202110220512A CN 112836762 A CN112836762 A CN 112836762A
- Authority
- CN
- China
- Prior art keywords
- distillation
- model
- student model
- training
- loss value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 329
- 238000000034 method Methods 0.000 title claims abstract description 63
- 230000006870 function Effects 0.000 claims description 42
- 238000004364 calculation method Methods 0.000 claims description 40
- 238000004590 computer program Methods 0.000 claims description 11
- 238000002372 labelling Methods 0.000 abstract description 7
- 238000013473 artificial intelligence Methods 0.000 abstract description 3
- 230000008569 process Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 3
- 230000001360 synchronised effect Effects 0.000 description 2
- 230000006872 improvement Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Other Investigation Or Analysis Of Materials By Electrical Means (AREA)
Abstract
本申请涉及人工智能技术领域,揭示了一种模型蒸馏方法、装置、设备及存储介质,其中方法包括:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,预训练模型是基于Bert网络训练得到的模型;采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。从而通过三次蒸馏提升了蒸馏后得到的模型的准确率,减少了对训练样本标注的需求,降低了蒸馏的成本。
Description
技术领域
本申请涉及到人工智能技术领域,特别是涉及到一种模型蒸馏方法、装置、设备及存储介质。
背景技术
目前预训练模型具有较强的编码能力以及泛化能力,而且在使用预训练模型时,处理下游任务能极大的减少标注数据的使用量,因此在各个领域都发挥着巨大的作用。因为预训练模型通常参数量较大,导致无法在线使用。
现有技术通过把一个参数量大的预训练模型蒸馏到一个参数量小的模型,以用于实现参数量的减少以及推理速度的提升。但是,目前的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多甚至达到10个点左右的差距。同时,目前很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本。
发明内容
本申请的主要目的为提供一种模型蒸馏方法、装置、设备及存储介质,旨在解决现有技术的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本的技术问题。
为了实现上述发明目的,本申请提出一种模型蒸馏方法,所述方法包括:
获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;
采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;
采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;
采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
进一步的,所述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:
将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;
将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;
将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;
重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。
进一步的,所述将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:
将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。
进一步的,所述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:
将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;
将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;
将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;
重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。
进一步的,所述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:
将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;
当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。
进一步的,所述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:
将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;
将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;
重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。
进一步的,所述将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数的步骤,包括:
将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;
当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。
本申请还提出了一种模型蒸馏装置,所述装置包括:
数据获取模块,用于获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;
第一阶段蒸馏模块,用于采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;
第二阶段蒸馏模块,用于采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;
第三阶段蒸馏模块,用于采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
本申请还提出了一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现上述任一项所述方法的步骤。
本申请还提出了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任一项所述的方法的步骤。
本申请的模型蒸馏方法、装置、设备及存储介质,通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。
附图说明
图1为本申请一实施例的模型蒸馏方法的流程示意图;
图2为本申请一实施例的模型蒸馏装置的结构示意框图;
图3为本申请一实施例的计算机设备的结构示意框图。
本申请目的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
为了解决现有技术的蒸馏方法蒸馏后的小模型和原始模型存在准确率上的差距,很多蒸馏方案都需要大量的标注数据,这极大的提升了蒸馏的成本的技术问题,本申请提出了一种模型蒸馏方法,所述方法应用于人工智能技术领域。所述模型蒸馏方法通过第一次采用整体蒸馏学习、第二次采用分层蒸馏学习、第三次分层蒸馏学习,通过三次蒸馏提升了蒸馏后得到的模型的准确率;而且在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。
参照图1,本申请实施例中提供一种模型蒸馏方法,所述方法包括:
S1:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;
S2:采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;
S3:采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;
S4:采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
本实施例通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。
对于S1,可以从数据库中获取预训练模型,也可以是用户输入的预训练模型,还可以是第三方应用系统发送的预训练模型。
可以从数据库中获取学生模型,也可以是用户输入的学生模型,还可以是第三方应用系统发送的学生模型。
可以从数据库中获取多个带标注的训练样本,也可以是用户输入的多个带标注的训练样本,还可以是第三方应用系统发送的多个带标注的训练样本。
可以从数据库中获取多个未标注的训练样本,也可以是用户输入的多个未标注的训练样本,还可以是第三方应用系统发送的多个未标注的训练样本。
所述学生模型包括:Embedding层、BiLSTM层、Dense层。Embedding层输入数据到BiLSTM层,BiLSTM层输出数据到Dense层。Embedding层是嵌入层。BiLSTM层的输出为每一个标签的预测评分。Dense层是全连接层,输出预测概率。
带标注的训练样本包括:样本数据、样本标定值,样本标定值是对样本数据的标定结果。
未标注的训练样本包括:样本数据。
可选的,所述多个带标注的训练样本中带标注的训练样本的数量小于所述多个未标注的训练样本中未标注的训练样本的数量。
对于S2,采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,也就是对所述预训练模型的所有参数进行更新,将训练后的学生模型作为第一次蒸馏后的学生模型。
对于S3,采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,也就是对第一次蒸馏后的学生模型的参数进行分层更新,将训练后的第一次蒸馏后的学生模型作为第二次蒸馏后的学生模型。从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第二次蒸馏就忘记了第一次蒸馏的内容的现象。
对于S4,采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,也就是对第二次蒸馏后的学生模型的参数进行分层更新,将训练后的第二次蒸馏后的学生模型作为训练好的学生模型。从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第三次蒸馏就忘记了第二次蒸馏的内容的现象。
在一个实施例中,上述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:
S21:将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;
S22:将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;
S23:将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;
S24:重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。
本实施例实现了根据未标注的训练样本预测得到预测评分计算损失值对所述学生模型的所有参数进行更新,实现了整体蒸馏学习所述预训练模型学习到的知识。
对于S21,将所述未标注的训练样本的样本数据输入所述预训练模型进行预测,将所述预训练模型的评分预测层输出的评分作为第一预测评分。
对于S22,将所述未标注的训练样本的样本数据输入所述学生模型进行预测,将所述学生模型的BiLSTM层输出的评分作为第二预测评分。
对于S23,将所述第一预测评分、所述第二预测评分输入第一损失函数进行损失值计算,将计算得到的损失值作为第一损失值。
根据所述第一损失值更新所述学生模型的所有参数的方法可以从现有技术中选择,在此不做赘述。
对于S24,所述第一收敛条件是指相邻两次计算的第一损失值的大小满足lipschitz条件(利普希茨连续条件)。
所述迭代次数达到第二收敛条件是指所述学生模型被用于计算所述第二预测评分的次数,也就是说,计算一次,迭代次数增加1。
可以理解的是,当所述第一损失值未达到第一收敛条件并且迭代次数未达到第二收敛条件时,从所述多个未标注的训练样本中获取新的未标注的训练样本,根据获取的未标注的训练样本执行步骤S21至步骤S24。
在一个实施例中,上述将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:
将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。
KL散度损失函数,又称为K-L散度损失函数。
KL散度损失函数KL(p||q)的计算公式为:
其中,x是所述未标注的训练样本的样本数据,p(x)是第一预测评分,q(x)是第二预测评分,log()是对数函数。
在一个实施例中,上述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:
S31:将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;
S32:将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;
S33:将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;
S34:重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。
本实施例实现了根据未标注的训练样本预测得到预测概率计算损失值对第一次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第二次蒸馏就忘记了第一次蒸馏的内容的现象。
对于S31,将所述未标注的训练样本的样本数据输入所述预训练模型进行概率预测,将所述预训练模型的概率预测层输出的概率作为第一预测概率。
对于S32,将所述未标注的训练样本的样本数据输入所述第一次蒸馏后学生模型进行概率预测,将所述第一次蒸馏后学生模型的Dense层输出的概率作为第二预测概率。
对于S33,将所述第一预测概率、所述第二预测概率输入第二损失函数进行损失值计算,将计算得到的损失值作为第二损失值。
根据所述第二损失值每次只更新所述第一次蒸馏后学生模型的一层(也就是Embedding层、BiLSTM层、Dense层)的参数。
对于S34,所述第三收敛条件是指相邻两次计算的第三损失值的大小满足lipschitz条件(利普希茨连续条件)。
所述迭代次数达到第四收敛条件是指所述学生模型被用于计算所述第四预测概率的次数,也就是说,计算一次,迭代次数增加1。
可以理解的是,当所述第二损失值未达到第三收敛条件并且迭代次数未达到第四收敛条件时,从所述多个未标注的训练样本中获取新的未标注的训练样本,根据获取的未标注的训练样本执行步骤S31至步骤S34。
在一个实施例中,上述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:
S331:将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;
S332:当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。
本实施例实现了根据未标注的训练样本预测得到预测概率计算损失值对第一次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第二次蒸馏就忘记了第一次蒸馏的内容的现象。
对于S331,MSE损失函数公式MSE(p,q)如下:
其中,pt是第一预测概率,qt是第二预测概率。
对于S332,第一Dense层收敛条件、第一BiLSTM层收敛条件可以根据训练需求设置,在此不做具体限定。
在一个实施例中,上述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:
S41:将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;
S42:将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;
S43:重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。
本实施例实现了根据带标注的训练样本预测得到的预测概率计算损失值对第二次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第三次蒸馏就忘记了第二次蒸馏的内容的现象。
对于S41,将所述带标注的训练样本的样本数据输入所述第二次蒸馏后的学生模型进行概率预测,将所述第二次蒸馏后的学生模型的Dense层输出的概率作为第三预测概率。
对于S42,将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行损失值计算,将计算得到的损失值作为第三损失值。
根据所述第三损失值每次只更新所述第二次蒸馏后学生模型的一层(也就是Embedding层、BiLSTM层、Dense层)的参数。
对于S43,所述第五收敛条件是指相邻两次计算的第三损失值的大小满足lipschitz条件(利普希茨连续条件)。
所述迭代次数达到第六收敛条件是指所述第二次蒸馏后的学生模型被用于计算所述第三预测概率的次数,也就是说,计算一次,迭代次数增加1。
可以理解的是,当所述第三损失值未达到第五收敛条件并且迭代次数未达到第六收敛条件时,从所述多个带标注的训练样本中获取新的带标注的训练样本,根据获取的带标注的训练样本执行步骤S41至步骤S43。
在一个实施例中,上述所述将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数的步骤,包括:
S421:将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;
S422:当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。
本实施例根据带标注的训练样本预测得到的预测概率计算损失值对第二次蒸馏后的学生模型的参数进行分层更新,从而避免了现有技术的蒸馏方式的灾难性遗忘的现象,避免了在第三次蒸馏就忘记了第二次蒸馏的内容的现象。
对于S421,交叉熵损失函数CE的计算公式如下:
其中,yc是所述带标注的训练样本的样本标定值,pc是第三预测概率。
对于S422,第二Dense层收敛条件、第二BiLSTM层收敛条件可以根据训练需求设置,在此不做具体限定。
参照图2,本申请还提出了一种模型蒸馏装置,所述装置包括:
数据获取模块100,用于获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;
第一阶段蒸馏模块200,用于采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;
第二阶段蒸馏模块300,用于采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;
第三阶段蒸馏模块400,用于采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
本实施例通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。
在一个实施例中,所述第一阶段蒸馏模块200,包括:预训练模型评分预测子模块、学生模型评分预测子模块、第一阶段蒸馏训练子模块;
所述预训练模型评分预测子模块,用于将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;
所述学生模型评分预测子模块,用于将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;
所述第一阶段蒸馏训练子模块,用于将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分,重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。
在一个实施例中,所述第一阶段蒸馏训练子模块包括:第一损失值计算单元;
所述第一损失值计算单元,用于将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。
在一个实施例中,所述第二阶段蒸馏模块300包括:预训练模型概率预测子模块、第一次蒸馏后的学生模型概率预测子模块、第二阶段蒸馏训练子模块;
所述预训练模型概率预测子模块,用于将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;
所述第一次蒸馏后的学生模型概率预测子模块,用于将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;
所述第二阶段蒸馏训练子模块,用于将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率,重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。
在一个实施例中,所述第二阶段蒸馏训练子模块包括:第二损失值计算单元、第一参数更新单元;
所述第二损失值计算单元,用于将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;
所述第一参数更新单元,用于当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。
在一个实施例中,所述第三阶段蒸馏模块400包括:第二次蒸馏后的学生模型概率预测子模块、第三阶段蒸馏训练子模块;
所述第二次蒸馏后的学生模型概率预测子模块,用于将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;
所述第三阶段蒸馏训练子模块,用于将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率,重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。
在一个实施例中,所述第三阶段蒸馏训练子模块包括:第三损失值计算单元、第二参数更新单元;
所述第三损失值计算单元,用于将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;
所述第二参数更新单元,用于当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。
参照图3,本申请实施例中还提供一种计算机设备,该计算机设备可以是服务器,其内部结构可以如图3所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设计的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于储存模型蒸馏方法等数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种模型蒸馏方法。所述模型蒸馏方法,包括:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
本实施例通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。
本申请一实施例还提供一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现一种模型蒸馏方法,包括步骤:获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
上述执行的模型蒸馏方法,通过采用未标注的训练样本和学生模型对预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型,采用未标注的训练样本和第一次蒸馏后的学生模型对预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型,采用带标注的训练样本对第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型,从而通过三次蒸馏提升了蒸馏后得到的模型的准确率;因为在第一次和第二次蒸馏采用的是未标注的训练样本,从而减少了对训练样本标注的需求,降低了蒸馏的成本。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的和实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可以包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双速据率SDRAM(SSRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、装置、物品或者方法不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、装置、物品或者方法所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、装置、物品或者方法中还存在另外的相同要素。
以上所述仅为本申请的优选实施例,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。
Claims (10)
1.一种模型蒸馏方法,其特征在于,所述方法包括:
获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;
采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;
采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;
采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
2.根据权利要求1所述的模型蒸馏方法,其特征在于,所述采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型的步骤,包括:
将所述未标注的训练样本输入所述预训练模型进行评分预测,获取所述预训练模型的评分预测层输出的第一预测评分;
将所述未标注的训练样本输入所述学生模型的进行评分预测,得到第二预测评分;
将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值,根据所述第一损失值更新所述学生模型的所有参数,将更新参数后的所述学生模型用于下一次计算所述第二预测评分;
重复执行上述方法步骤直至所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件,将所述第一损失值达到第一收敛条件或迭代次数达到第二收敛条件的所述学生模型,确定为所述第一次蒸馏后的学生模型。
3.根据权利要求2所述的模型蒸馏方法,其特征在于,所述将所述第一预测评分、所述第二预测评分输入第一损失函数进行计算,得到第一损失值的步骤,包括:
将所述第一预测评分、所述第二预测评分输入KL散度损失函数进行计算,得到所述第一损失值。
4.根据权利要求1所述的模型蒸馏方法,其特征在于,所述采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型的步骤,包括:
将所述未标注的训练样本输入所述预训练模型进行概率预测,获取所述预训练模型的概率预测层输出的第一预测概率;
将所述未标注的训练样本输入所述第一次蒸馏后的学生模型进行概率预测,得到第二预测概率;
将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数,将更新参数后的所述第一次蒸馏后的学生模型用于下一次计算所述第二预测概率;
重复执行上述方法步骤直至所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件,将所述第二损失值达到第三收敛条件或迭代次数达到第四收敛条件的所述第一次蒸馏后的学生模型,确定为所述第二次蒸馏后的学生模型。
5.根据权利要求4所述的模型蒸馏方法,其特征在于,所述将所述第一预测概率、所述第二预测概率输入第二损失函数进行计算,得到第二损失值,根据所述第二损失值按第一预设参数分层更新规则更新所述第一次蒸馏后的学生模型的参数的步骤,包括:
将所述第一预测概率、所述第二预测概率输入MSE损失函数进行计算,得到所述第二损失值;
当所述第二损失值中的Dense层参数未达到第一Dense层收敛条件时,根据所述第二损失值中的Dense层参数更新所述第一次蒸馏后的学生模型的Dense层的参数,否则,当所述第二损失值中的BiLSTM层参数未达到第一BiLSTM层收敛条件时,根据所述第二损失值中的BiLSTM层参数更新所述第一次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第二损失值中的Embedding层参数更新所述第一次蒸馏后的学生模型的Embedding层的参数。
6.根据权利要求1所述的模型蒸馏方法,其特征在于,所述采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型的步骤,包括:
将所述带标注的训练样本输入所述第二次蒸馏后的学生模型进行概率预测,得到第三预测概率;
将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数,将更新参数后的所述第二次蒸馏后的学生模型用于下一次计算所述第三预测概率;
重复执行上述方法步骤直至所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件,将所述第三损失值达到第五收敛条件或迭代次数达到第六收敛条件的所述第二次蒸馏后的学生模型,确定为所述训练好的学生模型。
7.根据权利要求6所述的模型蒸馏方法,其特征在于,所述将所述第三预测概率、所述带标注的训练样本的样本标定值输入第三损失函数进行计算,得到第三损失值,根据所述第三损失值按第二预设参数分层更新规则更新所述第二次蒸馏后的学生模型的参数的步骤,包括:
将所述第三预测概率、所述带标注的训练样本的样本标定值输入交叉熵损失函数进行计算,得到所述第三损失值;
当所述第三损失值中的Dense层参数未达到第二Dense层收敛条件时,根据所述第三损失值中的Dense层参数更新所述第二次蒸馏后的学生模型的Dense层的参数,否则,当所述第三损失值中的BiLSTM层参数未达到第二BiLSTM层收敛条件时,根据所述第三损失值中的BiLSTM层参数更新所述第二次蒸馏后的学生模型的BiLSTM层的参数,否则,根据所述第三损失值中的Embedding层参数更新所述第二次蒸馏后的学生模型的Embedding层的参数。
8.一种模型蒸馏装置,其特征在于,所述装置包括:
数据获取模块,用于获取预训练模型、学生模型、多个带标注的训练样本、多个未标注的训练样本,所述预训练模型是基于Bert网络训练得到的模型;
第一阶段蒸馏模块,用于采用所述未标注的训练样本和所述学生模型对所述预训练模型进行整体蒸馏学习,得到第一次蒸馏后的学生模型;
第二阶段蒸馏模块,用于采用所述未标注的训练样本和所述第一次蒸馏后的学生模型对所述预训练模型进行分层蒸馏学习,得到第二次蒸馏后的学生模型;
第三阶段蒸馏模块,用于采用所述带标注的训练样本对所述第二次蒸馏后的学生模型进行分层蒸馏学习,得到训练好的学生模型。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至7中任一项所述方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110220512.9A CN112836762A (zh) | 2021-02-26 | 2021-02-26 | 模型蒸馏方法、装置、设备及存储介质 |
PCT/CN2021/084539 WO2022178948A1 (zh) | 2021-02-26 | 2021-03-31 | 模型蒸馏方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110220512.9A CN112836762A (zh) | 2021-02-26 | 2021-02-26 | 模型蒸馏方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112836762A true CN112836762A (zh) | 2021-05-25 |
Family
ID=75933941
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110220512.9A Pending CN112836762A (zh) | 2021-02-26 | 2021-02-26 | 模型蒸馏方法、装置、设备及存储介质 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN112836762A (zh) |
WO (1) | WO2022178948A1 (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113177616A (zh) * | 2021-06-29 | 2021-07-27 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置、设备及存储介质 |
WO2023024427A1 (zh) * | 2021-08-24 | 2023-03-02 | 平安科技(深圳)有限公司 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
CN115861847A (zh) * | 2023-02-24 | 2023-03-28 | 耕宇牧星(北京)空间科技有限公司 | 可见光遥感图像目标智能辅助标注方法 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11080558B2 (en) * | 2019-03-21 | 2021-08-03 | International Business Machines Corporation | System and method of incremental learning for object detection |
CN110059740A (zh) * | 2019-04-12 | 2019-07-26 | 杭州电子科技大学 | 一种针对嵌入式移动端的深度学习语义分割模型压缩方法 |
CN112257860A (zh) * | 2019-07-02 | 2021-01-22 | 微软技术许可有限责任公司 | 基于模型压缩的模型生成 |
CN110852426B (zh) * | 2019-11-19 | 2023-03-24 | 成都晓多科技有限公司 | 基于知识蒸馏的预训练模型集成加速方法及装置 |
CN111242297A (zh) * | 2019-12-19 | 2020-06-05 | 北京迈格威科技有限公司 | 基于知识蒸馏的模型训练方法、图像处理方法及装置 |
-
2021
- 2021-02-26 CN CN202110220512.9A patent/CN112836762A/zh active Pending
- 2021-03-31 WO PCT/CN2021/084539 patent/WO2022178948A1/zh active Application Filing
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113177616A (zh) * | 2021-06-29 | 2021-07-27 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置、设备及存储介质 |
WO2023024427A1 (zh) * | 2021-08-24 | 2023-03-02 | 平安科技(深圳)有限公司 | 适用于bert模型的蒸馏方法、装置、设备及存储介质 |
CN115861847A (zh) * | 2023-02-24 | 2023-03-28 | 耕宇牧星(北京)空间科技有限公司 | 可见光遥感图像目标智能辅助标注方法 |
Also Published As
Publication number | Publication date |
---|---|
WO2022178948A1 (zh) | 2022-09-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112836762A (zh) | 模型蒸馏方法、装置、设备及存储介质 | |
CN113673698B (zh) | 适用于bert模型的蒸馏方法、装置、设备及存储介质 | |
US20150254554A1 (en) | Information processing device and learning method | |
Hanema et al. | Stabilizing tube-based model predictive control: Terminal set and cost construction for LPV systems | |
CN113792682B (zh) | 基于人脸图像的人脸质量评估方法、装置、设备及介质 | |
CN111523686B (zh) | 一种模型联合训练的方法和系统 | |
CN112613312B (zh) | 实体命名识别模型的训练方法、装置、设备及存储介质 | |
CN112732892B (zh) | 课程推荐方法、装置、设备及存储介质 | |
CN112348362A (zh) | 岗位候选人的确定方法、装置、设备及介质 | |
CN112365385B (zh) | 基于自注意力的知识蒸馏方法、装置和计算机设备 | |
CN112733911A (zh) | 实体识别模型的训练方法、装置、设备和存储介质 | |
CN113270103A (zh) | 基于语义增强的智能语音对话方法、装置、设备及介质 | |
CN113326379A (zh) | 文本分类预测方法、装置、设备及存储介质 | |
CN114860915A (zh) | 一种模型提示学习方法、装置、电子设备及存储介质 | |
CN114416984A (zh) | 基于人工智能的文本分类方法、装置、设备及存储介质 | |
EP3895080A1 (en) | Regularization of recurrent machine-learned architectures | |
CN115186062A (zh) | 多模态预测方法、装置、设备及存储介质 | |
CN113268564B (zh) | 相似问题的生成方法、装置、设备及存储介质 | |
KR20220098698A (ko) | 잠재인자에 기반한 협업 필터링을 사용하여 사용자의 정답확률을 예측하는 학습 컨텐츠 추천 시스템 및 그것의 동작방법 | |
Xu et al. | Optimal regulation of uncertain dynamic systems using adaptive dynamic programming | |
CN112766485A (zh) | 命名实体模型的训练方法、装置、设备及介质 | |
CN113642984A (zh) | 基于人工智能的员工考勤方法、装置、设备及存储介质 | |
CN117668157A (zh) | 基于知识图谱的检索增强方法、装置、设备及介质 | |
Bryson et al. | A generalized multiple criteria data-fitting model with sparsity and entropy with application to growth forecasting | |
CN116629362A (zh) | 一种基于路径搜索的可解释时间图推理方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
REG | Reference to a national code |
Ref country code: HK Ref legal event code: DE Ref document number: 40046364 Country of ref document: HK |
|
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20210525 |
|
RJ01 | Rejection of invention patent application after publication |