CN111369009A - 一种能容忍不可信节点的分布式机器学习方法 - Google Patents

一种能容忍不可信节点的分布式机器学习方法 Download PDF

Info

Publication number
CN111369009A
CN111369009A CN202010143202.7A CN202010143202A CN111369009A CN 111369009 A CN111369009 A CN 111369009A CN 202010143202 A CN202010143202 A CN 202010143202A CN 111369009 A CN111369009 A CN 111369009A
Authority
CN
China
Prior art keywords
node
gradient
training
machine learning
buffer
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010143202.7A
Other languages
English (en)
Inventor
李武军
杨亦锐
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing University
Original Assignee
Nanjing University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing University filed Critical Nanjing University
Priority to CN202010143202.7A priority Critical patent/CN111369009A/zh
Publication of CN111369009A publication Critical patent/CN111369009A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开一种能容忍不可信节点的分布式机器学习方法,各工作节点从服务器节点获取最新的参数,根据本地存储的数据计算梯度后,将梯度发送给服务器节点,重复该步骤,直到收到服务器的中止消息。服务器节点设置有一定数量的缓冲器,每次接收到梯度信息后,根据发送方工作节点的编号,计算出对应的缓冲器编号,并将该缓冲器中的值更新为已收到对应该缓冲器的所有梯度的平均值;然后判断是否所有缓冲器都存有梯度,若是,则通过聚集函数,根据各缓冲器中的梯度计算出最终梯度,更新模型参数,清空所有缓冲器;再将最新参数发送给该工作节点;不断重复以上训练步骤,直到满足停止条件时,通知各个工作节点停止。

Description

一种能容忍不可信节点的分布式机器学习方法
技术领域
本发明涉及一种能容忍不可信节点的分布式机器学习方法,可以有效地减小分布式机器学习中不可信节点的错误梯度信息带来的负面影响,提升分布式机器学习的鲁棒性。
背景技术
许多机器学习模型可以被形式化为有限和优化问题:
Figure BDA0002399817540000011
其中w为模型的参数,n为训练样本的总数,ξi表示第i个样本,f(w;ξi)表示第i个样本所对应的损失函数,d为模型维度大小。随机梯度下降法(SGD)及其变体是目前应用最广的求解上述有限和优化问题的方法。
参数服务器架构(Parameter Server)是分布式机器学习中常用的一种架构。参数服务器架构包含了一个服务器节点集群以及多个工作节点集群。其中,服务器节点集群由多个服务器节点组成,每个服务器节点维护一部分全局共享参数。各个服务器节点可以彼此通信来复制和/或迁移参数以维持可靠性。一个工作节点集群通常在本地存储一部分训练样本,并利用本地存储的训练样本计算一些局部数据,比如梯度。各个工作节点之间无法直接通信,只能与服务器节点通信来更新和检索共享参数。
随着训练数据量的增大,机器学习训练过程的时间开销也将不断增加。分布式算法通过在多个节点上并行地进行训练来减少学习过程的时间开销。除此之外,在边缘计算、联邦学习这些应用中,训练数据存储在各个终端设备中,因为隐私保护的要求以及通信带宽的限制,服务器无法直接访问训练数据。在这些应用中,往往采用分布式机器学习算法。
在实现数据并行的随机梯度下降法时,工作节点使用不同的数据子集和本地模型副本并行地计算出梯度,并发送给服务器节点。中心化的参数服务器收集梯度,并用来更新参数,然后把更新后的参数发给工作节点。根据学习过程中各个节点是否需要在迭代轮次、模型参数、任务进度等方面保持一致,分布式机器学习算法可以分为同步算法和异步算法。在同步算法中,同步过程会带来额外的时间开销。另一方面,在边缘计算、联邦学习等应用中,大部分时间无法保证所有终端在线,甚至无法实现同步。因此,异步的分布式机器学习算法适用范围更广。
另外,传统的分布式机器学习方法大多假设所有节点是可信的,但在实际应用中,由于数据误标注、传输延迟、软硬件错误、恶意攻击等原因,部分工作节点可能不可信,而这些不可信工作节点向服务器发送的错误消息往往会使得参数错误地更新,导致方法失败。例如在图像识别、语音识别等有监督学习任务中,需要大量有标注数据。在人工标注过程中,不可避免地会出现误标注现象。这些误标注数据导致工作节点计算出错误的梯度消息。在这种情况下,工作节点可能是不可信的。除此之外,在边缘计算和联邦学习等应用中,服务器组织者对工作节点的控制很弱,因而难以保证工作节点的可信度。在这些工作节点可能不可信的情形下,目前已有的异步分布式机器学习,尚未能够较好地解决这一问题。
发明内容
发明目的:在目前的分布式异步机器学习中,若不可信工作节点中存在错误或者恶意攻击,会导致服务器节点收到一个错误的梯度,并将其用于模型参数更新,最终导致机器学习失败。针对上述问题与不足,提供一种能容忍不可信节点的分布式机器学习方法,基于梯度消息缓冲器,服务器节点收到梯度消息后,根据来源工作节点的编号,将其暂存到相应的缓冲器中,在所有缓冲器都存有梯度信息后,再根据聚集函数计算出最终梯度,用于更新参数。可以看出,本发明的方法中,服务器节点接收到梯度消息后,会将其暂存在缓冲器中,更新梯度时结合各个缓冲器中储存的梯度,通过聚集函数计算出最终梯度用于参数更新,因此能有效减小不可信工作节点发送的错误消息给算法带来的负面影响。
技术方案:一种能容忍不可信节点的分布式机器学习机制,其在服务器节点上训练流程的具体步骤为:
步骤100,输入机器学习模型w、工作节点数目m、样本总数n、学习率ηt、缓冲器数量B,以及聚集函数Aggr(·)和哈希函数hash(·);
步骤101,令t=0,初始化模型参数w=w0,并发送给所有工作节点;
步骤102,对于b=1,2,…,B,初始化缓冲器:hb=0,并令缓冲器中已存储梯度的计数器Nb=0;
步骤103,等待,直到收到来自任意工作节点的梯度信息g,并保存该工作节点的编号s;
步骤104,计算缓冲器编号:b=hash(s);
步骤105,更新对应缓冲器中的储存值:
Figure BDA0002399817540000031
步骤106,判断是否所有Nb>0(b=1,2,…,B),若否,则不进行参数更新,直接跳转到步骤110;
步骤107,计算最终梯度:G=Aggr([h1,h2,…,hB]);
步骤108,进行参数更新:wt+1=wtt·G,令t=t+1;
步骤109,清空所有缓冲器:对于b=1,2,…,B,令Nb=0,hb=0;
步骤110,将最新模型参数发送给编号s的工作节点;
步骤111,判断此时是否满足停止条件,若否,则返回步骤103继续训练;
步骤112,通知各个工作节点训练停止工作。
本发明的方法在第k个工作节点上训练流程的具体步骤为:
步骤200,输入训练样本集合的子集
Figure BDA0002399817540000032
(完整的训练样本集合
Figure BDA0002399817540000033
)、每次采样的批量大小l;
步骤201,接收服务器节点发送的模型参数w;
步骤202,从本地数据集
Figure BDA0002399817540000034
中随机挑选一个小批量数据
Figure BDA0002399817540000035
步骤203,根据挑选出的样本数据集
Figure BDA0002399817540000036
计算出随机梯度
Figure BDA0002399817540000037
其中
Figure BDA0002399817540000038
则表示第i个样本ξi所对应的损失函数在当前模型参数下的梯度;
步骤204,发送计算出的随机梯度g到服务器节点;
步骤205,判断是否收到服务器节点发送的停止工作消息,若否,则返回步骤201,继续训练;若是,则结束训练。
有益效果:本发明的方法是一种能容忍不可信节点的分布式机器学习方法,训练过程异步执行,既适用于数据中心的多机集群分布式机器学习,也适用于服务器作为云端、手机或嵌入式设备作为终端的端云协同分布式机器学习,包括边缘计算、联邦学习等应用。本方法利用服务器中设置的缓冲器以及聚集函数,能有效降低不可信节点发送的错误梯度信息(包括偶然错误和恶意攻击导致的错误)对分布式训练算法带来的负面影响,提升分布式机器学习的鲁棒性。
附图说明
图1为本发明实施的能容忍不可信节点的分布式机器学习方法在服务器节点上的工作流程图;
图2为本发明实施的能容忍不可信节点的分布式机器学习方法在工作节点上的工作流程图。
具体实施方式
下面结合具体实施例,进一步阐明本发明,应理解这些实施例仅用于说明本发明而不用于限制本发明的范围,在阅读了本发明之后,本领域技术人员对本发明的各种等价形式的修改均落于本申请所附权利要求所限定的范围。
本发明提供的能容忍不可信节点的分布式机器学习方法,既可应用于多机集群分布式机器学习,也可应用于边缘计算、联邦学习等应用,适合于待分类的数据集样本数多、所使用的机器学习模型参数量大的场景,也适用于数据分布在各个终端上,但由于种种原因无法发送训练数据的场景。本发明可适用于图像分类、文本分类、语音识别等多种任务。以图像分类为例,在本发明的方法中,训练数据存储在若干个工作节点上,而机器学习模型参数由若干个服务器节点共同维护,本发明方法在该应用中的具体工作流程如下:
能容忍不可信节点的分布式机器学习方法,在服务器节点上的工作流程如图1所示。首先输入机器学习模型w、工作节点数目m、样本总数n、学习率ηt、缓冲器数量B,以及聚集函数Aggr(·)和哈希函数hash(·)(步骤100);初始化迭代轮次计数t=0,初始化模型参数w=w0并发送模型参数w0到所有的工作节点(步骤101),初始化缓冲器,对于b=1,2,…,B,令Nb=0,hb=0(步骤102)。随后进入到模型训练的迭代阶段:先等待,直到收到来自任意工作节点的梯度信息g,保存该工作节点的编号s(步骤103);用哈希函数计算该工作节点对应的缓冲器编号:b=hash(s)(步骤104);并更新缓冲器加入该梯度后的储存值:Nb=Nb+1,
Figure BDA0002399817540000041
(步骤105)。之后进行判断,是否所有缓冲器都有储存值,即是否所有Nb>0(b=1,2,…,B)(步骤106),若是,则进入参数更新步骤:通过聚集函数计算最终梯度:G=Aggr([h1,h2,…,hB])(步骤107),用梯度下降法进行参数更新:wt+1=wtt·G,令t=t+1(步骤108),并在更新后清空缓冲器:对于b=1,2,…,B,令Nb=0,hb=0(步骤109)。之后,将最新模型发送回编号为s的工作节点(步骤110)。最后,判断此时是否满足停止条件(步骤111)。若否,则返回步骤103继续训练;若是,则通知各个工作节点训练停止工作(步骤112),结束训练。
能容忍不可信节点的分布式机器学习方法,在第k个工作节点上的工作流程如图2所示。首先输入训练样本集合的子集
Figure BDA0002399817540000051
(完整的训练样本集合
Figure BDA0002399817540000052
)和每次采样的批量大小l(步骤200),随后进入到模型训练阶段:接收服务器节点发送的模型参数w(步骤201),从本地数据集
Figure BDA0002399817540000053
中随机挑选一个小批量数据
Figure BDA0002399817540000054
(步骤202),根据挑选出的样本数据集
Figure BDA0002399817540000055
计算出随机梯度
Figure BDA0002399817540000056
(步骤203),其中
Figure BDA0002399817540000057
则表示第i个样本ξi所对应的损失函数在当前模型参数下的梯度。计算结束后,将梯度g发送到服务器节点(步骤204)。最后判断是否收到来自服务器的停止工作消息(步骤205)。若否,则跳转到步骤201继续训练;若是,则结束训练。
本发明的方法在图像分类数据集上进行了实验。实验过程中,统计了各个工作节点中训练模型精确度的平均值。实验结果表明,在无恶意攻击情况下,本发明的方法与传统的分布式异步随机梯度法精确度相同。在工作节点中存在恶意攻击时,传统的异步随机梯度法完全失效;而本发明提出的方法的预测精确度仅略微降低,能够抵抗不可信节点的恶意攻击,有效降低了不可信节点的错误梯度的负面影响,提升了分布式机器学习的鲁棒性。

Claims (4)

1.一种能容忍不可信节点的分布式机器学习方法,其特征在于,工作节点和服务器节点的主要任务分别如下:各工作节点从服务器节点获取最新的参数,根据本地存储的数据计算梯度后,将梯度发送给服务器节点,并不断重复该步骤,直到收到服务器的中止消息;服务器节点设置有一定数量的缓冲器,每次接收到梯度信息后,根据发送方工作节点的编号,计算出对应的缓冲器编号,并将该缓冲器中的值更新为已收到对应该缓冲器的所有梯度的平均值;然后判断是否所有缓冲器都存有梯度,若是,则通过聚集函数,根据各缓冲器中的梯度计算出最终梯度,更新模型参数,清空所有缓冲器;再将最新参数发送给该工作节点;不断重复以上训练步骤,直到满足停止条件时,通知各个工作节点停止。
2.如权利要求1所述的能容忍不可信节点的分布式机器学习方法,其特征在于,在服务器节点上训练流程的具体步骤为:
步骤100,输入机器学习模型w、工作节点数目m、样本总数n、学习率ηt、缓冲器数量B,以及聚集函数Aggr(·)和哈希函数hash(·);
步骤101,令t=t0,随机初始化模型参数w=w0,并发送给所有工作节点;
步骤102,对于b=1,2,…,B,初始化缓冲器:hb=0,并令缓冲器中已存储梯度的计数器Nb=0;
步骤103,等待,直到收到来自任意工作节点的梯度信息g,并保存该工作节点的编号s;
步骤104,计算缓冲器编号:b=hash(s);
步骤105,更新对应缓冲器中的储存值:
Figure FDA0002399817530000011
步骤106,判断是否所有Nb>0,若否,则不进行参数更新,直接跳转到步骤110;
步骤107,计算最终梯度:G=Aggr([h1,h2,…,hB]);
步骤108,进行参数更新:wt+1=wtt·G,令t=t+1;
步骤109,清空所有缓冲器:对于b=1,2,…,B,令Nb=0,hb=0;
步骤110,将最新模型参数发送给编号s的工作节点;
步骤111,判断此时是否满足停止条件,若否,则返回步骤103继续训练;
步骤112,通知各个工作节点训练停止工作。
3.如权利要求1所述的能容忍不可信节点的分布式机器学习方法,其特征在于,在第k个工作节点上训练流程的具体步骤为:
步骤200,输入训练样本集合的子集
Figure FDA0002399817530000021
和每次采样的批量大小l;
步骤201,接收服务器节点发送的模型参数w;
步骤202,从本地数据集
Figure FDA0002399817530000022
中随机挑选一个小批量数据
Figure FDA0002399817530000023
步骤203,根据挑选出的样本数据集
Figure FDA0002399817530000024
计算出随机梯度
Figure FDA0002399817530000025
其中
Figure FDA0002399817530000026
则表示第i个样本ξi所对应的损失函数在当前模型参数下的梯度;
步骤204,发送计算出的随机梯度g到服务器节点;
步骤205,判断是否收到服务器节点发送的停止工作消息,若否,则返回步骤201,继续训练;若是,则结束训练。
4.如权利要求2所述的能容忍不可信节点的分布式机器学习方法,其特征在于:步骤102以及步骤104-109中,生成若干个缓冲器,用于暂存服务器节点接收到的梯度信息,在所有缓冲器都已存有梯度信息后再通过聚集函数Aggr(·),计算出最终梯度,进行模型参数更新。
CN202010143202.7A 2020-03-04 2020-03-04 一种能容忍不可信节点的分布式机器学习方法 Pending CN111369009A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010143202.7A CN111369009A (zh) 2020-03-04 2020-03-04 一种能容忍不可信节点的分布式机器学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010143202.7A CN111369009A (zh) 2020-03-04 2020-03-04 一种能容忍不可信节点的分布式机器学习方法

Publications (1)

Publication Number Publication Date
CN111369009A true CN111369009A (zh) 2020-07-03

Family

ID=71208513

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010143202.7A Pending CN111369009A (zh) 2020-03-04 2020-03-04 一种能容忍不可信节点的分布式机器学习方法

Country Status (1)

Country Link
CN (1) CN111369009A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111709533A (zh) * 2020-08-19 2020-09-25 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备
CN111814968A (zh) * 2020-09-14 2020-10-23 北京达佳互联信息技术有限公司 用于机器学习模型的分布式训练的方法和装置
WO2022012129A1 (zh) * 2020-07-17 2022-01-20 华为技术有限公司 云服务系统的模型处理方法及云服务系统
CN114461392A (zh) * 2022-01-25 2022-05-10 西南交通大学 一种带宽感知的选择性数据多播方法
WO2022121804A1 (zh) * 2020-12-10 2022-06-16 华为技术有限公司 半异步联邦学习的方法和通信装置

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110084378A (zh) * 2019-05-07 2019-08-02 南京大学 一种基于本地学习策略的分布式机器学习方法
CN110287031A (zh) * 2019-07-01 2019-09-27 南京大学 一种减少分布式机器学习通信开销的方法

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110084378A (zh) * 2019-05-07 2019-08-02 南京大学 一种基于本地学习策略的分布式机器学习方法
CN110287031A (zh) * 2019-07-01 2019-09-27 南京大学 一种减少分布式机器学习通信开销的方法

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022012129A1 (zh) * 2020-07-17 2022-01-20 华为技术有限公司 云服务系统的模型处理方法及云服务系统
CN111709533A (zh) * 2020-08-19 2020-09-25 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备
CN111814968A (zh) * 2020-09-14 2020-10-23 北京达佳互联信息技术有限公司 用于机器学习模型的分布式训练的方法和装置
CN111814968B (zh) * 2020-09-14 2021-01-12 北京达佳互联信息技术有限公司 用于机器学习模型的分布式训练的方法和装置
WO2022121804A1 (zh) * 2020-12-10 2022-06-16 华为技术有限公司 半异步联邦学习的方法和通信装置
CN114461392A (zh) * 2022-01-25 2022-05-10 西南交通大学 一种带宽感知的选择性数据多播方法
CN114461392B (zh) * 2022-01-25 2023-03-31 西南交通大学 一种带宽感知的选择性数据多播方法

Similar Documents

Publication Publication Date Title
CN111369009A (zh) 一种能容忍不可信节点的分布式机器学习方法
EP3540652B1 (en) Method, device, chip and system for training neural network model
US8904149B2 (en) Parallelization of online learning algorithms
CN108009642B (zh) 分布式机器学习方法和系统
US20180302297A1 (en) Methods and systems for controlling data backup
CN106919957B (zh) 处理数据的方法及装置
CN113760553B (zh) 一种基于蒙特卡洛树搜索的混部集群任务调度方法
CN110471621B (zh) 一种面向实时数据处理应用的边缘协同存储方法
CN110322931A (zh) 一种碱基识别方法、装置、设备及存储介质
WO2020236250A1 (en) Efficient freshness crawl scheduling
CN110414569A (zh) 聚类实现方法及装置
CN113887748B (zh) 在线联邦学习任务分配方法、装置、联邦学习方法及系统
CN109976873B (zh) 容器化分布式计算框架的调度方案获取方法及调度方法
US8756093B2 (en) Method of monitoring a combined workflow with rejection determination function, device and recording medium therefor
CN108595251B (zh) 动态图更新方法、装置、存储引擎接口和程序介质
CN112506658B (zh) 一种服务链中动态资源分配和任务调度方法
EP3896908B1 (en) Event stream processing system, event stream processing method, event stream processing program
CN106502842B (zh) 数据恢复方法及系统
CN116361271B (zh) 一种区块链数据修改迁移方法、电子设备及存储介质
Zhang et al. Txallo: Dynamic transaction allocation in sharded blockchain systems
US20210081878A1 (en) Automatic generation of tasks and retraining machine learning modules to generate tasks based on feedback for the generated tasks
CN110138723A (zh) 一种邮件网络中恶意社区的确定方法及系统
Goldsztajn et al. Utility maximizing load balancing policies
CN117290363B (zh) 一种面向救援活动的异构数据管理方法及系统
US11665110B1 (en) Using distributed services to continue or fail requests based on determining allotted time and processing time

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination