CN115878989A - 模型训练方法、装置及存储介质 - Google Patents
模型训练方法、装置及存储介质 Download PDFInfo
- Publication number
- CN115878989A CN115878989A CN202111131777.8A CN202111131777A CN115878989A CN 115878989 A CN115878989 A CN 115878989A CN 202111131777 A CN202111131777 A CN 202111131777A CN 115878989 A CN115878989 A CN 115878989A
- Authority
- CN
- China
- Prior art keywords
- local
- model
- target
- parameter
- model parameters
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请实施例公开了一种模型训练方法、装置及存储介质,属于人工智能领域。所述方法包括:目标分节点在接收到来自中心节点的全局模型参数后,将接收到的全局模型参数和目标分节点自身目标模型的本地模型参数进行融合获得融合模型参数,这样,获得的融合模型参数中同时包含了目标分节点的本地信息和全局信息,之后,采用同时包含了本地信息和全局信息的融合模型参数对目标模型进行更新并采用本地训练集对该目标模型进行训练,能够有效提高目标模型的精度,这样,即使参与训练的各分节点采用的本地训练集中的样本数据分布不同或各分节点上的目标模型需要完成的任务不同,采用本申请实施例方法训练出的目标模型仍然能够满足各分节点的需求。
Description
技术领域
本申请涉及人工智能(artificial intelligence,AI)领域,特别涉及一种模型训练方法、装置及存储介质。
背景技术
当前,AI模型被广泛的应用于各行各业中。其中,模型的训练需要用到大量数据,但随着人们数据保护意识的增强,数据的获取难度越来越大,跨节点的模型训练越来越受到重视。
相关技术中,中心节点可以下发目标模型的初始化全局模型参数给参与学习的各个分节点,各个分节点可以将该初始化全局模型参数作为自身之上的该目标模型的模型参数,并通过本地训练集对该目标模型进行训练,然后将训练后获得的该目标模型的模型参数上传至中心节点。中心节点可以将接收到的各个分节点上报的模型参数进行联合平均,从而得到更新后的全局模型参数,之后,中心节点可以将更新后的全局模型参数再次下发至各个分节点,各个分节点重复上述过程直至模型收敛为止,将中心节点最后一次下发的更新后的全局模型参数作为自身之上的目标模型的模型参数。
然而,由于各分节点上用于训练目标模型的训练集中的样本数据分布可能不同,而且各分节点上的目标模型的任务可能也不同,所以,在所有分节点均采用相同的全局模型参数作为自身的目标模型的模型参数来对目标模型进行训练的情况下,各分节点最终得到的拥有相同模型参数的目标模型并不能满足各自的需求。
发明内容
本申请实施例提供了一种模型训练方法、装置和存储介质,能够使得各个分节点训练得到的模型更贴合自身需求,提高了各个分节点上的模型的精度。所述技术方案如下:
第一方面,提供了一种模型训练方法,所述方法包括:接收来自中心节点的目标模型的全局模型参数;对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数,所述本地模型参数是指基于本地训练集对所述目标模型进行训练得到的模型参数;将所述目标模型的本地模型参数更新为所述融合模型参数,并根据所述目标分节点的本地训练集对更新后的目标模型进行训练。
其中,全局模型参数是中心节点根据每个分节点的本地训练集的样本数据的数量对接收到的各分节点的本地模型参数进行加权平均后得到的。
在本申请实施例中,由于融合模型参数是目标分节点通过对自身的本地模型参数和接收到的中心节点的全局模型参数进行融合后得到的,所以融合模型参数中同时包含了目标分节点的本地信息和全局信息,采用这样的融合模型参数对目标模型进行更新并采用本地训练集对该目标模型进行训练,能够有效提高目标模型的精度。
可选地,所述对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数的实现过程可以为:根据所述目标分节点的参数融合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到所述融合模型参数。
可选地,所述目标分节点的参数融合规则包括基于所述目标分节点的属性信息确定的所述目标模型的本地模型参数中各个本地参数组的替换概率。在这种情况下,所述根据所述目标分节点的参数融合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数的实现过程可以为:将所述本地模型参数中替换概率大于第一阈值的本地参数组替换为所述全局模型参数中对应的参数组,得到所述融合模型参数。
其中,目标分节点的属性信息可以包括本地训练集中的样本数据的数据量和分布特征、目标分节点的计算能力以及目标模型的规模信息中的至少一项。本地参数组是对本地模型参数中的多个本地参数值进行划分得到。
在本申请实施例中,由于不同分节点的属性信息可能不同,各分节点根据自身的属性信息确定目标模型的本地模型参数中各个本地参数组的替换概率,并将本地模型参数中替换概率大于第一阈值的本地参数组替换为全局模型参数中对应的参数组,得到所述融合模型参数,这样各个分节点获得的融合模型参数能够更适合各分节点自身的需求,能够有效降低因各分节点上用于训练目标模型的训练集中的样本数据分布不同或各分节点上的目标模型的任务不同对训练出的模型精度的影响,能够提高各分节点最终训练得到的目标模型的精度。
可选地,在对目标模型的全局模型参数和本地模型参数进行融合之前,还可以根据所述目标分节点的属性信息,确定参数搜索粒度;根据所述参数搜索粒度,对所述本地模型参数包括的多个本地参数值进行分组,得到多个本地参数组;获取所述多个本地参数组中每个本地参数组的待优化替换概率;根据所述本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率。
其中,参数搜索粒度用于指示将模型参数包括的多个参数值划分为参数组时每个参数组中包括的参数值的数量,参数组中包括的参数值的数量越多,则说明参数搜索粒度越粗,参数组中包括的参数值的数量越少,则说明参数搜索粒度越细。
在本申请实施例中,由于各分节点的本地训练集的样本数据的数据量、本地训练集中的样本数据的复杂程度以及各分节点的计算能力和各分节点部署的目标模型的规模可能不同,所以各分节点根据自身的属性信息确定出的参数搜索粒度更适合自身的需求。
可选地,所述根据所述本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率的实现过程可以为:根据每个本地参数组的待优化替换概率、所述多个本地参数组和所述全局模型参数,确定验证模型参数;将所述目标模型的本地模型参数替换为所述验证模型参数,得到验证模型;根据所述本地训练集对所述验证模型进行训练,得到更新后的验证模型;根据验证集对所述更新后的验证模型进行测试;如果测试结果不满足参考条件,对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为所述待优化替换概率,并返回执行所述根据每个本地参数组的待优化替换概率、所述多个本地参数组和所述全局模型参数,确定验证模型参数的步骤,直至所述测试结果满足所述参考条件时,将最后一次更新后的概率作为每个本地参数组的替换概率。
其中,参考条件可以为模型的损失函数、精准率、召回率、准确率、平均交并比等。
在本申请实施例中,各分节点是根据自身的验证集对自身的验证模型进行测试,并在测试结果不满足参考条件的情况下,对各本地参数组的待优化替换概率进行不断优化,之后,通过优化得到的本地参数组的替换概率来将本地模型参数中的部分参数值替换为全局模型参数中的对应参数值,从而得到融合模型参数,此时,该融合模型参数更贴合该分节点的数据分布和检测任务,在此基础上,基于该融合模型参数对该目标模型训练,能够使得训练得到的模型的精度更高。
可选地,所述目标分节点的参数融合规则包括根据所述目标分节点的属性信息预先设置的待替换的本地参数组的索引,所述本地参数组为对所述本地模型参数包括的多个本地参数值进行分组得到。在这种情况下,所述根据所述目标分节点的参数组合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数的实现过程可以为:根据所述待替换的本地参数组的索引,将所述本地模型参数中待替换的本地参数组替换为所述全局模型参数中对应的参数组,得到所述融合模型参数。
在本申请实施例中,基于人工经验预先设置待替换的参数组的索引,能够减少模型训练过程的计算量。
第二方面,提供了一种模型训练装置,所述模型训练装置具有实现上述第一方面中模型训练方法行为的功能。所述模型训练装置包括至少一个模块,该至少一个模块用于实现上述第一方面所提供的模型训练方法。
第三方面,提供了一种模型训练装置,所述模型训练装置的结构中包括处理器和存储器,所述存储器用于存储支持模型训练装置执行上述第一方面所提供的模型训练方法的程序,以及存储用于实现上述第一方面所提供的模型训练方法所涉及的数据。所述处理器被配置为用于执行所述存储器中存储的程序。
第四方面,提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述第一方面所述的模型训练方法。
第五方面,提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述第一方面所述的模型训练方法。
上述第二方面、第三方面、第四方面和第五方面所获得的技术效果与第一方面中对应的技术手段获得的技术效果近似,在这里不再赘述。
本申请实施例提供的技术方案带来的有益效果至少包括:
在本申请实施例中,目标分节点在接收到来自中心节点的全局模型参数后,将接收到的全局模型参数和目标分节点自身目标模型的本地模型参数进行融合获得融合模型参数,这样,获得的融合模型参数中同时包含了目标分节点的本地信息和全局信息,之后,采用同时包含了本地信息和全局信息的融合模型参数对目标模型进行更新并采用本地训练集对该目标模型进行训练,能够有效提高目标模型的精度,这样,即使参与训练的各分节点采用的本地训练集中的样本数据分布不同或各分节点上的目标模型需要完成的任务不同,采用本申请实施例方法训练出的目标模型仍然能够满足各分节点的需求。
附图说明
图1是本申请实施例提供的一种模型训练方法所涉及的系统架构图;
图2是本申请实施例提供的一种计算机设备的结构示意图;
图3是本申请实施例提供的一种模型训练方法的流程图;
图4是本申请实施例提供的一种确定每个本地参数组替换概率的方法流程图;
图5是本申请实施例提供的一种目标分节点A获得融合模型参数的流程示意图;
图6是本申请实施例提供的一种模型训练装置的结构示意图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
在对本申请实施例进行详细的解释说明之前,先对本申请实施例的应用场景予以说明。
本申请实施例提供的模型训练方法可以用于联邦学习、分布式学习等跨节点的深度学习场景中。例如,在自动驾驶场景中,由于各公司采集的自动驾驶数据受自身采集设备以及自身所处地区的天气、建筑风格、交通标识符等的影响,所以不同公司拥有的自动驾驶数据不同,在这种情况下,各公司间一般不进行跨区域数据共享,而是可以采用本申请实施例的方法对拥有不同自动驾驶数据的各公司的自动驾驶模型进行联合训练,可以降低联邦学习过程中因各公司自动驾驶数据不同对训练出的模型精度的影响,能够提高各公司最终训练得到的模型的精度。再例如,在商品信息推送场景中,电商客户端中包含有为客户推送所售商品的商品信息的商品信息推送模型。其中,各个客户端为客户推送的商品信息中包含的商品种类可能不同,也即,各个客户端中的商品信息推送模型的任务不同,但是各个客户端拥有的客户群体有可能相同。在这种场景中,采用本申请实施例的方法对这多个客户端中的商品信息推送模型进行联合训练,训练得到的商品信息推送模型能够更符合各个客户端的要求。
需要说明的是,上述仅是本申请实施例给出的一些示例性的应用场景,并不构成对本申请实施例提供的模型训练方法的应用场景的限定。
图1是本申请实施例提供的一种模型训练方法所涉及的系统架构图。如图1所示,该系统包括中心节点101和多个分节点102,其中,多个分节点102可以通过有线或无线网络与中心节点101连接。
在本申请实施例中,中心节点101和多个分节点102上可以均部署有目标模型。中心节点101下发目标模型的全局模型参数给参与学习的多个分节点102,相应地,各分节点102接收中心节点101下发的全局模型参数并采用本申请实施例的方法对各分节点102自身的目标模型进行训练。
需要说明的是,在一种可能的实现方式中,中心节点101可以为一台服务器或一个服务器集群,或者,是一个能够对联邦学习过程进行协调的云平台,各个分节点102可以为诸如智能手机、平板电脑、笔记本电脑等用户终端。
可选地,中心节点101和各个分节点102可以部署在云环境中,例如,该中心节点101和多个分节点102均为部署在云数据中心中的服务器或虚拟机上。或者,该中心节点101和各个分节点可以是部署在边缘环境中的计算机设备,本申请实施例对此不做限定。
图2是本申请实施例提供的一种计算机设备的结构示意图。图1中的中心节点和/或目标分节点均可以通过该计算机设备来实现。参见图2,该计算机设备包括至少一个处理器201,通信总线202,存储器203以及至少一个通信接口204。
其中,处理器201可以包括一个通用中央处理器(central processing unit,CPU)、图形处理器(graphics processing unit,GPU)、网络处理器(network processor,NP)、微处理器、或者一个或多个用于实现本申请方案的集成电路,例如,专用集成电路(application-specific integrated circuit,ASIC)、可编程逻辑器件(programmablelogic device,PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(complexprogrammable logic device,CPLD)、现场可编程逻辑门阵列(field-programmable gatearray,FPGA)、通用阵列逻辑(generic array logic,GAL)或其任意组合。
通信总线202用于在上述组件之间传送信息。通信总线202可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
存储器203可以是只读存储器(read-only memory,ROM)、随机存取存储器(randomaccess memory,RAM)、可存储静态信息和指令的其它类型的静态存储设备、可存储信息和指令的其它类型的动态存储设备,也可以是电可擦可编程只读存储器(electricallyerasable programmable read-only memory,EEPROM)、只读光盘(compact disc read-only memory,CD-ROM),或者其它光盘存储、光碟存储(包括压缩光碟、激光碟、光碟、数字通用光碟、蓝光光碟等)、磁盘存储介质,或者其它磁存储设备,或者能够用于携带或存储具有指令或数据结构形式的期望的程序代码并能够由计算机设备存取的任何其它介质,但不限于此。存储器203可以是独立存在,或者通过通信总线202与处理器201相连接,或者和处理器201集成在一起。
通信接口204使用任何收发器一类的装置,用于与其它设备或通信网络通信。通信接口204包括有线通信接口,还可以包括无线通信接口。其中,有线通信接口例如可以为以太网接口。以太网接口可以是光接口,电接口或其组合。无线通信接口可以为无线局域网(wireless local area networks,WLAN)接口,蜂窝网络通信接口或其组合等。
在一种实施例,处理器201可以包括一个或多个CPU,例如图2中所示的CPU0和CPU1。
在一种实施例,该物理服务器可以包括多个处理器,例如图2中所示的处理器201和处理器205。这些处理器中的每一个可以是一个单核处理器(single-CPU),也可以是一个多核处理器(multi-CPU)。这里的处理器可以指一个或多个设备、电路、和/或用于处理数据的处理核。
在一种实施例,计算机设备还可以包括输出设备206和输入设备207。输出设备206和处理器201通信,可以以多种方式来显示信息。例如,输出设备206可以是液晶显示器(liquid crystal display,LCD)、发光二级管(light emitting diode,LED)显示设备、阴极射线管(cathode ray tube,CRT)显示设备或投影仪(projector)等。输入设备207和处理器201通信,可以以多种方式接收用户的输入。例如,输入设备207可以是鼠标、键盘、触摸屏设备或传感设备等。
其中,存储器203用于存储执行本申请方案的程序代码208,处理器201用于执行存储器203中存储的程序代码208。该计算机设备可以通过处理器201以及存储器203中的程序代码208,来实现下文图3实施例所提供的模型训练方法。
图3是本申请实施例提供的一种模型训练方法的流程图。该方法可以应用于图1中所示的模型训练系统中的任一分节点上,下文中以其中一个分节点为例进行说明,为了方便叙述,将该分节点称为目标分节点,参见图3,该方法包括以下步骤:
步骤301:接收来自中心节点的目标模型的全局模型参数。
在本申请实施例中,中心节点和各分节点上均部署有目标模型,各个分节点均具有自己的本地训练集,在开始训练该目标模型时,各个分节点根据自己的本地训练集对目标模型进行训练得到目标模型的本地模型参数,之后,各分节点将本地模型参数上传至中心节点,相应地,中心节点接收各分节点上传的本地模型参数,并根据每个分节点的本地训练集的样本数量对接收到的各分节点的本地模型参数进行加权平均后得到全局模型参数。之后,中心节点将得到的全局模型参数下发给各分节点,相应地,目标分节点接收中心节点下发的目标模型的全局模型参数。
需要说明的是,各分节点上传至中心节点的本地模型参数可以包括自身部署的目标模型的全部本地参数值,也可以为部分本地参数值。例如,某个分节点可以上传目标模型的所有层的本地参数值,或者,上传目标模型中某几个层的本地参数值。中心节点在接收到各分节点上传的本地模型参数后,可以对所有分节点上传的本地模型参数中目标模型的相同部分的参数值进行加权平均,得到全局模型参数。例如,一个分节点上传目标模型的第1至5卷积层的本地参数值至中心节点,另一个分节点上传目标模型第1至10卷积层的本地参数值至中心节点,中心节点可以将两个分节点上传的本地模型参数中的第1至5卷积层的本地参数值进行加权平均,获得第1至5卷积层的全局参数值,并将获得的第1至5卷积层的全局参数值作为全局模型参数下发至各分节点。
示例性的,对于全局模型参数中的全局参数值ωk,可以通过下述公式确定得到:
其中,i为参与模型训练的第i个分节点,ni为第i个分节点的本地训练集的样本数量,n为参与模型训练的所有分节点的本地训练集的样本数量之和,为第i个分节点的本地模型参数中与该全局参数值ωk对应的本地参数值,其中,k为大于0且不小于全局模型参数的个数的正整数。
步骤302:对目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数。
在本申请实施例中,目标分节点可以根据自身的参数融合规则,对目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数。
在一种实现方式中,目标分节点的参数融合规则可以为基于目标分节点的属性信息确定的目标模型的本地模型参数中各个本地参数组的替换概率。其中,该本地参数组是对本地模型参数中的多个本地参数值进行划分得到。在这种情况下,目标分节点首先可以基于目标分节点的属性信息确定多个本地参数组,进而确定各个本地参数组的替换概率,之后,将本地模型参数中替换概率大于第一阈值的本地参数组替换为全局模型参数中对应的参数组,得到融合模型参数。
示例性地,如图4所示,目标分节点可以通过下述步骤3021-3024来确定本地参数组,并确定本地模型参数中各个本地参数组的替换概率。
3021:根据目标分节点的属性信息,确定参数搜索粒度。
在本申请实施例中,目标分节点的属性信息可以包括本地训练集中的样本数据的数据量和分布特征、目标分节点的计算能力以及目标模型的规模信息中的至少一项。
其中,本地训练集中的样本数据的分布特征用于指示本地训练集中的样本数据的复杂程度或目标分节点的本地训练集的样本数据与其他分节点的本地训练集的样本数据之间的差异性。目标分节点的计算能力可以通过目标分节点的每秒浮点运算次数来表征,当然也可以通过其他参数来表征。其中,目标分节点的每秒浮点运算次数由目标分节点的硬件设备的性能决定,例如,当目标分节点为笔记本电脑时,则目标分节点的计算能力由笔记本电脑的CPU芯片、内存大小、带宽等决定。目标模型的规模信息用于指示目标模型的大小,例如,该目标模型的规模信息包括目标模型的层数或目标模型的参数量的多少。
对于不同的属性信息,目标分节点可以采用不同的方式来确定参数搜索粒度。其中,该参数搜索粒度用于指示将模型参数包括的多个参数值划分为参数组时每个参数组中包括的参数值的数量。其中,参数组中包括的参数值的数量越多,则说明参数搜索粒度越粗,参数组中包括的参数值的数量越少,则说明参数搜索粒度越细。
示例性地,当目标分节点的属性信息为本地训练集中的样本数据的数据量时,目标分节点可以根据本地训练集中样本数据的数据量的多少来确定参数搜索粒度。其中,如果本地训练集中的样本数据的数据量较少,则目标分节点可以选择较粗的参数搜索粒度,例如,将一个卷积层的参数值划分为一组,也即,作为一个参数组,或者将相邻的几个卷积层的参数值划分为一组。如果本地训练集中的样本数据的数据量较多时,则目标分节点可以选择较细的参数搜索粒度,例如,可以将每个参数值作为一个参数组。
示例性的,目标分节点中可以存储有中心节点预先设置的多个数据量范围以及每个数据量范围对应的参数搜索粒度,基于此,目标分节点可以先确定本地训练集中的样本数据的数据量属于预先设置的多个数据量范围中的哪个范围,然后将自身的样本数据的数据量所属的数据量范围对应的参数搜索粒度作为自身的参数搜索粒度。
例如,预先设置的数据量范围为三个,其中,第一数据量范围中样本数据的数据量为0至100,对应的参数搜索粒度为将两个卷积层的参数值作为一个参数组;第二数据量范围中样本数据的数据量为101至300,对应的参数搜索粒度为将一个卷积层的参数值作为一个参数组;第三数据量范围中样本数据的数据量为大于300,对应的参数搜索粒度为将一个参数值作为一个参数组。这样,当目标分节点的本地训练集中的样本数据的数据量为180时,对应第二数据量范围,这时目标分节点的参数搜索粒度为将一个卷积层的参数值作为一个参数组。
当目标分节点的属性信息为本地训练集中样本数据的分布特征,且该分布特征用于指示本地训练集中样本数据的复杂程度时,目标分节点可以根据本地训练集中样本数据的复杂程度来确定参数搜索粒度。其中,本地训练集中的样本数据越复杂,则选择越细的参数搜索粒度,本地训练集中的样本数据越简单,则选择越粗的参数搜索粒度。
示例性地,本地训练集的样本数据的复杂程度可以通过本地训练集中的样本数据的标注量来表征,其中,本地训练集中的样本数据的标注量是指本地训练集中各样本数据中包含的标注信息的数量总和。在这种情况下,目标分节点可以根据本地训练集中样本数据的标注量的多少来确定参数搜索粒度,当本地训练集中的样本数据的标注量较多时,表明本地训练集中的样本数据的复杂程度较高,可以选择较细的参数搜索粒度,当本地训练集中的样本数据的标注量较少时,表明本地训练集中的样本数据的复杂程度较低,可以选择较粗的参数搜索粒度。
例如,目标分节点中可以存储有中心节点预先设置的多个标注量范围以及每个标注量范围对应的参数搜索粒度,模型训练过程中,目标分节点可以先确定自身的本地训练集中的样本数据的标注量属于预先设置的多个标注量范围中的哪个范围,然后将自身样本数据的标注量所属的标注量范围对应的参数搜索粒度作为自身的参数搜索粒度。
可选地,本地训练集中的样本数据的复杂程度也可以通过本地训练集中的样本数据之间的相关关系来表征。在这种情况下,目标分节点可以先获取本地训练集中的样本数据的相似度矩阵,之后,根据该相似度矩阵得到本地训练集中的样本数据之间的相似度,该相似度为一个0至1之间的数值,数值越大说明本地训练集中的样本数据的相似度越高,则本地训练集中的样本数据的复杂程度越低,这时可以选择较粗的参数搜索粒度,相应地,数值越小,说明本地训练集中的样本数据的相似度越低,则本地训练集中的样本数据的复杂度越高,这时可以选择较细的参数搜索粒度。
例如,当本地训练集的样本数据的相似度为0.4至0.8之间的一个数值时,目标分节点可以选择较粗的参数搜索粒度,例如,将一个卷积层的参数值作为一个参数组,当本地训练集中的样本数据的相似度为0至0.4之间的一个数值时,可以选择较细的参数搜索粒度,例如,将每个参数值作为一个参数组。
可选地,当目标分节点的属性信息为本地训练集中样本数据的分布特征,且该分布特征用于指示目标分节点的本地训练集的样本数据与其他分节点的本地训练集的样本数据之间的差异性时,目标分节点可以根据该差异性来确定参数搜索粒度。其中,目标分节点的本地训练集的样本数据与其他分节点的本地训练集的样本数据之间差异性越大,则选择越细的参数搜索粒度,目标分节点的本地训练集的样本数据与其他分节点的本地训练集的样本数据之间差异性越小,则选择越粗的参数搜索粒度。
示例性地,目标分节点的本地训练集的样本数据与其他分节点的本地训练集的样本数据之间的差异性可以通过目标分节点的本地模型参数与接收的来自中心节点的全局模型参数的余弦相似度来表征。在这种情况下,目标分节点可以用本地模型参数中的参数值减去全局模型参数中对应的参数值,得到第一向量,再通过本地训练集对目标模型进行训练来对本地模型参数进行更新,得到参考模型参数。之后,目标分节点用本地模型参数中的参数值减去参考模型参数中对应的参数值,得到第二向量。计算第一向量和第二向量的余弦相似度。其中,该余弦相似度为一个位于-1至1之间的数值。余弦相似度越大,表明目标分节点的本地训练集中的样本数据与其它分节点的本地训练集的样本数据的差异性越小,则目标分节点的本地训练集的样本数据的复杂程度越低,此时,则选择越粗的参数搜索粒度,余弦相似度越小,表明目标分节点的本地训练集的样本数据与其它分节点的本地训练集的样本数据的差异性越大,则目标分节点的本地训练集中的样本数据的复杂程度越高,此时,则选择越细的参数搜索粒度。例如,当目标分节点的余弦相似度小于0时,可以选择较细的参数搜索粒度,当目标分节点的余弦相似度大于等于0时,可以选择较粗的参数搜索粒度。
当目标分节点的属性信息为自身的计算能力时,目标分节点可以根据自身计算能力的强弱来确定参数搜索粒度。需要说明的是,目标分节点选择的参数搜索粒度越细,后续步骤中本地模型参数划分出的本地参数组的数量越多,生成每个本地参数组的待优化替换概率时所需要的计算量就越大,也就要求目标分节点具有更强计算能力。基于此。当目标分节点的计算能力较强时,可以选择较细的参数搜索粒度,当目标分节点的计算能力较弱时,可以选择较粗的参数搜索粒度。
示例性的,当目标分节点的计算能力通过目标分节点的每秒浮点运算次数来表征时,该目标分节点中可以存储有中心节点预先设置的多个浮点运算次数范围以及每个浮点运算次数范围对应的参数搜索粒度。基于此,目标分节点可以先确定自身的每秒浮点运算次数属于预先设置的多个浮点运算次数范围中的哪个范围,然后将自身的每秒浮点运算次数所属的浮点运算次数范围对应的参数搜索粒度作为自身的参数搜索粒度。
当目标模型的属性信息为目标模型的规模信息时,可以根据目标模型的层数的多少或目标模型的参数量的多少来确定参数搜索粒度,当目标模型的层数较多或参数量较多时,可以选择较粗的参数搜索粒度;当目标模型的层数较少或参数量较少时,可以选择较细的参数搜索粒度。
例如,当目标模型的层数在3至10层时,目标分节点确定的参数搜索粒度可以为将一个参数值作为一个参数组;当目标模型的层数为11至20层时,目标分节点确定的参数搜索粒度可以为将一个层的参数值作为一个参数组;当目标模型的层数大于20层,目标分节点可以将两个或三个相邻层的参数值作为一个参数组。
可选地,在一些可能的实现方式中,目标分节点也可以将自身属性信息中的多种信息进行结合来确定参数搜索粒度,本申请实施例在此不再赘述。
3022:根据参数搜索粒度,对本地模型参数中的多个本地参数值进行分组,得到多个本地参数组。
在确定参数搜索粒度之后,目标分节点可以根据该参数搜索粒度所指示的参数组内包含的参数值的数量,对本地模型参数的多个本地参数值进行分组,得到多个本地参数组。
例如,如果该参数搜索粒度指示一个参数组内包含一个层的参数值,则当该目标模型包括m个层时,根据该参数搜索粒度,可以将m个层中每个层的本地参数值作为一个本地参数组,从而得到m个本地参数组。
另外,目标分节点还可以采用相同的参数搜索粒度对全局模型参数包括的多个全局参数值进行分组,从而得到相同数量且与多个本地参数组一一对应的多个全局参数组。
3023:获取多个本地参数组中每个本地参数组的待优化替换概率。
在一种实现方式中,目标分节点可以随机生成每个本地参数组的待优化替换概率,或者,该目标分节点也可以将每个本地参数组的待优化替换概率设置为预设值。
3024:根据本地训练集对每个参数组的待优化替换概率进行迭代优化,得到每个参数组的替换概率。
示例性地,目标分节点可以通过下述步骤A-E根据本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率。
A:根据每个本地参数组的待优化替换概率、多个本地参数组和全局模型参数,确定验证模型参数。
在本申请实施例中,目标分节点可以将本地模型参数中待优化替换概率大于第一阈值的本地参数组替换为全局模型参数中对应的全局参数组,将替换后的本地模型参数作为验证模型参数。
例如,对目标模型的本地模型参数中的多个本地参数值进行划分后得到的本地参数组为每个本地参数组的待优化替换概率为对全局模型参数包括的全局参数值划分得到的全局参数组为ω1,ω2,ω3,…,ωm,当待优化替换概率/>大于第一阈值时,目标分节点则可以将本地模型参数中的参数组/>分别替换为全局模型参数中对应的全局参数组ω1,ω2,得到验证模型参数ω1,ω2,/>
B:将目标模型的本地模型参数替换为验证模型参数,得到验证模型。
在得到验证模型参数之后,目标分节点可以将目标模型的本地模型参数包括的多个本地参数值直接替换为验证模型参数包括的参数值,从而得到验证模型。
C:根据本地训练集对验证模型进行训练,得到更新后的验证模型。
在得到验证模型之后,目标分节点可以根据本地训练集对该验证模型进行训练,以对该验证模型的模型参数进行更新,从而得到更新后的验证模型。其中,目标分节点可以根据本地训练集对该验证模型进指定次数的迭代训练,以得到更新后的验证模型,或者,目标分节点可以根据本地训练集进行多轮迭代训练,直至更新后的验证模型收敛为止。
在得到更新后的验证模型后,在一种实现方式中,目标分节点可以将更新后的验证模型的损失函数与预设的参考损失值进行比较分节点的要求,如果该更新后的验证模型的损失函数不大于预设的参考损失值时,则该更新后的验证模型满足目标分节点的要求,此时,目标分节点可以将当前的各个本地参数组的待优化替换概率作为最终的替换概率。如果该更新后的验证模型的损失函数大于预设的参考损失值时,则说明该更新后的验证模型不满足目标分节点的要求,此时,目标分节点可以采用强化学习、重参数化方法、多臂老虎机法或其他能够对待优化替换概率进行更新的方法对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为待优化替换概率,并返回执行上述步骤A至C,直至验证模型的损失函数满足目标分节点的要求后,将最后一次更新后的概率作为每个本地参数组的替换概率。
可选地,在另一种可能的实现方式中,在得到更新后的验证模型之后,目标分节点还可以通过下述步骤D和E来对更新后的验证模型进行测试,进而根据测试结果来对本地参数组的待优化替换概率进行更新。
D:根据验证集对更新后的验证模型进行测试。
目标分节点采用验证集中的样本数据对更新后的验证模型进行测试,得到测试结果,之后,目标分节点可以判断测试结果是否满足参考条件。
其中,参考条件可以为模型的损失函数、精准率、召回率、准确率、平均交并比等能够评价模型好坏的指标,本申请实施例对此不做限定。
例如,当参考条件为精准率时,目标分节点可以将验证集中的各个样本数据输入至该验证模型中,验证模型对各个样本数据进行识别,从而输出各个样本数据的识别结果。之后,目标分节点可以统计验证模型得到的识别结果中正确的识别结果的数量与识别结果的总数量之间的比例,从而得到该验证模型的精准率,也即,该验证模型的测试结果。之后,将该验证模型的精准率与预设的参考精准率进行比较,如果该验证模型的精准率大于预设的参考精准率时,则表明该验证模型的测试结果满足参考条件,如果该验证模型的精准率不大于预设的参考精准率,则表明该验证模型的测试结果不满足参考条件。
E:如果测试结果不满足参考条件,对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为待优化替换概率,并返回执行步骤A,直至测试结果满足参考条件时,将最后一次更新后的概率作为每个本地参数组的替换概率。
如果通过上述步骤D确定测试结果不满足参考条件,则目标分节点可以采用强化学习、重参数化方法、多臂老虎机法或其他能够对待优化替换概率进行更新的方法对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为待优化替换概率,并返回执行步骤A,直至验证模型的测试结果满足参考条件为止,将最后一次更新后的概率作为每个本地参数组的替换概率。
在通过上述步骤获得每个本地参数组的替换概率之后,将本地模型参数中每个本地参数组的替换概率依次和第一阈值进行比较,如果替换概率大于第一阈值,则将本地模型参数中该替换概率对应的本地参数组替换为对应的全局参数组,如果替换概率不大于第一阈值,则保留本地模型参数中该替换概率对应的本地参数组,将本地模型参数中替换概率大于第一阈值的本地参数组全部替换为对应的全局参数组后,得到目标模型的融合模型参数。其中,替换概率大于第一阈值的本地参数组的数量可能为一个,也可能为多个,本申请实施例对此不作限定。
在另一种实现方式中,目标分节点的参数融合规则可以为根据目标分节点的属性信息预先设置的本地模型参数中待替换的本地参数组的索引,基于此,目标分节点可以根据预先设置的待替换的本地参数组的索引将本地模型参数中待替换的本地参数组替换为全局模型参数中对应的参数组,得到融合模型参数。其中,待替换的本地参数组的索引可以为一个或多个卷积层的参数组的索引。
需要说明的是,在该种实现方式中,待替换的本地参数组的索引可以是人工基于目标分节点的属性信息预先配置在该目标分节点中,也可以是在对本地模型参数中的参数值进行划分得到本地参数组之后,目标分节点自动根据自身的属性信息确定得到的,本申请实施例对此不做限定。
其中,当目标分节点自动根据自身的属性信息确定待替换的本地参数组的索引时,该目标分节点可以参考上述基于自身的属性信息确定参数搜索粒度的实现原理来控制待替换的本地参数组的索引的数量,也即控制待替换的本地参数组的数量,进而基于确定的索引数量来选取待替换的本地参数组的索引。
示例性地,当目标分节点的属性信息为本地训练集中样本数据的数据量时,如果本地训练集中样本数据的数据量较少,则目标分节点可以选取较少数量的索引作为预先设置的待替换的本地参数组的索引。例如,当本地模型参数中一层的参数值为一个本地参数组时,如果本地训练集中的数据量较少,则目标分节点可以选取一个本地参数组的索引,也即,将一个本地参数组作为待替换的参数组。如果本地训练集中的样本数据的数据量较多,则目标分节点可以选取多个本地参数组的索引作为待替换的本地参数组的索引,例如,将第2、第3、第6和第7层这4个层中的本地参数组的索引预先设置为待替换参数组的索引。
当目标分节点的属性信息为本地训练集中的样本数据的分布特征时,如果该分布特征指示本地训练集中的样本数据较为复杂或与其他节点的样本数据的差异性较大,则目标分节点可以选择较多的本地参数组的索引作为待替换参数组的索引。如果该分布特征指示本地训练集中的样本数据较为简单或与其他节点的样本数据的差异性较小,则目标分节点可以选择较少的本地参数组的索引作为待替换参数组的索引。其中,判断本地训练集中的样本数据的复杂程度或与其他节点的样本数据的差异性大小的方式可以参考前文介绍,本申请实施例在此不再赘述。
当目标分节点的属性信息为目标分节点的计算能力时,如果目标分节点自身的计算能力较强,则可以选择较多的本地参数组的索引作为待替换参数组的索引,如果目标分节点自身的计算能力较弱,则可以选择较少的本地参数组的索引作为待替换参数组的索引。其中,判断目标分节点自身计算能力强弱的方式可以参考前文介绍,本申请实施例在此不再赘述。
当目标分节点的属性信息为目标模型的规模信息时,如果该规模信息指示目标模型的层数较多或目标模型参数的数量较多,则目标分节点可以选择较多的本地参数组的索引作为待替换参数组的索引,如果该规模信息指示目标模型的层数较少或目标模型参数的数量较少,则目标分节点可以选择较少的本地参数组的索引作为待替换参数组的索引。其中,判断目标模型的层数的多少或参数来的多少的方式可以参考前文介绍,本申请实施例在此不再赘述。
步骤303:将目标模型的本地模型参数更新为融合模型参数,并根据目标分节点的本地训练集对更新后的目标模型进行训练。
在得到融合模型参数之后,目标分节点将目标模型的本地模型参数全部替换为融合模型参数,之后,采用本地训练集中的样本数据对模型参数为融合模型参数的目标模型进行训练,并将训练后获得的模型参数作为更新后的本地模型参数,并上传至中心节点。中心节点在接收到目标分节点上传的本地模型参数以及其他分节点上传的本地模型参数之后,可以继续采用步骤301的方法对接收到的更新后的本地模型参数进行加权平均,得到更新后的全局模型参数,并再次将更新后的全局模型参数下发至参与学习的各个分节点,相应地,各个分节点可以采用上述步骤301-303的方法继续对自身的目标模型进行训练。如此循环多轮,直至目标分节点上的目标模型收敛为止,该目标分节点即得到了训练好的目标模型,此时,该目标分节点可以停止训练。
示例性的,参考图5,图5中的为目标分节点A的本地模型参数,ΩGlobal=[ω1,ω2,ω3,…,ωm]为目标分节点接收的来自中心节点的全局模型参数,将本地模型参数/>与全局模型参数ΩGlobal=[ω1,ω2,ω3,…,ωm]进行融合后,获得的/>为融合模型参数,采用融合模型参数/>替换目标模型中的本地模型参数 后,采用目标分节点A的本地数据集对目标模型进行训练。
在本申请实施例中,目标分节点在接收到来自中心节点的全局模型参数后,将接收到的全局模型参数和目标分节点自身目标模型的本地模型参数进行融合获得融合模型参数,这样,获得的融合模型参数中同时包含了目标分节点的本地信息和全局信息,之后,采用同时包含了本地信息和全局信息的融合模型参数对目标模型进行更新并采用本地训练集对该目标模型进行训练,能够有效提高目标模型的精度,这样,即使在参与训练的各分节点采用的本地训练集中的样本数据分布不同或各分节点上的模型需要完成的任务不同的情况下,采用本申请实施例方法训练出的模型仍然能够满足各分节点的需求。
接下来对本申请实施例提供的模型训练方法和其他模型训练方法在两种不同应用场景下的目标模型的训练效果进行示例说明。
第一种应用场景:联邦学习系统中包括4个分节点和一个中心节点。其中,4个分节点和中心节点上均部署有能够完成分类任务的目标模型。4个分节点上的本地训练集中的图像分别为产品图、艺术图、剪贴画和真实图像这四种不同风格的图像,并且,每个分节点的本地训练集的图像中包含的物体的类别相同。例如,分节点1中的本地训练集中包含有铅笔的产品图,分节点2-4的本地训练集中分别包含有铅笔的艺术图、铅笔的剪贴画和铅笔的真实图像。在这种情况下,采用本申请实施例的方法联合训练得到的4个分节点上的目标模型与采用其他方法训练得到的目标模型的分类准确率如表1所示:
表1不同方法训练得到的各个分节点上的模型的分类准确率对比表
第二种应用场景:联邦学习系统中包括3个分节点和一个中心节点。其中,3个分节点和中心节点上均部署有用于完成检测任务的目标模型。3个分节点上的目标模型分别需要完成人脸检测、无人驾驶检测和交通标识符检测这三种不同的检测任务,由于3个分节点的检测任务不同,所以3分节点的本地训练集中的样本数据也不相同。在这种情况下,采用本申请实施例的方法联合训练得到的3个分节点上的目标模型与采用其他方法训练得到的目标模型在完成检测任务时的准确率如表2所示:
表2不同方法训练得到的各个分节点上的模型的检测准确率对比表
分节点1 | 分节点2 | 分节点3 | 平均准确率 | |
单节点模型训练方法 | 54.99 | 59.07 | 94.71 | 69.59 |
联邦平均模型训练方法 | 50.99 | 59.91 | 93.39 | 68.10 |
异构网络的联邦学习模型训练方法 | 48.32 | 61.98 | 92.28 | 67.73 |
联邦批归一化模型训练方法 | 52.58 | 60.44 | 94.28 | 69.10 |
本申请实施例的模型训练方法 | 56.04 | 63.44 | 94.25 | 71.24 |
由上述两种应用场景下的训练效果对比表可以看出,采用本申请实施例的训练方法对各个分节点的模型进行联合训练后,能够更好的保证多个分节点上的模型的识别准确率。
接下来对本申请实施例提供的模型训练装置进行介绍。
参见图6,本申请实施例提供了一种模型训练装置600,该装置600包括:
接收模块601,用于执行前述实施例中的步骤301;
融合模块602,用于执行前述实施例中的步骤302;
训练模块603,用于执行前述实施例中的步骤303。
可选地,融合模块602用于:
根据目标分节点的参数融合规则,对目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数。
可选地,目标分节点的参数融合规则包括基于目标分节点的属性信息确定的目标模型的本地模型参数中各个本地参数组的替换概率,融合模块602主要用于:
将本地模型参数中替换概率大于第一阈值的本地参数组替换为全局模型参数中对应的参数组,得到融合模型参数。
可选地,融合模块602还包括:
确定单元,用于根据目标分节点的属性信息,确定参数搜索粒度;
分组单元,用于根据参数搜索粒度,对本地模型参数包括的多个本地参数值进行分组,得到多个本地参数组;
获取单元,用于获取多个本地参数组中每个本地参数组的待优化替换概率;
迭代单元,用于根据本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率。
可选地,迭代单元主要用于:
根据每个本地参数组的待优化替换概率、多个本地参数组和全局模型参数,确定验证模型参数;
将目标模型的本地模型参数替换为验证模型参数,得到验证模型;
根据本地训练集对验证模型进行训练,得到更新后的验证模型;
根据验证集对更新后的验证模型进行测试;
如果测试结果不满足参考条件,对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为待优化替换概率,并返回执行根据每个本地参数组的待优化替换概率、多个本地参数组和全局模型参数,确定验证模型参数的步骤,直至测试结果满足参考条件时,将最后一次更新后的概率作为每个本地参数组的替换概率。
可选地,目标分节点的参数融合规则包括根据目标分节点的属性信息预先设置的待替换的本地参数组的索引,本地参数组为对本地模型参数包括的多个本地参数值进行分组得到,融合模块602还用于:
根据待替换的本地参数组的索引,将本地模型参数中待替换的本地参数组替换为全局模型参数中对应的参数组,得到融合模型参数。
可选地,目标分节点的属性信息包括本地训练集中的样本数据的数据量和分布特征、目标分节点的计算能力以及目标模型的规模信息中的至少一项。
综上所述,本发明实施例的目标分节点在接收到来自中心节点的全局模型参数后,将接收到的全局模型参数和目标分节点自身目标模型的本地模型参数进行融合获得融合模型参数,这样,获得的融合模型参数中同时包含了目标分节点的本地信息和全局信息,之后,采用同时包含了本地信息和全局信息的融合模型参数对目标模型进行更新并采用本地训练集对该目标模型进行训练,能够有效提高目标模型的精度,这样,即使在参与训练的各分节点采用的本地训练集中的样本数据分布不同或各分节点上的模型需要完成的任务不同的情况下,采用本申请实施例方法训练出的模型仍然能够满足各分节点的需求。
需要说明的是:上述实施例提供的模型训练装置在对目标模型进行训练时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的模型训练装置与模型训练方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意结合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如:同轴电缆、光纤、数据用户线(digital subscriber line,DSL))或无线(例如:红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质(例如:软盘、硬盘、磁带)、光介质(例如:数字通用光盘(digital versatile disc,DVD))、或者半导体介质(例如:固态硬盘(solid state disk,SSD))等。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
以上所述并不用以限制本申请实施例,凡在本申请实施例的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请实施例的保护范围之内。
Claims (15)
1.一种模型训练方法,其特征在于,应用于目标分节点,所述方法包括:
接收来自中心节点的目标模型的全局模型参数;
对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数,所述本地模型参数是指基于本地训练集对所述目标模型进行训练得到的模型参数;
将所述目标模型的本地模型参数更新为所述融合模型参数,并根据所述目标分节点的本地训练集对更新后的目标模型进行训练。
2.根据权利要求1所述的方法,其特征在于,所述对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数,包括:
根据所述目标分节点的参数融合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到所述融合模型参数。
3.根据权利要求2所述的方法,其特征在于,所述目标分节点的参数融合规则包括基于所述目标分节点的属性信息确定的所述目标模型的本地模型参数中各个本地参数组的替换概率,所述根据所述目标分节点的参数融合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数,包括:
将所述本地模型参数中替换概率大于第一阈值的本地参数组替换为所述全局模型参数中对应的参数组,得到所述融合模型参数。
4.根据权利要求3所述的方法,其特征在于,所述方法还包括:
根据所述目标分节点的属性信息,确定参数搜索粒度;
根据所述参数搜索粒度,对所述本地模型参数包括的多个本地参数值进行分组,得到多个本地参数组;
获取所述多个本地参数组中每个本地参数组的待优化替换概率;
根据所述本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率。
5.根据权利要求4所述的方法,其特征在于,所述根据所述本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率,包括:
根据每个本地参数组的待优化替换概率、所述多个本地参数组和所述全局模型参数,确定验证模型参数;
将所述目标模型的本地模型参数替换为所述验证模型参数,得到验证模型;
根据所述本地训练集对所述验证模型进行训练,得到更新后的验证模型;
根据验证集对所述更新后的验证模型进行测试;
如果测试结果不满足参考条件,对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为所述待优化替换概率,并返回执行所述根据每个本地参数组的待优化替换概率、所述多个本地参数组和所述全局模型参数,确定验证模型参数的步骤,直至所述测试结果满足所述参考条件时,将最后一次更新后的概率作为每个本地参数组的替换概率。
6.根据权利要求2所述的方法,其特征在于,所述目标分节点的参数融合规则包括根据所述目标分节点的属性信息预先设置的待替换的本地参数组的索引,所述本地参数组为对所述本地模型参数包括的多个本地参数值进行分组得到,所述根据所述目标分节点的参数组合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数,包括:
根据所述待替换的本地参数组的索引,将所述本地模型参数中待替换的本地参数组替换为所述全局模型参数中对应的参数组,得到所述融合模型参数。
7.根据权利要求3-6任一所述的方法,其特征在于,所述目标分节点的属性信息包括所述本地训练集中的样本数据的数据量和分布特征、所述目标分节点的计算能力以及所述目标模型的规模信息中的至少一项。
8.一种模型训练装置,其特征在于,应用于目标分节点,所述装置包括:
接收模块,用于接收来自中心节点的目标模型的全局模型参数;
融合模块,用于对所述目标模型的全局模型参数和本地模型参数进行融合,得到融合模型参数,所述本地模型参数是指基于本地训练集对所述目标模型进行训练得到的模型参数;
训练模块,用于将所述目标模型的本地模型参数更新为所述融合模型参数,并根据所述目标分节点的本地训练集对更新后的目标模型进行训练。
9.根据权利要求8所述的装置,其特征在于,所述融合模块用于:
根据所述目标分节点的参数融合规则,对所述目标模型的全局模型参数和本地模型参数进行融合,得到所述融合模型参数。
10.根据权利要求9所述的装置,其特征在于,所述目标分节点的参数融合规则包括基于所述目标分节点的属性信息确定的所述目标模型的本地模型参数中各个本地参数组的替换概率,所述融合模块主要用于:
将所述本地模型参数中替换概率大于第一阈值的本地参数组替换为所述全局模型参数中对应的参数组,得到所述融合模型参数。
11.根据权利要求10所述的装置,其特征在于,所融合模块还包括:
确定单元,用于根据所述目标分节点的属性信息,确定参数搜索粒度;
分组单元,用于根据所述参数搜索粒度,对所述本地模型参数包括的多个本地参数值进行分组,得到多个本地参数组;
获取单元,用于获取所述多个本地参数组中每个本地参数组的待优化替换概率;
迭代单元,用于根据所述本地训练集对每个本地参数组的待优化替换概率进行迭代优化,得到每个本地参数组的替换概率。
12.根据权利要求11所述的装置,其特征在于,所述迭代单元主要用于:
根据每个本地参数组的待优化替换概率、所述多个本地参数组和所述全局模型参数,确定验证模型参数;
将所述目标模型的本地模型参数替换为所述验证模型参数,得到验证模型;
根据所述本地训练集对所述验证模型进行训练,得到更新后的验证模型;
根据验证集对所述更新后的验证模型进行测试;
如果测试结果不满足参考条件,对每个本地参数组的待优化替换概率进行更新,将更新后的概率作为所述待优化替换概率,并返回执行所述根据每个本地参数组的待优化替换概率、所述多个本地参数组和所述全局模型参数,确定验证模型参数的步骤,直至所述测试结果满足所述参考条件时,将最后一次更新后的概率作为每个本地参数组的替换概率。
13.根据权利要求9所述的装置,其特征在于,所述目标分节点的参数融合规则包括根据所述目标分节点的属性信息预先设置的待替换的本地参数组的索引,所述本地参数组为对所述本地模型参数包括的多个本地参数值进行分组得到,所述融合模块主要用于:
根据所述待替换的本地参数组的索引,将所述本地模型参数中待替换的本地参数组替换为所述全局模型参数中对应的参数组,得到所述融合模型参数。
14.根据权利要求10-13任一所述的装置,其特征在于,所述目标分节点的属性信息包括所述本地训练集中的样本数据的数据量和分布特征、所述目标分节点的计算能力以及所述目标模型的规模信息中的至少一项。
15.一种计算机可读存储介质,其特征在于,所述存储介质内存储有计算机程序,所述计算机程序被计算机执行时实现权利要求1-7任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111131777.8A CN115878989A (zh) | 2021-09-26 | 2021-09-26 | 模型训练方法、装置及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111131777.8A CN115878989A (zh) | 2021-09-26 | 2021-09-26 | 模型训练方法、装置及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115878989A true CN115878989A (zh) | 2023-03-31 |
Family
ID=85762708
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111131777.8A Pending CN115878989A (zh) | 2021-09-26 | 2021-09-26 | 模型训练方法、装置及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115878989A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116382599A (zh) * | 2023-06-07 | 2023-07-04 | 之江实验室 | 一种面向分布式集群的任务执行方法、装置、介质及设备 |
-
2021
- 2021-09-26 CN CN202111131777.8A patent/CN115878989A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116382599A (zh) * | 2023-06-07 | 2023-07-04 | 之江实验室 | 一种面向分布式集群的任务执行方法、装置、介质及设备 |
CN116382599B (zh) * | 2023-06-07 | 2023-08-29 | 之江实验室 | 一种面向分布式集群的任务执行方法、装置、介质及设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110554958B (zh) | 图数据库测试方法、系统、设备和存储介质 | |
US20170124178A1 (en) | Dynamic clustering for streaming data | |
CN106919957B (zh) | 处理数据的方法及装置 | |
CN109189876B (zh) | 一种数据处理方法及装置 | |
CN114418035A (zh) | 决策树模型生成方法、基于决策树模型的数据推荐方法 | |
CN115546525A (zh) | 多视图聚类方法、装置、电子设备及存储介质 | |
CN110135428B (zh) | 图像分割处理方法和装置 | |
CN110378739B (zh) | 一种数据流量匹配方法及装置 | |
CN113642727B (zh) | 神经网络模型的训练方法和多媒体信息的处理方法、装置 | |
CN114610825A (zh) | 关联网格集的确认方法、装置、电子设备及存储介质 | |
CN115878989A (zh) | 模型训练方法、装置及存储介质 | |
CN108830302B (zh) | 一种图像分类方法、训练方法、分类预测方法及相关装置 | |
CN116932935A (zh) | 地址匹配方法、装置、设备、介质和程序产品 | |
CN106651408B (zh) | 一种数据分析方法及装置 | |
CN115412401B (zh) | 训练虚拟网络嵌入模型及虚拟网络嵌入的方法和装置 | |
CN111738290A (zh) | 图像检测方法、模型构建和训练方法、装置、设备和介质 | |
CN110321435B (zh) | 一种数据源划分方法、装置、设备和存储介质 | |
CN117112880A (zh) | 信息推荐、多目标推荐模型训练方法、装置和计算机设备 | |
CN114417886A (zh) | 热点数据的处理方法、装置、电子设备及存储介质 | |
CN114398434A (zh) | 结构化信息抽取方法、装置、电子设备和存储介质 | |
CN110688508A (zh) | 图文数据扩充方法、装置及电子设备 | |
CN114693995B (zh) | 应用于图像处理的模型训练方法、图像处理方法和设备 | |
CN114241243B (zh) | 图像分类模型的训练方法、装置、电子设备和存储介质 | |
CN114547448B (zh) | 数据处理、模型训练方法、装置、设备、存储介质及程序 | |
US20220383626A1 (en) | Image processing method, model training method, relevant devices and electronic device |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |