CN114611692A - 模型训练方法、电子设备以及存储介质 - Google Patents
模型训练方法、电子设备以及存储介质 Download PDFInfo
- Publication number
- CN114611692A CN114611692A CN202210242894.XA CN202210242894A CN114611692A CN 114611692 A CN114611692 A CN 114611692A CN 202210242894 A CN202210242894 A CN 202210242894A CN 114611692 A CN114611692 A CN 114611692A
- Authority
- CN
- China
- Prior art keywords
- model
- target
- network layer
- parameter
- sample set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例公开了一种模型训练方法、电子设备以及存储介质,包括:获取包括目标样本集;基于所述目标样本集对预训练的参考模型进行调参,得到调参后模型;确定调参过程中,所述参考模型的预设参数对于所述目标样本集的参考权重;根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型,该方案可以提高蒸馏得到的学生模型(即目标模型)的预估能力。
Description
技术领域
本申请涉及计算机技术领域,具体涉及一种模型训练方法、电子设备以及存储介质。
背景技术
随着人工智能(AI,Artificial Intelligence)技术研究和进步,AI技术在多个领域展开研究和应用,比如,会在移动前端设备如智能相机、无人机以及机器人等计算资源有限的设备上部署卷积神经网络,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别、跟踪和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。
为了便于模型的部署以及面向移动端的推广,通常采用知识蒸馏的方式,将复杂、学习能力强的教师模型学到的特征表示“知识”蒸馏出来,传递给参数量小、学习能力弱的学生模型,然而,由于教师模型不同的基础模块层的参数不一样,也就是不同的层学到的信息不一样,导致蒸馏得到的学生模型的预估能力较低。
发明内容
本申请实施例提供一种模型训练方法、电子设备以及存储介质,可以提高蒸馏得到的学生模型的预估能力。
本申请实施例提供了一种模型训练方法,包括:
获取包括目标样本集;
基于所述目标样本集对预训练的参考模型进行调参,得到调参后模型;
确定调参过程中,所述参考模型的预设参数对于所述目标样本集的参考权重;
根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
可选的,在一些实施例中,所述根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型,包括:
计算所述调参后模型中各网络层对应的参考权重之和;
根据计算结果在所述调参后模型选择参考网络层;
将选择的参考网络层蒸馏至预设基础模型中,得到目标模型。
可选的,在一些实施例中,所述根据计算结果在所述调参后模型选择参考网络层,包括:
根据计算结果对所述调参后面模型的网络层进行排序;
在排序后的网络层中选择目标序号的网络层为参考网络层。
可选的,在一些实施例中,所述在排序后的网络层中选择目标序号的网络层为参考网络层,包括:
识别预设基础模型中网络层对应的层数;
根据所述层数,在排序后的网络层中选择目标序号的网络层为参考网络层。
可选的,在一些实施例中,所述将选择的参考网络层蒸馏至预设基础模型中,得到目标模型,包括:
根据所述参考网络层之间的顺序以及所述基础模型中目标网络层之间的顺序,构建所述参考网络层与目标网络层之间的映射关系;
基于所述映射关系,将所述目标网络层的参数更新为所述参考网络层的参数。
可选的,在一些实施例中,所述确定调参过程中,所述参考模型的预设参数对于所述目标样本集的参考权重,包括:
提取所述目标样本集中每个目标样本的样本标签;
获取调参过程中,所述参考模型针对所述目标样本的预测标签;
基于所述样本标签和预测标签,确定所述参考模型的预设参数对于所述目标样本集的参考权重。
可选的,在一些实施例中,所述基于所述样本标签和预测标签,确定所述参考模型的预设参数对于所述目标样本集的参考权重,包括:
获取预设权重计算公式;
根据所述权重计算公式、样本标签和预测标签,确定所述参考模型的预设参数对于所述目标样本集的参考权重。
可选的,在一些实施例中,所述基于所述目标样本集对预训练的参考模型进行调参,得到调参后模型,包括:
基于所述目标样本集对预训练的参考模型中全连接层对应的参数、激活函数对应的参数和/或卷积层对应的参数进行调整,得到调参后模型。
相应的,本申请还提供一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,其中,所述处理器执行所述程序时如上任一所述方法的步骤。
本申请还提供一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上任一所述方法的步骤。
本申请实施例在获取包括目标样本集后,基于所述目标样本集对预训练的参考模型进行调参,得到调参后模型,然后,确定调参过程中,所述参考模型的预设参数对于所述目标样本集的参考权重,最后,根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。本申请提供的模型训练的方案,可以根据预设参数对于目标样本集的参考权重,自适应地对预设基础模型进行知识蒸馏,从而得到目标模型,由此,可以提高蒸馏得到的学生模型(即目标模型)的预估能力。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的模型训练方法的流程示意图;
图2是本申请实施例提供的模型训练方法中构建映射关系的示意图;
图3是本申请实施例提供的模型训练装置的结构示意图;
图4是本申请实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请实施例提供一种模型训练方法、装置、电子设备和存储介质。
其中,该模型训练装置具体可以集成在服务器或者终端中,服务器可以包括一个独立运行的服务器或者分布式服务器,也可以包括由多个服务器组成的服务器集群,终端可以包括手机、平板电脑或个人计算机(PC,Personal Computer)。
以下分别进行详细说明。需说明的是,以下实施例的描述顺序不作为对实施例优先顺序的限定。
一种模型训练方法,包括:获取包括目标样本集;基于目标样本集对预训练的参考模型进行调参,得到调参后模型;确定调参过程中,参考模型的预设参数对于目标样本集的参考权重;根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
请参阅图1,图1为本申请实施例提供的模型训练方法的流程示意图。该模型训练方法的具体流程可以如下:
101、获取包括目标样本集。
其中,目标样本集可以由负样本组成,还可以是一部分正样本,另一部分是负样本,具体根据实际情况进行选择。
以目标样本为图像样本为例,可以采用邻域局部典型区域标注法对采集到的多个图像样本进行标注,得到多个标注了区域类型特征的图像样本。
其中,获取目标样本集的途径可以有多种,比如,可以从互联网、指定数据库和/或网页浏览记录中进行获取,具体可以根据实际应用的需求而定;同理,标注方式也可以根据实际应用的需求进行选择,比如,可以在工程师的指点下,由标注审核人员进行人工标注,或者,也可以通过训练标注模型来实现自动标注,等等,在此不作赘述。
102、基于目标样本集对预训练的参考模型进行调参,得到调参后模型。
本申请的模型可以理解为神经网络(NeuralNetworks,NN),神经网络是由大量的、简单的处理单元(称为神经元)广泛地互相连接而形成的复杂网络系统,它反映了人脑功能的许多基本特征,是一个高度复杂的非线性动力学习系统。可选地,参考模型可以为BERT模型、卷积神经网络、循环神经网络、长短期记忆人工神经网络、门控循环神经网络、前馈神经网络或生成对抗网络等等,具体根据实际需求进行选择。
以BERT模型为例,BERT的全称是Bidirectional Encoder Representations fromTransformer,也就是基于Transformer的双向编码器表征,同时,BERT模型也是一种预训练的语言模型,它的特点之一就是所有层都联合上下文语境进行预训练。训练方法是通过预测随机隐藏(Mask)的一部分输入符号(token)或者对输入的下一个句子进行分类,判断下一个句子是否真的属于给定语料里真实的跟随句子。
在本申请中,可以根据目标样本集中目标样本的真实标签以及参考模型预估目标样本的预估标签,对BERT模型进行调参,具体的,可以对BERT模型的Transformer的全连接层的参数和/或激活函数对应的参数进行调整;而对于卷积神经网络模型而言,还可以对其卷积层的参数进行调整,比如,调整卷积核的大小等等,即,步骤“基于目标样本集对预训练的参考模型进行调参,得到调参后模型”,具体可以包括:基于目标样本集对预训练的参考模型中全连接层对应的参数、激活函数对应的参数和/或卷积层对应的参数进行调整,得到调参后模型。
103、确定调参过程中,参考模型的预设参数对于目标样本集的参考权重。
在调参过程中,可以确定参考模型的预设参数对于目标样本集的参考权重,即,哪项参数的调整对于参考模型的预估能力影响最大、以及哪项项参数的调整对于参考模型的预估能力影响最小。
进一步的,在调参过程中,实际上可以理解为调整模型中网络层的参数,即,确定参考模型中网络层的参数对于目标样本集的参考权重,具体的,可以根据参考模型预测每个目标样本的结果以及目标样本自身的标签,对参考模型中网络层的参数进行调整,由此确定参考模型中网络层的参数对于目标样本集的参考权重,即,步骤“确定调参过程中,参考模型的预设参数对于目标样本集的参考权重”,具体可以包括:
(11)提取目标样本集中每个目标样本的样本标签;
(12)获取调参过程中,参考模型针对目标样本的预测标签;
(13)基于样本标签和预测标签,确定参考模型的预设参数对于目标样本集的参考权重。
其中,该参考模型可以根据实际应用的需求进行设定,例如,该参考检测模型可以包括四个卷积层和一个全连接层。
卷积层:主要用于对输入的图像(比如训练样本或需要识别的图像)进行特征提取,其中,卷积核大小可以根据实际应用而定,比如,从第一层卷积层至第四层卷积层的卷积核大小依次可以为(7,7),(5,5),(3,3),(3,3);可选的,为了降低计算的复杂度,提高计算效率,在本实施例中,这四层卷积层的卷积核大小可以都设置为(3,3),激活函数均采用“relu(线性整流函数,Rectified LinearUnit)”,而padding(padding,指属性定义元素边框与元素内容之间的空间)方式均设置为“same”,“same”填充方式可以简单理解为以0填充边缘,左边(上边)补0的个数和右边(下边)补0的个数一样或少一个。可选的,为了进一步减少计算量,还可以在第二至第四层卷积层中的所有层或任意1~2层进行下采样(pooling)操作,该下采样操作与卷积的操作基本相同,只不过下采样的卷积核为只取对应位置的最大值(max pooling)或平均值(average pooling)等,为了描述方便,在本发明实施例中,将均以在第二层卷积层和第三次卷积层中进行下采样操作,且该下采样操作具体为maxpooling为例进行说明。
需说明的是,为了描述方便,在本发明实施例中,将激活函数所在层和下采样层(也称为池化层)均归入卷积层中,应当理解的是,也可以认为该结构包括卷积层、激活函数所在层、下采样层(即池化层)和全连接层,当然,还可以包括用于输入数据的输入层和用于输出数据的输出层,在此不再赘述。
全连接层:可以将学到的特征映射到样本标记空间,其在整个卷积神经网络中主要起到“分类器”的作用,全连接层的每一个结点都与上一层(如卷积层中的下采样层)输出的所有结点相连,其中,全连接层的一个结点即称为全连接层中的一个神经元,全连接层中神经元的数量可以根据实际应用的需求而定,比如,在该孪生神经网络模型的上半分支网络和下半分支网络中,全连接层的神经元数量可以均设置为512个,或者,也可以均设置为128个,等等。与卷积层类似,可选的,在全连接层中,也可以通过加入激活函数来加入非线性因素,比如,可以加入激活函数sigmoid(S型函数)。
例如,参考模型执行的为图形识别任务,目标样本A的样本标签为“猫”,目标样本B的样本标签为“狗”,参考模型在初次预测目标样本A和目标样本B的预测标签分别为“狗”和“猫”,随后,对参考模型的网络层a以及网络层b分别进行参数调整,并利用调整一次的参考模型再次预测目标样本A和目标样本B,得到的预测标签均为和“猫”,再然后,将参考模型的网络层a以及网络层b进行参数初始化,紧接着,对参考模型的网络层a以及网络层c,并返回执行预测标签的步骤,直至预测结果与标注结果一致为止,即预测标签与样本标签相同,在该过程中,可以确定哪个网络层对于哪个样本的权重更高、以及哪个网络层对于哪个样本的权重更低。
可选的,在一些实施例中,还可以获取预设的权重计算公式,并通过该权重计算公式,确定参考模型的预设参数对于目标样本集的参考权重,即,步骤“基于样本标签和预测标签,确定参考模型的预设参数对于目标样本集的参考权重”,具体可以包括:
(21)获取预设权重计算公式;
(22)根据权重计算公式、样本标签和预测标签,确定参考模型的预设参数对于目标样本集的参考权重。
比如,在调参过程中可以利用信息度量指标来估计每个参数对于任务(即目标样本)的参考权重,假设参数之间相互独立,那么可以采用费希尔信息矩阵(FisherInformationMatrix,FIM),其中,FIM的计算公式如下:
其中,Fi(w)表示参考权重,xj表示本次数据集的第j个样本,yj表示本次数据集的第j个样本的标签,w表示模型的参数,表示模型的第i个参数,D表示目标样本集,p(yj|xj;w)表示样本xj经过模型预测后是标签yj的概率,根据该计算公式,则可确定参考模型的预设参数对于目标样本集的参考权重。
104、根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
知识蒸馏的本质是将复杂模型(即教师模型)中的“知识”迁移到简单模型(即学生模型)中,通过知识蒸馏的方式,使得简单模型更逼近于复杂模型的,从而使用更少的复杂度来获得类似的预测效果。
本申请提供的模型训练方法,在知识蒸馏之前,确定参考模型的预设参数对于目标样本集的参考权重,因此,在一些实施例中,可以将参考权重较高的网络层蒸馏至预设基础模型中,以得到目标模型。
比如,参考模型包括20层网络层,预设基础模型包括3层网络层,在一些实施例中,可以将参考权重最高的三层蒸馏至预设基础模型中,从而得到目标模型。
需要说明的是,参考权重越高,说明该网络层在预估目标样本的信息量越多,参考权重约定说明包含当前的预估任务和目标样本集的差异化信息越少,但是同时也包含了更多的深层语义信息,如:和预估任务相关的信息,受到不同样本集的影响较少,因此,为了避免后续得到目标模型出现过拟合的情况,在一些实施例中,可以在调参后模型中选择相应的网络层,从而进行后续的知识蒸馏,即,可选的,在一些实施例中,步骤“根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型”,包括:
(31)计算调参后模型中各网络层对应的参考权重之和;
(32)根据计算结果在调参后模型选择参考网络层;
(33)将选择的参考网络层蒸馏至预设基础模型中,得到目标模型。
比如,计算每个网络层对应的参考权重之和,具体的,网络层A的第一参数对应的参考权重为0.5,网络层A的第二参数对应的参考权重为0.3,网络层A的第一参数对应的参考权重为0.7,那么网络层A的参考权重之和为1.5;网络层B的第一参数对应的参考权重为0.9,网络层B的第二参数对应的参考权重为0.2,网络层B的第一参数对应的参考权重为0.7,那么网络层B的参考权重之和为1.8;网络层C的第一参数对应的参考权重为0.2,网络层C的第二参数对应的参考权重为0.3,网络层C的第一参数对应的参考权重为0.2,那么网络层C的参考权重之和为0.7;网络层D的第一参数对应的参考权重为0.5,网络层D的第二参数对应的参考权重为0.2,网络层D的第一参数对应的参考权重为0.6,那么网络层D的参考权重之和为1.1;网络层E的第一参数对应的参考权重为0.4,网络层E的第二参数对应的参考权重为0.3,网络层E的第一参数对应的参考权重为0.6,那么网络层E的参考权重之和为1.3;随后,根据实际需求选择相应的网络层进行知识蒸馏。
可选的,在一些实施例中,根据上述权重之和可以对调参后模型的网络层进行排序,排序结果为:网络层C-网络层D-网络层E-网络层B-网络层A,然后,根据该排序结果,选择目标序号的网络层为参考网络层,即,步骤“根据计算结果在调参后模型选择参考网络层”,具体可以包括:
(41)根据计算结果对调参后面模型的网络层进行排序;
(42)在排序后的网络层中选择目标序号的网络层为参考网络层。
比如,可以选择序号首位和末位的网络层作为参考网络层,当然,在一些实施例中,还可以根据预设基础模型中网络层的层数进行选择,即,步骤“在排序后的网络层中选择目标序号的网络层为参考网络层”,具体可以包括:
(51)识别预设基础模型中网络层对应的层数;
(52)根据层数,在排序后的网络层中选择目标序号的网络层为参考网络层。
比如,若预设基础模型中网络层对应的层数M为奇数时,则可以取排序最靠前M-(M/2)层和排序最后的(M/2)为参考网络层;若预设基础模型中网络层对应的层数M为偶数时,则可以取排序的第1层网络层和倒数第n层网络层作为参考网络层,其中,n为正整数,当然,选取参考网络层的方式还可以是其他,具体可以根据实际情况进行选择,在此不再赘述。
可以理解的是,在选择参考网络层后,需要建立调参后模型与预设基础模型之间网络层的对应关系,以便于后续知识蒸馏,即,可选的,在一些实施例中,步骤“将选择的参考网络层蒸馏至预设基础模型中,得到目标模型”,具体可以包括:
(61)根据参考网络层之间的顺序以及基础模型中目标网络层之间的顺序,构建参考网络层与目标网络层之间的映射关系;
(62)基于映射关系,将目标网络层的参数更新为参考网络层的参数。
比如,请参阅图2,基础模型中为3层结构的模型,其包括顺序排列的目标网络层a、目标网络层b以及目标网络层c,选择的参考网络层为网络层A、网络层B以及网络层C,其在参考模型中的结构顺序为网络层B-网络层A-网络层C,那么则可以建立目标网络层a与网络层B之间的映射关系、建立目标网络层b与网络层A之间的映射关系以及建立目标网络层c与网络层C之间的映射关系,最后,根据该建立的映射关系,将目标网络层的参数更新为参考网络层的参数。
本申请实施例在获取包括目标样本集后,基于目标样本集对预训练的参考模型进行调参,得到调参后模型,然后,确定调参过程中,参考模型的预设参数对于目标样本集的参考权重,最后,根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。本申请提供的模型训练的方案,可以根据预设参数对于目标样本集的参考权重,自适应地对预设基础模型进行知识蒸馏,从而得到目标模型,由此,可以提高蒸馏得到的学生模型(即目标模型)的预估能力。
为便于更好的实施本申请实施例的模型训练方法,本申请实施例还提供一种基于上述模型训练装置(简称训练装置)。其中名词的含义与上述模型训练方法中相同,具体实现细节可以参考方法实施例中的说明。
请参阅图3,图3为本申请实施例提供的模型训练装置的结构示意图,其中该训练装置可以包括获取模块201、调参模块202、确定模块203以及蒸馏模块204,具体可以如下:
获取模块201,用于获取包括目标样本集。
其中,目标样本集可以由负样本组成,还可以是一部分正样本,另一部分是负样本,具体根据实际情况进行选择。
获取模块201获取目标样本集的途径可以有多种,比如,可以从互联网、指定数据库和/或网页浏览记录中进行获取,具体可以根据实际应用的需求而定;同理,标注方式也可以根据实际应用的需求进行选择,比如,可以在工程师的指点下,由标注审核人员进行人工标注,或者,也可以通过训练标注模型来实现自动标注,等等,在此不作赘述。
调参模块202,用于基于目标样本集对预训练的参考模型进行调参,得到调参后模型。
比如,调参模块202可以根据目标样本集中目标样本的真实标签以及参考模型预估目标样本的预估标签,对参考模型进行调参,具体的,可以对参考模型中网络层的参数和/或激活函数对应的参数进行调整,还可以对其卷积层的参数进行调整,比如,调整卷积核的大小等等,即,可选的,在一些实施例中,调参模块202具体可用于:基于目标样本集对预训练的参考模型中全连接层对应的参数、激活函数对应的参数和/或卷积层对应的参数进行调整,得到调参后模型。
确定模块203,用于确定调参过程中,参考模型的预设参数对于目标样本集的参考权重。
在调参过程中,实际上可以理解为调整模型中网络层的参数,即,确定参考模型中网络层的参数对于目标样本集的参考权重,具体的,可以根据参考模型预测每个目标样本的结果以及目标样本自身的标签,对参考模型中网络层的参数进行调整,由此确定参考模型中网络层的参数对于目标样本集的参考权重。
可选的,在一些实施例中,确定模块203具体可以包括:
提取单元,用于提取目标样本集中每个目标样本的样本标签;
获取单元,用于获取调参过程中,参考模型针对目标样本的预测标签;
确定单元,用于基于样本标签和预测标签,确定参考模型的预设参数对于目标样本集的参考权重。
可选的,在一些实施例中,确定单元具体可以用于:获取预设权重计算公式;根据权重计算公式、样本标签和预测标签,确定参考模型的预设参数对于目标样本集的参考权重。
蒸馏模块204,用于根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
例如,具体的,蒸馏模块204在知识蒸馏之前,确定参考模型的预设参数对于目标样本集的参考权重,因此,在一些实施例中,可以将参考权重较高的网络层蒸馏至预设基础模型中,以得到目标模型。
可选的,在一些实施例中,蒸馏模块204具体可以包括:
计算单元,用于计算调参后模型中各网络层对应的参考权重之和;
选择单元,用于根据计算结果在调参后模型选择参考网络层;
蒸馏单元,用于将选择的参考网络层蒸馏至预设基础模型中,得到目标模型。
可选的,在一些实施例中,选择单元具体可以包括:
排序子单元,用于根据计算结果对调参后面模型的网络层进行排序;
选择子单元,用于在排序后的网络层中选择目标序号的网络层为参考网络层。
可选的,在一些实施例中,选择子单元具体可以用于:识别预设基础模型中网络层对应的层数;根据层数,在排序后的网络层中选择目标序号的网络层为参考网络层。
可选的,在一些实施例中,蒸馏单元具体可以用于:根据参考网络层之间的顺序以及基础模型中目标网络层之间的顺序,构建参考网络层与目标网络层之间的映射关系;基于映射关系,将目标网络层的参数更新为参考网络层的参数。
本申请实施例的获取模块201在获取包括目标样本集后,调参模块202基于目标样本集对预训练的参考模型进行调参,得到调参后模型,然后,确定模块203确定调参过程中,参考模型的预设参数对于目标样本集的参考权重,最后,蒸馏模块204根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。本申请提供的模型训练的方案,可以根据预设参数对于目标样本集的参考权重,自适应地对预设基础模型进行知识蒸馏,从而得到目标模型,由此,可以提高蒸馏得到的学生模型(即目标模型)的预估能力。
此外,本申请实施例还提供一种电子设备,如图4所示,其示出了本申请实施例所涉及的电子设备的结构示意图,具体来讲:
该电子设备可以包括一个或者一个以上处理核心的处理器301、一个或一个以上计算机可读存储介质的存储器302、电源303和输入单元304等部件。本领域技术人员可以理解,图4中示出的电子设备结构并不构成对电子设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。其中:
处理器301是该电子设备的控制中心,利用各种接口和线路连接整个电子设备的各个部分,通过运行或执行存储在存储器302内的软件程序和/或模块,以及调用存储在存储器302内的数据,执行电子设备的各种功能和处理数据,从而对电子设备进行整体监控。可选的,处理器301可包括一个或多个处理核心;优选的,处理器301可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器301中。
存储器302可用于存储软件程序以及模块,处理器301通过运行存储在存储器302的软件程序以及模块,从而执行各种功能应用以及模型训练。存储器302可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据电子设备的使用所创建的数据等。此外,存储器302可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储器302还可以包括存储器控制器,以提供处理器301对存储器302的访问。
电子设备还包括给各个部件供电的电源303,优选的,电源303可以通过电源管理系统与处理器301逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。电源303还可以包括一个或一个以上的直流或交流电源、再充电系统、电源故障检测电路、电源转换器或者逆变器、电源状态指示器等任意组件。
该电子设备还可包括输入单元304,该输入单元304可用于接收输入的数字或字符信息,以及产生与用户设置以及功能控制有关的键盘、鼠标、操作杆、光学或者轨迹球信号输入。
尽管未示出,电子设备还可以包括显示单元等,在此不再赘述。具体在本实施例中,电子设备中的处理器301会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文件加载到存储器302中,并由处理器301来运行存储在存储器302中的应用程序,从而实现各种功能,如下:
获取包括目标样本集;基于目标样本集对预训练的参考模型进行调参,得到调参后模型;确定调参过程中,参考模型的预设参数对于目标样本集的参考权重;根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
本申请实施例在获取包括目标样本集后,基于目标样本集对预训练的参考模型进行调参,得到调参后模型,然后,确定调参过程中,参考模型的预设参数对于目标样本集的参考权重,最后,根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。本申请提供的模型训练的方案,可以根据预设参数对于目标样本集的参考权重,自适应地对预设基础模型进行知识蒸馏,从而得到目标模型,由此,可以提高蒸馏得到的学生模型(即目标模型)的预估能力。
本领域普通技术人员可以理解,上述实施例的各种方法中的全部或部分步骤可以通过指令来完成,或通过指令控制相关的硬件来完成,该指令可以存储于一计算机可读存储介质中,并由处理器进行加载和执行。
为此,本申请实施例提供一种存储介质,其中存储有多条指令,该指令能够被处理器进行加载,以执行本申请实施例所提供的任一种模型训练方法中的步骤。例如,该指令可以执行如下步骤:
获取包括目标样本集;基于目标样本集对预训练的参考模型进行调参,得到调参后模型;确定调参过程中,参考模型的预设参数对于目标样本集的参考权重;根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
其中,该存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,RandomAccess Memory)、磁盘或光盘等。
由于该存储介质中所存储的指令,可以执行本申请实施例所提供的任一种模型训练方法中的步骤,因此,可以实现本申请实施例所提供的任一种模型训练方法所能实现的有益效果,详见前面的实施例,在此不再赘述。
以上对本申请实施例所提供的一种模型训练方法、装置、电子设备以及存储介质进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
Claims (10)
1.一种模型训练方法,其特征在于,包括:
获取包括目标样本集;
基于所述目标样本集对预训练的参考模型进行调参,得到调参后模型;
确定调参过程中,所述参考模型的预设参数对于所述目标样本集的参考权重;
根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型。
2.根据权利要求1所述的方法,其特征在于,所述根据确定的参考权重对预设基础模型进行知识蒸馏,得到目标模型,包括:
计算所述调参后模型中各网络层对应的参考权重之和;
根据计算结果在所述调参后模型选择参考网络层;
将选择的参考网络层蒸馏至预设基础模型中,得到目标模型。
3.根据权利要求2所述的方法,其特征在于,所述根据计算结果在所述调参后模型选择参考网络层,包括:
根据计算结果对所述调参后面模型的网络层进行排序;
在排序后的网络层中选择目标序号的网络层为参考网络层。
4.根据权利要求3所述的方法,其特征在于,所述在排序后的网络层中选择目标序号的网络层为参考网络层,包括:
识别预设基础模型中网络层对应的层数;
根据所述层数,在排序后的网络层中选择目标序号的网络层为参考网络层。
5.根据权利要求2所述的方法,其特征在于,所述将选择的参考网络层蒸馏至预设基础模型中,得到目标模型,包括:
根据所述参考网络层之间的顺序以及所述基础模型中目标网络层之间的顺序,构建所述参考网络层与目标网络层之间的映射关系;
基于所述映射关系,将所述目标网络层的参数更新为所述参考网络层的参数。
6.根据权利要求1至5任一项所述的方法,其特征在于,所述确定调参过程中,所述参考模型的预设参数对于所述目标样本集的参考权重,包括:
提取所述目标样本集中每个目标样本的样本标签;
获取调参过程中,所述参考模型针对所述目标样本的预测标签;
基于所述样本标签和预测标签,确定所述参考模型的预设参数对于所述目标样本集的参考权重。
7.根据权利要求6所述的方法,其特征在于,所述基于所述样本标签和预测标签,确定所述参考模型的预设参数对于所述目标样本集的参考权重,包括:
获取预设权重计算公式;
根据所述权重计算公式、样本标签和预测标签,确定所述参考模型的预设参数对于所述目标样本集的参考权重。
8.根据权利要求1至5任一项所述的方法,其特征在于,所述基于所述目标样本集对预训练的参考模型进行调参,得到调参后模型,包括:
基于所述目标样本集对预训练的参考模型中全连接层对应的参数、激活函数对应的参数和/或卷积层对应的参数进行调整,得到调参后模型。
9.一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,其中,所述处理器执行所述程序时实现如权利要求1-8任一项所述模型训练方法的步骤。
10.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,其中,所述计算机程序被处理器执行时实现如权利要求1-8任一项所述模型训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210242894.XA CN114611692A (zh) | 2022-03-11 | 2022-03-11 | 模型训练方法、电子设备以及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210242894.XA CN114611692A (zh) | 2022-03-11 | 2022-03-11 | 模型训练方法、电子设备以及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114611692A true CN114611692A (zh) | 2022-06-10 |
Family
ID=81863814
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210242894.XA Pending CN114611692A (zh) | 2022-03-11 | 2022-03-11 | 模型训练方法、电子设备以及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114611692A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116030323A (zh) * | 2023-03-27 | 2023-04-28 | 阿里巴巴(中国)有限公司 | 图像处理方法以及装置 |
CN116226678A (zh) * | 2023-05-10 | 2023-06-06 | 腾讯科技(深圳)有限公司 | 模型处理方法、装置、设备及存储介质 |
-
2022
- 2022-03-11 CN CN202210242894.XA patent/CN114611692A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116030323A (zh) * | 2023-03-27 | 2023-04-28 | 阿里巴巴(中国)有限公司 | 图像处理方法以及装置 |
CN116030323B (zh) * | 2023-03-27 | 2023-08-29 | 阿里巴巴(中国)有限公司 | 图像处理方法以及装置 |
CN116226678A (zh) * | 2023-05-10 | 2023-06-06 | 腾讯科技(深圳)有限公司 | 模型处理方法、装置、设备及存储介质 |
CN116226678B (zh) * | 2023-05-10 | 2023-07-21 | 腾讯科技(深圳)有限公司 | 模型处理方法、装置、设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20210042580A1 (en) | Model training method and apparatus for image recognition, network device, and storage medium | |
CN111079833B (zh) | 图像识别方法、装置以及计算机可读存储介质 | |
CN116415654A (zh) | 一种数据处理方法及相关设备 | |
CN112052948B (zh) | 一种网络模型压缩方法、装置、存储介质和电子设备 | |
Elhamifar et al. | Self-supervised multi-task procedure learning from instructional videos | |
CN114611692A (zh) | 模型训练方法、电子设备以及存储介质 | |
CN111708823B (zh) | 异常社交账号识别方法、装置、计算机设备和存储介质 | |
CN112329948A (zh) | 一种多智能体策略预测方法及装置 | |
CN113590876A (zh) | 一种视频标签设置方法、装置、计算机设备及存储介质 | |
CN112418302A (zh) | 一种任务预测方法及装置 | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
CN111046655B (zh) | 一种数据处理方法、装置及计算机可读存储介质 | |
CN112560639A (zh) | 人脸关键点数目转换方法、系统、电子设备及存储介质 | |
CN111522926A (zh) | 文本匹配方法、装置、服务器和存储介质 | |
CN115168720A (zh) | 内容交互预测方法以及相关设备 | |
CN113449840A (zh) | 神经网络训练方法及装置、图像分类的方法及装置 | |
CN113591509A (zh) | 车道线检测模型的训练方法、图像处理方法及装置 | |
CN112633425B (zh) | 图像分类方法和装置 | |
CN115878750A (zh) | 信息处理方法、装置、设备及计算机可读存储介质 | |
WO2023170067A1 (en) | Processing network inputs using partitioned attention | |
Thiodorus et al. | Convolutional neural network with transfer learning for classification of food types in tray box images | |
CN115168722A (zh) | 内容交互预测方法以及相关设备 | |
CN114648762A (zh) | 语义分割方法、装置、电子设备和计算机可读存储介质 | |
CN117010480A (zh) | 模型训练方法、装置、设备、存储介质及程序产品 | |
CN110826726B (zh) | 目标处理方法、目标处理装置、目标处理设备及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |