CN111027428A - 一种多任务模型的训练方法、装置及电子设备 - Google Patents

一种多任务模型的训练方法、装置及电子设备 Download PDF

Info

Publication number
CN111027428A
CN111027428A CN201911205175.5A CN201911205175A CN111027428A CN 111027428 A CN111027428 A CN 111027428A CN 201911205175 A CN201911205175 A CN 201911205175A CN 111027428 A CN111027428 A CN 111027428A
Authority
CN
China
Prior art keywords
task
loss value
type
loss
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN201911205175.5A
Other languages
English (en)
Other versions
CN111027428B (zh
Inventor
刘思阳
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing QIYI Century Science and Technology Co Ltd
Original Assignee
Beijing QIYI Century Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing QIYI Century Science and Technology Co Ltd filed Critical Beijing QIYI Century Science and Technology Co Ltd
Priority to CN201911205175.5A priority Critical patent/CN111027428B/zh
Publication of CN111027428A publication Critical patent/CN111027428A/zh
Application granted granted Critical
Publication of CN111027428B publication Critical patent/CN111027428B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Biophysics (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Human Computer Interaction (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Multimedia (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明实施例提供了一种多任务模型的训练方法、装置及电子设备。该方法包括:将样本数据输入训练中的多任务模型,得到样本数据的针对每类任务的预测内容;利用样本数据的针对每类任务的预测内容,计算每类任务的损失值;将每类任务的损失值代入当前的总损失计算函数,得到多任务模型的损失值;当所得到的损失值为预设的期望损失值时,结束训练;否则,调整多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构总损失计算函数,并返回将样本数据输入训练中的多任务模型的步骤。通过本方案,可以实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的。

Description

一种多任务模型的训练方法、装置及电子设备
技术领域
本发明涉及计算机技术领域,特别是涉及一种多任务模型的训练方法及装置。
背景技术
所谓多任务模型为能够同时输出多类任务的处理结果的神经网络模型。例如:针对图像分析领域而言,通过一模型可以同时输出图像的第一类特征信息和第二类特征信息,此时,该模型属于多任务模型,该多任务模型所针对的多任务包括关于第一类特征信息的识别任务和关于第二类特征信息的识别任务。并且,由于多任务模型针对多类任务,那么,针对多任务模型而言,在训练模型时,每类任务均存在损失。
现有技术中,在训练多任务模型的过程中,计算多任务模型的损失值时,利用预先设定的一组权重参数,对每类任务的损失值进行加权求和,得到多任务模型的损失值;进而,当利用多任务模型的损失值判断出多任务模型未收敛时,调整多任务模型的网络参数,并继续对多任务模型进行训练。并且,考虑到模型训练时所利用权重参数为人工设定,因此,为了保证训练完成的多任务模型的精准度较高,通常人工设定多组权重参数,利用每一组权重参数,分别对多任务模型进行训练,并从所训练得到的多个多任务模型中选择收敛效果最好的模型,作为最终的多任务模型。
可见,由于现有技术中针对每一组权重参数均执行完整的模型训练过程,因此,无疑存在浪费计算资源的问题。
发明内容
本发明实施例的目的在于提供一种多任务模型的训练方法、装置及电子设备,以实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的。
具体技术方案如下:
第一方面,本发明实施例提供了一种多任务模型的训练方法,所述方法包括:
将样本数据输入训练中的多任务模型,得到所述样本数据的针对每类任务的预测内容;
利用所述样本数据的针对每类任务的预测内容,计算每类任务的损失值;
将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值;其中,所述总损失值计算函数用于:对每类任务的损失值的加权损失以及每类任务对应的权重参数的修正值进行求和,每一权重参数的修正值与该权重参数呈负相关关系;
当所得到的损失值为预设的期望损失值时,结束训练;否则,调整所述多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构所述总损失计算函数,并返回所述将样本数据输入训练中的多任务模型的步骤。
可选地,所述总损失计算函数包括:
Figure BDA0002296768830000021
其中,Ltotal为所述多任务模型的损失值,Li为所述多任务模型所针对的任务i的损失值,αi为所述任务i对应的权重参数,n为所述多任务模型所针对任务的总数量,f(αi)为用于求取αi的修正值的函数,且f(αi)为函数值与αi呈负相关关系的函数。
可选地,所述f(αi)包括
Figure BDA0002296768830000022
其中,r为预设的底数。
可选地,所述将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值之前,所述方法还包括:
对每类任务的损失值进行归一化处理;
所述将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值,包括:
将归一化处理后的每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值。
可选地,所述对每类任务的损失值进行归一化处理,包括:
针对每类任务,计算该类任务在产生损失时能够产生的最大损失值;
针对每类任务,将该类任务的损失值除以该类任务的最大损失值,得到归一化处理后的该类任务的损失值。
第二方面,本发明实施例还提供了一种多任务模型的训练装置,所述装置包括:
预测内容确定模块,用于将样本数据输入训练中的多任务模型,得到所述样本数据的针对每类任务的预测内容;
第一损失值计算模块,用于利用所述样本数据的针对每类任务的预测内容,计算每类任务的损失值;
第二损失值计算模块,用于将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值;其中,所述总损失值计算函数用于:对每类任务的损失值的加权损失以及每类任务对应的权重参数的修正值进行求和,每一权重参数的修正值与该权重参数呈负相关关系;
损失值分析模块,用于当所得到的损失值为预设的期望损失值时,结束训练;否则,调整所述多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构所述总损失计算函数,并触发所述预测内容确定模块。
可选地,所述总损失计算函数包括:
Figure BDA0002296768830000031
其中,Ltotal为所述多任务模型的损失值,Li为所述多任务模型所针对的任务i的损失值,αi为所述任务i对应的权重参数,n为所述多任务模型所针对任务的总数量,f(αi)为用于求取αi的修正值的函数,且f(αi)为函数值与αi呈负相关关系的函数。
可选地,所述f(αi)包括
Figure BDA0002296768830000032
其中,r为预设的底数。
可选地,所述装置还包括:
归一化模块,用于在所述第二损失值计算模块将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值之前,对每类任务的损失值进行归一化处理;
所述第二损失值计算模块,具体用于:
将归一化处理后的每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值。
可选地,所述归一化模块,具体用于:
针对每类任务,计算该类任务在产生损失时能够产生的最大损失值;将该类任务的损失值除以该类任务的最大损失值,得到归一化处理后的该类任务的损失值。
第三方面,本发明实施例还提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现本发明实施例所提供的多任务模型的训练方法的步骤。
第四方面,本发明实施例还提供一种计算机可读存储介质,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现本发明实施例所提供的多任务模型的训练方法的步骤。
本发明实施例所提供的方案中,在多任务模型训练时,各类任务的权重参数和多任务模型的网络参数均作为可自动学习的参数,使得在一次完整的模型训练过程中可以通过自动调整权重参数来获得最优的权重参数;同时,总损失计算函数考虑每类任务对应的权重参数的修正值,这样多任务模型的损失值的计算过程成为动态优化过程,为多任务模型达到损失值为期望损失值的收敛效果提供实现基础。因此,通过本方案,可以实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍。
图1为本发明实施例所提供的一种多任务模型的训练方法的流程图;
图2为本发明实施例所提供的一种多任务模型的训练方法的另一流程图;
图3为本发明实施例所提供的一种多任务模型的训练方法的另一流程图;
图4(a)为多任务模型所针对的任务为两类任务时,本发明实施例中多任务模型的训练过程的原理图;
图4(b)为多任务模型所针对的任务为热度图任务和位移图任务时,本发明实施例中多任务模型的训练过程的原理图;
图5为本发明实施例所提供的一种多任务模型的训练装置的结构示意图;
图6为本发明实施例所提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述。
为了实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的,本发明实施例提供了一种多任务模型的训练方法、装置及电子设备。
其中,本发明实施例所提供的一种多任务模型的训练方法的执行主体可以为一种多任务模型的训练装置。并且,该多任务模型的训练装置可以应用于电子设备中,该电子设备可以为服务器,也可以为终端设备。
并且,任一个能够同时输出多类任务的处理结果的神经网络模型,均可以作为本发明实施例所涉及的多任务模型,并采用本发明实施例所提供的多任务模型的训练方法进行训练,以在保证多任务模型的精度的同时,降低对计算资源的浪费。
具体的,针对于图像分析领域而言,能够同时输出图像的多类特征信息的神经网络模型,可以作为本发明实施例所涉及的多任务模型;而针对文本分析领域而言,能够同时输出文本的多类特征信息的神经网络模型,也可以作为本发明实施例所涉及的多任务模型。示例性的,图像的多类特征信息可以包括:人体关键点信息和人像蒙版;或者,图像中人体关键点的第一类特征信息和图像中人体关键点的第二类特征信息,等等。示例性的,文本的多类特征信息可以包括:文本的情感类别信息和文本的关键字,等等。
下面首先结合附图,对本发明实施例所提供的一种多任务模型的训练方法进行介绍。
如图1所示,本发明实施例所提供的一种多任务模型的训练方法,可以包括如下步骤:
S101,将样本数据输入训练中的多任务模型,得到该样本数据的针对每类任务的预测内容;
S102,利用该样本数据的针对每类任务的预测内容,计算每类任务的损失值;
S103,将每类任务的损失值代入当前的总损失计算函数,得到该多任务模型的损失值;其中,该总损失值计算函数用于:对每类任务的损失值的加权损失以及每类任务对应的权重参数的修正值进行求和,每一权重参数的修正值与该权重参数呈负相关关系;
S104,判断所得到的损失值是否为预设的期望损失值,如果是,结束训练,得到训练完成的模型;否则,执行S105;
S105,调整该多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构该总损失计算函数,并返回S101。
其中,样本数据为用于训练该多任务模型的数据,多任务模型所针对的任务不同,则样本数据不同。示例性的,多任务模型所针对的任务包括图像的多类特征信息的识别任务,则样本数据为样本图像;而多任务模型所针对的任务包括文本的多类特征信息的识别任务,则样本数据为文本。
另外,该样本数据的针对每类任务的预测内容,即为样本数据输入多任务模型后,经过多任务模型处理所得到针对每类任务的处理结果。并且,多任务模型的具体模型结构,可以根据实际情况设定,本发明实施例对此不作限定。
可以理解的是,利用该样本数据的针对每类任务的预测内容,计算每类任务的损失值的处理思路,具体可以为:根据该样本数据的针对每类任务的预测内容与该样本数据的针对每类任务的标注内容的差异,来计算每类任务的损失值。任一种能够计算每类任务的损失值的具体实现方式,均可以应用于本发明实施例。
并且,本实施例中,为了避免人工多次设定权重参数所导致的计算资源浪费的问题,将每类任务对应的权重参数作为模型训练过程中的可自学习的参数,即在一次完整的训练过程中,每类任务对应的权重参数不再是固定值,而是可更新的值。而为了在权重参数作为可自学习的参数的前提下,保证多任务模型在较高精度下进行有效收敛,本实施例,设置了新的总损失计算函数。其中,该新的损失计算函数不但考虑了每类任务的损失值的加权和,同时,考虑每类任务对应的权重参数的修正值且该修正值与权重参数呈负相关关系,这样使得多任务模型的损失值的计算过程成为动态优化过程,为多任务模型达到损失值为期望损失值的收敛效果提供实现基础。
其中,所谓的修正值与权重参数呈负相关关系具体指:权重参数增大时,权重参数的修正值减小,而权重参数减小时,权重参数的修正值增大。并且,基于权重参数与权重参数的修正值的关系可知,通过函数表征时,求取权重参数的修正值的函数为以权重参数为自变量的函数。
需要说明的是,总损失计算函数中,各个任务对应的权重参数均通过人工设定有初始值,这样,在第一次计算多任务模型的损失值时,该总损失计算函数中的各个权重参数为初始值,而后续每次计算多任务模型的损失值时,该总损失计算函数中的各个权重参数为自学习得到的值。
另外,多任务模型的训练装置在调整该多任务模型的网络参数和每类任务对应的权重参数所采用的调整方式,可以为存在多种,例如:梯度调节方式或反向传播方式,当然并不局限于此。任一种在模型训练过程中能够调整网络参数的方式,均可以作为本发明实施例中关于该多任务模型的网络参数和每类任务对应的权重参数所利用的调整方式。
其中,该多任务模型的网络参数为模型训练过程所需学习的关于模型本身的参数,例如:该多任务模型的网络参数可以包括卷积核权重,全连接层权重等等。
本发明实施例所提供的方案中,在多任务模型训练时,各类任务的权重参数和多任务模型的网络参数均作为可自动学习的参数,使得在一次完整的模型训练过程中可以通过自动调整权重参数来获得最优的权重参数;同时,总损失计算函数考虑每类任务对应的权重参数的修正值,这样多任务模型的损失值的计算过程成为动态优化过程,为多任务模型达到损失值为期望损失值的收敛效果提供实现基础。因此,通过本方案,可以实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的。
可选地,在一种实现方式中,所述总损失计算函数包括:
Figure BDA0002296768830000081
其中,Ltotal为所述多任务模型的损失值,Li为所述多任务模型所针对的任务i的损失值,αi为所述任务i对应的权重参数,n为所述多任务模型所针对任务的总数量,f(αi)为用于求取αi的修正值的函数,且f(αi)为函数值与αi呈负相关关系的函数。
需要强调的是,f(αi)为函数值与αi呈负相关关系的函数的前提下,本发明实施例对于f(αi)的具体函数形式不做限定。示例性的,所述f(αi)可以包括
Figure BDA0002296768830000082
其中,r为预设的底数,此时,
Figure BDA0002296768830000083
其中,r的具体取值,可以根据实际情况设定,例如:r=10,或者,r=5,等等。
可选地,为了降低计算量,在上述实施例的S103之前,即所述将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值之前,如图2所示,所述方法还包括:
S1030,对每类任务的损失值进行归一化处理;
相应的,S103的步骤,可以包括:
S1031,将归一化处理后的每类任务的损失值代入当前的总损失计算函数,得到该多任务模型的损失值。
为了方便理解方案,图4(a)示出了多任务模型所针对的任务为两类任务时,本发明实施例中多任务模型的训练过程的原理图。
示例性的,对每类任务的损失值进行归一化处理所利用的归一化分母,可以为预先根据经验设定的值,当然,也可以是根据实际损失所确定的值。
基于根据实际损失所确定的值的方式,相应的,所述对每类任务的损失值进行归一化处理,可以包括:
针对每类任务,计算该类任务在产生损失时能够产生的最大损失值;将该类任务的损失值除以该类任务的最大损失值,得到归一化处理后的该类任务的损失值。
并且,在存在归一化过程的方案中,示例性的,所述总损失计算函数可以为:
Figure BDA0002296768830000091
其中,Ltotal为该多任务模型的损失值,Li为该多任务模型所针对的任务i的损失值,αi为该任务i对应的权重参数,n为该多任务模型所针对任务的总数量,r为预设的底数。
下面结合多任务模型所针对的多任务包括图像中人体关键点的热度图的识别任务和图像中人体关键点的位移图的识别任务为例,对本发明实施例所提供的一种多任务模型的训练方法进行介绍。为了描述方便,下述将图像中人体关键点的热度图的识别任务,简称为热度图任务,而将图像中人体关键点的位移图的识别任务,简称为位移图任务。
其中,人体关键点的热度图为:人体关键点可能存在的位置的概率分布图。而人体关键点的位移图包括人体关键点的x轴方向的位移图和y轴方向的位移图,其中,人体关键点的x轴方向的位移图中每个点用于表征:x轴方向上,该点所在位置相对于目标点所在位置的偏移距离,人体关键点的y轴方向的位移图中每个点用于表征:y轴方向上,该点所在位置相对于目标点所在位置的偏移距离,目标点为人体关键点在位移图中的映射点。
并且,通过人体关键点的热度图和位移图,可以利用预定的计算方式,确定出人体关键点的位置坐标。由于通过人体关键点的热度图和位移图确定人体关键点的位置坐标不为本发明的发明点,因此,本发明实施例对于预定的计算方式不做限定。
如图3所示,本发明实施例所提供的一种多任务模型的训练方法,可以包括如下步骤:
S301,将样本图像输入训练中的多任务模型,得到该样本图像中各个人体关键点的预测热度图和预测位移图;
其中,样本图像中每一人体关键点对应一个预测热度图,同时,对应一个x轴方向的预测位移图和y轴方向的预测位移图。并且,对于多任务模型的具体结构,本发明实施例不做限定。
S302,利用所得到的各个预测热度图和各个预测位移图,计算热度图任务的损失值和位移图任务的损失值;
S303,对热度图任务的损失值进行归一化处理,得到归一化后的热度图任务的损失值,并将位移图任务的损失值进行归一化处理,得到归一化后的位移图任务的损失值;
S304,将归一化后的热度图任务的损失值以及归一化后的位移图任务的损失值,代入当前的总损失计算函数,得到该多任务模型的损失值;
S305,判断所得到的损失值是否为预设的期望损失值,如果是,结束训练,得到训练完成的多任务模型;否则,执行S306;
S306,调整该多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构该总损失计算函数,并返回S301。
针对步骤S302而言,示例性的,利用所得到的各个预测热度图和各个预测位移图,计算热度图任务的损失值和位移图任务的损失值,具体为:
针对每一预测热度图,基于该预测热度图与同一人体关键点的真值热度图的差异,计算该预测热度图的损失值;
针对每一x轴方向的预测位移图,基于该预测位移图与同一人体关键点的x轴方向的真值位移图的差异,计算x轴方向的预测位移图的损失值;
针对每一y轴方向的预测位移图,基于该预测位移图与同一人体关键点的y轴方向的真值位移图的差异,计算y轴方向的预测位移图的损失值;
对每一预测热度图的损失值进行求和,得到热度图任务的损失值;
对每一x轴方向的预测位移图的损失值以及每一y轴方向的预测位移图的损失值进行求和,得到位移图任务的损失值。
针对步骤S303而言,示例性的,对热度图任务的损失值进行归一化处理,得到归一化后的热度图任务的损失值,可以包括:
计算热度图任务在产生损失时能够产生的最大损失值;
将热度图任务的损失值除以热度图任务的最大损失值,得到归一化处理后的热度图任务的损失值。
其中,计算热度图任务在产生损失时能够产生的最大损失值的方式可以包括:
利用第一计算公式,计算热度图任务在产生损失时能够产生的最大损失值;
其中,第一计算公式为:
Figure BDA0002296768830000111
Figure BDA0002296768830000112
为热度图任务在产生损失时能够产生的最大损失值,n为预测热度图的个数,wh×hh为预测热度图的尺寸,
Figure BDA0002296768830000113
为n个预测热度图中的中心点位的最大值,
Figure BDA0002296768830000114
为n个预测热度图中的边缘点位的最小值。
任一预测热度图的中心点位为取值最大的像素点;任一预测热度图的边缘点位为除取值最大的像素点以外的像素点。
针对步骤S303而言,示例性的,对位移图任务的损失值进行归一化处理,得到归一化后的位移图任务的损失值,可以包括:
计算位移图任务在产生损失时能够产生的最大损失值;
将位移图任务的损失值除以位移图任务的最大损失值,得到归一化处理后的位移图任务的损失值。
其中,计算位移图任务在产生损失时能够产生的最大损失值的方式可以包括:
Figure BDA0002296768830000121
Figure BDA0002296768830000122
Figure BDA0002296768830000123
Figure BDA0002296768830000124
为位移图任务在产生损失时能够产生的最大损失值,
Figure BDA0002296768830000125
为位移图任务在产生损失时能够产生的x轴方向的最大损失值,
Figure BDA0002296768830000126
为位移图任务在产生损失时能够产生的y轴方向的最大损失值;n为每一方向上预测位移图的个数,wh×hh为预测位移图的尺寸。
针对步骤S304而言,示例性的,总损失计算函数可以为:
Figure BDA0002296768830000127
其中,Ltotal为多任务模型的损失值,α1为热度图任务对应的权重参数,Lh为热度图任务的损失值,α2为位移图任务对应的权重参数,Lo为位移图任务的损失值。
针对步骤S305而言,在多任务模型未收敛时,不但可以调整该多任务模型的网络参数,而且,可以调整热度图任务对应的权重参数以及位移图任务对应的权重参数,利用调整后的权重参数重构该总损失计算函数。需要说明的是,总损失计算函数中,热度图任务对应的权重参数和位移图任务对应的权重参数均通过人工设定有初始值,这样,在第一次计算多任务模型的损失值时,该总损失计算函数中的各个权重参数为初始值,而后续每次计算多任务模型的损失值时,该总损失计算函数中的各个权重参数为自学习得到的值。
为了方便理解方案,图4(b)示出了多任务模型所针对的任务为热度图任务和位移图任务时,本发明实施例中中多任务模型的训练过程的原理图。
本发明实施例所提供的方案中,在多任务模型训练时,热度图任务和位移图任务的权重参数,以及多任务模型的网络参数均作为可自动学习的参数,使得在一次完整的模型训练过程中可以通过自动调整权重参数来获得最优的权重参数权重参数;同时,总损失计算函数考虑每类任务对应的权重参数的修正值,这样多任务模型的损失值的计算过程成为动态优化过程,为多任务模型达到损失值为期望损失值的收敛效果提供实现基础。因此,通过本方案,可以实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的。
相应于上述方法实施例,本发明实施例还提供了一种多任务模型的训练装置。如图5所示,所述装置可以包括:
预测内容确定模块510,用于将样本数据输入训练中的多任务模型,得到所述样本数据的针对每类任务的预测内容;
第一损失值计算模块520,用于利用所述样本数据的针对每类任务的预测内容,计算每类任务的损失值;
第二损失值计算模块530,用于将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值;其中,所述总损失值计算函数用于:对每类任务的损失值的加权损失以及每类任务对应的权重参数的修正值进行求和,每一权重参数的修正值与该权重参数呈负相关关系;
损失值分析模块540,用于当所得到的损失值为预设的期望损失值时,结束训练;否则,调整所述多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构所述总损失计算函数,并触发预测内容确定模块510。
本发明实施例所提供的方案中,在多任务模型训练时,各类任务的权重参数和多任务模型的网络参数均作为可自动学习的参数,使得在一次完整的模型训练过程中可以通过自动调整权重参数来获得最优的权重参数;同时,总损失计算函数考虑每类任务对应的权重参数的修正值,这样多任务模型的损失值的计算过程成为动态优化过程,为多任务模型达到损失值为期望损失值的收敛效果提供实现基础。因此,通过本方案,可以实现在保证多任务模型的精准度的同时,降低对计算资源的浪费的目的。
可选地,所述总损失计算函数包括:
Figure BDA0002296768830000141
其中,Ltotal为所述多任务模型的损失值,Li为所述多任务模型所针对的任务i的损失值,αi为所述任务i对应的权重参数,n为所述多任务模型所针对任务的总数量,f(αi)为用于求取αi的修正值的函数,且f(αi)为函数值与αi呈负相关关系的函数。
可选地,所述f(αi)包括
Figure BDA0002296768830000142
其中,r为预设的底数。
可选地,所述装置还包括:
归一化模块,用于在所述第二损失值计算模块将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值之前,对每类任务的损失值进行归一化处理;
所述第二损失值计算模块,具体用于:
将归一化处理后的每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值。
可选地,所述归一化模块,具体用于:
针对每类任务,计算该类任务在产生损失时能够产生的最大损失值;将该类任务的损失值除以该类任务的最大损失值,得到归一化处理后的该类任务的损失值。
本发明实施例还提供了一种电子设备,如图6所示,包括处理601、通信接口602、存储器603和通信总线604,其中,处理器601,通信接口602,存储器603通过通信总线604完成相互间的通信,
存储器603,用于存放计算机程序;
处理601,用于执行存储器603上所存放的程序时,实现本发明实施例所提供的多任务模型的训练方法的步骤。
上述终端提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,简称PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,简称EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述终端与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,简称RAM),也可以包括非易失性存储器(non-volatile memory),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,简称CPU)、网络处理器(Network Processor,简称NP)等;还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application SpecificIntegrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本发明提供的又一实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述实施例中任一所述的多任务模型的训练方法。
在本发明提供的又一实施例中,还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述实施例中任一所述的多任务模型的训练方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本发明实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk(SSD))等。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置、设备、存储介质实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本发明的较佳实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本发明的保护范围内。

Claims (12)

1.一种多任务模型的训练方法,其特征在于,所述方法包括:
将样本数据输入训练中的多任务模型,得到所述样本数据的针对每类任务的预测内容;
利用所述样本数据的针对每类任务的预测内容,计算每类任务的损失值;
将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值;其中,所述总损失值计算函数用于:对每类任务的损失值的加权损失以及每类任务对应的权重参数的修正值进行求和,每一权重参数的修正值与该权重参数呈负相关关系;
当所得到的损失值为预设的期望损失值时,结束训练;否则,调整所述多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构所述总损失计算函数,并返回所述将样本数据输入训练中的多任务模型的步骤。
2.根据权利要求1所述的方法,其特征在于,所述总损失计算函数包括:
Figure FDA0002296768820000011
其中,Ltotal为所述多任务模型的损失值,Li为所述多任务模型所针对的任务i的损失值,αi为所述任务i对应的权重参数,n为所述多任务模型所针对任务的总数量,f(αi)为用于求取αi的修正值的函数,且f(αi)为函数值与αi呈负相关关系的函数。
3.根据权利要求2所述的方法,其特征在于,所述f(αi)包括
Figure FDA0002296768820000012
其中,r为预设的底数。
4.根据权利要求1-3任一项所述的方法,其特征在于,所述将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值之前,所述方法还包括:
对每类任务的损失值进行归一化处理;
所述将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值,包括:
将归一化处理后的每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值。
5.根据权利要求4所述的方法,其特征在于,所述对每类任务的损失值进行归一化处理,包括:
针对每类任务,计算该类任务在产生损失时能够产生的最大损失值;
针对每类任务,将该类任务的损失值除以该类任务的最大损失值,得到归一化处理后的该类任务的损失值。
6.一种多任务模型的训练装置,其特征在于,所述装置包括:
预测内容确定模块,用于将样本数据输入训练中的多任务模型,得到所述样本数据的针对每类任务的预测内容;
第一损失值计算模块,用于利用所述样本数据的针对每类任务的预测内容,计算每类任务的损失值;
第二损失值计算模块,用于将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值;其中,所述总损失值计算函数用于:对每类任务的损失值的加权损失以及每类任务对应的权重参数的修正值进行求和,每一权重参数的修正值与该权重参数呈负相关关系;
损失值分析模块,用于当所得到的损失值为预设的期望损失值时,结束训练;否则,调整所述多任务模型的网络参数和每类任务对应的权重参数,利用调整后的权重参数重构所述总损失计算函数,并触发所述预测内容确定模块。
7.根据权利要求6所述的装置,其特征在于,所述总损失计算函数包括:
Figure FDA0002296768820000021
其中,Ltotal为所述多任务模型的损失值,Li为所述多任务模型所针对的任务i的损失值,αi为所述任务i对应的权重参数,n为所述多任务模型所针对任务的总数量,f(αi)为用于求取αi的修正值的函数,且f(αi)为函数值与αi呈负相关关系的函数。
8.根据权利要求7所述的装置,其特征在于,所述f(αi)包括
Figure FDA0002296768820000031
其中,r为预设的底数。
9.根据权利要求6-8任一项所述的装置,其特征在于,所述装置还包括:
归一化模块,用于在所述第二损失值计算模块将每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值之前,对每类任务的损失值进行归一化处理;
所述第二损失值计算模块,具体用于:
将归一化处理后的每类任务的损失值代入当前的总损失计算函数,得到所述多任务模型的损失值。
10.根据权利要求9所述的装置,其特征在于,所述归一化模块,具体用于:
针对每类任务,计算该类任务在产生损失时能够产生的最大损失值;将该类任务的损失值除以该类任务的最大损失值,得到归一化处理后的该类任务的损失值。
11.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1-5任一所述的方法步骤。
12.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-5任一所述的方法步骤。
CN201911205175.5A 2019-11-29 2019-11-29 一种多任务模型的训练方法、装置及电子设备 Active CN111027428B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201911205175.5A CN111027428B (zh) 2019-11-29 2019-11-29 一种多任务模型的训练方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201911205175.5A CN111027428B (zh) 2019-11-29 2019-11-29 一种多任务模型的训练方法、装置及电子设备

Publications (2)

Publication Number Publication Date
CN111027428A true CN111027428A (zh) 2020-04-17
CN111027428B CN111027428B (zh) 2024-03-08

Family

ID=70207377

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201911205175.5A Active CN111027428B (zh) 2019-11-29 2019-11-29 一种多任务模型的训练方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN111027428B (zh)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112541124A (zh) * 2020-12-24 2021-03-23 北京百度网讯科技有限公司 生成多任务模型的方法、装置、设备、介质及程序产品
CN113435528A (zh) * 2021-07-06 2021-09-24 北京有竹居网络技术有限公司 对象分类的方法、装置、可读介质和电子设备
CN113516239A (zh) * 2021-04-16 2021-10-19 Oppo广东移动通信有限公司 模型训练方法、装置、存储介质及电子设备
CN114882464A (zh) * 2022-05-31 2022-08-09 小米汽车科技有限公司 多任务模型训练方法、多任务处理方法、装置及车辆
CN114913371A (zh) * 2022-05-10 2022-08-16 平安科技(深圳)有限公司 多任务学习模型训练方法、装置、电子设备及存储介质
CN115081630A (zh) * 2022-08-24 2022-09-20 北京百度网讯科技有限公司 多任务模型的训练方法、信息推荐方法、装置和设备

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106503669A (zh) * 2016-11-02 2017-03-15 重庆中科云丛科技有限公司 一种基于多任务深度学习网络的训练、识别方法及系统
CN109086660A (zh) * 2018-06-14 2018-12-25 深圳市博威创盛科技有限公司 多任务学习深度网络的训练方法、设备及存储介质

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106503669A (zh) * 2016-11-02 2017-03-15 重庆中科云丛科技有限公司 一种基于多任务深度学习网络的训练、识别方法及系统
CN109086660A (zh) * 2018-06-14 2018-12-25 深圳市博威创盛科技有限公司 多任务学习深度网络的训练方法、设备及存储介质

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112541124A (zh) * 2020-12-24 2021-03-23 北京百度网讯科技有限公司 生成多任务模型的方法、装置、设备、介质及程序产品
CN112541124B (zh) * 2020-12-24 2024-01-12 北京百度网讯科技有限公司 生成多任务模型的方法、装置、设备、介质及程序产品
CN113516239A (zh) * 2021-04-16 2021-10-19 Oppo广东移动通信有限公司 模型训练方法、装置、存储介质及电子设备
CN113435528A (zh) * 2021-07-06 2021-09-24 北京有竹居网络技术有限公司 对象分类的方法、装置、可读介质和电子设备
CN113435528B (zh) * 2021-07-06 2024-02-02 北京有竹居网络技术有限公司 对象分类的方法、装置、可读介质和电子设备
CN114913371A (zh) * 2022-05-10 2022-08-16 平安科技(深圳)有限公司 多任务学习模型训练方法、装置、电子设备及存储介质
CN114882464A (zh) * 2022-05-31 2022-08-09 小米汽车科技有限公司 多任务模型训练方法、多任务处理方法、装置及车辆
CN115081630A (zh) * 2022-08-24 2022-09-20 北京百度网讯科技有限公司 多任务模型的训练方法、信息推荐方法、装置和设备

Also Published As

Publication number Publication date
CN111027428B (zh) 2024-03-08

Similar Documents

Publication Publication Date Title
CN111027428B (zh) 一种多任务模型的训练方法、装置及电子设备
CN109784391B (zh) 基于多模型的样本标注方法及装置
CN110175278B (zh) 网络爬虫的检测方法及装置
CN109598414B (zh) 风险评估模型训练、风险评估方法、装置及电子设备
CN110909663B (zh) 一种人体关键点识别方法、装置及电子设备
JP2020523649A (ja) 処理対象のトランザクションに関するリスクを識別する方法、装置、及び電子機器
WO2022027913A1 (zh) 目标检测模型生成方法、装置、设备及存储介质
CN111027412B (zh) 一种人体关键点识别方法、装置及电子设备
CN110969100B (zh) 一种人体关键点识别方法、装置及电子设备
JP2018530093A (ja) クレジット点数モデルトレーニング方法、クレジット点数計算方法、装置及びサーバー
CN111340245A (zh) 一种模型训练方法及系统
CN113283388B (zh) 活体人脸检测模型的训练方法、装置、设备及存储介质
WO2021174814A1 (zh) 众包任务的答案验证方法、装置、计算机设备及存储介质
CN110941824B (zh) 一种基于对抗样本增强模型抗攻击能力的方法和系统
CN112434717B (zh) 一种模型训练方法及装置
CN117014507A (zh) 任务卸载模型的训练方法、任务的卸载方法及装置
CN116416052A (zh) 一种针对特定用户的授信方法及装置、电子设备及存储介质
CN113033542B (zh) 一种文本识别模型的生成方法以及装置
CN111046380B (zh) 一种基于对抗样本增强模型抗攻击能力的方法和系统
CN112926608A (zh) 一种图像分类方法、装置、电子设备及存储介质
CN113066486B (zh) 数据识别方法、装置、电子设备和计算机可读存储介质
CN117743568B (zh) 基于资源流量和置信度融合的内容生成方法和系统
CN117725231B (zh) 基于语义证据提示和置信度的内容生成方法和系统
CN112560709B (zh) 一种基于辅助学习的瞳孔检测方法及系统
CN110308905B (zh) 一种页面组件匹配方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant