CN115374954A - 一种基于联邦学习的模型训练方法、终端以及存储介质 - Google Patents
一种基于联邦学习的模型训练方法、终端以及存储介质 Download PDFInfo
- Publication number
- CN115374954A CN115374954A CN202210806260.2A CN202210806260A CN115374954A CN 115374954 A CN115374954 A CN 115374954A CN 202210806260 A CN202210806260 A CN 202210806260A CN 115374954 A CN115374954 A CN 115374954A
- Authority
- CN
- China
- Prior art keywords
- model
- parameters
- gradient
- client
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明提供一种基于联邦学习的模型训练方法、终端以及存储介质,该方法包括:S101:客户端接收服务端下发的模型参数,根据模型参数、本地数据计算模型的梯度;S102:通过梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型;S103:将模型优化后的参数发送至服务端,其中,服务端根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。本发明能够在提高模型准确度的同时,降低了客户端的数据泄漏风险和训练成本,收敛速度快、学习性能好,适应了模型训练要求。
Description
技术领域
本发明涉及模型训练技术领域,尤其涉及一种基于联邦学习的模型训练方法、终端以及存储介质。
背景技术
机器学习模型需要大量广泛的数据训练模型以提高模型准确度,但目前数据存放比较分散,考虑到数据安全各机构之间的数据也难以共享,给数据充分利用带来困难。而机构仅使用自身数据训练模型,模型效果很难得到提升,联邦学习技术通过同态加密算法将各机构间的模型参数进行安全交换,形成学习各个机构数据的全局模型,解决数据不出本地,又可以充分利用各家机构数据的问题。
然而,与集中式学习相比,联邦学习的分散式学习面临很多挑战,由于各机构或者各个设备之间数据分布不同,会出现非独立同分布问题,造成收敛速度慢和学习性能恶化,难以满足模型训练要求。
发明内容
为了克服现有技术的不足,本发明提出一种基于联邦学习的模型训练方法、终端以及存储介质,在模型训练过程中,通过服务端向客户端下发模型参数,客户端根据该模型参数和本地数据训练模型以适应本地数据,并在模型训练结束后将模型的参数发送给服务端,通过服务端汇总参数得到新的模型参数,通过该模型参数做进一步的循环训练,从而在提高模型准确度的同时,降低了客户端的数据泄漏风险和训练成本,收敛速度快、学习性能好,适应了模型训练要求。
为解决上述问题,本发明采用的一个技术方案为:一种基于联邦学习的模型训练方法,所述基于联邦学习的模型训练方法应用于客户端,包括:S101:客户端接收服务端下发的模型参数,根据模型参数、本地数据计算模型的梯度;S102:通过所述梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型;S103:将所述模型优化后的参数发送至服务端,其中,所述服务端根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。
进一步地,所述根据模型参数、本地数据计算模型的梯度的步骤具体包括:根据用户的数据分布选取部分本地数据,通过选取的本地数据计算无偏梯度,基于所述无偏梯度计算模型的梯度。
进一步地,所述通过所述梯度更新模型的参数的步骤具体包括:通过公式优化模型的参数,其中,表示模型经k轮循环训练本地迭代t-1次后的无偏梯度,k表示客户端与服务端之间循环训练的次数,t是客户端的本地迭代次数,i指客户端i,表示经k轮循环和t次迭代后模型的参数,α为学习步长。
进一步地,所述根据模型的参数计算优化后的模型的梯度和优化模型的步骤具体包括:通过公式计算优化后模型的梯度,根据所述梯度更新模型,其中,β为学习率,为经k轮循环和t-1次迭代后模型的梯度,表示经k轮循环和t次迭代后模型的参数。
进一步地,所述通过元模型参数进行蒸馏获取新的模型参数的步骤具体包括:通过公式进行蒸馏获取新的模型参数,其中,wk+1,j为模型经k轮循环j次迭代后的元模型参数,wk+1,j-1表示经k轮循环j-1次迭代后的元模型参数,kl表示Kullback-Leibler散度,σ是softmax函数,η表示步长,k是客户端与服务端之间的循环次数,j是迭代次数,Ak表示客户端教师模型数量,d表示未标记数据,为偏微分符号,表示第i个客户端的模型经k轮循环后的损失函数,f(wk+1,j-1,d)表示经k轮循环j-1次迭代后的元模型参数对应的损失函数。
进一步地,所述将新的模型参数发送给客户端进行循环训练的步骤之前还包括:判断循环训练的次数是否达到预设值;若是,则确定模型训练成功,停止循环训练;若否,则将新的模型参数发送给客户端进行循环训练。
基于相同的发明构思,本发明还提出一种基于联邦学习的模型训练方法,包括:S201:服务端选取待训练的客户端,将模型参数下发给所述客户端,其中,所述客户端根据模型参数计算模型的梯度,通过所述梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型,将所述模型优化后的参数发送至服务端;S202:根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。
基于相同的发明构思,本发明还提出一种智能终端,所述智能终端包括处理器、存储器,所述处理器与所述存储器通信连接,所述存储器存储有计算机程序,所述处理器通过所述计算机程序执行如上所述的基于联邦学习的模型训练方法。
基于相同的发明构思,本发明还提出一种计算机可读存储介质,所述计算机可读存储介质存储有程序数据,所述程序数据被用于执行如上所述的基于联邦学习的模型训练方法。
相比现有技术,本发明的有益效果在于:在模型训练过程中,通过服务端向客户端下发模型参数,客户端根据该模型参数和本地数据训练模型以适应本地数据,并在模型训练结束后将模型的参数发送给服务端,通过服务端汇总参数得到新的模型参数,通过该模型参数做进一步的循环训练,从而在提高模型准确度的同时,降低了客户端的数据泄漏风险和训练成本,收敛速度快、学习性能好,适应了模型训练要求。
附图说明
图1为本发明基于联邦学习的模型训练方法一实施例的流程图;
图2为本发明基于联邦学习的模型训练方法另一实施例的流程图;
图3为本发明基于联邦学习的模型训练方法又一实施例的流程图;
图4为本发明智能终端一实施例的结构图;
图5为本发明计算机可读存储介质一实施例的结构图。
具体实施方式
以下通过特定的具体实例说明本申请的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本申请的其他优点与功效。本申请还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本申请的精神下进行各种修饰或改变。需说明的是,通常在此处附图中描述和示出的各本公开实施例在不冲突的前提下,可相互组合,其中的结构部件或功能模块可以以各种不同的配置来布置和设计。因此,以下对在附图中提供的本公开的实施例的详细描述并非旨在限制要求保护的本公开的范围,而是仅仅表示本公开的选定实施例。基于本公开中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本公开保护的范围。
在本申请公开使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本公开。在本公开和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。还应当理解,本文中使用的术语“和/或”是指并包含一个或多个相关联的列出项目的任何或所有可能组合。
请参阅图1-2,其中,图1为本发明基于联邦学习的模型训练方法一实施例的流程图;图2为本发明基于联邦学习的模型训练方法另一实施例的流程图,结合图1-2对本发明的基于联邦学习的模型训练方法进行说明。
在本实施例中,应用基于联邦学习的模型训练方法的客户端可以为手机、电脑、服务器、云平台以及其他能够存储本地数据和利用接收的模型参数进行模型训练的智能终端。
具体的,基于联邦学习的模型训练方法包括:
S101:客户端接收服务端下发的模型参数,根据模型参数、本地数据计算模型的梯度。
在本实施例中,服务端可以为服务器以及其他能够与多个客户端通信连接进行模型训练的智能终端。服务端选取预设数量的客户端,向选取的客户端发送模型参数,客户端接收到该模型参数后,使用本地数据、模型参数对客户端中的模型进行训练,其中,通过计算模型的梯度的方式实现对模型的训练。
在本实施例中,因模型训练需要进行多次循环训练,如果每次梯度计算都要基于用户的所有数据,会需要较大的计算量。因此,本发明每次只根据用户的数据分布pi选取本地数据Di的一个部分来计算一个无偏梯度。具体的,根据模型参数、本地数据计算模型的梯度的步骤具体包括:根据用户的数据分布选取部分本地数据,通过选取的本地数据计算无偏梯度,基于无偏梯度计算模型的梯度。
在一个具体的实施例中,通过公式计算无偏梯度,其中,li(w;x,y)模型w在给定输入x∈Xi预测真实标签y∈Yi时的误差,x,y为本地数据Di中的数据,表示模型w的无偏梯度,:=表示迭代多次,|Di|表示本地数据Di中选取的样本数量,x,y为本地数据集Di中的样本。
S102:通过梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型。
在计算得到模型的梯度后,利用该梯度对模型进行优化,在本实施例中,通过梯度更新模型的参数的步骤具体包括:通过公式 优化模型的参数,其中,表示模型经k轮循环训练本地迭代t-1次后的无偏梯度,k表示客户端与服务端之间循环训练的次数,t是客户端的本地迭代次数,i指客户端i,表示经k轮循环和t次迭代后模型的参数,α为学习步长。
对模型优化完成后,通过优化后的模型计算模型的预测值和真实值之间的损失,即计算模型的梯度。其中,根据模型的参数计算优化后的模型的梯度和优化模型的步骤具体包括:通过公式计算优化后模型的梯度,根据梯度更新模型,其中,β为学习率,为经k轮循环和t-1次迭代后模型的梯度,表示经k轮循环和t次迭代后模型的参数,表示k轮循环和t次迭代后模型的参数。
其中,通过公式 计算平表示客户端第t次迭代时从本地数据Di中抽取的数据,且和中的数据不同。表示数据为时经k轮循环客户端第t-1次迭代后模型的二阶偏微分,表示数据为对经k轮循环客户端第t-1次迭代前模型的损失函数。
S103:将模型优化后的参数发送至服务端,其中,服务端根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。
在本实施例中,客户端与服务端可以通过有线连接或无线连接的方式进行数据传输,具体连接方式可根据实际需求进行设置,在此不做限定。
通过元模型参数进行蒸馏获取新的模型参数的步骤具体包括:通过公式进行蒸馏获取新的模型参数,其中,wk+1,j为模型经k轮循环j次迭代后的元模型参数,wk+1,j-1表示经k轮循环j-1次迭代后的元模型参数,kl表示Kullback-Leibler散度,σ是softmax函数,η表示步长,k是客户端与服务端之间的循环次数,1次迭代后的元模型参数对应的损失函数。其中,每次通过抽取客户端的小批量未标记数据进行循环训练。
假设客户端、服务端构成的系统包含p个不同的模型原型组,这些模型原型组可能在神经网络,结构和数值精度存在差异。通过集成蒸馏,每个模型架构组通过模型平均获取知识,通过输出信息交互达到横跨多个结构知识共享,在下一轮中,每个激活的客户端都会收到相应的融合原型模型。值得注意的是,由于融合发生在服务器端,所以不会给客户带来额外的负担和干扰。
本发明使用未标记数据或人工生成的样本来辅助所有参与客户端的知识提取,该方法可以应用于同构和异构设置,提高了模型训练方法的应用范围。
其中,将新的模型参数发送给客户端进行循环训练的步骤之前还包括:判断循环训练的次数是否达到预设值;若是,则确定模型训练成功,停止循环训练;若否,则将新的模型参数发送给客户端进行循环训练。
下面通过该模型训练方法的整体工作流程对本发明基于联邦学习的模型训练方法做进一步说明。
输入:
Input:初始化模型参数w0,活动用户比例r.
for k:0 to K-1 do
服务器随机选择大小为m的用户;
服务器发送wk(模型参数)到所有用户Ak;
for i∈Akdo
for t:1 toτdo
End for
End for
for j in{1,...N}do
选取一小批样本d(1)无标签数据集,(2)通过AVGLOGITS
wk+1←wk+1,N
End for。
有益效果:本发明基于联邦学习的模型训练方法在模型训练过程中,通过服务端向客户端下发模型参数,客户端根据该模型参数和本地数据训练模型以适应本地数据,并在模型训练结束后将模型的参数发送给服务端,通过服务端汇总参数得到新的模型参数,通过该模型参数做进一步的循环训练,从而在提高模型准确度的同时,降低了客户端的数据泄漏风险和训练成本,收敛速度快、学习性能好,适应了模型训练要求。
基于相同的发明构思,本发明还提出一种基于联邦学习的模型训练方法,请参阅图3,图3为本发明基于联邦学习的模型训练方法一实施例的流程图,结合图3对本发明基于联邦学习的模型训练方法进行说明。
在本实施例中,基于联邦学习的模型训练方法应用于客户端,包括:
S201:服务端选取待训练的客户端,将模型参数下发给客户端,其中,客户端根据模型参数计算模型的梯度,通过梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型,将模型优化后的参数发送至服务端。
S202:根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。
其中,该模型训练方法已经在上文的实施例中进行详细描述,在此不再赘述。
基于相同的发明构思,本发明还提出一种智能终端,请参阅图4,图4为本发明智能终端一实施例的结构图,结合图4对本发明的智能终端进行具体说明。
在本实施例中,智能终端包括处理器、存储器,所述处理器与所述存储器通信连接,存储器存储有计算机程序,计算机程序被用于执行如上述实施例所述的基于联邦学习的模型训练方法。
在一些实施例中,存储器可能包括但不限于高速随机存取存储器、非易失性存储器。例如一个或多个磁盘存储设备、闪存设备或其他非易失性固态存储设备。处理器可以是通用处理器,包括中央处理器(Central Processing Unit,简称CPU)、网络处理器(NetworkProcessor,简称NP)等;还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application Specific Integrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程功能器件、分立门或者晶体管功能器件、分立硬件组件。
基于相同的发明构思,本发明还提出一种计算机可读存储介质,请参阅图5,图5为本发明计算机可读存储介质一实施例的结构图,结合图5对本发明的计算机可读存储介质进行说明。
在本实施例中,计算机可读存储介质存储有程序数据,该程序数据被用于执行如上述实施例所述的基于联邦学习的模型训练方法。
其中,计算机可读存储介质可包括,但不限于,软盘、光盘、CD-ROM(紧致盘-只读存储器)、磁光盘、ROM(只读存储器)、RAM(随机存取存储器)、EPROM(可擦除可编程只读存储器)、EEPROM(电可擦除可编程只读存储器)、磁卡或光卡、闪存或适于存储机器可执行指令的其他类型的介质/机器可读介质。该计算机可读存储介质可以是未接入计算机设备的产品,也可以是已接入计算机设备使用的部件。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本发明。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其他实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。
Claims (10)
1.一种基于联邦学习的模型训练方法,其特征在于,所述基于联邦学习的模型训练方法应用于客户端,包括:
S101:客户端接收服务端下发的模型参数,根据模型参数、本地数据计算模型的梯度;
S102:通过所述梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型;
S103:将所述模型优化后的参数发送至服务端,其中,所述服务端根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。
2.如权利要求1所述的基于联邦学习的模型训练方法,其特征在于,所述根据模型参数、本地数据计算模型的梯度的步骤具体包括:
根据用户的数据分布选取部分本地数据,通过选取的本地数据计算无偏梯度,基于所述无偏梯度计算模型的梯度。
7.如权利要求6所述的基于联邦学习的模型训练方法,其特征在于,所述将新的模型参数发送给客户端进行循环训练的步骤之前还包括:
判断循环训练的次数是否达到预设值;
若是,则确定模型训练成功,停止循环训练;
若否,则将新的模型参数发送给客户端进行循环训练。
8.一种基于联邦学习的模型训练方法,其特征在于,所述基于联邦学习的模型训练方法应用于客户端,包括:
S201:服务端选取待训练的客户端,将模型参数下发给所述客户端,其中,所述客户端根据模型参数计算模型的梯度,通过所述梯度更新模型的参数,根据模型的参数计算优化后的模型的梯度和优化模型,将所述模型优化后的参数发送至服务端;
S202:根据参数进行参数聚合得到更新后的元模型参数,通过元模型参数进行蒸馏获取新的模型参数,将新的模型参数发送给客户端进行循环训练。
9.一种智能终端,其特征在于,所述智能终端包括处理器、存储器,所述处理器与所述存储器通信连接,所述存储器存储有计算机程序,所述处理器通过所述计算机程序执行如权利要求1-8任一项所述的基于联邦学习的模型训练方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有程序数据,所述程序数据被用于执行如权利要求1-8任一项所述的基于联邦学习的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210806260.2A CN115374954A (zh) | 2022-07-08 | 2022-07-08 | 一种基于联邦学习的模型训练方法、终端以及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210806260.2A CN115374954A (zh) | 2022-07-08 | 2022-07-08 | 一种基于联邦学习的模型训练方法、终端以及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115374954A true CN115374954A (zh) | 2022-11-22 |
Family
ID=84062074
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210806260.2A Pending CN115374954A (zh) | 2022-07-08 | 2022-07-08 | 一种基于联邦学习的模型训练方法、终端以及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115374954A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115879467A (zh) * | 2022-12-16 | 2023-03-31 | 浙江邦盛科技股份有限公司 | 一种基于联邦学习的中文地址分词方法及装置 |
-
2022
- 2022-07-08 CN CN202210806260.2A patent/CN115374954A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115879467A (zh) * | 2022-12-16 | 2023-03-31 | 浙江邦盛科技股份有限公司 | 一种基于联邦学习的中文地址分词方法及装置 |
CN115879467B (zh) * | 2022-12-16 | 2024-04-30 | 浙江邦盛科技股份有限公司 | 一种基于联邦学习的中文地址分词方法及装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111091199B (zh) | 一种基于差分隐私的联邦学习方法、装置及存储介质 | |
WO2022257730A1 (zh) | 实现隐私保护的多方协同更新模型的方法、装置及系统 | |
Liu et al. | DeepSlicing: Deep reinforcement learning assisted resource allocation for network slicing | |
CN114741611B (zh) | 联邦推荐模型训练方法以及系统 | |
EP3688673A1 (en) | Neural architecture search | |
CN111030861A (zh) | 一种边缘计算分布式模型训练方法、终端和网络侧设备 | |
WO2023124296A1 (zh) | 基于知识蒸馏的联合学习训练方法、装置、设备及介质 | |
CN112926897A (zh) | 基于联邦学习的客户端贡献计算方法和装置 | |
JP2023542901A (ja) | スパース性を誘導する連合機械学習 | |
CN115374954A (zh) | 一种基于联邦学习的模型训练方法、终端以及存储介质 | |
Liu et al. | Multi-job intelligent scheduling with cross-device federated learning | |
CN112187670A (zh) | 一种基于群体智能的网络化软件共享资源分配方法及装置 | |
CN113094180B (zh) | 无线联邦学习调度优化方法及装置 | |
CN108289115B (zh) | 一种信息处理方法及系统 | |
CN106327236B (zh) | 一种确定用户行动轨迹的方法及装置 | |
CN116911403B (zh) | 联邦学习的服务器和客户端的一体化训练方法及相关设备 | |
Chen et al. | Energy‐efficiency fog computing resource allocation in cyber physical internet of things systems | |
CN114691630B (zh) | 一种智慧供应链大数据共享方法及系统 | |
Liu et al. | Online quantification of input model uncertainty by two-layer importance sampling | |
CN114676272A (zh) | 多媒体资源的信息处理方法、装置、设备及存储介质 | |
CN113988158A (zh) | 一种基于ftrl和学习率的纵向联邦逻辑回归训练方法及装置 | |
CN114022731A (zh) | 基于drl的联邦学习节点选择方法 | |
CN114528893A (zh) | 机器学习模型训练方法、电子设备及存储介质 | |
CN112036418A (zh) | 用于提取用户特征的方法和装置 | |
CN112738815A (zh) | 一种可接入用户数的评估方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |