CN116070719B - 一种跨计算节点分布式训练高效通信方法及系统 - Google Patents

一种跨计算节点分布式训练高效通信方法及系统 Download PDF

Info

Publication number
CN116070719B
CN116070719B CN202310271228.3A CN202310271228A CN116070719B CN 116070719 B CN116070719 B CN 116070719B CN 202310271228 A CN202310271228 A CN 202310271228A CN 116070719 B CN116070719 B CN 116070719B
Authority
CN
China
Prior art keywords
local
global
distributed training
update amount
updating
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202310271228.3A
Other languages
English (en)
Other versions
CN116070719A (zh
Inventor
彭涵阳
秦爽
余跃
王进
王晖
李革
高文
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peng Cheng Laboratory
Original Assignee
Peng Cheng Laboratory
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peng Cheng Laboratory filed Critical Peng Cheng Laboratory
Priority to CN202310271228.3A priority Critical patent/CN116070719B/zh
Publication of CN116070719A publication Critical patent/CN116070719A/zh
Application granted granted Critical
Publication of CN116070719B publication Critical patent/CN116070719B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D30/00Reducing energy consumption in communication networks
    • Y02D30/70Reducing energy consumption in communication networks in wireless communication networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Cable Transmission Systems, Equalization Of Radio And Reduction Of Echo (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种跨计算节点分布式训练高效通信方法及系统,所述方法包括:在中心服务器上构建分布式训练机器学习模型;获取分布式训练机器学习模型中每个计算节点的本地更新量,并对本地更新量进行量化,得到量化后的本地更新量;根据量化后的本地更新量得到全局更新量,并对全局更新量进行量化,得到量化后的全局更新量;在各计算节点中,根据量化后的全局更新量更新分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。本发明通过量化方法将计算节点间所需要通信的数据进行压缩以减少通讯数据量,不会影响最终收敛特性,从而减少通信时间,提高系统的整体训练效率。

Description

一种跨计算节点分布式训练高效通信方法及系统
技术领域
本发明涉及计算机深度学习技术领域,具体涉及一种跨计算节点分布式训练高效通信方法及系统。
背景技术
如今机器学习模型的规模越来越大,在单计算节点上训练大模型变得非常低效甚至变得不可能。超大规模智能模型在多计算节点甚至跨地域计算中心计算节点上进行分布式并行训练成为必然趋势。分布式并行训练模型的过程中,为保持最终模型的有效性,各计算节点需要频繁且大量地通信交换优化器所需要的数据,因此通信时间可能比各计算节点的本地计算时间更长,通信效率低下导致无法高效训练。
因此,现有技术还有待于改进和发展。
发明内容
本发明要解决的技术问题在于,针对现有技术的上述缺陷,提供一种跨计算节点分布式训练高效通信方法及系统,旨在解决现有技术中通信效率低下导致无法高效训练的问题。
本发明解决技术问题所采用的技术方案如下:
第一方面,本发明提供一种跨计算节点分布式训练高效通信方法,其中,所述方法包括:
在中心服务器上构建分布式训练机器学习模型;
获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。
在一种实现方式中,所述在中心服务器上构建分布式训练机器学习模型,包括:
构建所述分布式训练机器学习模型为
Figure SMS_1
其中,
Figure SMS_3
是所述分布式训练机器学习模型的d-维模型参数,/>
Figure SMS_5
是参数维度,/>
Figure SMS_7
是分布式计算节点的数目,/>
Figure SMS_4
表示第/>
Figure SMS_6
个计算节点标示,/>
Figure SMS_8
是在第/>
Figure SMS_9
个计算节点上随机采样的样本,/>
Figure SMS_2
表示损失函数。
在一种实现方式中,所述在中心服务器上构建分布式训练机器学习模型后,还包括:
初始化所述分布式训练机器学习模型的模型参数;其中,所述分布式训练机器学习模型上所有计算节点的模型参数
Figure SMS_11
都初始化为/>
Figure SMS_14
,学习率为/>
Figure SMS_17
,冲量因子固定为/>
Figure SMS_12
,第/>
Figure SMS_15
个计算节点上的本地冲量初始化为/>
Figure SMS_18
,/>
Figure SMS_19
,第/>
Figure SMS_10
个计算节点上的本地误差补偿为/>
Figure SMS_13
,全局误差补偿初始化为/>
Figure SMS_16
在一种实现方式中,所述获取所述分布式训练机器学习模型中每个计算节点的本地更新量,包括:
在第
Figure SMS_20
个计算节点上随机采样,得到样本/>
Figure SMS_21
根据所述样本
Figure SMS_22
,得到第/>
Figure SMS_23
个计算节点上的本地梯度为
Figure SMS_24
,其中/>
Figure SMS_25
为梯度算子,/>
Figure SMS_26
为更新时刻,/>
Figure SMS_27
为/>
Figure SMS_28
时刻的模型参数;
根据所述第
Figure SMS_29
个计算节点上的本地梯度,得到第/>
Figure SMS_30
个计算节点上的两个本地冲量为
Figure SMS_31
和/>
Figure SMS_32
,其中/>
Figure SMS_33
为冲量因子;
根据所述第
Figure SMS_34
个计算节点上的两个本地冲量,得到第/>
Figure SMS_35
个计算节点上的所述本地更新量为
Figure SMS_36
,其中/>
Figure SMS_37
,/>
Figure SMS_38
为所述本地冲量;
将在第
Figure SMS_39
个计算节点上的所述本地更新量加上本地误差补偿,更新所述本地更新量为
Figure SMS_40
,其中/>
Figure SMS_41
为本地误差补偿。
在一种实现方式中,所述对所述本地更新量进行量化,得到量化后的本地更新量,包括:
在第
Figure SMS_42
个计算节点上采用伯努利二值分布法将所述本地更新量进行量化,得到所述量化后的本地更新量为
Figure SMS_43
其中
Figure SMS_44
在一种实现方式中,所述对所述本地更新量进行量化,得到量化后的本地更新量之后,包括:
在第
Figure SMS_45
个计算节点上更新误差补偿,得到更新的误差补偿为
Figure SMS_46
在一种实现方式中,所述根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量,包括:
将各计算节点的所述量化后的本地更新量进行平均,得到所述全局更新量为
Figure SMS_47
将所述全局更新量加上全局误差补偿,更新所述全局更新量为
Figure SMS_48
,其中/>
Figure SMS_49
为所述全局误差补偿;
对所述全局更新量采用伯努利二值分布法进行量化,得到所述量化后的全局更新量为
Figure SMS_50
其中
Figure SMS_51
在一种实现方式中,所述根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量之后,还包括:
更新所述全局误差补偿,得到更新的全局误差补偿为
Figure SMS_52
在一种实现方式中,所述在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型,包括:
将所述量化后的全局更新量
Figure SMS_53
下发到各计算节点上;
在第
Figure SMS_54
个计算节点上更新所述模型参数为/>
Figure SMS_55
第二方面,本发明实施例还提供一种跨计算节点分布式训练高效通信装置,其中,所述装置包括:
模型构建模块,用于在中心服务器上构建分布式训练机器学习模型;
本地更新量量化模块,用于获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
全局更新量量化模块,用于根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
模型更新模块,用于在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。
在一种实现方式中,所述本地更新量量化模块包括:
本地更新量量化单元,用于在第
Figure SMS_56
个计算节点上采用伯努利二值分布法将所述本地更新量进行量化,得到所述量化后的本地更新量为
Figure SMS_57
其中,
Figure SMS_58
其中,
Figure SMS_59
为更新时刻,/>
Figure SMS_60
为/>
Figure SMS_61
时刻第/>
Figure SMS_62
个计算节点上的本地更新量,/>
Figure SMS_63
是参数维度。
第三方面,本发明实施例还提供一种跨计算节点分布式训练高效通信系统,所述系统包括中心服务器、多个计算节点以及在所述系统上运行的跨计算节点分布式训练高效通信程序,所述处理器执行所述跨计算节点分布式训练高效通信程序时,实现如以上任一项所述的跨计算节点分布式训练高效通信方法的步骤。
第四方面,本发明实施例还提供一种计算机可读存储介质,其中,所述计算机可读存储介质上存储有跨计算节点分布式训练高效通信程序,所述跨计算节点分布式训练高效通信程序被处理器执行时,实现如以上任一项所述的跨计算节点分布式训练高效通信方法的步骤。
有益效果:与现有技术相比,本发明提供了一种跨计算节点分布式训练高效通信方法,首先分布式训练机器学习模型,并获取所述分布式训练机器学习模型中每个计算节点的本地更新量,然后对所述本地更新量进行量化,得到量化后的本地更新量。通过对本地更新量进行量化,可将每一次迭代步中计算节点间所需要通信的数据从32比特量化压缩到1比特,而不会影响最终收敛特性,从而减小通信时间,提高系统的整体训练效率。然后,对所述全局更新量进行量化,得到量化后的全局更新量以更新分布式训练机器学习模型,通过进一步压缩下发全局更新量时通信数据的比特值,提高通讯效率,以保证高效训练。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明中记载的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的跨计算节点分布式训练高效通信方法流程示意图。
图2是本发明实施例提供的SGD和BinSGD在IMAGENET上训练ResNet-50网络时损失函数趋势图。
图3是本发明实施例提供的跨计算节点分布式训练高效通信装置的原理框图。
图4是本发明实施例提供的跨计算节点分布式训练高效通信系统的内部结构原理框图。
具体实施方式
为使本发明的目的、技术方案及效果更加清楚、明确,以下参照附图并举实施例对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
本技术领域技术人员可以理解,除非特意声明,这里使用的单数形式“一”、“一个”、“所述”和“该”也可包括复数形式。应该进一步理解的是,本发明的说明书中使用的措辞“包括”是指存在所述特征、整数、步骤、操作、元件和/或组件,但是并不排除存在或添加一个或多个其他特征、整数、步骤、操作、元件、组件和/或它们的组。应该理解,当我们称元件被“连接”或“耦接”到另一元件时,它可以直接连接或耦接到其他元件,或者也可以存在中间元件。此外,这里使用的“连接”或“耦接”可以包括无线连接或无线耦接。这里使用的措辞“和/或”包括一个或更多个相关联的列出项的全部或任一单元和全部组合。
本技术领域技术人员可以理解,除非另外定义,这里使用的所有术语(包括技术术语和科学术语),具有与本发明所属领域中的普通技术人员的一般理解相同的意义。还应该理解的是,诸如通用字典中定义的那些术语,应该被理解为具有与现有技术的上下文中的意义一致的意义,并且除非像这里一样被特定定义,否则不会用理想化或过于正式的含义来解释。
数据并行是指在分布式系统中不同计算节点同时并行运行训练同一批数据的不同子集并且在每次迭代步中都需要将不同计算节点间需要通信以聚合所有计算节点的梯度,它是分布式并行训练的最基本的一种并行技术。在保证最终训练后的评价指标无显著下降的前提下,将数据并行过程中将各计算节点和参数服务器之间通信交换数据进行压缩,可大大降低通信数据量,减少通信时间,提高训练效率。如今机器学习模型的规模越来越大,在单计算节点上训练大模型变得非常低效甚至变得不可能。超大规模智能模型在多计算节点甚至跨地域计算中心计算节点上进行分布式并行训练成为必然趋势。分布式并行训练模型的过程中,为保持最终模型的有效性,各计算节点需要频繁且大量地通信交换优化器所需要的数据,因此通信时间可能比各计算节点的本地计算时间更长,成为高效训练的瓶颈。
为了解决上述问题,本实施例提供了一种跨计算节点分布式训练高效通信方法,首先分布式训练机器学习模型,并获取所述分布式训练机器学习模型中每个计算节点的本地更新量,然后对所述本地更新量进行量化,得到量化后的本地更新量。通过对本地更新量进行量化,可将每一次迭代步中计算节点间所需要通信的数据从32比特量化压缩到1比特,而不会影响最终收敛特性,从而减小通信时间,提高系统的整体训练效率。然后,对所述全局更新量进行量化,得到量化后的全局更新量以更新分布式训练机器学习模型,通过进一步压缩下发全局更新量时通信数据的比特值,提高通讯效率,以保证高效训练。
示例性方法
本实施例提供一种跨计算节点分布式训练高效通信方法。如图1所示,所述方法包括如下步骤:
步骤S100、在中心服务器上构建分布式训练机器学习模型为
Figure SMS_64
其中,
Figure SMS_66
是所述分布式训练机器学习模型的d-维模型参数,/>
Figure SMS_69
是参数维度,/>
Figure SMS_71
是分布式计算节点的数目,/>
Figure SMS_67
表示第/>
Figure SMS_68
个计算节点标示,/>
Figure SMS_70
是在第/>
Figure SMS_72
个计算节点上随机采样的样本,/>
Figure SMS_65
表示损失函数。
具体地,分布式计算是一种多计算节点协同计算的方法,和集中式计算是相对的。随着计算技术的发展,有些应用需要非常巨大的计算能力才能完成,如果采用集中式计算,需要耗费相当长的时间来完成。分布式计算将该应用分解成许多小的部分,分配给多台计算机进行处理。这样可以节约整体计算时间,大大提高计算效率。本发明构建的分布式训练机器学习模型可用于图像处理、卫星遥感、气象预测与数据分析等多个领域。
需要注意的是,本发明中所述的分布式训练机器学习模型包括常规的各计算节点和中心参数服务器通信的主从通信拓扑结构,也包括其它非主从式通信拓扑结点的通信上下行拓扑结构。
在一种实现方式中,所述步骤S100之后包括:
步骤M100、初始化所述分布式训练机器学习模型的模型参数;其中,所述分布式训练机器学习模型上所有计算节点的模型参数
Figure SMS_74
都初始化为/>
Figure SMS_77
,学习率为/>
Figure SMS_80
,冲量因子固定为/>
Figure SMS_75
,第/>
Figure SMS_78
个计算节点上的本地冲量初始化为/>
Figure SMS_81
,/>
Figure SMS_82
第/>
Figure SMS_73
个计算节点上的本地误差补偿为/>
Figure SMS_76
,全局误差补偿初始化为/>
Figure SMS_79
具体地,在分布式训练机器学习模型的优化器中,每一次迭代步中每一个计算节点的计算过程中所需要的要素为:权重参数,学习率和更新量。其中权重参数和学习率每一个计算节点的本地都会维护一个副本,不需要计算节点通信。
举例说明,本发明可应用于图像处理,在数据集IMAGENET上训练ResNet-50网络。训练的具体参数如下所示,共有8台计算节点服务器,每台服务器上有8个Nvidia-A100GPU,计算节点服务器之间用10Gbps的以太网连接。训练集中的数据是图像,我们将图像的分辨率设置为224X224,每个GPU上每一轮训练放置32张图像。我们将全精度的随机梯度下降法(StochasticGradientDescent,SGD)作为对比基准,本实施例提出的算法命名为二值随机梯度下降法(BinaryStochasticGradientDescent,BinSGD)。SGD的初始学习率为0.2,学习率在30,60,90epoch时分别减小10倍,冲量因子
Figure SMS_83
设置为0.9,权重衰减(WeightDecay)设置为0.0001。BinSGD的初始学习率为0.002,学习率在30,60,90epoch时分别减小10倍,冲量因子/>
Figure SMS_84
设置为0.95,权重衰减(WeightDecay)设置为0.1。
步骤S200、获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
具体地,本地更新量是各个计算节点分别计算出的各自更新量的平均值,需要各计算节点通信交换才能得到。若本地更新量的数值过大,会导致各计算节点需要频繁且大量地通信交换优化器所需要的数据,因此通信时间可能比各计算节点的本地计算时间更长,成为高效训练的瓶颈。在本发明的分布式训练机器学习模型中,各计算节点的本地更新量在通信前的各元素通过伯努利概率分布随机将元素的数值进行量化,以达到减小通信数据量的效果。
在一种实现方式中,所述步骤S200具体包括:
步骤S201、在第
Figure SMS_85
个计算节点上随机采样,得到样本/>
Figure SMS_86
步骤S202、根据所述样本
Figure SMS_87
,得到第/>
Figure SMS_88
个计算节点上的本地梯度为
Figure SMS_89
,其中/>
Figure SMS_90
为梯度算子,/>
Figure SMS_91
为更新时刻,/>
Figure SMS_92
为/>
Figure SMS_93
时刻的模型参数;
步骤S203、根据所述第
Figure SMS_94
个计算节点上的本地梯度,得到第/>
Figure SMS_95
个计算节点上的两个本地冲量为
Figure SMS_96
和/>
Figure SMS_97
,其中/>
Figure SMS_98
为冲量因子;
步骤S204、根据所述第
Figure SMS_99
个计算节点上的两个本地冲量,得到第/>
Figure SMS_100
个计算节点上的所述本地更新量为
Figure SMS_101
,其中/>
Figure SMS_102
,/>
Figure SMS_103
为所述本地冲量;
具体地,步骤S203可以保证
Figure SMS_104
Figure SMS_105
中的对应元素/>
Figure SMS_106
和/>
Figure SMS_107
恒有/>
Figure SMS_108
,因此步骤S204中的/>
Figure SMS_109
中的元素/>
Figure SMS_110
值一定在[-1,1]之间。
步骤S205、将在第
Figure SMS_111
个计算节点上的所述本地更新量加上本地误差补偿,更新所述本地更新量为
Figure SMS_112
,其中/>
Figure SMS_113
为本地误差补偿。
具体地,在第
Figure SMS_114
个计算节点上的所述本地更新量上添加本地误差补偿可以使模型在训练过程中收敛更快,最终的推断性能更好。本地误差补偿可根据本地更新量进行更新。
步骤S206、在第
Figure SMS_115
个计算节点上采用伯努利二值分布法将所述本地更新量进行量化,得到所述量化后的本地更新量为
Figure SMS_116
其中
Figure SMS_117
具体地,
Figure SMS_118
的值一定在[0,1]之间,因此可以直接利用伯努利二值分布随机将/>
Figure SMS_119
中的元素/>
Figure SMS_120
量化到1或者-1。本地更新量量化后的数据的期望值和未量化的数据相等,也就是说在此过程中只是带来了方差。在应用全精度随机梯度下降优化算法在小批量数据上训练也会带来方差,而且此方差一般来说比本发明中的本地更新量量化带来的方差更大,因此本专利提出的算法对收敛速率影响较小。
在一种实现方式中,所述步骤S200之后包括:
步骤M200、在第
Figure SMS_121
个计算节点上更新误差补偿,得到更新的误差补偿为
Figure SMS_122
步骤S300、根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
在一种实现方式中,所述步骤S300具体包括:
步骤S301、将各计算节点的所述量化后的本地更新量进行平均,得到所述全局更新量为
Figure SMS_123
步骤S302、将所述全局更新量加上全局误差补偿,更新所述全局更新量为
Figure SMS_124
,其中/>
Figure SMS_125
为所述全局误差补偿;
步骤S303、对所述全局更新量采用伯努利二值分布法进行量化,得到所述量化后的全局更新量为
Figure SMS_126
其中
Figure SMS_127
具体地,各计算节点将量化后的本地更新量数据通过上传通信链路上传到参数服务器。在参数服务器上,将接收到的各计算节点更新量数据进行平均,此时平均更新量的各元素的数值一定在[-1,1]之间,然后再次通过伯努利概率分布随机将元素的数值量化到+1或者-1。全局更新量量化后的数据的期望值和未量化的数据相等,也就是说在此过程中只是带来了方差。在应用全精度随机梯度下降优化算法在小批量数据上训练也会带来方差,而且此方差一般来说比本发明中的全局更新量量化带来的方差更大,因此本专利提出的算法对收敛速率影响较小。
举例说明,本实施中步骤M100中所述的分布式训练机器学习模型,具体收敛特性如图2所示,虽然与全精度32比特的SGD相比,BinSGD将计算节点服务器间的通信数据量化到1比特,计算节点间的通信量直接减少了32倍,但是BinSGD和SGD的收敛速率是基本相当的,从而从实践证明了BinSGD的有效性。
在一种实现方式中,所述步骤S300之后包括:
步骤M300、更新所述全局误差补偿,得到更新的全局误差补偿为
Figure SMS_128
具体地,误差补偿就是人为地造出一种新的原始误差去抵消当前成为问题的原有的原始误差,并应尽量使两者大小相等,方向相反,从而达到减少加工误差,提高加工精度的目的。本实施例中,添加全局误差补偿可以使模型在训练过程中收敛更快,最终的推断性能更好。
步骤S400、在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。
将所述量化后的全局更新量
Figure SMS_129
下发到各计算节点上;
在第
Figure SMS_130
个计算节点上更新所述模型参数为/>
Figure SMS_131
具体地,参数服务器再通过下载通信链路将更新量数据再下发到各计算节点上。最后,各计算节点优化器的所需要所有要素后进行一次迭代计算。经过优化的分布式训练机器学习模型,可以将每一次迭代步中计算节点间所需要通信的数据从32比特量化压缩到1比特,而不会影响最终收敛特性,从而减小通信时间,提高系统的整体训练效率。
需要注意的是,本方法中的通信拓扑结构,除常规的各计算节点和中心服务器通信的主从通信拓扑结构之外,还包括其它按本方法的量化方式的非主从式通信拓扑结点的通信上下行拓扑结构。
示例性装置
如图3中所示,本实施例还提供一种跨计算节点分布式训练高效通信装置,所述装置包括:
模型构建模块10,用于在中心服务器上构建分布式训练机器学习模型;
本地更新量量化模块20,用于获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
全局更新量量化模块30,用于根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
模型更新模块40,用于在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。
在一种实现方式中,所述模型构建模块10包括:
模型构建单元,用于构建所述分布式训练机器学习模型为
Figure SMS_132
其中,
Figure SMS_133
是所述分布式训练机器学习模型的d-维模型参数,/>
Figure SMS_135
是参数维度,/>
Figure SMS_138
是分布式计算节点的数目,/>
Figure SMS_134
表示第/>
Figure SMS_137
个计算节点标示,/>
Figure SMS_139
是在第/>
Figure SMS_140
个计算节点上随机采样的样本,/>
Figure SMS_136
表示损失函数。
在一种实现方式中,所述装置还包括:
初始化单元,用于初始化所述分布式训练机器学习模型的模型参数;其中,所述分布式训练机器学习模型上所有计算节点的模型参数
Figure SMS_142
都初始化为/>
Figure SMS_145
,学习率为/>
Figure SMS_149
,冲量因子固定为/>
Figure SMS_143
,第/>
Figure SMS_146
个计算节点上的本地冲量初始化为/>
Figure SMS_148
,/>
Figure SMS_150
,第/>
Figure SMS_141
个计算节点上的本地误差补偿为/>
Figure SMS_144
,全局误差补偿初始化为/>
Figure SMS_147
在一种实现方式中,所述本地更新量量化模块20包括:
采样单元,用于在第
Figure SMS_151
个计算节点上随机采样,得到样本/>
Figure SMS_152
本地梯度获取单元,用于根据所述样本
Figure SMS_153
,得到第/>
Figure SMS_154
个计算节点上的本地梯度为
Figure SMS_155
,其中/>
Figure SMS_156
为梯度算子,/>
Figure SMS_157
为更新时刻,/>
Figure SMS_158
为/>
Figure SMS_159
时刻的模型参数;
本地冲量获取单元,用于根据所述第
Figure SMS_160
个计算节点上的本地梯度,得到第/>
Figure SMS_161
个计算节点上的两个本地冲量为
Figure SMS_162
和/>
Figure SMS_163
,其中/>
Figure SMS_164
为冲量因子;
本地更新量获取单元,用于根据所述第
Figure SMS_165
个计算节点上的两个本地冲量,得到第/>
Figure SMS_166
个计算节点上的所述本地更新量为
Figure SMS_167
,其中/>
Figure SMS_168
,/>
Figure SMS_169
为所述本地冲量;
本地更新量更新单元,用于将在第
Figure SMS_170
个计算节点上的所述本地更新量加上本地误差补偿,更新所述本地更新量为
Figure SMS_171
,其中/>
Figure SMS_172
为本地误差补偿。
本地更新量量化单元,用于在第
Figure SMS_173
个计算节点上采用伯努利二值分布法将所述本地更新量进行量化,得到所述量化后的本地更新量为
Figure SMS_174
其中
Figure SMS_175
在一种实现方式中,所述装置还包括:
第一误差补偿单元,用于在第
Figure SMS_176
个计算节点上更新误差补偿,得到更新的误差补偿为
Figure SMS_177
在一种实现方式中,所述全局更新量量化模块30包括:
全局更新量获取单元,用于将各计算节点的所述量化后的本地更新量进行平均,得到所述全局更新量为
Figure SMS_178
全局更新量更新单元,用于将所述全局更新量加上全局误差补偿,更新所述全局更新量为
Figure SMS_179
,其中/>
Figure SMS_180
为所述全局误差补偿;
全局更新量量化单元,用于对所述全局更新量采用伯努利二值分布法进行量化,得到所述量化后的全局更新量为
Figure SMS_181
其中
Figure SMS_182
在一种实现方式中,所述装置还包括:
第二误差补偿单元,用于更新所述全局误差补偿,得到更新的全局误差补偿为
Figure SMS_183
在一种实现方式中,所述模型更新模块40,包括:
数据下发单元,用于将所述量化后的全局更新量
Figure SMS_184
下发到各计算节点上;
模型参数更新单元,用于在第
Figure SMS_185
个计算节点上更新所述模型参数为
Figure SMS_186
在一个实施例中,如图4所示,提供一种跨计算节点分布式训练高效通信系统,所述系统包括中心服务器、多个计算节点以及在所述系统上运行的跨计算节点分布式训练高效通信程序,所述处理器执行所述跨计算节点分布式训练高效通信程序时,实现如下操作指令:
在中心服务器上构建分布式训练机器学习模型;
获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本发明所提供的各实施例中所使用的对存储器、存储、运营数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双运营数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
综上,本发明公开了一种跨计算节点分布式训练高效通信方法及系统,所述方法包括:在中心服务器上构建分布式训练机器学习模型;获取分布式训练机器学习模型中每个计算节点的本地更新量,并对本地更新量进行量化,得到量化后的本地更新量;根据量化后的本地更新量得到全局更新量,并对全局更新量进行量化,得到量化后的全局更新量;在各计算节点上,根据量化后的全局更新量更新分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型。本发明通过量化方法将计算节点间所需要通信的数据进行压缩以减少通讯数据量,不会影响最终收敛特性,从而减少通信时间,提高系统的整体训练效率。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。

Claims (10)

1.一种跨计算节点分布式训练高效通信方法,其特征在于,所述方法包括:
在中心服务器上构建分布式训练机器学习模型;
获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型;
所述在中心服务器上构建分布式训练机器学习模型,包括:
在中心服务器上构建所述分布式训练机器学习模型为
Figure QLYQS_1
其中,
Figure QLYQS_2
是所述分布式训练机器学习模型的/>
Figure QLYQS_3
-维模型参数,/>
Figure QLYQS_4
是参数维度,/>
Figure QLYQS_5
是分布式计算节点的数目,/>
Figure QLYQS_6
表示第i个计算节点标示,/>
Figure QLYQS_7
是在第i个计算节点上随机采样的样本,/>
Figure QLYQS_8
表示损失函数;
所述在中心服务器上构建分布式训练机器学习模型后,还包括:
初始化所述分布式训练机器学习模型的模型参数和训练参数;其中,所述分布式训练机器学习模型上所有计算节点的模型参数
Figure QLYQS_10
都初始化为/>
Figure QLYQS_13
,学习率为/>
Figure QLYQS_15
,冲量因子固定为
Figure QLYQS_11
,第i个计算节点上的本地冲量初始化为/>
Figure QLYQS_12
,/>
Figure QLYQS_14
,第i个计算节点上的本地误差补偿为/>
Figure QLYQS_16
,全局误差补偿初始化为/>
Figure QLYQS_9
将所述模型参数和训练参数通过通信链路下传到各计算节点;
所述获取所述分布式训练机器学习模型中每个计算节点的本地更新量,包括:
在第i个计算节点上随机采样,得到样本
Figure QLYQS_17
根据所述样本
Figure QLYQS_18
,得到第i个计算节点上的本地梯度为
Figure QLYQS_19
,其中/>
Figure QLYQS_20
为梯度算子,/>
Figure QLYQS_21
为更新时刻,/>
Figure QLYQS_22
为/>
Figure QLYQS_23
时刻的模型参数;
根据所述第i个计算节点上的本地梯度,得到第i个计算节点上的两个本地冲量为
Figure QLYQS_24
和/>
Figure QLYQS_25
,其中/>
Figure QLYQS_26
为冲量因子;
根据所述第i个计算节点上的两个本地冲量,得到第i个计算节点上的所述本地更新量为
Figure QLYQS_27
,其中/>
Figure QLYQS_28
,/>
Figure QLYQS_29
为所述本地冲量;
将在第i个计算节点上的所述本地更新量加上本地误差补偿,更新所述本地更新量为
Figure QLYQS_30
,其中/>
Figure QLYQS_31
为本地误差补偿。
2.根据权利要求1所述的跨计算节点分布式训练高效通信方法,其特征在于,所述对所述本地更新量进行量化,得到量化后的本地更新量,包括:
在第i个计算节点上采用伯努利二值分布法将所述本地更新量进行量化,得到所述量化后的本地更新量为
Figure QLYQS_32
其中
Figure QLYQS_33
3.根据权利要求2所述的跨计算节点分布式训练高效通信方法,其特征在于,所述对所述本地更新量进行量化,得到量化后的本地更新量之后,包括:
在第i个计算节点上更新误差补偿,得到更新的误差补偿为
Figure QLYQS_34
4.根据权利要求2所述的跨计算节点分布式训练高效通信方法,其特征在于,所述根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量,包括:
在中心服务器上将各计算节点的所述量化后的本地更新量进行平均,得到所述全局更新量为
Figure QLYQS_35
将所述全局更新量加上全局误差补偿,更新所述全局更新量为
Figure QLYQS_36
,其中/>
Figure QLYQS_37
为所述全局误差补偿;
对所述全局更新量采用伯努利二值分布法进行量化,得到中心服务器上的所述量化后的全局更新量为
Figure QLYQS_38
其中
Figure QLYQS_39
5.根据权利要求4所述的跨计算节点分布式训练高效通信方法,其特征在于,所述根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量之后,还包括:
更新所述全局误差补偿,得到更新的全局误差补偿为
Figure QLYQS_40
6.根据权利要求4所述的跨计算节点分布式训练高效通信方法,其特征在于,所述在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型,包括:
将中心服务器上的所述量化后的全局更新量
Figure QLYQS_41
下发到各计算节点上;
在第i个计算节点上更新所述模型参数为
Figure QLYQS_42
7.一种跨计算节点分布式训练高效通信装置,其特征在于,所述装置包括:
模型构建模块,用于在中心服务器上构建分布式训练机器学习模型;
本地更新量量化模块,用于获取所述分布式训练机器学习模型中每个计算节点的本地更新量,并对所述本地更新量进行量化,得到量化后的本地更新量;
全局更新量量化模块,用于根据所述量化后的本地更新量得到全局更新量,并对所述全局更新量进行量化,得到量化后的全局更新量;
模型更新模块,用于在各计算节点中,根据所述量化后的全局更新量更新所述分布式训练机器学习模型参数,得到更新后的分布式训练机器学习模型;
所述模型构建模块包括:
模型构建单元,用于在中心服务器上构建所述分布式训练机器学习模型为
Figure QLYQS_43
其中,
Figure QLYQS_44
是所述分布式训练机器学习模型的/>
Figure QLYQS_45
-维模型参数,/>
Figure QLYQS_46
是参数维度,/>
Figure QLYQS_47
是分布式计算节点的数目,/>
Figure QLYQS_48
表示第i个计算节点标示,/>
Figure QLYQS_49
是在第i个计算节点上随机采样的样本,/>
Figure QLYQS_50
表示损失函数;
所述装置还包括:
初始化单元,用于初始化所述分布式训练机器学习模型的模型参数和训练参数;其中,所述分布式训练机器学习模型上所有计算节点的模型参数
Figure QLYQS_52
都初始化为/>
Figure QLYQS_54
,学习率为/>
Figure QLYQS_56
,冲量因子固定为/>
Figure QLYQS_53
,第i个计算节点上的本地冲量初始化为/>
Figure QLYQS_55
,/>
Figure QLYQS_57
,第i个计算节点上的本地误差补偿为/>
Figure QLYQS_58
,全局误差补偿初始化为/>
Figure QLYQS_51
;将所述模型参数和训练参数通过通信链路下传到各计算节点;
所述本地更新量量化模块包括:
采样单元,用于在第i个计算节点上随机采样,得到样本
Figure QLYQS_59
本地梯度获取单元,用于根据所述样本
Figure QLYQS_60
,得到第i个计算节点上的本地梯度为
Figure QLYQS_61
,其中/>
Figure QLYQS_62
为梯度算子,/>
Figure QLYQS_63
为更新时刻,/>
Figure QLYQS_64
为/>
Figure QLYQS_65
时刻的模型参数;
本地冲量获取单元,用于根据所述第i个计算节点上的本地梯度,得到第i个计算节点上的两个本地冲量为
Figure QLYQS_66
和/>
Figure QLYQS_67
,其中/>
Figure QLYQS_68
为冲量因子;
本地更新量获取单元,用于根据所述第i个计算节点上的两个本地冲量,得到第i个计算节点上的所述本地更新量为
Figure QLYQS_69
,其中/>
Figure QLYQS_70
,/>
Figure QLYQS_71
为所述本地冲量;
本地更新量更新单元,用于将在第i个计算节点上的所述本地更新量加上本地误差补偿,更新所述本地更新量为
Figure QLYQS_72
,其中/>
Figure QLYQS_73
为本地误差补偿。
8.根据权利要求7所述的跨计算节点分布式训练高效通信装置,其特征在于,所述本地更新量量化模块包括:
本地更新量量化单元,用于在第i个计算节点上采用伯努利二值分布法将本地更新量进行量化,得到量化后的本地更新量为
Figure QLYQS_74
其中,
Figure QLYQS_75
其中,
Figure QLYQS_76
为更新时刻,/>
Figure QLYQS_77
为/>
Figure QLYQS_78
时刻第i个计算节点上的本地更新量,/>
Figure QLYQS_79
是参数维度。
9.一种跨计算节点分布式训练高效通信系统,其特征在于,所述系统包括中心服务器、多个计算节点以及在所述系统上运行的跨计算节点分布式训练高效通信程序,所述系统执行所述跨计算节点分布式训练高效通信程序时,实现如权利要求1-6任一项所述的跨计算节点分布式训练高效通信方法的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有跨计算节点分布式训练高效通信程序,所述跨计算节点分布式训练高效通信程序被处理器执行时,实现如权利要求1-6任一项所述的跨计算节点分布式训练高效通信方法的步骤。
CN202310271228.3A 2023-03-20 2023-03-20 一种跨计算节点分布式训练高效通信方法及系统 Active CN116070719B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310271228.3A CN116070719B (zh) 2023-03-20 2023-03-20 一种跨计算节点分布式训练高效通信方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310271228.3A CN116070719B (zh) 2023-03-20 2023-03-20 一种跨计算节点分布式训练高效通信方法及系统

Publications (2)

Publication Number Publication Date
CN116070719A CN116070719A (zh) 2023-05-05
CN116070719B true CN116070719B (zh) 2023-07-14

Family

ID=86180462

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310271228.3A Active CN116070719B (zh) 2023-03-20 2023-03-20 一种跨计算节点分布式训练高效通信方法及系统

Country Status (1)

Country Link
CN (1) CN116070719B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117035123B (zh) * 2023-10-09 2024-01-09 之江实验室 一种并行训练中的节点通信方法、存储介质、设备

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109754060A (zh) * 2017-11-06 2019-05-14 阿里巴巴集团控股有限公司 一种神经网络机器学习模型的训练方法及装置
CN113886460A (zh) * 2021-09-26 2022-01-04 中国空间技术研究院 低带宽分布式深度学习方法

Family Cites Families (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020018394A1 (en) * 2018-07-14 2020-01-23 Moove.Ai Vehicle-data analytics
CN114761974A (zh) * 2019-09-24 2022-07-15 华为技术有限公司 用于量化神经网络的权重和输入的训练方法
CN113128696A (zh) * 2019-12-31 2021-07-16 香港理工大学深圳研究院 分布式机器学习通信优化方法、装置、服务器及终端设备
CN111382844B (zh) * 2020-03-11 2023-07-07 华南师范大学 一种深度学习模型的训练方法及装置
CN112288097B (zh) * 2020-10-29 2024-04-02 平安科技(深圳)有限公司 联邦学习数据处理方法、装置、计算机设备及存储介质
CN113591145B (zh) * 2021-07-28 2024-02-23 西安电子科技大学 基于差分隐私和量化的联邦学习全局模型训练方法
CN115033878A (zh) * 2022-08-09 2022-09-09 中国人民解放军国防科技大学 快速自博弈强化学习方法、装置、计算机设备和存储介质

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109754060A (zh) * 2017-11-06 2019-05-14 阿里巴巴集团控股有限公司 一种神经网络机器学习模型的训练方法及装置
CN113886460A (zh) * 2021-09-26 2022-01-04 中国空间技术研究院 低带宽分布式深度学习方法

Also Published As

Publication number Publication date
CN116070719A (zh) 2023-05-05

Similar Documents

Publication Publication Date Title
EP3660754A1 (en) Communication efficient federated learning
Durham et al. Adaptive sequential posterior simulators for massively parallel computing environments
CN116070719B (zh) 一种跨计算节点分布式训练高效通信方法及系统
CN106022521B (zh) 基于Hadoop架构的分布式BP神经网络的短期负荷预测方法
EP3504666A1 (en) Asychronous training of machine learning model
CN112183750A (zh) 神经网络模型训练方法、装置、计算机设备及存储介质
CN110782030A (zh) 深度学习权值更新方法、系统、计算机设备及存储介质
CN113282470B (zh) 一种性能预测方法及装置
CN113111576A (zh) 一种基于混合编码粒子群-长短期记忆神经网络出水氨氮软测量方法
CN116401238A (zh) 偏离度监测方法、装置、设备、存储介质和程序产品
US20200143282A1 (en) Quantizing machine learning models with balanced resolution via damped encoding
CN115392348A (zh) 联邦学习梯度量化方法、高效通信联邦学习方法及相关装置
US20230072535A1 (en) Error mitigation for sampling on quantum devices
CN109992631B (zh) 一种动态异质信息网络嵌入方法、装置和电子设备
CN111240606A (zh) 一种基于安全内存的存储优化方法及系统
CN113610709A (zh) 模型量化方法、装置、电子设备和计算机可读存储介质
CN114584476A (zh) 一种流量预测方法、网络训练方法、装置及电子设备
CN110929849A (zh) 一种神经网络模型的压缩方法和装置
CN110764696B (zh) 向量信息存储及更新的方法、装置、电子设备及存储介质
Aufiero et al. Surrogate-based global sensitivity analysis with statistical guarantees via floodgate
Barreiro‐Ures et al. Analysis of interval‐grouped data in weed science: The binnednp Rcpp package
CN115577618B (zh) 高压换流阀厅环境因子预测模型构建方法与预测方法
CN116432735A (zh) 一种数据处理方法、装置及边缘计算设备
CN116778254A (zh) 一种图像分类模型生成方法、装置、设备及存储介质
O'Gorman Constructing narrower confidence intervals by inverting adaptive tests

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant