CN111368997B - 神经网络模型的训练方法及装置 - Google Patents
神经网络模型的训练方法及装置 Download PDFInfo
- Publication number
- CN111368997B CN111368997B CN202010143596.6A CN202010143596A CN111368997B CN 111368997 B CN111368997 B CN 111368997B CN 202010143596 A CN202010143596 A CN 202010143596A CN 111368997 B CN111368997 B CN 111368997B
- Authority
- CN
- China
- Prior art keywords
- model
- loss
- probability distribution
- calibration sample
- neural network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F17/00—Digital computing or data processing equipment or methods, specially adapted for specific functions
- G06F17/10—Complex mathematical operations
- G06F17/18—Complex mathematical operations for evaluating statistical data, e.g. average values, frequency distributions, probability functions, regression analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Evolutionary Computation (AREA)
- Mathematical Analysis (AREA)
- Health & Medical Sciences (AREA)
- Pure & Applied Mathematics (AREA)
- Computational Linguistics (AREA)
- Computational Mathematics (AREA)
- Mathematical Optimization (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Operations Research (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Algebra (AREA)
- Databases & Information Systems (AREA)
- Image Analysis (AREA)
- Feedback Control In General (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本说明书实施例提供一种神经网络模型的训练方法及装置,在训练方法中,基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练第一模型的第二模型。从样本集合中选取当前标定样本,并基于其执行以下步骤:将当前标定样本输入第一模型,得到第一概率分布。基于第一概率分布,确定当前标定样本的预测标签。将当前标定样本输入第二模型,得到第二概率分布。基于标定标签和预测标签,确定第一预测损失。基于第一概率分布和第二概率分布,确定第二预测损失。结合第一预测损失和第二预测损失,调整第一模型的参数。在全部样本选取完之后,将最后一次调整参数后的第一模型作为在当前周期训练后的神经网络模型。
Description
技术领域
本说明书一个或多个实施例涉及计算机技术领域,尤其涉及一种神经网络模型的训练方法及装置。
背景技术
随着人工智能的普遍流行,神经网络模型越来越受到关注。对于神经网络模型,通常需要先对其进行训练,之后,利用训练后的神经网络模型进行业务处理,如,进行基于图像识别的业务处理、基于用户分类的业务处理、基于音频识别的业务处理以及基于文本分析的业务处理等等。
在传统的神经网络模型的训练过程中,通常会先对样本集合中的样本进行划分,如划分为若干份。之后分批次读取各份样本,并且在每次读取到一份样本之后,基于梯度下降方法,训练神经网络模型。然而,由于样本集合中的样本的质量存在差异性,因此,当质量较差的样本排列在后时,可能会出现模型的预测准确率不稳定的情况,也即会出现训练震荡的问题,这就使得训练得到的模型准确性较差。
因此,需要提供一种更可靠地神经网络模型的训练方法。
发明内容
本说明书一个或多个实施例描述了一种神经网络模型的训练方法及装置,可以使训练的神经网络模型更准确。
第一方面,提供了一种神经网络模型的训练方法,包括:
基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练所述第一模型的第二模型;
从样本集合中选取一批标定样本作为当前标定样本,基于当前标定样本以及所述第二模型,对所述第一模型进行训练,该训练步骤包括:
将当前标定样本输入所述第一模型,通过所述第一模型的输出得到当前标定样本对应的第一概率分布;
基于所述第一概率分布,确定对应于当前标定样本的预测标签;
将当前标定样本输入所述第二模型,通过所述第二模型的输出得到当前标定样本对应的第二概率分布;
基于当前标定样本的标定标签和所述预测标签,确定第一预测损失;
基于所述第一概率分布和所述第二概率分布,确定第二预测损失;
以最小化所述第一预测损失和所述第二预测损失为目标,调整所述第一模型的模型参数;
在基于所述样本集合中的各样本执行所述训练步骤之后,将最后一次调整模型参数后的所述第一模型作为在当前周期训练后的神经网络模型。
第二方面,提供了一种神经网络模型的训练装置,包括:
确定单元,用于基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练所述第一模型的第二模型;
训练单元,用于从样本集合中选取一批标定样本作为当前标定样本,基于当前标定样本以及所述第二模型,对所述第一模型进行训练;
所述训练单元具体包括:
输入子单元,用于将当前标定样本输入所述第一模型,通过所述第一模型的输出得到当前标定样本对应的第一概率分布;
确定子单元,用于基于所述第一概率分布,确定对应于当前标定样本的预测标签;
所述输入子单元,还用于将当前标定样本输入所述第二模型,通过所述第二模型的输出得到当前标定样本对应的第二概率分布;
所述确定子单元,还用于基于当前标定样本的标定标签和所述预测标签,确定第一预测损失;基于所述第一概率分布和所述第二概率分布,确定第二预测损失;
调整子单元,用于以最小化所述第一预测损失和所述第二预测损失为目标,调整所述第一模型的模型参数;
所述确定单元,还用于在基于所述样本集合中的各样本执行所述训练步骤之后,将最后一次调整模型参数后的所述第一模型作为在当前周期训练后的神经网络模型。
第三方面,提供了一种计算机存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行第一方面的方法。
第四方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现第一方面的方法。
本说明书一个或多个实施例提供的神经网络模型的训练方法及装置,在神经网络模型的迭代训练过程中,在每个迭代周期,可以使用上一迭代周期学习到的知识,指导该迭代周期的模型训练过程,由此可以使得模型训练过程更加平稳,进而可以使训练的神经网络模型更准确。
附图说明
为了更清楚地说明本说明书实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本说明书的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1为本说明书提供的神经网络模型的训练方法应用场景示意图;
图2为本说明书一个实施例提供的神经网络模型的训练方法流程图;
图3为本说明书一个实施例提供的神经网络模型的训练装置示意图。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
在描述本说明书提供的方案之前,先对本方案的发明构思作以下说明。
为解决在传统的神经网络模型的训练过程中训练震荡的问题,本申请的发明人考虑在神经网络模型的迭代训练过程中,在每个迭代周期(简称周期),可以使用上一周期学习到的知识,指导该周期的模型训练过程,由此来确保模型训练过程的稳定性。
具体地,在每个周期,可以先基于在上一周期训练后的神经网络模型,确定在当前周期待训练的第一模型(也称学生模型),以及用于辅助或者指导训练第一模型的第二模型(也称老师模型)。应理解,在每个周期开始时,第一模型和第二模型是相同的,即该两者均是由初始的神经网络模型在经过过去若干周期的训练之后得到。
之后,从样本集合中选取一批标定样本作为当前标定样本,基于当前标定样本以及第二模型,对第一模型进行训练。其中,第一模型的具体训练过程可以为:将当前标定样本输入第一模型,通过第一模型的输出得到当前标定样本对应的第一概率分布。基于第一概率分布,确定对应于当前标定样本的预测标签。将当前标定样本输入第二模型,通过第二模型的输出得到当前标定样本对应的第二概率分布。基于当前标定样本的标定标签和预测标签,确定第一预测损失。基于第一概率分布和第二概率分布,确定第二预测损失。以最小化第一预测损失和第二预测损失为目标,调整第一模型的模型参数。至此,第一模型在当前周期的一个训练步骤结束。
需要说明的是,在神经网络模型的每个训练步骤中,通过引入老师模型,可以确保每一步的训练不会出现大的震荡。
应理解,由于在上述一个训练步骤结束之后,第一模型的模型参数发生了变化,而第二模型的模型参数没有发生变化,因此,第一模型已经不同于第二模型。
此外,在每个周期,第一模型的训练步骤是循环执行的,其结束条件为:样本集合中的样本全部选取完成。在样本集合中的样本全部选取完成之后,神经网络模型训练过程的一个周期结束。在该一个周期结束之后,可以将最后一次调整参数后的第一模型作为在当前周期训练后的神经网络模型。从而,在进入下一周期之后,可以基于其确定在下一周期待训练的第一模型,以及用于在下一周期辅助训练第一模型的第二模型,并执行下一周期的第一模型的训练的步骤;依次类推,直至到达最后一个周期,神经网络模型的训练过程结束。
由以上的发明构思可以看出,在神经网络模型的每个周期,都会先确定一个老师模型,由于这个老师模型是基于在上一周期训练后的神经网络模型确定,从而其可以看作是在上一周期学习到的知识。之后,通过该老师模型来指导当前周期的学习模型的训练过程(如,缩小老师模型的预测结果与学习模型的预测结果之间的差异),从而可以避免传统技术中,由于样本的差异性,而导致的训练震荡的问题。
以上就是本说明书提供的发明构思,基于该发明构思就可以得到本方案,以下对本方案进行详细阐述。
图1为本说明书提供的神经网络模型的训练方法应用场景示意图。图1中,业务处理系统可以基于预先训练的神经网络模型进行相应的业务处理。这里的业务处理可以包括但不限于基于图像识别的业务处理、基于用户分类的业务处理、基于音频识别的业务处理以及基于文本分析的业务处理等等。
图1中的神经网络模型可以经过多个周期训练得到,其中,在每个周期,都会先确定一个老师模型,由于这个老师模型是基于在上一周期训练后的神经网络模型确定,从而其可以看作是在上一周期学习到的知识。之后,通过该老师模型来指导当前周期的学习模型的训练过程。由此,可以确保模型训练过程的稳定性。
图2为本说明书一个实施例提供的神经网络模型的训练方法流程图。所述方法的执行主体可以为具有处理能力的设备:服务器或者系统或者装置。如图2所示,该方法可以包括:
步骤202,基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练第一模型的第二模型。
这里的神经网络模型可以包括但不限于语言表征模型Bidirectional(EncoderRepresentations from Transformers,BERT)、循环神经网络(Recurrent NeuralNetworks,RNN)以及卷积神经网络(Convolutional Neural Networks,CNN)等。
此外,上述第一模型也可以称为学生模型,上述第二模型也可以称为老师模型。这里的老师模型可以看作是在上一周期的训练过程中所学习到的知识,其用于辅助或者指导当前周期的学生模型的训练过程。
在一个例子中,可以直接将上一周期训练后的神经网络模型,分别作为第一模型和第二模型。在另一个例子中,也可以先对上一周期训练后的神经网络模型的模型参数作一些简单调整,如,明显的错误更正。之后,将调整后的上一周期的神经网络模型,分别作为第一模型和第二模型。
应理解,在进入每个周期之后,都会先确定一个老师模型和一个学生模型。且在初始时,该老师模型和学生模型是相同的。
步骤204,从样本集合中选取一批标定样本作为当前标定样本,基于当前标定样本以及第二模型,对第一模型进行训练。
上述标定样本是指具有真实标签的样本。具体地,训练后的神经网络模型用于进行基于图像识别的业务处理,则所选取的标定样本可以为图片。若训练后的神经网络模型用于进行基于用户分类的业务处理,则所选取的标定样本可以为用户。若训练后的神经网络模型用于进行基于音频识别的业务处理,则所选取的标定样本可以为音频。若训练后的神经网络模型用于进行基于文本分析的业务处理,则所选取的标定样本可以为文本。
上述对第一模型进行训练的步骤具体可以包括:
步骤2042,将当前标定样本输入第一模型,通过第一模型的输出得到当前标定样本对应的第一概率分布。
需要说明的是,由于当前标定样本是基于选取的一批标定样本确定的,因此,当前标定样本的个数可以为多个。从而,在步骤2042中,可以是通过第一模型的输出得到每个标定样本对应的第一概率分布。
具体地,在将多个当前标定样本输入第一模型之后,针对当前标定样本中的每个标定样本,第一模型可以输出该标定样本对应于各预定标签的多个概率值。其中,每个概率值表征该标定样本归属于对应预定标签的可能性。以训练后的神经网络模型用于进行图像识别的业务处理为例,上述各预定标签可以是指不同的图像类别。可以理解的是,每个标定样本对应于各预定标签的多个概率值,可以构成该标定样本对应的第一概率分布。
步骤2044,基于第一概率分布,确定对应于当前标定样本的预测标签。
应理解,在当前标定样本的个数为多个时,这里是指确定对应于每个标定样本的预测标签。具体地,可以通过如下两种方式确定对应于每个标定样本的预测标签。
第一种方式,从每个标定样本对应的第一概率分布中进行采样,并基于采样结果确定对应于该标定样本的预测标签。这里的采样结果可以是指采样到的概率值。需要说明的是,由于对于每个标定样本,与其对应的第一概率分布中每个概率值与一个预定标签相对应,从而在从第一概率分布中采样某个概率值之后,可以将该概率值对应的预定标签作为上述标定样本的预测标签。
第二种方式,从每个标定样本对应的第一概率分布中选择最大概率值,并基于最大概率值确定对应于该标定样本的预测标签。如,可以将最大概率值对应的预定标签作为该标定样本的预测标签。
步骤2046,将当前标定样本输入第二模型,通过第二模型的输出得到当前标定样本对应的第二概率分布。
同样地,这里可以是通过第二模型的输出得到每个标定样本对应的第二概率分布。这里的第二概率分布的定义可参见上述第一概率分布的定义,本说明书在此不复赘述。
可以理解的是,在将样本集合中的第一批样本分别输入第一模型和第二模型时,通过该两个模型分别输出的第一概率分布和第二概率分布是相同的。之后,随着第一模型的模型参数的调整,通过该两个模型预测的概率分布会逐渐形成差异。
需要说明的是,在本说明书中,第二模型输出的概率分布用于辅助训练第一模型,以确保每一步的训练不会出现大的震荡。
步骤2048,基于当前标定样本的标定标签和预测标签,确定第一预测损失,基于第一概率分布和第二概率分布,确定第二预测损失。
在一个例子中,可以基于如下的损失函数,确定第一预测损失。
应理解,上述公式1仅为本说明书给出的确定第一预测损失的一种示例,在实际应用中,还可以在公式1中加入正则项等,本说明书对此不作限定。
需要说明的是,在神经网络模型训练的过程中,通过不断减小上述第一预测损失,可使得神经网络模型能够正确输出预测标签。
上述基于第一概率分布和第二概率分布,确定第二预测损失的步骤具体可以为:计算第一概率分布与所述第二概率分布的差异度。这里的差异度包括以下任一种:KL散度、交叉熵以及JS散度。将计算得到的差异度作为第二预测损失。
以上述差异度为KL散度为例来说,可以基于如下的损失函数,确定第二预测损失。
应理解,上述公式2仅为本说明书给出的确定第二预测损失的一种示例,在实际应用中,还可以在公式2中加入正则项等,本说明书对此不作限定。
需要说明的是,在神经网络模型训练的过程中,通过不断减小上述第二预测损失,可使得标定样本的第一概率分布接近于第二概率分布,由此可以确保模型训练过程的稳定性。
步骤2050,以最小化第一预测损失和第二预测损失为目标,调整第一模型的模型参数。
在一个例子中,可以基于第一预测损失和第二预测损失各自对应的预定权重,对第一预测损失和第二预测损失进行加权求和,得到综合损失。该综合损失与第一预测损失和第二预测损失正相关。基于综合损失,调整第一模型的模型参数。至此,完成了第一模型的模型参数的一次调整。在完成该一次调整之后,第一模型与第二模型之间就会形成差异。
应理解,在实际应用中,上述步骤2042-步骤2050是循环执行的,也即第一模型的模型参数可以进行多次调整,且循环执行的终止条件是:样本集合中的样本全部选取完成。在样本集合中的样本全部选取完成之后,神经网络模型训练过程的一个周期结束。
步骤206,在基于样本集合中的各样本执行上述训练的步骤之后,将最后一次调整模型参数后的第一模型作为在当前周期训练后的神经网络模型。
之后,在进入下一周期之后,可以基于在当前周期训练后的神经网络模型,确定在下一周期待训练的第一模型,以及用于在下一周期辅助训练第一模型的第二模型,并执行下一周期的第一模型的训练的步骤;依次类推,直至到达最后一个周期,神经网络模型的训练过程结束。
应理解,上述在当前周期训练后的神经网络模型即为在当前周期训练后的学生模型。在一个例子中,如果直接将在当前周期训练后的神经网络模型作为在下一周期的学生模型和老师模型,那么在当前周期训练后的学习模型即为在下一周期待训练的学生模型,同时也为在下一周期的老师模型。
最后,需要强调的是,本说明书上文所称的“训练后的神经网络模型”是指在全部的迭代周期结束后,在最后一个迭代周期训练后的神经网络模型。
在本说明书中,训练后的神经网络模型可以用于进行业务处理。该业务处理可以包括以下任一种:基于图像识别的业务处理(如,人脸识别、目标检测等)、基于用户分类的业务处理(如,用户人群划分、用户服务定制等)、基于音频识别的业务处理(如,语音识别、声纹分析等)以及基于文本分析的业务处理(如,语音分析、意图识别等)等。
综合以上,本说明书提供的方案,在神经网络模型的每个周期,都会先确定一个老师模型,由于这个老师模型是基于在上一周期训练后的神经网络模型确定,从而其可以看作是在上一周期学习到的知识。之后,通过该老师模型来指导当前周期的学习模型的训练过程(如,缩小老师模型的预测结果与学习模型的预测结果之间的差异),从而可以避免传统技术中,由于样本的差异性,而导致的训练震荡的问题,进而可以提升训练的神经网络模型的准确性。
与上述神经网络模型的训练方法对应地,本说明书一个实施例还提供的一种神经网络模型的训练装置,如图3所示,该装置可以包括:
确定单元302,用于基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练第一模型的第二模型。
上述神经网络模型可以包括以下任一种:卷积神经网络CNN、循环神经网络RNN以及语言表征模型BERT。
训练单元304,用于从样本集合中选取一批标定样本作为当前标定样本,基于当前标定样本以及第二模型,对第一模型进行训练。
训练单元304具体可以包括:
输入子单元3042,用于将当前标定样本输入第一模型,通过第一模型的输出得到当前标定样本对应的第一概率分布。
确定子单元3044,用于基于第一概率分布,确定对应于当前标定样本的预测标签。
确定子单元3044具体可以用于:
从第一概率分布中进行采样,并基于采样结果确定对应于当前标定样本的预测标签。或者,
从第一概率分布中选择最大概率值,并基于最大概率值确定对应于当前标定样本的预测标签。
输入子单元3042,还用于将当前标定样本输入第二模型,通过第二模型的输出得到当前标定样本对应的第二概率分布。
确定子单元3044,还用于基于当前标定样本的标定标签和预测标签,确定第一预测损失。基于第一概率分布和第二概率分布,确定第二预测损失。
确定子单元3044还具体可以用于:
计算第一概率分布与第二概率分布的差异度,其中,该差异度可以包括以下任一种:KL散度、交叉熵以及JS散度。
将计算得到的差异度作为第二预测损失。
调整子单元3046,用于以最小化第一预测损失和第二预测损失为目标,调整第一模型的模型参数。
调整子单元3046具体可以用于:
基于第一预测损失和第二预测损失各自对应的预定权重,对第一预测损失和第二预测损失进行加权求和,得到综合损失。该综合损失与第一预测损失和第二预测损失正相关。
基于综合损失,调整第一模型的模型参数。
确定单元302,还用于在基于样本集合中的各样本执行上述训练步骤之后,将最后一次调整模型参数后的第一模型作为在当前周期训练后的神经网络模型。
在本说明书中,训练后的神经网络模型可以用于进行业务处理。该业务处理可以包括以下任一种:基于图像识别的业务处理、基于用户分类的业务处理、基于音频识别的业务处理以及基于文本分析的业务处理。相应地,上述标定样本可以包括以下任一种:图片、用户、音频以及文本。
本说明书上述实施例装置的各功能模块的功能,可以通过上述方法实施例的各步骤来实现,因此,本说明书一个实施例提供的装置的具体工作过程,在此不复赘述。
本说明书一个实施例提供的神经网络模型的训练装置,可以提升训练的神经网络模型的准确性。
另一方面,本说明书的实施例提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行图2所示的方法。
另一方面,本说明书的实施例提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现图2所示的方法。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于设备实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
结合本说明书公开内容所描述的方法或者算法的步骤可以硬件的方式来实现,也可以是由处理器执行软件指令的方式来实现。软件指令可以由相应的软件模块组成,软件模块可以被存放于RAM存储器、闪存、ROM存储器、EPROM存储器、EEPROM存储器、寄存器、硬盘、移动硬盘、CD-ROM或者本领域熟知的任何其它形式的存储介质中。一种示例性的存储介质耦合至处理器,从而使处理器能够从该存储介质读取信息,且可向该存储介质写入信息。当然,存储介质也可以是处理器的组成部分。处理器和存储介质可以位于ASIC中。另外,该ASIC可以位于服务器中。当然,处理器和存储介质也可以作为分立组件存在于服务器中。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。计算机可读介质包括计算机存储介质和通信介质,其中通信介质包括便于从一个地方向另一个地方传送计算机程序的任何介质。存储介质可以是通用或专用计算机能够存取的任何可用介质。
上述对本说明书特定实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
以上所述的具体实施方式,对本说明书的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本说明书的具体实施方式而已,并不用于限定本说明书的保护范围,凡在本说明书的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本说明书的保护范围之内。
Claims (12)
1.一种神经网络模型的训练方法,包括:
基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练所述第一模型的第二模型;
从样本集合中选取一批图片作为当前标定样本,基于当前标定样本以及所述第二模型,对所述第一模型进行训练,训练步骤包括:
将当前标定样本输入所述第一模型,通过所述第一模型的输出得到当前标定样本对应的第一概率分布;其中的每个概率值表征当前标定样本归属于对应图像类别的可能性;
基于所述第一概率分布,确定对应于当前标定样本的预测标签;
将当前标定样本输入所述第二模型,通过所述第二模型的输出得到当前标定样本对应的第二概率分布;
基于当前标定样本的标定标签和所述预测标签,确定第一预测损失;基于所述第一概率分布和所述第二概率分布,确定第二预测损失;
以最小化所述第一预测损失和所述第二预测损失为目标,调整所述第一模型的模型参数;
在基于所述样本集合中的各样本执行所述训练步骤之后,将最后一次调整模型参数后的所述第一模型作为在当前周期训练后的神经网络模型;
其中,在最后一个周期训练后的神经网络模型用于进行基于图像识别的业务处理。
2.根据权利要求1所述的方法,所述基于所述第一概率分布,确定对应于当前标定样本的预测标签,包括:
从所述第一概率分布中进行采样,并基于采样结果确定对应于当前标定样本的预测标签;或者,
从所述第一概率分布中选择最大概率值,并基于所述最大概率值确定对应于当前标定样本的预测标签。
3.根据权利要求1所述的方法,所述基于所述第一概率分布和所述第二概率分布,确定第二预测损失,包括:
计算所述第一概率分布与所述第二概率分布的差异度;其中,所述差异度包括以下任一种:KL散度、交叉熵以及JS散度;
将所述差异度作为所述第二预测损失。
4.根据权利要求1所述的方法,所述以最小化所述第一预测损失和所述第二预测损失为目标,调整所述第一模型的模型参数,包括:
基于所述第一预测损失和所述第二预测损失各自对应的预定权重,对所述第一预测损失和所述第二预测损失进行加权求和,得到综合损失;所述综合损失与所述第一预测损失和所述第二预测损失正相关;
基于所述综合损失,调整所述第一模型的模型参数。
5.根据权利要求1所述的方法,所述神经网络模型包括以下任一种:卷积神经网络CNN、循环神经网络RNN以及语言表征模型BERT。
6.一种神经网络模型的训练装置,包括:
确定单元,用于基于在上一周期训练后的神经网络模型,分别确定在当前周期待训练的第一模型,以及用于辅助训练所述第一模型的第二模型;
训练单元,用于从样本集合中选取一批图片作为当前标定样本,基于当前标定样本以及所述第二模型,对所述第一模型进行训练;
所述训练单元具体包括:
输入子单元,用于将当前标定样本输入所述第一模型,通过所述第一模型的输出得到当前标定样本对应的第一概率分布;其中的每个概率值表征当前标定样本归属于对应图像类别的可能性;
确定子单元,用于基于所述第一概率分布,确定对应于当前标定样本的预测标签;
所述输入子单元,还用于将当前标定样本输入所述第二模型,通过所述第二模型的输出得到当前标定样本对应的第二概率分布;
所述确定子单元,还用于基于当前标定样本的标定标签和所述预测标签,确定第一预测损失;基于所述第一概率分布和所述第二概率分布,确定第二预测损失;
调整子单元,用于以最小化所述第一预测损失和所述第二预测损失为目标,调整所述第一模型的模型参数;
所述确定单元,还用于在基于所述样本集合中的各样本执行训练步骤之后,将最后一次调整模型参数后的所述第一模型作为在当前周期训练后的神经网络模型;
其中,在最后一个周期训练后的神经网络模型用于进行基于图像识别的业务处理。
7.根据权利要求6所述的装置,所述确定子单元具体用于:
从所述第一概率分布中进行采样,并基于采样结果确定对应于当前标定样本的预测标签;或者,
从所述第一概率分布中选择最大概率值,并基于所述最大概率值确定对应于当前标定样本的预测标签。
8.根据权利要求6所述的装置,所述确定子单元还具体用于:
计算所述第一概率分布与所述第二概率分布的差异度;其中,所述差异度包括以下任一种:KL散度、交叉熵以及JS散度;
将所述差异度作为所述第二预测损失。
9.根据权利要求6所述的装置,所述调整子单元具体用于:
基于所述第一预测损失和所述第二预测损失各自对应的预定权重,对所述第一预测损失和所述第二预测损失进行加权求和,得到综合损失;所述综合损失与所述第一预测损失和所述第二预测损失正相关;
基于所述综合损失,调整所述第一模型的模型参数。
10.根据权利要求6所述的装置,所述神经网络模型包括以下任一种:卷积神经网络CNN、循环神经网络RNN以及语言表征模型BERT。
11.一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行权利要求1-5中任一项所述的方法。
12.一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-5中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010143596.6A CN111368997B (zh) | 2020-03-04 | 2020-03-04 | 神经网络模型的训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010143596.6A CN111368997B (zh) | 2020-03-04 | 2020-03-04 | 神经网络模型的训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111368997A CN111368997A (zh) | 2020-07-03 |
CN111368997B true CN111368997B (zh) | 2022-09-06 |
Family
ID=71212498
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010143596.6A Active CN111368997B (zh) | 2020-03-04 | 2020-03-04 | 神经网络模型的训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111368997B (zh) |
Families Citing this family (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112115997B (zh) * | 2020-09-11 | 2022-12-02 | 苏州浪潮智能科技有限公司 | 一种物体识别模型的训练方法、系统及装置 |
CN112560437B (zh) * | 2020-12-25 | 2024-02-06 | 北京百度网讯科技有限公司 | 文本通顺度的确定方法、目标模型的训练方法及装置 |
CN112883193A (zh) * | 2021-02-25 | 2021-06-01 | 中国平安人寿保险股份有限公司 | 一种文本分类模型的训练方法、装置、设备以及可读介质 |
CN113222139A (zh) * | 2021-04-27 | 2021-08-06 | 商汤集团有限公司 | 神经网络训练方法和装置、设备,及计算机存储介质 |
CN113178189B (zh) * | 2021-04-27 | 2023-10-27 | 科大讯飞股份有限公司 | 一种信息分类方法及装置、信息分类模型训练方法及装置 |
CN113313314A (zh) * | 2021-06-11 | 2021-08-27 | 北京沃东天骏信息技术有限公司 | 模型训练方法、装置、设备及存储介质 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018051841A1 (ja) * | 2016-09-16 | 2018-03-22 | 日本電信電話株式会社 | モデル学習装置、その方法、及びプログラム |
CN108805258A (zh) * | 2018-05-23 | 2018-11-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及其装置、计算机服务器 |
CN109376615A (zh) * | 2018-09-29 | 2019-02-22 | 苏州科达科技股份有限公司 | 用于提升深度学习网络预测性能的方法、装置及存储介质 |
CN109670572A (zh) * | 2017-10-16 | 2019-04-23 | 优酷网络技术(北京)有限公司 | 神经网络预测方法及装置 |
CN109754089A (zh) * | 2018-12-04 | 2019-05-14 | 浙江大华技术股份有限公司 | 一种模型训练系统及方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10755172B2 (en) * | 2016-06-22 | 2020-08-25 | Massachusetts Institute Of Technology | Secure training of multi-party deep neural network |
-
2020
- 2020-03-04 CN CN202010143596.6A patent/CN111368997B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018051841A1 (ja) * | 2016-09-16 | 2018-03-22 | 日本電信電話株式会社 | モデル学習装置、その方法、及びプログラム |
CN109670572A (zh) * | 2017-10-16 | 2019-04-23 | 优酷网络技术(北京)有限公司 | 神经网络预测方法及装置 |
CN108805258A (zh) * | 2018-05-23 | 2018-11-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及其装置、计算机服务器 |
CN109376615A (zh) * | 2018-09-29 | 2019-02-22 | 苏州科达科技股份有限公司 | 用于提升深度学习网络预测性能的方法、装置及存储介质 |
CN109754089A (zh) * | 2018-12-04 | 2019-05-14 | 浙江大华技术股份有限公司 | 一种模型训练系统及方法 |
Non-Patent Citations (2)
Title |
---|
Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation;Linfeng Zhang等;《arXiv》;20190517;第1-10页 * |
基于深度特征蒸馏的人脸识别;葛仕明等;《北京交通大学学报》;20171231;第41卷(第6期);第27-33页 * |
Also Published As
Publication number | Publication date |
---|---|
CN111368997A (zh) | 2020-07-03 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111368997B (zh) | 神经网络模型的训练方法及装置 | |
CN110705717B (zh) | 计算机执行的机器学习模型的训练方法、装置及设备 | |
CN110546656B (zh) | 前馈生成式神经网络 | |
US20210019599A1 (en) | Adaptive neural architecture search | |
CN110275939B (zh) | 对话生成模型的确定方法及装置、存储介质、电子设备 | |
US20220092416A1 (en) | Neural architecture search through a graph search space | |
US20200104687A1 (en) | Hybrid neural architecture search | |
EP3948850B1 (en) | System and method for end-to-end speech recognition with triggered attention | |
US10580432B2 (en) | Speech recognition using connectionist temporal classification | |
US11475225B2 (en) | Method, system, electronic device and storage medium for clarification question generation | |
CN111382573A (zh) | 用于答案质量评估的方法、装置、设备和存储介质 | |
CN111813954B (zh) | 文本语句中两实体的关系确定方法、装置和电子设备 | |
CN111191722B (zh) | 通过计算机训练预测模型的方法及装置 | |
CN110717027B (zh) | 多轮智能问答方法、系统以及控制器和介质 | |
CN111738017A (zh) | 一种意图识别方法、装置、设备及存储介质 | |
CN117539977A (zh) | 一种语言模型的训练方法及装置 | |
US20100296728A1 (en) | Discrimination Apparatus, Method of Discrimination, and Computer Program | |
CN112214592A (zh) | 一种回复对话评分模型训练方法、对话回复方法及其装置 | |
CN113555005B (zh) | 模型训练、置信度确定方法及装置、电子设备、存储介质 | |
CN113192530B (zh) | 模型训练、嘴部动作参数获取方法、装置、设备及介质 | |
CN115577797A (zh) | 一种基于本地噪声感知的联邦学习优化方法及系统 | |
CN111275780B (zh) | 人物图像的生成方法及装置 | |
CN113674745A (zh) | 语音识别方法及装置 | |
EP3619654A1 (en) | Continuous parametrizations of neural network layer weights | |
CN111581911B (zh) | 实时文本自动添加标点的方法、模型构建方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |