CN115034356A - 一种用于横向联邦学习的模型融合方法及系统 - Google Patents

一种用于横向联邦学习的模型融合方法及系统 Download PDF

Info

Publication number
CN115034356A
CN115034356A CN202210498743.0A CN202210498743A CN115034356A CN 115034356 A CN115034356 A CN 115034356A CN 202210498743 A CN202210498743 A CN 202210498743A CN 115034356 A CN115034356 A CN 115034356A
Authority
CN
China
Prior art keywords
model
global
iteration
node
local
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210498743.0A
Other languages
English (en)
Inventor
武星
裴洁
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Shanghai for Science and Technology
Original Assignee
University of Shanghai for Science and Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Shanghai for Science and Technology filed Critical University of Shanghai for Science and Technology
Priority to CN202210498743.0A priority Critical patent/CN115034356A/zh
Publication of CN115034356A publication Critical patent/CN115034356A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Abstract

本发明涉及横向联邦机器学习领域,公开了一种用于横向联邦学习的模型融合方法及系统。本发明中,首先,云服务器初始化全局服务模型以及所需的超参数,每个终端用户设备利用本地数据进行模型训练,得到更新的本地模型。其次,云服务器的仲裁调度模块,利用压缩策略选出最佳的本地网络模型,授权相应的终端用户设备节点上传本地局部模型到云服务器,以及根据仲裁激励机制制定模型融合策略。最后,云服务器聚集上传的终端用户设备模型,根据融合策略计算得到全局模型。本发明解决了横向联邦学习技术中,终端设备与云服务器频繁传递模型参数所导致的高通信开销,保证数据和模型的安全性,并提高模型融合的性能。

Description

一种用于横向联邦学习的模型融合方法及系统
技术领域
本发明涉及一种基于横向联邦学习的模型融合方法及系统,用于解决联邦学习技术中云服务节点与终端用户节点模型融合成全局模型所导致的模型劣化问题,属于横向联邦学习领域。
背景技术
机器学习作为人工智能领域的一个重要理论,被广泛应用于数据挖掘、语音识别和计算机视觉等各个领域。机器学习网络模型需要经过训练完成后得到最终模型才能投入使用。
终端用户设备计算能力有限,现有机器学习网络模型通过采用云服务的方式在云端完成训练。终端用户设备含有的图片、音频或者文本数据,需要集中上传至云服务中心来训练模型。然而,一方面这种方式需要较高的通信带宽来上传数据和较高的存储空间来管理数据,另一方面终端用户数据在上传过程中存在泄密的风险,上传的真实数据也容易泄露其隐私。从长远来看,这阻碍了机器学习技术的落地和应用。
为了保护终端用户设备的隐私数据,强化模型的学习能力,联邦学习应运而生。在联邦学习框架中,终端设备无需上传本地数据至云服务器,而只需共享本地训练的模型。云服务器收集终端用户设备的模型进行融合得到全局模型,经过多次迭代计算后得到最终的完备模型。然而在现有的横向联邦学习技术中,终端用户设备与云服务器之间需要频繁传递模型参数,导致较高的通信开销。
发明内容
本发明要解决的技术问题是:在现有的横向联邦学习技术中,终端用户设备与云服务器之间需要频繁传递模型参数,导致较高的通信开销。
为了解决上述技术问题,本发明的一个技术方案是提供了一种用于横向联邦学习的模型融合方法,其特征在于,包括以下步骤:
步骤S1:初始化全局云服务器节点,包括构建全局服务模型后,初始化全局模型参数ω1,从而获得用于进行第一轮迭代的全局服务模型W1,初始化迭代轮次t=1;
步骤S2:全局云服务器节点将第t迭代的全局服务模型Wt下发至设备节点集合U={u1,u2,…,ui,…,um}中的每个终端用户设备节点,由每个终端用户设备节点利用本地数据对模型进行局部模型训练,从而得到第t轮迭代的局部模型集合
Figure BDA0003634481990000021
并计算得到第t轮迭代时局部模型集合中每个局部模型的模型训练损失,其中,ui表示第i个终端用户设备节点,i=1,2,…,m,
Figure BDA0003634481990000022
表示第i个终端用户设备节点训练得到第t轮迭代的第i个局部模型;
步骤S3:全局云服务器节点中的仲裁调度模块收集各终端用户设备节点上传的模型训练损失,进而获得第t轮迭代的模型训练损失集合Lt
Figure BDA0003634481990000023
Figure BDA0003634481990000024
仲裁调度模块基于模型训练损失集合Lt从局部模型集合M中选出最佳的K个局部模型,将K个局部模型所对应的各终端用户设备节点重新定义为策略节点,则所有策略节点构成的第t轮迭代的策略节点集合
Figure BDA0003634481990000025
Figure BDA0003634481990000026
Figure BDA0003634481990000027
表示第t轮迭代的第k个策略节点;全局云服务器节点授权各策略节点上传本地局部模型;
步骤S4:全局云服务器节点根据融合策略Se,基于各策略节点上传的第t轮迭代时的局部模型的模型参数,采用平均融合算法或者自适应融合算法计算第t轮迭代时的全局模型参数ωt
步骤S5:全局云服务器节点对选定的K个局部模型采用仲裁激励算法裁定出第t+1轮迭代时全局云服务器节点所需的融合策略Se
步骤S6:全局云服务器节点根据步骤S45得到的第t轮迭代时的全局模型参数ωt更新全局服务模型,t←t+1,返回步骤S2,全局服务模型训练进入下一轮迭代,直至全局服务模型收敛,模型训练结束。
优选地,步骤S3中,所述全局云服务器节点根据K=β·m计算需要筛选的模型个数K,β为预设的筛选因子,随后从模型训练损失集合Lt中挑选出K个最小的模型训练损失所对应的局部模型。
优选地,步骤S4中,全局云服务器节点根据自适应融合算法计算第t轮迭代时的全局模型参数ωt包括以下子步骤:
步骤S4-1-1:计算第t轮迭代的全局服务模型各层权重参数与各策略节点上传的局部模型各层权重参数的差异性,其中,全局服务模型第l层权重参数ωl与第k个策略节点上传的局部模型第l层权重参数
Figure BDA0003634481990000028
的差异性表示为
Figure BDA0003634481990000029
则有:
Figure BDA0003634481990000031
式中,JS(·)表示詹森香农散度,KL(·)表示Kullback-Leible散度,||表示两模型权重的相对熵。
步骤S4-1-2:计算第t轮迭代时各策略节点对于全局服务模型各层权重参数的贡献度,其中,第k个策略节点对全局服务模型第l层权重参数ωl的贡献度为
Figure BDA0003634481990000032
则有:
Figure BDA0003634481990000033
步骤S4-1-3:计算第t轮迭代时全局服务模型的各层权重参数,设第t轮迭代时全局服务模型的第l层权重参数表示为
Figure BDA0003634481990000034
则有:
Figure BDA0003634481990000035
式中,η表示模型训练的学习率,
Figure BDA0003634481990000036
表示梯度算子,L(·)表示全局服务模型的损失。
优选地,步骤S4中,全局云服务器节点根据平均融合算法计算第t轮迭代时的全局模型参数ωt,如下式所示:
Figure BDA0003634481990000037
式中,nk为第k个策略节点的数据量,n为总数据量,
Figure BDA0003634481990000038
为第t轮迭代时第k个策略节点上传的局部模型的模型参数。
优选地,所述步骤S5包括以下步骤:
步骤S5-1:全局云服务器节点计算得到第t-1轮迭代和第t轮迭代的损失差异Δ,Δ←|L(ωt)-L(ωt-1)|;
步骤S5-2,根据损失差异Δ计算全局云服务器节点所需的融合策略Se,则有:
Figure BDA0003634481990000039
式中,ω是选定选择平均融合算法还是自适应融合算法的损失阈值。
优选地,所述步骤S1至所述步骤S6中,所述全局云服务器节点与各终端用户设备节点之间传输的是加密数据。
本发明的另一个技术方案是提供了一种基于横向联邦学习的模型融合系统,其特征在于,包括:
初始化模块,用于初始化全局云服务器节点
终端设备模型训练模块:每个终端用户设备节点通过终端设备模型训练模块从全局云服务器节点获得下发的第t迭代的全局服务模型Wt,随后每个终端用户设备节点采用终端设备模型训练模块利用本地数据对模型进行局部模型训练,从而得到第t轮迭代的局部模型集合
Figure BDA0003634481990000041
并且终端设备模型训练模块计算得到第t轮迭代时局部模型集合中每个局部模型的模型训练损失,其中,ui表示第i个终端用户设备节点,i=1,2,…,m,
Figure BDA0003634481990000042
表示第i个终端用户设备节点训练得到第t轮迭代的第i个局部模型;
仲裁调度模块,进一步包括以下子模块:
本地模型损失输入子模块,用于收集各终端用户设备节点上传的模型训练损失,进而获得第t轮迭代的模型训练损失集合Lt
Figure BDA0003634481990000043
压缩策略子模块,用于根据模型训练损失集合Lt筛选最佳的K个局部模型;
模型选定子模块,用于授权K个局部模型所对应的终端用户设备节点上传本地局部模型到全局云服务器节点,获得授权的终端用户设备节点被定义为策略节点;
本地模型数据量输入子模块,用于获取各策略节点的本地数据量;
全局模型更新策略模块,进一步包括以下子模块:
仲裁阈值输入子模块,用于获取融合算法的损失阈值∈;
模型更新判断子模块:模型更新判断子模块计算得到第t轮迭代和第t-1轮迭代的损失差异Δ,Δ←|L(ωt)-L(ωt-1)|,再根据损失差异Δ计算全局云服务器节点所需的融合策略Se,则有:
Figure BDA0003634481990000044
当Se=1时,进入自适应融合更新子模块;当Se=0时,进入平均融合更新子模块;
自适应融合更新子模块:计算第t轮迭代的全局服务模型各层权重参数与各策略节点上传的局部模型各层权重参数的差异性,依据差异性计算得到第t轮迭代时各策略节点对于全局服务模型各层权重参数的贡献度,然后根据贡献度计算得到第t轮迭代时全局服务模型的各层权重参数;
平均融合更新子模块:根据平均融合算法计算第t轮迭代时的全局模型参数ωt,如下式所示:
Figure BDA0003634481990000051
式中,nk为第k个策略节点的数据量,n为总数据量,
Figure BDA0003634481990000052
为第t轮迭代时第k个策略节点上传的局部模型的模型参数;
终止判定子模块:用于判定全局服务模型是否收敛,若收敛,则结束模型训练;反之,进入第t+1轮模型训练。
优选地,所述初始化模块进一步包括:
全局模型构建子模块,用于构建初始的全局服务模型,包括对全局服务模型的输入单元,隐藏单元,输出单元、输入单元、隐藏单元及输出单元的神经节点个数和各神经节点的连接路径进行设计;
全局模型初始化子模块,用于初始化全局服务模型,包括初始化全局模型参数ω1,从而获得用于进行第一轮迭代的全局服务模型W1
全局变量初始化子模块,用于初始化全局联邦变量,该全局联邦变量包括由m个终端用户设备节点组成的设备节点集合U={u1,u2,…,ui,…,um},ui表示第i个终端用户设备节点,i=1,2,…,m;本地训练次数ε;筛选因子β;迭代轮次t。
优选地,所述终端设备模型训练模块进一步包括:
全局模型输入子模块:每个终端用户设备节点通过全局模型输入子模块与全局云服务器节点通信,经过全局云服务器节点的身份认证后获取第t迭代的全局服务模型Wt
本地模型训练子模块:每个终端用户设备节点采用本地模型训练子模块利用本地数据训练所获得的模型ε次,得到第t轮迭代的局部模型,同时,计算得到第t轮迭代时当前局部模型的模型训练损失,并上传至全局云服务节点;
并行训练子模块,用于并行执行全局模型输入子模块和本地模型训练子模块,得到得到第t轮迭代的局部模型集合
Figure BDA0003634481990000053
本发明中,首先云服务器初始化全局模型,下发模型至用户终端设备。然后,用户终端设备利用本地数据进行局部模型训练,将训练得到的局部模型上传至云服务器。最后,云服务器根据上传的局部模型按照设计的模型融合策略进行融合更新,计算得到准确的全局模型代替局部模型。
本发明根据终端用户设备节点训练的本地模型优劣情况进行自适应式模型融合优化,通过结合自适应融合和平均融合机制来设计全局模型更新策略。云服务器利用压缩策略筛选最佳终端设备模型来更新全局模型,解决了横向联邦学习技术中,用户终端设备与云服务器频繁传递模型参数所导致的高通信开销,提高模型融合的性能。在本发明所提供的方法以及系统中,在整个系统运行中传输的是加密数据,能够避免数据泄露,从而保证数据和模型的安全性。
附图说明
图1为本发明实施例提供的总体方法流程图;
图2为本发明实施例提供的总体原理示意图;
图3为本发明实施例基于横向联邦学习的模型融合系统模块设计示意图。
具体实施方式
下面结合具体实施例,进一步阐述本发明。应理解,这些实施例仅是为了助于本技术领域的普通技术人员对本发明原理和知识的理解,而不用于限制本发明的范围,不能认为是限制本发明的应用场景。此外应理解,在阅读了本发明讲授的内容之后,本领域技术人员可以对本发明作各种改动或修改,但基于本发明的原理和宗旨对实施例所做的变形、变化和转换同样落于本申请所附权利要求书所限定的范围。并且显而易见的是,本说明书仅以优选的实施方式作为举例,无需详尽所有的实施方式。
下面以100个终端用户设备联合训练长短期神经网络模型为例,来阐述本发明的具体实施步骤。
结合图1以及图2,本发明的基于横向联邦学习的模型融合方法的实施例具体步骤包括:
步骤S1:全局云服务器节点初始化,包括构建初始模型,初始化全局服务模型以及所需的超参数,进一步包括以下子步骤:
步骤S1-1:构建初始全局服务模型,包括设计全局服务模型中的输入单元、隐藏单元以及输出单元的神经节点个数和连接路径。
本实施例中,全局服务模型为长短期神经网络模型,构建初始长短期神经网络模型时,设计模型输入层和模型输出层分别有298和1个神经元节点,并设计模型输入层中298个神经元节点的连接路径。
步骤S1-2:初始化步骤S1-1建立的全局服务模型,包括初始化全局服务模型的全局模型参数ω1,从而获得用于进行第一轮迭代的全局服务模型W1
步骤S1-3:初始化全局联邦变量,该全局联邦变量包括:由m个终端用户设备节点组成的设备节点集合U={u1,u2,…,ui,…,um},ui表示第i个终端用户设备节点,i=1,2,…,m;本地训练次数ε;筛选因子β;迭代轮次t。
本实施例中:m=100,则U={u1,u,……,u100};ε=10;β=0.8;t=1。
步骤S2:终端用户设备节点模型训练。全局云服务器节点将初始化后的全局服务模型下发至设备节点集合U={u1,u2,…,ui,…,um}中的每个终端用户设备节点,由每个终端用户设备节点利用本地数据对模型进行局部模型训练,从而得到第t轮迭代的局部模型集合
Figure BDA0003634481990000071
第i个终端用户设备节点训练得到第t轮迭代的第i个局部模型
Figure BDA0003634481990000072
具体包括以下子步骤:
步骤S2-1:终端用户设备节点ui与全局云服务器节点通信,经过全局云服务器节点的身份认证后获取全局服务模型Wt,Wt为第t轮迭代的全局服务模型。
本实施例中,假定当前迭代轮次t=10,则终端用户设备节点u50从全局云服务器节点获取全局服务模型W10
步骤S2-2:终端用户设备节点ui利用本地数据训练全局服务模型Wt,共训练ε次,得到第t轮迭代更新后的局部模型
Figure BDA0003634481990000073
同时,终端用户设备节点ui计算出第t轮迭代的模型训练损失
Figure BDA0003634481990000074
并将其上传至全局云服务器节点。
本实施例中,假定当前迭代轮次t=10,则终端用户设备节点u50利用本地数据训练全局服务模型W10,共10次,得到第t=10轮迭代的局部模型
Figure BDA0003634481990000075
步骤S3:局部模型筛选。仲裁调度模块收集终端用户设备节点上传的模型训练损失,利用压缩策略算法选出第t轮迭代的最佳的K个局部模型,并授权相应的终端用户设备节点上传本地局部模型到全局云服务器节点。步骤S3具体包括以下子步骤:
步骤S3-1:全局云服务器节点设计有仲裁调度模块,用于管理终端用户设备节点,仲裁调度模块获取各终端用户设备节点上传的模型训练损失,所有模型训练损失构成第t轮迭代的模型训练损失集合Lt
Figure BDA0003634481990000081
本实施例中,仲裁调度模块获取的模型训练损失集合
Figure BDA0003634481990000082
步骤S3-2:全局云服务器节点根据公式K=β·m计算需要筛选的模型个数K。
本实施例中,K=β·m=0.8×100=80。
步骤S3-3:根据公式
Figure BDA0003634481990000083
筛选第t轮迭代的模型质量最佳的K个本地网络模型。公式
Figure BDA0003634481990000084
表示从模型训练损失集合Lt中挑选出K个最小的模型训练损失所对应的局部模型。
本实施例中,筛选出模型质量最佳的80个局部模型。
步骤S3-4:将步骤3-3选出的局部模型所对应的各终端用户设备节点重新定义为策略节点,则所有策略节点构成的第t轮迭代的策略节点集合记为
Figure BDA0003634481990000085
Figure BDA0003634481990000086
表示第t轮迭代的第k个策略节点。全局云服务器节点授权各策略节点上传本地局部模型。
本实施例中,选定的策略节点集合记为
Figure BDA0003634481990000087
步骤S4:全局云服务器节点根据收集到的终端用户设备模型和上轮选定的融合策略Se执行全局模型整合,得到第t轮迭代时的全局模型参数ωt,全局模型参数包括全局服务模型的各层权重参数,具体包括以下子步骤:
步骤S4-1:若融合策略Se=0,则全局云服务器节点根据平均融合算法计算第t轮迭代时的全局模型参数ωt
若融合策略Se=1,则全局云服务器节点根据自适应融合算法计算第t轮迭代时的全局模型参数ωt
其中,全局云服务器节点根据自适应融合算法计算第t轮迭代时的全局模型参数ωt包括以下子步骤:
步骤S4-1-1:计算第t轮迭代的全局服务模型各层权重参数与各策略节点上传的局部模型各层权重参数的差异性,其中,全局服务模型第l层权重参数ωl与第k个策略节点上传的局部模型第l层权重参数
Figure BDA0003634481990000088
的差异性表示为
Figure BDA0003634481990000089
则有:
Figure BDA00036344819900000810
式中,JS(·)表示詹森香农散度,KL(·)表示Kullback-Leible散度,||表示两模型权重的相对熵。
步骤S4-1-2:计算第t轮迭代时各策略节点对于全局服务模型各层权重参数的贡献度,其中,第k个策略节点对全局服务模型第l层权重参数ωl的贡献度为
Figure BDA0003634481990000098
则有:
Figure BDA0003634481990000091
步骤S4-1-3:计算第t轮迭代时全局服务模型的各层权重参数,设第t轮迭代时全局服务模型的第l层权重参数表示为
Figure BDA0003634481990000092
则有:
Figure BDA0003634481990000093
式中,η表示模型学习率,
Figure BDA0003634481990000094
表示梯度算子,L(·)表示全局服务模型的损失。
全局云服务器节点根据平均融合算法计算第t轮迭代时的全局模型参数ωt,如下式所示:
Figure BDA0003634481990000095
式中,nk为第k个策略节点的数据量,n为总数据量,
Figure BDA0003634481990000096
为第t轮迭代时第k个策略节点上传的局部模型的模型参数。
步骤S5:全局云服务器节点对选定的K个局部模型采用仲裁激励算法裁定出第t+1轮迭代时全局云服务器节点所需的融合策略Se,具体包括以下子步骤:
步骤S5-1:全局云服务器节点计算得到上一轮第t-1轮迭代和当前第t轮迭代的全局模型参数损失差异Δ,Δ←|L(ωt)-L(ωt-1)|。
本实施例中,假定当前迭代轮次t=10,则全局云服务器节点利用训练损失计算轮次9和轮次10的损失差异Δ=0.2。
步骤S5-2,根据损失差异Δ计算全局云服务器节点所需的融合策略Se,则有:
Figure BDA0003634481990000097
式中,∈是选定选择平均融合算法还是自适应融合算法的损失阈值。
本实施例中,若设置∈=0.5,当Δ=0.2时,计算得到的下一轮次的融合策略Se=0。
步骤S6:全局云服务器节点根据步骤S4得到的第t轮迭代时的全局模型参数ω更新全局服务模型,t←t+1,返回步骤S2,全局服务模型训练进入下一轮迭代,直至全局服务模型收敛,模型训练结束。
上述步骤S1至步骤S6中,所述全局云服务器节点与各终端用户设备节点之间传输的是加密数据。
本发明实施例所提供的上述方法可采用计算机理论技术实现自动执行流程,为本领域技术人员的常识,此次不再赘述。
应当注意的是,本说明书的描述术语“实例”、“实施例”或“样例”等意在阐述本发明实施例的结构、功能或者特征。上述术语的表述是示范性的,不具有本发明实施例的限定性。而且,表述的结构、功能或者特征可以在多个实例中以符合实际的方式组合。

Claims (9)

1.一种用于横向联邦学习的模型融合方法,其特征在于,包括以下步骤:
步骤S1:初始化全局云服务器节点,包括构建全局服务模型后,初始化全局模型参数ω1,从而获得用于进行第一轮迭代的全局服务模型W1,初始化迭代轮次t=1;
步骤S2:全局云服务器节点将第t迭代的全局服务模型Wt下发至设备节点集合U={u1,u2,...,ui,...,um}中的每个终端用户设备节点,由每个终端用户设备节点利用本地数据对模型进行局部模型训练,从而得到第t轮迭代的局部模型集合
Figure FDA0003634481980000011
并计算得到第t轮迭代时局部模型集合中每个局部模型的模型训练损失,其中,ui表示第i个终端用户设备节点,i=1,2,...,m,
Figure FDA0003634481980000017
表示第i个终端用户设备节点训练得到第t轮迭代的第i个局部模型;
步骤S3:全局云服务器节点中的仲裁调度模块收集各终端用户设备节点上传的模型训练损失,进而获得第t轮迭代的模型训练损失集合Lt
Figure FDA0003634481980000012
Figure FDA0003634481980000013
仲裁调度模块基于模型训练损失集合Lt从局部模型集合M中选出最佳的K个局部模型,将K个局部模型所对应的各终端用户设备节点重新定义为策略节点,则所有策略节点构成的第t轮迭代的策略节点集合
Figure FDA0003634481980000014
Figure FDA0003634481980000015
Figure FDA0003634481980000016
表示第t轮迭代的第k个策略节点;全局云服务器节点授权各策略节点上传本地局部模型;
步骤S4:全局云服务器节点根据融合策略Se,基于各策略节点上传的第t轮迭代时的局部模型的模型参数,采用平均融合算法或者自适应融合算法计算第t轮迭代时的全局模型参数ωt
步骤S5:全局云服务器节点对选定的K个局部模型采用仲裁激励算法裁定出第t轮迭代时全局云服务器节点所需的融合策略Se
步骤S6:全局云服务器节点根据步骤S4得到的第t轮迭代时的全局模型参数ωt更新全局服务模型,t←t+1,返回步骤S2,全局服务模型训练进入下一轮迭代,直至全局服务模型收敛,模型训练结束。
2.如权利要求1所述的一种用于横向联邦学习的模型融合方法,其特征在于,步骤S3中,所述全局云服务器节点根据K=β·m计算需要筛选的模型个数K,β为预设的筛选因子,随后从模型训练损失集合Lt中挑选出K个最小的模型训练损失所对应的局部模型。
3.如权利要求1所述的一种用于横向联邦学习的模型融合方法,其特征在于,步骤S4中,全局云服务器节点根据自适应融合算法计算第t轮迭代时的全局模型参数ωt包括以下子步骤:
步骤S4-1-1:计算第t轮迭代的全局服务模型各层权重参数与各策略节点上传的局部模型各层权重参数的差异性,其中,全局服务模型第l层权重参数ωl与第k个策略节点上传的局部模型第l层权重参数
Figure FDA0003634481980000021
的差异性表示为
Figure FDA0003634481980000022
则有:
Figure FDA0003634481980000023
式中,JS(·)表示詹森香农散度,KL(·)表示Kullback-Leible散度,||表示两模型权重的相对熵;
步骤S4-1-2:计算第t轮迭代时各策略节点对于全局服务模型各层权重参数的贡献度,其中,第k个策略节点对全局服务模型第l层权重参数ωl的贡献度为
Figure FDA0003634481980000029
则有:
Figure FDA0003634481980000024
步骤S4-1-3:计算第t轮迭代时全局服务模型的各层权重参数,设第t轮迭代时全局服务模型的第l层权重参数表示为
Figure FDA0003634481980000025
则有:
Figure FDA0003634481980000026
式中,η表示模型学习率,
Figure FDA0003634481980000027
表示梯度算子,L(·)表示全局服务模型的损失。
4.如权利要求1所述的一种用于横向联邦学习的模型融合方法,其特征在于,所述步骤S5包括以下步骤:
步骤S5-1:全局云服务器节点计算得到第t轮迭代和第t-1轮迭代的损失差异Δ,Δ←|L(ωt)-L(ωt-1)|;
步骤S5-2,根据损失差异Δ计算全局云服务器节点所需的融合策略Se,则有:
Figure FDA0003634481980000028
式中,∈是选定选择平均融合算法还是自适应融合算法的损失阈值。
5.如权利要求1所述的一种用于横向联邦学习的模型融合方法,其特征在于,步骤S4中,全局云服务器节点根据平均融合算法计算第t轮迭代时的全局模型参数ωt,如下式所示:
Figure FDA0003634481980000031
式中,nk为第k个策略节点的数据量,n为总数据量,
Figure FDA0003634481980000032
为第t轮迭代时第k个策略节点上传的局部模型的模型参数。
6.如权利要求1所述的一种用于横向联邦学习的模型融合方法,其特征在于,所述步骤S1至所述步骤S6中,所述全局云服务器节点与各终端用户设备节点之间传输的是加密数据。
7.一种基于横向联邦学习的模型融合系统,其特征在于,包括:
初始化模块,用于初始化全局云服务器节点
终端设备模型训练模块:每个终端用户设备节点通过终端设备模型训练模块从全局云服务器节点获得下发的第t迭代的全局服务模型Wt,随后每个终端用户设备节点采用终端设备模型训练模块利用本地数据对模型进行局部模型训练,从而得到第t轮迭代的局部模型集合
Figure FDA0003634481980000033
并且终端设备模型训练模块计算得到第t轮迭代时局部模型集合中每个局部模型的模型训练损失,其中,ui表示第i个终端用户设备节点,i=1,2,...,m,
Figure FDA0003634481980000034
表示第i个终端用户设备节点训练得到第t轮迭代的第i个局部模型;
仲裁调度模块,进一步包括以下子模块:
本地模型损失输入子模块,用于收集各终端用户设备节点上传的模型训练损失,进而获得第t轮迭代的模型训练损失集合Lt
Figure FDA0003634481980000035
压缩策略子模块,用于根据模型训练损失集合Lt筛选最佳的K个局部模型;
模型选定子模块,用于授权K个局部模型所对应的终端用户设备节点上传本地局部模型到全局云服务器节点,获得授权的终端用户设备节点被定义为策略节点;
本地模型数据量输入子模块,用于获取各策略节点的本地数据量;
全局模型更新策略模块,进一步包括以下子模块:
仲裁阈值输入子模块,用于获取融合算法的损失阈值∈;
模型更新判断子模块:模型更新判断子模块计算得到第t轮迭代和第t-1轮迭代的损失差异Δ,Δ←|L(ωt)-L(ωt-1)|,再根据损失差异Δ计算全局云服务器节点所需的融合策略Se,则有:
Figure FDA0003634481980000041
当Se=1时,进入自适应融合更新子模块;当Se=0时,进入平均融合更新子模块;
自适应融合更新子模块:计算第t轮迭代的全局服务模型各层权重参数与各策略节点上传的局部模型各层权重参数的差异性,依据差异性计算得到第t轮迭代时各策略节点对于全局服务模型各层权重参数的贡献度,然后根据贡献度计算得到第t轮迭代时全局服务模型的各层权重参数;
平均融合更新子模块:根据平均融合算法计算第t轮迭代时的全局模型参数ωt,如下式所示:
Figure FDA0003634481980000042
式中,nk为第k个策略节点的数据量,n为总数据量,
Figure FDA0003634481980000043
为第t轮迭代时第k个策略节点上传的局部模型的模型参数;
终止判定子模块:用于判定全局服务模型是否收敛,若收敛,则结束模型训练;反之,进入第t+1轮模型训练。
8.如权利要求7所述的一种基于横向联邦学习的模型融合系统,其特征在于,所述初始化模块进一步包括:
全局模型构建子模块,用于构建初始的全局服务模型,包括对全局服务模型的输入单元,隐藏单元,输出单元、输入单元、隐藏单元及输出单元的神经节点个数和各神经节点的连接路径进行设计;
全局模型初始化子模块,用于初始化全局服务模型,包括初始化全局模型参数ω1,从而获得用于进行第一轮迭代的全局服务模型W1
全局变量初始化子模块,用于初始化全局联邦变量,该全局联邦变量包括由m个终端用户设备节点组成的设备节点集合U={u1,u2,...,ui,...,um},ui表示第i个终端用户设备节点,i=1,2,...,m;本地训练次数ε;筛选因子β;迭代轮次t。
9.如权利要求8所述的一种基于横向联邦学习的模型融合系统,其特征在于,所述终端设备模型训练模块进一步包括:
全局模型输入子模块:每个终端用户设备节点通过全局模型输入子模块与全局云服务器节点通信,经过全局云服务器节点的身份认证后获取第t迭代的全局服务模型Wt
本地模型训练子模块:每个终端用户设备节点采用本地模型训练子模块利用本地数据训练所获得的模型ε次,得到第t轮迭代的局部模型,同时,计算得到第t轮迭代时当前局部模型的模型训练损失,并上传至全局云服务节点;
并行训练子模块,用于并行执行全局模型输入子模块和本地模型训练子模块,得到得到第t轮迭代的局部模型集合
Figure FDA0003634481980000051
CN202210498743.0A 2022-05-09 2022-05-09 一种用于横向联邦学习的模型融合方法及系统 Pending CN115034356A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210498743.0A CN115034356A (zh) 2022-05-09 2022-05-09 一种用于横向联邦学习的模型融合方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210498743.0A CN115034356A (zh) 2022-05-09 2022-05-09 一种用于横向联邦学习的模型融合方法及系统

Publications (1)

Publication Number Publication Date
CN115034356A true CN115034356A (zh) 2022-09-09

Family

ID=83119507

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210498743.0A Pending CN115034356A (zh) 2022-05-09 2022-05-09 一种用于横向联邦学习的模型融合方法及系统

Country Status (1)

Country Link
CN (1) CN115034356A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115277264A (zh) * 2022-09-28 2022-11-01 季华实验室 一种基于联邦学习的字幕生成方法、电子设备及存储介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115277264A (zh) * 2022-09-28 2022-11-01 季华实验室 一种基于联邦学习的字幕生成方法、电子设备及存储介质
CN115277264B (zh) * 2022-09-28 2023-03-24 季华实验室 一种基于联邦学习的字幕生成方法、电子设备及存储介质

Similar Documents

Publication Publication Date Title
US11461654B2 (en) Multi-agent cooperation decision-making and training method
CN111309824B (zh) 实体关系图谱显示方法及系统
CN109990790B (zh) 一种无人机路径规划方法及装置
US11966837B2 (en) Compression of deep neural networks
AU2024200810A1 (en) Training tree-based machine-learning modeling algorithms for predicting outputs and generating explanatory data
CN111506405A (zh) 一种基于深度强化学习的边缘计算时间片调度方法
CN108111860B (zh) 基于深度残差网络的视频序列丢失帧预测恢复方法
CN108615231B (zh) 一种基于神经网络学习融合的全参考图像质量客观评价方法
CN111401344A (zh) 人脸识别方法和装置及人脸识别系统的训练方法和装置
CN112311578A (zh) 基于深度强化学习的vnf调度方法及装置
CN115034356A (zh) 一种用于横向联邦学习的模型融合方法及系统
CN112269729A (zh) 面向网络购物平台大规模服务器集群的负载智能分析方法
CN114584406B (zh) 一种联邦学习的工业大数据隐私保护系统及方法
CN112256916A (zh) 一种基于图胶囊网络的短视频点击率预测方法
CN115563858A (zh) 一种工作机稳态性能提升的方法、装置、设备及介质
CN110968512A (zh) 软件质量评估方法、装置、设备及计算机可读存储介质
Shan et al. A hybrid knowledge-based system for urban development
CN106228029B (zh) 基于众包的量化问题求解方法和装置
Qu et al. Learning-based multi-drone network edge orchestration for video analytics
CN116432053A (zh) 基于模态交互深层超图神经网络的多模态数据表示方法
Rădulescu et al. Analysing congestion problems in multi-agent reinforcement learning
CN115762147A (zh) 一种基于自适应图注意神经网络的交通流量预测方法
Lu et al. A network traffic prediction model based on reinforced staged feature interaction and fusion
CN114399901A (zh) 一种控制交通系统的方法和设备
CN114792187A (zh) 基于意愿和信任双重约束的群智感知团队招募方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination