CN110287031B - 一种减少分布式机器学习通信开销的方法 - Google Patents
一种减少分布式机器学习通信开销的方法 Download PDFInfo
- Publication number
- CN110287031B CN110287031B CN201910583390.2A CN201910583390A CN110287031B CN 110287031 B CN110287031 B CN 110287031B CN 201910583390 A CN201910583390 A CN 201910583390A CN 110287031 B CN110287031 B CN 110287031B
- Authority
- CN
- China
- Prior art keywords
- machine learning
- memory
- gradients
- distributed machine
- momentum
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F9/00—Arrangements for program control, e.g. control units
- G06F9/06—Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
- G06F9/46—Multiprogramming arrangements
- G06F9/50—Allocation of resources, e.g. of the central processing unit [CPU]
- G06F9/5061—Partitioning or combining of resources
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Computer And Data Communications (AREA)
- Telephonic Communication Services (AREA)
Abstract
本发明公开了一种减少分布式机器学习通信开销的方法,基于参数服务器架构,既适用于数据中心的多机集群分布式机器学习,也适用于服务器作为云端、手机或嵌入式设备作为终端的端云协同分布式机器学习。包括以下步骤:首先所有工作节点计算梯度,并结合两轮参数差求出全局动量,将全局动量与上一轮记忆梯度求和得到新一轮记忆梯度,对其取部分发给服务器节点,剩下部分进行累积;随后服务器节点累积所有稀疏的记忆梯度和,以此更新参数并将两轮的参数差广播给所有工作节点;最后工作节点接收两轮参数差,并更新参数。本发明的方法基于全局梯度压缩,工作节点与服务器节点间通信时只传递全局动量的一部分,从而减小了分布式机器学习中的通信开销。
Description
技术领域
本发明提供了一种减少分布式机器学习通信开销的方法,可以有效地减少分布式机器学习中的通信开销。
背景技术
大部分机器学习模型可以被形式化为以下优化问题:
其中w代表了模型的参数,n代表了训练样本的总数,ξi代表了第i个样本,f(w;ξi)则表示第i个样本所对应的损失函数,d表示模型大小。为了求解上述优化问题,随机梯度下降法(SGD)以及它的变体是目前应用最为广泛的方法。变种中的动量梯度下降法(MSGD)对梯度使用指数加权平均,使得本次梯度影响减少,波动情况减小,在接近最小值时收敛更加稳定。
参数服务器架构(Parameter Server)是分布式机器学习中常用的一种架构,该架构具有良好的可扩展性和容错性,同时也支持灵活的一致性模型。参数服务器架构中包含一个服务器节点集群和多个工作节点集群,服务器节点集群包含多个服务器节点,一个服务器节点维护全局共享参数的一部分,服务器节点彼此通信来复制和/或迁移参数以用于可靠性和缩放。一个工作节点集群通常本地存储一部分训练数据,运行一个应用程序来计算一些局部数据,比如梯度。工作节点互相之间不通信,只与服务器节点通信来更新和检索共享参数。
随着训练数据量的增大,很多机器学习问题的训练过程需要花费大量的时间,分布式算法将训练数据分散到多个节点上并行地进行训练,以此来加速机器学习的训练过程。在实现数据并行的随机梯度下降法时,工作节点使用不同的数据子集和本地模型副本并行地计算出梯度,并发送给服务器节点。中心化的参数服务器收集所有梯度,并对他们求平均用来更新参数,然后把更新后的参数发给所有的工作节点。在算法拓展性较好的时候,数据并行使得增加训练节点的数可以显著减少模型训练时间。然而,随着分布式集群的规模越来越大,梯度的传递和参数的同步延长了通信时间,成为了进一步提高效率的瓶颈。
发明内容
发明目的:目前的分布式随机梯度下降法在参数更新时,服务器节点需要从每一个工作节点接受一个高维向量。随着机器学习模型的增大和工作节点数增多,这样的方法中所消耗的通信时间也会越来越长,最终导致通信堵塞,算法收敛速度减慢。针对上述问题与不足,提供一种减少分布式机器学习通信开销的方法,基于全局动量压缩,工作节点计算出本地的梯度,加到全局动量上,再加上上一轮的记忆梯度,然后根据某种方法取和的一部分到参数服务器,在参数服务器中汇总更新参数后广播到所有工作节点。可以看出,本发明的方法中,在工作节点与服务器节点之间通信时,只发送参数差和记忆梯度的一部分,所以能有效地减少分布式机器学习中的通信开销,达到加速分布式机器学习训练过程的目的。
技术方案:一种减少分布式机器学习通信开销的方法,其在服务器节点上训练流程的具体步骤为:
步骤100,输入机器学习模型w以及总共的迭代轮数T、工作节点数目p、样本总数n、学习率ηt;
步骤101,随机初始化模型参数w=w0;
步骤103,更新模型参数wt+1=wt-ηtvt,k;
步骤104,将参数差wt+1-wt广播给所有的工作节点;
步骤105,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则输出并保存模型参数w;否则返回步骤102继续进行训练。
本发明的方法在第k个工作节点上训练流程的具体步骤为:
步骤201,初始化记忆动量u0,k=0。
步骤202,接受服务器节点发送的模型参数差wt-wt-1;
步骤203,更新模型参数wt=wt-1-(wt-wt-1);
步骤207,生成一个稀疏向量mt,k∈{0,1}d,||mt,k||0=dρ;
步骤208,发送mt,k⊙(ut,k+gt,k)到服务器节点;
步骤209,更新记忆梯度ut+1,k=(1-mt,k)⊙(ut,k+gt,k),k=1,2,…,p;
步骤210,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则结束训练流程;否则返回步骤202继续进行训练。
有益效果:本发明提供的减少分布式机器学习通信开销的方法,既适用于数据中心的多机集群分布式机器学习,也适用于服务器作为云端、手机或嵌入式设备作为终端的端云协同分布式机器学习。本发明的方法基于全局动量压缩,在参数服务器架构下实现分布式动量梯度下降法,与现有技术相比,本发明使用记忆梯度克服了结合分布式随机梯度下降法和随机坐标下降法会带来误差的缺点,使用全局动量弥补了随机梯度和全梯度之间的误差,在保证预测精确度基本不降低的情况下,可以将通信量减少到传统动量梯度下降法的1%甚至更低。
附图说明
图1为本发明实施的减少分布式机器学习通信开销的方法在服务器节点上的工作流程图;
图2为本发明实施的减少分布式机器学习通信开销的方法在工作节点上的工作流程图。
具体实施方式
下面结合具体实施例,进一步阐明本发明,应理解这些实施例仅用于说明本发明而不用于限制本发明的范围,在阅读了本发明之后,本领域技术人员对本发明的各种等价形式的修改均落于本申请所附权利要求所限定的范围。
本发明提供的减少分布式机器学习通信开销的方法,可应用于图像分类、文本分类等领域,适合于待分类的数据集样本数多、所使用的机器学习模型参数量大的场景。以图像分类应用为例,在本发明的方法中,训练图像数据将分布式的存储在若干个工作节点上,而机器学习模型参数将由若干个服务器节点共同维护,在图像分类应用中的具体工作流程如下所述:
减少分布式机器学习通信开销的方法,在服务器节点上的工作流程如图1所示。首先输入机器学习模型w以及总共的迭代轮数T、工作节点数目p、样本总数n和学习率ηt;(步骤100),随机初始化模型参数w=w0并广播模型参数w0到所有的工作节点(步骤101)。接下来初始化迭代轮数计数器t=0(步骤102),随后进入到模型训练的迭代阶段:累计所有工作节点发送过来的稀疏动量(步骤103),并更新模型参数wt+1=wt-ηtvt,k(步骤104);然后将参数差wt+1-wt广播给所有的工作节点(步骤105)。每次迭代结束时将迭代轮数计数器增加1(步骤106)并进行判断是否达到停止条件t=T(步骤117),若未达到停止条件则继续迭代,否则输出训练结果并保存模型(步骤108)。
减少分布式机器学习通信开销的方法,在第k个工作节点上的工作流程如图2所示。首先输入本地训练图像数据以及总共的迭代轮数T、学习率ηt、稀疏度ρ、批量大小b、动量系数β(步骤200),本地训练图像数据为完整训练图像数据集合的一个子集(完整训练图像数据集合接下来初始化迭代轮数计数器t=0,接收模型初始参数w0,令w-1=w0,初始化记忆动量u0,k=0(步骤201),随后进入到模型训练的迭代阶段:接收服务器节点发送的模型参数差wt-wt-1(步骤202),并更新模型参数wt=wt-1-(wt-wt-1)(步骤203);从本地数据集中随机挑选一个小批量数据(步骤204),并计算全局动量(步骤205);然后从ut,k+gt,k中随机选择一部分元素作为S,然后取S中第ρ|S|大元素的值作为阈值θ(步骤206),生成一个稀疏向量mt,k=(ut,k+gt,k)>θ,mt,k∈{0,1}d,||mt,k||0=dρ(步骤207),并发送mt,k⊙(ut,k+gt,k)到服务器节点(步骤208),最后更新记忆梯度ut+1,k=(1-mt,k)⊙(ut,k+gt,k),k=1,2,…,p(步骤209)。每次迭代结束时将迭代轮数计数器增加1(步骤210)并进行判断是否达到停止条件t=T(步骤211),若未达到停止条件则继续迭代,否则结束训练流程(步骤212)。
本发明的方法在多个图像分类数据集上进行了实验。实验过程中,在服务器端统计了一轮迭代中服务器接收来自所有工作节点的比特数和发送给所有工作节点的比特数,通信压缩比为本发明的算法一轮通信比特数与传统动量梯度下降法一轮通信的比特数之比。实验结果表明,本发明提出的方法在保证预测精确度基本不降低的情况下,可以将通信量减少到传统动量梯度下降法的1%甚至更低。
Claims (2)
1.一种减少分布式机器学习通信开销的方法,其特征在于,其在服务器节点上训练流程的具体步骤为:
步骤100,输入机器学习模型w以及总共的迭代轮数T、工作节点数目p、样本总数n、学习率ηt;
步骤101,随机初始化模型参数w=w0;
步骤103,更新模型参数wt+1=wt-ηtvt,k;
步骤104,将参数差wt+1-wt广播给所有的工作节点;
步骤105,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则输出并保存模型参数w;否则返回步骤102继续进行训练;
其在第k个工作节点上训练流程的具体步骤为:
步骤201,初始化记忆动量u0,k=0;
步骤202,接受服务器节点发送的模型参数差wt-wt-1;
步骤203,更新模型参数wt=wt-1-(wt-wt-1);
步骤207,生成一个稀疏向量mt,k∈{0,1}d,||mt,k||0=dρ;
步骤208,发送mt,k⊙(ut,k+gt,k)到服务器节点;
步骤209,更新记忆梯度ut+1,k=(1-mt,k)⊙(ut,k+gt,k),k=1,2,…,p;
步骤210,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则结束训练流程;否则返回步骤202继续进行训练。
2.如权利要求1所述的减少分布式机器学习通信开销的方法,其特征在于:步骤207-209中,生成一个稀疏向量mt,k∈{0,1}d,||mt,k||0=dρ,发送mt,k⊙(ut,k+gt,k)到服务器节点,更新记忆梯度ut+1,k=(1-mt,k)⊙(ut,k+gt,k),k=1,2,…,p。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910583390.2A CN110287031B (zh) | 2019-07-01 | 2019-07-01 | 一种减少分布式机器学习通信开销的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910583390.2A CN110287031B (zh) | 2019-07-01 | 2019-07-01 | 一种减少分布式机器学习通信开销的方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110287031A CN110287031A (zh) | 2019-09-27 |
CN110287031B true CN110287031B (zh) | 2023-05-09 |
Family
ID=68020322
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910583390.2A Active CN110287031B (zh) | 2019-07-01 | 2019-07-01 | 一种减少分布式机器学习通信开销的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110287031B (zh) |
Families Citing this family (17)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110929878B (zh) * | 2019-10-30 | 2023-07-04 | 同济大学 | 一种分布式随机梯度下降方法 |
CN110889509B (zh) * | 2019-11-11 | 2023-04-28 | 安徽超清科技股份有限公司 | 一种基于梯度动量加速的联合学习方法及装置 |
US11379727B2 (en) * | 2019-11-25 | 2022-07-05 | Shanghai United Imaging Intelligence Co., Ltd. | Systems and methods for enhancing a distributed medical network |
CN110990155B (zh) * | 2019-11-29 | 2022-03-22 | 杭州电子科技大学 | 一种面向大规模安全监控的参数通信方法 |
CN112948105B (zh) * | 2019-12-11 | 2023-10-17 | 香港理工大学深圳研究院 | 一种梯度传输方法、梯度传输装置及参数服务器 |
CN111369008A (zh) * | 2020-03-04 | 2020-07-03 | 南京大学 | 一种阶段性增大批量的机器学习方法 |
CN111369009A (zh) * | 2020-03-04 | 2020-07-03 | 南京大学 | 一种能容忍不可信节点的分布式机器学习方法 |
CN111625603A (zh) * | 2020-05-28 | 2020-09-04 | 浪潮电子信息产业股份有限公司 | 一种分布式深度学习的梯度信息更新方法及相关装置 |
CN111709533B (zh) * | 2020-08-19 | 2021-03-30 | 腾讯科技(深圳)有限公司 | 机器学习模型的分布式训练方法、装置以及计算机设备 |
CN111784002B (zh) * | 2020-09-07 | 2021-01-19 | 腾讯科技(深圳)有限公司 | 分布式数据处理方法、装置、计算机设备及存储介质 |
CN112235344B (zh) * | 2020-09-07 | 2022-12-23 | 上海大学 | 一种面向分布式机器学习的稀疏通信模型的实现方法 |
CN112101569A (zh) * | 2020-09-17 | 2020-12-18 | 上海交通大学 | 面向数据周期性的分布式多模型随机梯度下降方法 |
CN112686383B (zh) * | 2020-12-30 | 2024-04-16 | 中山大学 | 一种通信并行的分布式随机梯度下降的方法、系统及装置 |
CN112966438A (zh) * | 2021-03-05 | 2021-06-15 | 北京金山云网络技术有限公司 | 机器学习算法选择方法、分布式计算系统 |
CN113159287B (zh) * | 2021-04-16 | 2023-10-10 | 中山大学 | 一种基于梯度稀疏的分布式深度学习方法 |
CN113300890B (zh) * | 2021-05-24 | 2022-06-14 | 同济大学 | 一种网络化机器学习系统的自适应通信方法 |
CN114118437B (zh) * | 2021-09-30 | 2023-04-18 | 电子科技大学 | 一种面向微云中分布式机器学习的模型更新同步方法 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20040071363A1 (en) * | 1998-03-13 | 2004-04-15 | Kouri Donald J. | Methods for performing DAF data filtering and padding |
US20180068216A1 (en) * | 2015-03-13 | 2018-03-08 | Institute Of Acoustics, Chinese Academy Of Sciences | Big data processing method based on deep learning model satisfying k-degree sparse constraint |
CN109600255A (zh) * | 2018-12-04 | 2019-04-09 | 中山大学 | 一种去中心化的参数服务器优化算法 |
CN109902741A (zh) * | 2019-02-28 | 2019-06-18 | 上海理工大学 | 一种制冷系统故障诊断方法 |
CN109951438A (zh) * | 2019-01-15 | 2019-06-28 | 中国科学院信息工程研究所 | 一种分布式深度学习的通信优化方法及系统 |
-
2019
- 2019-07-01 CN CN201910583390.2A patent/CN110287031B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20040071363A1 (en) * | 1998-03-13 | 2004-04-15 | Kouri Donald J. | Methods for performing DAF data filtering and padding |
US20180068216A1 (en) * | 2015-03-13 | 2018-03-08 | Institute Of Acoustics, Chinese Academy Of Sciences | Big data processing method based on deep learning model satisfying k-degree sparse constraint |
CN109600255A (zh) * | 2018-12-04 | 2019-04-09 | 中山大学 | 一种去中心化的参数服务器优化算法 |
CN109951438A (zh) * | 2019-01-15 | 2019-06-28 | 中国科学院信息工程研究所 | 一种分布式深度学习的通信优化方法及系统 |
CN109902741A (zh) * | 2019-02-28 | 2019-06-18 | 上海理工大学 | 一种制冷系统故障诊断方法 |
Also Published As
Publication number | Publication date |
---|---|
CN110287031A (zh) | 2019-09-27 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110287031B (zh) | 一种减少分布式机器学习通信开销的方法 | |
CN106297774B (zh) | 一种神经网络声学模型的分布式并行训练方法及系统 | |
CN110084378B (zh) | 一种基于本地学习策略的分布式机器学习方法 | |
Tao et al. | {eSGD}: Communication efficient distributed deep learning on the edge | |
CN113191484B (zh) | 基于深度强化学习的联邦学习客户端智能选取方法及系统 | |
Zhang et al. | Deep learning for wireless coded caching with unknown and time-variant content popularity | |
Feng et al. | Mobility-aware cluster federated learning in hierarchical wireless networks | |
CN111382844B (zh) | 一种深度学习模型的训练方法及装置 | |
CN111243045B (zh) | 一种基于高斯混合模型先验变分自编码器的图像生成方法 | |
CN108282501B (zh) | 一种云服务器资源信息同步方法、装置和系统 | |
Brunner et al. | Robust event-triggered MPC for constrained linear discrete-time systems with guaranteed average sampling rate | |
Mitra et al. | Achieving linear convergence in federated learning under objective and systems heterogeneity | |
CN113206887A (zh) | 边缘计算下针对数据与设备异构性加速联邦学习的方法 | |
CN111369009A (zh) | 一种能容忍不可信节点的分布式机器学习方法 | |
CN115374853A (zh) | 基于T-Step聚合算法的异步联邦学习方法及系统 | |
CN114169543A (zh) | 一种基于模型陈旧性与用户参与度感知的联邦学习算法 | |
Mu et al. | Communication and storage efficient federated split learning | |
Wu et al. | From deterioration to acceleration: A calibration approach to rehabilitating step asynchronism in federated optimization | |
Jin et al. | Simulating aggregation algorithms for empirical verification of resilient and adaptive federated learning | |
Sun et al. | On the role of server momentum in federated learning | |
Bhatnagar et al. | Multiscale Q-learning with linear function approximation | |
Cui et al. | The Data Value based Asynchronous Federated Learning for UAV Swarm under Unstable Communication Scenarios | |
CN115423393A (zh) | 一种基于lstm的动态自适应调度周期的订单调度方法及装置 | |
De et al. | Variance reduction for distributed stochastic gradient descent | |
CN116611506B (zh) | 用户分析模型训练方法、用户标签确定方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |