CN114816719B - 多任务模型的训练方法及装置 - Google Patents
多任务模型的训练方法及装置 Download PDFInfo
- Publication number
- CN114816719B CN114816719B CN202210717146.2A CN202210717146A CN114816719B CN 114816719 B CN114816719 B CN 114816719B CN 202210717146 A CN202210717146 A CN 202210717146A CN 114816719 B CN114816719 B CN 114816719B
- Authority
- CN
- China
- Prior art keywords
- training
- task
- model
- tasks
- parameter
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F9/00—Arrangements for program control, e.g. control units
- G06F9/06—Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
- G06F9/46—Multiprogramming arrangements
- G06F9/48—Program initiating; Program switching, e.g. by interrupt
- G06F9/4806—Task transfer initiation or dispatching
- G06F9/4843—Task transfer initiation or dispatching by program, e.g. task dispatcher, supervisor, operating system
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Theoretical Computer Science (AREA)
- General Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Data Mining & Analysis (AREA)
- Computing Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Feedback Control In General (AREA)
Abstract
本公开关于一种多任务模型的训练方法及装置,可用于自动驾驶领域。其中,该方法包括:确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络;根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成;获取多任务模型的第二数据集和第二训练参数;根据第二数据集和第二训练参数对多任务模型进行训练。本公开根据任务对应的第一数据集和第一训练参数依次对单个任务进行训练,再根据整个模型对应的第二数据集和第二训练参数对整个模型进行训练,以应对不同任务之间的训练差异,解决任务间训练差异过大导致训练奔溃的问题。
Description
技术领域
本公开涉及自动驾驶技术领域,尤其涉及一种多任务模型的训练方法及装置。
背景技术
目前,随着人工智能的不断发展,对智能化的要求越来越高,在具体的应用场景中,一项落地的人工智能服务往往包含多种不同类型的处理任务,以通过各个处理任务共同协作从而更加精确、高效地完成服务内容。例如,在自动驾驶场景中,精确地感知周围环境是实现自动驾驶至关重要的一环,往往通过检测或识别道路中各种类型的元素:车道线、红绿灯、标牌、路面标识和可行驶区域等,其中每一种元素的检测或识别可以作为一个感知任务,通过多任务模型,使得各个任务共享一个主干网络,共用从主干网络得到的特征图,基于这些特征图再进行不同任务各自的处理过程,从而提升整体效率。
发明内容
本公开提供一种多任务模型的训练方法、装置、电子设备及计算机可读存储介质,以至少解决如何应对不同任务之间的训练差异,在任务之间训练冲突过大的情况下实现多任务模型训练的问题。本公开的技术方案如下:
根据本公开实施例的第一方面,提供一种多任务模型的训练方法,包括:确定所述多任务模型对应的多个任务,并获取所述任务的第一数据集和第一训练参数,所述多任务模型包括主干网络和所述任务对应的分支网络;根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成;获取所述多任务模型的第二数据集和第二训练参数;根据所述第二数据集和所述第二训练参数对所述多任务模型进行训练。
在本公开的一个实施例中,所述根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成,包括:基于预设的任务训练顺序,根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成。
在本公开的一个实施例中,所述第一训练参数中包括:所述任务对应的第一优化器的类型和第一超参数,所述第一超参数中包括:所述任务对应的第一学习率、第一批量大小和第一训练轮次;所述第二训练参数中包括:所述多任务模型对应的第二优化器的类型和第二超参数,所述第二超参数中包括:所述多任务模型对应的第二学习率、第二批量大小和第二训练轮次。
在本公开的一个实施例中,所述根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,包括:响应于当前任务为所述任务训练顺序中的第一个任务,固定除了所述当前任务之外的其他任务对应的第一模型参数,所述第一模型参数为所述任务对应的分支网络的参数;以及根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数和第二模型参数,所述第二模型参数为所述主干网络的参数。
在本公开的一个实施例中,所述根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,还包括:响应于所述当前任务为除了所述第一个任务之外的其他任务,固定除了所述当前任务之外的其他任务对应的所述第一模型参数和所述第二模型参数;以及根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数。
在本公开的一个实施例中,所述获取所述多任务模型的第二数据集和第二训练参数,包括:将所述多个任务的多个所述第一数据集的集合,确定为所述第二数据集;以及将所述多个任务的多个所述第一学习率的总和,确定为所述第二学习率;以及根据所述第一数据集的大小,确定所述第二批量大小;以及将所述多个任务的多个所述第一训练轮次中的最小值,确定为所述第二训练轮次。
在本公开的一个实施例中,所述根据所述第二数据集和所述第二训练参数对所述多任务模型进行训练,包括:将所述任务的所述第一数据集作为所述多任务模型的单批训练数据;基于所述多任务模型的多个所述单批训练数据和所述第二训练参数,对所述多任务模型进行所述第二训练轮次的循环迭代训练。
在本公开的一个实施例中,基于所述多任务模型的多个所述单批训练数据和所述第二训练参数,对所述多任务模型进行一个轮次的循环迭代训练,包括:基于所述任务训练顺序,获取当前序数对应的所述单批训练数据;根据所述单批训练数据对所述多任务模型进行一次迭代训练,并在本次迭代训练结束后,更新所述多任务模型的模型参数;根据下一序数对应的所述单批训练数据对参数更新后的所述多任务模型进行下一次迭代训练,直至完成本轮次的循环迭代训练。
根据本公开实施例的第二方面,提供一种多任务模型的训练装置,包括:第一获取模块,被配置为执行确定所述多任务模型对应的多个任务,并获取所述任务的第一数据集和第一训练参数,所述多任务模型包括主干网络和所述任务对应的分支网络;第一训练模块,被配置为执行根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成;第二获取模块,被配置为执行获取所述多任务模型的第二数据集和第二训练参数;第二训练模块,被配置为执行根据所述第二数据集和所述第二训练参数对所述多任务模型进行训练。
在本公开的一个实施例中,所述第一训练模块,还被配置为执行:基于预设的任务训练顺序,根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成。
在本公开的一个实施例中,所述第一训练参数中包括:所述任务对应的第一优化器的类型和第一超参数,所述第一超参数中包括:所述任务对应的第一学习率、第一批量大小和第一训练轮次;所述第二训练参数中包括:所述多任务模型对应的第二优化器的类型和第二超参数,所述第二超参数中包括:所述多任务模型对应的第二学习率、第二批量大小和第二训练轮次。
在本公开的一个实施例中,所述第一训练模块,还被配置为执行:响应于当前任务为所述任务训练顺序中的第一个任务,固定除了所述当前任务之外的其他任务对应的第一模型参数,所述第一模型参数为所述任务对应的分支网络的参数;以及根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数和第二模型参数,所述第二模型参数为所述主干网络的参数。
在本公开的一个实施例中,所述第一训练模块,还被配置为执行:响应于所述当前任务为除了所述第一个任务之外的其他任务,固定除了所述当前任务之外的其他任务对应的所述第一模型参数和所述第二模型参数;以及根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数。
在本公开的一个实施例中,所述第二获取模块,还被配置为执行:将所述多个任务的多个所述第一数据集的集合,确定为所述第二数据集;以及将所述多个任务的多个所述第一学习率的总和,确定为所述第二学习率;以及根据所述第一数据集的大小,确定所述第二批量大小;以及将所述多个任务的多个所述第一训练轮次中的最小值,确定为所述第二训练轮次。
在本公开的一个实施例中,所述第二训练模块,还被配置为执行:将所述任务的所述第一数据集作为所述多任务模型的单批训练数据;基于所述多任务模型的多个所述单批训练数据和所述第二训练参数,对所述多任务模型进行所述第二训练轮次的循环迭代训练。
在本公开的一个实施例中,所述第二训练模块,还被配置为执行:基于所述任务训练顺序,获取当前序数对应的所述单批训练数据;根据所述单批训练数据对所述多任务模型进行一次迭代训练,并在本次迭代训练结束后,更新所述多任务模型的模型参数;根据下一序数对应的所述单批训练数据对参数更新后的所述多任务模型进行下一次迭代训练,直至完成本轮次的循环迭代训练。
根据本公开实施例的第三方面,提供一种电子设备,包括:处理器;用于存储所述处理器的可执行指令的存储器;其中,所述处理器被配置为执行所述指令,以实现如本公开实施例第一方面所述的多任务模型的训练方法。
根据本公开实施例的第四方面,提供一种计算机可读存储介质,当所述计算机可读存储介质中的指令由电子设备的处理器执行时,使得电子设备能够执行如本公开实施例第一方面所述的多任务模型的训练方法。
本公开的实施例提供的技术方案至少带来以下有益效果:确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络;根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成;获取多任务模型的第二数据集和第二训练参数;根据第二数据集和第二训练参数对多任务模型进行训练。本公开实施例根据任务对应的第一数据集和第一训练参数依次对单个任务进行训练,再根据整个模型对应的第二数据集和第二训练参数对整个模型进行训练,以应对不同任务之间的训练差异,解决任务间训练差异过大导致训练奔溃的问题。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理,并不构成对本公开的不当限定。
图1是根据一示例性实施例示出的一种多任务模型的训练方法的流程图。
图2是根据另一示例性实施例示出的一种多任务模型的训练方法的流程图。
图3是根据另一示例性实施例示出的一种多任务模型的训练方法的流程图。
图4是根据另一示例性实施例示出的一种多任务模型的训练方法的流程图。
图5是根据另一示例性实施例示出的一种多任务模型的训练方法的流程图。
图6是根据另一示例性实施例示出的一种多任务模型的训练方法的流程图。
图7是根据一示例性实施例示出的一种多任务模型的训练装置的框图。
图8是根据一示例性实施例示出的一种电子设备的框图。
具体实施方式
为了使本领域普通人员更好地理解本公开的技术方案,下面将结合附图,对本公开实施例中的技术方案进行清楚、完整地描述。
需要说明的是,本公开的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本公开的实施例能够以除了在这里图示或描述的那些以外的顺序实施。以下示例性实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本公开的一些方面相一致的装置和方法的例子。
相关技术中,由于不同的任务往往会存在数据标注差异,损失函数差异,数据比例差异等不可预测的问题,这就给多任务训练带来了很大的挑战,在训练时任务之间训练冲突过大常常导致训练直接奔溃,难以完成多任务模型的训练。
图1是根据一示例性实施例示出的一种多任务模型的训练方法的流程图,如图1所示,本公开实施例的多任务模型的训练方法,可以包括以下步骤:
S101,确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络。
需要说明的是,本公开实施例的多任务模型的训练方法的执行主体为本公开实施例提供的多任务模型的训练装置,该装置可以设置在电子设备中,例如手机、车载终端和台式电脑等,以执行本公开实施例的多任务模型的训练方法。
在本公开实施例中,多任务模型包括主干网络和分支网络,每个分支网络对应一个任务,因此多任务模型能够实现不同任务的处理过程。
例如,主干网络可以理解为模型中的backbone,用来做特征提取的网络,代表整个多任务模型的一部分,一般是用于前端提取图片信息,生成特征图(feature map)供后面的网络使用;分支网络可以理解为模型中的head,head位于backbone之后,是用于利用之前提取的特征获取输出内容的网络,head利用这些特征,做出预测。
在多任务模型训练之前,需要确定多任务模型对应的多个任务,每个任务可以对应有不同的数据集和训练参数,其中,训练参数可以包括但不限于该任务对应的优化器的类型和超参数。因此我们需要获取每个任务的数据集(即该任务的第一数据集)和训练参数(即该任务的第一训练参数),其中第一训练参数中可以包括但不限于任务对应的第一优化器的类型和第一超参数等,第一超参数中可以包括但不限于:任务对应的第一学习率、第一批量大小和第一训练轮次等。
例如,多任务模型对应有M个任务,每个任务有对应的第一数据集为,以及该任务训练的训练轮次(即上述该任务的第一训练轮次),每个任务的head有不同的学习率(即上述该任务的第一学习率)和优化器(即上述该任务的第一优化器)和批大小(batch size,即上述该任务的第一批量大小)。
S102,根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成。
在本公开实施例中,根据单个任务对应的第一数据集和第一训练参数,对单个任务进行训练,即对任务对应的分支网络(head)进行训练,在该任务训练完成后,根据下一个任务对应的第一数据集和第一训练参数对下一个任务对应的分支网络进行训练,直至多任务模型对应的多个任务全部训练完成。
S103,获取多任务模型的第二数据集和第二训练参数。
在本公开实施例中,在单个任务依次训练完成后,对整个多任务模型进行训练,该多任务模型拥有对应的数据集(即第二数据集)和训练参数(即第二训练参数),其中第二训练参数中可以包括但不限于多任务模型对应的第二优化器的类型和第二超参数,第二超参数中可以包括但不限于:多任务模型对应的第二学习率、第二批量大小和第二训练轮次。
需要说明的是,上述数据集、训练参数、优化器、超参数、学习率、批量大小以及训练轮次对应的“第一”和“第二”只是为了区分单个任务对应的数据和参数以及整个多任务模型对应的数据和参数。
S104,根据第二数据集和第二训练参数对多任务模型进行训练。
在本公开实施例中,根据多任务模型的第二数据集和第二训练参数对整个多任务模型进行训练。
本公开的实施例提供的多任务模型的训练方法,确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络;根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成;获取多任务模型的第二数据集和第二训练参数;根据第二数据集和第二训练参数对多任务模型进行训练。本公开实施例根据任务对应的第一数据集和第一训练参数依次对单个任务进行训练,再根据整个模型对应的第二数据集和第二训练参数对整个模型进行训练,以应对不同任务之间的训练差异,解决任务间训练差异过大导致训练奔溃的问题
在上述实施例的基础上,可以在进行单个任务的训练之前,对多个任务的训练顺序进行排序,例如按照任务的复杂度进行排序,或者人工指定训练顺序,以此来得到预设的任务训练顺序。基于预设的任务训练顺序,根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成。
在一些实施例中,可以将每个任务的任务标识,按照预设的训练顺序,存放在训练队列中,根据训练队列弹出的任务标识,对对应的任务进行训练,在当前任务训练完成后,从训练队列中弹出下一个任务标识,以此完成对全部任务的训练过程,其中在对单个任务开始训练之前,还可以包括对任务对应的数据集和优化器等进行初始化的过程。
在对单个任务和整个多任务模型进行训练时,需要在训练过程中对对应的模型参数(包括主干网络的参数和分支网络的参数)进行更新,以得到训练好的多任务模型,因此,本公开实施例还包括在多任务模型的训练过程中对模型参数进行更新的方法。
在上述实施例的基础上,如图2所示,上述步骤“根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练”可以包括以下步骤:
S201,响应于当前任务为任务训练顺序中的第一个任务,固定除了当前任务之外的其他任务对应的第一模型参数,第一模型参数为任务对应的分支网络的参数。
在本公开实施例中,基于上述预设的任务训练顺序,判断当前任务是否为任务训练顺序中的第一个任务,若当前任务为第一个任务,则固定除了该任务之外的其他任务对应的第一模型参数,即固定其他分支网络的参数,在对该任务对应的分支网络进行训练时,对于除了该任务对应的分支网络之外的其他分支网络不进行参数更新。简单理解为冻结除了该任务之外的其他任务的head,不冻结该任务的head和主干网络backbone。
S202,根据当前任务的第一数据集和第一训练参数,对当前任务进行循环迭代训练,并在每次迭代训练结束后,更新当前任务对应的第一模型参数和第二模型参数,第二模型参数为主干网络的参数。
在本公开实施例中,在对除了当前任务之外的其他任务对应的第一模型参数进行固定的情况下,基于当前任务的第一数据集和第一训练参数对当前任务对应的分支网络进行训练,此外还包括对主干网络进行训练。
在一些实施例中,根据第一训练参数中的第一批量大小,将第一数据集划分为多个批次的训练数据,利用1个批次的训练数据进行一次迭代训练,直至将整个第一数据集的数据全部训练完成后结束一个轮次的训练过程。以此根据第一训练参数中的第一训练轮次,完成多轮迭代训练。其中,在每一次迭代训练结束后,更新当前任务对应的第一模型参数和主干网络的参数(即第二模型参数),基于参数更新后的多任务模型继续进行下一次迭代训练。以此完成当前任务对应的训练过程。
在上述实施例的基础上,若当前任务不为上述第一个任务,如图3所示,上述步骤“根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练”还包括以下步骤:
S301,响应于当前任务为除了第一个任务之外的其他任务,固定除了当前任务之外的其他任务对应的第一模型参数和第二模型参数。
在本公开实施例中,基于预设的任务训练顺序,判断当前任务是否为训练顺序中的第一个任务,若当前任务为除了第一个任务之外的其他任务,则固定除了当前任务之外的其他任务对应的第一模型参数和主干网络对应的第二模型参数。
S302,根据当前任务的第一数据集和第一训练参数,对当前任务进行循环迭代训练,并在每次迭代训练结束后,更新当前任务对应的第一模型参数。
在本公开实施例中,根据当前任务的第一数据集和第一训练参数,对当前任务进行循环迭代训练,具体训练过程与上述第一个任务的训练过程相同,此处不再赘述,不同的是,在对该当前任务进行循环迭代训练时,在每次迭代训练结束后,只对当前任务对应的第一模型参数进行更新。
由此可知,本公开实施例在对第一个任务进行训练时,固定除了第一个任务之外的其他任务对应的分支网络的参数,在每次迭代训练后更新该任务对应的分支网络的参数和主干网络的参数;而在对除了第一个任务之外的其他任务进行训练时,需要固定主干网络的参数和除了该任务之外的其他任务对应的分支网络的参数,在每次迭代训练后,更新该任务对应的分支网络的参数。以此完成对多个任务的单独训练过程。
在上述实施例的基础上,如图4所示,上述步骤S103中“获取多任务模型的第二数据集和第二训练参数”,可以包括以下步骤:
S401,将多个任务的多个第一数据集的集合,确定为第二数据集。
在本公开实施例中,对全部任务的第一数据集求并集,即将每个第一数据集放入一个集合中,将该集合作为整个多任务模型对应的第二数据集。
S402,将多个任务的多个第一学习率的总和,确定为第二学习率。
S403,根据第一数据集的大小,确定第二批量大小。
在本公开实施例中,将每一个第一数据集作为多任务模型整体训练的单批训练数据,即每个batch为一个任务的数据集,将第一数据集的大小确定为batch对应的第二批量大小。由于第一数据集的大小可能不同,因此对多任务模型进行整体训练时,该多任务模型对应的第二批量大小的值为多个。在一些实施例中,可以不对第二批量大小这一超参数进行限定,只需将第一数据集作为一个单批训练数据即可。
S404,将多个任务的多个第一训练轮次中的最小值,确定为第二训练轮次。
在上述实施例的基础上,如图5所示,上述步骤S104中“根据第二数据集和第二训练参数对多任务模型进行训练”,可以包括以下步骤:
S501,将任务的第一数据集作为多任务模型的单批训练数据。
本公开实施例中,由于不同类型任务的数据集可能差异很大,无法存在一个batch(即单批训练数据)里面进行一次batch的前向传播,因此在一个batch里面只有一个任务的数据集。
S502,基于多任务模型的多个单批训练数据和第二训练参数,对多任务模型进行第二训练轮次的循环迭代训练。
本公开实施例中,基于多任务模型的多个单批训练数据、第二训练参数中的第二优化器的类型和第二学习率,对整个多任务模型进行一个轮次的循环迭代训练,直至完成第二训练轮次的循环迭代训练。
在上述实施例的基础上,如图6所示,本公开实施例还公开了基于多任务模型的多个单批训练数据和第二训练参数,对多任务模型进行一个轮次的循环迭代训练的过程,可以包括以下步骤:
S601,基于任务训练顺序,获取当前序数对应的单批训练数据。
在本公开实施例中,根据上述单个任务训练时的任务训练顺序,获取当前序数对应的单批训练数据,其中,序数可以理解为表示事物次第的数目,如顺序中的第一、第二、第三……。即根据上述实施例中的任务训练顺序依次获取任务的第一数据集。
S602,根据单批训练数据对多任务模型进行一次迭代训练,并在本次迭代训练结束后,更新多任务模型的模型参数。
例如,以当前序数对应的单批训练数据,以第二学习率,以Adam优化算法为第二优化器,对多任务模型进行一次迭代训练,在本次迭代训练结束后,更新整个多任务模型的模型参数,其中包括主干网络的参数和各个分支网络的参数。
S603,根据下一序数对应的单批训练数据对参数更新后的多任务模型进行下一次迭代训练,直至完成本轮次的循环迭代训练。
在本次参数更新完成后,继续获取下一序数对应的单批训练数据,以下一序数对应的单批训练数据对参数更新后的多任务模型进行下一次迭代训练,并在迭代训练结束后更新模型参数,直至利用全部的单批训练数据完成本轮次的循环迭代训练。
由此,本公开实施例根据任务对应的数据集和训练数据依次对单个任务进行训练,再通过整个模型的数据集和训练数据对整个模型进行训练,以应对不同任务之间的训练差异,解决任务间梯度回传方向和大小差异过大导致训练奔溃的问题,实现渐进式训练过程,该方法对多任务模型没有限制,通用性较好。
图7是根据一示例性实施例示出的一种多任务模型的训练装置的框图。如图7所示,本公开实施例的多任务模型的训练装置700,包括:第一获取模块701、第一训练模块702、第二获取模块703和第二训练模块704。
第一获取模块701,被配置为执行确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络。
第一训练模块702,被配置为执行根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成。
第二获取模块703,被配置为执行获取多任务模型的第二数据集和第二训练参数。
第二训练模块704,被配置为执行根据第二数据集和第二训练参数对多任务模型进行训练。
在本公开的一个实施例中,第一训练模块702,还被配置为执行:基于预设的任务训练顺序,根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成。
在本公开的一个实施例中,第一训练参数中包括:任务对应的第一优化器的类型和第一超参数,第一超参数中包括:任务对应的第一学习率、第一批量大小和第一训练轮次;第二训练参数中包括:多任务模型对应的第二优化器的类型和第二超参数,第二超参数中包括:多任务模型对应的第二学习率、第二批量大小和第二训练轮次。
在本公开的一个实施例中,第一训练模块702,还被配置为执行:响应于当前任务为任务训练顺序中的第一个任务,固定除了当前任务之外的其他任务对应的第一模型参数,第一模型参数为任务对应的分支网络的参数;以及根据当前任务的第一数据集和第一训练参数,对当前任务进行循环迭代训练,并在每次迭代训练结束后,更新当前任务对应的第一模型参数和第二模型参数,第二模型参数为主干网络的参数。
在本公开的一个实施例中,第一训练模块702,还被配置为执行:响应于当前任务为除了第一个任务之外的其他任务,固定除了当前任务之外的其他任务对应的第一模型参数和第二模型参数;以及根据当前任务的第一数据集和第一训练参数,对当前任务进行循环迭代训练,并在每次迭代训练结束后,更新当前任务对应的第一模型参数。
在本公开的一个实施例中,第二获取模块703,还被配置为执行:将多个任务的多个第一数据集的集合,确定为第二数据集;以及将多个任务的多个第一学习率的总和,确定为第二学习率;以及根据第一数据集的大小,确定第二批量大小;以及将多个任务的多个第一训练轮次中的最小值,确定为第二训练轮次。
在本公开的一个实施例中,第二训练模块704,还被配置为执行:将任务的第一数据集作为多任务模型的单批训练数据;基于多任务模型的多个单批训练数据和第二训练参数,对多任务模型进行第二训练轮次的循环迭代训练。
在本公开的一个实施例中,第二训练模块704,还被配置为执行:基于任务训练顺序,获取当前序数对应的单批训练数据;根据单批训练数据对多任务模型进行一次迭代训练,并在本次迭代训练结束后,更新多任务模型的模型参数;根据下一序数对应的单批训练数据对参数更新后的多任务模型进行下一次迭代训练,直至完成本轮次的循环迭代训练。
关于上述实施例中的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
本公开的实施例提供的多任务模型的训练装置,确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络;根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成;获取多任务模型的第二数据集和第二训练参数;根据第二数据集和第二训练参数对多任务模型进行训练。本公开实施例根据任务对应的第一数据集和第一训练参数依次对单个任务进行训练,再根据整个模型对应的第二数据集和第二训练参数对整个模型进行训练,以应对不同任务之间的训练差异,解决任务间训练差异过大导致训练奔溃的问题。
图8是根据一示例性实施例示出的一种电子设备的框图。
如图8所示,上述电子设备800包括:
存储器801及处理器802,连接不同组件(包括存储器801和处理器802)的总线803,存储器801存储有计算机程序,当处理器802执行程序时实现本公开实施例的多任务模型的训练方法。
总线803表示几类总线结构中的一种或多种,包括存储器总线或者存储器控制器,外围总线,图形加速端口,处理器或者使用多种总线结构中的任意总线结构的局域总线。举例来说,这些体系结构包括但不限于工业标准体系结构(ISA)总线,微通道体系结构(MAC)总线,增强型ISA总线、视频电子标准协会(VESA)局域总线以及外围组件互连(PCI)总线。
电子设备800典型地包括多种电子设备可读介质。这些介质可以是任何能够被电子设备800访问的可用介质,包括易失性和非易失性介质,可移动的和不可移动的介质。
存储器801还可以包括易失性存储器形式的计算机系统可读介质,例如随机存取存储器(RAM)804和/或高速缓存存储器805。电子设备800可以进一步包括其它可移动/不可移动的、易失性/非易失性计算机系统存储介质。仅作为举例,存储系统806可以用于读写不可移动的、非易失性磁介质(图8未显示,通常称为“硬盘驱动器”)。尽管图8中未示出,可以提供用于对可移动非易失性磁盘(例如“软盘”)读写的磁盘驱动器,以及对可移动非易失性光盘(例如CD-ROM, DVD-ROM或者其它光介质)读写的光盘驱动器。在这些情况下,每个驱动器可以通过一个或者多个数据介质接口与总线803相连。存储器801可以包括至少一个程序产品,该程序产品具有一组(例如至少一个)程序模块,这些程序模块被配置以执行本公开各实施例的功能。
具有一组(至少一个)程序模块807的程序/实用工具808,可以存储在例如存储器801中,这样的程序模块807包括——但不限于——操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。程序模块807通常执行本公开所描述的实施例中的功能和/或方法。
电子设备800也可以与一个或多个外部设备809(例如键盘、指向设备、显示器810等)通信,还可与一个或者多个使得用户能与该电子设备800交互的设备通信,和/或与使得该电子设备800能与一个或多个其它计算设备进行通信的任何设备(例如网卡,调制解调器等等)通信。这种通信可以通过输入/输出接口812进行。并且,电子设备800还可以通过网络适配器813与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图8所示,网络适配器813通过总线803与电子设备800的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备800使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统等。
处理器802通过运行存储在存储器801中的程序,从而执行各种功能应用以及数据处理。
需要说明的是,本实施例的电子设备的实施过程和技术原理参见前述对本公开实施例的多任务模型的训练方法的解释说明,此处不再赘述。
本公开实施例提供的电子设备,确定多任务模型对应的多个任务,并获取任务的第一数据集和第一训练参数,多任务模型包括主干网络和任务对应的分支网络;根据第一数据集和第一训练参数,依次对任务对应的分支网络进行训练,直至多个任务全部训练完成;获取多任务模型的第二数据集和第二训练参数;根据第二数据集和第二训练参数对多任务模型进行训练。本公开实施例根据任务对应的第一数据集和第一训练参数依次对单个任务进行训练,再根据整个模型对应的第二数据集和第二训练参数对整个模型进行训练,以应对不同任务之间的训练差异,解决任务间训练差异过大导致训练奔溃的问题。
为了实现上述实施例,本公开还提出一种计算机可读存储介质。
其中,该计算机可读存储介质中的指令由电子设备的处理器执行时,使得电子设备能够执行如前的多任务模型的训练方法。可选的,计算机可读存储介质可以是ROM、随机存取存储器(RAM)、CD-ROM、磁带、软盘和光数据存储设备等。
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本公开的其它实施方案。本公开旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的真正范围和精神由下面的权利要求指出。
应当理解的是,本公开并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本公开的范围仅由所附的权利要求来限制。
Claims (14)
1.一种多任务模型的训练方法,所述多任务模型运用于自动驾驶场景中检测或识别道路中各种类型的元素,其特征在于,所述各种类型的元素中,每一种元素的检测或识别作为一个感知任务,包括:
确定所述多任务模型对应的多个任务,并获取所述任务的第一数据集和第一训练参数,所述多任务模型包括主干网络和所述任务对应的分支网络,所述多任务模型对应的多个任务共享所述主干网络,共用从所述主干网络得到的特征图;所述第一训练参数中包括:所述任务对应的第一优化器的类型和第一超参数,所述第一超参数中包括:所述任务对应的第一学习率、第一批量大小和第一训练轮次;
根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成;
获取所述多任务模型的第二数据集和第二训练参数,所述第二训练参数中包括:所述多任务模型对应的第二优化器的类型和第二超参数,所述第二超参数中包括:所述多任务模型对应的第二学习率、第二批量大小和第二训练轮次;
根据所述第二数据集和所述第二训练参数对所述多任务模型进行训练;
所述获取所述多任务模型的第二数据集和第二训练参数,包括:
将所述多任务模型的全部任务的第一数据集求并集,确定为所述第二数据集;以及
将所述多任务模型的全部任务的第一学习率的总和,确定为所述第二学习率;以及
根据所述第一数据集的大小,确定所述第二批量大小;以及
将所述多个任务的多个所述第一训练轮次中的最小值,确定为所述第二训练轮次。
2.根据权利要求1所述的训练方法,其特征在于,所述根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成,包括:
基于预设的任务训练顺序,根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成。
3.根据权利要求2 所述的训练方法,其特征在于,所述根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,包括:
响应于当前任务为所述任务训练顺序中的第一个任务,固定除了所述当前任务之外的其他任务对应的第一模型参数,所述第一模型参数为所述任务对应的分支网络的参数;以及
根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数和第二模型参数,所述第二模型参数为所述主干网络的参数。
4.根据权利要求3所述的训练方法,其特征在于,所述根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,还包括:
响应于所述当前任务为除了所述第一个任务之外的其他任务,固定除了所述当前任务之外的其他任务对应的所述第一模型参数和所述第二模型参数;以及
根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数。
5.根据权利要求2所述的训练方法,其特征在于,所述根据所述第二数据集和所述第二训练参数对所述多任务模型进行训练,包括:
将所述任务的所述第一数据集作为所述多任务模型的单批训练数据;
基于所述多任务模型的多个所述单批训练数据和所述第二训练参数,对所述多任务模型进行所述第二训练轮次的循环迭代训练。
6.根据权利要求5所述的训练方法,其特征在于,基于所述多任务模型的多个所述单批训练数据和所述第二训练参数,对所述多任务模型进行一个轮次的循环迭代训练,包括:
基于所述任务训练顺序,获取当前序数对应的所述单批训练数据;
根据所述单批训练数据对所述多任务模型进行一次迭代训练,并在本次迭代训练结束后,更新所述多任务模型的模型参数;
根据下一序数对应的所述单批训练数据对参数更新后的所述多任务模型进行下一次迭代训练,直至完成本轮次的循环迭代训练。
7.一种多任务模型的训练装置,所述多任务模型运用于自动驾驶场景中检测或识别道路中各种类型的元素,其特征在于,所述各种类型的元素中,每一种元素的检测或识别作为一个感知任务,包括:
第一获取模块,被配置为执行确定所述多任务模型对应的多个任务,并获取所述任务的第一数据集和第一训练参数,所述多任务模型包括主干网络和所述任务对应的分支网络,所述多任务模型对应的多个任务共享所述主干网络,共用从所述主干网络得到的特征图;所述第一训练参数中包括:所述任务对应的第一优化器的类型和第一超参数,所述第一超参数中包括:所述任务对应的第一学习率、第一批量大小和第一训练轮次;
第一训练模块,被配置为执行根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成;
第二获取模块,被配置为执行获取所述多任务模型的第二数据集和第二训练参数,所述第二训练参数中包括:所述多任务模型对应的第二优化器的类型和第二超参数,所述第二超参数中包括:所述多任务模型对应的第二学习率、第二批量大小和第二训练轮次;
第二训练模块,被配置为执行根据所述第二数据集和所述第二训练参数对所述多任务模型进行训练;
所述第二获取模块,还被配置为执行:
将所述多任务模型的全部任务的第一数据集求并集,确定为所述第二数据集;以及
将所述多任务模型的全部任务的第一学习率的总和,确定为所述第二学习率;以及
根据所述第一数据集的大小,确定所述第二批量大小;以及
将所述多个任务的多个所述第一训练轮次中的最小值,确定为所述第二训练轮次。
8.根据权利要求7所述的训练装置,其特征在于,所述第一训练模块,还被配置为执行:
基于预设的任务训练顺序,根据所述第一数据集和所述第一训练参数,依次对所述任务对应的分支网络进行训练,直至多个所述任务全部训练完成。
9.根据权利要求8所述的训练装置,其特征在于,所述第一训练模块,还被配置为执行:
响应于当前任务为所述任务训练顺序中的第一个任务,固定除了所述当前任务之外的其他任务对应的第一模型参数,所述第一模型参数为所述任务对应的分支网络的参数;以及
根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数和第二模型参数,所述第二模型参数为所述主干网络的参数。
10.根据权利要求9所述的训练装置,其特征在于,所述第一训练模块,还被配置为执行:
响应于所述当前任务为除了所述第一个任务之外的其他任务,固定除了所述当前任务之外的其他任务对应的所述第一模型参数和所述第二模型参数;以及
根据所述当前任务的所述第一数据集和所述第一训练参数,对所述当前任务进行循环迭代训练,并在每次迭代训练结束后,更新所述当前任务对应的所述第一模型参数。
11.根据权利要求8所述的训练装置,其特征在于,所述第二训练模块,还被配置为执行:
将所述任务的所述第一数据集作为所述多任务模型的单批训练数据;
基于所述多任务模型的多个所述单批训练数据和所述第二训练参数,对所述多任务模型进行所述第二训练轮次的循环迭代训练。
12.根据权利要求11所述的训练装置,其特征在于,所述第二训练模块,还被配置为执行:
基于所述任务训练顺序,获取当前序数对应的所述单批训练数据;
根据所述单批训练数据对所述多任务模型进行一次迭代训练,并在本次迭代训练结束后,更新所述多任务模型的模型参数;
根据下一序数对应的所述单批训练数据对参数更新后的所述多任务模型进行下一次迭代训练,直至完成本轮次的循环迭代训练。
13.一种电子设备,其特征在于,包括:
处理器;
用于存储所述处理器的可执行指令的存储器;
其中,所述处理器被配置为执行所述指令,以实现如权利要求1-6中任一项所述的方法。
14.一种计算机可读存储介质,其特征在于,当所述计算机可读存储介质中的指令由电子设备的处理器执行时,使得电子设备能够执行如权利要求1-6中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210717146.2A CN114816719B (zh) | 2022-06-23 | 2022-06-23 | 多任务模型的训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210717146.2A CN114816719B (zh) | 2022-06-23 | 2022-06-23 | 多任务模型的训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114816719A CN114816719A (zh) | 2022-07-29 |
CN114816719B true CN114816719B (zh) | 2022-09-30 |
Family
ID=82520368
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210717146.2A Active CN114816719B (zh) | 2022-06-23 | 2022-06-23 | 多任务模型的训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114816719B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116385949B (zh) * | 2023-03-23 | 2023-09-08 | 广州里工实业有限公司 | 一种移动机器人的区域检测方法、系统、装置及介质 |
Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109598184A (zh) * | 2017-09-30 | 2019-04-09 | 北京图森未来科技有限公司 | 一种多分割任务的处理方法和装置 |
CN110136828A (zh) * | 2019-05-16 | 2019-08-16 | 杭州健培科技有限公司 | 一种基于深度学习实现医学影像多任务辅助诊断的方法 |
CN111310574A (zh) * | 2020-01-17 | 2020-06-19 | 清华大学 | 一种车载视觉实时多目标多任务联合感知方法和装置 |
CN111539351A (zh) * | 2020-04-27 | 2020-08-14 | 广东电网有限责任公司广州供电局 | 一种多任务级联的人脸选帧比对方法 |
CN112380923A (zh) * | 2020-10-26 | 2021-02-19 | 天津大学 | 基于多任务的智能自主视觉导航与目标检测方法 |
CN113435571A (zh) * | 2021-05-12 | 2021-09-24 | 上海微亿智造科技有限公司 | 实现多任务并行的深度网络训练方法和系统 |
CN113704388A (zh) * | 2021-03-05 | 2021-11-26 | 腾讯科技(深圳)有限公司 | 多任务预训练模型的训练方法、装置、电子设备和介质 |
CN113705662A (zh) * | 2021-08-26 | 2021-11-26 | 中国银联股份有限公司 | 一种协同训练方法、装置及计算机可读存储介质 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11370423B2 (en) * | 2018-06-15 | 2022-06-28 | Uatc, Llc | Multi-task machine-learned models for object intention determination in autonomous driving |
-
2022
- 2022-06-23 CN CN202210717146.2A patent/CN114816719B/zh active Active
Patent Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109598184A (zh) * | 2017-09-30 | 2019-04-09 | 北京图森未来科技有限公司 | 一种多分割任务的处理方法和装置 |
CN110136828A (zh) * | 2019-05-16 | 2019-08-16 | 杭州健培科技有限公司 | 一种基于深度学习实现医学影像多任务辅助诊断的方法 |
CN111310574A (zh) * | 2020-01-17 | 2020-06-19 | 清华大学 | 一种车载视觉实时多目标多任务联合感知方法和装置 |
CN111539351A (zh) * | 2020-04-27 | 2020-08-14 | 广东电网有限责任公司广州供电局 | 一种多任务级联的人脸选帧比对方法 |
CN112380923A (zh) * | 2020-10-26 | 2021-02-19 | 天津大学 | 基于多任务的智能自主视觉导航与目标检测方法 |
CN113704388A (zh) * | 2021-03-05 | 2021-11-26 | 腾讯科技(深圳)有限公司 | 多任务预训练模型的训练方法、装置、电子设备和介质 |
CN113435571A (zh) * | 2021-05-12 | 2021-09-24 | 上海微亿智造科技有限公司 | 实现多任务并行的深度网络训练方法和系统 |
CN113705662A (zh) * | 2021-08-26 | 2021-11-26 | 中国银联股份有限公司 | 一种协同训练方法、装置及计算机可读存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN114816719A (zh) | 2022-07-29 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110334689B (zh) | 视频分类方法和装置 | |
CN109241141B (zh) | 深度学习的训练数据处理方法和装置 | |
CN109447156B (zh) | 用于生成模型的方法和装置 | |
CN110647920A (zh) | 机器学习中的迁移学习方法及装置、设备与可读介质 | |
CN112668608B (zh) | 一种图像识别方法、装置、电子设备及存储介质 | |
CN114816719B (zh) | 多任务模型的训练方法及装置 | |
CN111124920A (zh) | 设备性能测试方法、装置及电子设备 | |
CN112306447A (zh) | 一种界面导航方法、装置、终端和存储介质 | |
CN114332590B (zh) | 联合感知模型训练、联合感知方法、装置、设备和介质 | |
CN113657411A (zh) | 神经网络模型的训练方法、图像特征提取方法及相关装置 | |
CN112416301A (zh) | 深度学习模型开发方法及装置、计算机可读存储介质 | |
CN113255819B (zh) | 用于识别信息的方法和装置 | |
US20230367972A1 (en) | Method and apparatus for processing model data, electronic device, and computer readable medium | |
EP4198815A1 (en) | Neural network training method and apparatus for image retrieval, and electronic device | |
CN112287144B (zh) | 图片检索方法、设备及存储介质 | |
CN109857838B (zh) | 用于生成信息的方法和装置 | |
CN111124862A (zh) | 智能设备性能测试方法、装置及智能设备 | |
CN115310582A (zh) | 用于训练神经网络模型的方法和装置 | |
CN112308074A (zh) | 用于生成缩略图的方法和装置 | |
CN112308205A (zh) | 基于预训练模型的模型改进方法及装置 | |
CN111049988A (zh) | 移动设备的亲密度预测方法、系统、设备及存储介质 | |
CN113515465B (zh) | 基于区块链技术的软件兼容性测试方法及系统 | |
US12008803B2 (en) | Neural network training method and apparatus for image retrieval, and electronic device | |
CN117149339B (zh) | 基于人工智能的用户界面关系识别方法及相关装置 | |
CN114677691B (zh) | 文本识别方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |