CN110084378B - 一种基于本地学习策略的分布式机器学习方法 - Google Patents

一种基于本地学习策略的分布式机器学习方法 Download PDF

Info

Publication number
CN110084378B
CN110084378B CN201910375050.0A CN201910375050A CN110084378B CN 110084378 B CN110084378 B CN 110084378B CN 201910375050 A CN201910375050 A CN 201910375050A CN 110084378 B CN110084378 B CN 110084378B
Authority
CN
China
Prior art keywords
local
machine learning
gradient
distributed machine
parameters
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN201910375050.0A
Other languages
English (en)
Other versions
CN110084378A (zh
Inventor
李武军
高昊
赵申宜
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing University
Original Assignee
Nanjing University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing University filed Critical Nanjing University
Priority to CN201910375050.0A priority Critical patent/CN110084378B/zh
Publication of CN110084378A publication Critical patent/CN110084378A/zh
Application granted granted Critical
Publication of CN110084378B publication Critical patent/CN110084378B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于本地学习策略的分布式机器学习方法,基于参数服务器架构,既适用于数据中心的多机集群分布式机器学习,也适用于服务器作为云端、手机或嵌入式设备作为终端的端云协同分布式机器学习。包括以下步骤:首先服务器节点累计所有工作节点计算的本地梯度和得到全梯度,并将全梯度广播给所有工作节点;随后每个工作节点各自进行若干次参数更新后将本地的参数发送给服务器节点;最后服务器节点将从工作节点收集到的参数求均值作为最新参数广播给所有工作节点;上述过程迭代多轮直到达到收敛条件。本发明的方法基于本地学习策略,不需要在工作节点每次参数更新后都进行通信,从而减少了分布式机器学习中的通信开销。

Description

一种基于本地学习策略的分布式机器学习方法
技术领域
本发明提供了一种基于本地学习策略的分布式机器学习方法,涉及机器学习领域的分布式算法,可以有效地减少分布式机器学习中的通信开销。
背景技术
大部分机器学习模型可以被形式化为以下优化问题:
Figure BDA0002051360910000011
其中w代表了模型的参数,n代表了训练样本的总数,fi(·)则表示第i个样本所对应的损失函数。为了求解上述优化问题,随机梯度下降法(SGD)以及它的变体是目前应用最为广泛的方法。随着训练数据量的增大,很多机器学习问题的训练过程需要花费大量的时间,分布式算法将训练数据分散到多个节点上并行地进行训练,以此来加速机器学习的训练过程。
参数服务器架构(Parameter Server)是分布式机器学习中常用的一种架构,该架构具有良好的可扩展性和容错性,同时也支持灵活的一致性模型。参数服务器架构中包含两种类型的节点:模型参数存储在一个或是多个服务器节点(Server)上,训练样本数据存储在多个工作节点(Worker)上。
在基于参数服务器架构的分布式随机梯度下降法中,每一次参数更新可以描述为以下过程:首先服务器节点将当前的模型参数广播给所有工作节点;随后每个工作节点在本地的样本集合中随机选取一个样本(假设其样本编号为i),并计算出该样本所对应的损失函数的梯度
Figure BDA0002051360910000012
最后所有工作节点将梯度
Figure BDA0002051360910000013
发送给服务器节点,在服务器节点收集到所有工作节点上的梯度后,使用随机梯度下降法更新模型参数。一次机器学习问题的训练过程,往往要经历很多次参数更新才能接近全局最优解或是局部最优解。
随着机器学习模型的增大和参与分布式计算的节点数增多,每次参数更新时节点之间的通信往往会成为性能瓶颈。
发明内容
发明目的:目前的分布式随机梯度下降法在每次参数更新时节点之间都需要进行通信以同步参数,随着机器学习模型的增大和参与分布式计算的节点数增多,这样的方法中所消耗的通信时间也会越来越长,通信开销往往会成为性能瓶颈。针对上述问题与不足,提供一种基于本地学习策略的分布式机器学习方法,基于本地学习策略,每个工作节点在接收到服务器节点所发送的当前模型参数后,会在本地使用类似于随机方差缩减梯度下降法(SVRG)的方式进行若干次参数更新并得到一个本地参数。在这之后服务器节点才会与所有工作节点进行通信,并将所有工作节点上本地参数的均值作为新的模型参数。由此可见,本发明的方法中通信频率明显降低,所以能有效地减少分布式机器学习中的通信开销,从而达到加速分布式机器学习训练过程的目的。
技术方案:一种基于本地学习策略的分布式机器学习方法,其在服务器节点上训练流程的具体步骤为:
步骤100,输入机器学习模型w以及总共的迭代轮数T、工作节点数目p、样本总数n;
步骤101,随机初始化模型参数w=w0
步骤102,将当前的模型参数wt广播给所有的工作节点;
步骤103,收集所有工作节点计算的本地梯度和zk
步骤104,计算出全梯度
Figure BDA0002051360910000021
步骤105,将全梯度z广播给所有的工作节点;
步骤106,收集所有工作节点计算的本地参数uk
步骤107,更新模型参数
Figure BDA0002051360910000022
步骤108,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则输出并保存模型w;否则返回步骤102继续进行训练。
本发明的方法在第k个工作节点上训练流程的具体步骤为:
步骤200,输入训练样本集合的子集
Figure BDA0002051360910000023
(完整的训练样本集合
Figure BDA0002051360910000024
以及总共的迭代轮数T、学习率η、本地更新次数M;
步骤201,接受服务器节点发送的模型参数wt
步骤202,根据本地的样本数据
Figure BDA0002051360910000025
计算出本地梯度和
Figure BDA0002051360910000026
其中
Figure BDA0002051360910000027
则表示第i个样本所对应的损失函数在当前模型参数下的梯度;
步骤203,将本地梯度和zk发送给服务器节点;
步骤204,接受服务器节点发送的全梯度z;
步骤205,根据当前的模型参数wt、全梯度z与本地的样本数据
Figure BDA0002051360910000034
进行M次本地参数更新;
步骤206,将本地参数uk发送给服务器节点;
步骤207,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则结束训练流程;否则返回步骤201继续进行训练。
在第k个工作节点上进行步骤205的本地参数更新的具体流程为:首先输入当前模型参数wt、全梯度z、本地的样本数据
Figure BDA0002051360910000031
以及学习率η、本地更新次数M;随后初始化本地参数uk,0=wt;最后从本地的样本数据
Figure BDA0002051360910000032
中随机选取一个样本(假设其样本编号为ik,m),并按照以下公式更新本地参数uk:
Figure BDA0002051360910000033
其中uk,m代表第m次更新时的本地参数,c为人工设置的超参数,c(uk,m-wt)这一项用于减小本地学习策略所带来的偏差,从而保证本发明方法的收敛性。此外,如果每个工作节点的本地样本数据分布和全局样本数据分布相差不大,c可以设置为0。重复以上的步骤M次后即完成了本地参数更新的流程。
有益效果:本发明提供的基于本地学习策略的分布式机器学习方法,既适用于数据中心的多机集群分布式机器学习,也适用于服务器作为云端、手机或嵌入式设备作为终端的端云协同分布式机器学习。本发明的方法基于本地学习策略,节点之间每经历若干次本地的参数更新之后才会进行一次通信,与现有技术相比,本发明的方法不需要在每次参数更新后都进行通信以同步参数,从而减少了分布式机器学习中的通信开销。
附图说明
图1为本发明实施的基于本地学习策略的分布式机器学习方法在服务器节点上的工作流程图;
图2为本发明实施的基于本地学习策略的分布式机器学习方法在工作节点上的工作流程图;
图3为本发明实施的在工作节点上进行本地参数更新的工作流程图。
具体实施方式
下面结合具体实施例,进一步阐明本发明,应理解这些实施例仅用于说明本发明而不用于限制本发明的范围,在阅读了本发明之后,本领域技术人员对本发明的各种等价形式的修改均落于本申请所附权利要求所限定的范围。
本发明提供的基于本地学习策略的分布式机器学习方法,可应用于图像分类、文本分类等领域,适合于待分类的数据集样本数多、所使用的机器学习模型参数量大的场景。以图像分类应用为例,在本发明的方法中,训练图像数据将分布式的存储在若干个工作节点上,而机器学习模型参数将由若干个服务器节点共同维护,在图像分类应用中的具体工作流程如下所述:
基于本地学习策略的分布式机器学习方法,在服务器节点上的工作流程如图1所示。首先输入机器学习模型w以及总共的迭代轮数T、工作节点数目p、样本总数n(步骤100),并随机初始化模型参数w=w0(步骤101)。接下来初始化迭代轮数计数器t=0(步骤102),随后进入到模型训练的迭代阶段:先将当前模型参数wt广播给所有的工作节点(步骤103),并收集所有工作节点计算的本地梯度和zk(步骤104);根据收集的本地梯度和计算出全梯度
Figure BDA0002051360910000041
(步骤105),并将全梯度z广播给所有工作节点(步骤106);最后收集所有工作节点计算的本地参数uk(步骤107)并更新模型参数
Figure BDA0002051360910000042
(步骤108)。每次迭代结束时将迭代轮数计数器增加1(步骤109)并进行判断是否达到停止条件t=T(步骤110),若未达到停止条件则继续迭代,否则输出训练结果并保存模型(步骤111)。
基于本地学习策略的分布式机器学习方法,在第k个工作节点上的工作流程如图2所示。首先输入本地训练图像数据
Figure BDA0002051360910000043
以及总共的迭代轮数T、学习率η、本地更新次数M(步骤200),本地训练图像数据为完整训练图像数据集合的一个子集(完整训练图像数据集合
Figure BDA0002051360910000044
)。接下来初始化迭代轮数计数器t=0(步骤201),随后进入到模型训练的迭代阶段:先接受服务器节点发送的模型参数wt(步骤202),并根据本地训练图像数据
Figure BDA0002051360910000045
计算出本地梯度和
Figure BDA0002051360910000046
Figure BDA0002051360910000047
(步骤203);随后将本地梯度和zk发送给服务器节点(步骤204),并接受服务器节点发送的全梯度z(步骤205);最后根据全梯度z、本地训练图像数据
Figure BDA0002051360910000051
以及当前模型参数wt进行M次本地参数更新(步骤206),并在更新结束后将本地参数uk发送给服务器节点(步骤207)。每次迭代结束时将迭代轮数计数器增加1(步骤208)并进行判断是否达到停止条件t=T(步骤209),若未达到停止条件则继续迭代,否则结束训练流程(步骤210)。
在第k个工作节点上进行本地参数更新的工作流程图如图3所示。首先读取当前的模型参数wt、全梯度z、本地训练图像数据
Figure BDA0002051360910000052
以及学习率η、本地更新次数M(步骤2060),并初始化本地参数uk,0=wt(步骤2061)和更新次数计数器m=0(步骤2062)。随后进入迭代更新的过程:先从本地训练图像数据
Figure BDA0002051360910000053
中随机选取一张编号为ik,m的图像样本(步骤2063),随后按照以下公式更新本地参数uk(步骤2064):
Figure BDA0002051360910000054
最后将更新次数计数器m增加1(步骤2065);重复上述步骤,直到满足停止条件m=M(步骤2066),输出本地参数模型uk(步骤2067)。
本发明的方法在多个图像分类、文本分类数据集上进行了实验。实验结果表明,本发明提出的方法相比于其他分布式机器学习方法具有更高的效率。

Claims (3)

1.一种基于本地学习策略的分布式机器学习方法,其特征在于,其在服务器节点上训练流程的具体步骤为:
步骤100,输入机器学习模型w以及总共的迭代轮数T、工作节点数目p、样本总数n;
步骤101,随机初始化模型参数w=w0
步骤102,将当前的模型参数wt广播给所有的工作节点;
步骤103,收集所有工作节点计算的本地梯度和zk
步骤104,计算出全梯度
Figure FDA0004072335060000011
步骤105,将全梯度z广播给所有的工作节点;
步骤106,收集所有工作节点计算的本地参数uk
步骤107,更新模型参数
Figure FDA0004072335060000012
步骤108,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则输出并保存模型w;否则返回步骤102继续进行训练;
在第k个工作节点上训练流程的具体步骤为:
步骤200,输入训练样本集合的子集
Figure FDA0004072335060000013
以及总共的迭代轮数T、学习率η、本地更新次数M;完整的训练样本集合
Figure FDA0004072335060000014
步骤201,接受服务器节点发送的模型参数wt
步骤202,根据本地的样本数据
Figure FDA0004072335060000015
计算出本地梯度和
Figure FDA0004072335060000016
其中
Figure FDA0004072335060000017
则表示第i个样本所对应的损失函数在当前模型参数下的梯度;
步骤203,将本地梯度和zk发送给服务器节点;
步骤204,接受服务器节点发送的全梯度z;
步骤205,根据当前的模型参数wt、全梯度z与本地的样本数据
Figure FDA0004072335060000018
进行M次本地参数更新;
步骤206,将本地参数uk发送给服务器节点;
步骤207,判断当前已完成的迭代轮数t是否达到总共的迭代轮数T,如果是则结束训练流程;否则返回步骤201继续进行训练。
2.如权利要求1所述的基于本地学习策略的分布式机器学习方法,其特征在于,在第k个工作节点上进行本地参数更新的具体流程为:首先输入当前模型参数wt、全梯度z、本地的样本数据
Figure FDA0004072335060000021
以及学习率η、本地更新次数M;随后初始化本地参数uk,0=wt;最后从本地的样本数据
Figure FDA0004072335060000022
中随机选取一个样本,设其样本编号为ik,m,并按照以下公式更新本地参数uk
Figure FDA0004072335060000023
其中uk,m代表第m次更新时的本地参数,c为人工设置的超参数,c(uk,m-wt)这一项用于减小本地学习策略所带来的偏差;
重复以上的步骤M次后即完成了本地参数更新的流程。
3.如权利要求1所述的基于本地学习策略的分布式机器学习方法,其特征在于:分布式机器学习方法是基于参数服务器架构的,既适用于数据中心的多机集群分布式机器学习,也适用于服务器作为云端、手机或嵌入式设备作为终端的端云协同分布式机器学习。
CN201910375050.0A 2019-05-07 2019-05-07 一种基于本地学习策略的分布式机器学习方法 Active CN110084378B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910375050.0A CN110084378B (zh) 2019-05-07 2019-05-07 一种基于本地学习策略的分布式机器学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910375050.0A CN110084378B (zh) 2019-05-07 2019-05-07 一种基于本地学习策略的分布式机器学习方法

Publications (2)

Publication Number Publication Date
CN110084378A CN110084378A (zh) 2019-08-02
CN110084378B true CN110084378B (zh) 2023-04-21

Family

ID=67418970

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910375050.0A Active CN110084378B (zh) 2019-05-07 2019-05-07 一种基于本地学习策略的分布式机器学习方法

Country Status (1)

Country Link
CN (1) CN110084378B (zh)

Families Citing this family (15)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110929878B (zh) * 2019-10-30 2023-07-04 同济大学 一种分布式随机梯度下降方法
US11379727B2 (en) * 2019-11-25 2022-07-05 Shanghai United Imaging Intelligence Co., Ltd. Systems and methods for enhancing a distributed medical network
CN111027708A (zh) * 2019-11-29 2020-04-17 杭州电子科技大学舟山同博海洋电子信息研究院有限公司 一种面向分布式机器学习的参数通信优化方法
CN111369009A (zh) * 2020-03-04 2020-07-03 南京大学 一种能容忍不可信节点的分布式机器学习方法
CN111444021B (zh) * 2020-04-02 2023-03-24 电子科技大学 基于分布式机器学习的同步训练方法、服务器及系统
CN111325417B (zh) * 2020-05-15 2020-08-25 支付宝(杭州)信息技术有限公司 实现隐私保护的多方协同更新业务预测模型的方法及装置
CN113946434A (zh) * 2020-07-17 2022-01-18 华为技术有限公司 云服务系统的模型处理方法及云服务系统
CN111709533B (zh) * 2020-08-19 2021-03-30 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备
CN112085524B (zh) * 2020-08-31 2022-11-15 中国人民大学 一种基于q学习模型的结果推送方法和系统
CN111814968B (zh) * 2020-09-14 2021-01-12 北京达佳互联信息技术有限公司 用于机器学习模型的分布式训练的方法和装置
CN112381218B (zh) * 2020-11-20 2022-04-12 中国人民解放军国防科技大学 一种用于分布式深度学习训练的本地更新方法
CN114548356A (zh) * 2020-11-27 2022-05-27 华为技术有限公司 一种机器学习方法、装置和系统
CN112561078B (zh) * 2020-12-18 2021-12-28 北京百度网讯科技有限公司 分布式的模型训练方法及相关装置
CN115633031B (zh) * 2022-09-06 2024-02-23 鹏城实验室 一种启发式指导的异步历史优化方法及相关设备
CN116070720B (zh) * 2023-03-23 2023-07-21 山东海量信息技术研究院 基于分布式集群的数据处理方法、系统、设备及存储介质

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US9269054B1 (en) * 2011-11-09 2016-02-23 BigML, Inc. Methods for building regression trees in a distributed computing environment
CN108829441B (zh) * 2018-05-14 2022-10-18 中山大学 一种分布式深度学习的参数更新优化系统
CN109600255A (zh) * 2018-12-04 2019-04-09 中山大学 一种去中心化的参数服务器优化算法

Also Published As

Publication number Publication date
CN110084378A (zh) 2019-08-02

Similar Documents

Publication Publication Date Title
CN110084378B (zh) 一种基于本地学习策略的分布式机器学习方法
CN110287031B (zh) 一种减少分布式机器学习通信开销的方法
US11514309B2 (en) Method and apparatus for accelerating distributed training of a deep neural network
CN109561148B (zh) 边缘计算网络中基于有向无环图的分布式任务调度方法
CN111444021B (zh) 基于分布式机器学习的同步训练方法、服务器及系统
CN102158417A (zh) 实现多约束QoS路由选择的优化方法及装置
CN111369009A (zh) 一种能容忍不可信节点的分布式机器学习方法
Lee et al. Accurate and fast federated learning via IID and communication-aware grouping
CN113313349B (zh) 卫星任务资源匹配优化方法、装置、存储介质和电子设备
CN115034615A (zh) 一种用于作业车间调度的提高遗传规划调度规则中特征选择效率的方法
CN113391907A (zh) 一种任务的放置方法、装置、设备和介质
Badri et al. A sample average approximation-based parallel algorithm for application placement in edge computing systems
Stützle et al. Automatic (offline) configuration of algorithms
CN103944748B (zh) 基于遗传算法的网络关键节点的自相似流量生成简化方法
CN114331709A (zh) 基于偏序的区块链多版本交易时序化方法及系统
CN110119268B (zh) 基于人工智能的工作流优化方法
CN111711702A (zh) 一种基于通信拓扑的分布式协同交互方法及系统
CN115766475B (zh) 基于通信效率的半异步电力联邦学习网络及其通信方法
CN117787440A (zh) 一种面向非独立同分布数据的车联网多阶段联邦学习方法
Liu et al. QoS multicast routing based on particle swarm optimization
CN115242838B (zh) 一种车载边缘计算中服务协同卸载的方法
CN112861315B (zh) 一种电力系统非凸单目标最优潮流全局解的一维下降搜索法
CN114330743A (zh) 一种用于最小-最大化问题的跨设备联邦学习方法
CN110135747B (zh) 基于神经网络的流程定制方法
Kaynar et al. The cross-entropy method with patching for rare-event simulation of large Markov chains

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant