CN114511042A - 一种模型的训练方法、装置、存储介质及电子装置 - Google Patents
一种模型的训练方法、装置、存储介质及电子装置 Download PDFInfo
- Publication number
- CN114511042A CN114511042A CN202210353017.XA CN202210353017A CN114511042A CN 114511042 A CN114511042 A CN 114511042A CN 202210353017 A CN202210353017 A CN 202210353017A CN 114511042 A CN114511042 A CN 114511042A
- Authority
- CN
- China
- Prior art keywords
- target
- determining
- model
- loss value
- training data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明实施例提供了一种模型的训练方法、装置、存储介质及电子装置,其中,该方法包括:利用训练完成的目标老师模型从N个目标维度识别训练数据,确定训练数据的第一特征图以及训练数据在每个目标维度的第一识别结果;利用初始学生模型从N个目标维度识别训练数据,确定训练数据的第二特征图以及训练数据在每个目标维度的第二识别结果;基于第一识别结果、第二识别结果、第一特征图以及第二特征图确定初始学生模型的目标损失值;在目标损失值不满足预定条件的情况下,更新初始学生模型的网络参数,直到目标损失值满足预定条件为止,得到目标网络模型。通过本发明,达到一个目标网络模型可用于执行不同的任务的效果,提高了训练模型的效率。
Description
技术领域
本发明实施例涉及计算机领域,具体而言,涉及一种模型的训练方法、装置、存储介质及电子装置。
背景技术
知识蒸馏是一种常用的压缩技术,操作过程相对简单,同时可以获得较好的性能。知识蒸馏采用老师模型-学生模型框架,将复杂且精度较高的模型作为老师模型,简单轻量的小网络作为学生模型,老师模型学习能力较强,训练过程中将老师模型的知识迁移给学习能力较弱的学生模型,用来增加学生模型的学习能力和泛化能力,目的是让轻量的学生模型学习到和老师模型相近的精度,最终部署上线的就是这个轻量的学生模型。
常见的知识蒸馏方式主要用于单任务分类,即每个模型仅用于执行单一的动作,在由多个任务需要执行时,需要训练出多个模型,每个模型分别用于执行一个任务。
由此可知,相关技术中存在模型执行的任务单一的问题。
针对相关技术中存在的上述问题,目前尚未提出有效的解决方案。
发明内容
本发明实施例提供了一种模型的训练方法、装置、存储介质及电子装置,以至少解决相关技术中存在的模型执行的任务单一的问题。
根据本发明的一个实施例,提供了一种模型的训练方法,包括:利用训练完成的目标老师模型从N个目标维度识别训练数据,确定所述训练数据的第一特征图以及所述训练数据在每个所述目标维度的第一识别结果,其中,所述目标老师模型中包括所述N个第一子模型,一个所述第一子模型用于从一个所述目标维度识别所述训练数据;利用初始学生模型从所述N个所述目标维度识别所述训练数据,确定所述训练数据的第二特征图以及所述训练数据在每个所述目标维度的第二识别结果,其中,所述初始学生模型为经过初始训练后得到的网络模型,所述初始学生模型中包括所述N个第二子模型,一个所述第二子模型用于从一个所述目标维度识别所述训练数据;基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值;在所述目标损失值不满足预定条件的情况下,更新所述初始学生模型的网络参数,直到所述目标损失值满足所述预定条件为止,得到目标网络模型。
根据本发明的另一个实施例,提供了一种模型的训练装置,包括:第一识别模块,用于利用训练完成的目标老师模型从N个目标维度识别训练数据,确定所述训练数据的第一特征图以及所述训练数据在每个所述目标维度的第一识别结果,其中,所述目标老师模型中包括所述N个第一子模型,一个所述第一子模型用于从一个所述目标维度识别所述训练数据;第二识别模块,用于利用初始学生模型从所述N个所述目标维度识别所述训练数据,确定所述训练数据的第二特征图以及所述训练数据在每个所述目标维度的第二识别结果,其中,所述初始学生模型为经过初始训练后得到的网络模型,所述初始学生模型中包括所述N个第二子模型,一个所述第二子模型用于从一个所述目标维度识别所述训练数据;确定模块,用于基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值;训练模块,用于在所述目标损失值不满足预定条件的情况下,更新所述初始学生模型的网络参数,直到所述目标损失值满足所述预定条件为止,得到目标网络模型。
根据本发明的又一个实施例,还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现上述任一项中所述的方法的步骤。
根据本发明的又一个实施例,还提供了一种电子装置,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行上述任一项方法实施例中的步骤。
通过本发明,利用训练完成的目标老师模型从N个目标维度识别训练数据,以确定训练数据的第一特征图,以及训练数据在每个目标维度的第一识别结果,利用初始学生模型从N个目标维度识别训练数据,以确定训练数据的第二特征图以及训练数据在每个目标维度的第二识别结果。根据第一识别结果、第二识别结果、第一特征图以及第二特征图确定初始学生模型的目标损失值,在目标损失值不满足预定条件的情况下,更新初始学生模型的网络参数,直到目标损失值满足预定条件为止,得到目标学生模型。其中,目标老师模型中包括N个第一子模型,一个第一子模型用于从一个目标维度识别训练数据,初始学生模型中包括N个第二子模型,第一第二子模型用于从一个目标维度识别训练数据。由于目标老师模型和初始学生模型均能从不同的维度识别训练数据,因此,训练得到的目标网络模型可以从N个不同的目标维度识别数据,实现了一个目标网络模型可以用于执行不同的任务。因此,可以解决相关技术中存在的模型执行的任务单一的问题,达到一个目标网络模型可用于执行不同的任务的效果,提高了训练模型的效率。
附图说明
图1是本发明实施例的一种模型的训练方法的移动终端的硬件结构框图;
图2是根据本发明实施例的模型的训练方法的流程图;
图3是根据本发明示例性实施例的目标老师模型以及初始学生模型的网络架构示意图;
图4是根据本发明示例性实施例的确定第一损失值的过程示意图;
图5是根据本发明实施例的模型的训练装置的结构框图。
具体实施方式
下文中将参考附图并结合实施例来详细说明本发明的实施例。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
目前基于卷积神经网络的识别任务已经成为视觉领域研究的主流方向,在实际应用中,由于部署模型的硬件成本较为昂贵,对GPU的性能要求较高,往往希望模型部署时占用较低的内存同时产生较低的时延,因此模型部署时对网络轻量化的需求越来强烈。一般精度较好的模型都是些参数量多的大网络,甚至是多个模型集成得到的,这种模型推理速度慢,对资源部署要求较高,而且很难直接部署到服务中,模型压缩成为了一个重要的步骤。
主流的模型压缩技术主要有以下几种;(1)结构优化,通过优化网络结构的设计去减少模型的冗余和计算量,例如网络模块(block)层面的改进,像深度可分离卷积、分组卷积等结构,在保证网络性能的同时,减少了参数量和计算量。(2)剪枝技术,在预训练好的大型模型的基础上,设计对网络参数的评价准则,以此为根据删除“冗余”参数。(3)量化技术,用较低位宽表示典型的32位浮点型网络参数,网络参数包括权重、激活值、梯度和误差等等。(4)知识蒸馏,知识蒸馏通过将老师网络的知识迁移到学生网络中,使学生网络达到与老师网络相似的性能,同时又能起到模型压缩的目的。
知识蒸馏是一种常用的压缩技术,相对于其他几种压缩技术而言,操作过程相对简单,同时可以获得较好的性能。知识蒸馏采用老师模型-学生模型框架,将复杂且精度较高的模型作为老师模型,简单轻量的小网络作为学生模型,老师模型学习能力较强,训练过程中将老师模型的知识迁移给学习能力较弱的学生模型,用来增加学生模型的学习能力和泛化能力,目的是让轻量的学生模型学习到和老师模型相近的精度,最终部署上线的就是这个轻量的学生模型。
常见的知识蒸馏方式主要用于单任务分类,多任务相对于单任务蒸馏方式的难点:(1)不同任务间的损失平衡问题,不同任务的图片对应解决的问题也不同,因此样本学习的难易程度不同,训练过程中的损失也会随之产生差异,因此学习过程要自适应的调整每个任务间的损失,防止网络过度学习某个任务,而忽略别的任务,造成任务间的性能差距较大。(2)多个任务知识蒸馏过程学习相对单任务学习起来较为困难,由于任务多了之后,需要学习的知识也增多了,因此对知识蒸馏框架的学习能力要求较高。
针对相关技术中存在的上述问题,提出以下实施例:
本申请实施例中所提供的方法实施例可以在移动终端、计算机终端或者类似的运算装置中执行。以运行在移动终端上为例,图1是本发明实施例的一种模型的训练方法的移动终端的硬件结构框图。如图1所示,移动终端可以包括一个或多个(图1中仅示出一个)处理器102(处理器102可以包括但不限于微处理器MCU或可编程逻辑器件FPGA等的处理装置)和用于存储数据的存储器104,其中,上述移动终端还可以包括用于通信功能的传输设备106以及输入输出设备108。本领域普通技术人员可以理解,图1所示的结构仅为示意,其并不对上述移动终端的结构造成限定。例如,移动终端还可包括比图1中所示更多或者更少的组件,或者具有与图1所示不同的配置。
存储器104可用于存储计算机程序,例如,应用软件的软件程序以及模块,如本发明实施例中的模型的训练方法对应的计算机程序,处理器102通过运行存储在存储器104内的计算机程序,从而执行各种功能应用以及数据处理,即实现上述的方法。存储器104可包括高速随机存储器,还可包括非易失性存储器,如一个或者多个磁性存储装置、闪存、或者其他非易失性固态存储器。在一些实例中,存储器104可进一步包括相对于处理器102远程设置的存储器,这些远程存储器可以通过网络连接至移动终端。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
传输设备106用于经由一个网络接收或者发送数据。上述的网络具体实例可包括移动终端的通信供应商提供的无线网络。在一个实例中,传输设备106包括一个网络适配器(Network Interface Controller,简称为NIC),其可通过基站与其他网络设备相连从而可与互联网进行通讯。在一个实例中,传输设备106可以为射频(Radio Frequency,简称为RF)模块,其用于通过无线方式与互联网进行通讯。
在本实施例中提供了一种模型的训练方法,图2是根据本发明实施例的模型的训练方法的流程图,如图2所示,该流程包括如下步骤:
步骤S202,利用训练完成的目标老师模型从N个目标维度识别训练数据,确定所述训练数据的第一特征图以及所述训练数据在每个所述目标维度的第一识别结果,其中,所述目标老师模型中包括所述N个第一子模型,一个所述第一子模型用于从一个所述目标维度识别所述训练数据;
步骤S204,利用初始学生模型从所述N个所述目标维度识别所述训练数据,确定所述训练数据的第二特征图以及所述训练数据在每个所述目标维度的第二识别结果,其中,所述初始学生模型为经过初始训练后得到的网络模型,所述初始学生模型中包括所述N个第二子模型,一个所述第二子模型用于从一个所述目标维度识别所述训练数据;
步骤S206,基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值;
步骤S208,在所述目标损失值不满足预定条件的情况下,更新所述初始学生模型的网络参数,直到所述目标损失值满足所述预定条件为止,得到目标网络模型。
在上述实施例中,可以对初始老师模型进行训练,得到的收敛的目标老师模型。目标老师模型可以从不同的目标维度识别训练数据,一个目标维度可以认为是一种任务。例如,当训练数据为图像,图像中包括机动车时,目标维度可以包括车牌、车型、车头朝向等。即,当训练数据为图像,图像中包括机动车时,目标老师模型可以从多个目标维度识别图像,得到第一识别结果,第一识别结果中包括车牌信息、车型信息、车头朝向信息等。其中,目标老师模型中包括N个第一子模型,每个第一子模型用于从一个目标维度识别训练数据,不同第一子模型对应的目标维度不同,N为大于1的正整数,如2、3等。需要说明的是,上述N的取值仅是一种示例性说明,N还可以取4、6、8等。其中,当N为3时,目标老师模型以及初始学生模型的网络架构示意图可参见附图3,如图3所示,该模型架构包括3个分支,即3个第一子模型或第二子模型,分别用于执行不同的任务,得到不同任务的预测值,即第一识别结果或第二识别结果。
在上述实施例中,初始学生模型可以是利用训练数据训练到收敛的模型。初始学生模型中包括的第二子模型的数量与目标老师模型中包括的第一子模型的数量相同。
在上述实施例中,可以获取不同目标维度对应的训练数据集,不同目标维度间的训练数据集的均衡,可以采用随机采样策略扩充训练数据集,将不同目标维度的训练数据集扩充至同一级别。例如,可以在N个目标维度中选取训练数据量最大的目标维度,将这个最大的数据量作为其余维度需要扩增的目标数据量,对于其余需要扩增数据的任务可以按如下操作:基于本身的数据集进行随机的重复采样直至数据量和目标数据量相等。在得到每个目标维度对应的训练数据集后,可以将多维度训练数据集输入至参数量较多且复杂的老师网络训练至收敛,获得精度高的目标老师模型。同时,可以将同样的多维度训练数据集输入至参数量较少、且较为简单的学生网络训练至收敛,获得精度稍低的初始学生模型。其中,参数是指网络的卷积核的权重和偏置、全连接层的权重和偏置、BatchNorm层的两个可学习变量等。参数量较少是指网络中卷积核的层数较少,网络的模块结构较为简单,则参数就会较少。
在上述实施例中,在得到目标老师模型以及初始学生模型后,可以利用目标老师模型以及初始学生模型识别训练数据,得到目标老师模型输出的第一特征图以及第一识别结果,初始学生模型输出的第二特征图以及第二识别结果。并根据第一特征图、第二特征图、第一识别结果以及第二识别结果确定初始学生模型的目标损失值,在目标损失值不满足预定条件的情况下,更新初始学生模型的网络参数。再次将训练数据集中包括的训练数据输入至目标老师模型以及更新了网络参数的初始学生模型中,得到更新了网络参数的初始学生模型的目标损失值,在目标损失值不满足预定条件的情况下,再次更新初始学生模型的网络参数。直到基于更新了网络参数的初始学生模型的目标损失值满足预定条件,将最终的包括目标老师模型以及初始老师模型的模型确定为目标网络模型。
在上述实施例中,可以通过训练次数确定目标损失值是否满足预定条件,例如,可以设定预定条件为训练预定次数,若当前目标损失值对应的训练次数小于预定次数,则认为目标损失值不满足预定条件。若当前目标损失值对应的训练次数大于或等于预定次数,则认为目标损失值满足预定条件。
在上述实施例中,还可以通过损失值阈值确定目标损失值是否满足预定条件。例如,当目标损失值小于损失值阈值的情况下,确定目标损失值满足预定条件,当目标损失值大于或等于损失值阈值的情况下,确定目标损失值不满足预定条件。
在上述实施例中,预定条件可以既包括预定次数还包括损失值阈值,当目标损失值满足二者之一时,则认为目标损失值满足预定条件。
可选地,上述步骤的执行主体可以是后台处理器,或者其他的具备类似处理能力的设备,还可以是至少集成数据处理设备的机器,其中,数据处理设备可以包括计算机、手机等终端,但不限于此。
通过本发明,利用训练完成的目标老师模型从N个目标维度识别训练数据,以确定训练数据的第一特征图,以及训练数据在每个目标维度的第一识别结果,利用初始学生模型从N个目标维度识别训练数据,以确定训练数据的第二特征图以及训练数据在每个目标维度的第二识别结果。根据第一识别结果、第二识别结果、第一特征图以及第二特征图确定初始学生模型的目标损失值,在目标损失值不满足预定条件的情况下,更新初始学生模型的网络参数,直到目标损失值满足预定条件为止,得到目标学生模型。其中,目标老师模型中包括N个第一子模型,一个第一子模型用于从一个目标维度识别训练数据,初始学生模型中包括N个第二子模型,第一第二子模型用于从一个目标维度识别训练数据。由于目标老师模型和初始学生模型均能从不同的维度识别训练数据,因此,训练得到的目标网络模型可以从N个不同的目标维度识别数据,实现了一个目标网络模型可以用于执行不同的任务。因此,可以解决相关技术中存在的模型执行的任务单一的问题,达到一个目标网络模型可用于执行不同的任务的效果,提高了训练模型的效率。
在一个示例性实施例中,基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值包括:基于所述第一识别结果以及所述第二识别结果确定第一损失值;基于所述第一特征图以及所述第二特征图确定第二损失值;基于所述第一损失值以及所述第二损失值确定目标损失值。在本实施例中,在确定目标损失值时,可以分别确定第一识别结果和第二识别结果之间的第一损失值,确定第一特征图与第二特征图之间的第二损失值,根据第一损失值和第二损失值确定目标损失值。
在上述实施例中,蒸馏过程中,可以将目标老师模型的N个第一子模型的输出作为软标签,用于代替硬标签和初始学生模型做KL散度损失( Kullback–Leiblerdivergence),以确定第一损失值。其中,软标签即为第一识别结果。硬标签为训练数据中包括的标签信息, 即硬标签为预先为训练数据分配的标签信息。确定第一损失值的过程示意图可参见附图4。在确定第二损失值时,可以通过均方误差MSE的方式确定。
在一个示例性实施例中,基于所述第一识别结果以及所述第二识别结果确定第一损失值包括:确定所述第一识别结果中包括的每个所述目标维度的第一子识别结果以及所述第二识别结果中包括的每个所述目标维度的第二子识别结果;基于每个所述目标维度的所述第一子识别结果以及所述第二子识别结果确定第一子损失值;确定每个所述目标维度对应的目标权重;基于每个所述第一子损失值以及每个所述目标权重确定所述第一损失值。在本实施例中,在确定第一损失值,可以分别确定每个目标维度对应的第一子识别结果以及第二子识别结果,确定第一子识别结果和第二子识别结果之间的第一子损失值,以得到N个第一子损失值。还可以确定每个目标维度对应的目标权重,根据N个第一子损失值以及N个目标权重确定目标损失值。
在上述实施例中,可以确定第一子损失值与其对应的目标权重的乘积,以得到N个乘积,将N个乘积的和确定为第一损失值。还可以根据每个第一子损失值以及每个目标权重利用其他方式确定第一损失值。
在一个示例性实施例中,确定每个所述目标维度对应的目标权重包括:在首次确定所述第一损失值的情况下,将预先确定的初始权重确定为每个所述目标维度的所述目标权重;在非首次确定所述第一损失值,确定所述N个所述第一子损失值中包括的最大子损失值,按照第一预定方式增大所述最大子损失值对应的目标维度的第一当前权重,减小其他子损失值对应的目标维度的第二当前权重,以得到每个所述目标维度的所述目标权重,其中,所述其他子损失值为所述N个所述第一子损失值中包括的除所述最大子损失值之外的损失值。在本实施例中,为了平衡3个任务间损失的差异,可以采用自适应权重来加权不同任务的损失,根据当前损失的数值,即第一子损失值自动设置其加权的权值。在首次确定第一损失值时,可以预先确定的初始权重确定为目标维度的目标权重,如,目标权重可以为1/N。当在确定出第一损失值,进而得到目标损失值,在目标损失值不满足预定条件的情况下,更新初始学生模型的网络参数。在更新了网络参数后,再次确定第一损失值时,可以根据上次确定的每个第一子损失值调整对应的权重。调整目标权重的过程可以是增大最大子损失值对应的第一当前权重,减小其他子损失值对应的第二当前权重。
在一个示例性实施例中,基于每个所述第一子损失值以及每个所述目标权重确定所述第一损失值包括:确定每个所述目标权重与第一常数的乘积的倒数,得到N个第一倒数;确定每个所述第一子损失值与所述第一子损失值对应的所述第一倒数的第一乘积,以得到所述N个第一乘积;确定所述N个所述目标权重的乘积的N次方根;确定以第二常数为底的所述N次方根的对数;将所述N个所述第一乘积以及所述对数的第一和值确定为所述第一损失值。在本实施例中,可以确定目标权重与第一常数的乘积的倒数,得到N个第一倒数,其中,第一常数可以为N。确定每个第一子损失值与其对应的第一倒数的第一乘积,确定N个目标权重的乘积的N次方根,确定以第二常数为底的N次方根的对数,将N个第一乘积以及对数的第一和值确定为第一损失值。其中,第二常数可以为2、e、10等。
在上述实施例中,当N为3时,第一损失值可以表示为,其中,为第一个任务的KL损失,即第一子损失值,为第二个任务的KL损失,为第三个任务的KL损失, 、 、 定义为网络中可学习的目标权重,均初始化为1,跟随网络的训练而自适应的变化为3个任务的总的KL损失,即第一损失值。由于网络训练过程中希望前三项越来越小,则相对应的 、 、 变量就会越来大,则第四项就会越来越大,与前三项产生对抗,以保证变量处于合理的范围内。若某个任务的KL损失变大,则相对应的也会随之增大,表明该任务学习的权重变大。
在上述实施例中,在网络中定义3个可学习变量代表3个任务的损失权重,网络学习过程中会最小化公式1的KL损失函数,相应的会更新3个可学习变量的值,这就自适应的调整了相应的损失权重
在一个示例性实施例中,基于所述第一特征图以及所述第二特征图确定第二损失值包括:确定所述第一特征图与所述第二特征图之间的均方误差;将所述均方误差确定为所述第二损失值。在本实施例中,在确定第二损失值时,可以根据第一特征图与第二特征图之间的均方误差确定,将均方误差确定为第二损失值。其中,第一特征图可以是目标老师模型的backbone的最后一层feature map,第二特征图可以使初始学生模型的backbone的最后一层feature map。
在上述实施例中,取老师网络backbone的最后一层feature map和学生网络backbone的最后一层feature map进行维度压缩转化为二维矩阵,再进行维度对齐,获得相同维度的两个feature map,求两者的均方误差(MSE, Mean Square Error),将二者的均方误差确定为第二损失值。其中,求均方误差的步骤如下:若feature map的尺寸为四维向量B*C*H*W,B代表输入数据的批量大小,C代表feature map的通道数,H和W是feature map的尺寸大小。由于初始学生模型和目标老师模型在输入数据时的数据批量大小B是相等的,令目标老师模型的feature map为B*C1*H1*W1,维度压缩为二维向量B*(C1*H1*W1),初始学生模型的尺寸为B*C2*H2*W2,维度压缩为二维向量B*(C2*H2*W2),接下来进行维度对齐操作,目标老师模型的B*(C1*H1*W1)矩阵乘上自身B*(C1*H1*W1)的转置矩阵,获得B*B的二维向量,学生网络也进行同样操作,获得B*B的二维向量,两者求均方误差。
在一个示例性实施例中,基于所述第一损失值以及所述第二损失值确定目标损失值包括:确定所述第一损失值对应的第一权重以及所述第二损失值对应的第二权重;确定所述第一损失值与所述第一权重的第二乘积;确定所述第二损失值与所述第二权重的第三乘积;将所述第二乘积与所述第三乘积的第二和值确定为所述目标损失值。在本实施例中,可以将多任务自适应加权后的KL散度损失,即第一损失值与MSE损失,即第二损失值求和作为最终的损失,即目标损失值来监督整个蒸馏过程,增强学生网络的学习能力。其中,公式可表示目标损失值。为N个任务KL损失函数,即第一损失值,为feature map的均方误差损失,即第二损失值, 、 为固定的加权参数,可以在训练过程中根据损失的差异自定义设定,如取0.6、取2000。其中,监督是指训练的反向传播过程中损失函数对网络中的参数求梯度,利用该梯度更新网络中的各个参数,从而起到引导网络学习的方向。
在一个示例性实施例中,在利用初始学生模型从所述N个所述目标维度识别所述训练数据之前,所述方法还包括:利用训练数据集训练初始模型,以得到所述初始学生模型,其中,所述训练数据集中包括所述N个子训练数据集,所述初始模型中包括所述N个初始第二子模型,一个所述子训练数据集用于训练一个所述初始第二子模型,以得到所述第二子模型。在本实施例中,初始学生模型可以是预先利用训练数据集进行过训练的模型。在初始学生模型训练收敛后,再次利用目标老师模型输出的第一特征图以及第一识别结果进行二次训练,提高了初始学生模型的准确率。进一步提高了目标网络模型的准确率。
在一个示例性实施例中,在得到目标网络模型之后,所述方法还包括:利用所述目标网络模型从所述N个所述目标维度识别目标图像;输出所述目标图像在所述N个所述目标维度的目标识别结果。在本实施例中,在得到目标网络模型后,可以将目标网络模型应用到各个领域,在不同的领域,目标网络模型的各个子模型所执行的任务,即目标维度不同。例如,当应用到交通领域时,目标维度可以为车型、车牌、车身颜色、车头朝向等。当应用到人脸识别领域时,目标维度可以为性别、人脸特征、面部属性、对象的标识信息,如身份证号等。
在前述实施例中,采用一个网络进行N个任务的蒸馏学习,在输入网络前对N个任务的数据集进行类别均衡,保证数据规模在同一个级别上,将N个训练数据集输入同一个网络(backbone)中,训练过程中N个任务共享同一个backbone的参数,输出为N个任务的预测值。老师网络,即目标老师模型和学生网络,即初始学生模型均采用这种多任务框架,蒸馏过程中,老师网络输出软标签用于监督学生网络的学习,采用自适应权重用于平衡不同任务间的损失,为了增强知识蒸馏的性能,提取老师backbone的最后一层特征图(featuremap)和学生backbone的最后一层feature map求损失,作为蒸馏损失的一部分,增强了网络的学习能力。由于训练过程中能够自适应的平衡不同任务的蒸馏损失,防止模型过度学习某个任务而忽略别的任务,同时利用最后一层的feature map求MSE损失函数来增加蒸馏的性能,比简单的使用软标签损失函数学到的知识更多,学习能力更强。
相对于现有的单任务分类技术而言,这种多任务蒸馏方式将多个任务集成到一个模型中训练,简化了重复的训练步骤,同时获得了与单任务分类相当的性能。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到根据上述实施例的方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
在本实施例中还提供了一种模型的训练装置,该装置用于实现上述实施例及优选实施方式,已经进行过说明的不再赘述。如以下所使用的,术语“模块”可以实现预定功能的软件和/或硬件的组合。尽管以下实施例所描述的装置较佳地以软件来实现,但是硬件,或者软件和硬件的组合的实现也是可能并被构想的。
图5是根据本发明实施例的模型的训练装置的结构框图,如图5所示,该装置包括:
第一识别模块52,用于利用训练完成的目标老师模型从N个目标维度识别训练数据,确定所述训练数据的第一特征图以及所述训练数据在每个所述目标维度的第一识别结果,其中,所述目标老师模型中包括所述N个第一子模型,一个所述第一子模型用于从一个所述目标维度识别所述训练数据;
第二识别模块54,用于利用初始学生模型从所述N个所述目标维度识别所述训练数据,确定所述训练数据的第二特征图以及所述训练数据在每个所述目标维度的第二识别结果,其中,所述初始学生模型为经过初始训练后得到的网络模型,所述初始学生模型中包括所述N个第二子模型,一个所述第二子模型用于从一个所述目标维度识别所述训练数据;
确定模块56,用于基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值;
训练模块58,用于在所述目标损失值不满足预定条件的情况下,更新所述初始学生模型的网络参数,直到所述目标损失值满足所述预定条件为止,得到目标网络模型。
在一个示例性实施例中,确定模块56可以通过如下方式实现基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值:基于所述第一识别结果以及所述第二识别结果确定第一损失值;基于所述第一特征图以及所述第二特征图确定第二损失值;基于所述第一损失值以及所述第二损失值确定目标损失值。
在一个示例性实施例中,确定模块56可以通过如下方式实现基于所述第一识别结果以及所述第二识别结果确定第一损失值:确定所述第一识别结果中包括的每个所述目标维度的第一子识别结果以及所述第二识别结果中包括的每个所述目标维度的第二子识别结果;基于每个所述目标维度的所述第一子识别结果以及所述第二子识别结果确定第一子损失值;确定每个所述目标维度对应的目标权重;基于每个所述第一子损失值以及每个所述目标权重确定所述第一损失值。
在一个示例性实施例中,确定模块56可以通过如下方式实现确定每个所述目标维度对应的目标权重:在首次确定所述第一损失值的情况下,将预先确定的初始权重确定为每个所述目标维度的所述目标权重;在非首次确定所述第一损失值,确定所述N个所述第一子损失值中包括的最大子损失值,按照第一预定方式增大所述最大子损失值对应的目标维度的第一当前权重,减小其他子损失值对应的目标维度的第二当前权重,以得到每个所述目标维度的所述目标权重,其中,所述其他子损失值为所述N个所述第一子损失值中包括的除所述最大子损失值之外的损失值。
在一个示例性实施例中,确定模块56可以通过如下方式实现基于每个所述第一子损失值以及每个所述目标权重确定所述第一损失值:确定每个所述目标权重与第一常数的乘积的倒数,得到N个第一倒数;确定每个所述第一子损失值与所述第一子损失值对应的所述第一倒数的第一乘积,以得到所述N个第一乘积;确定所述N个所述目标权重的乘积的N次方根;确定以第二常数为底的所述N次方根的对数;将所述N个所述第一乘积以及所述对数的第一和值确定为所述第一损失值。
在一个示例性实施例中,确定模块56可以通过如下方式实现基于所述第一特征图以及所述第二特征图确定第二损失值:确定所述第一特征图与所述第二特征图之间的均方误差;将所述均方误差确定为所述第二损失值。
在一个示例性实施例中,确定模块56可以通过如下方式实现基于所述第一损失值以及所述第二损失值确定目标损失值:确定所述第一损失值对应的第一权重以及所述第二损失值对应的第二权重;确定所述第一损失值与所述第一权重的第二乘积;确定所述第二损失值与所述第二权重的第三乘积;将所述第二乘积与所述第三乘积的第二和值确定为所述目标损失值。
在一个示例性实施例中,所述装置可以用于在利用初始学生模型从所述N个所述目标维度识别所述训练数据之前:利用训练数据集训练初始模型,以得到所述初始学生模型,其中,所述训练数据集中包括所述N个子训练数据集,所述初始模型中包括所述N个初始第二子模型,一个所述子训练数据集用于训练一个所述初始第二子模型,以得到所述第二子模型。
在一个示例性实施例中,所述装置还可以用于在得到目标网络模型之后,利用所述目标网络模型从所述N个所述目标维度识别目标图像;输出所述目标图像在所述N个所述目标维度的目标识别结果。
需要说明的是,上述各个模块是可以通过软件或硬件来实现的,对于后者,可以通过以下方式实现,但不限于此:上述模块均位于同一处理器中;或者,上述各个模块以任意组合的形式分别位于不同的处理器中。
本发明的实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现上述任一项中所述的方法的步骤。
在一个示例性实施例中,上述计算机可读存储介质可以包括但不限于:U盘、只读存储器(Read-Only Memory,简称为ROM)、随机存取存储器(Random Access Memory,简称为RAM)、移动硬盘、磁碟或者光盘等各种可以存储计算机程序的介质。
本发明的实施例还提供了一种电子装置,包括存储器和处理器,该存储器中存储有计算机程序,该处理器被设置为运行计算机程序以执行上述任一项方法实施例中的步骤。
在一个示例性实施例中,上述电子装置还可以包括传输设备以及输入输出设备,其中,该传输设备和上述处理器连接,该输入输出设备和上述处理器连接。
本实施例中的具体示例可以参考上述实施例及示例性实施方式中所描述的示例,本实施例在此不再赘述。
显然,本领域的技术人员应该明白,上述的本发明的各模块或各步骤可以用通用的计算装置来实现,它们可以集中在单个的计算装置上,或者分布在多个计算装置所组成的网络上,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,并且在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。这样,本发明不限制于任何特定的硬件和软件结合。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (12)
1.一种模型的训练方法,其特征在于,包括:
利用训练完成的目标老师模型从N个目标维度识别训练数据,确定所述训练数据的第一特征图以及所述训练数据在每个所述目标维度的第一识别结果,其中,所述目标老师模型中包括所述N个第一子模型,一个所述第一子模型用于从一个所述目标维度识别所述训练数据;
利用初始学生模型从所述N个所述目标维度识别所述训练数据,确定所述训练数据的第二特征图以及所述训练数据在每个所述目标维度的第二识别结果,其中,所述初始学生模型为经过初始训练后得到的网络模型,所述初始学生模型中包括所述N个第二子模型,一个所述第二子模型用于从一个所述目标维度识别所述训练数据;
基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值;
在所述目标损失值不满足预定条件的情况下,更新所述初始学生模型的网络参数,直到所述目标损失值满足所述预定条件为止,得到目标网络模型。
2.根据权利要求1所述的方法,其特征在于,基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值包括:
基于所述第一识别结果以及所述第二识别结果确定第一损失值;
基于所述第一特征图以及所述第二特征图确定第二损失值;
基于所述第一损失值以及所述第二损失值确定目标损失值。
3.根据权利要求2所述的方法,其特征在于,基于所述第一识别结果以及所述第二识别结果确定第一损失值包括:
确定所述第一识别结果中包括的每个所述目标维度的第一子识别结果以及所述第二识别结果中包括的每个所述目标维度的第二子识别结果;
基于每个所述目标维度的所述第一子识别结果以及所述第二子识别结果确定第一子损失值;
确定每个所述目标维度对应的目标权重;
基于每个所述第一子损失值以及每个所述目标权重确定所述第一损失值。
4.根据权利要求3所述的方法,其特征在于,确定每个所述目标维度对应的目标权重包括:
在首次确定所述第一损失值的情况下,将预先确定的初始权重确定为每个所述目标维度的所述目标权重;
在非首次确定所述第一损失值,确定所述N个所述第一子损失值中包括的最大子损失值,按照第一预定方式增大所述最大子损失值对应的目标维度的第一当前权重,减小其他子损失值对应的目标维度的第二当前权重,以得到每个所述目标维度的所述目标权重,其中,所述其他子损失值为所述N个所述第一子损失值中包括的除所述最大子损失值之外的损失值。
5.根据权利要求3所述的方法,其特征在于,基于每个所述第一子损失值以及每个所述目标权重确定所述第一损失值包括:
确定每个所述目标权重与第一常数的乘积的倒数,得到N个第一倒数;
确定每个所述第一子损失值与所述第一子损失值对应的所述第一倒数的第一乘积,以得到所述N个第一乘积;
确定所述N个所述目标权重的乘积的N次方根;
确定以第二常数为底的所述N次方根的对数;
将所述N个所述第一乘积以及所述对数的第一和值确定为所述第一损失值。
6.根据权利要求2所述的方法,其特征在于,基于所述第一特征图以及所述第二特征图确定第二损失值包括:
确定所述第一特征图与所述第二特征图之间的均方误差;
将所述均方误差确定为所述第二损失值。
7.根据权利要求2所述的方法,其特征在于,基于所述第一损失值以及所述第二损失值确定目标损失值包括:
确定所述第一损失值对应的第一权重以及所述第二损失值对应的第二权重;
确定所述第一损失值与所述第一权重的第二乘积;
确定所述第二损失值与所述第二权重的第三乘积;
将所述第二乘积与所述第三乘积的第二和值确定为所述目标损失值。
8.根据权利要求1所述的方法,其特征在于,在利用初始学生模型从所述N个所述目标维度识别所述训练数据之前,所述方法还包括:
利用训练数据集训练初始模型,以得到所述初始学生模型,其中,所述训练数据集中包括所述N个子训练数据集,所述初始模型中包括所述N个初始第二子模型,一个所述子训练数据集用于训练一个所述初始第二子模型,以得到所述第二子模型。
9.根据权利要求1所述的方法,其特征在于,在得到目标网络模型之后,所述方法还包括:
利用所述目标网络模型从所述N个所述目标维度识别目标图像;
输出所述目标图像在所述N个所述目标维度的目标识别结果。
10.一种模型的训练装置,其特征在于,包括:
第一识别模块,用于利用训练完成的目标老师模型从N个目标维度识别训练数据,确定所述训练数据的第一特征图以及所述训练数据在每个所述目标维度的第一识别结果,其中,所述目标老师模型中包括所述N个第一子模型,一个所述第一子模型用于从一个所述目标维度识别所述训练数据;
第二识别模块,用于利用初始学生模型从所述N个所述目标维度识别所述训练数据,确定所述训练数据的第二特征图以及所述训练数据在每个所述目标维度的第二识别结果,其中,所述初始学生模型为经过初始训练后得到的网络模型,所述初始学生模型中包括所述N个第二子模型,一个所述第二子模型用于从一个所述目标维度识别所述训练数据;
确定模块,用于基于所述第一识别结果、所述第二识别结果、所述第一特征图以及所述第二特征图确定所述初始学生模型的目标损失值;
训练模块,用于在所述目标损失值不满足预定条件的情况下,更新所述初始学生模型的网络参数,直到所述目标损失值满足所述预定条件为止,得到目标网络模型。
11.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现所述权利要求1至9任一项中所述的方法的步骤。
12.一种电子装置,包括存储器和处理器,其特征在于,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行所述权利要求1至9任一项中所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210353017.XA CN114511042A (zh) | 2022-04-06 | 2022-04-06 | 一种模型的训练方法、装置、存储介质及电子装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210353017.XA CN114511042A (zh) | 2022-04-06 | 2022-04-06 | 一种模型的训练方法、装置、存储介质及电子装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114511042A true CN114511042A (zh) | 2022-05-17 |
Family
ID=81554727
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210353017.XA Pending CN114511042A (zh) | 2022-04-06 | 2022-04-06 | 一种模型的训练方法、装置、存储介质及电子装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114511042A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114724011A (zh) * | 2022-05-25 | 2022-07-08 | 北京闪马智建科技有限公司 | 一种行为的确定方法、装置、存储介质及电子装置 |
CN114821247A (zh) * | 2022-06-30 | 2022-07-29 | 杭州闪马智擎科技有限公司 | 一种模型的训练方法、装置、存储介质及电子装置 |
CN114998570A (zh) * | 2022-07-19 | 2022-09-02 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
-
2022
- 2022-04-06 CN CN202210353017.XA patent/CN114511042A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114724011A (zh) * | 2022-05-25 | 2022-07-08 | 北京闪马智建科技有限公司 | 一种行为的确定方法、装置、存储介质及电子装置 |
CN114821247A (zh) * | 2022-06-30 | 2022-07-29 | 杭州闪马智擎科技有限公司 | 一种模型的训练方法、装置、存储介质及电子装置 |
CN114998570A (zh) * | 2022-07-19 | 2022-09-02 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
CN114998570B (zh) * | 2022-07-19 | 2023-03-28 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110175671B (zh) | 神经网络的构建方法、图像处理方法及装置 | |
CN110633745B (zh) | 一种基于人工智能的图像分类训练方法、装置及存储介质 | |
CN114511042A (zh) | 一种模型的训练方法、装置、存储介质及电子装置 | |
CN109840531A (zh) | 训练多标签分类模型的方法和装置 | |
CN111797983A (zh) | 一种神经网络构建方法以及装置 | |
CN111079780A (zh) | 空间图卷积网络的训练方法、电子设备及存储介质 | |
CN108510058B (zh) | 神经网络中的权重存储方法以及基于该方法的处理器 | |
CN113570029A (zh) | 获取神经网络模型的方法、图像处理方法及装置 | |
US20210312295A1 (en) | Information processing method, information processing device, and information processing program | |
CN113627545B (zh) | 一种基于同构多教师指导知识蒸馏的图像分类方法及系统 | |
CN110874626B (zh) | 一种量化方法及装置 | |
CN111931901A (zh) | 一种神经网络构建方法以及装置 | |
CN113592041B (zh) | 图像处理方法、装置、设备、存储介质及计算机程序产品 | |
US20220083843A1 (en) | System and method for balancing sparsity in weights for accelerating deep neural networks | |
CN108229536A (zh) | 分类预测模型的优化方法、装置及终端设备 | |
CN114091554A (zh) | 一种训练集处理方法和装置 | |
CN116245142A (zh) | 用于深度神经网络的混合精度量化的系统和方法 | |
CN115018039A (zh) | 一种神经网络蒸馏方法、目标检测方法以及装置 | |
CN109325530A (zh) | 基于少量无标签数据的深度卷积神经网络的压缩方法 | |
US20220335293A1 (en) | Method of optimizing neural network model that is pre-trained, method of providing a graphical user interface related to optimizing neural network model, and neural network model processing system performing the same | |
CN114511083A (zh) | 一种模型的训练方法、装置、存储介质及电子装置 | |
CN111343602A (zh) | 基于进化算法的联合布局与任务调度优化方法 | |
CN116348881A (zh) | 联合混合模型 | |
CN114445692B (zh) | 图像识别模型构建方法、装置、计算机设备及存储介质 | |
WO2022127603A1 (zh) | 一种模型处理方法及相关装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |