CN112101569A - 面向数据周期性的分布式多模型随机梯度下降方法 - Google Patents

面向数据周期性的分布式多模型随机梯度下降方法 Download PDF

Info

Publication number
CN112101569A
CN112101569A CN202010981089.XA CN202010981089A CN112101569A CN 112101569 A CN112101569 A CN 112101569A CN 202010981089 A CN202010981089 A CN 202010981089A CN 112101569 A CN112101569 A CN 112101569A
Authority
CN
China
Prior art keywords
model
data
gradient descent
training
global model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010981089.XA
Other languages
English (en)
Inventor
吴帆
吕承飞
丁雨成
牛超越
郑臻哲
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Jiaotong University
Original Assignee
Shanghai Jiaotong University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Jiaotong University filed Critical Shanghai Jiaotong University
Priority to CN202010981089.XA priority Critical patent/CN112101569A/zh
Publication of CN112101569A publication Critical patent/CN112101569A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

一种面向数据周期性的分布式多模型随机梯度下降方法,通过云端服务器向各个用户发布最新的全局模型,用户通过下载全局模型进行本地训练,执行多次随机梯度下降处理后将最新的本地模型上传至云端服务器进行聚合处理并更新全局模型,云端服务器读取当前所在的数据区块编号,将该区块的特制模型更新为当前的全局模型。本发明为不同的数据区块,即不同时间段内可用的用户群体提供特制的模型,对每个数据区块,通过保存在这个数据区块中产生的最新的全局模型得到对应的特制模型,显著提升训练得到的机器学习模型的准确率和鲁棒性。

Description

面向数据周期性的分布式多模型随机梯度下降方法
技术领域
本发明涉及的是一种机器学习领域的技术,具体是一种数据周期性的分布式多模型随机梯度下降方法,为一个周期中不同的数据区块提供不同的模型,应用于联合学习领域中,可显著提升测试精度。
背景技术
联合学习作为分布式机器学习的一个分支,允许大量客户端协作训练机器学习模型,而客户端可用性是部署联合学习系统的主要挑战之一:为了避免对用户体验造成负面影响,往往只有满足特定可用性条件(如充电,空闲,免费Wi-Fi)的客户端有机会被选择参与训练,这导致整个训练过程中,不同客户端参与训练的频率不同。联合学习领域的经典算法联合均值(FedAvg)忽略了这一问题,导致机器学习模型偏向高可用客户端,整体性能下降。
发明内容
本发明针对现有技术存在的上述不足,提出一种面向数据周期性的分布式多模型随机梯度下降方法,为不同的数据区块,即不同时间段内可用的用户提供特制的模型,对每个数据区块,通过聚合在这个数据区块中产生的历史模型得到一个为其特制的模型,显著提升训练得到的机器学习模型的准确率和鲁棒性。
本发明是通过以下技术方案实现的:
本发明涉及一种面向数据周期性的分布式多模型随机梯度下降方法,通过云端服务器向各个用户发布最新的全局模型,用户通过下载全局模型进行本地训练,执行多次随机梯度下降处理后将最新的本地模型上传至云端服务器进行聚合处理并更新全局模型,云端服务器读取当前所在的数据区块编号,将该区块的特制模型更新为当前的全局模型。
所述的全局模型是指:整个训练过程中云服务器上的机器学习模型,每轮迭代时,可用的客户端从云服务器上下载该全局模型,并根据该模型和本地数据计算所需的模型更新信息。
所述的本地模型是指:每个用户从云端服务器下载最新的全局模型后在本地进行训练得到的模型。
所述的数据区块是指:整个训练过程可以看作有若干个训练周期,每个周期可以由服务器划分成多个时间段(根据实际场景灵活划分,例如一个训练过程包含10天,每天可划分成6个4小时),每个时间段内有若干的可用用户,这些用户对应的整体数据分布为一个数据区块。
所述的特制模型,经多次更新全局模型后最终为每个数据区块都得到一个特制模型,其值为在该区块中产生的最新的全局模型。
本发明涉及一种实现上述方法的系统,包括:联合学习训练单元、数据区块划分单元、特制模型记录单元,其:联合学习训练单元执行全局模型的训练过程,数据区块划分单元通过时间戳对一个训练周期进行划分并对数据区块编号,特制模型记录单元记录了不同数据区块内产生的最新的全局模型。
技术效果
本发明整体解决了现有联合学习技术中,因为用户的周期性可用而造成的周期性训练数据问题。本发明消除传统的联合均值算法中单一的全局模型会偏向当前时间段内的数据区块的问题,为每个数据区块都提供了一个特制模型。
附图说明
图1为多模型随机梯度下降法的示意图;
图2为该算法与传统的联合均值算法的比较示意图;
图3为本地迭代次数对该算法的影响示意图;
图4为数据区块的数目对该算法的影响示意图。
具体实施方式
本实施例涉及一种面向数据周期性的分布式多模型随机梯度下降方法,采用pytorch教程的深度学习模型架构,其中包括两个卷积层,和三个全连接层,参数总数约为62006。
本实施例包括以下步骤。
步骤1、数据划分与初始化:将CIFAR-10数据集(包括训练集50000张图片和测试集10000张图片)按照标签划分成M(缺省值为5)个不同的数据区块。对每个数据区块,再按照标签分配给N(缺省值为100)个用户,使得每个用户的数据分布和数据量不同,其中,每个用户的数据量符合正态分布。云端服务器将当前迭代数初始化为0,并利用pytorch对全局模型进行随机初始化。
步骤2、云端服务器将全局模型发送给所有可用用户,每个用户利用自己分配到的数据进行本地训练。本地模型的更新方式为随机梯度下降法,在本地迭代I(缺省值为10)次后将更新后的本地模型发送给服务器。
步骤3、服务器收集到所有可用用户的本地模型后进行模型聚合,得到新的全局模型。
步骤4、服务器根据迭代数确定当前所在的数据区块数,将该区块对应的特制模型更新为当前的全局模型。
所述的特制模型,即每个数据区块都有一个特制的模型,其值为在该区块内产生的最新的全局模型。
重复步骤2至4,直至模型测试准确率收敛。最终为每个数据区块,即每个时间段内可用的用户群体得到了一个特制模型。
如图2所示,为周期性数据下多模型随机梯度下降法,传统的联合均值算法,以及在理想情况下(数据不具有周期性)的联合均值算法的测试准确率曲线。可以看到,本发明的算法准确率达到65%,高于理想情况下联合均值算法的62%。而联合均值算法在周期训练数据下无法收敛,准确率在56%-59%之间震荡。
如图3所示,为本地训练次数对于该算法的影响,可见在一定范围内增大本地训练次数可以有效减少收敛所需的轮数。
如图4所示,为数据区块数目对联合均值算法和分布式多模型随机梯度下降法的影响,显示出数据区块数目的增大会显著影响联合均值算法的准确率,而多模型随机梯度下降法的准确率却不受数据区块数目的影响。
本方法通过保存每个数据区块内产生的最新全局模型,为该数据区块对应的用户群体提供一个特制模型。经过具体实际实验,在INTELi9-9900K型号CPU、NVIDIARTX-2080Ti型号GPU的服务器上,在python3.6下,利用pytorch1.3.1库编程后设置学习率γ=0.01,训练周期数C=10,每个周期内数据区块数目M=5,每个数据区块内N=100个客户并行训练E=200轮,每轮包含本地更新次数I=10的实验设置下运行上述方法,最终M个特制模型分别在相应的数据区块上的平均测试准确率为65%。
与现有技术相比,本方法提升了6%的测试准确率。并且最终的测试曲线趋于平滑,而联合均值算法的测试准确率会周期性地上下震荡。
上述具体实施可由本领域技术人员在不背离本发明原理和宗旨的前提下以不同的方式对其进行局部调整,本发明的保护范围以权利要求书为准且不由上述具体实施所限,在其范围内的各个实现方案均受本发明之约束。

Claims (7)

1.一种面向数据周期性的分布式多模型随机梯度下降方法,其特征在于,通过云端服务器向各个用户发布最新的全局模型,用户通过下载全局模型进行本地训练,执行多次随机梯度下降处理后将最新的本地模型上传至云端服务器进行聚合处理并更新全局模型,云端服务器读取当前所在的数据区块编号,将该区块的特制模型更新为当前的全局模型。
2.根据权利要求1所述的分布式多模型随机梯度下降方法,其特征是,所述的全局模型是指:整个训练过程中云服务器上的机器学习模型,每轮迭代时,可用的客户端从云服务器上下载该全局模型,并根据该模型和本地数据计算所需的模型更新信息。
3.根据权利要求1所述的分布式多模型随机梯度下降方法,其特征是,所述的本地模型是指:每个用户从云端服务器下载最新的全局模型后在本地进行训练得到的模型。
4.根据权利要求1所述的分布式多模型随机梯度下降方法,其特征是,所述的数据区块是指:整个训练过程可以看作有若干个训练周期,每个周期可以由服务器划分成多个时间段,每个时间段内有若干的可用用户,这些用户对应的整体数据分布为一个数据区块。
5.根据权利要求1所述的分布式多模型随机梯度下降方法,其特征是,所述的特制模型,经多次更新全局模型后最终为每个数据区块都得到一个特制模型,其值为在该区块中产生的最新的全局模型。
6.根据权利要求1~5中任一所述的分布式多模型随机梯度下降方法,其特征是,具体步骤包括:
步骤1、数据划分与初始化:将CIFAR-10数据集中的训练集50000张图片和测试集10000张图片按照标签划分成M个不同的数据区块;对每个数据区块,再按照标签分配给N个用户,使得每个用户的数据分布和数据量不同,其中,每个用户的数据量符合正态分布;云端服务器将当前迭代数初始化为0,并利用pytorch对全局模型进行随机初始化;
步骤2、云端服务器将全局模型发送给所有可用用户,每个用户利用自己分配到的数据进行本地训练,本地模型的更新方式为随机梯度下降法,在本地迭代I次后将更新后的本地模型发送给服务器;
步骤3、服务器收集到所有可用用户的本地模型后进行模型聚合,得到新的全局模型;
步骤4、服务器根据迭代数确定当前所在的数据区块数,将该区块对应的特制模型更新为当前的全局模型;
重复步骤2至步骤4直至模型测试准确率收敛;最终为每个数据区块,即每个时间段内可用的用户群体得到了一个特制模型。
7.一种实现上述任一权利要求所述方法的系统,其特征在于,包括:联合学习训练单元、数据区块划分单元、特制模型记录单元,其:联合学习训练单元执行全局模型的训练过程,数据区块划分单元通过时间戳对一个训练周期进行划分并对数据区块编号,特制模型记录单元记录了不同数据区块内产生的最新的全局模型。
CN202010981089.XA 2020-09-17 2020-09-17 面向数据周期性的分布式多模型随机梯度下降方法 Pending CN112101569A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010981089.XA CN112101569A (zh) 2020-09-17 2020-09-17 面向数据周期性的分布式多模型随机梯度下降方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010981089.XA CN112101569A (zh) 2020-09-17 2020-09-17 面向数据周期性的分布式多模型随机梯度下降方法

Publications (1)

Publication Number Publication Date
CN112101569A true CN112101569A (zh) 2020-12-18

Family

ID=73759511

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010981089.XA Pending CN112101569A (zh) 2020-09-17 2020-09-17 面向数据周期性的分布式多模型随机梯度下降方法

Country Status (1)

Country Link
CN (1) CN112101569A (zh)

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110287031A (zh) * 2019-07-01 2019-09-27 南京大学 一种减少分布式机器学习通信开销的方法

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110287031A (zh) * 2019-07-01 2019-09-27 南京大学 一种减少分布式机器学习通信开销的方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
YUUCHENG DING: ""Distributed Optimization over Block Cyclic Data"", 《ARXIV》, pages 2 - 5 *

Similar Documents

Publication Publication Date Title
CN111353582B (zh) 一种基于粒子群算法的分布式深度学习参数更新方法
CN113191484A (zh) 基于深度强化学习的联邦学习客户端智能选取方法及系统
CN108564164B (zh) 一种基于spark平台的并行化深度学习方法
WO2020086214A1 (en) Deep reinforcement learning for production scheduling
CN108873936B (zh) 一种基于势博弈的飞行器自主编队方法
CN107368891A (zh) 一种深度学习模型的压缩方法和装置
JP7095675B2 (ja) 情報処理装置、情報処理方法、並びにプログラム
CN109543726A (zh) 一种训练模型的方法及装置
CN110502323B (zh) 一种云计算任务实时调度方法
CN113206887A (zh) 边缘计算下针对数据与设备异构性加速联邦学习的方法
CN113469372A (zh) 强化学习训练方法、装置、电子设备以及存储介质
CN109508146A (zh) 一种区块链的周期性存储空间回收方法
Al-Saedi et al. Reducing communication overhead of federated learning through clustering analysis
CN106327251A (zh) 模型训练系统和方法
CN114745392A (zh) 流量调度方法
CN112101569A (zh) 面向数据周期性的分布式多模型随机梯度下降方法
WO2015124668A1 (en) Tree-structure storage method for managing computation offloading data
CN117453391A (zh) 基于rnn和粒子群的端边云异构资源调度方法及装置
CN110995790B (zh) 一种解决区块链网络共识不确定性的方法
CN109272151B (zh) 一种基于Spark的车辆路径规划算法优化方法
CN115115064A (zh) 一种半异步联邦学习方法及系统
CN115329985B (zh) 无人集群智能模型训练方法、装置和电子设备
CN107577808B (zh) 一种多级列表页排序的方法、装置、服务器及介质
CN114970103A (zh) 考虑配送员经验与随机行驶时间的即时配送路径优化方法
CN113015179B (zh) 基于深度q网络的网络资源选择方法、装置以及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination