CN117993478A - 基于双向知识蒸馏和联邦学习的模型训练方法及装置 - Google Patents

基于双向知识蒸馏和联邦学习的模型训练方法及装置 Download PDF

Info

Publication number
CN117993478A
CN117993478A CN202410130275.0A CN202410130275A CN117993478A CN 117993478 A CN117993478 A CN 117993478A CN 202410130275 A CN202410130275 A CN 202410130275A CN 117993478 A CN117993478 A CN 117993478A
Authority
CN
China
Prior art keywords
model
global
client
knowledge distillation
local
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202410130275.0A
Other languages
English (en)
Inventor
吴少智
陈宝智
刘欣刚
苏涵
王婷婷
冯承霖
张立澄
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Yangtze River Delta Research Institute of UESTC Huzhou
Original Assignee
Yangtze River Delta Research Institute of UESTC Huzhou
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Yangtze River Delta Research Institute of UESTC Huzhou filed Critical Yangtze River Delta Research Institute of UESTC Huzhou
Priority to CN202410130275.0A priority Critical patent/CN117993478A/zh
Publication of CN117993478A publication Critical patent/CN117993478A/zh
Pending legal-status Critical Current

Links

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了一种基于双向知识蒸馏和联邦学习的模型训练方法、装置、电子设备及存储介质。该方法由客户端执行,包括:接收服务器发送的全局模型;对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。上述技术方案,提升了模型训练精度。

Description

基于双向知识蒸馏和联邦学习的模型训练方法及装置
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于双向知识蒸馏和联邦学习的模型训练方法、装置、电子设备及存储介质。
背景技术
随着深度学习技术的发展,深度学习模型被广泛应用于各种预测任务中。
在实现本发明的过程中,发明人发现现有技术中至少存在以下技术问题:现有基于联邦学习的深度学习模型训练方法,存在模型训练精度低的问题。
发明内容
本发明提供了一种基于双向知识蒸馏和联邦学习的模型训练方法、装置、电子设备及存储介质,以提升模型训练精度。
根据本发明的一方面,提供了一种基于双向知识蒸馏和联邦学习的模型训练方法,由客户端执行,包括:
接收服务器发送的全局模型;
对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;
将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
根据本发明的另一方面,提供了一种基于双向知识蒸馏和联邦学习的模型训练方法,由服务器执行,包括:
将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器;
对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
根据本发明的另一方面,提供了一种基于双向知识蒸馏和联邦学习的模型训练装置,由客户端执行,包括:
全局模型接收模块,用于接收服务器发送的全局模型;
双向知识蒸馏训练模块,用于对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;
中间全局模型发送模块,用于将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
根据本发明的另一方面,提供了一种基于双向知识蒸馏和联邦学习的模型训练装置,由服务器执行,包括:
全局模型发送模块,用于将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器;
模型聚合模块,用于对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
根据本发明的另一方面,提供了一种电子设备,所述电子设备包括:
至少一个处理器;
以及与所述至少一个处理器通信连接的存储器;
其中,所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行本发明任一实施例所述的基于双向知识蒸馏和联邦学习的模型训练方法。
根据本发明的另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现本发明任一实施例所述的基于双向知识蒸馏和联邦学习的模型训练方法。
本发明实施例的技术方案,通过双向知识蒸馏使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,提升了模型训练精度。
应当理解,本部分所描述的内容并非旨在标识本发明的实施例的关键或重要特征,也不用于限制本发明的范围。本发明的其它特征将通过以下的说明书而变得容易理解。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是根据本发明实施例一提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图;
图2是根据本发明实施例二提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图;
图3是根据本发明实施例三提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图;
图4是根据本发明实施例四提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图;
图5是根据本发明实施例五提供的一种基于双向知识蒸馏和联邦学习的模型训练装置的结构示意图;
图6是根据本发明实施例六提供的一种基于双向知识蒸馏和联邦学习的模型训练装置的结构示意图;
图7是根据本发明实施例提供的一种基于双向知识蒸馏和联邦学习的模型训练系统的示意图;
图8是实现本发明实施例的基于双向知识蒸馏和联邦学习的模型训练方法的电子设备的结构示意图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“初始”、“目标”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。本申请技术方案中对数据的获取、存储、使用、处理等均符合国家法律法规的相关规定。
实施例一
图1为本发明实施例一提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图,本实施例可适用于应对数据特征非独立同分布数据的个性化联邦学习的情况,该方法可以由基于双向知识蒸馏和联邦学习的模型训练装置来执行,该基于双向知识蒸馏和联邦学习的模型训练装置可以采用硬件和/或软件的形式实现,该基于双向知识蒸馏和联邦学习的模型训练装置可配置于客户端中。如图1所示,该方法包括:
S110、接收服务器发送的全局模型。
其中,对于本公开实施例中的联邦学习系统,可以包括多个客户端和一个服务器,各客户端与服务器之间通讯连接,客户端可以为个人计算机、手机等电子设备。
具体地,对于任一客户端,可以接收服务发送的全局模型或者全局模型的模型参数,以供客户端进行模型训练。
S120、对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型。
其中,双向知识蒸馏可以使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,从而可以提升中间全局模型的训练精度。
S130、将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
具体地,各客户端可以将中间全局模型发送至服务器,服务器对各客户端发送的中间全局模型进行模型聚合,得到预测性能更优的目标全局模型。
本发明实施例的技术方案,通过双向知识蒸馏使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,提升了模型训练精度。
实施例二
图2为本发明实施例二提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图,本实施例的方法与上述实施例中提供的基于双向知识蒸馏和联邦学习的模型训练方法中各个可选方案可以结合。本实施例提供的基于双向知识蒸馏和联邦学习的模型训练方法进行了进一步优化。可选的,所述对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,包括:以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行模型训练,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行模型训练,得到中间全局模型。
如图2所示,该方法包括:
S210、接收服务器发送的全局模型。
S220、以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行模型训练,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行模型训练,得到中间全局模型。
其中,特征提取器是指模型中用于特征提取的网络,可以包括卷积层、池化层等,在此不做具体限定。
需要说明的是,基于双向知识蒸馏,可以让客户端局部模型向全局模型学习数据特征提取过程,同时让全局模型向客户端局部模型学习客户端本地数据特征,从而提升模型训练精度。
可选地,以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行模型训练,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行模型训练,得到中间全局模型,包括:基于本地数据对所述全局模型和所述位于客户端的局部模型进行监督学习,得到监督学习损失;以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行知识蒸馏学习,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行知识蒸馏学习,得到知识蒸馏损失;基于所述监督学习损失和所述知识蒸馏损失,对所述全局模型和所述位于客户端的局部模型进行梯度下降,直至满足模型训练停止条件,得到中间全局模型。
示例性地,客户端本地训练过程包括:根据本地数据对全局模型和位于客户端的局部模型进行监督学习,并计算得到监督学习损失,进而以位于客户端的局部模型的特征提取器为学生,以全局模型的特征提取器为教师进行知识蒸馏学习,同时以全局模型为学生,以位于客户端的局部模型为教师进行知识蒸馏学习,并计算得到知识蒸馏损失,将监督学习损失和知识蒸馏损失之和,对全局模型和位于客户端的局部模型进行梯度下降,完成预先设定次数的迭代训练后,得到训练完成的中间全局模型。
S230、将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
本发明实施例的技术方案,基于双向知识蒸馏,可以让客户端局部模型向全局模型学习数据特征提取过程,同时让全局模型向客户端局部模型学习客户端本地数据特征,从而提升模型训练精度。
实施例三
图3为本发明实施例三提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图,本实施例可适用于应对数据特征非独立同分布数据的个性化联邦学习的情况,该方法可以由基于双向知识蒸馏和联邦学习的模型训练装置来执行,该基于双向知识蒸馏和联邦学习的模型训练装置可以采用硬件和/或软件的形式实现,该基于双向知识蒸馏和联邦学习的模型训练装置可配置于服务器中。如图3所示,该方法包括:
S310、将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器。
S320、对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
示例性地,服务器可以将本轮全局模型广播至各客户端,并基于预设比例选取客户端参与本轮训练。任一客户端的本地训练过程包括:根据本地数据对全局模型和位于客户端的局部模型进行监督学习,并计算得到监督学习损失,进而以位于客户端的局部模型的特征提取器为学生,以全局模型的特征提取器为教师进行知识蒸馏学习,同时以全局模型为学生,以位于客户端的局部模型为教师进行知识蒸馏学习,并计算得到知识蒸馏损失,将监督学习损失和知识蒸馏损失之和,对全局模型和位于客户端的局部模型进行梯度下降,完成预先设定次数的迭代训练后,得到训练完成的中间全局模型。服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
本发明实施例的技术方案,通过双向知识蒸馏使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,提升了模型训练精度。
实施例四
图4为本发明实施例四提供的一种基于双向知识蒸馏和联邦学习的模型训练方法的流程图,本实施例的方法与上述实施例中提供的基于双向知识蒸馏和联邦学习的模型训练方法中各个可选方案可以结合。本实施例提供的基于双向知识蒸馏和联邦学习的模型训练方法对进行了进一步优化。可选的,所述对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型,包括:根据各客户端局部模型预测器的逆网络确定各客户端的全局模型聚合权重;基于所述各客户端的全局模型聚合权重,对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
如图4所示,该方法包括:
S410、将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器。
S420、根据各客户端局部模型预测器的逆网络确定各客户端的全局模型聚合权重。
其中,逆网络可以为全连接层的逆网络,可以用于拟合客户端本地数据的条件概率分布,客户端本地数据的条件概率分布可以用于评估客户端本地数据特征之间的差异。
可选地,根据各客户端局部模型预测器的逆网络确定各客户端的全局模型聚合权重,包括:确定各客户端局部模型预测器的逆网络对应的全局逆网络的条件概率分布;基于全局模型的损失函数,确定全局逆网络的条件概率分布下的各客户端的全局模型聚合权重。
示例性地,服务器模型聚合步骤可以包括:确定各客户端局部模型预测器的逆网络的条件概率分布,对各客户端局部模型预测器的逆网络的条件概率分布进行平均处理,得到全局逆网络的条件概率分布,进而通过全局模型的损失函数,求解全局逆网络的条件概率分布下的最优权重,从而得到各客户端的全局模型聚合权重,其中,全局模型的损失函数可以为最优权重求解公式可以如下:
其中,α*表示最优权重,Qglobal(z|y)表示全局逆网络的条件概率分布,z表示特征提取器提取得到的数据特征,y表示预测器的预测结果,表示预测器,/>表示预测器参数,gglobal(y;wg,global)表示客户端局部模型预测器的逆网络,wg,global表示全局逆网络的参数,αi表示第i个客户端的全局模型的聚合权重,i∈[1,N],t表示全局训练迭代次数。
S430、基于所述各客户端的全局模型聚合权重,对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
本发明实施例的技术方案,在服务器上通过构造客户端个性化网络部分的逆网络来衡量客户端数据特征分布之间的差异,以此计算模型加权聚合的权重,实现了全局模型在每个客户端数据上预测性能的平衡以及使客户端局部模型对本地数据的具有专家级别的预测性能。
实施例五
图5为本发明实施例五提供的一种基于双向知识蒸馏和联邦学习的模型训练装置的结构示意图。如图5所示,该装置包括:
全局模型接收模块510,用于接收服务器发送的全局模型;
双向知识蒸馏训练模块520,用于对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;
中间全局模型发送模块530,用于将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
本发明实施例的技术方案,通过双向知识蒸馏使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,提升了模型训练精度。
在一些可选的实施方式中,双向知识蒸馏训练模块520,包括:
学生-教师双向训练单元,用于以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行模型训练,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行模型训练,得到中间全局模型。
在一些可选的实施方式中,学生-教师双向训练单元,还具体用于:
基于本地数据对所述全局模型和所述位于客户端的局部模型进行监督学习,得到监督学习损失;
以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行知识蒸馏学习,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行知识蒸馏学习,得到知识蒸馏损失;
基于所述监督学习损失和所述知识蒸馏损失,对所述全局模型和所述位于客户端的局部模型进行梯度下降,直至满足模型训练停止条件,得到中间全局模型。
本发明实施例所提供的基于双向知识蒸馏和联邦学习的模型训练装置可执行本发明任意实施例所提供的基于双向知识蒸馏和联邦学习的模型训练方法,具备执行方法相应的功能模块和有益效果。
实施例六
图6为本发明实施例六提供的一种基于双向知识蒸馏和联邦学习的模型训练装置的结构示意图。如图6所示,该装置包括:
全局模型发送模块610,用于将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器;
模型聚合模块620,用于对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
本发明实施例的技术方案,通过双向知识蒸馏使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,提升了模型训练精度。
在一些可选的实施方式中,模型聚合模块620,包括:
全局模型聚合权重确定单元,用于根据各客户端局部模型预测器的逆网络确定各客户端的全局模型聚合权重;
全局模型聚合单元,用于基于所述各客户端的全局模型聚合权重,对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
在一些可选的实施方式中,全局模型聚合权重确定单元,还具体用于:
确定各客户端局部模型预测器的逆网络对应的全局逆网络的条件概率分布;
基于全局模型的损失函数,确定所述全局逆网络的条件概率分布下的各客户端的全局模型聚合权重。
本发明实施例所提供的基于双向知识蒸馏和联邦学习的模型训练装置可执行本发明任意实施例所提供的基于双向知识蒸馏和联邦学习的模型训练方法,具备执行方法相应的功能模块和有益效果。
图7是根据本发明实施例提供的一种基于双向知识蒸馏和联邦学习的模型训练系统的示意图。该系统包括多个客户端和一个服务器。具体而言,服务器对全局模型进行初始化,进而服务器将本轮全局模型广播至所有客户端,进而服务器基于预设比例选取客户端参与本轮训练,对各客户端的预测器构建对应的逆网络与对参与本轮训练的客户端并行进行本地训练并行处理,以提升数据处理效率,降低资源冲突。对于任一客户端,客户端可以根据本地数据对全局模型和位于客户端的局部模型进行监督学习,并计算得到监督学习损失,进而以位于客户端的局部模型的特征提取器为学生,以全局模型的特征提取器为教师进行知识蒸馏学习,同时以全局模型为学生,以位于客户端的局部模型为教师进行知识蒸馏学习,并计算得到知识蒸馏损失,将监督学习损失和知识蒸馏损失之和,对全局模型和位于客户端的局部模型进行梯度下降,完成预先设定次数的迭代训练后,得到训练完成的中间全局模型,并将中间全局模型发送至服务器。服务器根据各客户端对应逆网络计算本轮聚合权重,将各客户端训练出的中间全局模型加权聚合,得到目标全局模型,在完成全局训练轮次的情况下,结束模型训练。
本发明实施例的技术方案,在客户端上通过双向知识蒸馏使客户端局部模型向全局模型学习,同时让全局模型向客户端局部模型学习,提升了模型训练精度。在服务器上通过构造客户端个性化网络部分的逆网络来衡量客户端数据特征分布之间的差异,以此计算模型加权聚合的权重,实现了全局模型在每个客户端数据上预测性能的平衡以及使客户端局部模型对本地数据的具有专家级别的预测性能。
实施例七
图8示出了可以用来实施本发明的实施例的电子设备10的结构示意图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字助理、蜂窝电话、智能电话、可穿戴设备(如头盔、眼镜、手表等)和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本发明的实现。
如图8所示,电子设备10包括至少一个处理器11,以及与至少一个处理器11通信连接的存储器,如只读存储器(ROM)12、随机访问存储器(RAM)13等,其中,存储器存储有可被至少一个处理器执行的计算机程序,处理器11可以根据存储在只读存储器(ROM)12中的计算机程序或者从存储单元18加载到随机访问存储器(RAM)13中的计算机程序,来执行各种适当的动作和处理。在RAM 13中,还可存储电子设备10操作所需的各种程序和数据。处理器11、ROM 12以及RAM 13通过总线14彼此相连。I/O接口15也连接至总线14。
电子设备10中的多个部件连接至I/O接口15,包括:输入单元16,例如键盘、鼠标等;输出单元17,例如各种类型的显示器、扬声器等;存储单元18,例如磁盘、光盘等;以及通信单元19,例如网卡、调制解调器、无线通信收发机等。通信单元19允许电子设备10通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
处理器11可以是各种具有处理和计算能力的通用和/或专用处理组件。处理器11的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的处理器、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。处理器11执行上文所描述的各个方法和处理,例如基于双向知识蒸馏和联邦学习的模型训练方法,该方法包括:
接收服务器发送的全局模型;
对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;
将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
在一些实施例中,基于双向知识蒸馏和联邦学习的模型训练方法可被实现为计算机程序,其被有形地包含于计算机可读存储介质,例如存储单元18。在一些实施例中,计算机程序的部分或者全部可以经由ROM 12和/或通信单元19而被载入和/或安装到电子设备10上。当计算机程序加载到RAM 13并由处理器11执行时,可以执行上文描述的基于双向知识蒸馏和联邦学习的模型训练方法的一个或多个步骤。备选地,在其他实施例中,处理器11可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行基于双向知识蒸馏和联邦学习的模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、现场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、系统级芯片(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本发明的方法的计算机程序可以采用一个或多个编程语言的任何组合来编写。这些计算机程序可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器,使得计算机程序当由处理器执行时使流程图和/或框图中所规定的功能/操作被实施。计算机程序可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本发明的上下文中,计算机可读存储介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的计算机程序。计算机可读存储介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。备选地,计算机可读存储介质可以是机器可读信号介质。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在电子设备上实施此处描述的系统和技术,该电子设备具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给电子设备。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)、区块链网络和互联网。
计算系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务中,存在的管理难度大,业务扩展性弱的缺陷。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发明中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本发明的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本发明保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本发明的精神和原则之内所作的修改、等同替换和改进等,均应包含在本发明保护范围之内。

Claims (10)

1.一种基于双向知识蒸馏和联邦学习的模型训练方法,其特征在于,由客户端执行,包括:
接收服务器发送的全局模型;
对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;
将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
2.根据权利要求1所述的方法,其特征在于,所述对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,包括:
以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行模型训练,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行模型训练,得到中间全局模型。
3.根据权利要求2所述的方法,其特征在于,所述以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行模型训练,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行模型训练,得到中间全局模型,包括:
基于本地数据对所述全局模型和所述位于客户端的局部模型进行监督学习,得到监督学习损失;
以所述位于客户端的局部模型的特征提取器为学生,以所述全局模型的特征提取器为教师进行知识蒸馏学习,同时以所述全局模型为学生,以所述位于客户端的局部模型为教师进行知识蒸馏学习,得到知识蒸馏损失;
基于所述监督学习损失和所述知识蒸馏损失,对所述全局模型和所述位于客户端的局部模型进行梯度下降,直至满足模型训练停止条件,得到中间全局模型。
4.一种基于双向知识蒸馏和联邦学习的模型训练方法,其特征在于,由服务器执行,包括:
将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器;
对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
5.根据权利要求4所述的方法,其特征在于,所述对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型,包括:
根据各客户端局部模型预测器的逆网络确定各客户端的全局模型聚合权重;
基于所述各客户端的全局模型聚合权重,对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
6.根据权利要求5所述的方法,其特征在于,所述根据各客户端局部模型预测器的逆网络确定各客户端的全局模型聚合权重,包括:
确定各客户端局部模型预测器的逆网络对应的全局逆网络的条件概率分布;
基于全局模型的损失函数,确定所述全局逆网络的条件概率分布下的各客户端的全局模型聚合权重。
7.一种基于双向知识蒸馏和联邦学习的模型训练装置,其特征在于,由客户端执行,包括:
全局模型接收模块,用于接收服务器发送的全局模型;
双向知识蒸馏训练模块,用于对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型;
中间全局模型发送模块,用于将所述中间全局模型发送至所述服务器,以使所述服务器对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
8.一种基于双向知识蒸馏和联邦学习的模型训练装置,其特征在于,由服务器执行,包括:
全局模型发送模块,用于将全局模型分别发送至各客户端,对于任一客户端,所述客户端对所述全局模型和位于客户端的局部模型进行双向知识蒸馏训练,得到中间全局模型,将所述中间全局模型发送至服务器;
模型聚合模块,用于对各客户端发送的中间全局模型进行模型聚合,得到目标全局模型。
9.一种电子设备,其特征在于,所述电子设备包括:
至少一个处理器;
以及与所述至少一个处理器通信连接的存储器;
其中,所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-3或者权利要求4-6中任一项所述的基于双向知识蒸馏和联邦学习的模型训练方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现权利要求1-3或者权利要求4-6中任一项所述的基于双向知识蒸馏和联邦学习的模型训练方法。
CN202410130275.0A 2024-01-30 2024-01-30 基于双向知识蒸馏和联邦学习的模型训练方法及装置 Pending CN117993478A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202410130275.0A CN117993478A (zh) 2024-01-30 2024-01-30 基于双向知识蒸馏和联邦学习的模型训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202410130275.0A CN117993478A (zh) 2024-01-30 2024-01-30 基于双向知识蒸馏和联邦学习的模型训练方法及装置

Publications (1)

Publication Number Publication Date
CN117993478A true CN117993478A (zh) 2024-05-07

Family

ID=90896775

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202410130275.0A Pending CN117993478A (zh) 2024-01-30 2024-01-30 基于双向知识蒸馏和联邦学习的模型训练方法及装置

Country Status (1)

Country Link
CN (1) CN117993478A (zh)

Similar Documents

Publication Publication Date Title
CN115147687A (zh) 学生模型训练方法、装置、设备及存储介质
CN113850394B (zh) 联邦学习方法、装置、电子设备及存储介质
CN114742237A (zh) 联邦学习模型聚合方法、装置、电子设备及可读存储介质
CN114065864A (zh) 联邦学习方法、联邦学习装置、电子设备以及存储介质
CN114860411B (zh) 多任务学习方法、装置、电子设备和存储介质
CN117993478A (zh) 基于双向知识蒸馏和联邦学习的模型训练方法及装置
CN115907926A (zh) 商品的推荐方法、装置、电子设备及存储介质
CN115359322A (zh) 一种目标检测模型训练方法、装置、设备和存储介质
CN114999665A (zh) 数据处理方法、装置、电子设备及存储介质
CN116933896B (zh) 一种超参数确定及语义转换方法、装置、设备及介质
CN116662788B (zh) 一种车辆轨迹处理方法、装置、设备和存储介质
CN117251295B (zh) 一种资源预测模型的训练方法、装置、设备及介质
CN115578583B (zh) 图像处理方法、装置、电子设备和存储介质
CN116662194A (zh) 软件质量度量方法、装置、设备及介质
CN114816758B (zh) 资源分配方法和装置
CN117933353A (zh) 强化学习模型训练方法、装置、电子设备及存储介质
CN115017145A (zh) 数据扩展方法、装置及存储介质
CN116823510A (zh) 节点影响力度量方法、装置、设备及存储介质
CN115658826A (zh) 一种轨迹停留点确定方法、装置、设备及存储介质
CN117851208A (zh) 一种芯片评估方法、装置、电子设备及介质
CN117593007A (zh) 一种异常交易识别方法、装置、设备及存储介质
CN113836242A (zh) 数据处理方法、装置、电子设备及可读存储介质
CN117611324A (zh) 信用评级方法、装置、电子设备和存储介质
CN116992150A (zh) 一种研发组件推荐方法、装置、设备及存储介质
CN114372624A (zh) 一种主体的效能预测方法、装置、存储介质及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination