CN117151208B - 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 - Google Patents
基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 Download PDFInfo
- Publication number
- CN117151208B CN117151208B CN202310985134.2A CN202310985134A CN117151208B CN 117151208 B CN117151208 B CN 117151208B CN 202310985134 A CN202310985134 A CN 202310985134A CN 117151208 B CN117151208 B CN 117151208B
- Authority
- CN
- China
- Prior art keywords
- gradient
- global
- representing
- learning rate
- neural network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 77
- 238000003860 storage Methods 0.000 title claims abstract description 8
- 238000012549 training Methods 0.000 claims abstract description 38
- 230000006870 function Effects 0.000 claims abstract description 23
- 238000003062 neural network model Methods 0.000 claims description 50
- 230000001360 synchronised effect Effects 0.000 claims description 20
- 230000003044 adaptive effect Effects 0.000 claims description 16
- 230000008569 process Effects 0.000 claims description 11
- 238000004590 computer program Methods 0.000 claims description 9
- 238000011156 evaluation Methods 0.000 abstract description 5
- 230000002776 aggregation Effects 0.000 description 15
- 238000004220 aggregation Methods 0.000 description 15
- 238000010586 diagram Methods 0.000 description 9
- 238000004891 communication Methods 0.000 description 8
- 238000009826 distribution Methods 0.000 description 7
- 238000004422 calculation algorithm Methods 0.000 description 5
- 230000002596 correlated effect Effects 0.000 description 4
- 230000003111 delayed effect Effects 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 238000002474 experimental method Methods 0.000 description 4
- 238000012545 processing Methods 0.000 description 4
- 238000013459 approach Methods 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 230000032683 aging Effects 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 230000001934 delay Effects 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 230000009467 reduction Effects 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 230000002411 adverse Effects 0.000 description 1
- 230000002238 attenuated effect Effects 0.000 description 1
- 230000000875 corresponding effect Effects 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000013499 data model Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000009977 dual effect Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质,为了解决异步联邦学习面临着双重挑战:陈旧性问题和数据集不平衡问题,本发明分别在中心服务器和工作节点解决如上问题,中心服务器接收完K个梯度后,首先进行无偏梯度估计,并实施一种基于余弦相似度的新型评估方法,以衡量延迟梯度的陈旧度;同时进一步调整学习速率,更新并广播模型参数和迭代次数。对于数据集不平衡问题,工作节点引入了一个类平衡损失函数,可以处理异质性数据对于模型训练的影响,本发明根据延时程度自适应调整学习速率,提高了模型的预测精度。
Description
技术领域
本发明属于数据安全领域,尤其涉及一种基于自适应学习率的异步联邦学习参数更新方法、设备及系统。
背景技术
近年来,随着移动和边缘设备已经广泛采用,并为各种应用生成了大量有价值的数据。这些设备也增加了机器学习的需求,以实现个性化和低延迟的AI应用。然而,由于隐私和带宽限制,集中式数据收集和模型训练是不可行的。因此,针对这些问题,谷歌提出了联邦学习用于解决机器学习模型训练的数据需求与用户数据隐私保护之间的矛盾。联邦学习已经成为一种新的范式,使得在大量边缘设备(客户端)之间进行协作机器学习成为可能,而无需共享他们的数据。联邦学习也可以用于用户数据必须保密或不能离开其原始环境的场景,例如在医疗和金融领域。
经典的联邦学习方法大多数是运行在同步的系统中,在每一次迭代中,中心服务器会随机抽取一些工作节点基于本地数据完成本地训练,工作节点将训练好的模型上传至中心服务器,随后中心服务器将收集到的模型参数进行聚合,再向每个工作节点发送更新后的模型。但是在设备异质性和网络异质性的情景中,经典的联邦学习方法面临着拖延者效应,会导致每一轮迭代的运行时间变长,所以联邦学习每一轮迭代的运行时间由最慢的学习者决定。
部分学者已经提出异步联合学习来解决这个问题,每个客户端独立地更新全局模型,这显示了更大的灵活性和可扩展性。在每一轮迭代中,完成本地训练的工作节点上传其模型参数,当中心服务器收到K个更新之后,中心服务器开始进行参数聚合。没有参加本轮聚合的工作节点继续完成其本地训练,等待参与下一轮的更新。异步联邦学习可以降低下一轮迭代中本地训练消耗的时间,从而缓解拖延者效应。
尽管K异步联邦学习方法具有以上优点,但是在实践中经常面临以下两个问题:1)延迟的模型梯度更新是基于陈旧模型进行的,因此延迟梯度相较于当前最新梯度具有一定的方向误差;2)由于多个工作节点上的数据类别分布通常不能服从独立同分布,这会造成不同工作节点的本地梯度更新方向均与中心服务器不一致,从而降低了模型的效应性,甚至会导致不收敛的问题。为了解决上述问题,现有的工作提出了基于两阶段训练策略的异步联邦学习方法,以加速训练并降低数据异质性的影响。但是该工作没有考虑到两阶段训练带来的巨大计算量和通信成本,同时现有工作中衡量梯度陈旧度的方法的策略是通过迭代滞后轮次或本地训练时间。显然,只有少数低延时的梯度会被聚合,大部分高延时的梯度将被过滤掉。
因此现有的技术需要一种能够既能有效缓解数据不平衡问题,又能解决延时梯度的异步联邦学习方法。
发明内容
发明目的:该发明旨在解决异步联邦学习中由于数据不平衡和延迟梯度导致的模型效用降低的问题。为此,本发明提出了一种基于自适应学习率的异步联邦学习参数更新方法,以解决异步联邦学习中的不平衡问题和陈旧问题。
在第一方面上,根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,用于中心服务器,包括
S110.中心服务器接收更新,更新包括由工作节点发出的梯度;
S120.中心服务器根据同步梯度估计全局无偏梯度;
S130.中心服务器根据全局无偏梯度计算延迟梯度的陈旧度;
S140.中心服务器根据陈旧度为延迟梯度调整学习率;
S150.中心服务器根据学习率更新全局神经网络模型;
S160.中心服务器将更新的全局神经网络模型的参数发出,更新的全局神经网络模型的参数由工作节点接收。
其中,同步梯度是工作节点依据最新的全局神经网络模型计算的梯度,延迟梯度是工作节点依据非最新的全局神经网络模型计算的梯度。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,步骤S10中,中心服务器接收的更新还包括迭代次数,中心服务器根据迭代次数达到预先定义的次数停止更新。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,步骤S10中,中心服务器将接收的更新加入队列,当队列的长度达到设定阈值,中心服务器执行步骤S20。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,所述步骤S110中还包括中心服务器将当前的迭代次数和当前的全局神经网络模型参数广播,广播由本地节点接收。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,估计全局无偏梯度由如下公式表示:
式中,g(wj)表示全局无偏梯度,n表示M个工作节点的本地样本总量,nm表示第m个工作节点的本地样本数量,g(wj,m,ξj,m)表示第m个工作节点的同步梯度,j表示当前更新轮次,wj,m表示第j轮第m个工作节点的神经网络模型的参数,ξj,m表示第j轮第m个工作节点的样本;
延迟梯度的陈旧度由如下公式表示:
式中,cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度的余弦相似性,用于表示梯度下降中的方向相似性,Gt表示全局无偏梯度g(wj),Gt-τ表示延迟梯度,∈表示超参数,s(τ)表示延迟梯度的陈旧度;
为延迟梯度调整学习率由如下公式表示:
式中,ητ表示调整之后的学习率,η0表示初始学习率,a表示陈旧度的阈值;
更新全局神经网络模型由如下公式表示:
wj+1表示更新后的全局神经网络模型,wj表示第j轮的全局神经网络模型,K表示参与当前更新的工作节点的个数,i表示第i个工作节点,j表示当前的第j个更新轮次,ηj,i表示第j轮第i个工作节点的学习率。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,工作节点的梯度更新的损失函数由如下公式表示:
式中,gi(x;w)表示第i类的logit,gy(x;w)是第y类的logit,γ是损失函数的超参数,i表示第i个类别,y表示第y个类别,表示第y类的实例数,/>表示第i类的实例数,c表示样本类别的总个数。
在第二方面上,根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,用于工作节点,包括
S210.工作节点接收全局神经网络模型的参数,全局神经网络模型的参数由中心服务器发出;
S220.工作节点依据全局神经网络模型的参数训练工作节点的本地模型;
S230.工作节点的本地模型进行梯度下降得到更新的参数;
S240.工作节点发出更新,更新包括梯度,工作节点依据最新的全局神经网络模型计算的梯度是同步梯度,工作节点依据非最新的全局神经网络模型计算的梯度是延迟梯度,其中,工作节点发出的同步梯度是由中心服务器接收的同步梯度,用于中心服务器估计全局无偏梯度,根据全局无偏梯度计算延迟梯度的陈旧度,根据陈旧度为延迟梯度调整学习率,根据学习率更新全局神经网络模型。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,所述步骤S230还包括工作节点根据更新次数达到预先定义的次数执行步骤S240,否则执行步骤S220。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,估计全局无偏梯度由如下公式表示:
式中,g(wj)表示全局无偏梯度,n表示M个工作节点的本地样本总量,nm表示第m个工作节点的本地样本数量,g(wj,m,ξj,m)表示第m个工作节点的同步梯度,j表示当前更新轮次,wj,m表示第j轮第m个工作节点的神经网络模型的参数,ξj,m表示第j轮第m个工作节点的样本;
延迟梯度的陈旧度由如下公式表示:
式中,cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度的余弦相似性,用于表示梯度下降中的方向相似性,Gt表示全局无偏梯度g(wj),Gt-τ表示延迟梯度,∈表示超参数,s(τ)表示延迟梯度的陈旧度;
为延迟梯度调整学习率由如下公式表示:
式中,ητ表示调整之后的学习率,η0表示初始学习率,a表示陈旧度的阈值;
更新全局神经网络模型由如下公式表示:
wj+1表示更新后的全局神经网络模型,wj表示第j轮的全局神经网络模型,K表示参与当前更新的工作节点的个数,i表示第i个工作节点,j表示当前的第j个更新轮次,ηj,i表示第j轮第i个工作节点的学习率。
根据本申请一些实施例的基于自适应学习率的异步联邦学习参数更新方法,工作节点的梯度由如下公式表示:
式中,gi(x;w)表示第i类的logit,gy(x;w)是第y类的logit,γ是损失函数的超参数,i表示第i个类别,y表示第y个类别,表示第y类的实例数,/>表示第i类的实例数,c表示样本类别的总个数。
本发明与现有技术相比,具有如下有益效果:
在第一方面上,本发明通过搭建基于不平衡数据分布的异步联邦学习算法,一方面可以协助多方共同学习一个准确且通用的神经网络模型,而无需公开和共享他们的本地用户数据集;另一方面本发明采取了一种双管齐下的方法,旨在分别解决客户端和服务器端的数据集不平衡和梯度陈旧性问题。本发明整合了一种新的评估方法,采用余弦相似度来衡量延迟梯度的陈旧性,进一步优化了服务器上的聚合算法,以提高异步联邦学习的性能。此外,还加入了一个类平衡损失函数来克服数据集不平衡问题,可以处理数据异质性的问题。这使得工作节点能够以一致的目标训练一个通用的分类器,而不考虑具体的类分布。从而提高了异步联邦学习训练速度的稳定度。
在第二方面上,本发明从梯度下降方向性的角度将延迟梯度的陈旧度进行了新的定义,现有计算延迟梯度陈旧度的方法认为延迟梯度的陈旧度和版本延迟呈正相关,本发明通过实验验证延迟梯度和同步梯度具有方向误差,但方向误差和版本延迟并不呈绝对的正相关,为此,本发明从梯度下降方向性的角度对陈旧度进行了新的定义,考虑了方向误差和版本延迟并不呈绝对的正相关,因此,本发明能够更好的利用延迟梯度促进模型收敛。
在第三方面,本发明解决了异步联邦学习面临着双重挑战:陈旧性问题和数据集不平衡问题,中心服务器接收完K个梯度后,首先进行无偏梯度估计,并实施一种基于余弦相似度的新型评估方法,以衡量延迟梯度的陈旧度;同时进一步调整学习速率,更新并广播模型参数和迭代次数。对于数据集不平衡问题,工作节点引入了一个类平衡损失函数,可以处理异质性数据对于模型训练的影响,本发明根据延时程度自适应调整学习速率,提高了模型的预测精度。
附图说明
图1为本发明实施例提供的一种基于自适应学习率的异步联邦学习参数更新方法流程图。
图2为本发明实施例提供的中心服务器端流程图。
图3为本发明实施例提供的工作节点端流程图。
图4为本发明实施例提供的基于加权聚合联邦学习的网络流量分类架构图。
图5为本发明实施例提供的不同联邦学习策略的实验对比图。
具体实施方式
下面将结合附图和技术方案,对本发明的实施过程进行详细描述。
实施例1:本发明涉及一种基于自适应学习率的异步联邦学习参数更新方法,还提出实现所述方法的相应的电子设备及可读存储介质。
本实施例的基于自适应学习率的异步联邦学习参数更新方法,用于中心服务器,所述方法包括以下步骤:
中心服务器初始化全局神经网络模型w0,学习率η,全局通信轮次T,并且初始化全局模型版本为version=0,参与K异步联邦学习的工作节点个数为K;
各种参数初始化完成后,中心服务器向工作节点分发神经网络模型,等待最快的K个工作节点发来梯度更新;
在第j轮全局迭代中,中心服务器接收到K个梯度更新。具体地说,接收到来自第i个节点的id和来自第i个节点的梯度g(wj,i,ξj,i),中心服务器在K个梯度中选择基于最新模型更新的梯度为同步梯度,则其余梯度则为陈旧梯度;
中心服务器根据同步梯度的本地样本量计算全局无偏梯度估计,具体为:
其中,g(wj)表示全局无偏梯度,n表示M个工作节点的本地样本总量,nm表示第m个工作节点的本地样本数量,g(wj,m,ξj,m)表示第m个工作节点的同步梯度,j表示当前更新轮次,wj,m表示第j轮第m个工作节点的神经网络模型的参数,ξj,m表示第j轮第m个工作节点的样本;
对于延迟梯度中心服务器计算其陈旧度,并且根据各个梯度的陈旧度为其赋予不同的学习率;
其中,中心服务器依据如下公式计算当前延迟梯度的陈旧度:
式中,Gt指的是上一步中计算得到的全局无偏梯度估计,Gt-τ是指陈旧梯度。cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度估计的余弦相似性,也即梯度下降中的方向相似性;∈表示超参数,可以根据不同的数据集或者训练任务进行调节,s(T)表示当前延迟梯度的陈旧性。
中心服务器根据K个梯度的陈旧度自适应调整学习率;
其中,中心服务器依据如下公式调整其学习率:
式中,ητ表示调整之后的学习率,η0表示初始学习率,a表示陈旧度的阈值,当陈旧度小于a时,无需调整学习率。
在全局无偏梯度估计、计算陈旧度和调整学习率完成后,更新当前的全局模型wj和全局模型版本version;
其中,依据如下公式进行全局模型的更新:
wj+1表示更新后的全局神经网络模型,wj表示第j轮的全局神经网络模型,K表示参与当前更新的工作节点的个数,i表示第i个工作节点,j表示当前的第j个更新轮次,ηj,i表示第j轮第i个工作节点的学习率。
本实施例的基于自适应学习率的异步联邦学习参数更新方法,用于工作节点,所述方法包括以下步骤:
工作节点接收来自中心服务器的发送的初始模型参数,模型版本version;
在本地使用类平衡损失函数进行训练,以克服本地数据集不平衡带来的负面影响;
其中,工作节点使用如下类平衡损失函数进行训练:
其中,式中,gi(x;w)表示第i类的logit,gy(x;w)是第y类的logit,γ是损失函数的超参数,i表示第i个类别,y表示第y个类别,表示第y类的实例数,/>表示第i类的实例数,c表示样本类别的总个数。该损失是softmax的一个无偏扩展,旨在补偿训练和测试之间的类别分布变化。以促进小类实例在训练中要求更大的gi(x;w),以克服测试中的特征偏差。
在本地训练t轮过后,工作节点将其训练得到的梯度g(wj,i,ξj,i)发送给中心服务器;等待来自中心服务器的更新;
利用更新后的权重进行下一轮训练。
本实施例的基于自适应学习率的异步联邦学习参数更新系统,包括中心服务器以及与中心服务器通信相连的多个工作节点,中心服务器与工作节点基于异步联邦学习机制进行参数聚合更新,所述中心服务器按照上述方法进行参数聚合更新,所述工作节点按照上述方法完成参数更新。
基于上述方法,一种用于在参数服务器端进行基于异步联邦学习的参数聚合更新的设备,所述设备包括:
存储器,存储有一个或多个计算机程序,所述一个或多个计算机程序被一个或多个处理器执行时,致使所述一个或多个处理器执行如本发明第一方面所述的参数聚合更新方法。
基于上述方法,一种用于在工作节点端进行基于异步联邦学习的参数聚合更新的设备,所述设备包括:
存储器,存储有一个或多个计算机程序,所述一个或多个计算机程序被一个或多个处理器执行时,致使所述一个或多个处理器执行如本发明第二方面所述的参数聚合更新方法。
本发明的有益效果:本发明通过搭建基于不平衡数据分布的异步联邦学习算法,一方面可以协助多方共同学习一个准确且通用的神经网络模型,而无需公开和共享他们的本地用户数据集;另一方面本发明采取了一种双管齐下的方法,旨在分别解决客户端和服务器端的数据集不平衡和梯度陈旧性问题。本发明整合了一种新的评估方法,采用余弦相似度来衡量延迟梯度的陈旧性,进一步优化了服务器上的聚合算法,以提高异步联邦学习的性能。此外,还加入了一个类平衡损失函数来克服数据集不平衡问题,可以处理数据异质性的问题。这使得工作节点能够以一致的目标训练一个通用的分类器,而不考虑具体的类分布。从而提高了异步联邦学习训练速度的稳定度。
实施例2:为了解决异步联邦学习中因为数据不平衡和延迟梯度导致的模型效用降低的问题,本发明提出一种基于自适应学习率的异步联邦学习参数更新方法,解决异步联邦学习中的不平衡问题和陈旧问题。该方法的中心服务器端包括如下步骤:
S1、中心服务器初始化模型,模型参数w0、初始学习率η0、全局迭代轮次T、超参数∈,γ以及初始化队列Q;
S2、中心服务器向连接的工作节点广播当前的全局通信轮次、当前的模型版本和最新的模型参数w0,等待来自工作节点的梯度更新;
S3、中心服务器与工作节点保持网络连接,接收来自工作节点的更新梯度g(wj,ξj),同时将接收到的梯度加入队列中;
S4、判断当前队列中是否接收到K个梯度更新,若当前接收到的更新数少于K个,继续接收来自工作节点的更新,等待更新最快的K个节点发送工作更新。如果队列中的更新数等于K,进行下一步;
S5、中心服务器在接收到的K个梯度中选择基于最新模型更新的梯度为同步梯度,则其余梯度则为陈旧梯度,中心服务器依据以下规则进行全局无偏梯度估计:
式中,g(wj)表示全局无偏梯度,n表示M个工作节点的本地样本总量,nm表示第m个工作节点的本地样本数量,g(wj,m,ξj,m)表示第m个工作节点的同步梯度,j表示当前更新轮次,wj,m表示第j轮第m个工作节点的神经网络模型的参数,ξj,m表示第j轮第m个工作节点的样本。
S6、中心服务器计根据余弦相似性算延迟梯度的陈旧性,并且根据各个梯度的陈旧度为其赋予不同的学习率;
到目前为止,衡量局部梯度过时程度的现有策略是通过迭代滞后τ的数量或通过局部训练时间。这些策略在解决实验中的陈旧模型问题方面已经证明了一定的有效性。然而,它们在实际场景中有明显的局限性。例如,一些具有低延迟的梯度可以与当前最新梯度有较高的方向一致性,而一些具有高延迟的梯度可能不会与当前最优梯度方向偏离太多。如果通过迭代滞后或局部训练时间测量的这些梯度的陈旧性,延迟梯度的陈旧度一旦超过某个阈值,则可能会错误地丢弃这些梯度。这会对训练模型的收敛产生不利影响,并减慢训练过程。在实践中,这种方法不能准确地测量陈旧的梯度是否有助于全局模型的收敛。因此本发明设计了基于余弦相似性的延迟梯度陈旧性的衡量方法;
S6.1、中心服务器依据如下公式计算当前延迟梯度的陈旧度:
式中,Gt指的是上一步中计算得到的全局无偏梯度估计,Gt-τ是指陈旧梯度。cos(Gt,Gt-τ)表示延迟梯度与全局无偏梯度估计的余弦相似性,也即梯度下降中的方向相似性;∈表示超参数,可以根据不同的数据集或者训练任务进行调节,s(τ)表示当前延迟梯度的陈旧性。
从梯度下降方向性的角度将延迟梯度的陈旧度进行了新的定义,现有计算延迟梯度陈旧度的方法认为延迟梯度的陈旧度和版本延迟呈正相关,本发明通过实验验证延迟梯度和同步梯度具有方向误差,但方向误差和版本延迟并不呈绝对的正相关,为此,本发明从梯度下降方向性的角度对陈旧度进行了新的定义,考虑了方向误差和版本延迟并不呈绝对的正相关,因此,本发明能够更好的利用延迟梯度促进模型收敛。
S7、学习率衰减。在异步联邦学习中,不同客户端的梯度关系可能存在陈旧性,即与最新的全局梯度相比有一定的延迟。这种陈旧性会影响全局模型的更新和收敛性能。为了减少陈旧性的影响,一种常用的方法是对陈旧客户端的权重进行学习率衰减,即降低其在全局权重更新中的贡献。学习率衰减的原则是,陈旧性越大,学习率越小。一种常见的学习率衰减策略是根据客户端的陈旧度τ来调整其权重的更新系数ητ,其中τ是当前延迟梯度的陈旧度参数。具体地,可以定义如下:
式中,ητ表示调整之后的学习率,η0表示初始学习率,a表示陈旧度的阈值,当陈旧度小于a时,无需调整学习率,其中η0是一个初始学习率,范围是(0,1),a是超参数。这样,当客户端的陈旧度超过一个阈值a时,其学习率会按照一个幂函数衰减。这种策略可以有效地平衡不同客户端的权重更新,提高全局模型的收敛性能。
S8、中心服务器更新模型。在全局无偏梯度估计、计算陈旧度和调整学习率完成后,更新当前的全局模型wj和全局模型版本version;
其中,依据如下公式进行全局模型的更新:
wj+1表示更新后的全局神经网络模型,wj表示第j轮的全局神经网络模型,K表示参与当前更新的工作节点的个数,i表示第i个工作节点,j表示当前的第j个更新轮次,ηj,i表示第j轮第i个工作节点的学习率。
S9、一轮更新结束之后,中心服务器判断当前轮次是否等于预先定义的总沟通轮次,若无,继续执行当前循环,若完成了T轮训练,则代表全局模型已经训练完成,因此程序训练结束。
在一种实施实例中,本发明的若干工作节点包含如下步骤:
S10、工作节点初始化模型和本地轮次t;
S11、工作节点从中心服务器接收最新的全局模型权重wt。这一步需要工作节点和中心服务器之间有可靠的通信连接,以及中心服务器能够及时地将全局模型权重广播给所有的工作节点。如果通信连接不稳定或者中心服务器的广播能力不足,可能会导致工作节点收到过时的或者错误的全局模型权重,影响本地训练的效果;
S12、工作节点用自己的数据集Di对wt进行本地训练,同时为了克服数据集不平衡带来的影响,需要将一般的经验损失函数改为类平衡损失函数,得到本地更新后的梯度g(wj,i,ξj,i)。这一步需要工作节点有足够的计算能力和数据量,以及合适的训练参数,如学习率、批量大小、训练轮数等。如果工作节点的计算能力或数据量不足,或者训练参数不合理,可能会导致本地训练的速度和质量不高,影响全局模型的收敛性能;
S13、本地模型进行梯度下降算法,得到更新的参数。工作节点根据梯度的方向和一个预设的学习率,更新参数向量,使目标函数沿着梯度下降的方向移动一小步,工作节点的梯度更新的损失函数由如下公式表示:
式中,gi(x;w)表示第i类的logit,gy(x;w)是第y类的logit,γ是损失函数的超参数,i表示第i个类别,y表示第y个类别,表示第y类的实例数,/>表示第i类的实例数,c表示样本类别的总个数。
S14、工作节点判断本地更新次数是否等于预先定义的t轮,若小于t轮,继续循环S12-S13,直到达到预设的训练轮数,若达到预定轮次,则本地训练结束;
S15、将梯度g(wj,i,ξj,i)发送给中心服务器,并等待下一轮的全局模型权重。如果有的工作节点训练时间超时,则继续当前轮次的本地更新,等待下一轮再参与全局模型的更新。
根据本发明的另一实施例,提供一种用于在工作节点端进行基于异步联邦学习的中心聚合更新的设备,设备包括:存储器,存储有一个或多个计算机程序,所述一个或多个计算机程序被一个或多个处理器执行时,所述一个或多个处理器执行上述方法实施例中的步骤。
本发明实例提供了基于异步联邦学习的聚合更新方法的实施步骤,需要说明的是,虽然在流程图中给出了逻辑流程顺序,但是在某些情况下,可以以不同的执行顺序所示或描述的步骤。
本发明还提供一种基于异步联邦学习的参数聚合更新系统,包括中心服务器以及与中心服务器通信相连的多个工作节点,参数服务器与工作节点基于异步联邦学习机制进行参数聚合更新,中心服务器根据步骤S1-S9所述的方法进行参数聚合更新;工作节点根据步骤S10-S15所述的方法完成参数更新。
本发明公开了一种基于自适应学习率的异步联邦学习参数更新方法、设备及系统。为了解决异步联邦学习面临着双重挑战:陈旧性问题和数据集不平衡问题,本方法分别在中心服务器和工作节点解决如上问题。中心服务器接收完K个梯度后,首先进行无偏梯度估计,并实施一种基于余弦相似度的新型评估方法,以衡量延迟梯度的陈旧度;同时进一步调整学习速率,更新并广播模型参数和迭代次数。对于数据集不平衡问题,工作节点引入了一个类平衡损失函数,可以处理异质性数据对于模型训练的影响。本发明根据延时程度自适应调整学习速率,提高了模型的预测精度。
本领域内的技术人员应明白,本发明的实施例可提供为方法、设备、装置、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等,上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统,、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。最后应当说明的是:以上实施例仅用以说明本发明的技术方案而非对其限制,尽管参照上述实施例对本发明进行了详细的说明,所属领域的普通技术人员应当理解:依然可以对本发明的具体实施方式进行修改或者等同替换,而未脱离本发明精神和范围的任何修改或者等同替换,其均应涵盖在本发明的权利要求保护范围之内。
Claims (12)
1.一种基于自适应学习率的异步联邦学习参数更新方法,其特征在于,用于中心服务器,包括
S110.中心服务器接收更新,更新包括由工作节点发出的梯度;
S120.中心服务器根据同步梯度估计全局无偏梯度;
S130.中心服务器根据全局无偏梯度计算延迟梯度的陈旧度;
S140.中心服务器根据陈旧度为延迟梯度调整学习率;
S150.中心服务器根据学习率更新全局神经网络模型;
S160.中心服务器将更新的全局神经网络模型的参数发出,更新的全局神经网络模型的参数由工作节点接收;
其中,同步梯度是工作节点依据最新的全局神经网络模型计算的梯度,延迟梯度是工作节点依据非最新的全局神经网络模型计算的梯度;
其中,延迟梯度的陈旧度由如下公式表示:
式中,表示延迟梯度与全局无偏梯度的余弦相似性,用于表示梯度下降中的方向相似性,/>表示全局无偏梯度/>,/>表示延迟梯度,/> 表示超参数,表示延迟梯度的陈旧度;
延迟梯度调整学习率由如下公式表示:
式中,表示调整之后的学习率,/>表示初始学习率,/>表示陈旧度的阈值。
2.根据权利要求1的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,步骤S110中,中心服务器接收的更新还包括迭代次数,中心服务器根据迭代次数达到预先定义的次数停止更新。
3.根据权利要求1的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,步骤S110中,中心服务器将接收的更新加入队列,当队列的长度达到设定阈值,中心服务器执行步骤S120。
4.根据权利要求1的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,所述步骤S110中还包括中心服务器将当前的迭代次数和当前的全局神经网络模型参数广播,广播由本地节点接收。
5.根据权利要求1-4中任一项所述的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,估计全局无偏梯度由如下公式表示:
式中,表示全局无偏梯度,/>表示/>个工作节点的本地样本总量,/>表示第/>个工作节点的本地样本数量,/>表示第/>个工作节点的同步梯度,/>表示当前更新轮次,/>表示第/>轮第/>个工作节点的神经网络模型的参数,/>表示第/>轮第/>个工作节点的样本;
更新全局神经网络模型由如下公式表示:
表示更新后的全局神经网络模型,/>表示第j轮的全局神经网络模型,/>表示参与当前更新的工作节点的个数,/>表示第/>个工作节点,/>表示当前的第/>个更新轮次,/>表示第/>轮第/>个工作节点的学习率。
6.根据权利要求1-4中任一项所述的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,工作节点的梯度由如下公式表示:
式中,表示第/>类的logit,/>是第/>类的logit,/>是损失函数的超参数,/>表示第/>个类别,/>表示第/>个类别,/>表示第/>类的实例数,/>表示第/>类的实例数,/>表示样本类别的总个数。
7.一种基于自适应学习率的异步联邦学习参数更新方法,其特征在于,用于工作节点,包括
S210.工作节点接收全局神经网络模型的参数,全局神经网络模型的参数由中心服务器发出;
S220.工作节点依据全局神经网络模型的参数训练工作节点的本地模型;
S230.工作节点的本地模型进行梯度下降得到更新的参数;
S240.工作节点发出更新,更新包括梯度,工作节点依据最新的全局神经网络模型计算的梯度是同步梯度,工作节点依据非最新的全局神经网络模型计算的梯度是延迟梯度,其中,工作节点发出的同步梯度是由中心服务器接收的同步梯度,用于中心服务器估计全局无偏梯度,根据全局无偏梯度计算延迟梯度的陈旧度,根据陈旧度为延迟梯度调整学习率,根据学习率更新全局神经网络模型;
其中,延迟梯度的陈旧度由如下公式表示:
式中,表示延迟梯度与全局无偏梯度的余弦相似性,用于表示梯度下降中的方向相似性,/>表示全局无偏梯度/>,/>表示延迟梯度,/> 表示超参数,表示延迟梯度的陈旧度;
延迟梯度调整学习率由如下公式表示:
式中,表示调整之后的学习率,/>表示初始学习率,/>表示陈旧度的阈值。
8.根据权利要求7所述的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,所述步骤S230还包括工作节点根据更新次数达到预先定义的次数执行步骤S240,否则执行步骤S220。
9.根据权利要求7-8中任一项所述的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,估计全局无偏梯度由如下公式表示:
式中,表示全局无偏梯度,/>表示/>个工作节点的本地样本总量,/>表示第/>个工作节点的本地样本数量,/>表示第/>个工作节点的同步梯度,/>表示当前更新轮次,/>表示第/>轮第/>个工作节点的神经网络模型的参数,/>表示第/>轮第/>个工作节点的样本;
更新全局神经网络模型由如下公式表示:
表示更新后的全局神经网络模型,/>表示第j轮的全局神经网络模型,/>表示参与当前更新的工作节点的个数,/>表示第/>个工作节点,/>表示当前的第/>个更新轮次,/>表示第/>轮第/>个工作节点的学习率。
10.根据权利要求7-8中任一项所述的基于自适应学习率的异步联邦学习参数更新方法,其特征在于,工作节点的梯度由如下公式表示:
式中,表示第/>类的logit,/>是第/>类的logit,/>是损失函数的超参数,/>表示第/>个类别,/>表示第/>个类别,/>表示第/>类的实例数,/>表示第/>类的实例数,/>表示样本类别的总个数。
11.一种电子设备,所述电子设备包括:一个或多个处理器,存储器,以及,一个或多个程序;其中,所述一个或多个程序被存储在所述存储器中,所述一个或多个程序包括指令,当所述指令被所述电子设备执行时,使得所述电子设备执行权利要求1~10任一项所述方法。
12.一种计算机可读存储介质,所述计算机可读存储介质包括计算机程序,当计算机程序在电子设备上运行时,使得所述电子设备执行权利要求1~10任一项所述方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310985134.2A CN117151208B (zh) | 2023-08-07 | 2023-08-07 | 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310985134.2A CN117151208B (zh) | 2023-08-07 | 2023-08-07 | 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117151208A CN117151208A (zh) | 2023-12-01 |
CN117151208B true CN117151208B (zh) | 2024-03-22 |
Family
ID=88899630
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310985134.2A Active CN117151208B (zh) | 2023-08-07 | 2023-08-07 | 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117151208B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117436515B (zh) * | 2023-12-07 | 2024-03-12 | 四川警察学院 | 联邦学习方法、系统、装置以及存储介质 |
Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111460443A (zh) * | 2020-05-28 | 2020-07-28 | 南京大学 | 一种联邦学习中数据操纵攻击的安全防御方法 |
WO2021120676A1 (zh) * | 2020-06-30 | 2021-06-24 | 平安科技(深圳)有限公司 | 联邦学习网络下的模型训练方法及其相关设备 |
CN113095407A (zh) * | 2021-04-12 | 2021-07-09 | 哈尔滨理工大学 | 一种降低通信次数的高效异步联邦学习方法 |
CN113435604A (zh) * | 2021-06-16 | 2021-09-24 | 清华大学 | 一种联邦学习优化方法及装置 |
CN113989561A (zh) * | 2021-10-29 | 2022-01-28 | 河海大学 | 基于异步联邦学习的参数聚合更新方法、设备及系统 |
CN113988308A (zh) * | 2021-10-27 | 2022-01-28 | 东北大学 | 一种基于延迟补偿机制的异步联邦梯度平均算法 |
CN114117926A (zh) * | 2021-12-01 | 2022-03-01 | 南京富尔登科技发展有限公司 | 一种基于联邦学习的机器人协同控制算法 |
CN114565103A (zh) * | 2022-02-28 | 2022-05-31 | 杭州卷积云科技有限公司 | 基于梯度选择和自适应学习率的加权k异步联邦学习方法、系统及装置 |
WO2022193432A1 (zh) * | 2021-03-17 | 2022-09-22 | 深圳前海微众银行股份有限公司 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
CN115470937A (zh) * | 2022-09-26 | 2022-12-13 | 广西师范大学 | 一种基于设备特性的异步联邦学习的任务调度方法 |
CN116488906A (zh) * | 2023-04-25 | 2023-07-25 | 重庆邮电大学 | 一种安全高效的模型共建方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11087864B2 (en) * | 2018-07-17 | 2021-08-10 | Petuum Inc. | Systems and methods for automatically tagging concepts to, and generating text reports for, medical images based on machine learning |
-
2023
- 2023-08-07 CN CN202310985134.2A patent/CN117151208B/zh active Active
Patent Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111460443A (zh) * | 2020-05-28 | 2020-07-28 | 南京大学 | 一种联邦学习中数据操纵攻击的安全防御方法 |
WO2021120676A1 (zh) * | 2020-06-30 | 2021-06-24 | 平安科技(深圳)有限公司 | 联邦学习网络下的模型训练方法及其相关设备 |
WO2022193432A1 (zh) * | 2021-03-17 | 2022-09-22 | 深圳前海微众银行股份有限公司 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
CN113095407A (zh) * | 2021-04-12 | 2021-07-09 | 哈尔滨理工大学 | 一种降低通信次数的高效异步联邦学习方法 |
CN113435604A (zh) * | 2021-06-16 | 2021-09-24 | 清华大学 | 一种联邦学习优化方法及装置 |
CN113988308A (zh) * | 2021-10-27 | 2022-01-28 | 东北大学 | 一种基于延迟补偿机制的异步联邦梯度平均算法 |
CN113989561A (zh) * | 2021-10-29 | 2022-01-28 | 河海大学 | 基于异步联邦学习的参数聚合更新方法、设备及系统 |
CN114117926A (zh) * | 2021-12-01 | 2022-03-01 | 南京富尔登科技发展有限公司 | 一种基于联邦学习的机器人协同控制算法 |
CN114565103A (zh) * | 2022-02-28 | 2022-05-31 | 杭州卷积云科技有限公司 | 基于梯度选择和自适应学习率的加权k异步联邦学习方法、系统及装置 |
CN115470937A (zh) * | 2022-09-26 | 2022-12-13 | 广西师范大学 | 一种基于设备特性的异步联邦学习的任务调度方法 |
CN116488906A (zh) * | 2023-04-25 | 2023-07-25 | 重庆邮电大学 | 一种安全高效的模型共建方法 |
Non-Patent Citations (4)
Title |
---|
Distributed asynchronous optimization with unbounded delays: How slow can you go?;Z. Zhou 等;《International Conference on Machine Learning》;20181231;5970-5979 * |
FedACA: An Adaptive Communication-Efficient Asynchronous Framework for Federated Learning;Shuang Zhou 等;《2022 IEEE International Conference on Autonomic Computing and Self-Organizing Systems (ACSOS)》;20221231;71-80 * |
Towards Efficient and Stable K-Asynchronous Federated Learning With Unbounded Stale Gradients on Non-IID Data;Zihao Zhou 等;《IEEE Transactions on Parallel and Distributed Systems》;20221201;第33卷(第12期);3291-3305 * |
基于卷积神经网络的异步联邦学习研究;张曦镱;《中国优秀硕士学位论文全文数据库 信息科技辑》;20230115;第2023年卷(第1期);I138-195 * |
Also Published As
Publication number | Publication date |
---|---|
CN117151208A (zh) | 2023-12-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113591145B (zh) | 基于差分隐私和量化的联邦学习全局模型训练方法 | |
CN111708640A (zh) | 一种面向边缘计算的联邦学习方法和系统 | |
CN111754000A (zh) | 质量感知的边缘智能联邦学习方法及系统 | |
CN117151208B (zh) | 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 | |
CN110889509B (zh) | 一种基于梯度动量加速的联合学习方法及装置 | |
CN113112027A (zh) | 一种基于动态调整模型聚合权重的联邦学习方法 | |
CN110968426B (zh) | 一种基于在线学习的边云协同k均值聚类的模型优化方法 | |
CN113139662A (zh) | 联邦学习的全局及局部梯度处理方法、装置、设备和介质 | |
CN109460793A (zh) | 一种节点分类的方法、模型训练的方法及装置 | |
CN113989561B (zh) | 基于异步联邦学习的参数聚合更新方法、设备及系统 | |
CN113206887A (zh) | 边缘计算下针对数据与设备异构性加速联邦学习的方法 | |
CN113691594B (zh) | 一种基于二阶导数解决联邦学习中数据不平衡问题的方法 | |
CN117349672B (zh) | 基于差分隐私联邦学习的模型训练方法、装置及设备 | |
CN115525038A (zh) | 一种基于联邦分层优化学习的设备故障诊断方法 | |
WO2020028770A1 (en) | Artificial neural network growth | |
CN113191504A (zh) | 一种面向计算资源异构的联邦学习训练加速方法 | |
CN116781343A (zh) | 一种终端可信度的评估方法、装置、系统、设备及介质 | |
CN112019547B (zh) | 网络流量评估方法、攻击检测方法、服务器及存储介质 | |
CN114401192A (zh) | 一种多sdn控制器协同训练方法 | |
CN113556780A (zh) | 一种拥塞控制方法及装置 | |
CN114580578A (zh) | 具有约束的分布式随机优化模型训练方法、装置及终端 | |
US20240028911A1 (en) | Efficient sampling of edge-weighted quantization for federated learning | |
CN118277891A (zh) | 一种考虑拜占庭容错的联邦学习方法、设备与系统 | |
CN117892805B (zh) | 基于超网络和层级别协作图聚合的个性化联邦学习方法 | |
CN117829274B (zh) | 模型融合方法、装置、设备、联邦学习系统及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |