CN111353541A - 一种多任务模型的训练方法 - Google Patents

一种多任务模型的训练方法 Download PDF

Info

Publication number
CN111353541A
CN111353541A CN202010138967.1A CN202010138967A CN111353541A CN 111353541 A CN111353541 A CN 111353541A CN 202010138967 A CN202010138967 A CN 202010138967A CN 111353541 A CN111353541 A CN 111353541A
Authority
CN
China
Prior art keywords
loss function
model
data
sample data
task
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010138967.1A
Other languages
English (en)
Inventor
张奎
陈清梁
王超
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Xinzailing Technology Co ltd
Original Assignee
Zhejiang Xinzailing Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Xinzailing Technology Co ltd filed Critical Zhejiang Xinzailing Technology Co ltd
Priority to CN202010138967.1A priority Critical patent/CN111353541A/zh
Publication of CN111353541A publication Critical patent/CN111353541A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Human Computer Interaction (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及一种多任务模型的训练方法,包括:S1.抽取多个数据集中的样本数据,其中,每个所述数据集中的所述样本数据均为单一任务的属性;S2.采用抽取的所述样本数据对神经网络模型进行训练,得到多任务模型,其中,所述神经网络模型为用于分类的神经网络模型;S3.基于所述多任务模型的输出结果获取所述多任务模型的总损失函数。本发明提出的多任务模型训练方法,对于数据集要求较低,允许每张图像仅包含一个任务的标签(即单一属性),对于特征接近的任务,比如人体属性等,能够在一个模型上进行合并训练。具有合并周期短,且保证精度不变的情况下,极大程度上提升计算效率。

Description

一种多任务模型的训练方法
技术领域
本发明涉及计算机机器学习技术领域,尤其涉及一种多任务模型的训练方法。
背景技术
深度学习作为机器学习领域中一个新的研究方向,目前已经在图像识别,语音识别,自然语言处理等相关领域都取得很多成果。但是由于深度学习模型计算复杂,效率低,而一般生产环境都有明确的性能指标,还有空间要求,比如内存等资源有限。如果对于一些相近的任务,比如人体属性中的性别、服饰类别等进行估计,往往都各自使用一个模型,无疑增加了计算量和资源占用。
近几年,多任务模型发展迅速。但是由于这些任务的训练数据往往都是独立的,每张图像不可能具有所有的属性标签,而数据标注的成本十分巨大,对于数据补标注不太现实。
例如,中国专利申请号CN201710603212,名称为“针对人体属性分类的自适应权重调整的模型训练方法”的方案,虽然也是一种多任务的模型训练方法,但是要求每张图像必须包含每个任务的标注信息,包括人脸属性和人体属性。主要创新点为引入了一个基于验证误差大小及变化趋势从而更新相应任务权重的算法,在训练过程中自适应动态地调整每个任务的相应权重值。从而取得较好的性能。但是也存在一定的缺点:
1)每张图像必须具有完整的人脸和人体属性,如果一开始这两个任务是分开标注的,需要补全标签,当数据集规模较大时成本较高;
2)对于人脸和人体,实际场景中,存在人体不完整,或者只有人体看不到人脸的情况,也就是说并不是每张图像都能有完整的标签。
发明内容
本发明的目的在于提供一种多任务模型的训练方法,对于数据集中的样本不需要完整的标签,资源占用低,计算效率高。
为实现上述发明目的,本发明提供一种多任务模型的训练方法,包括:
S1.抽取多个数据集中的样本数据,其中,每个所述数据集中的所述样本数据均为单一任务的属性;
S2.采用抽取的所述样本数据对神经网络模型进行训练,得到多任务模型,其中,所述神经网络模型为用于分类的神经网络模型;
S3.基于所述多任务模型的输出结果获取所述多任务模型的总损失函数。
根据本发明的一个方面,还包括:
S4.重复执行步骤S1-S3,根据所述多任务模型的输出结果计算所述总损失函数,并根据所述总损失函数对所述多任务模型的模型参数进行优化。
根据本发明的一个方面,在步骤S1中,以预设规则分别随机排列各所述数据集中的所述样本数据后抽取各所述数据集中的所述样本数据。
根据本发明的一个方面,步骤S1中,抽取多个数据集中的样本数据的步骤中,按照各所述数据集之间的比例抽取。
根据本发明的一个方面,步骤S1中,每个所述数据集之间的数据重叠率小于10%。
根据本发明的一个方面,步骤S3中,所述多任务模型的输出结果中包括与所述数据集相对应的子任务损失函数。
根据本发明的一个方面,步骤S3中,基于所述多任务模型的输出结果获取所述多任务模型的总损失函数的步骤中,所述总损失函数表示为:
L=wgLg+wcLc+woLo+…
其中,L表示总损失函数,Lg、Lc、Lo分别表示各子任务损失函数,wg、wc、wo分别表示所述子任务损失函数的权重。
根据本发明的一个方面,步骤S4中,重复执行步骤S1-S3,根据所述多任务模型的输出结果计算所述总损失函数,并根据所述总损失函数对所述多任务模型的模型参数进行优化的步骤中,包括:
S41.重复执行步骤S1-S3,获取多个所述多任务模型的输出结果;
S42.根据多个所述输出结果分别获取各子任务损失函数;
S43.根据各所述子任务损失函数对所述子任务损失函数的权重进行优化并更新所述总损失函数中的子任务损失函数的权重。
根据本发明的一种方案,本发明提出的多任务模型训练方法,对于数据集要求较低,允许每张图像仅包含一个任务的标签(即单一属性),对于特征接近的任务,比如人体属性等,能够在一个模型上进行合并训练。具有合并周期短,且保证精度不变的情况下,极大程度上提升计算效率。
附图说明
图1示意性表示根据本发明的一种实施方式的多任务模型的训练方法流程框图;
图2示意性表示根据本发明的一种实施方式的系统整体框图;
图3示意性表示根据本发明的一种实施方式的多数据集样本数据输入流程图;
图4示意性表示根据本发明的一种实施方式的神经网络模型的结构图;
图5示意性表示根据本发明的一种实施方式的神经网络模型的输出结果结构图。
具体实施方式
为了更清楚地说明本发明实施方式或现有技术中的技术方案,下面将对实施方式中所需要使用的附图作简单地介绍。显而易见地,下面描述中的附图仅仅是本发明的一些实施方式,对于本领域普通技术人员而言,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
在针对本发明的实施方式进行描述时,术语“纵向”、“横向”、“上”、“下”、“前”、“后”、“左”、“右”、“竖直”、“水平”、“顶”、“底”“内”、“外”所表达的方位或位置关系是基于相关附图所示的方位或位置关系,其仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此上述术语不能理解为对本发明的限制。
下面结合附图和具体实施方式对本发明作详细地描述,实施方式不能在此一一赘述,但本发明的实施方式并不因此限定于以下实施方式。
如图1所示,根据本发明的一种实施方式,本发明的一种多任务模型的训练方法,包括:
S1.抽取多个数据集中的样本数据,其中,每个数据集中的样本数据均为单一任务的属性;
S2.采用抽取的样本数据对神经网络模型进行训练,得到多任务模型,其中,神经网络模型为用于分类的神经网络模型;
S3.基于多任务模型的输出结果获取多任务模型的总损失函数。
根据本发明的一种实施方式,本发明的一种多任务模型的训练方法,还包括:
S4.重复执行步骤S1-S3,根据所述多任务模型的输出结果计算所述总损失函数,并根据所述总损失函数对所述多任务模型的模型参数进行优化。
根据本发明的一种实施方式,在步骤S1中,以预设规则分别随机排列各数据集中的样本数据后抽取各数据集中的样本数据。
根据本发明的一种实施方式,步骤S1中,抽取多个数据集中的样本数据的步骤中,按照各数据集所包含的样本数据量之间的比例抽取。例如,一个batch取32张图像,数据集1有100张样本数据,数据集2有300张样本数据,则数据集比例1:3,那这个batch的32张图像从数据集1中取8张,数据集2中取24张。通过抽取多个数据集中的样本数据实现后续训练过程的每个任务需要的数据输入。
通过上述设置,保证了数据集中存在多种类别时依然能够被充分训练,避免了遗漏对保证多任务模型的训练结果精度有利。
根据本发明的一种实施方式,步骤S1中,每个数据集之间的数据重叠率小于10%,即数据集中图像(即样本数据)包含所有属性的图像(即样本数据)占总图像(即样本数据)数量的比例不足10%。
根据本发明的一种实施方式,神经网络模型为用于分类的神经网络模型,其可以为vgg,resnet等任何可以用于分类的神经网络模型。在本实施方式中,该神经网络模型的全连接FC层输出为多个输出单元,每个单元与数据集中的细分类别相对应,例如,全连接FC层具有10个输出单元,则其中第一个单元用于表示性别,第2-6个单元用于表示服饰风格(即表示服饰属性的数据集中标注了5种风格),第7-10个单元用于表示身体朝向(即表示人体朝向属性的数据集中标注了前后左右4个方向)。
根据本发明的一种实施方式,步骤S3中,多任务模型的输出结果中包括与数据集相对应的子任务损失函数。例如,数据集共有表示性别属性、服饰属性、人体朝向属性三种,则子任务损失函数也相对应的生成三种。
根据本发明的一种实施方式,步骤S3中,基于多任务模型的输出结果获取多任务模型的总损失函数的步骤中,总损失函数表示为:
L=wgLg+wcLc+woLo+…
其中,L表示总损失函数,Lg、Lc、Lo分别表示各子任务损失函数,wg、wc、wo分别表示子任务损失函数的权重。
根据本发明的一种实施方式,步骤S4中,在重复执行步骤S1时,以预设规则分别随机排列各数据中的样本数据后抽取各数据集中的样本数据。
根据本发明的一种实施方式,步骤S4中,重复执行步骤S1-S3,根据多任务模型的输出结果对总损失函数进行优化的步骤中,包括:
S41.重复执行步骤S1-S3,获取多个多任务模型的输出结果;
S42.根据多个输出结果分别获取各子任务损失函数;
S43.根据各子任务损失函数的均值对子任务损失函数的权重进行优化并更新总损失函数中的子任务损失函数的权重。
通过上述设置,通过不断的对总损失函数的优化更新使得本发明的多任务模型的判断准确率能够保证在最佳状态,提高了判断效率。
如下以3个数据集为例对本发明的方法做进一步说明,在本实施例中,3个数据集分别表示性别属性、服饰属性和人体朝向属性。
如图2所示,根据本发明的一种实施方式,用于本发明的多任务模型的训练方法的系统整体包括数据集、待训练神经网络模型和损失函数。
如图3所示,根据本发明的一种实施方式,在执行步骤S1的过程中,包括以下步骤:
S11.获取三个数据集中的标注列表;
S12.通过预设的规则分别随机排列各数据集中的样本数据,使每个数据集中的标注列表的顺序被打乱;
S13.按照数据集之间的比例分别从3个列表中抽取样本数据。
如图3所示,根据本发明的一种实施方式,步骤S4中,在重复执行步骤S1的过程中,步骤S1中还包括:
S14.判断列表中数据是否全部抽取完,若是,则重新执行步骤S11-S13,否则重新执行步骤S13。
在此需要指出的是,在本方案中,多任务模型在训练时,是按照一轮接着一轮的方式进行,图3中在判断是否取将列表(即数据集)中数据全部取完的步骤中,若已取完则表示数据集中所有数据已被取了一遍,即一轮训练完成,进入下一轮训练。
根据本发明的一种实施方式,在步骤S13中抽取的样本数据被输入至神经网络模型进行训练。在本实施方式中,以神经网络模型resnet18为例(参见图4),在本方案中,其中全连接FC层输出为10个单元,即N=10,第一个单元用于表示性别,第2-6个单元用于表示服饰风格(本方案数据集中标注了5种风格),第7-10个单元用于表示身体朝向(前后左右)。需要指出的是,损失函数是指,用于模型训练中的损失函数,即计算模型估计值与真值之间的度量公式。
如图5所示,全连接FC层的输出单元可分为三个部分,分别与数据集相对应,)其中,Lg为性别子任务损失函数,其对应输出为FC的第一个单元,使用sigmoid激活函数,由于其是二分类,损失函数使用二值交叉熵损失函数;Lc为服饰子任务损失函数,其对应FC的第2-6个输出单元,由于其是单标签多类别,使用softmax激活,使用交叉熵损失函数;Lo为人体朝向子任务损失函数,其对应FC的第7-10个单元,由于其是单标签多类别,使用softmax激活,使用交叉熵损失函数。
根据本发明的一种实施方式,步骤S3中输出的总损失函数表示为:
L=wgLg+wcLc+woLo
其中wg、wc、wo为3个子任务损失函数的权重,初始值分别为0.3、0.35、0.35。
根据本发明的一种实施方式,步骤S4中,重复步骤S1-S3后实现每一轮训练,进而在每一轮训练完成后,根据该轮三个子任务每个子任务损失函数的均值Lgm,Lcm,Lom调整总损失函数中的权重值,即新的权重wg,wc,wo=softmax(Lgm,Lcm,Lom)。
实际测试中,使用本发明的方法训练的模型,其精度相对于原有模型的损失几乎可以忽略不记,但是对于一张图像,计算这3个不同的属性,计算耗时仅为原有3个模型分别估计的40%左右,GPU资源占用也降低了一半以上。
上述内容仅为本发明的具体方案的例子,对于其中未详尽描述的设备和结构,应当理解为采取本领域已有的通用设备及通用方法来予以实施。
以上所述仅为本发明的一个方案而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (8)

1.一种多任务模型的训练方法,包括:
S1.抽取多个数据集中的样本数据,其中,每个所述数据集中的所述样本数据均为单一任务的属性;
S2.采用抽取的所述样本数据对神经网络模型进行训练,得到多任务模型,其中,所述神经网络模型为用于分类的神经网络模型;
S3.基于所述多任务模型的输出结果获取所述多任务模型的总损失函数。
2.根据权利要求1所述的训练方法,其特征在于,还包括:
S4.重复执行步骤S1-S3,根据所述多任务模型的输出结果计算所述总损失函数,并根据所述总损失函数对所述多任务模型的模型参数进行优化。
3.根据权利要求2所述的训练方法,其特征在于,在步骤S1中,以预设规则分别随机排列各所述数据集中的所述样本数据后抽取各所述数据集中的所述样本数据。
4.根据权利要求1至3所述的训练方法,其特征在于,步骤S1中,抽取多个数据集中的样本数据的步骤中,按照各所述数据集所包含的所述样本数据量之间的比例抽取。
5.根据权利要求4所述的训练方法,其特征在于,步骤S1中,每个所述数据集之间的数据重叠率小于10%。
6.根据权利要求5所述的训练方法,其特征在于,步骤S3中,所述多任务模型的输出结果中包括与所述数据集相对应的子任务损失函数。
7.根据权利要求6所述的训练方法,其特征在于,步骤S3中,基于所述多任务模型的输出结果获取所述多任务模型的总损失函数的步骤中,所述总损失函数表示为:
L=wgLg+wcLc+woLo+…
其中,L表示总损失函数,Lg、Lc、Lo分别表示各子任务损失函数,wg、wc、wo分别表示所述子任务损失函数的权重。
8.根据权利要求7所述的训练方法,其特征在于,步骤S4中,重复执行步骤S1-S3,根据所述多任务模型的输出结果计算所述总损失函数,并根据所述总损失函数对所述多任务模型的模型参数进行优化的步骤中,包括:
S41.重复执行步骤S1-S3,获取多个所述多任务模型的输出结果;
S42.根据多个所述输出结果分别获取各子任务损失函数;
S43.根据各所述子任务损失函数的均值对所述子任务损失函数的权重进行优化并更新所述总损失函数中的子任务损失函数的权重。
CN202010138967.1A 2020-03-03 2020-03-03 一种多任务模型的训练方法 Pending CN111353541A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010138967.1A CN111353541A (zh) 2020-03-03 2020-03-03 一种多任务模型的训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010138967.1A CN111353541A (zh) 2020-03-03 2020-03-03 一种多任务模型的训练方法

