CN114723989A - 多任务学习方法、装置及电子设备 - Google Patents
多任务学习方法、装置及电子设备 Download PDFInfo
- Publication number
- CN114723989A CN114723989A CN202210307497.6A CN202210307497A CN114723989A CN 114723989 A CN114723989 A CN 114723989A CN 202210307497 A CN202210307497 A CN 202210307497A CN 114723989 A CN114723989 A CN 114723989A
- Authority
- CN
- China
- Prior art keywords
- task
- data
- data set
- balance factor
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例涉及图像处理技术领域,公开了一种多任务学习方法、装置及电子设备,该方法通过获取至少两个数据集,获取任务平衡因子和数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量,利用任务平衡因子和数据集平衡因子来确定每一任务在每一数据集中对应标签的采样率,以对每一数据集进行采样,得到每一任务用于训练的样本数据,进而进行多任务训练,得到多任务训练结果,本申请实施例能够解决在进行多任务单阶段训练时,存在多个数据集中不同任务的标签数量不同或者比例失衡或者数据分布不同而导致的标签不平衡的问题,从而平衡多任务学习中的不同任务的性能。
Description
技术领域
本申请实施例涉及图像处理技术领域,尤其涉及一种多任务学习方法、装置及电子设备。
背景技术
在自然语言处理、语音识别以及计算机视觉等多个领域,多任务学习(Multi-TaskLearning,MLT),通过利用不同学习任务的数据,同时学习多个子任务的共享信息,并从学习到的共享特征中预测多个目标来提高学习效率和预测准确性,进而获得更好的鲁棒性以及泛化能力,进而提升每个任务的性能并减少过拟合的风险。
一般的机器学习模型都是针对单一的特定任务,比如手写体数字识别、物体检测等。不同任务的模型都是在各自的数据集上单独学习得到的。而多任务用单一数据集要求该数据集有多个任务一一对应的标签,这样的数据集的标注成本高,且难以适应所有的多任务;另外这样的数据集容易存在数据量不足,导致过拟合的问题。因此,多任务学习通常采用多个数据集来补充不同任务的标签数量,并适应不同的多任务模型。
为了提高多个任务的性能,通常采用多阶段训练的方式来进行多任务学习,但是,这种训练方式的训练过程比较繁琐,导致训练效率不足,因此,为了提高训练效率,目前,通常采用单阶段的方式进行多任务学习,但是,如果在多个数据集上完成单阶段多任务训练,可能会存在以下情形:
(1)不同任务带标签的数据集的数据量不同,例如:有的任务包含丰富的带标签的图像数据,比如:目标检测和分割任务;而有的任务缺少带标签的图像数据,比如:行人属性分类任务。
(2)不同数据集中不同任务的标签数量不同或者比例失衡,有的数据集不包含所有任务的标签,例如:有的数据集有大量的深度图,但是没有分割标签,有的数据集有少量一一对应的深度图和分割标签。
(3)不同数据集的数据分布不同,例如:室内的数据集和室外的数据集的数据分布不同,比如:带深度图的一一对应的数据集一般是室内,使得室内的数据集的深度图标签的数量比较多,而室外的数据集的深度图标签的数量比较少。
由上所述,可以看出,如果直接将相关任务的数据集的图像数据都整合到一起进行训练,容易导致单阶段的多任务训练中不同任务的训练样本数量失衡,导致有的任务性能表现好而有的任务表现较差,从而使得任务学习不平衡。
目前的技术方案至少存在以下技术问题:
对于单阶段的多任务训练而言,直接输入多个数据集的图像数据存在不同数据集中不同任务的标签数量不同或者比例失衡或者数据分布不同的问题,导致不同任务的标签不平衡,即不同任务的训练样本数量失衡,导致有的任务性能表现好而有的任务表现较差,从而使得任务学习不平衡。
发明内容
本申请实施例提供一种多任务学习方法、装置及电子设备,以解决在进行多任务单阶段训练时,直接输入多个数据集的图像数据存在不同数据集中不同任务之间标签数量不同或者比例失衡或者数据分布不同而导致的标签不平衡的技术问题,使得不同任务之间标签平衡,从而平衡多任务学习中的不同任务的性能。
第一方面,本申请实施例提供一种多任务学习方法,该方法包括:
获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;
其中,标签平衡采样机制包括:
获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;
根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;
根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
在一些实施例中,根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率,包括:
假设一共有K个任务,n个数据集,则每一数据集中的每一任务对应的标签的采样率为:其中,为第k个任务对应的任务平衡因子,δi,k为第k个任务在第i个数据集中对应标签的数据集平衡因子,i∈{1,2…n},k∈{1,2…K}。
在一些实施例中,根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据,包括:
在至少两个数据集中,对每一任务在每一数据集中的对应标签,采用随机数生成算法生成随机数;
若随机数小于采样率,则将该标签加入到第一标签集合;
在遍历所有数据集中的所有标签之后,生成每一任务对应的第二标签集合,将第二标签集合中所有标签对应的数据确定为每一任务用于训练的样本数据。
在一些实施例中,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果,包括:
将采样后的所有数据作为训练的样本数据输入到第一模型中进行单阶段的多任务训练;
将每一任务对应的用于训练的样本数据的预测结果加入到预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合;
组合每一任务的预测结果集合,将组合后的预测结果集合确定为多任务预测结果。
在一些实施例中,第一模型包括至少两个任务分支,每一任务分支一一对应一个任务,将每一任务对应的用于训练的样本数据的预测结果加入到预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合,包括:
通过第一模型中的每一任务分支对与其对应的任务的每一样本数据进行预测,得到每一样本数据对应的预测结果,并将每一样本数据对应的预测结果加入到每一任务的预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合。
在一些实施例中,至少两个任务分支共用同一个主干网络,主干网络用于每一任务分支对每一样本进行预测,以得到每一样本对应的预测结果。
在一些实施例中,方法还包括:
根据每一任务对应的任务平衡因子,确定每一任务对应的标签数量,具体包括:
假设一共有K个任务,n个数据集,其中,n个数据集中的图像数量分别为{N1,N2…Nn},则每一任务对应的标签数量为:
在一些实施例中,方法还包括:
根据每一任务对应的任务平衡因子以及每一任务在每一数据集中的标签的数据集平衡因子,确定每一任务在每一数据集中的标签数量,具体包括:
假设一共有K个任务,n个数据集,其中,n个数据集中的图像数量分别为{N1,N2…Nn},则每一任务在每一数据集中的标签数量为:
第二方面,本申请实施例提供一种多任务学习装置,该装置包括:
数据集获取模块,用于获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
预测结果确定模块,用于基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;
其中,预测结果确定模块,包括:
平衡因子获取单元,用于获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;
采样率确定单元,用于根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;
训练样本确定单元,用于根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
第三方面,本申请实施例提供一种电子设备,包括:
存储器以及一个或多个处理器,一个或多个处理器用于执行存储在存储器中的一个或多个计算机程序,一个或多个处理器在执行一个或多个计算机程序时,使得电子设备实现如第一方面的多任务学习方法。
第四方面,本申请实施例提供一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,计算机程序包括程序指令,程序指令当被处理器执行时使处理器执行如第一方面的多任务学习方法。
本申请实施例的有益效果:区别于现有技术的情况,本申请实施例提供的一种多任务学习方法,包括:获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;其中,标签平衡采样机制包括:获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
通过获取至少两个数据集,获取任务平衡因子和数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量,利用任务平衡因子和数据集平衡因子来确定每一任务在每一数据集中对应标签的采样率,以对每一数据集进行采样,得到每一任务用于训练的样本数据,进而进行多任务训练,得到多任务训练结果,本申请实施例能够解决在进行多任务单阶段训练时,直接输入多个数据集的图像数据存在不同数据集中不同任务之间标签数量不同或者比例失衡或者数据分布不同而导致的标签不平衡的技术问题,使得不同任务之间标签平衡,从而平衡多任务学习中的不同任务的性能。
附图说明
一个或多个实施例通过与之对应的附图中的图片进行示例性说明,这些示例性说明并不构成对实施例的限定,附图中具有相同参考数字标号的元件表示为类似的元件,除非有特别申明,附图中的图不构成比例限制。
图1是本申请实施例提供的一种多任务学习方法的应用环境示意图;
图2是本申请实施例提供的一种多任务学习的示意图;
图3是本申请实施例提供的一种多任务学习方法的流程示意图;
图4是本申请实施例提供的一种多任务学习的框架示意图;
图5是本申请实施例提供的另一种多任务学习的流程示意图;
图6是图5中的步骤S504的细化流程图;
图7是图5中的步骤S505的细化流程图;
图8是本申请实施例提供的一种多任务学习的标签平衡采样机制的流程示意图;
图9是本申请实施例提供的一种多任务学习装置的结构示意图;
图10是本申请实施例提供的一种电子设备的硬件结构示意图。
具体实施方式
下面结合具体实施例对本申请进行详细说明。以下实施例将有助于本领域的技术人员进一步理解本申请,但不以任何形式限制本申请。应当指出的是,对本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进。这些都属于本申请的保护范围。
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本申请,并不用于限定本申请。
需要说明的是,如果不冲突,本申请实施例中的各个特征可以相互结合,均在本申请的保护范围之内。另外,虽然在装置示意图中进行了功能模块划分,在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于装置中的模块划分,或流程图中的顺序执行所示出或描述的步骤。此外,本文所采用的“第一”、“第二”、“第三”等字样并不对数据和执行次序进行限定,仅是对功能和作用基本相同的相同项或相似项进行区分。
除非另有定义,本说明书所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本说明书中在本申请的说明书中所使用的术语只是为了描述具体的实施方式的目的,不是用于限制本申请。本说明书所使用的术语“和/或”包括一个或多个相关的所列项目的任意的和所有的组合。
此外,下面所描述的本申请各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
在对本申请进行详细说明之前,对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释:
(1)神经网络,也简称为神经网络(NNs)或称作连接模型(Connection Model),它是一种模仿动物神经网络行为特征,进行分布式并行信息处理的算法数学模型。神经网络依靠系统的复杂程度,通过调整内部大量节点之间相互连接的关系,从而达到处理信息的目的。具体的,神经网络可以是由神经单元组成的,具体可以理解为具有输入层、隐含层、输出层的神经网络,一般来说第一层是输入层,最后一层是输出层,中间的层数都是隐含层。其中,具有很多层隐含层的神经网络则称为深度神经网络(deep neural network,DNN)。神经网络中的每一层的工作可以用数学表达式y=a(W·x+b)来描述,从物理层面,神经网络中的每一层的工作可以理解为通过五种对输入空间(输入向量的集合)的操作,完成输入空间到输出空间的变换(即矩阵的行空间到列空间),这五种操作包括:1、升维/降维;2、放大/缩小;3、旋转;4、平移;5、“弯曲”。其中1、2、3的操作由“W·x”完成,4的操作由“+b”完成,5的操作则由“a()”来实现这里之所以用“空间”二字来表述是因为被分类的对象并不是单个事物,而是一类事物,空间是指这类事物所有个体的集合,其中,W是神经网络各层的权重矩阵,该矩阵中的每一个值表示该层的一个神经元的权重值。该矩阵W决定着上文的输入空间到输出空间的空间变换,即神经网络每一层的W控制着如何变换空间。训练神经网络的目的,也就是最终得到训练好的神经网络的所有层的权重矩阵。因此,神经网络的训练过程本质上就是学习控制空间变换的方式,更具体的就是学习权重矩阵。
需要注意的是,在本申请实施例中,多任务学习所采用的模型,本质都是神经网络。神经网络中的常用组件有卷积层、池化层、归一化层和反向卷积层等,通过组装神经网络中的这些常用组件,设计得到模型,当确定模型参数(各层的权重矩阵)使得模型误差满足预设条件或调整模型参数的数量达到预设阈值时,模型收敛。
其中,卷积层配置有多个卷积核、每个卷积核设置有对应的步长,以对图像进行卷积运算。卷积运算的目的是提取输入图像的不同特征,第一层卷积层可能只能提取一些低级的特征如边缘、线条和角等层级,更深的卷积层能从低级特征中迭代提取更复杂的特征。
反向卷积层用于将一个低维度的空间映射到高维度,同时保持他们之间的连接关系/模式(这里的连接关系即是指卷积时候的连接关系)。反向卷积层配置有多个卷积核、每个卷积核设置有对应的步长,以对图像进行反卷积运算。一般,用于设计神经网络的框架库(例如PyTorch库)中内置有upsumple()函数,通过调用该upsumple()函数可以实现低维度到高维度的空间映射。
池化层(pooling)是模仿人的视觉系统可以对数据进行降维或用更高层次的特征表示图像。池化层的常见操作包括最大值池化、均值池化、随机池化、中值池化和组合池化等。通常来说,神经网络的卷积层之间都会周期性插入池化层以实现降维。
归一化层用于对中间层的所有神经元进行归一化运算,以防止梯度爆炸和梯度消失。
(2)损失函数,指的是将随机事件或其有关随机变量的取值映射为非负实数以表示该随机事件的“风险”或“损失”的函数。损失函数是一个非负实数函数,用来量化模型预测的预测标签和真实标签之间的差异。在应用中,损失函数通常作为学习准则与优化问题相联系,即通过最小化损失函数求解和评估模型。例如在统计学和机器学习中被用于模型的参数估计(parametric estimation)。在训练神经网络的过程中,因为希望神经网络的输出尽可能的接近真正想要预测的值,可以通过比较当前网络的预测值和真正想要的目标值,再根据两者之间的差异情况来更新每一层神经网络的权重矩阵(然,在第一次更新之前通常会有初始化的过程,即为神经网络中的各层预先配置参数),比如,如果网络的预测值高了,就调整权重矩阵让它预测低一些,不断的调整,直到神经网络能够预测出真正想要的目标值。因此,就需要预先定义“如何比较预测值和目标值之间的差异”,这便是损失函数(loss function)或目标函数(objective function),它们是用于衡量预测值和目标值的差异的重要方程。其中,以损失函数举例,损失函数的输出值(loss)越高表示差异越大,那么神经网络的训练就变成了尽可能缩小输出值(loss)的过程。
下面结合说明书附图具体阐述本申请的技术方案。
请参阅图1,图1是本申请实施例提供的一种多任务学习方法的应用环境示意图;
如图1所示,该应用环境100包括:电子设备101和服务器102,该电子设备101和服务器102通过有线或无线通信方式进行通信。
其中,电子设备101可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。电子设备101中可以设有客户端,该客户端可以是视频客户端、浏览器客户端、线上购物客户端、即时通信客户端等,本申请对客户端的类型不加以限定。
电子设备101以及服务器102可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。电子设备101可以获取至少两个数据集,并对数据集中的图像数据进行预测,以得到多任务预测结果,其中,该数据集中的图像数据可以是电子设备101的存储器中存储的图像或者接收其他设备发送的图像,例如:服务器102发送的图像。
或者,电子设备101可以接收服务器102发送的至少两个数据集,并存储至少两个数据集,基于至少两个数据集进行多任务训练,以得到多任务预测结果。用户可以对电子设备中存储的图像进行浏览,将图像发送到服务器102,由服务器102对至少两个数据集进行多任务训练,以得到多任务预测结果,并将多任务预测结果102发送到电子设备102。可以理解的是,电子设备可以通过图像采集器件来获取图像数据,其中,该图像采集器件可以内置于电子设备101中,还可以外接于电子设备101,本申请对此不加以限定。
可以理解的是,电子设备101可以泛指多个电子设备中的一个,本申请实施例仅以电子设备101来举例说明。本领域技术人员可以知晓,上述电子设备的数量可以更多或更少。比如上述电子设备可以仅为一个,或者上述电子设备为几十个或几百个,或者更多数量,本申请实施例对电子设备的数量和设备类型不加以限定。
其中,服务器102可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。
服务器102以及电子设备101可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。服务器102可以维护有图像数据库,用于存储多个数据集,每一数据集包括多个图像。服务器102可以接收电子设备101发送的图像和/或多任务训练指令,并根据电子设备102发送的图像,生成至少两个数据集,根据多任务训练指令,对至少两个数据集进行多任务训练,以得到多任务预测结果,并将多任务预测结果发送给电子设备101。
可以理解的是,上述服务器102的数量可以更多或更少,本申请实施例对此不加以限定。当然,服务器102还可以包括其他功能服务器,以便提供更全面且多样化的服务。
请参阅图2,图2是本申请实施例提供的一种多任务学习的示意图;
如图2所示,多任务学习(MTL,Multi Task Learning),在多任务学习网络中,多个输入的特征被共同利用,得到共享参数,其中,共享参数可以用于每一个任务的训练,通过多个共享参数来得到每一任务对应的任务特定参数,例如:任务1、任务2、任务3和任务4分别利用共享参数进行任务学习,并且,每一个任务都对应特定的任务特定参数,用于进行对应任务的训练。可以理解的是,多个共享参数后向传播并行地作用于4个输出。由于4个输出共享底部的隐层,这些隐层中用于某个任务的特征表示也可以被其他任务利用,促使多个任务共同学习。多个任务并行训练并共享不同任务已学到的特征表示,是多任务学习的核心思想。
可以理解的是,多任务学习是一种归纳迁移方法,充分利用隐含在多个相关任务训练信号中的特定领域信息。在后向传播过程中,多任务学习允许共享隐层中专用于某个任务的特征被其他任务使用;多任务学习将可以学习到可适用于几个不同任务的特征,这样的特征在单任务学习网络中往往不容易学到。归纳迁移的目标是利用额外的信息来源来提高当前任务的学习性能,包括提高泛化准确率、学习速度和学习的模型的可理解性。提供更强的归纳偏向是迁移提高泛化能力的一种方法,可以在固定的训练集上产生更好的泛化能力,或者减少达到同等性能水平所需要的训练样本数量。归纳偏向会导致一个归纳学习器更偏好一些假设,多任务学习正是利用隐含在相关任务训练信号中的信息作为一个归纳偏向来提高泛化能力。
本申请实施例提供一种多任务学习方法,以平衡各个任务的标签数量,从而平衡多任务学习中的不同任务的性能。
具体的,请参阅图3,图3是本申请实施例提供的一种多任务学习方法的流程示意图;
其中,该多任务学习方法,应用于上述的电子设备,具体的,该多任务学习方法的执行主体为该电子设备的一个或多个处理器。
如图3所示,该多任务学习方法,包括:
步骤S301:获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
具体的,电子设备通过获取至少两个数据集,其中,每一数据集均包括多个图像数据,即目标图像,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于N,N为任务的数量,例如:若任务的数量N为10,即有N个不同的任务需要进行训练,则每一数据集中标签种类的数量均不大于10种,比如:某一数据集中包括10种标签,而某一数据集中包括9种标签。为了避免其他不用于训练的标签造成的干扰,本申请实施例假定每一数据集中的标签均用于进行任务训练。
可以理解的是,一组样本构成的集合称为数据集(Data Set),标签(Label)指的是图像中需要预测的值,标签可以是离散值,也可以是连续值,例如:标签为年龄、是否是人类等。
请再参阅图4,图4是本申请实施例提供的一种多任务学习的框架示意图;
如图4所示,该多任务学习的框架包括:输入、模型、输出,其中,输入为至少两个数据集,例如:数据集1、数据集2,每一数据集均为图像数据集,包括多个图像数据,模型为深度学习模型,输出为图像数据的预测结果,例如:检测结果和/或分类结果。
具体的,至少两个数据集被输入到模型,即第一模型,该第一模型包括主干网络,该第一模型用于训练至少两个任务,分别对应至少两个任务分支,例如:该第一模型用于训练检测任务和属性分类任务,该检测任务对应检测分支,检测分支用于输出图像数据的检测结果该检测结果包括图像数据中的目标对象的检测框以及目标对象的分数;该属性分类任务对应属性分类分支,属性分类分支用于输出图像数据的分类结果,该分类结果中的目标对象的属性类别的分类分数;其中,与检测分支输出的检测结果对应的标签为:检测标签,即检测框的左上角坐标以及检测框的宽度和高度,该检测标签为4维变量,比如:4维变量表示为(x,y,w,h),x为检测框的左上角的横坐标,y为左上角的纵坐标,w为检测框的宽度,h为检测框的高度,与属性分类分支输出的分类结果对应的标签为类别标签,即类别ID,例如:[婴儿,儿童,成人,未知]4个类别对应的ID分别是[0,1,2,3],如果图像数据中的检测框是婴儿,则此时类别标签为0。
其中,属性分类分支对图像数据中的每一个像素都预测其分类分数,而属性分类分支输出的分类分数为检测分支输出的目标对象的检测框的中心像素对应的分类分数。具体的,检测分支输出左上角和右下角的坐标,通过几何运算即可求取目标对象的检测框的中心,可以理解的是,目标对象的检测框为矩形框,目标对象的检测框的中心即为矩形框的中心。
步骤S302:基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果。
具体的,标签平衡采样机制包括:
获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
可以理解的是,标签平衡采样机制用于对电子设备获取到的至少两个数据集进行采样,使得输入第一模型中的图像数据中的每一任务的标签平衡。为了进一步说明标签平衡采样机制的工作原理,请再参阅图5,图5是本申请实施例提供的另一种多任务学习的流程示意图;
如图5所示,该多任务学习的流程,包括:
步骤S501:获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
步骤S502:获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子;
具体的,任务平衡因子用于调节不同任务之间的标签数量,例如:在全部的数据集的标签中,第一任务对应的标签数量与第二任务对应的标签数量不同,因此,通过任务平衡因子来平衡第一任务和第二任务之间的标签数量,比如:全部的数据集中第一任务对应的标签数量为10000,而第二任务对应的标签数量为5000,此时,设置第一任务对应的任务平衡因子为0.1,第二任务对应的任务平衡因子为0.2,使得在采样过程中第一任务和第二任务对应的标签数量均为1000。
具体的,数据集平衡因子用于调节每一任务在不同数据集中的标签数量,例如:通过任务平衡因子确定在采样过程中第一任务对应的标签数量为1000,此时,数据集平衡因子用于将1000个标签分配到每一个数据集中,可以理解的是,每一数据集对应的标签数量不同,因此,每一数据集需要采样的标签数量也不同。假设每一数据集中均不包含噪声数据或噪声数据的比例相同,则此时每一数据集对应的数据集平衡因子相同,即每一数据集对应的数据集平衡因子均为1,其中,噪声数据指的是训练数据中存在标签错误(NoisyLabels)的图像数据。
在本申请实施例中,若不同数据集中包含的噪声数据的比例不同,此时不同数据集对应的数据集平衡因子不同,例如:第一数据集对应的噪声数据的比例为X,则第一数据集对应的数据集平衡因子为:1-X,其中,X∈[0,1]。可以理解的是,若每一数据集对应的数据集平衡因子为1-X,则此时对每一任务进行采样的标签数量小于由任务平衡因子确定的标签数量,此时,为了填补剩余的标签数量,可以从噪声数据的比例最小的若干个数据集中进行采样,例如:从噪声数据的比例最小的一个数据集中采样剩余的标签数量,或者,按照噪声数据的比例从小到大的顺序,对至少两个数据集进行排序,获取噪声数据的比例按照从小到大的顺序的前三名的数据集,并确定剩余的标签数量,从前三名的数据集中按照等比例或不等比例的方式进行采样,以确定每一数据集对应的标签数量,其中,噪声数据的比例越小的数据集,其采样的比例越大。在确定每一数据集对应的标签数量之后,根据每一数据集对应的标签数量,调整每一数据集对应的数据集平衡因子,例如:若某一任务对应的标签在所有数据集中的标签数量为1000,该任务对应的任务平衡因子为0.1,由此确定该任务在采样过程中对应的标签数量为1000,若某一数据集中该标签对应的标签数量为1000,而该任务在该数据集中采样的标签数量为120,则此时数据集平衡因子为1.2,即某一任务在某一数据集中采样的标签数量=某一数据集某一任务对应的标签数量*任务平衡因子*数据集平衡因子,比如:120=1000*0.1*1.2。
步骤S503:根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;
具体的,根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率,包括:
假设一共有K个任务,n个数据集,则每一数据集中的每一任务对应的标签的采样率为:其中,为第k个任务对应的任务平衡因子,δi,k为第k个任务在第i个数据集中对应标签的数据集平衡因子,i∈{1,2…n},k∈{1,2…K}。
步骤S504:根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据;
具体的,请再参阅图6,图6是图5中的步骤S504的细化流程图;
如图6所示,该步骤S504:根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据,包括:
步骤S5041:在至少两个数据集中,对每一任务在每一数据集中的对应标签,采用随机数生成算法生成随机数;
具体的,在至少两个数据集中,对于每一个任务在每一数据集中的每一个标签,采用随机数生成算法来生成一个随机数,其中,随机数的取值范围为[0,1]。在本申请实施例中,随机数的生成方式由随机数生成算法完成,其中,该随机数生成算法包括生成[0,1]之间均匀分布的随机数算法,例如:线性同余法(Linear Congruence Generator,LCG)、混合同余法等算法。
步骤S5042:若随机数小于采样率,则将该标签加入到第一标签集合;
具体的,随机数的取值范围为[0,1],采样率的取值范围为[0,1]。当随机数小于采样率时,则将该标签加入到第一标签集合。可以理解的是,当标签的数量越多时,采样的标签的数量将越接近于采样率对应的数量,即采样的标签的数量≈标签的总数量*采样率。
步骤S5043:在遍历所有数据集中的所有标签之后,生成每一任务对应的第二标签集合,将第二标签集合中所有标签对应的数据确定为每一任务用于训练的样本数据。
具体的,在遍历每一数据集中的所有标签,再遍历所有的数据集,由此遍历所有数据集中的所有标签,最终得到的第一标签集合作为第二标签集合,将第二标签集合中所有标签对应的图像数据作为每一任务用于训练的样本数据。
通过遍历所有数据集中的所有标签,以确定每一任务用于训练的样本数据,本申请实施例能够更好地确定每一任务对应的样本,有利于多任务的训练。
步骤S505:在对所有的数据集进行采样之后,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果。
具体的,请再参阅图7,图7是图5中的步骤S505的细化流程图;
如图7所示,该步骤S505:将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果,包括:
步骤S5051:将采样后的所有数据作为训练的样本数据输入到第一模型中进行单阶段的多任务训练;
在本申请实施例中,第一模型包括至少两个任务分支,每一任务分支一一对应一个任务,其中,至少两个任务分支共用一个主干网络。
请再参阅图4,如图4所示的多任务学习的框架为例,该框架包括两个任务分支,分别为检测分支和属性分类分支。
在多任务训练中,同时训练检测分支和属性分类分支需要一一对应的检测框和属性分类标签。然而,不同任务的训练数据量不同,导致多任务训练不平衡。例如:对于检测分支而言,分类为人的检测框数量比较多,但是这些数据集缺失属性的标签,因此无法训练多任务;对于属性分类分支而言,存在数据集同时包含检测框和属性类别的标签,但是数据量少,容易导致过拟合。因此,本申请实施例提出通过标签平衡采样机制来调节不同数据集、不同任务的标签比例,进而平衡多任务的训练。
如图4所示,以年龄阶段分类为例子,包括两个任务分支,分别为检测分支和年龄阶段分类分支,其中,检测分支对应检测任务,年龄阶段分类分支对应年龄阶段分类任务,例如:本申请实施例包括4个年龄阶段,分别为婴儿、儿童、成人以及年龄未知。假设数据集1只包含检测框标签的数据集,一共包含N1张图片,数据集2同时包含检测框标签和年龄阶段分类标签的数据集,一共包含N2张图片。
在数据集1和数据集2组成的混合数据集中,一共有N1+N2张图片用来训练检测分支,N2张图片训练年龄阶段分类分支。若N1>>N2,那么年龄阶段分类分支的标签数量将会远小于检测分支,导致年龄阶段分类分支过拟合。因此,本申请通过任务平衡因子表示不同任务之间的任务平衡因子,该任务平衡因子同时为检测分支以及年龄阶段分类分支的采样因子,其中,任务平衡因子的取值范围为[0,1]。
并且,在数据集1和数据集2组成的混合数据集中,同一个分支的数据来源于不同的数据集,且因为不同数据集的来源不同,为了提升模型的泛化能力需要针对不同分支调节数据集之间的标签比例。例如对于检测分支而言,希望尽可能检测到类别为人的框,并且不仅限于婴儿和儿童这些人类数据。本申请通过数据集平衡因子β1和β2表示数据集1和数据集2在检测标签上的比例,通过数据集平衡因子δ1和δ2表示数据集A和数据集B在年龄属性标签上的比例,为了实现各个标签的平衡,优选地,设置数据集平衡因子β1、β2、δ1和δ2均为1。
如图4所示,数据集1、数据集2中被采样的样本数据混合在一起输入到模型中,通过前向传播,模型的检测分支输出检测框以及对应类别的分数,属性分类分支则输出属性类别的分数;在每一次迭代中,标签平衡采样机制通过采样的方法确定该次迭代的批次数据中哪些标签能够计算损失函数,并且其对应的损失函数能够反向传导;用采样后的标签计算损失函数并反向传导更新参数。例如:模型的检测分支输出的对应类别的分数包括是否是人类的softmax分数,属性类别的分数包括每一个年龄阶段的softmax分数。可以理解的是,方向传导和样本的数量无关,即假设一个批次数据采样前是32个样本,采样后是24个样本,最终对需要反向传导的样本的损失函数求和,并计算梯度,因此样本数目不会影响反向传导。
通过采样确定每一任务的训练样本,从而使得任务的样本数量不同,假设采样前的预测结果集合和标签集合的样本数量为32,则采样后的预测结果集合和标签集合的样本数量小于等于32,例如:24个样本。
可以理解的是,每一任务均对应一个损失函数,每一任务均基于损失函数进行迭代训练,以更新每一任务的参数。而多任务学习的联合目标函数为所有任务损失函数的线性加权。其中,权重可以根据不同任务的重要程度来赋值,也可以根据任务的难易程度来赋值。优选地,本申请实施例中的所有任务设置相同的权重。
在实际的训练过程中,通过迭代训练,计算梯度,并更新每一任务的参数,例如:设置第一次数阈值,判断迭代次数是否大于第一次数阈值。具体的,本申请实施例采用随机梯度下降算法(stochastic gradient descent,SGD)来更新每一任务的参数,其中,迭代次数设置为24次,初始学习率设置为0.001,学习率更新策略为余弦退火的方式,最小学习率为e-5。
可以理解的是,余弦退火(Cosine annealing)可以通过余弦函数来降低学习率,余弦函数中随着x的增加余弦值首先缓慢下降,然后加速下降,再次缓慢下降。在余弦退火的方式下,下降模式能和学习率配合,以产生较好的更新效果。
可以理解的是,本申请实施例中的任务分支的数量不进行限定,可以包括两个以及两个以上的任务分支。
优选地,本申请实施例中的所有任务分支共用同一个主干网络,主干网络用于每一任务分支对每一样本进行预测,以得到每一样本对应的预测结果,以实现单阶段的多任务训练,提高训练效率。在本申请实施例中,主干网络即为特征提取器,用于提取特征,其中,主干网络包括神经网络,例如:ResNet,MobileNet、VGG等神经网络。
可以理解的是,多个任务共享同一个主干网络,该主干网络输出的特征为共享特征。用共享参数对应的共享网络输出的特征,来预测不同的任务的输出结果,也就是说,多个任务共享同一个特征提取器,并同时输出不同任务的结果,有利于实现多任务训练。
步骤S5052:将每一任务对应的用于训练的样本数据的预测结果加入到预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合;
具体的,将每一任务对应的用于训练的样本数据的预测结果加入到预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合,包括:
通过第一模型中的每一任务分支对与其对应的任务的每一样本数据进行预测,得到每一样本数据对应的预测结果,并将每一样本数据对应的预测结果加入到每一任务的预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合。
由于每一任务均对应一组用于训练的样本数据,而每一样本数据都对应一个预测结果,在得到每一样本数据对应的预测结果之后,将该预测结果加入到该任务对应的预测结果集合,在遍历该任务对应的所有的样本数据之后,从而得到最终的预测结果集合,最终的预测结果集合即为该任务对应的预测结果集合,以此类推,对下一个任务进行处理,得到至少两个任务中的每一任务的预测结果集合。
步骤S5053:组合每一任务的预测结果集合,将组合后的预测结果集合确定为多任务预测结果。
具体的,在遍历全部的样本数据之后,得到每一任务对应的预测结果集合,将每一任务对应的预测结果集合进行组合,将组合后的预测结果集合确定为多任务预测结果。
相比在单任务上训练的模型对在该任务上缺失标签的数据进行推理,并把生成的结果作为伪标签从而增加该任务的训练数据量,最终将在伪标签上预训练的多任务模型在各任务标签一一对应的数据上进行微调的方式而言,本申请实施例基于每一任务对应的训练样本,进行单阶段的多任务训练,由于采用的是单阶段的多任务训练,因此,在一个阶段能够完成模型训练,而不需要经历先粗略训练再微调的过程,因此,本申请能够提高多任务训练的效率。
请再参阅图8,图8是本申请实施例提供的一种多任务学习的标签平衡采样机制的流程示意图;
其中,该标签平衡采样机制通过采样的方式,对训练数据的标签进行抽取,以不同的采样率调节不同数据集、不同任务的标签量。
如图8所示,该多任务学习的标签平衡采样机制的流程,包括:
步骤S801:获取某一任务在某一数据集中的对应标签;
具体的,假设一共有K个任务,n个数据集,其中,n个数据集中的图像数量分别为{N1,N2…Nn},为第k个任务对应的任务平衡因子,Ni,k为第k个任务在第i个数据集中的标签数量,i∈{1,2…n},k∈{1,2…K}。
步骤S802:生成随机数;
具体的,假设随机数为c,其中,c的取值范围为[0,1]。
步骤S803:随机数是否小于任务平衡因子*数据集平衡因子;
步骤S804:是否遍历全部的数据集;
具体的,判断是否遍历全部的数据集,即当前的数据集是否为最后一个数据集,若是,则进入步骤S805;若否,则进入步骤S807;
步骤S805:是否遍历全部的任务;
具体的,判断是否遍历全部的任务,即当前的任务是否为最后一个任务,若是,则进入步骤S806;若否,则进入步骤S808;
步骤S806:生成最终的标签集合以及最终的预测结果集合;
具体的,若确定遍历全部的数据集和全部的任务,则将当前的标签集合作为最终的标签集合,将当前的预测结果集合作为最终的预测结果集合。
步骤S807:进入下一数据集;
步骤S808:进入下一任务;
步骤S809:将该标签加入对应的标签集合,以及,将该标签对应的图像数据的预测结果加入预测结果集合;
可以理解的是,通过采样率的方式进行采样,可能会出现采样误差,导致采样的样本数据的数量与预期的数据的数量出现误差,因此,在本申请实施例中,通过确定每一任务对应的标签数量来减少采样误差,具体的,根据每一任务对应的任务平衡因子,确定每一任务对应的标签数量,具体包括:
假设一共有K个任务,n个数据集,其中,n个数据集中的图像数量分别为{N1,N2…Nn},则每一任务对应的标签数量为:
例如:假设一共有2个任务,10个数据集,其中,每一个数据集中与第一个任务相关的标签数量均为1000,则10个数据集中与第一个任务相关的标签数量为10000,每一个数据集中与第二个任务相关的标签数量均为500,则与第二个任务相关的标签数量为5000,此时,为了平衡第一个任务和第二个任务之间的标签数量,因此,设置第一个任务对应的任务平衡因子为0.1,第二个任务对应的任务平衡因子为0.2,此时,可以确定第一个任务对应的标签数量为10000*0.1=1000,并且,确定第二个任务对应的标签数量为5000*0.2=1000。
通过确定每一个任务对应的标签数量,可以基于确定的标签数量来确定每一个任务对应的样本数据的数量,使得每一任务输入到第一模型中的用于训练的样本数据的数量相等,由于更好地平衡了不同任务的样本数据的数量,从而更好地平衡了多任务学习中的不同任务的性能。
可以理解的是,每一个任务对应的标签分布于不同的数据集中,为了更好地实现标签平衡,进一步地,该多任务学习方法通过确定每一任务在每一数据集中的标签数量,具体的,根据每一任务对应的任务平衡因子以及每一任务在每一数据集中的标签的数据集平衡因子,确定每一任务在每一数据集中的标签数量,具体包括:
假设一共有K个任务,n个数据集,其中,n个数据集中的图像数量分别为{N1,N2…Nn},则每一任务在每一数据集中的标签数量为:
例如:假设一共有2个任务,10个数据集,其中,每一个数据集中与第一个任务相关的标签数量均为1000,则10个数据集中与第一个任务相关的标签数量为10000,每一个数据集中与第二个任务相关的标签数量均为500,则与第二个任务相关的标签数量为5000,此时,为了平衡第一个任务和第二个任务之间的标签数量,因此,设置第一个任务对应的任务平衡因子为0.1,第二个任务对应的任务平衡因子为0.2,此时,可以确定第一个任务对应的标签数量为10000*0.1=1000,并且,确定第二个任务对应的标签数量为5000*0.2=1000,使得两个任务对应的标签数量相等。
此时,需要将第一个任务对应的1000个标签分配至每一数据集中,例如:第一个任务在第一个数据集中的标签的数据集平衡因子为1,此时,采样率=任务平衡因子*数据集平衡因子=0.1*1=0.1,则此时第一个任务在第一个数据集中的标签数量为1000*0.1=100,同理,第一个任务在第二个数据集中的标签的数据集平衡因子为1,在第二个数据集中的标签数量也为1000*0.1=100,…,在第十个数据集中的标签数量也为1000*0.1=100,以此可以得到第一个任务用于训练的样本数据。第二个任务对应的数据集的标签数量的确定与第一个任务类似,在此不再赘述。
通过确定每一个任务在每一个数据集中的标签数量,从而减少采样过程中的采样误差,并且,更好地将每一个任务对应的标签对应到每一个数据集,实现更好的标签分布,有利于更好地实现多任务学习过程中的标签平衡,从而平衡多任务学习中的不同任务的性能。
需要说明的是,本申请实施例中可以包括两个或超过两个的任务,在此不进行限定,均属于本申请实施例的保护范围。
在本申请实施例中,通过提供一种多任务学习方法,包括:获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;其中,标签平衡采样机制包括:获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
通过获取至少两个数据集,获取任务平衡因子和数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量,利用任务平衡因子和数据集平衡因子来确定每一任务在每一数据集中对应标签的采样率,以对每一数据集进行采样,得到每一任务用于训练的样本数据,进而进行多任务训练,得到多任务训练结果,本申请实施例能够解决在进行多任务单阶段训练时,直接输入多个数据集的图像数据存在不同数据集中不同任务之间标签数量不同或者比例失衡或者数据分布不同而导致的标签不平衡的技术问题,使得不同任务之间标签平衡,从而平衡多任务学习中的不同任务的性能。
请参阅图9,图9是本申请实施例提供的一种多任务学习装置的结构示意图;
其中,该多任务学习装置,应用于电子设备,具体的,该多任务学习装置应用于电子设备的一个或多个处理器。
如图9所示,该多任务学习装置90,包括:
数据集获取模块901,用于获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
预测结果确定模块902,用于基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;
其中,预测结果确定模块902,包括:
平衡因子获取单元9021,用于获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;
采样率确定单元9022,用于根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;
训练样本确定单元9023,用于根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
在本申请实施例中,多任务学习装置亦可以由硬件器件搭建成的,例如,多任务学习装置可以由一个或两个以上的芯片搭建而成,各个芯片可以互相协调工作,以完成上述各个实施例所阐述的多任务学习方法。再例如,多任务学习装置还可以由各类逻辑器件搭建而成,诸如由通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)、单片机、ARM(Acorn RISC Machine)或其它可编程逻辑器件、分立门或晶体管逻辑、分立的硬件组件或者这些部件的任何组合而搭建成。
本申请实施例中的多任务学习装置可以是装置,也可以是终端中的部件、集成电路、或芯片。该装置可以是移动电子设备,也可以为非移动电子设备。示例性的,移动电子设备可以为手机、平板电脑、笔记本电脑、掌上电脑、车载电子设备、可穿戴设备、超级移动个人计算机(ultra-mobile personal computer,UMPC)、上网本或者个人数字助理(personaldigital assistant,PDA)等,非移动电子设备可以为服务器、网络附属存储器(NetworkAttached Storage,NAS)、个人计算机(personal computer,PC)、电视机(television,TV)、柜员机或者自助机等,本申请实施例不作具体限定。
本申请实施例中的多任务学习装置可以为具有操作系统的装置。该操作系统可以为安卓(Android)操作系统,可以为ios操作系统,还可以为其他可能的操作系统,本申请实施例不作具体限定。
本申请实施例提供的多任务学习装置能够实现图3实现的各个过程,为避免重复,这里不再赘述。
需要说明的是,上述多任务学习装置可执行本申请上述实施例所提供的多任务学习方法,具备执行方法相应的功能模块和有益效果。未在多任务学习装置实施例中详尽描述的技术细节,可参见本申请实施例所提供的多任务学习方法。
在本申请实施例中,通过提供一种多任务学习装置,包括:数据集获取模块,用于获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;预测结果确定模块,用于基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;其中,预测结果确定模块,包括:平衡因子获取单元,用于获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;采样率确定单元,用于根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;训练样本确定单元,用于根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
通过获取至少两个数据集,获取任务平衡因子和数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量,利用任务平衡因子和数据集平衡因子来确定每一任务在每一数据集中对应标签的采样率,以对每一数据集进行采样,得到每一任务用于训练的样本数据,进而进行多任务训练,得到多任务训练结果,本申请实施例能够解决在进行多任务单阶段训练时,直接输入多个数据集的图像数据存在不同数据集中不同任务之间标签数量不同或者比例失衡或者数据分布不同而导致的标签不平衡的技术问题,使得不同任务之间标签平衡,从而平衡多任务学习中的不同任务的性能。
本申请实施例还提供了一种电子设备,请参阅图10,图10是本申请实施例提供的一种电子设备的硬件结构示意图;
如图10所示,该电子设备10包括通信连接的至少一个处理器11和存储器12(图10中以总线连接、一个处理器为例)。
其中,处理器11用于提供计算和控制能力,以控制电子设备10执行相应任务,例如,控制电子设备10执行上述任一方法实施例中的多任务学习方法,包括:获取至少两个数据集,至少两个数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;基于标签平衡采样机制,对至少两个数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;其中,标签平衡采样机制包括:获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量;根据任务平衡因子以及数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;根据采样率,对至少两个数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
通过获取至少两个数据集,获取任务平衡因子和数据集平衡因子,其中,任务平衡因子用于调节不同任务之间的标签数量,数据集平衡因子用于调节每一任务在不同数据集中的标签数量,利用任务平衡因子和数据集平衡因子来确定每一任务在每一数据集中对应标签的采样率,以对每一数据集进行采样,得到每一任务用于训练的样本数据,进而进行多任务训练,得到多任务训练结果,本申请实施例能够解决在进行多任务单阶段训练时,直接输入多个数据集的图像数据存在不同数据集中不同任务之间标签数量不同或者比例失衡或者数据分布不同而导致的标签不平衡的技术问题,使得不同任务之间标签平衡,从而平衡多任务学习中的不同任务的性能。
处理器11可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)、硬件芯片或者其任意组合;还可以是数字信号处理器(Digital Signal Processing,DSP)、专用集成电路(Application SpecificIntegrated Circuit,ASIC)、可编程逻辑器件(programmable logic device,PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(complex programmable logic device,CPLD),现场可编程逻辑门阵列(field-programmable gate array,FPGA),通用阵列逻辑(genericarray logic,GAL)或其任意组合。
存储器12作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序、非暂态性计算机可执行程序以及模块,如本申请实施例中的多任务学习方法对应的程序指令/模块。处理器11通过运行存储在存储器12中的非暂态软件程序、指令以及模块,可以实现下述任一方法实施例中的多任务学习方法。具体地,存储器12可以包括易失性存储器(volatile memory,VM),例如随机存取存储器(random access memory,RAM);存储器12也可以包括非易失性存储器(non-volatile memory,NVM),例如只读存储器(read-onlymemory,ROM),快闪存储器(flash memory),硬盘(hard disk drive,HDD)或固态硬盘(solid-state drive,SSD)或其他非暂态固态存储器件;存储器12还可以包括上述种类的存储器的组合。
在本申请实施例中,存储器12还可以包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
在本申请实施例中,电子设备10还可以具有有线或无线网络接口、键盘以及输入输出接口等部件,以便进行输入输出,电子设备10还可以包括其他用于实现设备功能的部件,在此不做赘述。
本申请实施例还提供了一种计算机可读存储介质,例如包括程序代码的存储器,上述程序代码可由处理器执行以完成上述实施例中的多任务学习方法。例如,该计算机可读存储介质可以是只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random AccessMemory,RAM)、只读光盘(Compact Disc Read-Only Memory,CDROM)、磁带、软盘和光数据存储设备等。
本申请实施例还提供了一种计算机程序产品,该计算机程序产品包括一条或多条程序代码,该程序代码存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该程序代码,处理器执行该程序代码,以完成上述实施例中提供的多任务学习方法的方法步骤。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来程序代码相关的硬件完成,该程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
通过以上的实施方式的描述,本领域普通技术人员可以清楚地了解到各实施方式可借助软件加通用硬件平台的方式来实现,当然也可以通过硬件。本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程是可以通过计算机程序来指令相关的硬件来完成,程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存储记忆体(Random Access Memory,RAM)等。
最后应说明的是:以上实施例仅用以说明本申请的技术方案,而非对其限制;在本申请的思路下,以上实施例或者不同实施例中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上述的本申请的不同方面的许多其它变化,为了简明,它们没有在细节中提供;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的范围。
Claims (10)
1.一种多任务学习方法,其特征在于,所述方法包括:
获取至少两个数据集,至少两个所述数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
基于标签平衡采样机制,对至少两个所述数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;
其中,所述标签平衡采样机制包括:
获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,所述任务平衡因子用于调节不同任务之间的标签数量,所述数据集平衡因子用于调节每一任务在不同数据集中的标签数量;
根据所述任务平衡因子以及所述数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;
根据所述采样率,对至少两个所述数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
3.根据权利要求1或2所述的方法,其特征在于,所述根据所述采样率,对至少两个所述数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据,包括:
在至少两个所述数据集中,对每一任务在每一数据集中的对应标签,采用随机数生成算法生成随机数;
若所述随机数小于所述采样率,则将该标签加入到第一标签集合;
在遍历所有数据集中的所有标签之后,生成每一任务对应的第二标签集合,将所述第二标签集合中所有标签对应的数据确定为每一任务用于训练的样本数据。
4.根据权利要求1所述的方法,其特征在于,所述将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果,包括:
将采样后的所有数据作为训练的样本数据输入到第一模型中进行单阶段的多任务训练;
将每一任务对应的用于训练的样本数据的预测结果加入到预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合;
组合每一任务的预测结果集合,将组合后的预测结果集合确定为多任务预测结果。
5.根据权利要求4所述的方法,其特征在于,所述第一模型包括至少两个任务分支,每一任务分支一一对应一个任务,所述将每一任务对应的用于训练的样本数据的预测结果加入到预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合,包括:
通过第一模型中的每一任务分支对与其对应的任务的每一样本数据进行预测,得到每一样本数据对应的预测结果,并将每一样本数据对应的预测结果加入到每一任务的预测结果集合,以同时得到至少两个任务中的每一任务的预测结果集合。
6.根据权利要求5所述的方法,其特征在于,至少两个任务分支共用同一个主干网络,所述主干网络用于每一任务分支对每一样本进行预测,以得到每一样本对应的预测结果。
9.一种多任务学习装置,其特征在于,所述装置包括:
数据集获取模块,用于获取至少两个数据集,至少两个所述数据集用于至少两个任务的训练,其中,每一任务对应一种标签,每一数据集包含的标签种类的数量不大于任务的数量;
预测结果确定模块,用于基于标签平衡采样机制,对至少两个所述数据集进行采样,将采样后的所有数据作为训练的样本数据输入到第一模型中进行多任务训练,以得到多任务预测结果;
其中,所述预测结果确定模块,包括:
平衡因子获取单元,用于获取每一任务对应的任务平衡因子以及每一任务在每一数据集中对应标签的数据集平衡因子,其中,所述任务平衡因子用于调节不同任务之间的标签数量,所述数据集平衡因子用于调节每一任务在不同数据集中的标签数量;
采样率确定单元,用于根据所述任务平衡因子以及所述数据集平衡因子,确定每一任务在每一数据集中对应标签的采样率;
训练样本确定单元,用于根据所述采样率,对至少两个所述数据集中的每一数据集进行采样,得到每一任务用于训练的样本数据。
10.一种电子设备,其特征在于,包括:
存储器以及一个或多个处理器,所述一个或多个处理器用于执行存储在所述存储器中的一个或多个计算机程序,所述一个或多个处理器在执行所述一个或多个计算机程序时,使得所述电子设备实现如权利要求1-8任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210307497.6A CN114723989A (zh) | 2022-03-25 | 2022-03-25 | 多任务学习方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210307497.6A CN114723989A (zh) | 2022-03-25 | 2022-03-25 | 多任务学习方法、装置及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114723989A true CN114723989A (zh) | 2022-07-08 |
Family
ID=82239301
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210307497.6A Pending CN114723989A (zh) | 2022-03-25 | 2022-03-25 | 多任务学习方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114723989A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115690544A (zh) * | 2022-11-11 | 2023-02-03 | 北京百度网讯科技有限公司 | 多任务学习方法及装置、电子设备和介质 |
GB2625073A (en) * | 2022-12-02 | 2024-06-12 | Sony Interactive Entertainment Inc | System and method for training a machine learning model |
-
2022
- 2022-03-25 CN CN202210307497.6A patent/CN114723989A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115690544A (zh) * | 2022-11-11 | 2023-02-03 | 北京百度网讯科技有限公司 | 多任务学习方法及装置、电子设备和介质 |
CN115690544B (zh) * | 2022-11-11 | 2024-03-01 | 北京百度网讯科技有限公司 | 多任务学习方法及装置、电子设备和介质 |
GB2625073A (en) * | 2022-12-02 | 2024-06-12 | Sony Interactive Entertainment Inc | System and method for training a machine learning model |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109583501B (zh) | 图片分类、分类识别模型的生成方法、装置、设备及介质 | |
CN109902706B (zh) | 推荐方法及装置 | |
WO2022083536A1 (zh) | 一种神经网络构建方法以及装置 | |
CN111279362B (zh) | 胶囊神经网络 | |
US11151417B2 (en) | Method of and system for generating training images for instance segmentation machine learning algorithm | |
WO2022042713A1 (zh) | 一种用于计算设备的深度学习训练方法和装置 | |
JP2021528796A (ja) | 活性スパース化を用いたニューラルネットワーク加速・埋め込み圧縮システム及び方法 | |
WO2022068623A1 (zh) | 一种模型训练方法及相关设备 | |
WO2022001805A1 (zh) | 一种神经网络蒸馏方法及装置 | |
US20230095606A1 (en) | Method for training classifier, and data processing method, system, and device | |
CN111352965B (zh) | 序列挖掘模型的训练方法、序列数据的处理方法及设备 | |
CN113705769A (zh) | 一种神经网络训练方法以及装置 | |
WO2021218470A1 (zh) | 一种神经网络优化方法以及装置 | |
WO2022111617A1 (zh) | 一种模型训练方法及装置 | |
CN114723989A (zh) | 多任务学习方法、装置及电子设备 | |
CN114997412A (zh) | 一种推荐方法、训练方法以及装置 | |
CN116664719B (zh) | 一种图像重绘模型训练方法、图像重绘方法及装置 | |
CN113392210A (zh) | 文本分类方法、装置、电子设备及存储介质 | |
WO2021042857A1 (zh) | 图像分割模型的处理方法和处理装置 | |
CN111368656A (zh) | 一种视频内容描述方法和视频内容描述装置 | |
CN114266897A (zh) | 痘痘类别的预测方法、装置、电子设备及存储介质 | |
WO2022156475A1 (zh) | 神经网络模型的训练方法、数据处理方法及装置 | |
CN111738403A (zh) | 一种神经网络的优化方法及相关设备 | |
WO2023185925A1 (zh) | 一种数据处理方法及相关装置 | |
CN115238909A (zh) | 一种基于联邦学习的数据价值评估方法及其相关设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |