CN113516239A - 模型训练方法、装置、存储介质及电子设备 - Google Patents
模型训练方法、装置、存储介质及电子设备 Download PDFInfo
- Publication number
- CN113516239A CN113516239A CN202110412115.1A CN202110412115A CN113516239A CN 113516239 A CN113516239 A CN 113516239A CN 202110412115 A CN202110412115 A CN 202110412115A CN 113516239 A CN113516239 A CN 113516239A
- Authority
- CN
- China
- Prior art keywords
- model
- task
- data
- data sets
- initial
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Withdrawn
Links
- 238000012549 training Methods 0.000 title claims abstract description 70
- 238000000034 method Methods 0.000 title claims abstract description 56
- 230000006870 function Effects 0.000 claims description 28
- 238000012545 processing Methods 0.000 claims description 19
- 230000015654 memory Effects 0.000 claims description 15
- 238000004590 computer program Methods 0.000 claims description 12
- 238000004364 calculation method Methods 0.000 abstract description 6
- 241001465754 Metazoa Species 0.000 description 10
- 230000008569 process Effects 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 230000003287 optical effect Effects 0.000 description 4
- 238000013473 artificial intelligence Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 2
- 230000001186 cumulative effect Effects 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 230000006978 adaptation Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 239000000126 substance Substances 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Image Analysis (AREA)
Abstract
本申请公开一种模型训练方法、装置、存储介质及电子设备。其中,获取多个数据集;将多个数据集逐个输入初始多任务模型;其中,所述初始多任务模型包含模型参数,所述模型参数包含共享参数和任务参数;其中,所述共享参数为所述初始多任务模型中多个任务共有的模型参数,所述任务参数为所述初始多任务模型中多个任务中每个任务独有的模型参数;基于所述初始多任务模型的输出结果调整所述模型参数,得到训练后的多任务模型。本方法能够避免针对同一训练样本进行重复计算的问题,提高了训练效率,节省了计算资源。
Description
技术领域
本公开涉及人工智能技术领域,特别是涉及一种模型训练方法、装置、存储介质及电子设备。
背景技术
深度学习在图像识别,语音识别,自然语言处理等相关领域都取得很多成果,但是由于深度学习模型计算复杂,效率低,如果对于一些相近的任务,往往都各自使用一个模型,无疑增加了计算量和资源占用。
发明内容
根据本公开的一个方面,提供以下技术方案:
一种模型训练方法,包括:
获取多个数据集;
将多个数据集逐个输入初始多任务模型;其中,所述初始多任务模型包含模型参数,所述模型参数包含共享参数和任务参数;其中,所述共享参数为所述初始多任务模型中多个任务共有的模型参数,所述任务参数为所述初始多任务模型中多个任务中每个任务独有的模型参数;
基于所述初始多任务模型的输出结果调整所述模型参数,得到训练后的多任务模型。
进一步地,根据所述初始多任务模型的输出结果计算损失函数的总损失值,并根据所述损失值对所述多任务模型的模型参数进行调整。
进一步地,若每个数据集的任务标签数量为一个,则将该任务标签对应的任务的损失值作为该数据集的损失值,其中,根据每个数据集的损失值计算得到损失函数的总损失值。
进一步地,每个数据集包含一个或多个任务标签,若所述数据集的任务标签数量为多个,则将多个任务标签对应的多个任务的损失值之和作为该数据集的损失值,其中,根据每个数据集的损失值计算得到损失函数的总损失值。
进一步地,所述总损失值为多个数据集对应的所有任务的损失值之和。
进一步地,多个数据集中的每个数据集对应的损失函数乘以c,其中,c为每个数据集占所有数据集的比重。进一步地,所述损失函数中还包含动态系数,其中,所述动态系数可根据当前模型训练自适应调整。
进一步地,将所述多个数据集按照第一顺序加载入数据加载器,并将加载入数据加载器中的数据集按照预设批数量进行划分。
进一步地,判断是否还存在下一批数据,其中,所述下一批数据为需要从数据加载器输出的数据;若有,则继续输出下一批数据。
进一步地,若没有,则所述多个数据集按照第二顺序加载入数据加载器,其中,所述第一顺序和所述第二顺序不同。
进一步地,将所述多个数据集中的部分数据集进行合并,得到多个合并后的数据集;将所述多个合并后的数据集逐个输入初始多任务模型。
进一步地,将带有相同任务标签的数据集进行合并。
进一步地,将带有相同任务标签比例大于合并阈值的数据集进行合并。
进一步地,将所述多个数据集中的部分数据集进行拆分,得到多个拆分后的数据集;将所述多个拆分后的数据集逐个输入初始多任务模型。
进一步地,将部分数据集按照任务标签随机拆分得到拆分后的数据集。根据本公开的另一个方面,还提供以下技术方案:
进一步地,包括:
获取待处理数据,其中,所述待处理数据包含多个数据集,每个数据集包含一个或多个任务标签;
利用前述的方法训练得到训练后的多任务模型对所述待处理数据进行处理,得到处理结果。
根据本公开的另一个方面,还提供以下技术方案:
第一获取模块,用于获取多个数据集;
输入装置,用于将所述多个数据集逐个输入初始多任务模型;其中,所述初始多任务模型包括共享参数和任务参数,所述共享参数为多个任务共有的模型参数,所述任务参数为多个任务中每个任务各自独有的模型参数;
训练模块,用于基于所述初始多任务模型的输出结果训练所述初始多任务模型,得到训练后的多任务模型。
根据本公开的另一个方面,还提供以下技术方案:
一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机上运行时,使得所述计算机执行如上述的模型训练方法。
根据本公开的另一个方面,还提供以下技术方案:
一种电子设备,包括处理器和存储器,所述存储器存储有计算机程序,其特征在于,所述处理器通过调用所述计算机程序,使得所述计算机执行上述任一方法中所述的步骤。
本公开实施例提供了模型训练方法、装置、存储介质及电子设备。其中,该模型训练方法包括:获取多个数据集;将多个数据集逐个输入初始多任务模型;其中,初始多任务模型包含模型参数,模型参数包含共享参数和任务参数;其中,共享参数为初始多任务模型中多个任务共有的模型参数,任务参数为初始多任务模型中多个任务中每个任务独有的模型参数;基于初始多任务模型的输出结果调整模型参数,得到训练后的多任务模型。该模型训练方法通过对数据集的逐个学习,每次只利用来自一组数据集的标签,只计算这组数据集所对应任务的学习损失值的方法,实现了同一多任务模型中,不同的数据依次训练不同的任务分支,解决了通常情况下,只能利用同时包含所有任务标签的数据集训练多任务学习模型的情况,极大地提升了可用数据的数量,也解决了逐任务训练的方法中,大量重复输入同一个或者数据集,比如一张图片,解决了模型训练效率低的问题。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为根据本公开一个实施例的多任务处理模型的训练方法的流程示意图;
图2为根据本公开另一个实施例的多任务处理模型的训练方法的流程示意图;
图3为根据本公开一个实施例的多任务处理模型的数据输入流程示意图;
图4为根据本公开一个实施例的多任务处理装置的结构示意图;
图5为根据本公开一个实施例的多任务处理模型训练的硬件装置的结构示意图。
具体实施方式
以下通过特定的具体实例说明本公开的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本公开的其他优点与功效。显然,所描述的实施例仅仅是本公开一部分实施例,而不是全部的实施例。本公开还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本公开的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。基于本公开中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本公开保护的范围。
需要说明的是,下文描述在所附权利要求书的范围内的实施例的各种方面。应显而易见,本文中所描述的方面可体现于广泛多种形式中,且本文中所描述的任何特定结构及/或功能仅为说明性的。基于本公开,所属领域的技术人员应了解,本文中所描述的一个方面可与任何其它方面独立地实施,且可以各种方式组合这些方面中的两者或两者以上。举例来说,可使用本文中所阐述的任何数目个方面来实施设备及/或实践方法。另外,可使用除了本文中所阐述的方面中的一或多者之外的其它结构及/或功能性实施此设备及/或实践此方法。
还需要说明的是,以下实施例中所提供的图示仅以示意方式说明本公开的基本构想,图式中仅显示与本公开中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。
另外,在以下描述中,提供具体细节是为了便于透彻理解实例。然而,所属领域的技术人员将理解,可在没有这些特定细节的情况下实践所述方面。
由于现有技术中利用人工智能模型进行任务处理时,对于每个单独的任务需要单独训练一个模型。当有多个任务时,需要单独训练一个应用于该任务的模型,而在这多个任务的训练为同一个训练样本时,需要将同一个训练样本分别输入至多个任务对应的多个模型,进行多次计算。因此,如何在训练多个任务处理模型时避免人工智能模型的大量重复计算成为了一些应用领域亟待解决的难题。
请参阅图1,图1为本申请实施例提供的模型训练方法的流程示意图。本申请实施例提供的模型训练方法的具体流程可以如下:
S102:获取多个数据集。
其中,数据集被分为M个,Di为其中的一个数据集,其中,0<i≤M,i为正整数。Di包含一个或多个任务标签,例如,D2包含了识别人像、动物、风景三个不同的任务。所有数据集所包含的任务大于或者等于人像、动物、风景三个不同的任务,还可以包括其他任务,例如,识别建筑物、植物等。这些数据集是用来训练多任务模型的,并且,样本数据可以是人工标注的,也可以是通过其他方式获得的,比如通过其他神经网络模型来标注的,在此不做限定。
其中,Di还可以被划为数据量m1,m2,…,mn,所有数据量形成Di的总数据量N。
本实施例中的多个任务可以是能够对于同一数据进行不同处理的任务,并且多个任务的处理结果之间可以互不影响。例如,本实施例中多个任务可以包括用于识别人像第一任务、用于识别动物的第二任务、用于识别风景的第三任务等。
S104:将多个数据集逐个输入初始多任务模型。
其中,初始多任务模型包括共享参数和任务参数,共享参数为多个任务共有的模型参数,任务参数为多个任务中每个任务各自独有的模型参数。
其中,将多个数据集逐个输入初始多任务模型中,每个数据集Di分别对初始多任务模型进行训练;利用数据集Di训练过程中,Di的数据量m1,m2,…,mn会被输入到初始多任务模型中,总数据量为N。
本公开实施例中,将多个数据集按照第一顺序加载入数据加载器,并将加载入数据加载器中的数据集按照预设批数量进行划分。其中,图2示出了训练过程,将多个数据集进行编号,如i=1、2、…、n,在编号之后,按照第一顺序,如数据集D1、数据集D2、数据集D3、…、数据集Dn,依照该顺序逐个将数据集输入到模型中。数据集D1对应图2中的数据集1,数据集D2对应图2中的数据集2,其他同理;此外,第一顺序不作具体限定,可以是多种排序方式。
多个数据集包含了所有的任务标签,即将所有的数据集输入初始模型之后,所有任务都进行了学习训练,并得到了所有的损失函数值。
请参见图3,将加载入数据加载器中的数据集按照预设批数量进行划分之后,判断是否还存在下一批数据,其中,下一批数据为需要从数据加载器输出的数据;如果有,则继续输出下一批数据。从而能够将所有的数据集都输入到初始多任务模型中去,保证所有的数据都能被利用到。
如果判断不存在下一批数据,则多个数据集按照第二顺序加载入数据加载器,其中,第一顺序和第二顺序不同。第二顺序可以是按照数据集D2、数据集D5、数据集D3、…、数据集Dn,或者第二顺序是随机产生的,不对其顺序做限定。
在按照第二顺序将所有的数据集都输入到初始多任务模型之后,可以再按照第三顺序将所有的数据集再输入到初始多任务模型中去,第三顺序和第一顺序以及第二顺序都不相同,第三顺序可以是随机产生的,也可以是按照某种特定的顺序。
本公开实施例中,还可以将多个数据集中的部分数据集进行合并,得到多个合并后的数据集;
将多个合并后的数据集逐个输入初始多任务模型。其中合并后的数据集是按照任务标签进行合并的,合并之后的数据集包含一个或多个任务标签。
其中,合并方式可以是将带有相同任务标签比例大于合并阈值的数据集进行合并。比如,数据集D1包含10个任务标签,数据集D2包含8个任务标签,这两个数据集之间相同的任务标签数量为6个,相同任务标签占D1总任务标签数量的60%,相同任务标签占D2总任务标签数量的75%,取最小的相同任务标签占数据集任务标签数量的占比60%,且如果合并阈值此处设置为50%,则由于带有相同任务标签比例大于合并阈值,所以讲数据集D1和D2进行合并。
通过合并之后,原本的多个数据集的个数会减小,从而训练的次数可以减小,但每个任务仍然会被训练,并得到了每个任务对应的损失函数值。
S106:基于初始多任务模型的输出结果训练初始多任务模型,得到训练后的多任务模型。
本公开实施例中,根据初始多任务模型的输出结果计算损失函数的损失值,并根据损失值对多任务模型的模型参数进行调整。模型训练过程中,将所有的数据集都输入到初始都任务模型之后,会得到每个数据集对应的任务的输出结果,根据所有的这些结果训练初始多任务模型,更新共享参数和任务参数,当模型收敛之后,从而得到训练后的多任务模型。
本公开实施例中,根据初始多任务模型的输出结果计算损失函数的总损失值,并根据损失值对多任务模型的模型参数进行调整。
每个数据集包含一个或多个任务标签,在每个数据集的任务标签数量为一个时,根据损失函数的损失值更新共享参数和多个任务对应的任务参数。
如果数据集任务标签数据为一个,则该种情况下计算损失函数的损失值比较简单,在计算某个数据集对应的损失值时,只需要计算对应的某个任务标签的损失值。例如,数据集D2包含的任务标签只有人像识别任务,则只需要将数据集D2输入到数据加载器中,然后所有的D2中的数据会流向人像识别任务,在计算该任务的损失函数时,只需要计算人脸识别任务的损失值即可。
本公开实施例中,在数据集的任务标签数量为多个时,根据多个任务标签分别得到的累积损失值更新共享参数和多个任务对应的任务参数。
如果数据集任务标签数据为多个,则该种情况下计算损失函数的损失值过程稍微多一些,在计算某个数据集对应的损失值时,需要计算所有的任务标签的损失值。例如,数据集D2包含的任务标签有人像识别任务、动物识别任务、风景识别任务三个任务,则需要将数据集D2输入到数据加载器中,然后所有的D2中的数据会流向人像识别任务、动物识别任务、风景识别任务三个任务,在计算这三个任务的损失函数时,需要计算这三个任务标签分别得到的累积损失值。在计算完成之后,更新共享参数和多个任务对应的任务参数。
本公开实施例中,总损失值为多个数据集对应的所有任务的损失值之和。损失函数表示为:
其中,LMTL代表多任务模型的联合损失,Li,j代表数据集Di中任务j的损失值,ci代表数据集Di在所有数据集中所占的比重。
例如,当数据集D2含有三个任务标签,即人像识别任务、动物识别任务、风景识别任务三个任务标签,则LMTL为D2中人像识别任务、动物识别任务、风景识别任务的损失函数之和,即LMTL=c1L2J1+c1L2J2+c1L2J3。
本公开实施例中,损失函数中还包含动态系数,其中,动态系数可根据当前模型训练自适应调整。例如,可以根据上一循环中对应任务在验证集上的损失值大小对当前循环中的损失值的大小进行动态调整。
本申请实施例提供的模型训练方法的具体流程可以如下:
获取待处理数据;
利用前述的方法训练得到的训练后的多任务模型对待处理数据进行处理,得到多个任务处理结果。
其中,待处理数据并不是用来训练初始多任务模型的,此时模型已经训练完毕,共享参数和任务都已经设置完毕;将需要处理的待处理数据输入到模型中,即可根据训练后的多任务模型进行数据处理,得到多个任务处理结果。
下面将结合附图4,对本申请实施例提供的模型训练装置进行详细介绍。本申请实施例提供的模型训练装置可以如下:
一种模型训练装置,包括:
第一获取模块301,用于获取多个数据集;
在可选的实现方式中,数据集被分为M个,Di为其中的一个数据集,其中,0<i≤M,i为正整数。Di包含一个或多个任务标签,例如,D2包含了识别人像、动物、风景三个不同的任务。所有数据集所包含的任务大于或者等于人像、动物、风景三个不同的任务,还可以包括其他任务,例如,识别建筑物、植物等。这些数据集是用来训练多任务模型的,并且,样本数据可以是人工标注的,也可以是通过其他方式获得的,比如通过其他神经网络模型来标注的,在此不做限定。
其中,Di被划为数据量m1,m2,…,mn和总数据量N。
在可选的实现方式中,多个任务可以是能够对于同一数据进行不同处理的任务,并且多个任务的处理结果之间可以互不影响。例如,本实施例中多个任务可以包括用于识别人像第一任务、用于识别动物的第二任务、用于风景的第三任务等。
输入模块302,用于将多个数据集逐个输入初始多任务模型;其中,初始多任务模型包括共享参数和任务参数,共享参数为多个任务共有的模型参数,任务参数为多个任务中每个任务各自独有的模型参数;
其中,将多个数据集逐个输入初始多任务模型中,每个数据集Di分别对初始多任务模型进行训练;利用数据集Di训练过程中,Di的数据量m1,m2,…,mn会被输入到初始多任务模型中,总数据量为N。
在可选的实现方式中,将所述多个数据集按照第一顺序加载入数据加载器,并将加载入数据加载器中的数据集按照预设批数量进行划分。其中,图2示出了训练过程图,将多个数据集进行编号,如数据集Di=1、2、…、n,在编号之后,按照第一顺序,如数据集D1、数据集D2、数据集D3、…、数据集Dn,依照该顺序逐个将数据集输入到模型中。
多个数据集包含了所有的任务标签,即将所有的数据集输入初始模型之后,所有任务都进行了学习训练,并得到了所有的损失函数值。
在可选的实现方式中,将加载入数据加载器中的数据集按照预设批数量进行划分之后,判断是否还存在下一批数据,其中,所述下一批数据为需要从数据加载器输出的数据;如果有,则继续输出下一批数据。从而能够将所有的数据集都输入到初始多任务模型中去,保证所有的数据都能被利用到。
在可选的实现方式中,如果判断不存在下一批数据,则多个数据集按照第二顺序加载入数据加载器,其中,第一顺序和第二顺序不同。第二顺序可以是按照数据集D2、数据集D5、数据集D3、…、数据集Dn,或者第二顺序是随机产生的,不对其顺序做限定。
在按照第二顺序将所有的数据集都输入到初始多任务模型之后,可以再按照第三顺序将所有的数据集再输入到初始多任务模型中去,第三顺序和第一顺序以及第二顺序都不相同,第三顺序可以是随机产生的。
训练模块303,用于基于初始多任务模型的输出结果训练初始多任务模型,得到训练后的多任务模型。
在可选的实现方式中,根据初始多任务模型的输出结果计算损失函数的损失值,并根据损失值对多任务模型的模型参数进行调整调整。模型训练过程中,将所有的数据集都输入到初始都任务模型之后,会得到每个数据集对应的任务的输出结果,根据所有的这些结果训练初始多任务模型,更新共享参数和任务参数,当模型收敛之后,从而得到训练后的多任务模型。
在可选的实现方式中,数据集包含一个或多个任务标签,在每个数据集的任务标签数量为一个时,根据损失函数的损失值更新共享参数和多个任务对应的任务参数。
如果数据集任务标签数据为一个,则该种情况下计算损失函数的损失值比较简单,在计算某个数据集对应的损失值时,只需要计算对应的某个任务标签的损失值。例如,数据集D2包含的任务标签只有人像识别任务,则只需要将数据集D2输入到数据加载器中,然后所有的D2中的数据会流向人像识别任务,在计算该任务的损失函数时,只需要计算人脸识别任务的损失值即可。
在可选的实现方式中,还包括,第二获取模块,用于获取待处理数据;
在可选的实现方式中,待处理数据并不是用来训练初始多任务模型的,此时模型已经训练完毕,共享参数和任务都已经设置完毕;将需要处理的待处理数据输入到模型中,即可根据训练后的多任务模型进行数据处理。
处理模块,用于利用前述的方法训练得到训练后的多任务模型对待处理数据进行处理,得到处理结果。
本申请实施例提供的电子设备40可以如下:
一种电子设备,如图5,包括处理器41和存储器42,存储器42存储有计算机程序,处理器41通过调用计算机程序,用于执行前述的模型训练方法。
本申请还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述方法的步骤。其中,计算机可读存储介质可以包括但不限于任何类型的盘,包括软盘、光盘、DVD、CD-ROM、微型驱动器以及磁光盘、ROM、RAM、EPROM、EEPROM、DRAM、VRAM、闪速存储器设备、磁卡或光卡、纳米系统(包括分子存储器IC),或适合于存储指令和/或数据的任何类型的媒介或设备。
本申请实施例还提供一种计算机程序产品,该计算机程序产品包括存储计算机程序的非瞬时性计算机可读存储介质,该计算机程序可操作来使计算机执行如上述方法实施例中记载的任何一种信息共享方法的部分或全部步骤。
本领域的技术人员可以清楚地了解到本申请的技术方案可借助软件和/或硬件来实现。本说明书中的“单元”和“模块”是指能够独立完成或与其他部件配合完成特定功能的软件和/或硬件,其中硬件例如可以是现场可编程门阵列(Field-Programmable GateArray,FPGA)、集成电路(Integrated Circuit,IC)等。
需要说明的是,对于前述的各方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本申请并不受所描述的动作顺序的限制,因为依据本申请,某些步骤可以采用其他顺序或者同时进行。其次,本领域技术人员也应该知悉,说明书中所描述的实施例均属于优选实施例,所涉及的动作和模块并不一定是本申请所必须的。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见其他实施例的相关描述。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置,可通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些服务接口,装置或单元的间接耦合或通信连接,可以是电性或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储器中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储器中,包括若干指令用以使得一台计算机设备(可为个人计算机、服务器或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储器包括:U盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、移动硬盘、磁碟或者光盘等各种可以存储程序代码的介质。
本领域普通技术人员可以理解上述实施例的各种方法中的全部或部分步骤是可以通进程序来指令相关的硬件来完成,该程序可以存储于一计算机可读存储器中,存储器可以包括:闪存盘、只读存储器(Read-Only Memory,ROM)、随机存取器(Random AccessMemory,RAM)、磁盘或光盘等。
以上所述者,仅为本公开的示例性实施例,不能以此限定本公开的范围。即但凡依本公开教导所作的等效变化与修饰,皆仍属本公开涵盖的范围内。本领域技术人员在考虑说明书及实践这里的公开后,将容易想到本公开的其它实施方案。本申请旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未记载的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的范围和精神由权利要求限定。
Claims (17)
1.一种模型训练方法,包括:
获取多个数据集;
将多个数据集逐个输入初始多任务模型;其中,所述初始多任务模型包含模型参数,所述模型参数包含共享参数和任务参数;其中,所述共享参数为所述初始多任务模型中多个任务共有的模型参数,所述任务参数为所述初始多任务模型中多个任务中每个任务独有的模型参数;
基于所述初始多任务模型的输出结果调整所述模型参数,得到训练后的多任务模型。
2.根据权利要求1所述的训练方法,其特征在于,基于所述初始多任务模型的输出结果调整所述模型参数,得到训练后的多任务模型,包括:
根据所述初始多任务模型的输出结果计算损失函数的总损失值,并根据所述损失值对所述多任务模型的模型参数进行调整。
3.根据权利要求2所述的训练方法,其特征在于,每个数据集包含一个或多个任务标签,若每个数据集的任务标签数量为一个,则将该任务标签对应的任务的损失值作为该数据集的损失值,其中,根据每个数据集的损失值计算得到损失函数的总损失值。
4.根据权利要求2所述的训练方法,其特征在于,每个数据集包含一个或多个任务标签,若所述数据集的任务标签数量为多个,则将多个任务标签对应的多个任务的损失值之和作为该数据集的损失值,其中,根据每个数据集的损失值计算得到损失函数的总损失值。
5.根据权利要求2所述的训练方法,其特征在于,所述总损失值为多个数据集对应的所有任务的损失值之和。
6.根据权利要求5所述的训练方法,其特征在于,多个数据集中的每个数据集对应的损失函数乘以c,其中,c为每个数据集占所有数据集的比重。
7.根据权利要求5所述的训练方法,其特征在于,所述损失函数中还包含动态系数,其中,所述动态系数可根据当前模型训练自适应调整。
8.根据权利要求1所述的训练方法,其特征在于,所述获取多个数据集包括:
将所述多个数据集按照第一顺序加载入数据加载器,并将加载入数据加载器中的数据集按照预设批数量进行划分。
9.根据权利要求8所述的训练方法,其特征在于,所述将加载入数据加载器中的数据集按照预设批数量进行划分之后,还包括:
判断是否还存在下一批数据,其中,所述下一批数据为需要从数据加载器输出的数据;
若有,则继续输出下一批数据。
10.根据权利要求9所述的训练方法,其特征在于,
若没有,则所述多个数据集按照第二顺序加载入数据加载器,其中,所述第一顺序和所述第二顺序不同。
11.根据权利要求1所述的训练方法,其特征在于,将所述多个数据集中的部分数据集进行合并,得到多个合并后的数据集;
将所述多个合并后的数据集逐个输入初始多任务模型。
12.根据权利要求11所述的训练方法,其特征在于,所述将所述多个数据集中的部分数据集进行合并,包括:
将带有相同任务标签的数据集进行合并。
13.根据权利要求12所述的训练方法,其特征在于,将带有相同任务标签比例大于合并阈值的数据集进行合并。
14.根据权利要求1-13任一项所述的训练方法,其特征在于,包括:
获取待处理数据,其中,所述待处理数据包含多个数据集,每个数据集包含一个或多个任务标签;
利用权所述训练得到训练后的多任务模型对所述待处理数据进行处理,得到处理结果。
15.一种模型训练装置,包括:
第一获取模块,用于获取多个数据集;
输入模块,用于将所述多个数据集逐个输入初始多任务模型;其中,所述初始多任务模型包括共享参数和任务参数,所述共享参数为多个任务共有的模型参数,所述任务参数为多个任务中每个任务各自独有的模型参数;
训练模块,用于基于所述初始多任务模型的输出结果训练所述初始多任务模型,得到训练后的多任务模型。
16.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,当所述计算机程序在计算机上运行时,使得所述计算机执行如权利要求1至13任一项所述的模型训练方法。
17.一种电子设备,包括处理器和存储器,所述存储器存储有计算机程序,其特征在于,所述处理器通过调用所述计算机程序,用于执行如权利要求1至13任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110412115.1A CN113516239A (zh) | 2021-04-16 | 2021-04-16 | 模型训练方法、装置、存储介质及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110412115.1A CN113516239A (zh) | 2021-04-16 | 2021-04-16 | 模型训练方法、装置、存储介质及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113516239A true CN113516239A (zh) | 2021-10-19 |
Family
ID=78062528
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110412115.1A Withdrawn CN113516239A (zh) | 2021-04-16 | 2021-04-16 | 模型训练方法、装置、存储介质及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113516239A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114356540A (zh) * | 2021-10-30 | 2022-04-15 | 腾讯科技(深圳)有限公司 | 一种参数更新方法、装置、电子设备和存储介质 |
CN114821538A (zh) * | 2022-05-19 | 2022-07-29 | 北京地平线机器人技术研发有限公司 | 一种多任务模型的训练方法及装置 |
CN116756579A (zh) * | 2023-08-22 | 2023-09-15 | 腾讯科技(深圳)有限公司 | 大语言模型的训练方法及基于大语言模型的文本处理方法 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109447259A (zh) * | 2018-09-21 | 2019-03-08 | 北京字节跳动网络技术有限公司 | 多任务处理及多任务处理模型训练方法、装置和硬件装置 |
WO2019100724A1 (zh) * | 2017-11-24 | 2019-05-31 | 华为技术有限公司 | 训练多标签分类模型的方法和装置 |
CN110188358A (zh) * | 2019-05-31 | 2019-08-30 | 北京神州泰岳软件股份有限公司 | 自然语言处理模型的训练方法及装置 |
CN111027428A (zh) * | 2019-11-29 | 2020-04-17 | 北京奇艺世纪科技有限公司 | 一种多任务模型的训练方法、装置及电子设备 |
CN111353541A (zh) * | 2020-03-03 | 2020-06-30 | 浙江新再灵科技股份有限公司 | 一种多任务模型的训练方法 |
WO2020143304A1 (zh) * | 2019-01-07 | 2020-07-16 | 平安科技(深圳)有限公司 | 损失函数优化方法、装置、计算机设备及存储介质 |
-
2021
- 2021-04-16 CN CN202110412115.1A patent/CN113516239A/zh not_active Withdrawn
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019100724A1 (zh) * | 2017-11-24 | 2019-05-31 | 华为技术有限公司 | 训练多标签分类模型的方法和装置 |
CN109447259A (zh) * | 2018-09-21 | 2019-03-08 | 北京字节跳动网络技术有限公司 | 多任务处理及多任务处理模型训练方法、装置和硬件装置 |
WO2020143304A1 (zh) * | 2019-01-07 | 2020-07-16 | 平安科技(深圳)有限公司 | 损失函数优化方法、装置、计算机设备及存储介质 |
CN110188358A (zh) * | 2019-05-31 | 2019-08-30 | 北京神州泰岳软件股份有限公司 | 自然语言处理模型的训练方法及装置 |
CN111027428A (zh) * | 2019-11-29 | 2020-04-17 | 北京奇艺世纪科技有限公司 | 一种多任务模型的训练方法、装置及电子设备 |
CN111353541A (zh) * | 2020-03-03 | 2020-06-30 | 浙江新再灵科技股份有限公司 | 一种多任务模型的训练方法 |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114356540A (zh) * | 2021-10-30 | 2022-04-15 | 腾讯科技(深圳)有限公司 | 一种参数更新方法、装置、电子设备和存储介质 |
CN114821538A (zh) * | 2022-05-19 | 2022-07-29 | 北京地平线机器人技术研发有限公司 | 一种多任务模型的训练方法及装置 |
CN116756579A (zh) * | 2023-08-22 | 2023-09-15 | 腾讯科技(深圳)有限公司 | 大语言模型的训练方法及基于大语言模型的文本处理方法 |
CN116756579B (zh) * | 2023-08-22 | 2023-12-12 | 腾讯科技(深圳)有限公司 | 大语言模型的训练方法及基于大语言模型的文本处理方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111414353A (zh) | 智能化的缺失数据填充方法、装置及计算机可读存储介质 | |
CN113516239A (zh) | 模型训练方法、装置、存储介质及电子设备 | |
EP3961384A1 (en) | Automatic derivation of software engineering artifact attributes from product or service development concepts | |
CN110689136B (zh) | 一种深度学习模型获得方法、装置、设备及存储介质 | |
CN113705775A (zh) | 一种神经网络的剪枝方法、装置、设备及存储介质 | |
CN110264274A (zh) | 客群划分方法、模型生成方法、装置、设备及存储介质 | |
CN113032116B (zh) | 任务时间预测模型的训练方法、任务调度方法及相关装置 | |
CN113408570A (zh) | 一种基于模型蒸馏的图像类别识别方法、装置、存储介质及终端 | |
CN113128419A (zh) | 一种障碍物识别方法和装置、电子设备及存储介质 | |
CN112785005A (zh) | 多目标任务的辅助决策方法、装置、计算机设备及介质 | |
CN110532448B (zh) | 基于神经网络的文档分类方法、装置、设备及存储介质 | |
CN112287950A (zh) | 特征提取模块压缩方法、图像处理方法、装置、介质 | |
CN115146775B (zh) | 边缘设备推理加速方法、装置和数据处理系统 | |
CN116128044A (zh) | 一种模型剪枝方法、图像处理方法及相关装置 | |
CN115423031A (zh) | 一种模型训练的方法以及相关装置 | |
CN114723455A (zh) | 业务处理方法、装置、电子设备和存储介质 | |
CN113408934A (zh) | 催收任务分配方法、装置、设备、存储介质、程序产品 | |
CN109840926B (zh) | 一种图像生成方法、装置及设备 | |
CN113408571A (zh) | 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端 | |
CN112905792A (zh) | 基于非文本场景的文本聚类方法、装置、设备及存储介质 | |
CN111400050A (zh) | 一种分配资源执行任务的方法及装置 | |
CN112230911A (zh) | 模型部署方法、装置、计算机设备和存储介质 | |
CN110288091A (zh) | 参数学习方法、装置、终端设备及可读存储介质 | |
CN112101394B (zh) | 供应商分域部署方法、装置、计算设备及计算机存储介质 | |
CN116363262B (zh) | 图像生成方法、装置及电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
WW01 | Invention patent application withdrawn after publication |
Application publication date: 20211019 |
|
WW01 | Invention patent application withdrawn after publication |