Publications (1)

Publication Number Publication Date
CN111353541A true CN111353541A (zh) 2020-06-30

Family

ID=71197249

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010138967.1A Pending CN111353541A (zh) 2020-03-03 2020-03-03 一种多任务模型的训练方法

Country Status (1)

Country Link
CN (1) CN111353541A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111737640A (zh) * 2020-08-17 2020-10-02 深圳江行联加智能科技有限公司 水位预测方法、装置及计算机可读存储介质
CN113516239A (zh) * 2021-04-16 2021-10-19 Oppo广东移动通信有限公司 模型训练方法、装置、存储介质及电子设备
CN114898180A (zh) * 2022-05-12 2022-08-12 深圳市慧鲤科技有限公司 多任务神经网络的训练方法、多任务处理方法及装置

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109815826A (zh) * 2018-12-28 2019-05-28 新大陆数字技术股份有限公司 人脸属性模型的生成方法及装置
CN110188673A (zh) * 2019-05-29 2019-08-30 京东方科技集团股份有限公司 表情识别方法和装置

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109815826A (zh) * 2018-12-28 2019-05-28 新大陆数字技术股份有限公司 人脸属性模型的生成方法及装置
CN110188673A (zh) * 2019-05-29 2019-08-30 京东方科技集团股份有限公司 表情识别方法和装置

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111737640A (zh) * 2020-08-17 2020-10-02 深圳江行联加智能科技有限公司 水位预测方法、装置及计算机可读存储介质
CN111737640B (zh) * 2020-08-17 2021-08-27 深圳江行联加智能科技有限公司 水位预测方法、装置及计算机可读存储介质
CN113516239A (zh) * 2021-04-16 2021-10-19 Oppo广东移动通信有限公司 模型训练方法、装置、存储介质及电子设备
CN114898180A (zh) * 2022-05-12 2022-08-12 深圳市慧鲤科技有限公司 多任务神经网络的训练方法、多任务处理方法及装置

Similar Documents

Publication Publication Date Title
CN111639710B (zh) 图像识别模型训练方法、装置、设备以及存储介质
CN113326764B (zh) 训练图像识别模型和图像识别的方法和装置
CN109271521B (zh) 一种文本分类方法及装置
CN110659725B (zh) 神经网络模型的压缩与加速方法、数据处理方法及装置
CN109471945B (zh) 基于深度学习的医疗文本分类方法、装置及存储介质
CN112613581B (zh) 一种图像识别方法、系统、计算机设备和存储介质
CN109993102B (zh) 相似人脸检索方法、装置及存储介质
CN111353541A (zh) 一种多任务模型的训练方法
CN110852439A (zh) 神经网络模型的压缩与加速方法、数据处理方法及装置
CN112016450B (zh) 机器学习模型的训练方法、装置和电子设备
CN112464865A (zh) 一种基于像素和几何混合特征的人脸表情识别方法
JP7403909B2 (ja) 系列マイニングモデルの訓練装置の動作方法、系列データの処理装置の動作方法、系列マイニングモデルの訓練装置、系列データの処理装置、コンピュータ機器、及びコンピュータプログラム
CN110738102A (zh) 一种人脸识别方法及系统
WO2020260862A1 (en) Facial behaviour analysis
CN113128671B (zh) 一种基于多模态机器学习的服务需求动态预测方法及系统
CN110110724A (zh) 基于指数型挤压函数驱动胶囊神经网络的文本验证码识别方法
US20240185025A1 (en) Flexible Parameter Sharing for Multi-Task Learning
CN114266897A (zh) 痘痘类别的预测方法、装置、电子设备及存储介质
TWI824485B (zh) 最佳化神經網路模型的方法
Terziyan et al. Causality-aware convolutional neural networks for advanced image classification and generation
CN116542321B (zh) 基于扩散模型的图像生成模型压缩和加速方法及系统
CN113569955A (zh) 一种模型训练方法、用户画像生成方法、装置及设备
WO2021059527A1 (ja) 学習装置、学習方法、及び、記録媒体
CN115082840A (zh) 基于数据组合和通道相关性的动作视频分类方法和装置
CN112560712A (zh) 基于时间增强图卷积网络的行为识别方法、装置及介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20200630