CN117952180A - 多任务模型的训练方法、数据预测方法、装置及设备 - Google Patents
多任务模型的训练方法、数据预测方法、装置及设备 Download PDFInfo
- Publication number
- CN117952180A CN117952180A CN202311363143.4A CN202311363143A CN117952180A CN 117952180 A CN117952180 A CN 117952180A CN 202311363143 A CN202311363143 A CN 202311363143A CN 117952180 A CN117952180 A CN 117952180A
- Authority
- CN
- China
- Prior art keywords
- model
- subtask
- value
- measure
- values
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 158
- 238000000034 method Methods 0.000 title claims abstract description 117
- 238000012545 processing Methods 0.000 claims abstract description 57
- 238000012795 verification Methods 0.000 claims description 49
- 238000005259 measurement Methods 0.000 claims description 40
- 230000008569 process Effects 0.000 claims description 25
- 230000015654 memory Effects 0.000 claims description 24
- 230000000875 corresponding effect Effects 0.000 claims description 20
- 230000008859 change Effects 0.000 claims description 15
- 238000004590 computer program Methods 0.000 claims description 11
- 230000002596 correlated effect Effects 0.000 claims description 9
- 238000013507 mapping Methods 0.000 claims description 7
- 238000012216 screening Methods 0.000 claims description 5
- 230000004044 response Effects 0.000 claims description 4
- 230000006870 function Effects 0.000 description 20
- 230000011218 segmentation Effects 0.000 description 15
- 238000004891 communication Methods 0.000 description 8
- 238000012360 testing method Methods 0.000 description 8
- 238000010586 diagram Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 6
- 230000001360 synchronised effect Effects 0.000 description 6
- 238000013473 artificial intelligence Methods 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 5
- 238000013461 design Methods 0.000 description 4
- 238000003062 neural network model Methods 0.000 description 4
- 230000000052 comparative effect Effects 0.000 description 3
- 238000002474 experimental method Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 230000001276 controlling effect Effects 0.000 description 2
- 230000007423 decrease Effects 0.000 description 2
- 230000003247 decreasing effect Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000001902 propagating effect Effects 0.000 description 2
- 230000033764 rhythmic process Effects 0.000 description 2
- 238000013515 script Methods 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 1
- 101000697856 Rattus norvegicus Bile acid-CoA:amino acid N-acyltransferase Proteins 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000003542 behavioural effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000033228 biological regulation Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 230000001010 compromised effect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000003595 spectral effect Effects 0.000 description 1
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供了一种多任务模型的训练方法、数据预测方法、装置及设备。训练方法包括:获取N个子任务模型的N个测度值,一个子任务模型对应一个测度值,一个测度值用于表征一个子任务模型的收敛进度;针对每个子任务模型,基于子任务模型的测度值,确定子任务模型的损失权重值,损失权重值与测度值表征的收敛进度负相关;通过多任务模型对训练集进行处理,得到N个子任务模型的N个预测损失值,一个子任务模型对应一个预测损失值;基于N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到多任务模型的总损失,并基于多任务模型的总损失更新多任务模型的参数。
Description
技术领域
本申请涉及人工智能技术,尤其涉及一种多任务模型的训练方法、数据预测方法、装置及设备。
背景技术
多任务模型是指一个模型中同时包含多个子任务模型的神经网络模型。每个子任务模型均有其对应的训练目标。例如:一个文本处理模型可包括两个子任务模型。两个子任务模型的训练目标分别为对文本进行多音字消岐、对文本进行分词。训练目标差异,可能导致模型训练的过程中各子任务模型的收敛进度存在差异,影响多任务模型训练的训练效果。
发明内容
本申请提供一种多任务模型的训练方法、数据预测方法、装置及设备,能够实现多任务模型的各子任务模型同步收敛。
本申请的技术方案是这样实现的:
本申请提供一种多任务模型的训练方法,多任务模型包括至少两个子任务模型,训练方法包括:
获取N个子任务模型的N个测度值,一个子任务模型对应一个测度值,一个测度值用于表征一个子任务模型的收敛进度;
针对每个子任务模型,基于子任务模型的测度值,确定子任务模型的损失权重值,损失权重值与测度值表征的收敛进度负相关;
通过多任务模型对训练集进行处理,得到N个子任务模型的N个预测损失值,一个子任务模型对应一个预测损失值;
基于N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到多任务模型的总损失,并基于多任务模型的总损失,更新多任务模型的参数。
本申请提供一种数据预测方法,训练好的多任务模型包括至少两个子任务模型,预测方法包括:
获取待处理数据;
通过训练好的多任务模型对待处理数据进行预测处理,得到每个子任务模型输出的预测结果;
其中,训练好的多任务模型是通过本申请提供的多任务模型的训练方法训练得到的;
基于所述N个预测结果执行数据处理任务。
本申请提供一种多任务模型的训练装置,多任务模型包括N个子任务模型,N为大于1的整数,多任务模型包括至少两个子任务模型,训练装置包括:
获取模块,用于获取N个子任务模型的N个测度值,一个子任务模型对应一个测度值,一个测度值用于表征一个子任务模型的收敛进度;
确定模块,用于针对每个子任务模型,基于子任务模型的测度值,确定子任务模型的损失权重值,损失权重值与测度值表征的收敛进度负相关;
第一处理模块,用于通过多任务模型对训练集进行处理,得到N个子任务模型的N个预测损失值,一个子任务模型对应一个预测损失值;
第二处理模块,用于基于N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到多任务模型的总损失,并基于多任务模型的总损失,更新多任务模型的参数。
本申请提供一种数据预测装置,训练好的多任务模型包括N个子任务模型,N为大于1的整数,所述数据预测装置包括:
获取模块,用于获取待处理数据;
预测模块,用于通过所述训练好的多任务模型对所述待处理数据进行预测处理,得到N个子任务模型输出的N个预测结果,一个子任务模型对应一个预测结果;其中,所述训练好的多任务模型是通过本申请提供的多任务模型的训练方法训练得到的;
执行模块,用于基于所述N个预测结果执行数据处理任务。
本申请提供一种电子设备,所述电子设备包括:存储器,用于存储计算机可执行指令;处理器,用于执行所述存储器中存储的计算机可执行指令时,实现本申请提供的多任务模型的训练方法,或者数据预测方法。
本申请提供一种计算机可读存储介质,存储有计算机程序或者计算机可执行指令,所述计算机程序或计算机可执行指令被处理器执行时实现本申请提供的多任务模型的训练方法,或者数据预测方法。
本申请提供一种计算机程序产品,包括计算机程序或计算机可执行指令,所述计算机程序或计算机可执行指令被处理器执行时实现本申请提供的多任务模型的训练方法,或者数据预测方法。
本申请具有以下有益效果:
本申请提供的训练方法可以根据每个子任务模型在训练过程中产生的测度值,动态调整各个子任务模型的损失权重,从而基于各个子任务模型的损失权重控制各个子任务模型的收敛速度,使得各子任务模型的收敛进度尽可能的同步收敛,从而提高多任务模型的训练效果。
附图说明
图1是本申请实施例提供的系统架构图;
图2A-图2B是本申请实施例提供的电子设备的结构示意图;
图3A-图3E是本申请实施例提供的多任务模型的训练方法的流程示意图;
图4是本申请实施例提供的基于多任务模型的训练方法的预测方法的流程示意图;
图5是本申请实施例提供的韵律模型的示意图;
图6是本申请实施例提供的子任务模型的回收率与训练步数的关系曲线。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步地详细描述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
在以下的描述中,所涉及的术语“第一\第二”仅仅是是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
对本申请实施例进行进一步详细说明之前,对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释。
1)多任务模型:包含多个子任务模型的神经网络模型,每个子任务模型有不同训练目标。神经网络模型训练时,这些子任务模型同时训练。
2)数据批(batch):模型训练时,往往是以数量N为单位对数据集进行分批,每一批样本(数据)同时进行计算。其中,N>=1的正整数。
3)迭代(epoch):训练集中所有数据都进行过一轮的迭代。
4)一步(step):对一批数据,输入模型结构,进行前向计算,损失计算,并进行损失回传,梯度更新,参数更新这些步骤,执行完毕后则一步训练完成。简言之,对一个batch的数据,进行模型训练的一个完整过程(前向推导->损失计算->损失回传->梯度更新->参数更新等)叫做一步。
5)模型的训练:至少需要训练集(training dataset)和验证集(validatedataset)。一般的,还额外需要一个测试集(test dataset)来进行模型测试。训练集是直接参数模型训练的数据集,epoch既是指对训练集的一遍完整训练;验证集是用于在训练中观察模型收敛进度,以及模型训练情况的数据集;测试集(test dataset)是在模型收敛后,用于作模型效果测试的数据集。一般的,验证集的数据规模,以及测试集的数据规模,都比训练集小的多。一般的,这三个数据集中的数据没有交集,各自包含不同的样本。
6)精确率(precision,P):多任务模型的常用测度(检验指标)之一,是指模型预测结果为真的样本中,实际标记值也是真的样本所占比率。
7)召回率(recall,R):机器学习,特别是多任务模型的常用测度之一,即对标记为真的样本,在模型推理结果中,推理结果也为真的样本所占的比率。
8)F1得分(Score):机器学习,特别多任务模型中,希望模型的精确率,召回率越高越好,但是当模型训练到一定阶段时,两者的变化方向趋向不同。为了得到最佳的模型,F1=2*(P*R)/(P+R);F1也是多任务模型常用测度之一。
9)损失函数(Loss Function):机器学习中用于计算“用预测值或者预测分布,来表达真实值或者真实分布”时带来的误差(或称损失)值的函数。常见损失函数包括L1损失、MSE损失、CE损失、KL散度损失等等。
10)响应于:用于表示所执行的操作所依赖的条件或者状态,当满足所依赖的条件或状态时,所执行的一个或多个操作可以是实时的,也可以具有设定的延迟;在没有特别说明的情况下,所执行的多个操作不存在执行先后顺序的限制。
11)损失值:用于表示模型的预测结果与实际结果(或称标记结果)之间的差异。模型的训练主要涉及前向传播(Forward Propagation)及反向传播(Back Propagation)两个过程,以包括输入层、隐藏层及输出层的多任务模型为例,前向传播处理是指依次通过输入层、隐藏层及输出层进行处理,最终得到预测结果;反向传播处理是指根据计算出的损失值依次传播至输出层、隐藏层及输入层,从而对各个层中的权重参数进行更新。
12)分词标记法:一般标识字在分词的位置:B表示词语开始,I为词中间位置的字,E为词语结束位置,S表示一个字单独成词,O表示其它字符。
13)文本韵律模型:即是对于给定的无标记文本,预测它的韵律标记:在什么位置,发生什么韵律。预测结果为一个跟文本长度的数字序列。在数字序列中,每个位置的值表示文本在这个位置上的汉字或者字符,预测它应该标记什么韵律。对于不标记韵律的字符位置,填充一个特殊值0,表示无韵律。
本申请实施例提供一种多任务模型的训练方法、数据预测方法、装置、电子设备、计算机可读存储介质及计算机程序产品,通过动态调整多任务模型中各子任务模型的收敛速度,达到平衡各个子任务模型的收敛进度的目的,使得训练后,每个子任务都得到或者近似得到理想期望值。
本申请实施例所提供的多任务模型的训练方法或数据预测方法,可以由终端独自实现;也可以由终端和服务器协同实现,例如终端独自承担下文的多任务模型的训练方法或数据预测方法,或者,终端向服务器发送针对多任务模型的训练请求,服务器根据接收的针对多任务模型的训练请求执行多任务模型的训练方法,并基于训练后的多任务模型针对待处理数据执行预测任务。
本申请实施例提供的电子设备可以是各种类型的终端设备或服务器,其中,服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器;终端可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。终端以及服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。
以服务器为例,例如可以是部署在云端的服务器集群,向用户开放人工智能云服务(AI as a Service,AIaaS),AIaaS平台会把几类常见的AI服务进行拆分,并在云端提供独立或者打包的服务,这种服务模式类似于一个AI主题商城,所有的用户都可以通过应用程序编程接口的方式来接入使用AIaaS平台提供的一种或者多种人工智能服务。
例如:其中的一种人工智能云服务可以为多任务模型的训练服务,即云端的服务器封装有本申请实施例提供的多任务模型的训练程序。用户通过终端调用云服务中的多任务模型的训练服务,以使部署在云端的服务器调用封装的多任务模型的训练程序,通过验证集对多任务模型进行验证处理,得到每个子任务模型的测度值;基于每个子任务模型的测度值,确定每个子任务模型的损失权重值;通过多任务模型对训练集进行处理,得到每个子任务模型的预测损失值;基于每个子任务模型的损失权重值,对每个子任务模型的预测损失值进行加权求和处理,得到多任务模型的总损失;基于多任务模型的总损失,更新多任务模型的参数,直至多任务模型收敛。训练好的多任务模型针对待处理数据执行预测任务(例如:韵律的标记、文本的分词等)。
参见图1,图1是本申请实施例提供的系统10的架构示意图,终端200通过网络300连接服务器100,其中,网络300可以是广域网或者局域网,又或者是二者的组合。
终端200可以被用来获取待处理数据。例如:用户通过终端输入待处理数据,终端自动获取针对待处理数据的预测请求。
在一些实施例中,终端中运行的客户端中可以植入有多任务模型的训练插件以及基于多任务模型的预测插件,用以在客户端本地实现多任务模型的训练方法以及数据预测方法。例如:终端200调用多任务模型的训练插件,以实现多任务模型的训练方法,通过动态调整多任务模型各子任务模型的收敛速度,达到平衡各个子任务模型的收敛进度的目的,使得某个训练后,每个子任务都得到或者近似得到理想期望值。终端200基于针对待处理数据的预测请求调用基于多任务模型的预测插件,通过训练好的多任务模型实现数据预测方法。
值得说明的是,训练好的多任务模型可以存储于终端200本地,并在需要时调用。另外,对于待处理数据待识别文本,可以是终端20实时生成的,也可以是从其他电子设备中获取的。
在一些实施例中,终端200获取针对待处理数据的预测请求后,调用服务器100的多任务模型的训练接口以及基于多任务模型的预测接口(可以提供为云服务的形式,即多任务模型的训练服务以及基于多任务模型的预测服务),服务器100通过多任务模型的训练插件,实现多任务模型的训练方法。服务器100通过基于多任务模型的预测插件,实现数据预测方法。
值得说明的是,训练好的多任务模型可以存储于服务器100本地,并在需要时调用。
在一些实施例中,终端或服务器可以通过运行计算机程序来实现本申请实施例提供的多任务模型的训练方法以及数据预测方法。举例来说,计算机程序可以是操作系统中的原生程序或软件模块;可以是本地(Native)应用程序(APP,Application),即需要在操作系统中安装才能运行的程序,如直播类的应用程序;也可以是小程序,即只需要下载到浏览器环境中就可以运行的程序;还可以是能够嵌入至任意APP中的小程序。总而言之,上述计算机程序可以是任意形式的应用程序、模块或插件。
在一些实施例中,服务器100可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器,其中,云服务可以是多任务模型的训练服务以及数据预测方法服务,供终端进行调用。
在一些实施例中,多个服务器可组成为一区块链,而服务器100为区块链上的节点,区块链中的每个节点之间可以存在信息连接,节点之间可以通过上述信息连接进行信息传输。其中,本申请实施例提供的多任务模型的训练方法以及数据预测方法所相关的数据(例如:多任务模型的训练逻辑、训练好的多任务模型)可保存于区块链上。
下面说明本申请实施例提供的电子设备的结构,参见图2A,图2A是本申请实施例提供的电子设备500的结构示意图,以电子设备500是服务器为例说明,图2A所示的电子设备500包括:至少一个处理器510、存储器550、至少一个网络接口520和用户接口530。电子设备500中的各个组件通过总线系统540耦合在一起。可理解,总线系统540用于实现这些组件之间的连接通信。总线系统540除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图2A中将各种总线都标为总线系统540。
处理器510可以是一种集成电路芯片,具有信号的处理能力,例如:通用处理器、数字信号处理器(digital signal processor,DSP),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其中,通用处理器可以是微处理器或者任何常规的处理器等。
存储器550包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(Read Only Memory,ROM),易失性存储器可以是随机存取存储器(Random Access Memory,RAM)。本申请实施例描述的存储器550旨在包括任意适合类型的存储器。存储器550可选地包括在物理位置上远离处理器510的一个或多个存储设备。
在一些实施例中,存储器550能够存储数据以支持各种操作,这些数据的示例包括程序、模块和数据结构或者其子集或超集,下面示例性说明。
操作系统551,包括用于处理各种基本系统服务和执行硬件相关任务的系统程序,例如框架层、核心库层、驱动层等,用于实现各种基础业务以及处理基于硬件的任务。
网络通信模块553,用于经由一个或多个(有线或无线)网络接口520到达其他计算设备,例如网络接口520包括:蓝牙、无线相容性认证(WiFi)、和通用串行总线(USB,Universal Serial Bus)等;
在一些实施例中,本申请实施例提供的多任务模型的训练装置可以采用软件方式实现,图2A示出了存储在存储器550中的多任务模型的训练装置555,其可以是程序和插件等形式的软件,包括以下软件模块:获取模块5551、确定模块5552、第一处理模块5553、以及第二处理模块5554这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。
在一些实施例中,本申请实施例提供的数据预测装置可以采用软件方式实现,图2B示出了存储在存储器550中的数据预测装置556,其可以是程序和插件等形式的软件,包括以下软件模块:获取模块5561、预测模块5562,以及执行模块5563,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。值得说明的是,图2B中除了示出的数据预测装置556外,其余结构均可与图2A相同。
需要说明的是,多任务模型的训练装置555与数据预测装置556可以集成在一个电子设备上,即电子设备可以同时实现多任务模型的训练方法与数据预测方法;多任务模型的训练装置555与数据预测装置556可以分别集成在两个电子设备上,即电子设备实现多任务模型的训练方法或数据预测方法。
如前所述,本申请实施例提供的多任务模型的训练方法可以由各种类型的电子设备实施。参见图3A,图3A是本申请实施例提供的多任务模型的训练方法的流程示意图,结合图3A示出的步骤进行说明,可以看出训练方法包括:步骤101-步骤104。
在步骤101中,获取N个子任务模型的N个测度值。
本申请实施例中,一个子任务模型对应一个测度值。一个测度值用于表征一个子任务模型的收敛进度。
本申请实施例涉及的测度值为表征子任务模型的收敛进度的指标。测度值可以包括但不限于:准确率(accuracy)、精确率(precision)、召回率(recall)、F1值、平均距离、平均相似度、平均散度、交叉熵、谱相似度、查全率、字错率、词错率、句错率、多通道特征相似度、多通道特征差值等。本申请实施例中的测度值还可以是多个测度值组合所得到的指标。例如:在一些应用场景下要求更强调准确率,召回率可以稍作妥协,测度值采用公式(1)计算得到。
测度值= 0.6 * 准确率 + 0.4 * 召回率 公式(1)
本申请实施例还具体公开一种获取测度值的实现方式。本实现方式中,任意一个子任务模型的测度值是对多任务模型进行M个验证周期的验证后计算得到的;M为大于1的整数;M个验证周期是指当前权重更新周期内所包括的验证周期;请参阅图3B,图3A中的步骤101可以通过执行步骤1011-步骤1013实现。
本申请实施例中,验证周期(validate_steps)可以理解为相邻的两次验证之间的迭代次数。验证周期的数值可以根据需求设定,本申请实施例不做具体的限定,例如:验证周期可以为10次迭代,即每完成10次迭代即达到验证周期。每个验证周期可以得到一个测度值。
本申请实施例中,权重更新周期可以理解为相邻的两次权重更新之间的验证周期。当前权重更新周期包括M个验证周期。本申请实施例不对当前权重更新周期包括验证周期的数值做具体的限定,例如:当前权重更新周期可以包括10次验证周期,即每完成10次验证到达当前权重更新周期,即M=10。
在步骤1011中,响应于当前时间到达权重更新周期规定的权重更新时间,针对每个子任务模型,基于当前权重更新周期内的M个验证周期的子任务模型的M个测度值,确定子任务模型的候选测度值。
作为一种可行性实现方式,候选测度值可以为M个验证周期的子任务模型的M个测度值中的一个测度值。例如,候选测度值可以是M个测度值中的中间值。
本申请实施例还具体公开一种确定候选测度值的实现方式,请参阅图3C,图3B中的步骤1011可以通过执行步骤10111-步骤10112实现。
步骤10111中,对当前权重更新周期内M个验证周期的子任务模型的M个测度值进行基于噪声的筛除处理。
本申请实施例具体公开一种筛除噪声的实现方式。具体的,可以参阅图3D,图3C中的步骤10111可以通过执行步骤101111-步骤101112实现。
步骤101111中,在M个测度值中,根据每两个相邻的测度值对应的变化趋势和测度差值,确定噪声测度值。
作为一种可行性实现方式,确定噪声测度值的方式可以为:若相邻的两个测度值的变化趋势与测度值的期望变化趋势相反,则将相邻的两个测度值中的后一个测度值作为噪声测度值。例如:若依次获取到的F1值为:0.3、0.2。0.3与0.2的变化趋势为降低(F1的期望变化趋势为升高),则将相邻的两个测度值中的后一个测度值(0.2)作为噪声测度值。
作为一种可行性实现方式,确定噪声测度值的方式可以为:若相邻的两个测度值的差值大于设定阈值,则将相邻的两个测度值中的后一个测度值作为噪声测度值。例如:设定阈值为0.1,依次获取到的F1值为:0.3、0.5。0.3与0.5的差值为0.2(大于设定阈值),则将相邻的两个测度值中的后一个测度值(0.5)作为噪声测度值。
值得注意的是,本申请实施例仅是示例性的介绍两种噪声测度值的确定方式,上述噪声测度值的确定并不构成具体的限定。
步骤101112中,将噪声测度值从M个测度值中筛除。
本实现方式中,将噪声测度值从M个测度值中筛除,可以降低噪声测度值影响,最终得到的候选测度值更加准确的表征子任务模型的收敛进度。
步骤10112中,对筛除处理后的测度值进行均值处理,得到子任务模型的候选测度值。
本实现方式中,对筛除处理后的测度值进行均值处理,得到子任务模型的候选测度值,可以降低噪声对候选测度值影响,最终得到的候选测度值更加准确的表征子任务模型的收敛进度。
作为一种可行性实现方式,可以对当前权重更新周期内得到的M个测度值进行处理,得到子任务模型的候选测度值。上述的处理可以包括:将M个测度值组合,得到候选测度值。其中,组合可以包括但不限于:线性组合、均值处理等。均值处理可以包括但不限于:算数平均、加权平均、几何平均等,例如:可以通过公式(2)得到候选测度值。
Mean(F1i) =ΣjF1ij / Lossλ_update_validates j∈[0,10) 式(2)
公式(2)中,Mean(F1i)为候选测度值,Lossλ_update_validates为当前权重更新周期内包含验证周期的数值,F1ij为当前权重更新周期内的验证周期j的第i个子任务模型的测度值。
本实现方式提供的确定候选测度值的实现方式,对当前权重更新周期内的M个测度值进行均值处理,得到子任务模型的候选测度值,候选测度值可以准确的表征子任务模型的收敛进度。
值得注意的是,本申请实施例仅是实例行的介绍几种确定候选测度值的实现方式,上述确定候选测度值实现方式并不构成具体的限定。
在步骤1012中,获取上一权重更新周期内的子任务模型的最佳测度值。
作为一种可行性实现方式,上一权重更新周期内的子任务模型的最佳测度值可以是上一权重更新周期内子任务模型的M个测度值中的最大值。作为一种可行性实现方式,上一权重更新周期内的子任务模型的最佳测度值还可以上一权重更新周期内子任务模型的M个测度值中的最小值。
在步骤1013中,从上一权重更新周期内的子任务模型的最佳测度值、子任务模型的候选测度值中,确定子任务模型的测度值。
作为一种可行性实现方式,若测度值与收敛进度正相关,则将上一权重更新周期内的子任务模型的最佳测度值、子任务模型的候选测度值中的最大值,作为当前权重更新周期内的子任务模型的测度值。例如:上一权重更新周期内的子任务模型的最佳测度值(F1=0.3),子任务模型的候选测度值中的最大值(F1=0.4),则选取F1=0.4作为当前权重更新周期内的子任务模型的测度值。
作为一种可行性实现方式,若测度值与收敛进度负相关,则将上一权重更新周期内的子任务模型的最佳测度值、子任务模型的候选测度值中的最小值,作为当前权重更新周期内的子任务模型的测度值。例如:上一权重更新周期内的子任务模型的最佳测度值(错误率=30%),子任务模型的候选测度值中的最大值(错误率=20%),则选取错误率=20%作为当前权重更新周期内的子任务模型的测度值。
本实现方式中,可以从上一权重更新周期内的子任务模型的最佳测度值、子任务模型的候选测度值中,确定子任务模型的测度值,最终得到的测度值更加准确的表征子任务模型的收敛进度。
在步骤102中,针对每个子任务模型,基于子任务模型的测度值,确定子任务模型的损失权重值。
本申请实施例涉及的损失权重值可以调整其对应的子任务模型在多任务模型中所占的比重,控制子任务模型的收敛速度。具体的,损失权重值越大,其对应的子任务模型在多任务模型中所占的训练份额越大,子任务模型的收敛速度越快。损失权重值越小,其对应的子任务模型在多任务模型中所占的训练份额越小,子任务模型的收敛速度越慢。
需要说明的是,本申请实施例涉及的损失权重值与测度值表征的收敛进度负相关。具体的:
在一些实施例中,可以采用与收敛进度正相关的测度,其中,与收敛进度正相关的测度可以包括但不限于:准确率、召回率、平均相似度等。在训练过程中,若采用与收敛进度正相关的测度,则损失权重值与测度值负相关。例如:如果采用准确率作为测度,在某步骤训练后,基于产生较高准确率的子任务模型(收敛进度快的子任务模型)的测度值,可以得到一个较小的损失权重值,较小的损失权重值可以降低其对应子任务模型的收敛速度,即降低收敛进度快的子任务模型的收敛速度。
在一些实施例中,可以采用与收敛进度负相关的测度,其中,与收敛进度负相关的测度可以包括但不限于:错误率、字错率、词错率、句错率等。在训练过程中,若采用与收敛进度负相关的测度,则损失权重值与测度值正相关。例如:如果采用错误率作为测度,在某步骤训练后产生较高错误率的子任务模型(收敛进度慢的子任务模型),基于该子任务模型的测度值可以得到一个较大的损失权重值,较大的损失值可以加快其对应子任务模型的收敛速度,即加快收敛进度慢的子任务模型的收敛速度。
为了保证训练前后,多任务模型的结构不会发生改变,本实现方式具体公开一种损失权重值的生成方式,可以参见图3E,图3A示出的步骤102中的“基于子任务模型的测度值,确定子任务模型的损失权重值”可以通过以下步骤1021至步骤1022实现,下面具体说明。
在步骤1021中,确定子任务模型的初始损失权重值。
本申请实施例中,初始损失权重值可以理解为训练前子任务模型的损失权重值。
在步骤1022中,基于所述子任务模型的测度值,对所述子任务模型的初始损失权重值进行映射处理,得到所述子任务模型的损失权重值。
本申请实施例不对初始损失权重值与损失权重值的映射关系做具体的限定,凡是可以满足:在测度值达到理想期望值时损失权重值等于初始损失权重值的映射关系,均可以应用到本申请实施例中,例如:初始损失权重值与损失权重值可以满足公式(3)中的映射关系。
λi=1+1-F1i,i∈[0,N] 公式(3)
公式(3)中,N为多任务模型中子任务模型的数量,λi为第i个子任务模型的损失权重值,F1i为第i个子任务模型的准确率(测度值)。以i=1为例,子任务模型的初始损失权重值均为1,在准确率达到100%(理想期望值),λ1=1+1-100%,子任务模型的损失权重值仍为1。
本实现方式公开的损失权重值生成方式,在测度值达到理想期望值时,损失权重值等于初始损失权重值。采用本实现方式提供的损失权重值的生成方式,在多任务模型收敛时(测度值达到理想期望值),每个子任务模型的损失权重值恢复为初始损失权重值,即训练前后多任务模型的模型结构未发生改变。
在步骤103中,通过多任务模型对训练集进行处理,得到N个子任务模型的N个预测损失值。
本申请实施例中,一个子任务模型对应一个预测损失值。
这里,步骤103可以通过以下方式实现:通过多任务模型对训练集进行预测处理,得到N个子任务模型的N个输出的预测结果,一个子任务模型对应一个预测结果;基于N个子任务模型的N个输出的预测结果,确定N个子任务模型的N个预测损失值,例如将N个预测结果代入N个损失函数,得到N个预测损失值,其中,一个预测结果与损失函数一一对应。在步骤104中,基于N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到多任务模型的总损失,并基于多任务模型的总损失,更新多任务模型的参数。
为了方便描述,本申请实施例将加权后的预测损失值称之为加权损失值。加权损失值可以参与多任务模型的总损失的计算。例如,作为一种可行性实现方式,可以采用公式(4),计算多任务模型的总损失。
Loss(N Task)=ΣLoss(λi Taski),i∈[1...N] 公式(4)
公式(4)中,N为多任务模型中子任务模型的数量,Loss(N Task)为总损失,Loss(λi Taski)为第i个子任务模型的加权损失值。
本申请实施例中涉及的总损失用于表示多任务模型的预测结果与实际结果之间的差异。多任务模型训练的目的在于训练后的多任务模型具有较低的总损失,进而保证采用训练后的多任务模型对待处理数据进行处理后预测结果与实际结果之间具有较小的差异。
基于多任务模型的总损失,更新多任务模型的参数的过程也可以称之为反向传播(Back Propagation)。反向传播处理是指根据计算出的总损失值依次传播至输出层、隐藏层及输入层,从而对多任务模型各个层中的参数进行更新,以降低多任务模型的总损失值。
综上,本申请实施例提供一种多任务模型的训练方法,该训练方法基于测度值生成的损失权重值,其中,测度值为表征子任务模型的收敛进度的指标,损失权重值可以控制子任务模型的收敛速度。因此,本申请实施例提供的训练方法可以根据子任务模型的收敛进度,调整各个子任务模型的损失权重值,从而影响各个子任务模型的收敛速度,使得这些子任务模型的尽可能达到同步收敛,以期望在某次迭代后,所有子任务模型都得到或者近似得到理想期望值,从而达成整个多任务模型的理想期望值。
如前所述,本申请实施例提供的数据预测方法可以由各种类型的电子设备实施。如图4所示,本申请实施例提供的数据预测方法通过以下步骤实现:
在步骤201中,获取待处理数据。
作为获取待处理数据的示例,用户可以通过终端输入待处理数据,终端自动获取针对待处理数据的预测请求(包括待处理数据),终端将针对待处理数据的预测请求发送至服务器,服务器解析针对待处理数据的预测请求,获取待处理数据。
在步骤202中,通过训练好的多任务模型对待处理数据进行预测处理,得到每个子任务模型输出的预测结果。
在步骤203中,基于N个预测结果执行数据处理任务。
本申请实施例在预测的过程中每个子任务模型输出的预测结果可以执行以下数据处理任务。数据处理任务可以包括但不限于信息推荐或者是统计不同的行为数据
作为一种可行性实现方式,可以是基于多个子任务模型输出的预测结果共同完成一个应用任务(即数据处理任务包括该应用任务),例如,输出的预测结果分别为待推荐信息的点击率以及转发率,需要基于预测的点击率以及转发率,计算待推荐信息的推荐结果,以实现信息推荐的应用任务。
作为一种可行性实现方式,可以是基于多个子任务模型输出的预测结果分别完成多个应用任务(即数据处理任务包括多个应用任务),例如,多个应用任务为针对某新闻的点赞数估计以及转发数估计,输出的预测结果分别为用户对某新闻的点赞率以及转发率,需要基于预测的点击率以及转发率,分别计算针对某新闻的点赞数以及转发数,以实现针对某新闻的点赞数估计以及转发数估计。
本申请实施例提供一种数据预测方法,该预测方法采用本申请实施例公开的训练方法得到训练好的多任务模型。训练好的多任务模型各子任务同步收敛,即各子任务同步收敛均达到理想期望值。进而保证采用训练后的多任务模型对待处理数据进行处理后预测结果与实际结果之间具有较小的差异。基于该训练后的多任务模型得到的预测结果准确度较高。
本申请实施例中训练好的多任务模型可以应用于各种业务场景。例如:对一个文本同时进行多音字消岐和文本分词处理。再例如:对一个文本不同的韵律进行标记(文本韵律标记)等。
下面将说明本申请实施例在一个实际的应用场景中的示例性应用。
多任务模型是指一个模型中同时包含多个子任务模型的神经网络模型。每个子任务模型均有其对应的训练目标和损失函数。例如:一个文本处理模型可包括两个子任务模型,其中,两个子任务模型的训练目标分别为:对文本进行多音字标记、对文本进行分词标记。
训练目标差异可以导致在多任务模型训练的过程中,各子任务模型的收敛进度存在差异,影响多任务模型训练的训练效果。
例如:对同一个文本同时进行多音字消岐和分词标记。文本中分词的分布比较均匀,且文本中分词出现密度较高。而不同文本中多音字的分布变化比较大,有的文本多音字标记相对密集一些,有的文本多音字标记相对稀疏一些。而在整体上呈现出的规律为:多音字标记密度(数据稠密度)小于比分词的标记密度。例如:文本“这么多人都去,谁来看家呢?。”对该文本进行分词标记和多音字标记。标记的结果如下:分词标记的结果可以为“这(B)么(I)多(E)人(S)都(B)去(E),(O)谁(B)来(E)看(B)家(E)呢(S)?(O)。”多音字标记的结果可以为“这么多人都(dou1)去,谁(shui2)来看(kan1)家呢?。”可以看出多音字的标记密度要大大的小于分词的标记密度。
多个子任务模型同时训练时,往往标记密度更高,训练目标的分类数更小的数据对应的子任务模型更容易学习(学习难度较低),该子任务模型的损失下降较快(收敛速度),可能只几个迭代后,该子任务模型的损失就收敛到很小。而标记密度较低的数据,以及目标分类数较多的数据对应的子任务模型学习难度较大,需要大量迭代训练,子任务模型才会收敛。
可见在多任务模型训练的过程中,各子任务模型的收敛速度的差异,导致每次训练结束后,各子任务模型的收敛进度存在差异。
为了平衡个子任务模型的收敛进度,相关技术一通过“模型设计”的方式来调整各个子任务模型的学习难度,进而平衡各个子任务模型的收敛进度。
通过“模型设计”来平衡各子任务模型的收敛进度,使各子任务模型尽可能的同步收敛实现方式存在如下的问题:模型设计主要是实现任务是目标效果最大化。在目标效果最大化基础上兼顾平衡各子任务模型的收敛进度,会导致多任务模型的设计、开发、实验等方面的成本较大。另外,开发出来的多任务模型只适合当前任务,后期的优化改进空间都很小。
一些研究者从多任务模型的训练过程入手,以期望找到可以平衡各个子任务模型的收敛进度的方式。多任务的训练过程大体为:将训练集输入多任务模型,进行一次前向计算过程,得到一组预测结果数据,这一组预测结果数据,分别对应不同子任务模型。将预测结果数据和真实标记数据(实际结果)输入到对应子任务模型的损失函数,得到该子任务模型在本次训练中的预测损失值。对所有子任务模型的预测损失值进行加权求和,得到多任务模型的总损失。对总损失进行损失反向传播,调整多任务模型的参数,完成一次迭代。经过若干迭代后,多任务模型的总损失越来越小,直到总损失值不再下降,则模型收敛。在一些训练过程中,也可以用额外的验证集数据对模型训练结果进行验证,用验证损失最小化来引导模型训练收敛过程。多任务训练时的总损失值为表征子任务模型的收敛进度的指标,例如:多任务模型的总损失值达到最小值时,可以认为该多任务模型收敛。
多任务模型的总损失值可以通过公式(5)得到。
Loss(N Task)=ΣLoss(λiTaski),i∈[1......N] 公式(5)
在公式(5)中,Loss(N Task)为多任务模型的总损失,Loss(λiTaski)为加权后的损失,λi为损失权重。损失权重值可以调整其对应的子任务模型在多任务模型中所占的比重,进而控制子任务模型的收敛速度。
考虑到损失权重值对子任务模型的收敛速度的影响,相关技术二通过对大量数据的分析为各子任务模型设定损失权重,以期望利用该损失权重可以平衡各子任务模型的收敛进度。相关技术二获得的损失权重通过大量的实验得出,获得这些损失权重的成本较高。
为了解决上述技术问题,本申请实施例提供一种多子任务模型的训练方法,该训练方法可以根据每个子任务模型在训练过程中产生的测度值,动态调整各个子任务模型的损失权重,从而控制各个子任务模型的收敛速度,使得各子任务模型的收敛进度尽可能的同步。进而达到某个训练/迭代后,所有子任务模型都得到或者近似得到理想期望值,从而达到多任务模型的理想期望值。
下面结合具体的应用场景对本申请实施例提供的方案作进一步的说明:
本实施例的应用场景为:TTS中常用的文本韵律的标记,采用的多任务模型为文韵律预测模型。
首先,介绍文本韵律模型的功能。
TTS任务中训练目标是通过文本韵律模型,将文本转换为语音。TTS往往需要采用对文本韵律进行标记,然后将标记后的文本输入文本韵律模型,来提升合成语音自然度。
韵律标记一般分为4级:#1表示韵律词,#2表示韵律短语,#3表示韵律短句,#4表示韵律句子。一个韵律标记后的文本可以为“北大#1女生#2穿#1马面裙#3参加#1毕业典礼#4。”最理想情况下,当文本韵律模型收敛时,输入一个句子:“北大女生穿马面裙参加毕业典礼。”期望得到的模型预测结果是:[0 1 0
2 1 0 0 3 0 1 0 0 0 4 0]。
下面结合图5介绍韵律模型的结构。
文本韵律模型包括:文本表示层(BERT)。BERT用于对每个训练集进行特征抽取。文本韵律模型还包括4个子任务模型,4个子任务模型在本实施例中可以称之为#1子任务模型、#2子任务模型、#3子任务模型、#4子任务模型。每个子任务模型用于对文本中每个字是否应该标记当前韵律的二值分类进行判断。
对比例1:中各子任务模型的损失权重值均为1,韵律模型的总损失值可以根据公式(6)生成。
TotalLoss=1*Loss1+1*Loss2+1*Loss3+1*Loss4 公式(6)
公式(6)中,TotalLoss为总损失值,Loss1为#1子任务模型的损失函数,Loss2为#2子任务模型的损失函数,Loss3为#3子任务模型的损失函数,Loss4为#4子任务模型的损失函数。
训练韵律模型的过程中,各子任务模型的召回率与迭代次数(训练步数)的关系曲线如图6所示,其中,图6中的(1)为#1子任务模型的召回率与迭代次数的关系曲线,图6中的(2)为#2子任务模型的召回率与迭代次数的关系曲线,图6中的(3)为#3子任务模型的召回率与迭代次数的关系曲线,图6中的(4)为#4子任务模型的召回率与迭代次数的关系曲线。
#1是韵律分词,一个句子必然有分词,也有很多分词,在文本中分词是稠密数据。因此,#1子任务模型的收敛速度较快。从图6中的(1)可以看出,在迭代次数达到1000时,#1子任务模型的召回率达到最大,随后,随着迭代次数的增加,#1子任务模型的召回率逐渐降低。
#2是韵律短语,韵律短语在一个句子(文本的一部分)中是否存在,存在多少都是不固定的。#2属于稀疏数据。因此#2子任务模型的收敛速度较慢。从图6中的(2)可以看出,在迭代次数达到1万5000时,#2子任务模型的召回率达到最大,随后,随着迭代次数的增加,#2子任务模型的召回率逐渐降低。
同理,#3属于稀疏数据,#3子任务模型的收敛速度较慢从图6中的(3)可以看出,在迭代次数达到2万5000时,#3子任务模型的召回率达到最大,随后,随着迭代次数的增加,#3子任务模型的召回率震荡上升。
#4是韵律句子,韵律句子每一句子必然出现且仅一次,#4属于稀疏数据。因此,#4子任务模型的收敛速度较慢,从图6中的(4)可以看出,在迭代次数达到5000时,#4子任务模型的召回率达到最大,随后,随着迭代次数的增加,#3子任务模型的召回率平稳。
对比例1提供的训练方案,最终韵律模型收敛时,#1子任务模型,#4子任务模型的召回率比较理想,可以达到92%以上,#2子任务模型的召回率和#3子任务模型的召回率很差,特别是#2召回率只有40%左右。
实施例1:
实施例1中,训练集为100万条,batch大小为100,每个epoch的训练步骤数为100W/100=1W,设置验证周期validate_steps=1000,损失权重更新周期Lossλ_update_validates=10,F1作为测度。
将4个子任务模型损失加权函数都设置为λi=1+1-F1ii∈[1,4],其中,λi为损失权重值。
训练过程中的损损失权重值重更新过程为:
(1)计算测度值:
每隔validate_steps=1000个训练步骤后,用韵律模型对验证集进行一次验证计算。即用验证集作为输入数据输入韵律模型,分别预测出对应的#1、#2、#3、#4的韵律标记结果(预测结果),并根据验证集中的人工标记的韵律真值(实际结果)和预测结果,分别计算出各子任务模型的精确率P和召回率R。然后,基于公式(7)计算出各子任务模型的测度值F1i。
F1i=(Pi*Ri)*2/(Pi+Ri) 公式(7)
(2)计算测度值的平均值Mean(F1i)。
训练过程中,每隔Lossλ_update_validates=10次上述的验证过程,根据公式(8)计算这10次所得测度值的平均值:
Mean(F1i)=ΣjF1ij/Lossλ_update_validates,j∈[0,10) 公式(8)
(3)计算最佳测度BestF1i。
根据公式(9)计算最佳测度。
当前的BestF1i=max(历史记录的BestF1i,Mean(F1i)) 公式(9)
公式(9)中,根据每个子任务模型的历史记录的最佳测度(上一权重更新周期内的子任务模型的最佳测度值)跟最近一次计算得到的该子任务模型的平均测度Mean(F1i)作比较,选取出当前权重更新周期的最佳测度BestF1i。
(4)计算和更新各子任务模型的损失权重λi。
基于公式(10)计算各子任务模型的损失权重λi。
λi=1+1-BestF1i 公式(10)
对每个子任务模型的预测损失值进行加权求和处理,得到多任务模型的总损失;基于多任务模型的总损失,更新多任务模型的参数。
对比例1提供的训练方法,各子任务训练的数据的稠密程度存在差异,各子任务模型的收敛速度存在差异,各子任务模型难以达成同步收敛。
实施例1提供的训练方法,是通过跟踪模型中各个子任务模型的测度值(模型收敛进度的一个指标),调整各个子任务模型的损失权重,进而控制各子模型的模型收敛进度,使得各子任务尽可能达成同步收敛。实施例1提供的训练方法执行成本低,通用性强,不需要额外的模型设计和权值选择实验。
至此已经结合本申请实施例提供的电子设备的示例性应用和实施,说明本申请实施例提供的多任务模型的训练方法以及数据预测方法。
本申请实施例还提供多任务模型的训练装置以及数据预测装置,实际应用中,多任务模型的训练装置以及数据预测装置中的各功能模块可以由电子设备(如终端设备、服务器或服务器集群)的硬件资源,如处理器等计算资源、通信资源(如用于支持实现光缆、蜂窝等各种方式通信)、存储器协同实现。图2A示出了存储在存储器550中的多任务模型的训练装置555以及图2B示出了存储在存储器550中的数据预测装置556,其可以是程序和插件等形式的软件,例如:软件C/C++、Java等编程语言设计的软件模块、C/C++、Java等编程语言设计的应用软件或大型软件系统中的专用软件模块、应用程序接口、插件、云服务等实现方式,下面对不同的实现方式举例说明。
其中,多任务模型的训练装置555包括:获取模块5551、确定模块5552、第一处理模块5553、第二处理模块5554。下面继续说明本申请实施例提供的多任务模型的训练装置555中各个模块配合实现多任务模型的训练方案。
获取模块5551,用于获取N个子任务模型的N个测度值,一个子任务模型对应一个测度值,一个测度值用于表征一个子任务模型的收敛进度;确定模块5552,用于针对每个子任务模型,基于子任务模型的测度值,确定子任务模型的损失权重值,损失权重值与测度值表征的收敛进度负相关;第一处理模块5553,用于通过多任务模型对训练集进行处理,得到N个子任务模型的N个预测损失值,一个子任务模型对应一个预测损失值;第二处理模块5554,用于基于N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到多任务模型的总损失,并基于多任务模型的总损失,更新多任务模型的参数。
上述技术方案中,若测度值与收敛进度正相关,则损失权重值与测度值负相关;若测度值与收敛进度负相关,则损失权重值与测度值正相关。
上述技术方案中,确定模块5552还用于确定每个子任务模型的初始损失权重值;基于每个子任务模型的测度值,对每个子任务模型的初始损失权重值进行映射处理,得到每个子任务模型的损失权重值;其中,若测度值达到理想期望值,则损失权重值为初始损失权重值。
上述技术方案中,获取模块5551还用于确定子任务模型的初始损失权重值;基于子任务模型的测度值,对子任务模型的初始损失权重值进行映射处理,得到子任务模型的损失权重值。
上述技术方案中,任意一个子任务模型的测度值是对多任务模型进行M个验证周期的验证后计算得到的;M为大于1的整数;M个验证周期是指当前权重更新周期内所包括的验证周期;获取模块5551,还用于响应于当前时间到达权重更新周期规定的权重更新时间,针对每个子任务模型,基于当前权重更新周期内的M个验证周期的子任务模型的M个测度值,确定子任务模型的候选测度值;获取上一权重更新周期内的子任务模型的最佳测度值;从上一权重更新周期内的子任务模型的最佳测度值、子任务模型的候选测度值中,确定子任务模型的测度值。
上述技术方案中,获取模块5551还用于若测度值与收敛进度正相关,则将上一权重更新周期内的子任务模型的最佳测度值和子任务模型的候选测度值中的最大值,作为子任务模型的测度值;若测度值与收敛进度负相关,则将上一权重更新周期内的子任务模型的最佳测度值和子任务模型的候选测度值中的最小值,作为子任务模型的测度值。
确定模块5552还用于对当前权重更新周期内M个验证周期的子任务模型的M个测度值进行均值处理,得到子任务模型的候选测度值。
上述技术方案中,确定模块5552还用对当前权重更新周期内M个验证周期的子任务模型的M个测度值进行基于噪声的筛除处理;对筛除处理后的测度值进行均值处理,得到子任务模型的候选测度值。
上述技术方案中,确定模块5552还用于在M个测度值中,根据每两个相邻的测度值对应的变化趋势和测度差值,确定噪声测度值;将噪声测度值从M个测度值中筛除。
上述技术方案中,确定模块5552还用于若相邻的两个测度值对应的变化趋势与期望变化趋势相反,则将相邻的两个测度值中的后一个测度值作为噪声;若相邻的两个测度值对应的变化趋势与期望变化趋势相同,但相邻的两个测度值之间的测度差值大于设定阈值,则将相邻的两个测度值中的后一个测度值作为噪声。其中,数据预测装置556包括一系列的模块,包括获取模块5561和预测模块5562。下面继续说明本申请实施例提供的数据预测装置556各个模块配合实现基于多任务模型的预测方案。
获取模块5561,用于获取待处理数据;预测模块5562,用于通过训练好的多任务模型对待处理数据进行预测处理,得到每个子任务模型模型输出的预测结果;其中,训练好的多任务模型是通过上述实施例提供的多任务模型的训练方法训练得到的;执行模块5563,用于基于N个预测结果执行数据处理任务。
本申请实施例提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该电子设备执行本申请实施例上述的多任务模型的训练方法,或者数据预测方法。
本申请实施例提供一种存储有计算机可执行指令的计算机可读存储介质,其中存储有计算机可执行指令或者计算机程序,当计算机可执行指令或者计算机程序被处理器执行时,将引起处理器执行本申请实施例提供的多任务模型的训练方法,或者数据预测方法,例如:如图3A-图3E示出的多任务模型的训练方法,如图4示出的数据预测方法。
在一些实施例中,计算机可读存储介质可以是FRAM、ROM、PROM、EP ROM、EEPROM、闪存、磁表面存储器、光盘、或CD-ROM等存储器;也可以是包括上述存储器之一或任意组合的各种设备。
在一些实施例中,计算机可执行指令可以采用程序、软件、软件模块、脚本或代码的形式,按任意形式的编程语言(包括编译或解释语言,或者声明性或过程性语言)来编写,并且其可按任意形式部署,包括被部署为独立的程序或者被部署为模块、组件、子例程或者适合在计算环境中使用的其它单元。
作为示例,计算机可执行指令可以但不一定对应于文件系统中的文件,可以可被存储在保存其它程序或数据的文件的一部分,例如:存储在超文本标记语言(HTML,HyperText Markup Language)文档中的一个或多个脚本中,存储在专用于所讨论的程序的单个文件中,或者,存储在多个协同文件(例如:存储一个或多个模块、子程序或代码部分的文件)中。
作为示例,计算机可执行指令可被部署为在一个电子设备上执行,或者在位于一个地点的多个电子设备上执行,又或者,在分布在多个地点且通过通信网络互连的多个电子设备上执行。
可以理解的是,在本申请实施例中,涉及到用户信息等相关的数据,当本申请实施例运用到具体产品或技术中时,需要获得用户许可或者同意,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。
以上所述,仅为本申请的实施例而已,并非用于限定本申请的保护范围。凡在本申请的精神和范围之内所作的任何修改、等同替换和改进等,均包含在本申请的保护范围之内。
Claims (14)
1.一种多任务模型的训练方法,其特征在于,所述多任务模型包括N个子任务模型,N为大于1的整数;所述训练方法包括:
获取所述N个子任务模型的N个测度值,一个子任务模型对应一个测度值,一个测度值用于表征一个子任务模型的收敛进度;
针对每个子任务模型,基于子任务模型的测度值确定所述子任务模型的损失权重值,所述损失权重值与所述测度值表征的收敛进度负相关;
通过所述多任务模型对训练集进行处理,得到所述N个子任务模型的N个预测损失值,一个子任务模型对应一个预测损失值;
基于所述N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到所述多任务模型的总损失,并基于所述多任务模型的总损失更新所述多任务模型的参数。
2.根据权利要求1所述的训练方法,其特征在于,若所述测度值与所述收敛进度正相关,则所述损失权重值与所述测度值负相关;若所述测度值与所述收敛进度负相关,则所述损失权重值与所述测度值正相关。
3.根据权利要求1或2所述的训练方法,其特征在于,所述基于子任务模型的测度值确定所述子任务模型的损失权重值,包括:
确定所述子任务模型的初始损失权重值;
基于所述子任务模型的测度值,对所述子任务模型的初始损失权重值进行映射处理,得到所述子任务模型的损失权重值。
4.根据权利要求1所述的训练方法,任意一个子任务模型的测度值是对所述多任务模型进行M个验证周期的验证后计算得到的;M为大于1的整数;所述M个验证周期是指当前权重更新周期内所包括的验证周期;
所述获取所述N个子任务模型的N个测度值,包括:
响应于当前时间到达权重更新周期规定的权重更新时间,针对每个子任务模型,基于所述当前权重更新周期内的M个验证周期的所述子任务模型的M个测度值,确定所述子任务模型的候选测度值;
获取上一权重更新周期内的所述子任务模型的最佳测度值;
从所述上一权重更新周期内的所述子任务模型的最佳测度值、所述子任务模型的候选测度值中,确定所述子任务模型的测度值。
5.根据权利要求4所述的训练方法,其特征在于,所述从所述上一权重更新周期内的所述子任务模型的最佳测度值、所述子任务模型的候选测度值中,确定所述子任务模型的测度值,包括:
若测度值与收敛进度正相关,则将所述上一权重更新周期内的所述子任务模型的最佳测度值和所述子任务模型的候选测度值中的最大值,作为所述子任务模型的测度值;
若测度值与收敛进度负相关,则将所述上一权重更新周期内的所述子任务模型的最佳测度值和所述子任务模型的候选测度值中的最小值,作为所述子任务模型的测度值。
6.根据权利要求4所述的训练方法,其特征在于,所述基于所述当前权重更新周期内的M个验证周期的所述子任务模型的M个测度值,确定所述子任务模型的候选测度值,包括:
对所述当前权重更新周期内M个验证周期的所述子任务模型的M个测度值进行均值处理,得到所述子任务模型的候选测度值。
7.根据权利要求4所述的训练方法,其特征在于,所述基于所述当前权重更新周期内的M个验证周期的所述子任务模型的M测度值,确定所述子任务模型的候选测度值,包括:
对所述当前权重更新周期内M个验证周期的所述子任务模型的M个测度值进行基于噪声的筛除处理;
对筛除处理后的测度值进行均值处理,得到所述子任务模型的候选测度值。
8.根据权利要求7所述的训练方法,其特征在于,所述M个验证周期的所述子任务模型的M个测度值按照验证周期的顺序排列;所述对所述当前权重更新周期内M个所述验证周期的所述子任务模型的M个测度值进行基于噪声的筛除处理,包括:
在M个测度值中,根据每两个相邻的测度值对应的变化趋势和测度差值,确定噪声测度值;
将所述噪声测度值从所述M个测度值中筛除。
9.根据权利要求8所述的方法,其特征在于,所述根据每两个相邻的测度值对应的变化趋势和测度差值,确定噪声测度值,包括:
若相邻的两个测度值对应的变化趋势与期望变化趋势相反,则将相邻的两个测度值中的后一个测度值作为噪声;
若相邻的两个测度值对应的变化趋势与期望变化趋势相同,但相邻的两个测度值之间的测度差值大于设定阈值,则将相邻的两个测度值中的后一个测度值作为所述噪声。
10.一种数据预测方法,其特征在于,包括:
获取待处理数据;
通过所述训练好的多任务模型对所述待处理数据进行预测处理,得到N个子任务模型输出的N个预测结果,一个子任务模型对应一个预测结果;训练好的多任务模型包括N个子任务模型,N为大于1的整数;
其中,所述训练好的多任务模型是通过上述权利要求1-9任一项所述多任务模型的训练方法训练得到的;
基于所述N个预测结果执行数据处理任务。
11.一种多任务模型的训练装置,其特征在于,所述多任务模型包括N个子任务模型,N为大于1的整数,所述训练装置包括:
获取模块,用于获取所述N个子任务模型的N个测度值,一个子任务模型对应一个测度值,一个测度值用于表征一个子任务模型的收敛进度;
确定模块,用于针对每个子任务模型,基于子任务模型的测度值,确定所述子任务模型的损失权重值,所述损失权重值与所述测度值表征的收敛进度负相关;
第一处理模块,用于通过所述多任务模型对训练集进行处理,得到所述N个子任务模型的N个预测损失值,一个子任务模型对应一个预测损失值;
第二处理模块,用于基于所述N个子任务模型的N个损失权重值,对N个预测损失值进行加权求和处理,得到所述多任务模型的总损失,并基于所述多任务模型的总损失,更新所述多任务模型的参数。
12.一种数据预测装置,其特征在于,训练好的多任务模型包括N个子任务模型,N为大于1的整数,所述数据预测装置包括:
获取模块,用于获取待处理数据;
预测模块,用于通过所述训练好的多任务模型对所述待处理数据进行预测处理,得到N个子任务模型输出的N个预测结果,一个子任务模型对应一个预测结果;其中,所述训练好的多任务模型是通过上述权利要求1-9任一项所述多任务模型的训练方法训练得到的;
执行模块,用于基于所述N个预测结果执行数据处理任务。
13.一种电子设备,其特征在于,所述电子设备包括:
存储器,用于存储计算机可执行指令;
处理器,用于执行所述存储器中存储的计算机可执行指令时,实现权利要求1至9任一项所述的多任务模型的训练方法,或权利要求10所述的数据预测方法。
14.一种计算机可读存储介质,其特征在于,存储有计算机程序或者计算机可执行指令,所述计算机程序或计算机可执行指令被处理器执行时实现权利要求1至9任一项所述的多任务模型的训练方法,或权利要求10所述的数据预测方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311363143.4A CN117952180A (zh) | 2023-10-19 | 2023-10-19 | 多任务模型的训练方法、数据预测方法、装置及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311363143.4A CN117952180A (zh) | 2023-10-19 | 2023-10-19 | 多任务模型的训练方法、数据预测方法、装置及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117952180A true CN117952180A (zh) | 2024-04-30 |
Family
ID=90800328
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311363143.4A Pending CN117952180A (zh) | 2023-10-19 | 2023-10-19 | 多任务模型的训练方法、数据预测方法、装置及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117952180A (zh) |
-
2023
- 2023-10-19 CN CN202311363143.4A patent/CN117952180A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP7210531B2 (ja) | ニューラルアーキテクチャ検索 | |
KR102302609B1 (ko) | 신경망 아키텍처 최적화 | |
US10984319B2 (en) | Neural architecture search | |
CN110210032B (zh) | 文本处理方法及装置 | |
US12008445B2 (en) | Black-box optimization using neural networks | |
KR20200014510A (ko) | 기계 학습 기반의 예측 서비스 제공 방법 및 그 장치 | |
CN112925926B (zh) | 多媒体推荐模型的训练方法、装置、服务器以及存储介质 | |
CN116257363B (zh) | 资源调度方法、装置、设备及存储介质 | |
CN116684330A (zh) | 基于人工智能的流量预测方法、装置、设备及存储介质 | |
CN117422005B (zh) | 一种模拟电路仿真误差自动控制的方法及应用 | |
JP2023552048A (ja) | ハードウェアアクセラレータのためのニューラルアーキテクチャスケーリング | |
CN111489196B (zh) | 基于深度学习网络的预测方法、装置、电子设备及介质 | |
CN117312979A (zh) | 对象分类方法、分类模型训练方法及电子设备 | |
CN117952180A (zh) | 多任务模型的训练方法、数据预测方法、装置及设备 | |
CN113191527A (zh) | 一种基于预测模型进行人口预测的预测方法及装置 | |
CN114298329A (zh) | 一种模型训练方法、装置、设备及存储介质 | |
CN113934813A (zh) | 一种样本数据划分的方法、系统、设备及可读存储介质 | |
CN111898389B (zh) | 信息确定方法、装置、计算机设备及存储介质 | |
CN112766490B (zh) | 特征变量学习方法、装置、设备及计算机可读存储介质 | |
CN115146596B (zh) | 召回文本的生成方法、装置、电子设备及存储介质 | |
WO2023028996A1 (en) | Methods and devices for ensuring the reproducibility of software systems | |
Galici et al. | Agent-based approach for Decentralized Genetic Algorithm | |
CN117219190A (zh) | 分子生成模型的训练方法、装置、设备、介质及程序产品 | |
CN115953031A (zh) | 风险预测模型的训练方法及装置、计算机可读存储介质 | |
CN114418122A (zh) | 机器学习模型的超参数配置方法、装置以及可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |