CN116910541A - 一种基于集群训练与梯度稀疏的联邦学习方法和装置 - Google Patents
一种基于集群训练与梯度稀疏的联邦学习方法和装置 Download PDFInfo
- Publication number
- CN116910541A CN116910541A CN202310792905.6A CN202310792905A CN116910541A CN 116910541 A CN116910541 A CN 116910541A CN 202310792905 A CN202310792905 A CN 202310792905A CN 116910541 A CN116910541 A CN 116910541A
- Authority
- CN
- China
- Prior art keywords
- training
- cluster
- round
- clusters
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 110
- 238000000034 method Methods 0.000 title claims abstract description 39
- 230000002776 aggregation Effects 0.000 claims abstract description 30
- 238000004220 aggregation Methods 0.000 claims abstract description 30
- 238000012360 testing method Methods 0.000 claims abstract description 11
- 230000005540 biological transmission Effects 0.000 claims description 7
- 230000006870 function Effects 0.000 claims description 7
- 239000011159 matrix material Substances 0.000 claims description 6
- 238000004590 computer program Methods 0.000 claims description 2
- 238000004891 communication Methods 0.000 abstract description 47
- 238000009826 distribution Methods 0.000 abstract description 20
- 238000004422 calculation algorithm Methods 0.000 abstract description 18
- 230000008569 process Effects 0.000 description 9
- 238000010586 diagram Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 5
- 238000002474 experimental method Methods 0.000 description 5
- 238000013527 convolutional neural network Methods 0.000 description 3
- 230000007423 decrease Effects 0.000 description 3
- 230000007547 defect Effects 0.000 description 3
- 238000011156 evaluation Methods 0.000 description 3
- 238000011144 upstream manufacturing Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 239000012141 concentrate Substances 0.000 description 1
- 230000007786 learning performance Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000005192 partition Methods 0.000 description 1
- 238000000926 separation method Methods 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- 239000002699 waste material Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Bioethics (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Medical Informatics (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于集群训练与梯度稀疏的联邦学习方法和装置,所述方法包括:通过客户端获取本地的数据集;将客户端随机划分为N个集群,获取每个集群的客户端节点;将中央服务器初始化后的全局模型,发送给N个集群的头节点;根据本地的数据集对接收的全局模型进行多轮训练,更新模型参数;在N个集群完成一轮训练后,通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,并下载聚合后的模型参数,传输给下一轮N个集群的头节点继续训练,直到到达预设定的迭代轮次;输出联邦学习训练出的模型以及测试精度;本发明使得通信效率得到提升的同时,并且相较于FedAvg算法,在数据非独立同分布的情况下有更高的效率。
Description
技术领域
本发明涉及一种基于集群训练与梯度稀疏的联邦学习方法和装置,属于机器学习技术领域。
背景技术
联邦学习[1](Federated Learning)是一种分布式机器学习技术,旨在训练一个全局模型而无需将数据集中在一个中央服务器上。在联邦学习中,参与者(设备或数据中心)使用本地数据集训练自己的本地模型,然后将本地模型的参数更新发送给中央服务器。中央服务器聚合这些本地模型参数更新,以更新全局模型。该过程可以重复多次,直到全局模型收敛或达到预定的停止标准。
联邦学习的主要优点是能够在不暴露用户数据的情况下训练模型,从而保护用户隐私。此外,由于参与者的本地数据集通常反映了不同的特征和分布,因此联邦学习还可以提高模型的泛化能力和鲁棒性。
在传统的联邦学习中,存在着数据不平衡、通信效率低下、隐私安全缺陷等问题。其中数据不平衡问题指的是在实际的联邦学习场景中,参与训练的客户端数据来源不同。例如,在医疗领域中,各个医院的数据来源可能不同,例如不同病人的数据分布、不同医生的诊断方式等,导致各个医院的数据分布不同,这就导致了Non-IID问题。在这种情况下,传统的联邦学习算法收敛速度会变慢,甚至可能无法收敛。而在传统算法的联邦学习场景假设中,所有的客户端的数据是满足独立同分布的,即各个客户端的数据分布相同并且相互独立,这使得最后的准确率接近于数据集中式训练(Centralizedtraining,CL)。
通常情况下,由于CL集中了所有的数据到一个中心服务器进行训练,可以充分利用所有的数据来更新模型,所以CL可以获得更高的准确率。但是由于隐私和安全考虑,很多情况下并不适合将所有数据集中到一个中心服务器进行训练。而在联邦学习中,数据被分散在不同的设备上进行训练,每个设备只能用自己本地的数据来更新模型,因此会导致模型性能下降。
Non-IID场景对联邦学习算法有两个主要的影响:
1.影响本地模型更新。由于各个客户端的数据分布不同,因此每个客户端训练出来的模型效果也不同。在传统联邦学习算法中,每个客户端更新上一轮的全局梯度方向有较大偏差。这种偏差可能会导致全局模型在训练过程中出现较大的震荡,甚至无法收敛。
影响全局模型聚合。各个客户端数据量的差距可能会导致全局模型更加关注某些客户端的数据,从而使得大数据量的客户端在聚合时权重占比较大,同时也会导致小数据量的客户端对全局模型的更新贡献占比小,从而影响全局模型的性能。
如图1所示,对比了数据集中式训练CL与在联邦学习场景中FedAvg与单客户端训练的梯度方向。其中N为联邦学习参与方的数量;w0表示初始梯度;表示Client1在只进行T轮本地更新后的梯度方向;/>表示N个参与方在FedAvg算法经过T轮模型参数聚合后的全局梯度方向;/>与/>类似,表示Client n在仅本地更新T轮后的梯度方向。从图中可以看出,IID场景下,由于各客户端数据分布情况差距小,FedAvg聚合得到的全局梯度方向与CL方法相似。而在Non-IID场景下,由于各客户端数据分布的异构性,导致FedAvg聚合的全局模型无法很好地适应各个客户端上的数据,因此FedAvg与CL方法之间表现出较大的差距。
联邦学习的通信分为上行链路和下行链路。上行链路通常指参与方将本地模型参数上传至中央服务器,下行链路则是服务器将每轮聚合后的全局模型参数发送给参与方。通常情况下,上行链路与下行链路具有不同的通信状态。
参与方需要频繁地进行上行链路通信以上传本轮本地模型的参数,在这个过程中,需要保证上传的数据量尽可能小,以减少网络拥塞和带宽浪费。由于参与方的计算能力和数据量不同,导致参与方上传本地模型参数的时刻也不同,会出现一些参与方上传完毕后需要等待其他参与方完成上传,才能进行下一轮模型更新。上行链路通信频率过高也会导致中央服务器通信负载过大,FedAvg相较于FedSGD算法减少了数倍的通信频率,同时也使得模型收敛的速度成倍增加。
下行链路通常发生在较长的时间间隔内,因为要等待中央服务器聚合完才能进行下行链路通信。下行链路通信频率相较于上行链路通信较低,但是传输的数据量较大,因此需要确保数据传输的准确性和完整性。
除此之外,为了保证通信的安全性,联邦学习还需要采取一些安全措施,例如差分隐私机制对原始数据添加噪声、同态加密时对原始数据进行扰动和转化以实现计算和保护隐私等操作,都会增加一定传输的数据量。
在传统的联邦学习中,存在着数据不平衡、通信效率低下、隐私安全缺陷等问题。
发明内容
本发明的目的在于克服现有技术中的不足,提供一种基于集群训练与梯度稀疏的联邦学习方法,使得通信效率得到提升的同时,并且相较于FedAvg算法,在数据非独立同分布的情况下有更高的效率。
为达到上述目的,本发明是采用下述技术方案实现的:
第一方面,本发明提供了一种基于集群训练与梯度稀疏的联邦学习方法,适用于中央服务器和多个客户端,所述方法包括:
通过客户端获取本地的数据集;
将客户端随机划分为N个集群,获取每个集群的客户端节点,包括头结点、尾结点和中继节点;
将中央服务器初始化后的全局模型,发送给N个集群的头节点;
在N个集群的头节点中,根据本地的数据集对接收的全局模型进行多轮训练,更新模型参数;
在N个集群完成一轮训练后,通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,并下载聚合后的模型参数,传输给下一轮N个集群的头节点继续训练,直到到达预设定的迭代轮次;
输出联邦学习训练出的模型以及测试精度。
进一步的,所述根据本地的数据集对接收的全局模型进行多轮训练,在奇数轮时,头节点接收上一轮由中央服务器聚合的N个集群的梯度,头节点本地训练的公式如下:
其中,表示奇数轮头节点更新的本地模型参数;w*r-1表示上一轮中央服务器聚合后进行稀疏操作的全局梯度,其中r表示当前训练轮次;η与F分别表示学习率和损失函数;/>表示头节点客户端上的训练数据,b表示本地训练的轮数。
进一步的,所述模型参数在传输前,先通过客户端对本地模型参数做梯度稀疏操作,其中,头结点做梯度稀疏操作时,公式如下:
表示头节点对本地更新模型参数/>进行Sparsek()操作后得到的参数。
进一步的,所述Sparsek()操作包括如下步骤:
对梯度信息进行top-K操作,取前K个绝对值最大的梯度;
将top-K操作之后的稠密矩阵转化为稀疏矩阵。
进一步的,在集群中,所述客户端按照设定顺序进行串行训练,在奇数轮训练时,头节点训练完本地模型并做topk()操作后将把模型参数/>传输给集群内中继节点以此类推,直到尾节点/>完成本地更新执行Sparsek()函数,则集群n得到了即w* n,则参与该轮集群内训练的中继节点与尾节点本地更新公式如下:
进一步的,所述通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,包括:
通过每个集群的尾节点向中央服务器Server传输代表集群n的模型参数w* n,Server收到w*={w* 1,w* 2,…,w* N}进行模型聚合操作得到wr,公式如下:
其中r表示当前轮次。
进一步的,在模型聚合操作之后,中央服务器将对wr进行Sparsek()操作得到ω*r,并将ω*r传输给下一轮N个集群的头节点。
进一步的,所述全局模型在完成一轮训练后,需要将全局模型传递给与上一轮不同的头节点,并且在新的一轮训练中,参与集群内训练的中继节点和尾节点也与上一轮不同,且互斥。
第二方面,本发明提供一种基于集群训练与梯度稀疏的联邦学习装置,包括:
获取模块,用于通过客户端获取本地的数据集;
划分模块,用于将客户端随机划分为N个集群,获取每个集群的客户端节点,包括头结点、尾结点和中继节点;
初始化模块,用于将中央服务器初始化后的全局模型,发送给N个集群的头节点;
训练模块,用于在N个集群的头节点中,根据本地的数据集对接收的全局模型进行多轮训练,更新模型参数;
聚合模块,用于在N个集群完成一轮训练后,通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,并下载聚合后的模型参数,传输给下一轮N个集群的头节点继续训练,直到到达预设定的迭代轮次;
输出模块,用于输出联邦学习训练出的模型以及测试精度。
第三方面,本发明提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现前述任一项所述方法的步骤。
与现有技术相比,本发明所达到的有益效果:
本发明提供一种基于集群训练与梯度稀疏的联邦学习方法和装置,利用集群训练的思想来减小非独立同分布(Non-IID)数据对模型性能的影响,同时减小中央服务器的通信负载,并采用梯度稀疏的思想来减少客户端与客户端之间、客户端与中央服务器之间的通信消耗,从而提升通信效率,经过在MINST数据集上的实验表明,我们提出的联邦学习框架能在联邦学习的场景中高效训练模型。
附图说明
图1是本发明背景技术提供的提供的现有技术梯度更新方向示意图;
图2是本发明实施例提供的FedAvg与FedOES(s=0.5)的收敛情况示意图;
图3是本发明实施例提供的FedOES在各稀疏率下的收敛情况示意图;
图4是本发明实施例提供的传统的联邦学习架构示意图;
图5是本发明实施例提供的基于集群训练的奇偶轮联邦学习架构示意图;
图6是本发明实施例提供的各客户端MNIST数据集的标签分布示意图。
具体实施方式
下面结合附图对本发明作进一步描述。以下实施例仅用于更加清楚地说明本发明的技术方案,而不能以此来限制本发明的保护范围。
实施例1
本实施例介绍一种基于集群训练与梯度稀疏的联邦学习方法,适用于中央服务器和多个客户端,所述方法包括:
通过客户端获取本地的数据集;
将客户端随机划分为N个集群,获取每个集群的客户端节点,包括头结点、尾结点和中继节点;
将中央服务器初始化后的全局模型,发送给N个集群的头节点;
在N个集群的头节点中,根据本地的数据集对接收的全局模型进行多轮训练,更新模型参数;
在N个集群完成一轮训练后,通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,并下载聚合后的模型参数,传输给下一轮N个集群的头节点继续训练,直到到达预设定的迭代轮次;
输出联邦学习训练出的模型以及测试精度。
本实施例提供的基于集群训练与梯度稀疏的联邦学习方法,其应用过程具体涉及如下步骤:
基于分轮集群训练的联邦学习算法(Federated Learning with odd-even roundand Gradient Sparsification,FedOES)的核心思想是将参与训练的客户端划分成若干集群,集群内部按奇偶轮次的顺序选择部分客户端参与集群内部的串行训练,每次传输都会对梯度进行一定的稀疏操作,从而降低通信数据量。中央服务器每轮拿到各个集群产出的梯度信息进行聚合。在这个过程中,联邦学习的通信效率得到了提高,并且基于集群训练的方法对Non-IID的联邦学习场景有较强的鲁棒性。
参与联邦学习的整个流程的角色有两种,分别为客户端和中央服务器。其中客户端的职责是根据本地的数据集对接收的模型进行训练,从而更新模型参数。中央服务器的职责是接收联邦学习上行链路中传输的模型,进行聚合操作。数据分布情况会对联邦学习效果产生重要影响,在IID场景中,客户端之间的数据分布相似;而在Non-IID场景中,客户端之间的数据分布差异较大。本实施例研究了IID与Non-IID两种数据划分方式在FedOES算法中的表现。
FedOES联邦学习架构可以被分为三个步骤,分别为初始化、集群训练、模型聚合。
为了更好地利用联邦学习中多个客户端的数据资源,以提高模型的性能和泛化能力,本实施例将客户端随机划分为N个集群,假设每个集群中的客户端数量为偶数2k,那么一个集群中的客户端节点可记为
一个集群的节点中有三种角色,分别为头节点(head node)、尾节点(tail node)和中继节点(relaying node)。集群内采用分轮训练,本实施例讨论分奇偶轮的训练情况。由于训练过程被分为奇偶轮,所以集群中分别有两个头节点、两个尾节点以及若干中继节点。在集群n中,头节点用表示,尾节点用/>表示;剩下的中继节点被标记为特别地,参与奇数轮训练的头节点被标记为/>尾节点被标记为/>参与偶数轮训练的头节点被标记为/>尾节点被标记为/>
下面对FedOES的流程进行详细分析。如图2所示为FedOES的算法架构图,如图4所示为FedOES的算法伪代码。
初始化:中央服务器Server初始化全局模型w0,发送给N个集群的头节点
集群内训练:以集群n中的奇数轮为例,头节点接收上一轮由中央服务器Server聚合的N个集群的梯度。头节点本地训练的公式如下:
其中表示奇数轮头节点更新的本地模型参数;w*r-1表示上一轮中央服务器聚合后进行稀疏操作的全局梯度,其中r表示当前训练轮次;η与F分别表示学习率和损失函数;表示头节点客户端上的训练数据,b表示本地训练的轮数。
为了减少模型参数的传输通信消耗,每个客户端要对本地模型参数做梯度稀疏操作,以头节点为例,公式如下:
表示头节点对本地更新模型参数/>进行Sparsek()操作后得到的参数,Sparsek()会使得模型参数占用的内存更小,有利于模型参数的传输。
本实施例的Sparsek()函数分为两步操作:(1)对梯度信息进行top-K操作,取前K个绝对值最大的梯度。(2)将top-K操作之后的稠密矩阵转化为稀疏矩阵。
由于在集群中,客户端是按照一定的顺序串行训练的。以奇数轮为例,在头节点训练完本地模型并做topk()操作后将把模型参数/>传输给集群内中继节点/>以此类推,直到尾节点/>完成本地更新执行Sparsek()函数,则集群n得到了/>即w* n。参与该轮集群内训练的中继节点与尾节点本地更新公式如下:
模型聚合:当N个集群完成本轮训练后,每个集群的尾节点将向中央服务器Server传输代表集群n的模型参数w* n。Server收到w*={w* 1,w* 2,…,w* N}进行模型聚合操作得到wr,公式如下:
其中r表示当前轮次。模型聚合操作之后,中央服务器将对wr进行Sparek()操作得到ω*r,并将ω*r传输给下一轮N个集群的头节点。特别地,由于本实施例提出的联邦学习架构采用分轮训练的方式,每轮参与集群内训练的客户端不同,因此需要将全局模型传递给与上一轮不同的头节点,并且在新的一轮训练中,参与集群内训练的中继节点和尾节点也与上一轮不同,且互斥。这一举措是为了确保每个节点能够在分轮训练的模式下,避免在相邻的轮次中重复参与训练。
模型的应用步骤:
第一步:各客户端获取本地的MNIST数据集图像。
第二步:选择合适该任务的卷积神经网络。
第三步:在每一轮中,各集群将模型参数发送给中央服务器进行梯度聚合并下载聚合后的模型参数,直到到达预设定的迭代轮次。
第四步:输出联邦学习训练出的模型以及测试精度。
在该方法实施有以下效果:
1.在数据独立同分布的场景下,准确率接近集中式训练;在数据非独立同分布的场景下,准确率高于传统联邦学习算法FedAvg。
2.相较于FedAvg算法,减少了大量的通信成本消耗。
本实施例的实验以联邦学习每轮模型聚合后的测试准确率以及每轮的通信消耗作为评价指标,下面给出准确率Accuracy和通信消耗Communication cost的定义。
测试准确率是用来衡量联邦学习算法性能最关键的指标,如公式所示:
其中,yn表示样本n的标签,表示模型对yn的预测值。/>则表示模型预测结果与真实标签相同的样本数。
假设有N个参与者,M个轮次的联邦学习训练过程,则可以用以下公式计算通信消耗:
公式(4-7)和公式(4-8)分别代表上传数据量和下载数据量,其中ni,j表示第i个参数者在第j轮上传的数据量,di,j表示第i个参数者在第j轮下载的数据量。
本实验考虑了分别在数据IID与Non-IID的情况下,验证基于分轮集群训练的联邦学习算法FedOES的学习性能。本实验将最常用的联邦学习算法FedAvg与FedOES在不同稀疏率s下的表现作对比分析,其中s代表选择的绝对值最大的权重张量中参数的占比,即top-K操作选择的占比。如图6所示,是本实验设置的Non-IID场景下的数据分布情况。
(1)FedAvg与FedOES(s=0.5)的实验情况
如图2和表1所示,展示了FedAvg与FedOES分别在数据IID与Non-IID下的性能对比。在每一轮的训练中,所有的客户端都参与FedAvg训练,而在FedOES中,每轮只有一半的客户端参与训练。实验结果显示,FedOES可以在较少的通信轮次达到收敛。当稀疏率s取0.5时,FedOES在第5轮、第10轮、第20轮的通信轮次中,无论是在IID还是Non-IID场景下,准确率显著高于FedAvg。经过多轮通信,FedOES的准确率略高于FedAvg,分别达到了98.85%和98.10%,高于FedAvg的98.83%和97.97%。
表1 FedAvg与FedOES(s=0.5)在不同轮次的表现
在本实验中,FedAvg的上游通信是指100个参与训练的客户端将本地更新的模型参数上传给中央服务器Server,而下游通信是指100个参与方下载每轮聚合的全局模型参数。而在FedOES中,上游通信是指各集群内部参与当前轮次训练的客户端将模型参数传递给集群训练内部的下一个客户端节点,以及各集群尾节点将模型参数传递给中央服务器Server。下游通信是指各集群中参与当前轮次训练的头节点下载上一轮的全局模型参数。因此,在FedOES中,以本实验设置为例,每轮有10个客户端将模型参数上传至Server,新一轮有10个客户端下载Server聚合的模型参数,极大地减小了Server的通信负载。
表中展示了FedAvg与FedOES分别在第5轮、第10轮、第20轮、第50轮、第100轮累计的通信开销。与FedAvg相比,FedOES的上游通信消耗降低了1.45倍,下游通信消耗降低了11.31倍。
(2)FedOES在不同稀疏率下的表现
如图3所示,FedOES在稀疏率s为0.3与0.2时,联邦学习模型在轮次较高时收敛的过程并不是完全平稳的,存在一定的波动。如表2所示,可以看出在不同通信轮次下,FedOES算法在不同稀疏率s下的准确率和通信消耗(upload和download)情况。与IID数据分布相比,在Non-IID数据分布下,准确率相对较低。随着通信轮次的增加,准确率也逐渐提高。同时,在相同的通信轮次、数据分布下,随着稀疏率s的下降,上传和下载的通信消耗逐渐下降,模型精度也逐渐下降。例如,在第100轮中,当稀疏率s为0.5时,FedOES算法在IID数据分布下的准确率为98.85%,上传和下载的通信消耗分别为365.40MB和73.08MB;而当稀疏率从0.5降至0.2时,在IID数据分布下的准确率降至97.82%,上传和下载的通信消耗也将至149.00MB和29.80MB。
表2FedOES各轮次不同稀疏率下的表现
综上所述,本实施例提出的FedOES在稀疏率s为0.5的情况下相较于FedAvg每轮的总通信量减少了3.1倍。在稀疏率s为0.4、0.3、0.2的情况下,总通信量分别减少了5.1、6.8、10.1倍,但是测试精度都有了不同程度的下降,主要原因是稀疏度高导致部分梯度信息丢失,上传的梯度变得更加分散和不连续,最终导致了模型性能的下降。因此为了获得更好的性能,选择合适的稀疏率s非常重要。
下面结合一个优选实施例,对上述实施例中涉及到的内容进行说明。
在这项研究中,我们使用了MNIST数据集进行验证,其包含70000个手写数字图像,其中60000个用于训练,10000个用于测试。为评估效果,我们设计了两种数据分离方法:IID和Non-IID。采用了100个客户端参与训练,将客户端分为10个集群,每个集群中包含10个客户端,每个客户端随机分配了600个训练示例。参与联邦学习训练的客户端训练样本划分如图4-图6所示。本研究采用了包含两个卷积层和两个池化层的卷积神经网络(CNN)。本实施例的参数如表3所示。
表3.实验参数
参数 | 值 |
客户端数量 | 100 |
通信轮数 | 100 |
本地批量大小 | 10 |
客户端每轮本地训练轮数 | 5 |
学习率 | 0.01 |
动量 | 0.5 |
集群数量 | 10 |
每个集群中客户端数量 | 10 |
FedOES的流程:实施例1的前述步骤中给出了FedOES的详细过程;
评价指标:实施例1的前述步骤中给出了本实施例的评价指标;
本实施例利用集群训练的思想来减小非独立同分布(Non-IID)数据对模型性能的影响,同时减小中央服务器的通信负载,并采用梯度稀疏的思想来减少客户端与客户端之间、客户端与中央服务器之间的通信消耗,从而提升通信效率。经过在MINST数据集上的实验表明,我们提出的联邦学习框架能在联邦学习的场景中高效训练模型。
以上所述仅是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明技术原理的前提下,还可以做出若干改进和变形,这些改进和变形也应视为本发明的保护范围。
Claims (10)
1.一种基于集群训练与梯度稀疏的联邦学习方法,适用于中央服务器和多个客户端,其特征在于,所述方法包括:
通过客户端获取本地的数据集;
将客户端随机划分为N个集群,获取每个集群的客户端节点,包括头结点、尾结点和中继节点;
将中央服务器初始化后的全局模型,发送给N个集群的头节点;
在N个集群的头节点中,根据本地的数据集对接收的全局模型进行多轮训练,更新模型参数;
在N个集群完成一轮训练后,通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,并下载聚合后的模型参数,传输给下一轮N个集群的头节点继续训练,直到到达预设定的迭代轮次;
输出联邦学习训练出的模型以及测试精度。
2.根据权利要求1所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,所述根据本地的数据集对接收的全局模型进行多轮训练,在奇数轮时,头节点接收上一轮由中央服务器聚合的N个集群的梯度,头节点本地训练的公式如下:
其中,表示奇数轮头节点更新的本地模型参数;w*r-1表示上一轮中央服务器聚合后进行稀疏操作的全局梯度,其中r表示当前训练轮次;η与F分别表示学习率和损失函数;/>表示头节点客户端上的训练数据,b表示本地训练的轮数。
3.根据权利要求2所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,所述模型参数在传输前,先通过客户端对本地模型参数做梯度稀疏操作,其中,头结点做梯度稀疏操作时,公式如下:
表示头节点对本地更新模型参数/>进行Sparsek()操作后得到的参数。
4.根据权利要求3所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,所述Sparsek()操作包括如下步骤:
对梯度信息进行top-K操作,取前K个绝对值最大的梯度;
将top-K操作之后的稠密矩阵转化为稀疏矩阵。
5.根据权利要求4所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,在集群中,所述客户端按照设定顺序进行串行训练,在奇数轮训练时,头节点训练完本地模型并做topk()操作后将把模型参数/>传输给集群内中继节点/>以此类推,直到尾节点完成本地更新执行Sparsek()函数,则集群n得到了/>即w* n,则参与该轮集群内训练的中继节点与尾节点本地更新公式如下:
6.根据权利要求5所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,所述通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,包括:
通过每个集群的尾节点向中央服务器Server传输代表集群n的模型参数w* n,Server收到w*={w* 1,w* 2,…,w* N}进行模型聚合操作得到wr,公式如下:
其中r表示当前轮次。
7.根据权利要求6所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,在模型聚合操作之后,中央服务器将对wr进行Sparsek()操作得到ω*r,并将ω*r传输给下一轮N个集群的头节点。
8.根据权利要求7所述的基于集群训练与梯度稀疏的联邦学习方法,其特征在于,所述全局模型在完成一轮训练后,需要将全局模型传递给与上一轮不同的头节点,并且在新的一轮训练中,参与集群内训练的中继节点和尾节点也与上一轮不同,且互斥。
9.一种基于集群训练与梯度稀疏的联邦学习装置,其特征在于,包括:
获取模块,用于通过客户端获取本地的数据集;
划分模块,用于将客户端随机划分为N个集群,获取每个集群的客户端节点,包括头结点、尾结点和中继节点;
初始化模块,用于将中央服务器初始化后的全局模型,发送给N个集群的头节点;
训练模块,用于在N个集群的头节点中,根据本地的数据集对接收的全局模型进行多轮训练,更新模型参数;
聚合模块,用于在N个集群完成一轮训练后,通过每个集群的尾结点将更新后的模型参数发送至中央服务器进行梯度聚合,并下载聚合后的模型参数,传输给下一轮N个集群的头节点继续训练,直到到达预设定的迭代轮次;
输出模块,用于输出联邦学习训练出的模型以及测试精度。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于:该程序被处理器执行时实现权利要求1~8任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310792905.6A CN116910541A (zh) | 2023-06-30 | 2023-06-30 | 一种基于集群训练与梯度稀疏的联邦学习方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310792905.6A CN116910541A (zh) | 2023-06-30 | 2023-06-30 | 一种基于集群训练与梯度稀疏的联邦学习方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116910541A true CN116910541A (zh) | 2023-10-20 |
Family
ID=88364007
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310792905.6A Pending CN116910541A (zh) | 2023-06-30 | 2023-06-30 | 一种基于集群训练与梯度稀疏的联邦学习方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116910541A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117196069A (zh) * | 2023-11-07 | 2023-12-08 | 中电科大数据研究院有限公司 | 联邦学习方法 |
CN117932544A (zh) * | 2024-01-29 | 2024-04-26 | 福州城投新基建集团有限公司 | 基于多源传感器数据融合的预测方法、装置和存储介质 |
-
2023
- 2023-06-30 CN CN202310792905.6A patent/CN116910541A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117196069A (zh) * | 2023-11-07 | 2023-12-08 | 中电科大数据研究院有限公司 | 联邦学习方法 |
CN117196069B (zh) * | 2023-11-07 | 2024-01-30 | 中电科大数据研究院有限公司 | 联邦学习方法 |
CN117932544A (zh) * | 2024-01-29 | 2024-04-26 | 福州城投新基建集团有限公司 | 基于多源传感器数据融合的预测方法、装置和存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116910541A (zh) | 一种基于集群训练与梯度稀疏的联邦学习方法和装置 | |
CN106062786B (zh) | 用于训练神经网络的计算系统 | |
CN106297774B (zh) | 一种神经网络声学模型的分布式并行训练方法及系统 | |
CN113705610B (zh) | 一种基于联邦学习的异构模型聚合方法和系统 | |
Jiang et al. | Fedmp: Federated learning through adaptive model pruning in heterogeneous edge computing | |
WO2016119429A1 (zh) | 用于神经网络中训练参数集的系统和方法 | |
CN113469373B (zh) | 基于联邦学习的模型训练方法、系统、设备及存储介质 | |
CN112637883B (zh) | 电力物联网中对无线环境变化具有鲁棒性的联邦学习方法 | |
CN113206887A (zh) | 边缘计算下针对数据与设备异构性加速联邦学习的方法 | |
Yoon et al. | Bitwidth heterogeneous federated learning with progressive weight dequantization | |
Arouj et al. | Towards energy-aware federated learning on battery-powered clients | |
CN112235062A (zh) | 一种对抗通信噪声的联邦学习方法和系统 | |
Jiang et al. | Computation and communication efficient federated learning with adaptive model pruning | |
CN115829055A (zh) | 联邦学习模型训练方法、装置、计算机设备及存储介质 | |
CN114925854A (zh) | 一种基于梯度相似性度量的联邦学习节点选择方法及系统 | |
Liu et al. | FedAGL: A communication-efficient federated vehicular network | |
Li et al. | AFedAvg: Communication-efficient federated learning aggregation with adaptive communication frequency and gradient sparse | |
Zhang et al. | FedSL: A Communication Efficient Federated Learning With Split Layer Aggregation | |
Cai et al. | High-efficient hierarchical federated learning on non-IID data with progressive collaboration | |
Li et al. | FedOES: An efficient federated learning approach | |
CN117196058A (zh) | 一种基于节点贡献聚类的公平联邦学习方法 | |
CN116010832A (zh) | 联邦聚类方法、装置、中心服务器、系统和电子设备 | |
CN117313832A (zh) | 基于双向知识蒸馏的联合学习模型训练方法、装置及系统 | |
CN115345320A (zh) | 一种在分层联邦学习框架下实现个性化模型的方法 | |
Beitollahi et al. | DSFL: Dynamic sparsification for federated learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |