CN117196057A - 一种模型训练方法和相关设备 - Google Patents
一种模型训练方法和相关设备 Download PDFInfo
- Publication number
- CN117196057A CN117196057A CN202210583164.6A CN202210583164A CN117196057A CN 117196057 A CN117196057 A CN 117196057A CN 202210583164 A CN202210583164 A CN 202210583164A CN 117196057 A CN117196057 A CN 117196057A
- Authority
- CN
- China
- Prior art keywords
- models
- model
- training
- processor
- gradient information
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 179
- 238000000034 method Methods 0.000 title claims abstract description 52
- 238000012360 testing method Methods 0.000 claims description 38
- 230000010076 replication Effects 0.000 claims description 29
- 230000015654 memory Effects 0.000 claims description 28
- 238000004891 communication Methods 0.000 claims description 23
- 238000012545 processing Methods 0.000 claims description 21
- 238000004590 computer program Methods 0.000 claims description 6
- 230000006870 function Effects 0.000 description 18
- 230000000694 effects Effects 0.000 description 9
- 101100455978 Arabidopsis thaliana MAM1 gene Proteins 0.000 description 8
- 241001465754 Metazoa Species 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 238000011022 operating instruction Methods 0.000 description 6
- 238000002474 experimental method Methods 0.000 description 4
- 238000010801 machine learning Methods 0.000 description 4
- 238000004364 calculation method Methods 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000003190 augmentative effect Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000003064 k means clustering Methods 0.000 description 2
- 239000004973 liquid crystal related substance Substances 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000003068 static effect Effects 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- 241000282326 Felis catus Species 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000010897 surface acoustic wave method Methods 0.000 description 1
- 238000001356 surgical procedure Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请实施例公开了一种模型训练方法,用于基于权重对模型进行训练。在本申请中,通过使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数,然后复制所述M’,得到p个复制模型,p为小于等于P的正整数,并根据所述P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,得到权重信息,最后使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型,充分考虑不同任务之间联系与差异,降低了不同任务之间的干扰。
Description
技术领域
本申请涉及机器学习领域,尤其涉及一种模型训练方法和相关设备。
背景技术
传统的机器学习模型可以通过采取元学习的技术,提高模型的泛化能力,适用少样本场景的效果。目前,业界普遍采取的基于梯度信息的元学习的技术,例如模型无关元学习(model-agnostic meta-learning,MAML),其利用了梯度信息,可以充分地提取训练数据中高层次知识,从而实现模型中知识的迁移,简化了元学习的复杂度。
但是,基于梯度的元学习普遍忽视了任务之间的联系与差异,针对不同类型的任务只寻找一个全局的通用解,这使得一些困难任务往往被忽视,从而降低了其在测试集中的表现。
发明内容
本申请实施例提供了一种模型训练方法,用于基于权重对模型进行训练。
本申请第一方面提供了一种模型训练方法,首先通过使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数,然后复制所述M’,得到p个复制模型,p为小于等于P的正整数,并根据所述P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,得到权重信息,最后使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型,通过训练后的p个模型分别对数据进行预测,充分考虑不同任务之间联系与差异,降低了不同任务之间的干扰。
在一些可行的实现方式中,对所述P个梯度信息进行聚类,得到p个聚类中心,所述p个聚类中心与所述p个复制模型一一对应,然后确定所述P个梯度信息中各个梯度信息到所述p个聚类中心中各个聚类中心的距离,从而作为所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,从而考虑到不同任务之间联系与差异,降低了不同任务之间的干扰。
在一些可行的实现方式中,所述聚类为K均值聚类(k-means),实现了对P个梯度信息的聚类。
在一些可行的实现方式中,还可以获取测试数据,并分别使用所述训练后的p个模型对所述测试数据进行预测,得到p个预测结果,最后基于所述权重信息对所述p个预测结果进行加权,得到所述测试数据的最终预测结果,充分考虑不同任务之间联系与差异,降低了不同任务之间的干扰。
在一些可行的实现方式中,所述M为元学习模型,提高了模型的泛化能力,适用少样本场景的效果。
第二方面,本申请提供一种模型训练设备,所述模型训练设备用于执行前述第一方面中任一项所述的方法。
第三方面,本申请提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述第一方面或第二方面或第三方面中任一项所述的方法。
本申请第四方面提供一种计算机程序产品,该计算机程序产品包括计算机执行指令,该计算机执行指令存储在计算机可读存储介质中;设备的至少一个处理器可以从计算机可读存储介质读取该计算机执行指令,至少一个处理器执行该计算机执行指令使得设备实施上述第一方面或者第一方面的任一种可能的实现方式所提供的方法。
本申请第五方面提供一种通信装置,该通信装置可以包括至少一个处理器、存储器和通信接口。至少一个处理器与存储器和通信接口耦合。存储器用于存储指令,至少一个处理器用于执行该指令,通信接口用于在至少一个处理器的控制下与其他通信装置进行通信。该指令在被至少一个处理器执行时,使至少一个处理器执行第一方面或第一方面的任意可能的实现方式中的方法。
本申请第六方面提供了一种芯片系统,该芯片系统包括处理器,用于支持实现上述第一方面或第一方面任意一种可能的实现方式中所涉及的功能。
在一种可能的设计中,芯片系统还可以包括存储器,存储器,用于保存必要的程序指令和数据。该芯片系统,可以由芯片构成,也可以包含芯片和其他分立器件。
其中,第二至第六方面或者其中任一种可能实现方式所带来的技术效果可参见第一方面或第一方面不同可能实现方式所带来的技术效果,此处不再赘述。
附图说明
图1-1为本申请实施例提供的一种模型训练设备的组成结构示意图;
图1-2为本申请实施例提供的一种模型训练设备的另一组成结构示意图;
图2-1为本申请实施例提供的一种模型训练方法的流程示意图;
图2-2为本申请实施例提供的一种模型训练方法的另一流程示意图;
图2-3为本申请实施例提供的一种模型训练和测试实验的实施例流程示意图;
图2-4为本申请实施例提供的一种模型训练的实施例流程示意图;
图3为本申请实施例提供的一种模型训练设备的结构示意图;
图4为本申请实施例提供的一种通信装置的结构示意图。
具体实施方式
本申请实施例提供了一种模型训练方法,用于基于权重对模型进行训练。
下面结合附图,对本申请的实施例进行描述。
本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的术语在适当情况下可以互换,这仅仅是描述本申请的实施例中对相同属性的对象在描述时所采用的区分方式。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,以便包含一系列单元的过程、方法、系统、产品或设备不必限于那些单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它单元。
本申请实施例可应用于如图1-1所示的模型训练设备100中,模型训练设备100包括处理模块110、显示模块120、存储模块130、收发模块140和输入模块150(例如键盘、鼠标、触摸屏等,此处不做限定)。
处理模块110是模型训练设备100的控制中心,利用各种接口和线路连接模型训练设备100的各个部分,通过运行或执行存储在存储模块130内的软件程序和/或模块,以及调用存储在存储模块130内的数据,执行模型训练设备100的各种功能和处理数据,从而对模型训练设备100进行整体监控。可选的,处理模块110可包括一个或多个处理单元;优选的,处理模块110可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理模块110中。
显示模块120可用于显示由用户输入的信息或提供给用户的信息以及模型训练设备100的各种界面。显示模块120可包括显示面板,可选的,可以采用液晶显示器(LiquidCrystal Display,LCD)、有机发光二极管(Organic Light-Emitting Diode,OLED)等形式来配置显示面板。进一步的,触控面板可覆盖显示面板,当触控面板检测到在其上或附近的触摸操作后,传送给处理模块110以确定触摸事件的类型,随后处理模块110根据触摸事件的类型在显示面板上提供相应的视觉输出。
存储模块130可以包括只读存储器和随机存取存储器,并向处理模块110提供指令和数据。存储模块130的一部分还可以包括非易失性随机存取存储器(non-volatilerandom access memory,NVRAM)。存储模块130存储有处理器和操作指令、可执行模块或者数据结构,或者它们的子集,或者它们的扩展集,其中,操作指令可包括各种操作指令,用于实现各种操作。
收发模块140可用于接收输入的数字或字符信息,以及产生与模型训练设备100的相关设置以及功能控制有关的信号输入。收发模块140可用于通过第一接口输出数字或字符信息;收发模块140还可用于通过第一接口向磁盘组发送指令,以修改磁盘组中的数据;收发模块140还可以包括显示屏等显示设备。
输入模块150可用于接收输入的数字或字符信息,以及产生与模型训练设备100的用户设置以及功能控制有关的键信号输入。具体地,输入模块150可包括触控面板以及其他输入设备。触控面板,也称为触摸屏,可收集用户在其上或附近的触摸操作(比如用户使用手指、触笔等任何适合的物体或附件在触控面板上或在触控面板附近的操作),并根据预先设定的程式驱动相应的连接装置。可选的,触控面板可包括触摸检测装置和触摸控制器两个部分。其中,触摸检测装置检测用户的触摸方位,并检测触摸操作带来的信号,将信号传送给触摸控制器;触摸控制器从触摸检测装置上接收触摸信息,并将它转换成触点坐标,再送给处理模块110,并能接收处理模块110发来的命令并加以执行。此外,可以采用电阻式、电容式、红外线以及表面声波等多种类型实现触控面板。除了触控面板,输入模块150还可以包括其他输入设备。具体地,其他输入设备可以包括但不限于物理键盘、功能键(比如音量控制按键、开关按键等)、轨迹球、鼠标、操作杆等中的一种或多种。
在一些可能的实现方式中,模型训练设备100可以为终端设备也可以为服务器,此处不做限定。
本实施例中,服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN、以及大数据和人工智能平台等基础云计算服务的云服务器。终端可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。终端以及服务器可以通过有线或无线通信方式进行直接或间接地连接,终端以及服务器可以连接组成区块链网络,本申请在此不做限制。
其中,终端设备可以用户设备(user equipment,UE)、移动台(mobile station,MS)、移动终端(mobile terminal,MT)等。终端设备可以是手机(mobile phone)、平板电脑(Pad)、带无线收发功能的电脑、虚拟现实(virtual reality,VR)终端设备、增强现实(augmented reality,AR)终端设备、工业控制(industrial control)中的无线终端、无人驾驶(self driving)中的无线终端、远程手术(remote medical surgery)中的无线终端、智能电网(smart grid)中的无线终端、运输安全(transportation safety)中的无线终端、智慧城市(smart city)中的无线终端、智慧家庭(smart home)中的无线终端等等。本申请的实施例对终端设备所采用的具体技术和具体设备形态不做限定。
例如,模型训练设备100是家用电脑(终端设备),如图1-2所示,模型训练设备100的显示模块120可以为显示器,输入模块150可以为键盘和/或鼠标,处理模块110、存储模块130、收发模块140可以集成在家用电脑的主机中。
传统的机器学习模型可以通过采取元学习的技术,提高模型的泛化能力,适用少样本场景的效果。
目前,业界普遍采取的基于梯度信息的元学习的技术,例如模型无关元学习(model-agnostic meta-learning,MAML),其利用了梯度信息,可以充分地提取训练数据中高层次知识,从而实现模型中知识的迁移,简化了元学习的复杂度。
但是,基于梯度的元学习普遍忽视了任务之间的联系与差异,针对不同类型的任务只寻找一个全局的通用解,这使得一些困难任务往往被忽视,从而降低了其在测试集中的表现。
为此,本申请提供了一种模型训练方法,首先通过使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数,然后复制所述M’,得到p个复制模型,p为小于等于P的正整数,并根据所述P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,得到权重信息,最后使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型,充分考虑不同任务之间联系与差异,降低了不同任务之间的干扰。
前述实施例介绍了本申请提供给的模型训练设备100,接下来介绍基于该模型训练设备100执行的模型训练方法,请参阅图2-1和图2-2所示,本申请实施例提供的模型训练方法主要包括如下步骤:
201、模型训练设备获取初始化的模型M。
在一些可行的实现方式,M可以为元学习模型,也可以为其他类型的机器学习模型,此处不做限定。
需要说明的是,元学习模型包括一个相似度函数S和一个support set,其中,S用于判断两个数据的相似度,支撑集(support set)包括多个标签,每个标签对应一个或若干个数据。当模型训练设备接收到询问数据(query)时,通过S将query和support set中各个标签进行相似度的计算,以判断query属于support set中的哪个标签。那么,M的参数即为S的参数,训练M即为训练S,从而得到S的参数。
202、模型训练设备使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数。
在本申请实施例中,P个任务训练集中每个任务训练集中有多个训练数据及其标签。例如,训练数据为图片及其标签,其标签如猫、狗、熊等等,此处不做限定。
在一些可行的实现方式中,若M为元学习模型,可以基于预设的数据池构造M的P个任务训练集。其中,不同的任务训练集用于训练不同的模型的不同的能力,使得该模型具有相应的能力完成相应的任务。例如,任务训练集1用于使得模型具有识别动物的能力,任务训练集2用于使得模型具有识别植物的能力,任务训练集3用于使得模型具有识别人脸的能力,任务训练集4用于使得模型具有识别汽车的能力,任务训练集5用于使得模型具有识别文具的能力。在本申请实施例中,模型训练设备可以依次通过任务训练集1、任务训练集2、任务训练集3、任务训练集4、任务训练集5对M进行训练,使得M具有任务训练集1/2/3/4/5对应的能力。
示例性的,模型训练设备首先通过任务训练集1训练M,得到M1的模型参数,M1具有识别动物的能力;然后,模型训练设备通过任务训练集2训练M1,得到M2的模型参数,M2同时具有识别动物的能力以及识别植物的能力;然后,模型训练设备通过任务训练集3训练M2,得到M3的模型参数,M3同时具有识别动物的能力、识别植物的能力以及识别人脸的能力;然后,模型训练设备通过任务训练集4训练M3,得到M4的模型参数,M4同时具有识别动物的能力、识别植物的能力、识别人脸的能力以及识别汽车的能力;最后,然后,模型训练设备通过任务训练集5训练M4,得到M5的模型参数,M5同时具有识别动物的能力、识别植物的能力、识别人脸的能力、识别汽车的能力以及识别文具的能力。最后,得到的M5作为M’。
在本申请实施例中,对M/M1/M2/M3/M4的训练可以通过梯度下降法来训练。示例性的,模型训练设备首先构建损失函数,将任务训练集输入M/M1/M2/M3/M4,得到损失函数的值,若损失函数的值不低于预设值,则进行梯度下降,通过多次迭代后,得到M1/M2/M3/M4/M5的模型参数。
在一些可行的实现方式中,若M是元学习模型,对于P个任务训练集中任一任务训练集,可以构造正样本(x1,y1,0)和负样本(x2,x3,1),其中,x1,x2,x3,y1均表示数据池中的数据,其中,x1,x2,x3具有相同的标签,y1具有与x1,x2,x3不同的标签,(x1,y1,0)表示x1和y1的标签不同(标签不同表示为0),(x2,x3,1)表示x2和x3的标签相同(标签相同表示为1)。多个正样本和负样本构成训练数据,即可对M进行训练。
在本申请实施例中,模型训练设备在使用不同的任务训练集对M/M1/M2/M3/M4进行训练后,并收集梯度信息。需要说明的是,收集的梯度信息可以为对模型参数的首次迭代计算得到的梯度信息,也可以为对模型参数的最后一次迭代计算得到的梯度信息,也可以为对模型参数的任意一次迭代计算得到的梯度信息,此处不做限定。
示例性的,模型训练设备对M进行训练得到M1后,得到梯度信息1;模型训练设备对M1进行训练得到M2后,得到梯度信息2;模型训练设备对M2进行训练得到M3后,得到梯度信息3;模型训练设备对M3进行训练得到M4后,得到梯度信息4;模型训练设备对M4进行训练得到M5(即M’)后,得到梯度信息5。
在本申请实施例中,模型训练设备将不同任务训练集对应的梯度信息进行汇聚,得到分裂池,续上述例子,分裂池包括梯度信息1/2/3/4/5。
在本申请实施例中,当M为元学习模型时,步骤202亦称为元训练(meta-train)。
203、模型训练设备复制所述M’,得到p个复制模型,p为小于等于P的正整数。
在本申请实施例中,模型训练设备可以将训练后的M’进行复制,得到p个复制模型,后续对这p个复制模型分别进行训练,以充分考虑不同任务之间联系与差异,降低了不同任务之间的干扰。
在本申请实施例中,当M为元学习模型时,步骤203亦称为元分裂(meta-divide)。
204、模型训练设备对所述P个梯度信息进行聚类(cluster),得到p个聚类中心,所述p个聚类中心与所述p个复制模型一一对应。
在一些可行的实现方式,可以通过K均值聚类(k-means)的算法对P个梯度信息进行聚类。例如,设置k=3,即对各个梯度信息分成3类。其中,任务训练集1用于识别动物,任务训练集2用于识别植物,任务训练集3用于识别人脸,任务训练集4用于识别汽车,任务训练集5用于识别文具,进行聚类后,可以得到梯度信息1和梯度信息2对应聚类中心1(都是识别生物),梯度信息3对应聚类中心2(属于困难任务),梯度信息4和梯度信息5对应聚类中心3(都是识别非生物)。
在一些可行的实现方式,还可以使用其他聚类方法,例如均值漂移聚类、基于密度的聚类等等,此处不做限定。
在本申请实施例中,聚类中心的数量与复制模型的数量相同,复制模型和聚类中心一一对应。例如,前述将M’复制了3个复制模型,那么聚类中心的数量为3,例如聚类中心1对应复制模型1,聚类中心2对应复制模型2,聚类中心3对应复制模型3。
205、模型训练设备确定所述P个梯度信息中各个梯度信息到所述p个聚类中心中各个聚类中心的距离,作为所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重。
例如,任务训练集1对应梯度信息1,梯度信息1对应聚类中心1,模型训练设备可以计算梯度信息1与聚类中心1/2/3的距离11/12/13。同样的,模型训练设备还可以计算得到梯度信息2与聚类中心1/2/3的距离21/22/23,梯度信息3与聚类中心1/2/3的距离31/32/33,梯度信息4与聚类中心1/2/3的距离41/42/43,梯度信息5与聚类中心1/2/3的距离51/52/53。最后,对距离11/12/13/21/22/23/31/32/33/41/42/43/51/52/53进行归一化处理,得到任务训练集1/2/3/4/5在聚类中心1/2/3所对应的权重11/12/13/21/22/23/31/32/33/41/42/43/51/52/53。
例如,任务训练集1在聚类中心1的权重为0.7,在聚类中心2的权重为0.2,在聚类中心3的权重为0.1。
在本申请实施例中,对于确定权重信息的模块,亦称为权重计算器(weightscalculatior)。
206、模型训练设备使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型。
示例性的,获取任务训练集1,确定任务训练集1在聚类中心1/2/3中的权重,例如为0.7,0.2和0.1,那么可以将任务训练集1分别输入到3个复制模型,让3个复制模型分别输出损失函数。例如,将任务训练集1输入到复制模型1,得到损失函数L1,并基于L1进行偏导求得梯度信息t1,然后令Q(i+1)=Q(i)-η*t1*w1,其中,Q(i)为上一轮求得的复制模型的模型参数,Q(i+1)为本轮求得的复制模型的模型参数,w1为任务训练集1在聚类中心1的权重。进行多次迭代后,直到迭代次数大于预设次数,或损失函数的值不大于预设值,则训练结束,得到训练后的模型1。分别使用P个任务训练集训练p个复制模型,即可得到训练后的p个模型。
在本申请实施例中,当M为元学习模型时,步骤206亦称为元集成(meta-emsemble)。
207、模型训练设备对测试数据进行预测(prediction)。
在本申请实施例中,模型训练设备首先获取测试数据,然后分别使用所述训练后的p个模型对所述测试数据进行预测,得到p个预测结果,最后基于所述权重信息对所述p个预测结果进行加权(亦称为微调finetune),可以得到所述测试数据的最终预测结果。例如,在模型1/2/3中的权重分别是w1/w2/w3,将测试集1分别输入模型1/2/3,分别得到概率分别p1/2/3,最后计算p1*w1+p2*w2+p3*w3,作为测试集1的概率分布。
例如,当模型训练设备中配置有support set,support set中有2个标签(分别为标签1和标签2),每个标签有1张图片。当模型训练设备接收到一个query时,模型训练设备根据标签1确定任务训练集1,再确定该任务训练集1在聚类中心1/2/3(即在小模型1/2/3)中的权重1/2/3。然后,模型训练设备计算query和标签1的图片1的相似性时,通过小模型1/2/3分别对query进行预测,得到概率分布p1/2/3,再使用权重1/2/3对p1/2/3进行加权,得到总模型输出的P。
在一些可能的实现方式中,模型训练设备可以从数据池中选择P个任务测试集,每个任务测试集包括多个数据及其标签,其中,P个任务测试集和P个任务训练集一一对应,P个任务测试集用于测试模型的预测效果,对应的任务测试集和任务训练集对应的任务相同,且对应的任务测试集和任务训练集不重合。
然后基于前述方法,依次将P个任务测试集中各个任务测试集的数据输入训练后的p个模型中的各个模型中,基于权重信息对得到的预测结果进行加权,得到最终的预测结果,并与对应的标签进行比对,得到训练得到的最终模型(包括权重信息和训练后的p个模型)的性能,例如准确率、召回率等。
下面,可以通过实验对比使用MAML与本申请的技术方案在准确率上的表现情况。本实验中的数据池来自复杂数据集MiniImageNet。如图2-3所示,首先从MiniImageNet选择训练数据集(train dataset)和测试数据集(test dataset),然后使用train dataset进行meta-train,得到M’;再对M’进行meta-divide,多个复制模型;再对多个复制模型进行meta-ensemble,得到权重信息,并基于权重信息和train dataset分别训练后的多个复制模型,得到训练后的多个模型;最后通过test dataset及通过权重信息进行微调(finetune)(即加权)测试训练后的多个复制模型,得到预测结果(prediction)。
选择MAML在验证集上表现最好的模型作为测试模型,得到测试结果如表1。
表1
如表1所示,在给定整体训练次数为35000的情况下,根据实验表明,MAML已经达到验证集最好表现,其在测试集上经过十次测试得到正确率均值为48.6%,标准差为0.46%。而本申请的技术方案命名为基于梯度的元学习增强框架(Task-Specific Meta-learning,TSML),其在各种配置下轻易取得了优于MAML的表现。如图2-4所示,为TSML的技术方案,初始化MAML后,得到M,在meta-train时对M迭代30000次,并进行meta-divide,并在进行meta-ensemble时迭代5000次,得到训练后的多个模型,训练后的多个模型在测试时取得了最优的表现效果:正确率均值51.2%,标准差为1.34%。该结果表明,TSML辅助MAML在复杂数据集的小样本场景下实现了大约2.6%的大幅度正确率增长,证明了本申请的技术方案对元学习算法的泛化能力的提升具有有效性。
需要说明的是,对于前述的各方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本申请并不受所描述的动作顺序的限制,因为依据本申请,某些步骤可以采用其他顺序或者同时进行。其次,本领域技术人员也应该知悉,说明书中所描述的实施例均属于优选实施例,所涉及的动作和模块并不一定是本申请所必须的。
为便于更好的实施本申请实施例的上述方案,下面还提供用于实施上述方案的相关装置。
请参阅图3所示,本申请实施例提供的一种模型训练设备300,可以包括:
训练模块301,用于使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数;
处理模块302,用于复制所述M’,得到p个复制模型,p为小于等于P的正整数;
权重计算器303,用于根据所述P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,得到权重信息;
所述训练模块301,还用于使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型。
在一些可行的实现方式中,所述权重计算器303,具体用于:对所述P个梯度信息进行聚类,得到p个聚类中心,所述p个聚类中心与所述p个复制模型一一对应;确定所述P个梯度信息中各个梯度信息到所述p个聚类中心中各个聚类中心的距离,作为所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重。
在一些可行的实现方式中,所述模型训练设备还包括:预测模块304,用于获取测试数据,分别使用所述训练后的p个模型对所述测试数据进行预测,得到p个预测结果,基于所述权重信息对所述p个预测结果进行加权,得到所述测试数据的最终预测结果。
需要说明的是,上述装置各模块/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其带来的技术效果与本申请方法实施例相同,具体内容可参见本申请前述所示的方法实施例中的叙述,此处不再赘述。
本申请实施例还提供一种计算机存储介质,其中,该计算机存储介质存储有程序,该程序执行包括上述方法实施例中记载的部分或全部步骤。
接下来介绍本申请实施例提供的另一种通信装置,请参阅图4所示,通信装置400包括:
接收器401、发射器402、处理器403和存储器404。在本申请的一些实施例中,接收器401、发射器402、处理器403和存储器404可通过总线或其它方式连接,其中,图4中以通过总线连接为例。
存储器404可以包括只读存储器和随机存取存储器,并向处理器403提供指令和数据。存储器404的一部分还可以包括非易失性随机存取存储器(non-volatile randomaccess memory,NVRAM)。存储器404存储有操作系统和操作指令、可执行模块或者数据结构,或者它们的子集,或者它们的扩展集,其中,操作指令可包括各种操作指令,用于实现各种操作。操作系统可包括各种系统程序,用于实现各种基础业务以及处理基于硬件的任务。
处理器403控制通信装置的操作,处理器403还可以称为中央处理单元(centralprocessing unit,CPU)。具体的应用中,通信装置的各个组件通过总线系统耦合在一起,其中总线系统除包括数据总线之外,还可以包括电源总线、控制总线和状态信号总线等。但是为了清楚说明起见,在图中将各种总线都称为总线系统。
上述本申请实施例揭示的方法可以应用于处理器403中,或者由处理器403实现。处理器403可以是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法的各步骤可以通过处理器403中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器403可以是通用处理器、数字信号处理器(digital signal processing,DSP)、专用集成电路(application specific integrated circuit,ASIC)、现场可编程门阵列(field-programmable gate array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。结合本申请实施例所公开的方法的步骤可以直接体现为硬件译码处理器执行完成,或者用译码处理器中的硬件及软件模块组合执行完成。软件模块可以位于随机存储器,闪存、只读存储器,可编程只读存储器或者电可擦写可编程存储器、寄存器等本领域成熟的存储介质中。该存储介质位于存储器404,处理器403读取存储器404中的信息,结合其硬件完成上述方法的步骤。
接收器401可用于接收输入的数字或字符信息,以及产生与相关设置以及功能控制有关的信号输入,发射器402可包括显示屏等显示设备,发射器402可用于通过外接接口输出数字或字符信息。
本申请实施例中,处理器403,用于执行前述通信装置执行的模型训练方法。
在另一种可能的设计中,当通信装置为芯片时,包括:处理单元和通信单元,所述处理单元例如可以是处理器,所述通信单元例如可以是输入/输出接口、管脚或电路等。该处理单元可执行存储单元存储的计算机执行指令,以使该终端内的芯片执行上述第一方面任意一项的无线报告信息的发送方法。可选地,所述存储单元为所述芯片内的存储单元,如寄存器、缓存等,所述存储单元还可以是所述终端内的位于所述芯片外部的存储单元,如只读存储器(read-only memory,ROM)或可存储静态信息和指令的其他类型的静态存储设备,随机存取存储器(random access memory,RAM)等。
其中,上述任一处提到的处理器,可以是一个通用中央处理器,微处理器,ASIC,或一个或多个用于控制上述方法的程序执行的集成电路。
另外需说明的是,以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。另外,本申请提供的装置实施例附图中,模块之间的连接关系表示它们之间具有通信连接,具体可以实现为一条或多条通信总线或信号线。
通过以上的实施方式的描述,所属领域的技术人员可以清楚地了解到本申请可借助软件加必需的通用硬件的方式来实现,当然也可以通过专用硬件包括专用集成电路、专用CPU、专用存储器、专用元器件等来实现。一般情况下,凡由计算机程序完成的功能都可以很容易地用相应的硬件来实现,而且,用来实现同一功能的具体硬件结构也可以是多种多样的,例如模拟电路、数字电路或专用电路等。但是,对本申请而言更多情况下软件程序实现是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在可读取的存储介质中,如计算机的软盘、U盘、移动硬盘、ROM、RAM、磁碟或者光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述的方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。
所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存储的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘(Solid State Disk,SSD))等。
Claims (12)
1.一种模型训练方法,其特征在于,包括:
使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数;
复制所述M’,得到p个复制模型,p为小于等于P的正整数;
根据所述P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,得到权重信息;
使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型。
2.根据权利要求1所述方法,其特征在于,所述根据P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,包括:
对所述P个梯度信息进行聚类,得到p个聚类中心,所述p个聚类中心与所述p个复制模型一一对应;
确定所述P个梯度信息中各个梯度信息到所述p个聚类中心中各个聚类中心的距离,作为所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重。
3.根据权利要求2所述方法,其特征在于,所述聚类为K均值聚类。
4.根据权利要求1-3中任一项所述方法,其特征在于,所述方法还包括:
获取测试数据;
分别使用所述训练后的p个模型对所述测试数据进行预测,得到p个预测结果;
基于所述权重信息对所述p个预测结果进行加权,得到所述测试数据的最终预测结果。
5.根据权利要求1-3中任一项所述方法,其特征在于,所述M为元学习模型。
6.一种模型训练设备,其特征在于,包括:
训练模块,用于使用P个任务训练集依次对模型M进行训练,得到M’以及P个梯度信息,P为大于等于2的整数;
处理模块,用于复制所述M’,得到p个复制模型,p为小于等于P的正整数;
权重计算器,用于根据所述P个梯度信息确定所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重,得到权重信息;
所述训练模块,还用于使用所述P个任务训练集及基于所述权重信息分别训练所述M’的p个复制模型,得到训练后的p个模型。
7.根据权利要求6所述设备,其特征在于,所述权重计算器,具体用于:
对所述P个梯度信息进行聚类,得到p个聚类中心,所述p个聚类中心与所述p个复制模型一一对应;
确定所述P个梯度信息中各个梯度信息到所述p个聚类中心中各个聚类中心的距离,作为所述P个任务训练集中各个任务训练集在所述p个复制模型中各个复制模型的权重。
8.根据权利要求6或7所述设备,其特征在于,还包括:
预测模块,用于获取测试数据,分别使用所述训练后的p个模型对所述测试数据进行预测,得到p个预测结果,基于所述权重信息对所述p个预测结果进行加权,得到所述测试数据的最终预测结果。
9.一种计算机可读存储介质,其特征在于,该计算机可读存储介质存储有程序,所述程序使得计算机设备执行如权利要求1-5中任一项所述的方法。
10.一种计算机程序产品,其特征在于,所述计算机程序产品包括计算机执行指令,所述计算机执行指令存储在计算机可读存储介质中;设备的至少一个处理器从所述计算机可读存储介质中读取所述计算机执行指令,所述至少一个处理器执行所述计算机执行指令使得所述设备执行如权利要求1-5中任一项所述的方法。
11.一种通信装置,其特征在于,所述通信装置包括至少一个处理器、存储器和通信接口;
所述至少一个处理器与所述存储器和所述通信接口耦合;
所述存储器用于存储指令,所述处理器用于执行所述指令,所述通信接口用于在所述至少一个处理器的控制下与其他通信装置进行通信;
所述指令在被所述至少一个处理器执行时,使所述至少一个处理器执行如权利要求1-5中任一项所述的方法。
12.一种芯片系统,其特征在于,所述芯片系统包括处理器和存储器,所述存储器和所述处理器通过线路互联,所述存储器中存储有指令,所述处理器用于执行如权利要求1-5中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210583164.6A CN117196057A (zh) | 2022-05-26 | 2022-05-26 | 一种模型训练方法和相关设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210583164.6A CN117196057A (zh) | 2022-05-26 | 2022-05-26 | 一种模型训练方法和相关设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117196057A true CN117196057A (zh) | 2023-12-08 |
Family
ID=88994718
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210583164.6A Pending CN117196057A (zh) | 2022-05-26 | 2022-05-26 | 一种模型训练方法和相关设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117196057A (zh) |
-
2022
- 2022-05-26 CN CN202210583164.6A patent/CN117196057A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111242297A (zh) | 基于知识蒸馏的模型训练方法、图像处理方法及装置 | |
CN108140075A (zh) | 将用户行为分类为异常 | |
CN111950596A (zh) | 一种用于神经网络的训练方法以及相关设备 | |
WO2024041479A1 (zh) | 一种数据处理方法及其装置 | |
CN111813532A (zh) | 一种基于多任务机器学习模型的图像管理方法及装置 | |
EP4390753A1 (en) | Text data processing method, neural network training method, and related devices | |
CN114974397A (zh) | 蛋白质结构预测模型的训练方法和蛋白质结构预测方法 | |
CN114334036A (zh) | 一种模型训练的方法、相关装置、设备以及存储介质 | |
CN116684330A (zh) | 基于人工智能的流量预测方法、装置、设备及存储介质 | |
CN114428842A (zh) | 一种扩充问答库的方法、装置、电子设备及可读存储介质 | |
CN111507407B (zh) | 图像分类模型的训练方法及装置 | |
CN115203194A (zh) | 一种元数据信息的生成方法、相关装置、设备及存储介质 | |
CN116186295A (zh) | 基于注意力的知识图谱链接预测方法、装置、设备及介质 | |
CN117196057A (zh) | 一种模型训练方法和相关设备 | |
JP7236501B2 (ja) | 文書類似度学習に基づくディープラーニングモデルの転移学習方法およびコンピュータ装置 | |
CN109583583B (zh) | 神经网络训练方法、装置、计算机设备及可读介质 | |
CN115687146A (zh) | Bios测试方法、装置、计算机设备和存储介质 | |
US20220292393A1 (en) | Utilizing machine learning models to generate initiative plans | |
CN113760407A (zh) | 信息处理方法、装置、设备及存储介质 | |
CN112257812A (zh) | 一种标注样本确定方法、装置、机器可读介质及设备 | |
CN117194408A (zh) | 一种索引方案的选择方法和存储设备 | |
CN113628080B (zh) | 一种分数预测方法、装置、存储介质和电子设备 | |
WO2024055952A1 (zh) | 一种数据处理方法及其装置 | |
CN110598578B (zh) | 身份识别方法、身份识别系统的训练方法、装置及设备 | |
CN113657453B (zh) | 基于生成对抗网络和深度学习的有害网站的检测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |