CN110889509A - 一种基于梯度动量加速的联合学习方法及装置 - Google Patents
一种基于梯度动量加速的联合学习方法及装置 Download PDFInfo
- Publication number
- CN110889509A CN110889509A CN201911095913.5A CN201911095913A CN110889509A CN 110889509 A CN110889509 A CN 110889509A CN 201911095913 A CN201911095913 A CN 201911095913A CN 110889509 A CN110889509 A CN 110889509A
- Authority
- CN
- China
- Prior art keywords
- momentum
- interval
- parameter
- parameters
- model parameters
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于梯度动量加速的联合学习方法及装置,所述联合学习方法包括:在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,所有聚合区间计算完成后,获得最优化的全局模型参数;本发明的优点在于:将动量梯度下降算法用于联合学习的本地更新过程也即边缘节点的参数更新过程中,算法收敛速度较快。
Description
技术领域
本发明涉及联合学习领域,更具体涉及一种基于梯度动量加速的联合学习方法及装置。
背景技术
FL(Federated Learning,联合学习)是一种用于分布式机器学习技术,它能够有效的利用边缘节点有限的计算和通信资源训练出最优的模型学习性能。FL的结构包括一个CS(Central Server,中心端服务器)和许多的EN(EdgeNode,边缘节点)。在EN上,原始数据被收集和存储在EN的存储单元中,一个嵌于EN的机器学习模型用于训练这些本地的数据,所以EN不需要将这些本地数据发送到CS上。FL的CS与EN之间只同步更新节点的机器学习模型参数,我们称为权重(Weight)。这不但能够减少节点与服务器间通信的数据量,还能保护用户数据的隐私(中心端服务器接触不到用户数据)。FL的学习过程分为两个步骤,它们是本地更新(LocalUpdate)和全局聚合(GlobalAggregation)。在本地更新步骤中,各个EN基于本地数据集执行优化算法(例如:GD(GradientDescent,梯度下降)和牛顿法)去调整本地学习模型权重,使模型的损失函数值最小。各个节点经过设置好的本地迭代次数之后,FL执行全局聚合步骤。所有的EN的权重被同步发送到CS,经过CS的加权平均处理后,一个更新全局模型权重被发送给所有的EN。FL的学习过程是本地更新和全局聚合的不断轮替。
在联合学习领域,由于EN端有限的通信和计算资源,加速联合学习过程意味在更少的本地更新和全局聚合步骤内实现更高的联合学习性能和更有效的资源利用效率。现有技术的联合学习算法使用GD(GradientDescent,梯度下降)算法执行本地更新步骤,没有考虑之前的权重变化对算法收敛的改善,算法收敛较慢。
发明内容
本发明所要解决的技术问题在于如何提供一种基于梯度动量加速的联合学习方法及装置,以提高算法收敛速度。
本发明通过以下技术手段实现解决上述技术问题的:一种基于梯度动量加速的联合学习方法,所述联合学习方法采用分布式系统,应用于图像识别和语音识别,所述分布式系统包括若干个边缘节点和一个连接所有边缘节点的中心服务器;所述联合学习方法包括:
步骤一:将训练过程分为若干个聚合区间,每一个聚合区间对应设定的时长;在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;
步骤二:每个边缘节点在当前聚合区间末将模型参数和动量参数同时发送给中心服务器,中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;
步骤三:将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,将当前聚合区间获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为当前聚合区间的下个聚合区间的初始化值,重复步骤一和步骤二,直至达到预设的中心服务器聚合次数,停止执行以上步骤;
步骤四:所有聚合区间计算完成后,获得最优化的全局模型参数。
在中心式的学习环境下,MGD(MomentumGradientDescent,动量梯度下降)的收敛速度比GD(GradientDescent,梯度下降)的收敛速度快,因此,本发明为了加速联合学习的收敛速度,将动量梯度下降算法用于联合学习的本地更新过程也即边缘节点的参数更新过程中,并且将当前聚合区间末获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为下个聚合区间的初始化值,考虑到了本地更新过程中之前的权重变化对算法收敛的改善,算法收敛速度较快。
优选的,所述步骤一,包括:
计算当前聚合区间((k-1)τ,kτ]内每个时刻的模型参数和动量参数,其中,t为时刻,τ为每个聚合区间的间隔,k为中心服务器的聚合次数,为第i个边缘节点的动量参数,为第i个边缘节点的模型参数,η为第i个边缘节点执行的动量梯度下降算法的学习步长,γ为第i个边缘节点执行的动量梯度下降算法的动量衰减因子,Fi()为第i个边缘节点的损失函数,为梯度算符,为第i个边缘节点的损失函数的梯度。
优选的,所述步骤二,包括:
每个边缘节点在当前聚合区间末t=kτ将模型参数和动量参数同时发送给中心服务器,中心服务器通过公式聚合这些模型参数得到全局模型参数,中心服务器通过公式聚合这些动量参数得到全局动量参数,其中,d(t)为全局动量参数,D为中心服务器的全局数据集的样本数,Di为第i个边缘节点的数据集中的样本数,N为边缘节点的总数,∑为求和符号,||为绝对值符号,为当前聚合区间末t=kτ时第i个边缘节点的动量参数;w(t)为全局模型参数,为当前聚合区间末t=kτ时第i个边缘节点的模型参数。
优选的,所述步骤三,还包括:根据公式对当前聚合区间内的获得的损失函数值与上一聚合区间内获得的损失函数值比较,其中,Wf为优化的全局模型参数,argmin()为求最小值的集合函数,为恒等号,K为预设的中心服务器聚合次数。
优选的,所述步骤三,还包括:将当前聚合区间获得的优化的全局模型参数Wf和全局动量参数d(kτ)发送给所有的边缘节点作为下一个聚合区间(kτ,(k+1)τ]的初始化值,即当前聚合区间获得的优化的全局模型参数Wf作为当前聚合区间的下一个聚合区间(kτ,(k+1)τ]的模型参数的初始化值,当前聚合区间末的全局动量参数d(kτ)作为下一个聚合区间(kτ,(k+1)τ]的动量参数的初始化值。
本发明还提供一种基于梯度动量加速的联合学习装置,所述联合学习装置采用分布式系统,应用于图像识别和语音识别,所述分布式系统包括若干个边缘节点和一个连接所有边缘节点的中心服务器;所述联合学习装置包括:
参数获取模块,用于将训练过程分为若干个聚合区间,每一个聚合区间对应设定的时长;在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;
聚合模块,用于每个边缘节点在当前聚合区间末将模型参数和动量参数同时发送给中心服务器,中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;
优化模块,用于将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,将当前聚合区间获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为当前聚合区间的下个聚合区间的初始化值,重复执行参数获取模块和聚合模块,直至达到预设的中心服务器聚合次数,停止执行以上模块;
最优化模块,用于所有聚合区间计算完成后,获得最优化的全局模型参数。
优选的,所述参数获取模块,还用于:
计算当前聚合区间((k-1)τ,kτ]内每个时刻的模型参数和动量参数,其中,t为时刻,τ为每个聚合区间的间隔,k为中心服务器的聚合次数,为第i个边缘节点的动量参数,为第i个边缘节点的模型参数,η为第i个边缘节点执行的动量梯度下降算法的学习步长,γ为第i个边缘节点执行的动量梯度下降算法的动量衰减因子,Fi()为第i个边缘节点的损失函数,为梯度算符,为第i个边缘节点的损失函数的梯度。
优选的,所述聚合模块,还用于:
每个边缘节点在当前聚合区间末t=kτ将模型参数和动量参数同时发送给中心服务器,中心服务器通过公式聚合这些模型参数得到全局模型参数,中心服务器通过公式聚合这些动量参数得到全局动量参数,其中,d(t)为全局动量参数,D为中心服务器的全局数据集的样本数,Di为第i个边缘节点的数据集中的样本数,N为边缘节点的总数,∑为求和符号,||为绝对值符号,为当前聚合区间末t=kτ时第i个边缘节点的动量参数;w(t)为全局模型参数,为当前聚合区间末t=kτ时第i个边缘节点的模型参数。
优选的,所述优化模块,还用于:根据公式
优选的,所述优化模块,还用于:将当前聚合区间获得的优化的全局模型参数Wf和全局动量参数d(kτ)发送给所有的边缘节点作为下一个聚合区间(kτ,(k+1)τ]的初始化值,即当前聚合区间获得的优化的全局模型参数Wf作为当前聚合区间的下一个聚合区间(kτ,(k+1)τ]的模型参数的初始化值,当前聚合区间末的全局动量参数d(kτ)作为下一个聚合区间(kτ,(k+1)τ]的动量参数的初始化值。
本发明的优点在于:MGD(MomentumGradientDescent,动量梯度下降)是一种用于中心式机器学习的优化算法。不同于一阶的梯度下降算法,MGD是一种二阶的梯度下降方法,它的下一步更新步骤由当前的梯度和上一次权重变化共同决定。上一次权重变化(Momentum Term,动量项)的引入能够加速算法的收敛。比起GD,MGD有更快的收敛率。
在联合学习领域,由于EN端有限的通信和计算资源,加速联合学习过程意味在更少的本地更新和全局聚合步骤内实现更高的联合学习性能和更有效的资源利用效率。在中心式的学习环境下,动量梯度下降算法的收敛速度比梯度下降算法的收敛速度快,因此,本发明为了加速联合学习的收敛速度,将动量梯度下降算法用于联合学习的本地更新过程也即边缘节点的参数更新过程中,并且将当前聚合区间末获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为下个聚合区间的初始化值,考虑到了本地更新过程中之前的权重变化对算法收敛的改善,算法收敛速度较快。
附图说明
图1为本发明实施例1所公开的一种基于梯度动量加速的联合学习方法的结构图;
图2为本发明实施例1所公开的一种基于梯度动量加速的联合学习方法的设计流程图;
图3为本发明实施例1所公开的一种基于梯度动量加速的联合学习方法中SVM模型在FL、MFL与MGD的损失函数收敛曲线的比较图;
图4为本发明实施例1所公开的一种基于梯度动量加速的联合学习方法中基于SVM模型的测试集精度随着本地更新次数收敛的曲线;
图5为本发明实施例1所公开的一种基于梯度动量加速的联合学习方法中线性回归模型下测试收敛曲线与本地迭代次数的关系图;
图6为本发明实施例1所公开的一种基于梯度动量加速的联合学习方法中逻辑回归模型下测试收敛曲线与本地迭代次数的关系图;
图7为本发明实施例2所公开的一种基于梯度动量加速的联合学习装置的结构框图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
实施例1
如图1所示为本发明提供的一种基于梯度动量加速的联合学习方法的结构图,图中,本地学习模型指的是每个边缘节点上嵌入的机器学习模型,全局学习模型指的是全局动量参数d(t)以及全局模型参数w(t)的求解公式,一种基于梯度动量加速的联合学习方法,所述联合学习方法采用分布式系统,应用于图像识别和语音识别,所述分布式系统包括若干个边缘节点和一个连接所有边缘节点的中心服务器;所述联合学习方法包括:
步骤一:在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;具体过程为:如图2所示,为基于梯度动量加速的联合学习方法的设计流程图。图中,本地更新指的是每个EN计算模型参数和动量参数的过程,全局聚合指的是中心服务器聚合这些模型参数和动量参数的过程。间隔[k]指的是当前聚合区间的迭代间隔,[k+1]是下一聚合区间的迭代间隔。
计算当前聚合区间((k-1)τ,kτ]内每个时刻的模型参数和动量参数,其中,t为时刻,τ为每个聚合区间的间隔,k为中心服务器的聚合次数,为第i个边缘节点的动量参数,为第i个边缘节点的模型参数,η为第i个边缘节点执行的动量梯度下降算法的学习步长,γ为第i个边缘节点执行的动量梯度下降算法的动量衰减因子,Fi()为第i个边缘节点的损失函数,为梯度算符,为第i个边缘节点的损失函数的梯度。
步骤二:每个边缘节点在当前聚合区间末将模型参数和动量参数同时发送给中心服务器,中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;具体过程为:每个边缘节点在当前聚合区间末t=kτ将模型参数和动量参数同时发送给中心服务器,中心服务器通过公式聚合这些模型参数得到全局模型参数,中心服务器通过公式聚合这些动量参数得到全局动量参数,其中,d(t)为全局动量参数,D为中心服务器的全局数据集的样本数,Di为第i个边缘节点的数据集中的样本数,N为边缘节点的总数,∑为求和符号,||为绝对值符号,为当前聚合区间末t=kτ时第i个边缘节点的动量参数;w(t)为全局模型参数,为当前聚合区间末t=kτ时第i个边缘节点的模型参数。
步骤三:将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,将当前聚合区间获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为当前聚合区间的下个聚合区间的初始化值,重复步骤一和步骤二,直至达到预设的中心服务器聚合次数,停止执行以上步骤;具体过程为:将当前聚合区间内的全局模型参数w(kτ)代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值通过公式比较,其中,Wf为优化的全局模型参数,argmin()为求最小值的集合函数,为恒等号,K为预设的中心服务器聚合次数,损失函数值小的为优化的全局模型参数。将当前聚合区间获得的优化的全局模型参数Wf和全局动量参数d(kτ)发送给所有的边缘节点作为下一个聚合区间(kτ,(k+1)τ]的初始化值,即当前聚合区间获得的优化的全局模型参数Wf作为当前聚合区间的下一个聚合区间(kτ,(k+1)τ]的模型参数的初始化值,当前聚合区间末的全局动量参数d(kτ)作为下一个聚合区间(kτ,(k+1)τ]的动量参数的初始化值。需要注意的是,若当前聚合区间为第一个聚合区间,那么当前聚合区间内的全局模型参数代入损失函数公式计算的损失函数值与模型参数的初始化值比较。
步骤四:所有聚合区间计算完成后,经过每一次比较,留下了损失函数值较小的全局模型参数,最终所有区间计算完成,得到的是损失函数值最小的全局模型参数,即获得最优化的全局模型参数。
需要注意的是,每个边缘节点上都具有计算单元,存储单元和信号发送与接收单元。其中存储单元存储机器学习所需的训练样本集;计算单元执行对于特定机器学习损失函数的动量梯度优化,损失函数优化过程用到的是本存储的数据集。信号发送与接收单元对机器学习模型的参数进行发送与接收。中心服务器上具有计算单元,存储单元以及发送与接收单元。计算单元执行所有边缘节点的模型参数的聚合运算;存储单元存储边缘节点发送的模型参数;发送与接收单元用于模型参数的发送与接收过程。对于边缘节点和中心服务器内部的计算单元、存储单元和信号发送与接收单元,属于现有技术的硬件架构,在此不做过多描述。
以下采用3种嵌入的机器学习模型,分别为SVM(supportvectormachine,支持向量机)模型,线性回归模型和逻辑回归模型,对本发明的联合学习方法(即MomentumFederatedLearning,动量联合学习,以下简称MFL)、MGD(MomentumGradientDescent,动量梯度下降)以及FL(Federated Learning,联合学习)进行仿真验证。仿真验证是基于python环境,分别使用SVM模型、线性回归模型和逻辑回归模型训练MNIST数据集。我们设置η=0·002y=0·5τ=4以及总的本地更新次数kτ=1000。MFL、FL和中心式MGD在EN上被执行来优化以上三种学习模型。其中SVM在第i个EN上的损失函数为:
线性回归模型在第i个EN上的损失函数为:
逻辑回归模型在第i个EN上的损失函数为:
其中,
w为机器学习模型的全局模型参数的矩阵形式,即上文所述的w(t)的矩阵形式,对于以上三个机器学习模型,xj是第i个EN上的第j个样本的输入机器学习模型向量,yj是第j个样本的对应的机器学习模型期望输出标量。使用上面3种机器学习模型来训练MNIST数据集得到损失函数或者测试精度曲线。MNIST数据集为现有技术的数据集。
图3比较了在SVM模型在FL,MFL与MGD的损失函数收敛曲线的比较。从图3可以看出MFL的收敛速率要快于FL,这与预期的采用MGD执行本地更新会加速FL收敛的效果一致。当然也可以看出MGD有最快的收敛速率,由于MGD执行中心式的机器学习(MGD的学习的数据集是全局的数据集,被收集在CS端进行中心式的学习),聚合频率τ不会对MGD的梯度更新产生滞后的影响。
图4比较了基于SVM学习模型的测试集精度随着本地更新次数收敛的曲线。仍然可以看出在相同迭代次数下,MGD具有最好的测试精度,MFL的测试精度总是好于FL。这依然说明了MFL可以加速联合学习的收敛。
图5与图6分别在线性回归与逻辑回归模型下测试了它们的收敛曲线与本地迭代次数的关系。可以看出在基于SVM、线性回归与逻辑回归的学习模型,比起FL架构,我们提出的MFL架构总是有更快的收敛速度。
需要说明的是,中心式的MGD用于中心式场景,它需要先将分布在EN上的数据都收集到CS上,再在CS上执行MGD。但是MFL不需要数据收集过程,直接以分布式的形式,利用EN上的通信与计算资源完成学习过程。因此,中心式MGD虽然收敛性能好于τ>1情况下的MFL,但是执行它先要收集原始数据,这一过程耗费的通信资源是巨大的。
本发明的工作过程和工作原理为:在中心式的学习环境下,动量梯度下降算法的收敛速度比梯度下降算法的收敛速度快,因此,本发明为了加速联合学习的收敛速度,将动量梯度下降算法用于联合学习的本地更新过程也即边缘节点的参数更新过程中,并且将当前聚合区间末获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为下个聚合区间的初始化值,考虑到了本地更新过程中之前的权重变化对算法收敛的改善,算法收敛速度较快。
实施例2
与本发明实施例1相对应的,本发明实施例2提供一种基于梯度动量加速的联合学习装置,所述联合学习装置采用分布式系统,应用于图像识别和语音识别,所述分布式系统包括若干个边缘节点和一个连接所有边缘节点的中心服务器;所述联合学习装置包括:
参数获取模块,用于将训练过程分为若干个聚合区间,每一个聚合区间对应设定的时长;在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;
聚合模块,用于每个边缘节点在当前聚合区间末将模型参数和动量参数同时发送给中心服务器,中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;
优化模块,用于将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,将当前聚合区间获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为当前聚合区间的下个聚合区间的初始化值,重复执行参数获取模块和聚合模块,直至达到预设的中心服务器聚合次数,停止执行以上模块;
最优化模块,用于所有聚合区间计算完成后,获得最优化的全局模型参数。
具体的,所述参数获取模块,还用于:
计算当前聚合区间((k-1)τ,kτ]内每个时刻的模型参数和动量参数,其中,t为时刻,τ为每个聚合区间的间隔,k为中心服务器的聚合次数,为第i个边缘节点的动量参数,为第i个边缘节点的模型参数,η为第i个边缘节点执行的动量梯度下降算法的学习步长,γ为第i个边缘节点执行的动量梯度下降算法的动量衰减因子,Fi()为第i个边缘节点的损失函数,为梯度算符,为第i个边缘节点的损失函数的梯度。
具体的,所述聚合模块,还用于:
每个边缘节点在当前聚合区间末t=kτ将模型参数和动量参数同时发送给中心服务器,中心服务器通过公式聚合这些模型参数得到全局模型参数,中心服务器通过公式聚合这些动量参数得到全局动量参数,其中,d(t)为全局动量参数,D为中心服务器的全局数据集的样本数,Di为第i个边缘节点的数据集中的样本数,N为边缘节点的总数,∑为求和符号,| |为绝对值符号,为当前聚合区间末t=kτ时第i个边缘节点的动量参数;w(t)为全局模型参数,为当前聚合区间末t=kτ时第i个边缘节点的模型参数。
具体的,所述优化模块,还用于:根据公式
具体的,所述优化模块,还用于:将当前聚合区间获得的优化的全局模型参数Wf和全局动量参数d(kτ)发送给所有的边缘节点作为下一个聚合区间(kτ,(k+1)τ]的初始化值,即当前聚合区间获得的优化的全局模型参数Wf作为当前聚合区间的下一个聚合区间(kτ,(k+1)τ]的模型参数的初始化值,当前聚合区间末的全局动量参数d(kτ)作为下一个聚合区间(kτ,(k+1)τ]的动量参数的初始化值。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (10)
1.一种基于梯度动量加速的联合学习方法,其特征在于,所述联合学习方法采用分布式系统,应用于图像识别和语音识别,所述分布式系统包括若干个边缘节点和一个连接所有边缘节点的中心服务器;所述联合学习方法包括:
步骤一:将训练过程分为若干个聚合区间,每一个聚合区间对应设定的时长;在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;
步骤二:每个边缘节点在当前聚合区间末将模型参数和动量参数同时发送给中心服务器,中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;
步骤三:将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,将当前聚合区间获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为当前聚合区间的下个聚合区间的初始化值,重复步骤一和步骤二,直至达到预设的中心服务器聚合次数,停止执行以上步骤;
步骤四:所有聚合区间计算完成后,获得最优化的全局模型参数。
6.根据权利要求5所述的一种基于梯度动量加速的联合学习方法,其特征在于,所述步骤三,还包括:将当前聚合区间获得的优化的全局模型参数Wf和全局动量参数d(kτ)发送给所有的边缘节点作为下一个聚合区间(kτ,(k+1)τ]的初始化值,即当前聚合区间获得的优化的全局模型参数Wf作为当前聚合区间的下一个聚合区间(kτ,(k+1)τ]的模型参数的初始化值,当前聚合区间末的全局动量参数d(kτ)作为下一个聚合区间(kτ,(k+1)τ]的动量参数的初始化值。
7.一种基于梯度动量加速的联合学习装置,其特征在于,所述联合学习装置采用分布式系统,应用于图像识别和语音识别,所述分布式系统包括若干个边缘节点和一个连接所有边缘节点的中心服务器;所述联合学习装置包括:
参数获取模块,用于将训练过程分为若干个聚合区间,每一个聚合区间对应设定的时长;在每个边缘节点上嵌入相同的机器学习模型,并在当前聚合区间内执行动量梯度下降算法获取当前聚合区间内每个时刻的模型参数和动量参数;
聚合模块,用于每个边缘节点在当前聚合区间末将模型参数和动量参数同时发送给中心服务器,中心服务器聚合这些模型参数得到全局模型参数,中心服务器聚合这些动量参数得到全局动量参数;
优化模块,将当前聚合区间内的全局模型参数代入损失函数公式获得损失函数值与上一聚合区间内获得的损失函数值比较,获得优化的全局模型参数,将当前聚合区间获得的优化的全局模型参数和全局动量参数发送给所有的边缘节点作为当前聚合区间的下个聚合区间的初始化值,重复执行参数获取模块和聚合模块,直至达到预设的中心服务器聚合次数,停止执行以上模块;
最优化模块,用于所有聚合区间计算完成后,获得最优化的全局模型参数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911095913.5A CN110889509B (zh) | 2019-11-11 | 2019-11-11 | 一种基于梯度动量加速的联合学习方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911095913.5A CN110889509B (zh) | 2019-11-11 | 2019-11-11 | 一种基于梯度动量加速的联合学习方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110889509A true CN110889509A (zh) | 2020-03-17 |
CN110889509B CN110889509B (zh) | 2023-04-28 |
Family
ID=69747302
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911095913.5A Active CN110889509B (zh) | 2019-11-11 | 2019-11-11 | 一种基于梯度动量加速的联合学习方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110889509B (zh) |
Cited By (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112183612A (zh) * | 2020-09-24 | 2021-01-05 | 重庆邮电大学 | 一种基于参数扩充的联合学习方法、装置及系统 |
CN112488183A (zh) * | 2020-11-27 | 2021-03-12 | 平安科技(深圳)有限公司 | 一种模型优化方法、装置、计算机设备及存储介质 |
US20210218757A1 (en) * | 2020-01-09 | 2021-07-15 | Vmware, Inc. | Generative adversarial network based predictive model for collaborative intrusion detection systems |
CN113312177A (zh) * | 2021-05-11 | 2021-08-27 | 南京航空航天大学 | 一种基于联邦学习的无线边缘计算系统、优化方法 |
CN115086437A (zh) * | 2022-06-15 | 2022-09-20 | 中国科学技术大学苏州高等研究院 | 基于分簇和xdp技术的梯度聚合加速方法和装置 |
WO2022221997A1 (en) * | 2021-04-19 | 2022-10-27 | Microsoft Technology Licensing, Llc | Parallelizing moment-based optimizations with blockwise model-update filtering |
CN115989530A (zh) * | 2020-08-26 | 2023-04-18 | 瑞典爱立信有限公司 | 生成并处理视频数据 |
CN116007597A (zh) * | 2022-12-19 | 2023-04-25 | 北京工业大学 | 基于动量梯度下降法对框架柱的垂直度测量方法及装置 |
CN116049267A (zh) * | 2022-12-26 | 2023-05-02 | 上海朗晖慧科技术有限公司 | 一种多维智能识别的化学物品搜索显示方法 |
CN116781836A (zh) * | 2023-08-22 | 2023-09-19 | 云视图研智能数字技术(深圳)有限公司 | 一种全息远程教学方法及系统 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109952580A (zh) * | 2016-11-04 | 2019-06-28 | 易享信息技术有限公司 | 基于准循环神经网络的编码器-解码器模型 |
CN110287031A (zh) * | 2019-07-01 | 2019-09-27 | 南京大学 | 一种减少分布式机器学习通信开销的方法 |
US20190318268A1 (en) * | 2018-04-13 | 2019-10-17 | International Business Machines Corporation | Distributed machine learning at edge nodes |
-
2019
- 2019-11-11 CN CN201911095913.5A patent/CN110889509B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109952580A (zh) * | 2016-11-04 | 2019-06-28 | 易享信息技术有限公司 | 基于准循环神经网络的编码器-解码器模型 |
US20190318268A1 (en) * | 2018-04-13 | 2019-10-17 | International Business Machines Corporation | Distributed machine learning at edge nodes |
CN110287031A (zh) * | 2019-07-01 | 2019-09-27 | 南京大学 | 一种减少分布式机器学习通信开销的方法 |
Non-Patent Citations (1)
Title |
---|
孙娅楠;林文斌;: "梯度下降法在机器学习中的应用" * |
Cited By (17)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11811791B2 (en) * | 2020-01-09 | 2023-11-07 | Vmware, Inc. | Generative adversarial network based predictive model for collaborative intrusion detection systems |
US20210218757A1 (en) * | 2020-01-09 | 2021-07-15 | Vmware, Inc. | Generative adversarial network based predictive model for collaborative intrusion detection systems |
CN115989530A (zh) * | 2020-08-26 | 2023-04-18 | 瑞典爱立信有限公司 | 生成并处理视频数据 |
CN112183612A (zh) * | 2020-09-24 | 2021-01-05 | 重庆邮电大学 | 一种基于参数扩充的联合学习方法、装置及系统 |
CN112488183A (zh) * | 2020-11-27 | 2021-03-12 | 平安科技(深圳)有限公司 | 一种模型优化方法、装置、计算机设备及存储介质 |
CN112488183B (zh) * | 2020-11-27 | 2024-05-10 | 平安科技(深圳)有限公司 | 一种模型优化方法、装置、计算机设备及存储介质 |
WO2022221997A1 (en) * | 2021-04-19 | 2022-10-27 | Microsoft Technology Licensing, Llc | Parallelizing moment-based optimizations with blockwise model-update filtering |
CN113312177A (zh) * | 2021-05-11 | 2021-08-27 | 南京航空航天大学 | 一种基于联邦学习的无线边缘计算系统、优化方法 |
CN113312177B (zh) * | 2021-05-11 | 2024-03-26 | 南京航空航天大学 | 一种基于联邦学习的无线边缘计算系统、优化方法 |
CN115086437A (zh) * | 2022-06-15 | 2022-09-20 | 中国科学技术大学苏州高等研究院 | 基于分簇和xdp技术的梯度聚合加速方法和装置 |
CN115086437B (zh) * | 2022-06-15 | 2023-08-22 | 中国科学技术大学苏州高等研究院 | 基于分簇和xdp技术的梯度聚合加速方法和装置 |
CN116007597A (zh) * | 2022-12-19 | 2023-04-25 | 北京工业大学 | 基于动量梯度下降法对框架柱的垂直度测量方法及装置 |
CN116007597B (zh) * | 2022-12-19 | 2024-06-11 | 北京工业大学 | 基于动量梯度下降法对框架柱的垂直度测量方法及装置 |
CN116049267B (zh) * | 2022-12-26 | 2023-07-18 | 上海朗晖慧科技术有限公司 | 一种多维智能识别的化学物品搜索显示方法 |
CN116049267A (zh) * | 2022-12-26 | 2023-05-02 | 上海朗晖慧科技术有限公司 | 一种多维智能识别的化学物品搜索显示方法 |
CN116781836A (zh) * | 2023-08-22 | 2023-09-19 | 云视图研智能数字技术(深圳)有限公司 | 一种全息远程教学方法及系统 |
CN116781836B (zh) * | 2023-08-22 | 2023-12-01 | 云视图研智能数字技术(深圳)有限公司 | 一种全息远程教学方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN110889509B (zh) | 2023-04-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110889509A (zh) | 一种基于梯度动量加速的联合学习方法及装置 | |
US20220391771A1 (en) | Method, apparatus, and computer device and storage medium for distributed training of machine learning model | |
CN111708640A (zh) | 一种面向边缘计算的联邦学习方法和系统 | |
CN106297774B (zh) | 一种神经网络声学模型的分布式并行训练方法及系统 | |
EP3540652B1 (en) | Method, device, chip and system for training neural network model | |
CN109299781B (zh) | 基于动量和剪枝的分布式深度学习系统 | |
CN110968426B (zh) | 一种基于在线学习的边云协同k均值聚类的模型优化方法 | |
CN108111335B (zh) | 一种调度和链接虚拟网络功能的方法及系统 | |
CN109710404B (zh) | 分布式系统中的任务调度方法 | |
CN108446770B (zh) | 一种基于采样的分布式机器学习慢节点处理系统及方法 | |
CN103678004A (zh) | 一种基于非监督特征学习的主机负载预测方法 | |
CN113469325A (zh) | 一种边缘聚合间隔自适应控制的分层联邦学习方法、计算机设备、存储介质 | |
CN113778691B (zh) | 一种任务迁移决策的方法、装置及系统 | |
CN112287990A (zh) | 一种基于在线学习的边云协同支持向量机的模型优化方法 | |
CN116156563A (zh) | 基于数字孪生的异构任务与资源端边协同调度方法 | |
CN113191504B (zh) | 一种面向计算资源异构的联邦学习训练加速方法 | |
CN117829307A (zh) | 一种面向数据异构性的联邦学习方法及系统 | |
Zhou et al. | AdaptCL: Efficient collaborative learning with dynamic and adaptive pruning | |
CN114841341B (zh) | 图像处理模型训练及图像处理方法、装置、设备和介质 | |
Zhou et al. | DRL-Based Workload Allocation for Distributed Coded Machine Learning | |
Esfahanizadeh et al. | Stream iterative distributed coded computing for learning applications in heterogeneous systems | |
CN103886169A (zh) | 一种基于AdaBoost的链路预测算法 | |
CN110276455B (zh) | 基于全局率权重的分布式深度学习系统 | |
CN117118836A (zh) | 基于资源预测的服务功能链多阶段节能迁移方法 | |
CN114758130A (zh) | 图像处理及模型训练方法、装置、设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |