CN114764593A - 一种模型训练方法、模型训练装置及电子设备 - Google Patents

一种模型训练方法、模型训练装置及电子设备 Download PDF

Info

Publication number
CN114764593A
CN114764593A CN202210286944.4A CN202210286944A CN114764593A CN 114764593 A CN114764593 A CN 114764593A CN 202210286944 A CN202210286944 A CN 202210286944A CN 114764593 A CN114764593 A CN 114764593A
Authority
CN
China
Prior art keywords
task
loss
model
training
iteration period
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210286944.4A
Other languages
English (en)
Inventor
王涵柳
庞建新
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ubtech Robotics Corp
Original Assignee
Ubtech Robotics Corp
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ubtech Robotics Corp filed Critical Ubtech Robotics Corp
Priority to CN202210286944.4A priority Critical patent/CN114764593A/zh
Publication of CN114764593A publication Critical patent/CN114764593A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/11Complex mathematical operations for solving equations, e.g. nonlinear equations, general mathematical optimization problems
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • G06F18/253Fusion techniques of extracted features

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Mathematical Analysis (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Artificial Intelligence (AREA)
  • Computational Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Optimization (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Pure & Applied Mathematics (AREA)
  • Operations Research (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本申请公开了一种模型训练方法、模型训练装置、电子设备及计算机可读存储介质。其中,该方法包括:在待训练的多任务模型的训练过程中,分别确定所述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数;基于所述损失函数在所述当前的迭代周期内对所述待训练的多任务模型进行训练。通过本申请方案,可以提升多任务模型的准确率。

Description

一种模型训练方法、模型训练装置及电子设备
技术领域
本申请属于人工智能技术领域,尤其涉及一种模型训练方法、模型训练装置、电子设备及计算机可读存储介质。
背景技术
对运算力较小的处理器来说,其在同时处理多个任务的时,需要开启多个深度学习模型,这可能导致出现处理器的CPU占用过高,处理器过热的情况。为缓解这种情况,提出了多任务模型,处理器可通过运行一个多任务模型来实现多个任务的同时处理。考虑到多任务模型中所涉及的任务数量较多,目前的多任务模型在准确率上的表现仍有待提高。
发明内容
本申请提供了一种模型训练方法、模型训练装置、电子设备及计算机可读存储介质,可以提升多任务模型的准确率。
第一方面,本申请提供了一种模型训练方法,包括:
在待训练的多任务模型的训练过程中,分别确定上述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;
根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数;
基于所述损失函数在所述当前的迭代周期内对所述待训练的多任务模型进行训练。
第二方面,本申请提供了一种模型训练装置,包括:
确定模块,用于在待训练的多任务模型的训练过程中,分别确定上述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;
构建模块,用于根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数;
训练模块,用于基于所述损失函数在所述当前的迭代周期内对所述待训练的多任务模型进行训练。
第三方面,本申请提供了一种电子设备,上述电子设备包括存储器、处理器以及存储在上述存储器中并可在上述处理器上运行的计算机程序,上述处理器执行上述计算机程序时实现如上述第一方面的方法的步骤。
第四方面,本申请提供了一种计算机可读存储介质,上述计算机可读存储介质存储有计算机程序,上述计算机程序被处理器执行时实现如上述第一方面的方法的步骤。
第五方面,本申请提供了一种计算机程序产品,上述计算机程序产品包括计算机程序,上述计算机程序被一个或多个处理器执行时实现如上述第一方面的方法的步骤。
本申请与现有技术相比存在的有益效果是:首先在待训练的多任务模型的训练过程中,分别确定所述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重,然后根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数,最后基于所述损失函数在所述当前的迭代周期内对所述待训练的多任务模型进行训练。上述过程为各任务的损失设计了随迭代周期而动态变化的损失权重,使得每个迭代周期内对多任务模型的训练更有针对性,能够一定程度上提升最终训练所得的多任务模型的准确率。
可以理解的是,上述第二方面至第五方面的有益效果可以参见上述第一方面中的相关描述,在此不再赘述。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的模型训练方法的实现流程示意图;
图2是本申请实施例提供的多任务模型的结构示例图;
图3是本申请实施例提供的模型训练装置的结构框图;
图4是本申请实施例提供的电子设备的结构示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
为了说明本申请所提出的技术方案,下面通过具体实施例来进行说明。
下面对本申请实施例所提出的模型训练方法作出说明。请参阅图1,该模型训练方法的实现流程详述如下:
步骤101,在待训练的多任务模型的训练过程中,分别确定待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重。
在构建多任务模型的损失函数时,一般最简单的方式是直接将各个任务的损失值相加,得到多任务模型整体的损失值;也即,多任务模型整体的损失值来源于不同任务的损失值之和。这种损失函数可表达为如下公式(1):
Loss=∑kLk (1)
其中,Lk是多任务模型中第k个任务的损失值;Loss是多任务模型需要回归的总损失值。
由于不同任务的损失值的量级很有可能不一样,基于式(1)的损失函数训练多任务模型时,可能会导致多任务模型的学习被某个任务所主导或学偏。当该多任务模型倾向于去拟合某个任务时,其它任务的效果很可能会受到负面影响,导致效果相对变差。为解决该问题,对多任务模型的损失函数进行了调整。具体地,在调整后的损失函数中,每个任务都被配置了一个固定的损失权重。这种调整后的损失函数可表达为如下公式(2):
Loss=∑kwk*Lk (2)
其中,wk是多任务模型中第k个任务的损失权重,其它参数在前文已有释义,此处不作赘述。需注意的是,通常来说,各任务的损失权重是由开发人员自己设计的,比如,实验中发现任务1的损失值只有0.001,而任务2的损失为1000,数量级显著不对等,则可人为设计任务1的损失权重为1000000,使得任务1和任务2的损失值的数量级相同。
然而,这种损失权重的设置方式也可能存在问题,这是因为不同任务的学习难易程度是不同的,且训练过程中不同任务很可能处于不同的学习阶段,这导致上式(2)中所设置的固定的损失权重在某个阶段可能会限制了任务的学习。
基于此,本申请实施例提出将各任务的损失权重动态化。具体而言,针对任一任务来说,该任务的损失权重会在每个学习周期(也即迭代周期)开始时,基于该任务的收敛速度和学习难易程度进行更新。
在一些实施例中,记当前的迭代周期为第t迭代周期,则对于任一任务m来说,其在该第t迭代周期的损失权重wm(t)可以通过如下方式确定:
首先,根据该任务在第t-1迭代周期的损失值及该任务在第t-2迭代周期的损失值,计算得到该任务在第t-1迭代周期的训练速度,可表达为如下公式(3):
Figure BDA0003560265060000041
其中,Lm(t-1)是任务m在第t-1迭代周期的损失值;Lm(t-2)是任务m在第t-2迭代周期的损失值;rm(t-1)是任务m在第t-1迭代周期的训练速度。可以理解,r的值越小,则表示对应任务的训练速度越快。
然后,根据该任务在t-1迭代周期的训练速度,以及所有任务在t-1迭代周期的训练速度之和,计算得到该任务在第t迭代周期的损失权重,可表达为如下公式(4):
Figure BDA0003560265060000051
其中,K是多任务模型中的任务总数;wm(t)是任务m在第t迭代周期的损失权重;rk(t-1)是多任务模型中第k个任务在t-1迭代周期的训练速度,其它参数在前文已有释义,此处不作赘述。可以理解,上式(4)中,分子是K与任务m在第t-1迭代周期的训练速度的乘积,分母是任务1在第t-1迭代周期的训练速度、任务2在第t-1迭代周期的训练速度直至任务K在第t-1迭代周期的训练速度的和。通过上式(4),表达了任务m的损失值在总损失值的计算中的贡献大小。可以理解,对某一任务来说,该任务的损失值在总损失值的计算中的贡献越大,则该任务的损失权重可以越大;反之,贡献越小,则说明该任务已收敛,该任务的损失权重可以越小。由此,可使得多任务模型的训练更偏向于收敛难以训练的任务。
在一些实施例中,通过上式(3)及(4)可知,对任一任务来说,其在每个迭代周期所采用的损失权重通常是不同的,这会使损失值的变化太灵敏(也即敏感程度过高),导致出现多任务模型不知道怎么学习的问题。基于此,可考虑引入指数移动平均值(ExponentialMoving Average,EMA)来降低损失权重变化的敏感程度。记当前的迭代周期为第t迭代周期,则对于任一任务m来说,其在该第t迭代周期的损失权重wm(t)也可以通过如下方式确定:
首先,根据该任务在第t-1迭代周期的损失值及该任务在第t-2迭代周期的损失值,计算得到该任务在第t-1迭代周期的训练速度,如上式(3),此处不再赘述。
然后,根据该任务在第t-1迭代周期的训练速度、该任务在第t-2迭代周期的优化后的训练速度及预设的计算权重,计算该任务在第t-1迭代周期的优化后的训练速度,可表达为如下公式(5):
vm(t-1)=β*vm(t-2)+(1-β)*rm(t-1) (5)
其中,vm(t-1)是任务m在第t-1迭代周期的优化后的训练速度;β是计算权重;vm(t-2)是任务m在第t-2迭代周期的优化后的训练速度;rm(t-1)是任务m在第t-1迭代周期的训练速度。可以理解,上式(5)采用了EMA的计算思想,认为过去1/(1-β)个时刻之前的数值平均会衰减到1/e的加权比例。基于上式(5),后续所计算出的任务的损失权重可不仅只依赖于任务当前的损失值的减小幅度,还平均了任务之前的损失值的变化幅度,由此降低了权重变化的敏感程度。
最后,根据该任务在t-1迭代周期的优化后的训练速度,以及所有任务在t-1迭代周期的优化后的训练速度之和,计算得到该任务在第t迭代周期的损失权重,可表达为如下公式(6):
Figure BDA0003560265060000061
上式(6)中的各个参数符号在前文已有示例,此处不再赘述。可以理解,上式(6)中,分子是K与任务m在第t-1迭代周期的优化后的训练速度的乘积,分母是任务1在第t-1迭代周期的优化后的训练速度、任务2在第t-1迭代周期的优化后的训练速度直至任务K在第t-1迭代周期的优化后的训练速度的和。
可以理解,基于上式(3)、(5)及(6)所计算出的损失权重与基于上式(3)及(4)所计算出的损失权重相比,采用了指数移动平均值来降低损失权重变化的敏感程度,更好地表达了任务m的损失值在总损失值的计算中的贡献大小。
可以理解,在采用以上这两种损失权重的确定方式计算当前的迭代周期的损失权重时,涉及到了前两个迭代周期的相关数据。因而,t的取值应至少为3。这意味着在多任务模型的训练过程中,前两个迭代周期可以采用固定的损失权重,从第三个迭代周期开始,才采用以上这两种损失权重的确定方式计算当前的迭代周期的损失权重。也即,从第三个迭代周期开始,各任务的损失权重才开始动态变化。
在一些实施例中,为保障计算所得到的各个任务的损失权重均大于0,可通过指数函数再对上式(4)及上式(6)再作优化,具体为将上式(4)所涉及的训练速度作指数处理,以及,对上式(6)所涉及的优化后的训练速度作指数处理。则,优化后的上式(4)可表达为如下公式(7):
Figure BDA0003560265060000071
优化后的上式(6)可表达为如下公式(8):
Figure BDA0003560265060000072
步骤102,根据损失权重构建待训练的多任务模型在当前的迭代周期的损失函数。
在本申请实施例中,记当前的迭代周期为第t迭代周期,则所构建的损失函数可表达为如下公式(9):
Loss=∑kwk(t)*Lk (9)
步骤103,基于该损失函数在当前的迭代周期内对待训练的多任务模型进行训练。
在本申请实施例中,在当前的迭代周期内,各任务所对应的损失权重不再变化,电子设备可根据为该当前的迭代周期所构建的损失函数对待训练的多任务模型进行训练。
可以理解,通过步骤101及102构建了当前迭代周期内所使用的损失函数。该损失函数中,难以学习的任务被赋予了较大的损失权重,而易学习和已收敛的任务被赋予了较小的损失权重,这使得当前的迭代周期内基于该损失函数进行训练时,能够更侧重于训练那些难以学习的任务,使得多任务模型中的各个任务的准确率均能有所增长。
在一些实施例中,在对多任务模型进行训练时,还有一个挑战是数据集的扩充。由于当多任务模型所涉及到的任务较多时,较难获得一个所有任务都有对应真实值的数据集,这就要求数据集中的很多真实值需要人工进行标注,导致人力成本的增加。而在训练过程中,可能出现部分任务已经达到了比较好的准确率,此时,不再需要对该部分任务增加数据集来提升准确率。基于此,本申请实施还提出了损失掩码,通过损失掩码来增加多任务中单个任务或部分任务的准确率。则步骤102可具体表现为:
根据损失权重及损失掩码,构建待训练的多任务模型在当前的迭代周期的损失函数,其中,该损失掩码根据当前所使用的标签数据而确定。
为便于理解,根据损失权重及损失掩码所构建的损失函数可表达为如下公式(10):
Loss=∑kLMk*wk(t)*Lk (10)
其中,LMk为第k个任务在当前的损失掩码中所对应的掩码值。其中,损失掩码值可通过如下过程确定:基于当前所使用的标签数据中的各个值确定目标任务,其中,目标任务指的是:在该标签数据中的值为非真实值的任务;将该目标任务在损失掩码中所对应的掩码值确定为第一预设掩码值,该第一预设掩码值为“0”;将除目标任务外的其它任务在损失掩码中所对应的掩码值确定为第二预设掩码值,该第二预设掩码值为“1”。
仅作为示例,在制作各个样本数据所对应的标签数据时,在不存在任务(回归任务或分类任务)的真实值为“-1”的情况下,可将没有真实值的标签设置为“-1”,由此可生成一个对应多任务模型的的标签数据。将采用此种方式生成的标签数据混入有全部真实值的数据集中,打乱顺序进行训练。在训练过程中,针对这种标签数据(也即存在没有真实值的标签数据)生成对应的损失掩码(loss mask)。
例如,多任务模型为人脸属性检测任务模型,其包括如下任务:人脸关键点检测任务,人脸姿态检测任务,人脸微笑检测任务,人脸年龄检测任务,人脸性别检测任务,人脸口罩检测任务,人脸眼镜检测任务,人脸颜值检测任务及人脸质量检测。
对于某一样本图像来说,该样本图像仅有年龄,性别,口罩和眼镜的真实值,则基于该样本图像所生成的标签数据可以为(-1,-1,-1,2,1,66,3,-1,-1);对应地,基于该样本图像的标签数据而生成的损失掩码为(0,0,0,1,1,1,1,0,0)。也即,在当前以该样本图像作为输入时,只有人脸年龄检测任务,人脸性别检测任务,人脸口罩检测任务及人脸眼镜检测任务在当前的损失掩码中所对应的掩码值为“1”,其它各任务在当前的损失掩码中所对应的掩码值均为“0”。通过公式(10)可知,在当前以该样本图像作为输入时,计算所得的总损失值中,实际只考虑了标签为真实值的任务的损失值,其余那些标签为非真实值的任务的损失值未被考虑;也即,多任务模型只会回传标签为真实值的任务的损失值,不会回传标签为非真实值的任务的损失值。
通过以上过程,能够在不为训练样本标注多个真实值的标签的情况下,实现多任务模型中的单任务或部分任务的准确率增长,节约了标注时间及标注所需的人力成本。
在一些实施例中,多任务模型中,网络提取特征的时候可能受到了部分任务的影响,导致损失掉了一些特征,进而导致出现单个或部分表现不好的任务。例如,多任务模型为前文所提出的人脸属性检测任务模型,其所采用的主干网络是A Practical FacialLandmark Detector(PFLD)网络,该人脸属性检测任务模型在提取输入的图像的特征值时,由于受到一些任务的影响,导致部分图像特征有所损失,进而影响到某一任务或某些任务的准确率。为提升多任务模型中单任务或部分任务的准确率,可对该多任务模型的结构进行改进,改进后的多任务模型的结构如下:
该多任务模型包括N个卷积层及一个全连接层,N为大于2的整数;
其中,第1个卷积层的输入为该多任务模型的输入;
第i个卷积层的输入为第i-1个卷积层所输出的特征,其中,1<i≤N;
全连接层的输入为目标特征,目标特征基于至少两个卷积层所输出的特征融合而得。
其中,特征融合的过程具体为:将所有待融合的特征下采样至相同的大小,然后通过concate的方式将下采样后的待融合的特征贴在一起,再进行卷积,即可得到目标特征。
通过上述改进后的结构,全连接层不再仅对最后一个卷积层所输出的高层特征予以考虑,而是会对基于多层特征融合而得的目标特征予以考虑,以对因多任务耦合而导致特征丧失的情况进行补偿,由此提升表现不好的单任务或部分任务的准确率。
在一些实施例中,由于低层的特征语义信息比较少但是目标位置比较准确,高层的特征语义信息比较丰富但是目标位置比较粗略,因而此处可考虑分别从多任务模型的网络低层、中层及高层提取特征进行融合。则该目标特征可以通过如下过程而得:
获取第一特征,该第一特征为多任务模型中的第1个卷积层的输出。可以理解,该第一特征即为低层特征。
获取第二特征,该第二特征为多任务模型中的第j个卷积层的输出,其中,j为大于1且小于N的预设值。可以理解,该第二特征即为中层特征。仅作为示例,若N为偶数,则j可以为
Figure BDA0003560265060000101
若N为奇数,则j可以为
Figure BDA0003560265060000102
获取第三特征,该第三特征为多任务模型中的第N个卷积层的输出。可以理解,该三特征即为高层特征。
融合第一特征、第二特征及第三特征,即可得到目标特征。
仅作为示例,请参阅图2,图2给出了该多任务模型的结构的示例。
由上可见,在本申请实施例中,无需采用现有技术中的增加数据集的方式,可通过如下几方面增长多任务网络的准确率:第一方面,可为各个任务构建随迭代周期而动态变化的损失权重,使得训练过程中,难以学习的任务被赋予比较大的损失权重,易学习和已经收敛的任务被赋予较小的损失权重,帮助多任务模型侧重训练难以学习的任务;且损失权重的计算过程中结合了指数移动平均的算法,降低了损失权重的变化速率,使损失函数在稳定性上和任务学习的难易程度上做到了平衡,能够减少人为手动调权重超参数的实验数量。第二方面,提出了损失掩码,能够在不为训练样本标注多个真实值的标签的情况下,实现多任务模型中的单任务或部分任务的准确率增长,节约了标注时间及标注所需的人力成本。第三方面,改进了多任务模型的结构,通过结合低层、中层及高层的特征,对因多任务耦合而导致特征丧失的情况进行了补偿,由此提升了表现不好的单任务或部分任务的准确率。
对应于上文所提供的模型训练方法,本申请实施例还提供了一种模型训练装置。如图3所示,该模型训练装置300包括:
确定模块301,用于在待训练的多任务模型的训练过程中,分别确定待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;
构建模块302,用于根据上述损失权重构建上述待训练的多任务模型在上述当前的迭代周期的损失函数;
训练模块303,用于基于上述损失函数在上述当前的迭代周期内对上述待训练的多任务模型进行训练。
可选地,记上述当前的迭代周期为第t迭代周期,上述确定模块301,包括:
第一计算单元,用于针对每个任务,根据上述任务在第t-1迭代周期的损失值及上述任务在第t-2迭代周期的损失值,计算上述任务在第t-1迭代周期的训练速度;
第二计算单元,用于根据上述任务在第t-1迭代周期的训练速度、上述任务在第t-2迭代周期的优化后的训练速度及预设的计算权重,计算上述任务在第t-1迭代周期的优化后的训练速度;
第三计算单元,用于根据上述任务在第t-1迭代周期的优化后的训练速度,以及所有任务在第t-1迭代周期的优化后的训练速度之和,计算上述任务在第t迭代周期的损失权重。
可选地,上述构建模块302,具体用于根据上述损失权重及损失掩码,构建上述待训练的多任务模型在上述当前的迭代周期的损失函数,其中,所述损失掩码根据计算损失时所使用的标签数据而确定。
可选地,上述损失掩码通过如下过程确定:
基于计算损失时所使用的标签数据中的各个值确定目标任务,其中,所述目标任务为:在所述标签数据中的值为非真实值的任务;
将所述目标任务在所述损失掩码中所对应的掩码值确定为第一预设掩码值;
将除所述目标任务外的其它任务在所述损失掩码中所对应的掩码值确定为第二预设掩码值。
可选地,上述多任务模型包括N个卷积层及一个全连接层,N为大于2的整数;其中,第1个上述卷积层的输入为上述多任务模型的输入;第i个上述卷积层的输入为第i-1个上述卷积层所输出的特征,其中,1<i≤N;上述全连接层的输入为目标特征,上述目标特征基于至少两个上述卷积层所输出的特征融合而得。
可选地,上述目标特征通过融合第一特征、第二特征及第三特征而得,其中,上述第一特征为上述多任务模型中的第1个卷积层的输出,上述第二特征为上述多任务模型中的第j个卷积层的输出,其中,j为大于1且小于N的预设值,上述第三特征为上述多任务模型中的第N个卷积层的输出。
由上可见,在本申请实施例中,无需采用现有技术中的增加数据集的方式,可通过如下几方面增长多任务网络的准确率:第一方面,可为各个任务构建随迭代周期而动态变化的损失权重,使得训练过程中,难以学习的任务被赋予比较大的损失权重,易学习和已经收敛的任务被赋予较小的损失权重,帮助多任务模型侧重训练难以学习的任务;且损失权重的计算过程中结合了指数移动平均的算法,降低了损失权重的变化速率,使损失函数在稳定性上和任务学习的难易程度上做到了平衡,能够减少人为手动调权重超参数的实验数量。第二方面,提出了损失掩码,能够在不为训练样本标注多个真实值的标签的情况下,实现多任务模型中的单任务或部分任务的准确率增长,节约了标注时间及标注所需的人力成本。第三方面,改进了多任务模型的结构,通过结合低层、中层及高层的特征,对因多任务耦合而导致特征丧失的情况进行了补偿,由此提升了表现不好的单任务或部分任务的准确率。
对应于上文所提供的模型训练方法,本申请实施例还提供了一种电子设备。仅作为示例,该电子设备可以是机器人、智能手机、平板电脑、学习机或服务器等类型的设备,此处不作限定。请参阅图4,本申请实施例中的电子设备4包括:存储器401,一个或多个处理器402(图4中仅示出一个)及存储在存储器401上并可在处理器上运行的计算机程序。其中:存储器401用于存储软件程序以及单元,处理器402通过运行存储在存储器401的软件程序以及单元,从而执行各种功能应用以及数据处理,以获取上述预设事件对应的资源。具体地,处理器402通过运行存储在存储器401的上述计算机程序时实现以下步骤:
在待训练的多任务模型的训练过程中,分别确定上述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;
根据上述损失权重构建上述待训练的多任务模型在上述当前的迭代周期的损失函数;
基于上述损失函数在上述当前的迭代周期内对上述待训练的多任务模型进行训练。
假设上述为第一种可能的实施方式,则在第一种可能的实施方式作为基础而提供的第二种可能的实施方式中,记上述当前的迭代周期为第t迭代周期,上述分别确定待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重,包括:
针对每个任务,根据上述任务在第t-1迭代周期的损失值及上述任务在第t-2迭代周期的损失值,计算上述任务在第t-1迭代周期的训练速度;
根据上述任务在第t-1迭代周期的训练速度、上述任务在第t-2迭代周期的优化后的训练速度及预设的计算权重,计算上述任务在第t-1迭代周期的优化后的训练速度;
根据上述任务在第t-1迭代周期的优化后的训练速度,以及所有任务在第t-1迭代周期的优化后的训练速度之和,计算上述任务在第t迭代周期的损失权重。
在上述第一种可能的实施方式作为基础而提供的第三种可能的实施方式中,上述根据上述损失权重构建上述待训练的多任务模型在上述当前的迭代周期的损失函数,包括:
根据上述损失权重及损失掩码,构建上述待训练的多任务模型在上述当前的迭代周期的损失函数,其中,所述损失掩码根据计算损失时所使用的标签数据而确定。
在上述第三种可能的实施方式作为基础而提供的第四种可能的实施方式中,上述损失掩码通过如下过程确定:
基于计算损失时所使用的标签数据中的各个值确定目标任务,其中,所述目标任务为:在所述标签数据中的值为非真实值的任务;
将所述目标任务在所述损失掩码中所对应的掩码值确定为第一预设掩码值;
将除所述目标任务外的其它任务在所述损失掩码中所对应的掩码值确定为第二预设掩码值。
在上述第一种可能的实施方式作为基础而提供的第五种可能的实施方式中,上述多任务模型包括N个卷积层及一个全连接层,N为大于2的整数;
其中,第1个上述卷积层的输入为上述多任务模型的输入;
第i个上述卷积层的输入为第i-1个上述卷积层所输出的特征,其中,1<i≤N;
上述全连接层的输入为目标特征,上述目标特征基于至少两个上述卷积层所输出的特征融合而得。
在上述第五种可能的实施方式作为基础而提供的第六种可能的实施方式中,上述目标特征通过以下过程得到:
获取第一特征,上述第一特征为上述多任务模型中的第1个卷积层的输出;
获取第二特征,上述第二特征为上述多任务模型中的第j个卷积层的输出,其中,j为大于1且小于N的预设值;
获取第三特征,上述第三特征为上述多任务模型中的第N个卷积层的输出;
融合上述第一特征、上述第二特征及上述第三特征,得到上述目标特征。
应当理解,在本申请实施例中,所称处理器402可以是中央处理单元(CentralProcessing Unit,CPU),该处理器还可以是其他通用处理器、数字信号处理器(DigitalSignal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器401可以包括只读存储器和随机存取存储器,并向处理器402提供指令和数据。存储器401的一部分或全部还可以包括非易失性随机存取存储器。例如,存储器401还可以存储设备类别的信息。
由上可见,在本申请实施例中,无需采用现有技术中的增加数据集的方式,可通过如下几方面增长多任务网络的准确率:第一方面,可为各个任务构建随迭代周期而动态变化的损失权重,使得训练过程中,难以学习的任务被赋予比较大的损失权重,易学习和已经收敛的任务被赋予较小的损失权重,帮助多任务模型侧重训练难以学习的任务;且损失权重的计算过程中结合了指数移动平均的算法,降低了损失权重的变化速率,使损失函数在稳定性上和任务学习的难易程度上做到了平衡,能够减少人为手动调权重超参数的实验数量。第二方面,提出了损失掩码,能够在不为训练样本标注多个真实值的标签的情况下,实现多任务模型中的单任务或部分任务的准确率增长,节约了标注时间及标注所需的人力成本。第三方面,改进了多任务模型的结构,通过结合低层、中层及高层的特征,对因多任务耦合而导致特征丧失的情况进行了补偿,由此提升了表现不好的单任务或部分任务的准确率。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将上述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者外部设备软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的系统实施例仅仅是示意性的,例如,上述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
上述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
上述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关联的硬件来完成,上述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,上述计算机程序包括计算机程序代码,上述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。上述计算机可读存储介质可以包括:能够携带上述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机可读存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,上述计算机可读存储介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读存储介质不包括是电载波信号和电信信号。
以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。

Claims (10)

1.一种模型训练方法,其特征在于,包括:
在待训练的多任务模型的训练过程中,分别确定所述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;
根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数;
基于所述损失函数在所述当前的迭代周期内对所述待训练的多任务模型进行训练。
2.如权利要求1所述的模型训练方法,其特征在于,记所述当前的迭代周期为第t迭代周期,所述分别确定所述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重,包括:
针对每个任务,根据所述任务在第t-1迭代周期的损失值及所述任务在第t-2迭代周期的损失值,计算所述任务在第t-1迭代周期的训练速度;
根据所述任务在第t-1迭代周期的训练速度、所述任务在第t-2迭代周期的优化后的训练速度及预设的计算权重,计算所述任务在第t-1迭代周期的优化后的训练速度;
根据所述任务在第t-1迭代周期的优化后的训练速度,以及所有任务在第t-1迭代周期的优化后的训练速度之和,计算所述任务在第t迭代周期的损失权重。
3.如权利要求1所述的模型训练方法,其特征在于,所述根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数,包括:
根据所述损失权重及损失掩码,构建所述待训练的多任务模型在所述当前的迭代周期的损失函数,其中,所述损失掩码根据计算损失时所使用的标签数据而确定。
4.如权利要求3所述的模型训练方法,其特征在于,所述损失掩码通过如下过程确定:
基于计算损失时所使用的标签数据中的各个值确定目标任务,其中,所述目标任务为:在所述标签数据中的值为非真实值的任务;
将所述目标任务在所述损失掩码中所对应的掩码值确定为第一预设掩码值;
将除所述目标任务外的其它任务在所述损失掩码中所对应的掩码值确定为第二预设掩码值。
5.如权利要求1所述的模型训练方法,其特征在于,所述多任务模型包括N个卷积层及一个全连接层,N为大于2的整数;
其中,第1个所述卷积层的输入为所述多任务模型的输入;
第i个所述卷积层的输入为第i-1个所述卷积层所输出的特征,其中,1<i≤N;
所述全连接层的输入为目标特征,所述目标特征基于至少两个所述卷积层所输出的特征融合而得。
6.如权利要求5所述的模型训练方法,其特征在于,所述目标特征通过以下过程得到:
获取第一特征,所述第一特征为所述多任务模型中的第1个卷积层的输出;
获取第二特征,所述第二特征为所述多任务模型中的第j个卷积层的输出,其中,j为大于1且小于N的预设值;
获取第三特征,所述第三特征为所述多任务模型中的第N个卷积层的输出;
融合所述第一特征、所述第二特征及所述第三特征,得到所述目标特征。
7.一种模型训练装置,其特征在于,包括:
确定模块,用于在待训练的多任务模型的训练过程中,分别确定所述待训练的多任务模型中的每个任务在当前的迭代周期内的损失权重;
构建模块,用于根据所述损失权重构建所述待训练的多任务模型在所述当前的迭代周期的损失函数;
训练模块,用于基于所述损失函数在所述当前的迭代周期内对所述待训练的多任务模型进行训练。
8.如权利要求7所述的模型训练装置,其特征在于,记所述当前的迭代周期为第t迭代周期,所述确定模块,包括:
第一计算单元,用于针对每个任务,根据所述任务在第t-1迭代周期的损失值及所述任务在第t-2迭代周期的损失值,计算所述任务在第t-1迭代周期的训练速度;
第二计算单元,用于根据所述任务在第t-1迭代周期的训练速度、所述任务在第t-2迭代周期的优化后的训练速度及预设的计算权重,计算所述任务在第t-1迭代周期的优化后的训练速度;
第三计算单元,用于根据所述任务在第t-1迭代周期的优化后的训练速度,以及所有任务在第t-1迭代周期的优化后的训练速度之和,计算所述任务在第t迭代周期的损失权重。
9.一种电子设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至6任一项所述的方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述的方法。
CN202210286944.4A 2022-03-23 2022-03-23 一种模型训练方法、模型训练装置及电子设备 Pending CN114764593A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210286944.4A CN114764593A (zh) 2022-03-23 2022-03-23 一种模型训练方法、模型训练装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210286944.4A CN114764593A (zh) 2022-03-23 2022-03-23 一种模型训练方法、模型训练装置及电子设备

Publications (1)

Publication Number Publication Date
CN114764593A true CN114764593A (zh) 2022-07-19

Family

ID=82364878

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210286944.4A Pending CN114764593A (zh) 2022-03-23 2022-03-23 一种模型训练方法、模型训练装置及电子设备

Country Status (1)

Country Link
CN (1) CN114764593A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115984804A (zh) * 2023-03-14 2023-04-18 安徽蔚来智驾科技有限公司 一种基于多任务检测模型的检测方法及车辆

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115984804A (zh) * 2023-03-14 2023-04-18 安徽蔚来智驾科技有限公司 一种基于多任务检测模型的检测方法及车辆

Similar Documents

Publication Publication Date Title
CN108205655B (zh) 一种关键点预测方法、装置、电子设备及存储介质
CN111476309A (zh) 图像处理方法、模型训练方法、装置、设备及可读介质
CN112132847A (zh) 模型训练方法、图像分割方法、装置、电子设备和介质
CN113379627A (zh) 图像增强模型的训练方法和对图像进行增强的方法
CN112561060B (zh) 神经网络训练方法及装置、图像识别方法及装置和设备
CN112633420B (zh) 图像相似度确定及模型训练方法、装置、设备和介质
CN110728319B (zh) 一种图像生成方法、装置以及计算机存储介质
CN115861462B (zh) 图像生成模型的训练方法、装置、电子设备及存储介质
CN117576264B (zh) 图像生成方法、装置、设备及介质
CN110489955B (zh) 应用于电子设备的图像处理、装置、计算设备、介质
CN114170484B (zh) 图片属性预测方法、装置、电子设备和存储介质
CN110097004B (zh) 面部表情识别方法和装置
CN111402113A (zh) 图像处理方法、装置、电子设备及计算机可读介质
CN114764593A (zh) 一种模型训练方法、模型训练装置及电子设备
CN116152938A (zh) 身份识别模型训练和电子资源转移方法、装置及设备
CN113140012A (zh) 图像处理方法、装置、介质及电子设备
CN117830790A (zh) 多任务模型的训练方法、多任务处理方法及装置
CN115457365A (zh) 一种模型的解释方法、装置、电子设备及存储介质
CN114648021A (zh) 问答模型的训练方法、问答方法及装置、设备和存储介质
CN114120423A (zh) 人脸图像检测方法、装置、电子设备和计算机可读介质
CN114021010A (zh) 一种信息推荐模型的训练方法、装置及设备
CN115393914A (zh) 多任务模型训练方法、装置、设备及存储介质
CN112070022A (zh) 人脸图像识别方法、装置、电子设备和计算机可读介质
CN112036418A (zh) 用于提取用户特征的方法和装置
CN112215868A (zh) 基于生成对抗网络的去除手势图像背景的方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination