CN116416508A - 一种加快全局联邦学习模型收敛的方法及联邦学习系统 - Google Patents
一种加快全局联邦学习模型收敛的方法及联邦学习系统 Download PDFInfo
- Publication number
- CN116416508A CN116416508A CN202310262721.9A CN202310262721A CN116416508A CN 116416508 A CN116416508 A CN 116416508A CN 202310262721 A CN202310262721 A CN 202310262721A CN 116416508 A CN116416508 A CN 116416508A
- Authority
- CN
- China
- Prior art keywords
- network
- model
- global
- training
- federal learning
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 46
- 230000006854 communication Effects 0.000 claims abstract description 58
- 238000004891 communication Methods 0.000 claims abstract description 55
- 238000012549 training Methods 0.000 claims abstract description 50
- 230000006870 function Effects 0.000 claims abstract description 44
- 238000005265 energy consumption Methods 0.000 claims abstract description 31
- 230000002787 reinforcement Effects 0.000 claims abstract description 31
- 230000008569 process Effects 0.000 claims abstract description 17
- 238000005457 optimization Methods 0.000 claims abstract description 14
- 238000004458 analytical method Methods 0.000 claims abstract description 9
- 238000004422 calculation algorithm Methods 0.000 claims description 23
- 230000005540 biological transmission Effects 0.000 claims description 16
- 230000009471 action Effects 0.000 claims description 15
- 230000004913 activation Effects 0.000 claims description 8
- 230000002776 aggregation Effects 0.000 claims description 7
- 238000004220 aggregation Methods 0.000 claims description 7
- 238000012360 testing method Methods 0.000 claims description 7
- 238000013507 mapping Methods 0.000 claims description 3
- 238000012545 processing Methods 0.000 claims description 3
- 238000013461 design Methods 0.000 claims 1
- 238000004088 simulation Methods 0.000 description 9
- 238000004364 calculation method Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000002474 experimental method Methods 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 238000009826 distribution Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013434 data augmentation Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000007613 environmental effect Effects 0.000 description 1
- 230000003090 exacerbative effect Effects 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 238000003909 pattern recognition Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
- 230000001360 synchronised effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/94—Hardware or software architectures specially adapted for image or video understanding
- G06V10/95—Hardware or software architectures specially adapted for image or video understanding structured as a network, e.g. client-server architectures
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及一种加快全局联邦学习模型收敛的方法及联邦学习系统,属于工业物联网技术领域,所述方法包括以下步骤:S1、对联邦学习系统的时延建模分析;S2、对联邦学习系统的能耗建模分析;S3、确定优化目标;S4、构建图像分类网络和强化学习智能体网络;S5、设计图像分类网络的损失函数;S6、将节点选择问题转化为马尔可夫决策过程;S7、训练强化学习智能体网络;S8、使用Q网络指导联邦学习的设备选择。所述联邦学习系统,包括云服务器、多个边缘设备和无线网络。本发明通过强化学习智能体辅助联邦学习系统选择合适的设备参与训练,加速全局联邦学习模型的收敛速度,减少联邦训练所需要的通信轮次,降低通信和能耗成本。
Description
技术领域
本发明属于工业物联网技术领域,具体涉及一种加快全局联邦学习模型收敛的方法及联邦学习系统。
背景技术
工业物联网引入大量人工智能技术,目的是在异构和大规模网络中实现数据驱动的机器学习解决方案。然而由于隐私保护、法律法规等条件的限制,设备间、机构间形成“数据孤岛”。为了打破数据孤岛,充分利用分散在设备上的数据,联邦学习技术被提出并用在数据隐私保护等领域。联邦学习是一种保护用户隐私分布式机器学习框架,在不共享数据情况下完成联合建模。其主要思想是在边缘设备上训练本地模型、由中心云服务器执行联邦平均算法完成模型聚合,并将聚合模型下发给所有参与联邦训练边缘设备进行下一轮本地训练。边缘设备和云服务器重复以上过程直到全局模型达到目标精度。然而传统联邦学习面临以下方面挑战:
(1)设备异构性:不同边缘设备具有不同计算能力、存储空间、电池容量等;
(2)数据异构性:移动设备上收集的数据通常是非独立同分布的,这违反了分布式优化的基本假设。
(3)网络状态不稳定:边缘设备通信资源受限,数据传输速率受限环境影响波动较大,可能会增加联邦训练过程的通信成本以及传输功耗。
工业物联网(Industrial Internet of Things,IIoT)设备的计算和通信资源是受限的,很多IIoT设备使用电池供电,因此需要考虑设备的通信和能耗成本。此外,由于IIoT设备的通信质量不稳定,设备可能需要花费大量的时间用于本地模型的传输,增加通信成本。并且,数据的Non-IID分布会减慢全局模型收敛速度,降低模型精度,并且导致更多的通信轮次以达到模型收敛,因此加快全局联邦模型的收敛速度对降低训练的通信和能耗成本具有重要意义。
Li等人在其发表的论文”Federated Optimization in HeterogeneousNetworks”中提出了FedProx算法,通过使用正则项来平衡全局目标和局部目标之间的优化差异,降低Non-IID数据分布的影响,但是没有考虑动态网络场景下对系统通信成本和能耗成本的影响。
WEN等人在其发表的论文”Communication-Efficient Federated DataAugmentation on Non-IID Data”(Conference on Computer Vision and PatternRecognition Workshops,2022)提出使用自动编码器生成设备缺失的样本,降低设备数据集的非IID程度,但是这种方法给IIoT设备引入额外的计算负担。
YANG等人在其发表的论文”Improving Accuracy and Convergence in Group-based Federated Learning on Non-IID Data”(IEEE Transactions on NetworkScience and Engineering,2022)提出对边缘设备的本地模型进行分组聚类,在训练时从每组中随机抽取设备参加联邦训练,但是这种方法无法确定最佳分类簇数,可能会影响全局模型的收敛。
发明内容
为解决现有技术中存在的上述问题,本发明提供了一种加快全局联邦学习模型收敛的方法及联邦学习系统,提高了全局联邦模型的收敛速度,降低了联邦学习的通信和能耗成本。
本发明的目的可以通过以下技术方案实现:
本发明提供了一种加快全局联邦学习模型收敛的方法,包括以下步骤:
S1、对联邦学习系统的时延进行建模分析;
S2、对联邦学习系统的能耗进行建模分析;
S3、确定优化目标;
S4、构建图像分类网络和强化学习智能体网络;
S5、设计图像分类网络的损失函数;
S6、将节点选择问题转化为马尔可夫决策过程;
S7、训练强化学习智能体网络;
S8、使用Q网络指导联邦学习的设备选择;
所述强化学习智能体网络采用DDQN强化学习算法;所述强化学习智能体网络包括所述Q网络和Target Q网络,所述Q网络和Target Q网络采用相同的网络结构。
进一步地,所述对联邦学习系统的时延进行建模分析,具体包括:
其中,di为设备i的样本数量,ci为训练一个样本需要的cpu周期数,fi为工作频率,τ本地迭代的轮次;
计算第k次通信过程的训练时延Tk:
进一步地,所述对联邦学习系统的能耗进行建模分析,具体包括:
计算设备i在第k个通信轮次执行本地训练的能量消耗;
计算设备i进行模型传输的通信功耗;
结合所述能量消耗与所述通信功耗得到设备i的空闲能耗;
根据所述空闲能耗计算设备消耗的总能量。
进一步地,所述优化目标描述为:
min(Loss(x;θ)) (1)
fmin≤fi≤fmax (2)
Tk<Tmax (5)
其中公式(1)表示优化目标,它表示以最小化全局模型在测试集上的损失函数,x表示测试集样本,θ表示全局模型参数,公式(2)表示对设备工作频率的约束,公式(3)表示被选设备的总带宽不大于服务器带宽B,公式(4)表示在一个通信轮次中至少有一个设备被选中参与联邦训练,而最大数量不超过设备总量N,公式(5)表示在第k个通信轮次的训练时延Tk不能超过规定的最大时延Tmax。
进一步地,所述图像分类网络为两层的MLP网络;具体包括图像输入层、第一线性网络、第一激活函数层、第二线性网络、第二激活函数层和全连接网络。
进一步地,所述设计图像分类网络的损失函数,具体包括:
所述图像分类网络的损失函数包括交叉熵损失函数(lce)和最大均值差异损失(lMMD);
其中,lce用于图像分类任务;lMMD用于衡量全局模型和本地模型在本地数据样本输入下的输出差异;
将lce+lMMD作为本地网络模型的损失函数,并对其进行梯度下降,以更新本地模型参数。
进一步地,所述马尔可夫决策过程包括系统状态、动作空间、策略、奖励函数以及邻接状态;
其中,所述策略表示从状态空间到动作空间的映射;
所述奖励函数设置为最小化的时延与能耗的加权和。
进一步地,所述训练强化学习智能体网络,具体包括:
计算时序差分目标;
根据所述时序差分目标定义损失函数;
通过梯度下降法最小化损失函数以更新所述Q网络的网络参数。
进一步地,所述使用Q网络指导联邦学习的设备选择,具体包括:
云服务器向所有参与联邦学习的边缘设备发送全局模型参数并收集联邦学习系统的状态信息;
将状态信息输入Q网络,Q网络输出Q值,所述Q值表示每个动作的价值,并将排名前k个Q值对应的设备作为当前状态下的最佳设备子集;
最佳设备子集使用本地数据训练图像分类网络以更新本地模型,然后将本地模型上传至云服务器;
云服务器执行模型聚合算法,以更新全局模型;
不断执行以上过程,直到全局模型达到目标精度。
本发明还提供了一种加快全局联邦学习模型收敛的联邦学习系统,包括云服务器、多个边缘设备和无线网络;
所述云服务器,用于存储和更新全局模型参数、接收和发送消息,以及运行强化学习智能体网络和执行模型聚合算法;
所述边缘设备,用于存储和处理本地数据、计算节点评分、执行本地训练、接收和发送消息;
所述无线网络,用于连接云服务器和边缘设备。
本发明的有益效果为:
(1)通过强化学习智能体辅助联邦学习系统选择合适的设备参与训练,以最小全局模型在测试数据集上的模型损失函数为优化目标,以加快全局模型的收敛速度,降低联邦训练过程的通信和能耗成本。
(2)通过构建的强化学习智能体经过大量的交互数据不断改进节点选择策略,使得该方法具有较高的准确率和更高的鲁棒性。
(3)构建的强化学习智能体运行在云端服务器,不会给边缘设备引入额外的计算负担。
附图说明
为了便于本领域技术人员理解,下面结合附图对本发明作进一步的说明。
图1为本发明的方法流程示意图;
图2为本发明中图像分类器的结构示意图;
图3为本发明中Q网络的结构示意图;
图4为本发明中图像分类网络的损失函数构成示意图;
图5为本发明一实施例DDQN训练阶段的回报曲线图;
图6为本发明一实施例仿真实验的准确率对比曲线图。
具体实施方式
为更进一步阐述本发明为实现预定发明目的所采取的技术手段及功效,以下结合附图及较佳实施例,对依据本发明的具体实施方式、结构、特征及其功效,详细说明如下。
一种加快全局联邦学习模型收敛的方法,如图1所示,具体包括以下步骤:
S1、对联邦学习系统的时延进行建模分析。
在一个通信轮次中,设备的训练时延由本地模型训练时的计算时延和模型传输时的通信时延组成。假设设备i拥有di个样本,训练一个样本需要的cpu周期数为ci,工作频率为fi,并进行τ个轮次的本地迭代,那么设备i在第k个通信轮次中执行模型训练消耗的时间为:
由上述公式可知,设备i的数据传输速率与分配给设备i的带宽Bi,k、设备i的发射功率pi,k、信道增益gi,k以及噪声功率N0有关,环境状态的改变会影响数据传输速率,加剧通信时延的不确定性。
结合上述内容,设备i在第k个通信轮次中的本地训练时延为:
在联邦同步学习算法中,每个通信轮次的本地训练时间由最慢的设备决定,所以第k次通信过程的训练时延为:
S2、对联邦学习系统的能耗进行建模分析。
设备i在第k个通信轮次,执行本地训练的能量消耗为:
上式中,σ表示有效电容系数,σ与芯片本身的性质有关。
在模型传输时,设备i使用值为pi,k的功率进行模型传输,设备i的通信功耗为:
需说明的是,最先完成本地模型训练和模型传输的设备需要等待其他未完成的设备,执行速度快的设备存在空闲等待时间,在空闲等待时间消耗的能量称之为空闲能耗,那么设备i的空闲能耗等于空闲等待时间乘以空载状态下的单位能耗,计算公式如下:
由上述可知,设备i在第k个通信轮次消耗的能量为:
因此在第k个通信轮次,所有设备消耗的总能量为:
S3、确定优化目标。
目标是在动态网络场景下,考虑工业物联网(Industrial Internet of Things,IIoT)设备的异构性,以最小化全局模型在测试集上的损失函数为目标优化目标,并使用以下公式进行描述:
min(Loss(x;θ)) (1)
fmin≤fi≤fmax (2)
Tk<Tmax (5)
其中公式(1)为优化目标,它表示全局模型在测试集上的准确率,公式(2)表示对设备工作频率的约束,公式(3)表示被选设备的总带宽不大于服务器带宽B,公式(4)表示在一个通信轮次中至少有一个设备被选中参与联邦训练,而最大数量不超过设备总量N,公式(5)表示在第k个通信轮次的训练时延Tk不能超过规定的最大时延Tmax。
S4、构建图像分类网络和强化学习智能体网络。
如图2所示,通过构建一个两层的MLP(多层感知器)网络作为IIoT设备的本地图像分类网络,包括图像输入层、第一线性网络、第一激活函数层、第二线性网络、第二激活函数层和全连接网络;设置第一线性网络的输入维度为784,输出层维度为200,第二线性网络的输入维度为200,输出维度为200,全连接层的输入维度为200,输出维度为10;第一、二激活函数层均采用ReLU函数实现。
强化学习智能体网络采用DDQN(双重深度Q网络)强化学习算法,包括一个Q网络和一个Target Q网络,Q网络和Target Q网络采用相同的网络结构。其中,Q网络也称为强化学习智能体或智能体。
进一步地,Q网络和Target Q网络的网络结构由两层线性网络构成,如图3所示,第一层线性网络的输入维度为联邦学习系统的状态维度,输出维度设置为128,第一层线性网络的输出接ReLU激活函数。第二层线性网络的输入维度为128,输出维度等于边缘IIoT设备的数量。
本发明在实验时使用了20个IIoT设备,将每个设备当前时刻的数据传输速率、工作频率、信号发射功率、样本数量作为状态信息,因此,联邦学习系统的状态维度为80。
S5、设计图像分类网络的损失函数。
如图4所述,图像分类网络的损失函数由两部分构成:一部分是交叉熵损失函数(lce),用于图像分类任务;另一部分是最大均值差异损失(lMMD),用于衡量全局模型和本地模型在本地数据样本输入下的输出差异,将lce+lMMD作为本地网络模型的损失函数,并对其进行梯度下降,以更新本地模型参数。
S6、将节点选择问题转化为马尔可夫决策过程。
使用强化学习的方法解决联邦学习的节点选择问题,首先需要将该问题抽象为一个马尔可夫决策过程。一个马尔可夫决策过程包括系统状态S(t)、动作空间A(t)、策略π、奖励函数r以及邻接状态S(t+1),具体为:
系统状态S(t)由设备与服务器之间的数据传输速率β(t)、设备的工作频率ζ(t)、设备的信号发射功率Tp(t)以及拥有的样本数量ψ(t)组成。因此可定义时隙t的系统状态为:
S(t)={β(t),ζ(t),Tp(t),ψ(t)}
策略π表示从状态空间S(t)到动作空间A(t)的映射,即A(t)=π(S(t))。DRL(深度强化学习)的目标是学习一个最佳策略π,使得智能体根据当前状态做出的动作可以获得最大的期望奖励。
奖励函数的设置与优化目标一致,即最小化的时延与能耗的加权和,因此,奖励函数r表示为:
r=-Loss(x;θ)
邻接状态S(t+1)由当前状态S(t)以及策略π决定,具体的表达形式如下:
S(t+1)=S(t)+π(S(t))。
S7、训练强化学习智能体网络。
DDQN强化学习算法在训练时对Q网络参数进行更新,保持Target Q网络参数保持不变。每经过一定的迭代轮次,将Q网络的参数复制到Target Q网络中,从而避免估计误差和过度估计问题。
假设Q网络的网络参数记作θ,目标网络的网络参数记作θ-,时序差分目标的计算方式如下:
上式中,r表示系统返回的即时奖励,s'表示下一时刻的系统状态,a'表示采取的动作,A表示动作空间,表示使用Q网络估计下一个状态获得最大Q值采取的动作,/>表示使用Target Q网络根据下一时刻的状态以及Q网络估计的动作估计价值。
定义损失函数l(θ)=(Ytarget-Q(s,a;θ))2,通过对损失函数的反向传播更新θ,其中s和a分别表示当前时刻以及当前时刻做出的动作。通过梯度下降法最小化损失函数l(θ),θ的更新过程为:
本实施例中,α设置为0.001,Target Q网络的更新频率为Q网络每更新20次,Target Q网络使用Q网络更新一次参数。
DDQN强化学习算法根据联邦学习系统的交互数据和Target Q网络构建的损失函数更新所述Q网络的参数,使得获得的期望奖励达到收敛。
S8、使用Q网络指导联邦学习的设备选择。
云服务器向所有参与联邦学习的边缘设备发送全局模型参数并收集联邦学习系统的状态信息;将状态信息输入训练好的Q网络,Q网络输出每个动作的价值,也即Q值。将排名前k个Q值对应的设备作为当前状态下的最佳设备子集;最佳设备子集使用本地数据将接收的全局模型参数进行本地更新,然后将更新后的模型参数上传至服务器,由服务器执行模型聚合算法,以更新全局模型;不断执行以上过程,直到全局模型达到目标精度。
本发明还提供一种加快全局联邦学习模型收敛的联邦学习系统,包括一个云服务器、多个边缘设备和一个无线网络环境。其中,云服务器,用于存储和更新全局模型参数、接收和发送消息,以及运行强化学习智能体网络和执行模型聚合算法。边缘设备,用于存储和处理本地数据、计算节点评分、执行本地训练、接收和发送消息。无线网络环境,用于连接云服务器和边缘设备。
下面结合仿真实验对本发明的效果做进一步的说明:
(1)仿真实验条件:
仿真实验的硬件平台为:处理器为Intel(R)Core i7-12700H CPU,内存为16GB、显卡为NVIDIA GeForce RTX 3060。
仿真实验的软件平台为:win11操作系统,python 3.9.12,PyTorch1.12.1。
(2)仿真内容及仿真结果分析:
本发明仿真实验时使用了MNIST数据集,基于MNIST数据集构建IID数据集和Non-IID数据集。IID数据集是对每个类别的样本随机抽样50次,由总计500个样本组成。Non-IID数据集由主类样本和次类样本组成,其中主类样本占样本总数的70%,剩余的30%对次类样本均匀抽样得到。在实验中,将70%的设备分配Non-IID数据集,剩余30%的设备分配IID数据集。图5展示了强化学习智能体在训练阶段获得的回报随迭代轮次的变化曲线,随着迭代次数的增加,DDQN智能体在训练过程获得的奖励随迭代轮次的更加逐渐上升,并在60轮次迭代后趋近收敛。
本发明所提出的节点选择策略为通过Q网络指导联邦学习进行节点选择。下文结合仿真实验的描述中,将该节点选择策略称为LCNSFL-2算法,用来对比的两个算法是基于随机选择算法(Random Selection)和FedProx算法,从达到目标精度所需要的通信轮数、通信成本、能耗成本以及通信和能耗的加权成本机型对比,如表1所示。将目标精度设置为90%,通过实验结果发现,LCNSFL-2算法所需要的通信轮数最少,其通信、能耗以及加权成本远低于其他两种算法。
表1三种算法的性能对比
算法名称 | 数据集 | 通信轮数 | 通信成本 | 能耗成本 | 加权成本 |
LCNSFL-2 | MNIST | 18 | 3023.7 | 16175.0 | 9599.3 |
Random Selection | MNIST | 26 | 4116.3 | 25221.7 | 14669.0 |
FedProx | MNIST | 22 | 3549.5 | 19381.5 | 11465.5 |
通过以上所述以及图6可以得出,LCNSFL-2算法可以更快的达到目标精度,减少与服务器的通信轮次,从而降低了系统的通信和能耗成本。
以上所述,仅是本发明的较佳实施例而已,并非对本发明作任何形式上的限制,虽然本发明已以较佳实施例揭示如上,然而并非用以限定本发明,任何本领域技术人员,在不脱离本发明技术方案范围内,当可利用上述揭示的技术内容做出些许更动或修饰为等同变化的等效实施例,但凡是未脱离本发明技术方案内容,依据本发明的技术实质对以上实施例所作的任何简介修改、等同变化与修饰,均仍属于本发明技术方案的范围内。
Claims (10)
1.一种加快全局联邦学习模型收敛的方法,其特征在于:包括以下步骤:
S1、对联邦学习系统的时延进行建模分析;
S2、对联邦学习系统的能耗进行建模分析;
S3、确定优化目标;
S4、构建图像分类网络和强化学习智能体网络;
S5、设计图像分类网络的损失函数;
S6、将节点选择问题转化为马尔可夫决策过程;
S7、训练强化学习智能体网络;
S8、使用Q网络指导联邦学习的设备选择;
所述强化学习智能体网络采用DDQN强化学习算法;所述强化学习智能体网络包括所述Q网络和Target Q网络,所述Q网络和Target Q网络采用相同的网络结构。
3.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述对联邦学习系统的能耗进行建模分析,具体包括:
计算设备i在第k个通信轮次执行本地训练的能量消耗;
计算设备i进行模型传输的通信功耗;
结合所述能量消耗与所述通信功耗得到设备i的空闲能耗;
根据所述空闲能耗计算设备消耗的总能量。
4.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述优化目标描述为:
min(Loss(x;θ)) (1)
fmin≤fi≤fmax (2)
Tk<Tmax (5)
其中公式(1)表示优化目标,它表示以最小化全局模型在测试集上的损失函数,x表示测试集样本,θ表示全局模型参数,公式(2)表示对设备工作频率的约束,公式(3)表示被选设备的总带宽不大于服务器带宽B,公式(4)表示在一个通信轮次中至少有一个设备被选中参与联邦训练,而最大数量不超过设备总量N,公式(5)表示在第k个通信轮次的训练时延Tk不能超过规定的最大时延Tmax。
5.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述图像分类网络为两层的MLP网络;具体包括图像输入层、第一线性网络、第一激活函数层、第二线性网络、第二激活函数层和全连接网络。
6.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述设计图像分类网络的损失函数,具体包括:
所述图像分类网络的损失函数包括交叉熵损失函数(lce)和最大均值差异损失(lMMD);
其中,lce用于图像分类任务;lMMD用于衡量全局模型和本地模型在本地数据样本输入下的输出差异;
将lce+lMMD作为本地网络模型的损失函数,并对其进行梯度下降,以更新本地模型参数。
7.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述马尔可夫决策过程包括系统状态、动作空间、策略、奖励函数以及邻接状态;
其中,所述策略表示从状态空间到动作空间的映射;
所述奖励函数设置为最小化的时延与能耗的加权和。
8.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述训练强化学习智能体网络,具体包括:
计算时序差分目标;
根据所述时序差分目标定义损失函数;
通过梯度下降法最小化损失函数以更新所述Q网络的网络参数。
9.根据权利要求1所述的一种加快全局联邦学习模型收敛的方法,其特征在于:所述使用Q网络指导联邦学习的设备选择,具体包括:
云服务器向所有参与联邦学习的边缘设备发送全局模型参数并收集联邦学习系统的状态信息;
将状态信息输入Q网络,Q网络输出Q值,所述Q值表示每个动作的价值,并将排名前k个Q值对应的设备作为当前状态下的最佳设备子集;
最佳设备子集使用本地数据训练图像分类网络以更新本地模型,然后将本地模型上传至云服务器;
云服务器执行模型聚合算法,以更新全局模型;
不断执行以上过程,直到全局模型达到目标精度。
10.一种加快全局联邦学习模型收敛的联邦学习系统,其特征在于:包括云服务器、多个边缘设备和无线网络;
所述云服务器,用于存储和更新全局模型参数、接收和发送消息,以及运行强化学习智能体网络和执行模型聚合算法;
所述边缘设备,用于存储和处理本地数据、计算节点评分、执行本地训练、接收和发送消息;
所述无线网络,用于连接云服务器和边缘设备。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310262721.9A CN116416508A (zh) | 2023-03-17 | 2023-03-17 | 一种加快全局联邦学习模型收敛的方法及联邦学习系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310262721.9A CN116416508A (zh) | 2023-03-17 | 2023-03-17 | 一种加快全局联邦学习模型收敛的方法及联邦学习系统 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116416508A true CN116416508A (zh) | 2023-07-11 |
Family
ID=87055788
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310262721.9A Pending CN116416508A (zh) | 2023-03-17 | 2023-03-17 | 一种加快全局联邦学习模型收敛的方法及联邦学习系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116416508A (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117062132A (zh) * | 2023-10-12 | 2023-11-14 | 北京信息科技大学 | 兼顾时延与能耗的cf-uav智能传输信令交互方法 |
CN117076113A (zh) * | 2023-08-17 | 2023-11-17 | 重庆理工大学 | 一种基于联邦学习的工业异构设备多作业调度方法 |
CN117094381A (zh) * | 2023-08-21 | 2023-11-21 | 哈尔滨工业大学 | 一种兼顾高效通信和个性化的多模态联邦协同方法 |
CN117392483A (zh) * | 2023-12-06 | 2024-01-12 | 山东大学 | 基于增强学习的相册分类模型训练加速方法、系统及介质 |
CN117808125A (zh) * | 2024-02-29 | 2024-04-02 | 浪潮电子信息产业股份有限公司 | 模型聚合方法、装置、设备、联邦学习系统及存储介质 |
CN117910539A (zh) * | 2024-03-19 | 2024-04-19 | 电子科技大学 | 一种基于异构半监督联邦学习的家庭特征识别方法 |
CN117910539B (zh) * | 2024-03-19 | 2024-05-31 | 电子科技大学 | 一种基于异构半监督联邦学习的家庭特征识别方法 |
-
2023
- 2023-03-17 CN CN202310262721.9A patent/CN116416508A/zh active Pending
Cited By (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117076113A (zh) * | 2023-08-17 | 2023-11-17 | 重庆理工大学 | 一种基于联邦学习的工业异构设备多作业调度方法 |
CN117094381A (zh) * | 2023-08-21 | 2023-11-21 | 哈尔滨工业大学 | 一种兼顾高效通信和个性化的多模态联邦协同方法 |
CN117094381B (zh) * | 2023-08-21 | 2024-04-12 | 哈尔滨工业大学 | 一种兼顾高效通信和个性化的多模态联邦协同方法 |
CN117062132A (zh) * | 2023-10-12 | 2023-11-14 | 北京信息科技大学 | 兼顾时延与能耗的cf-uav智能传输信令交互方法 |
CN117062132B (zh) * | 2023-10-12 | 2024-01-09 | 北京信息科技大学 | 兼顾时延与能耗的cf-uav智能传输信令交互方法 |
CN117392483A (zh) * | 2023-12-06 | 2024-01-12 | 山东大学 | 基于增强学习的相册分类模型训练加速方法、系统及介质 |
CN117392483B (zh) * | 2023-12-06 | 2024-02-23 | 山东大学 | 基于增强学习的相册分类模型训练加速方法、系统及介质 |
CN117808125A (zh) * | 2024-02-29 | 2024-04-02 | 浪潮电子信息产业股份有限公司 | 模型聚合方法、装置、设备、联邦学习系统及存储介质 |
CN117808125B (zh) * | 2024-02-29 | 2024-05-24 | 浪潮电子信息产业股份有限公司 | 模型聚合方法、装置、设备、联邦学习系统及存储介质 |
CN117910539A (zh) * | 2024-03-19 | 2024-04-19 | 电子科技大学 | 一种基于异构半监督联邦学习的家庭特征识别方法 |
CN117910539B (zh) * | 2024-03-19 | 2024-05-31 | 电子科技大学 | 一种基于异构半监督联邦学习的家庭特征识别方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116416508A (zh) | 一种加快全局联邦学习模型收敛的方法及联邦学习系统 | |
CN112668128A (zh) | 联邦学习系统中终端设备节点的选择方法及装置 | |
CN109508812B (zh) | 一种基于深度记忆网络的航空器航迹预测方法 | |
WO2021017227A1 (zh) | 无人机轨迹优化方法、装置及存储介质 | |
CN113543176B (zh) | 基于智能反射面辅助的移动边缘计算系统的卸载决策方法 | |
CN113873022A (zh) | 一种可划分任务的移动边缘网络智能资源分配方法 | |
CN111629380B (zh) | 面向高并发多业务工业5g网络的动态资源分配方法 | |
CN111178486B (zh) | 一种基于种群演化的超参数异步并行搜索方法 | |
CN114492833A (zh) | 基于梯度记忆的车联网联邦学习分层知识安全迁移方法 | |
CN112116090A (zh) | 神经网络结构搜索方法、装置、计算机设备及存储介质 | |
CN113760511B (zh) | 一种基于深度确定性策略的车辆边缘计算任务卸载方法 | |
Dong et al. | Multi-exit DNN inference acceleration based on multi-dimensional optimization for edge intelligence | |
CN113778691B (zh) | 一种任务迁移决策的方法、装置及系统 | |
CN116390161A (zh) | 一种移动边缘计算中基于负载均衡的任务迁移方法 | |
CN115310360A (zh) | 基于联邦学习的数字孪生辅助工业物联网可靠性优化方法 | |
CN113887748B (zh) | 在线联邦学习任务分配方法、装置、联邦学习方法及系统 | |
Zhao et al. | Adaptive Swarm Intelligent Offloading Based on Digital Twin-assisted Prediction in VEC | |
CN112445617B (zh) | 一种基于移动边缘计算的负载策略选择方法及系统 | |
CN117255356A (zh) | 一种无线接入网中基于联邦学习的高效自协同方法 | |
CN116486192A (zh) | 一种基于深度强化学习的联邦学习方法及系统 | |
CN113722980A (zh) | 海洋浪高预测方法、系统、计算机设备、存储介质、终端 | |
Xiao et al. | Clustered federated multi-task learning with non-iid data | |
CN114371729B (zh) | 一种基于距离优先经验回放的无人机空战机动决策方法 | |
CN112396154A (zh) | 一种基于卷积神经网络训练的并行方法 | |
Fang et al. | Dependency-Aware Dynamic Task Offloading Based on Deep Reinforcement Learning in Mobile Edge Computing |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |