CN113361598B - 基于分布式学习的模型训练方法、服务器及分布式系统 - Google Patents

基于分布式学习的模型训练方法、服务器及分布式系统 Download PDF

Info

Publication number
CN113361598B
CN113361598B CN202110624386.3A CN202110624386A CN113361598B CN 113361598 B CN113361598 B CN 113361598B CN 202110624386 A CN202110624386 A CN 202110624386A CN 113361598 B CN113361598 B CN 113361598B
Authority
CN
China
Prior art keywords
client
task
training
round
current training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110624386.3A
Other languages
English (en)
Other versions
CN113361598A (zh
Inventor
刘铎
李丽
段莫名
张宇
陈咸彰
任骜
谭玉娟
汪成亮
梁靓
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Chongqing University
Original Assignee
Chongqing University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Chongqing University filed Critical Chongqing University
Priority to CN202110624386.3A priority Critical patent/CN113361598B/zh
Publication of CN113361598A publication Critical patent/CN113361598A/zh
Application granted granted Critical
Publication of CN113361598B publication Critical patent/CN113361598B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了基于分布式学习的模型训练方法,其根据客户端的历史训练任务的完成情况,得到客户端的当前训练轮任务预测量,以使得客户端根据所述当前训练轮任务预测量和服务器下发的全局模型进行本地训练,得到客户端的当前训练轮的本地模型和当前训练轮实际任务量,进而将每一个客户端的当前训练轮的本地模型聚合成新的全局模型,并根据所述当前训练实际任务量对客户端的历史任务完成情况进行更新,其通过尽可能预测逼近客户端实际训练能力的任务量,进而适应性地对客户端的训练任务作出调整,使得客户端尽可能完成多的训练任务而不掉队,提升了全局模型的精度。相应地,本发明还提供了一种服务器和分布式学习系统。

Description

基于分布式学习的模型训练方法、服务器及分布式系统
技术领域
本发明涉及机器学习技术领域,尤其涉及一种基于分布式学习的模型训练方法、服务器及分布式系统。
背景技术
随着互联网技术的飞速发展,人类进入了一个海量数据的信息化时代。在这样的时代背景下,每个人都是数据的产生者和拥有者,各种数据呈爆发式增长。这些数据蕴含着丰富的信息,促生了数据挖掘、云计算等学科并对机器学习提出了新的挑战。然而,要从分散在各个设备的数据中获得有用的信息并非易事,因为其涉及到隐私、技术和道德等多个方面。其中,隐私问题尤为重要。传统的分布式机器学习尽管解决了分散训练的问题,但其往往需要将分散数据集中到服务器训练,这种方式对保护隐私尤为不利。联邦学习(FederatedLearning,FL)是一种新型的分布式机器学习,主要用于解决传统机器学习隐私泄露的问题。在联邦学习中,成千上万的移动边缘设备(如智能手机,个人电脑,平板等)使用本地数据在本地训练模型而无需将用户数据上传到数据中心训练模型从而避免了隐私泄露。由于其在隐私保护方面显现出了巨大作用,联邦学习已经被用于个人推荐应用和医疗训练。但是在实际的联邦学习场景中,由于网络带宽资源是有限的,并非所有的客户端都能参与训练。因此,每轮训练只选择部分客户端参与。经典的联邦学习算法FedAvg训练过程如下:
①服务器挑选K个客户端(K=C*N,C为挑选比例,N为客户端总数)。一般的算法中服务器都是随机选择客户端,有的算法会对客户端进行筛选而非随机选择;
②服务器广播全局模型和任务量(各个客户端拥有相同的任务量);
③客户端训练本地模型。客户端使用本地数据和服务器发送的全局模型进行本地模型训练,在这个过程中服务器为每个客户端规定了相同的工作量,即训练的epoch数目是相同的(一般客户端将全部数据训练一遍称为完成了一个epoch);
④客户端上传本地模型。客户端在完成本地训练之后将训练完成的模型上传至服务器,该过程可以采用同态加密、差分隐私等方式达到更安全的数据保护。通过上传训练模型取代上传隐私数据到服务器降低了隐私泄露的风险;
⑤服务器聚合客户端模型。服务器使用加权平均或者其他方式将获得的客户端模型聚合为一个全局模型,加权平均聚合公式如下:
Figure BDA0003100469220000021
其中
Figure BDA0003100469220000022
为客户端k在第t轮上传的本地模型权重,n为第t轮选中的K个客户端的样本总数,nk为第k个客户端的样本数。该公式的意思即服务器使用加权平均的方式(权重为客户端k的样本占本轮选中客户端样本总数的权重)将K个客户端上传的本地模型进行聚合,最后得到第t轮的全局模型。该全局模型将作为第t+1轮的初始模型广播给被选中的客户端。
步骤①-⑤即为一个完整的训练轮数,此后重复训练以上训练过程直到达到目标精度或完成目标训练轮数。
本发明人在实施上述过程中发现,真实的联邦学习计算设备有别于机房中的设备。联邦设备的网络状态、计算力、资源、电池等系统状态有限且异构,具体表现为不同的设备完成训练任务的能力不同。上述的分布式机器学习方法在设计过程中未考虑设备系统异构的情况,在步骤⑤中分配给不同设备相同的计算任务量。该任务量未经过考量而与设备实际能完成的任务量不匹配,造成了客户端掉队的现象(即客户端无法完成分配的任务由于资源消耗殆尽中途退出训练或分配给客户端的任务过载以至于客户端无法在可接受的时间内完成训练)。掉队客户端只完成了分配任务量的一部分,并且该结果无法上传到服务器。大量的客户端掉队减慢了训练的收敛速度且降低了训练的精度,严重影响了模型性能。
发明内容
本发明提供一种基于分布式学习的模型训练方法,其能有效解决现有的分布式学习方法存在的客户端掉队现象,进而严重影响模型的精度。
本发明提供的基于分布式学习的模型训练方法,其应用于服务器,包括:
在上一训练轮结束且当前训练轮开始前,获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量;
在当前训练轮中,向每一个所述客户端下发全局模型和所述当前训练轮任务预测量,以使得每一个所述客户端执行本地训练的操作;
接收每一个所述客户端返回的本地训练结果,其中,所述本地训练结果是由所述客户端在当前训练轮中根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练得到的当前训练轮的本地模型和当前训练轮实际任务量。
将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据每一所述客户端的所述当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
优选的,所述获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据获取到的所述客户端在上一个训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量;或,
对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量。
优选的,所述当前训练轮任务预测量是指当前训练轮的任务量预测值时,则所述对于每一个所述客户端,根据获取到的所述客户端在上一个训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
根据以下公式对所述客户端在当前训练轮的任务量进行预测,得到所述客户端的当前训练轮的任务量预测值:
Figure BDA0003100469220000041
其中,
Figure BDA0003100469220000042
表示客户端k在第t个训练轮的预测任务量,即上一个训练轮的任务量预测值;
Figure BDA0003100469220000043
表示客户端k在第t+1个训练轮的预测任务量,即当前训练轮的任务量预测值;u是控制增量的超参数。
优选的,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值,则所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据以下公式对所述客户端在当前训练轮的任务量的下限和任务量的上限进行预测,得到所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值;
Figure BDA0003100469220000044
其中,
Figure BDA0003100469220000045
为客户端k在第t个训练轮的任务量预测下限值,即上一个训练轮的任务量预测下限值;
Figure BDA0003100469220000046
为客户端k在第t个训练轮的任务量预测上限值,即上一个训练轮的任务量预测上限值;
Figure BDA0003100469220000047
为客户端k在第t+1个训练轮的任务量预测下限值,即当前训练轮的任务量预测下限值;
Figure BDA0003100469220000051
为客户端k在第t+1个训练轮的任务量预测上限值,即当前训练轮的任务量预测上限值;u是控制增量的超参数。
优选的,所述当前训练轮任务预测量是指当前训练轮的任务量预测值,则所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据以下公式计算所述客户端在过去所有训练轮的负载阈值:
Figure BDA0003100469220000052
其中,
Figure BDA0003100469220000053
为客户端k从第1个训练轮到第t个训练轮的实际任务量的移动加权平均,也即过去所有训练轮的负载阈值,
Figure BDA0003100469220000054
为客户端k从第1轮到第t-1轮的实际任务量的移动加权平均,
Figure BDA0003100469220000055
为客户端k在第t-1个训练轮实际能完成的工作量,α是平滑指数;
根据所述客户端的过去所有训练轮的负载阈值以及所述客户端在上一个训练轮的任务完成情况,确定所述客户端在上一个训练轮的状态:当
Figure BDA0003100469220000056
所述客户端在上一个训练轮处于启动阶段;当
Figure BDA0003100469220000057
所述客户端在上一个训练轮处于增长阶段;当所述客户端在上一个训练轮不能完成上一个训练轮的任务量预测值时,所述客户端在上一个训练轮掉队;
根据以下公式对所述客户端在当前训练轮的任务量进行预测,得到所述客户端的当前训练轮的任务量预测值:
Figure BDA0003100469220000058
其中,
Figure BDA0003100469220000059
表示客户端k在第t个训练轮的预测任务量,即上一个训练轮的任务量预测值;
Figure BDA00031004692200000510
表示客户端k在第t+1个训练轮的预测任务量,即当前训练轮的任务量预测值;γ1和γ2分别为启动阶段和增长阶段的增量,且γ1>γ2
优选的,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值,则所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
Figure BDA0003100469220000061
其中,
Figure BDA0003100469220000062
为客户端k从第1个训练轮到第t个训练轮的实际任务量的移动加权平均,也即过去所有训练轮的负载阈值,
Figure BDA0003100469220000063
为客户端k从第1轮到第t-1轮的实际任务量的移动加权平均,
Figure BDA0003100469220000064
为客户端k在第t-1个训练轮实际能完成的工作量,α是平滑指数;
根据所述客户端的过去所有训练轮的负载阈值以及所述客户端在上一个训练轮的任务完成情况,确定所述客户端在上一个训练轮基于上一个训练轮的预测任务量下限值训练的状态和基于上一个训练轮的预测任务量上限值训练的状态:当
Figure BDA0003100469220000065
所述客户端在上一个训练轮基于上一个训练轮的预测任务量下限值训练处于启动阶段;当
Figure BDA0003100469220000066
所述客户端基于上一个训练轮的预测任务量下限值训练处于增长阶段;当
Figure BDA0003100469220000067
所述客户端基于上一个训练轮的预测任务量上限值训练处于启动阶段;当
Figure BDA0003100469220000068
所述客户端基于上一个训练轮的预测任务量上限值训练处于增长阶段;当所述客户端在上一个训练轮不能完成上一个训练轮的预测任务量下限值时,所述客户端掉队;
对于每一个所述客户端,根据以下公式对所述客户端在当前训练轮的任务量的下限和任务量的上限进行预测,得到所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值;
(1)当所述客户端在上一个训练轮能够完成任务
Figure BDA0003100469220000069
时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure BDA0003100469220000071
(2)当客户端在上一个训练轮只能完成
Figure BDA0003100469220000072
而无法完成
Figure BDA0003100469220000073
时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure BDA0003100469220000074
(3)当客户端在上一个训练轮掉队时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure BDA0003100469220000075
其中,
Figure BDA0003100469220000076
为客户端k在第t个训练轮的任务量预测下限值,即上一个训练轮的任务量预测下限值;
Figure BDA0003100469220000077
为客户端k在第t个训练轮的任务量预测上限值,即上一个训练轮的任务量预测上限值;
Figure BDA0003100469220000078
为客户端k在第t+1个训练轮的任务量预测下限值,即当前训练轮的任务量预测下限值;
Figure BDA0003100469220000079
为客户端k在第t+1个训练轮的任务量预测上限值,即当前训练轮的任务量预测上限值;γ1和γ2分别为启动阶段和增长阶段的增量,且γ1>γ2
优选的,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值,则所述客户端的本地训练结果具体通过以下方式获得:
所述客户端接收所述服务器下发的全局模型、所述当前训练轮的任务量预测下限值和所述当前训练轮的任务量预测上限值;
所述客户端采用本地数据对所述全局模型进行训练,当在训练过程中检测到所述本地数据完成了所述当前训练轮的任务量预测下限值时,生成所述客户端的本地模型,并将所述客户端的本地模型发送给所述服务器;
响应于所述服务器发送的增大任务量对所述本地数据进行训练的命令,继续训练所述本地数据,并在训练过程中检测到所述本地数据完成了所述当前训练轮的任务量预测上限值时,更新所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,并将所述当前训练轮的本地模型和当前训练轮实际任务量作为所述客户端的本地训练结果。
第二方面,本发明提供了一种基于分布式学习的模型训练方法,其应用于客户端,包括:
将所述客户端的历史任务完成情况发送给服务器,以使得所述服务器根据所述客户端上传的历史任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量;
接收所述服务器发送的全局模型和所述当前训练轮任务预测量,并根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量;
将所述当前训练轮的本地模型和所述当前训练轮实际任务量发送给所述服务器,以使得所述服务器将每一个所述客户端的当前训练轮的本地模型聚合成一个新的全局模型,并根据每一所述客户端的所述当前训练轮对每一个所述客户端的历史任务完成情况进行更新。
第三方面,本发明提供了一种服务器,包括处理器、存储器以及存储在所述存储器中且被配置为由所述处理器执行的计算机程序,所述处理器执行所述计算机程序时实现如第一方面提供的所述基于分布式学习的模型训练方法。
第四方面,本发明提供一种分布式学习系统,所述分布式学习系统包括若干个客户端和服务器,其中,所述服务器与若干个所述客户端通信连接;
所述服务器,用于根据获取到的每一个客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量,并将每一所述客户端的当前训练轮任务预测量和所述服务器的全局模型发送给每一所述客户端;
所述客户端,用于接收所述全局模型及所述当前训练轮任务预测量,并根据所述全局模型和所述当前训练轮任务预测量进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,并将所述客户端的本地模型和所述当前训练轮实际任务量发送给所述服务器;
所述服务器,还用于接收每一个所述客户端的当前训练轮的本地模型和所述当前训练轮实际任务量,并将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据每一个所述客户端的当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
与现有技术相比,本发明的有益效果在于:本发明提供了一种基于分布式学习的模型训练方法,其根据客户端的历史训练任务的完成情况对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,以使得所述客户端根据所述当前训练轮任务预测量和服务器下发的全局模型进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,进而将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据所述当前训练实际任务量对所述客户端的历史任务完全情况进行更新,其通过尽可能预测逼近客户端实际训练能力的任务量,进而适应性地对客户端的训练任务作出调整,使得客户端尽可能完成多的训练任务而不掉队,提升了全局模型的精度。相应地,本发明还提供了一种服务器和分布式学习系统。
附图说明
图1是本发明实施例一提供的基于分布式学习的模型训练方法的流程示意图;
图2是采用本发明实施例一提供的客户端的任务量预测算法的客户端的预测任务量的变化过程示意图;
图3是采用本发明实施例二提供的客户端的任务量预测算法得到的客户端的预测任务量的变化过程示意图;
图4是本发明实施例三提供的客户端的任务量预测算法的流程图;
图5是本发明实施例四提供的客户端的任务量预测算法的流程图;
图6是本发明实施例七提供的分布式学习系统的框架图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
实施例一
参见图1,其是本发明实施例一提供的基于分布式学习的模型训练方法的流程示意图。
本发明实施例一提供的基于分布式学习的模型训练方法,其应用于服务器,包括步骤S11到步骤S14:
步骤S11,在上一训练轮结束且当前训练轮开始前,获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量;
步骤S12,在当前训练轮中,向每一个所述客户端下发全局模型和每一所述客户端对应的所述当前训练轮任务预测量,以使得每一个所述客户端执行本地训练的操作;
步骤S13,接收每一个所述客户端返回的本地训练结果,其中,所述本地训练结果是由所述客户端在当前训练轮中根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练得到的当前训练轮的本地模型和当前训练轮实际任务量;
步骤S14,将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据每一个所述客户端的所述当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
在具体实施时,在第一轮训练中,所述客户端执行随机的一个任务量,此后每轮训练中均根据客户端历史训练轮的任务情况,对客户端在每一轮训练中的任务量进行预测,并根据每一轮任务预测量进行本地训练,以最大限度地利用客户端的资源,能够避免客户端掉队,从而减少客户端的掉队率并提升了分布式学习全局模型的精度。
具体的,所述步骤S11中的“获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量”,具体为:
对于每一个所述客户端,根据获取到的所述客户端在上一个训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量。
进一步的,所述当前训练轮任务预测量是指当前训练轮的任务量预测值,则在一种可选的实施方式中,所述对于每一个所述客户端,根据获取到的所述客户端在上一个训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
根据以下公式对所述客户端在当前训练轮的任务量进行预测,得到所述客户端的当前训练轮的任务量预测值:
Figure BDA0003100469220000111
其中,
Figure BDA0003100469220000112
表示客户端k在第t个训练轮的预测任务量,即上一个训练轮的任务量预测值;
Figure BDA0003100469220000121
表示客户端k在第t+1个训练轮的预测任务量,即当前训练轮的任务量预测值;u是控制增量的超参数。优选的,u设置为10。
参见图2,图2示出了采用本发明实施例一提供的客户端的任务量预算法得到的客户端的预测任务量的变化过程示意图。在本发明实施例中,如果客户端掉队,则当前训练轮的任务量预测值将变为上一个训练轮的任务量预测值的一半;如果客户端在上一个训练轮完成了上一个训练轮的任务量预测值,则当前训练轮的预测任务量在上一个训练轮的任务量预测值的基础上以增量
Figure BDA0003100469220000122
增加,不难看出,增量与上一个训练轮的任务量预测值成反比,即任务量越大增长速度越慢,这种方式以一种更谨慎的方式调整客户端的任务量,更大程度地避免了盲目分配客户端任务致使客户端掉队。并且客户端掉队以后,新一轮的任务量将会变为掉队轮任务量的一半,这种操作方便客户端的任务量快速回复到一个安全水平,同时也避免了客户端连续掉队的现象。
实施例二
本实施例与实施例一不同的是,本实施例在对客户端的当前训练轮任务预测量进行预测时,是根据客户端在过去所有训练轮的任务量完成情况对客户端的当前训练轮任务预测量进行预测,即,在图1提供的步骤S11步骤S14的方案的基础上,可替代性的,所述步骤S11中“获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量”,具体为:
对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量。
进一步的,所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体为:
对于每一个所述客户端,根据以下公式计算所述客户端在过去所有训练轮的负载阈值:
Figure BDA0003100469220000131
其中,
Figure BDA0003100469220000132
为客户端k从第1个训练轮到第t个训练轮的实际任务量的移动加权平均,也即过去所有训练轮的负载阈值,
Figure BDA0003100469220000133
为客户端k从第1轮到第t-1轮的实际任务量的移动加权平均,
Figure BDA0003100469220000134
为客户端k在第t-1个训练轮实际能完成的工作量,α是平滑指数;
根据所述客户端的过去所有训练轮的负载阈值以及所述客户端在上一个训练轮的任务完成情况,确定所述客户端在上一个训练轮的状态:当
Figure BDA0003100469220000138
所述客户端在上一个训练轮处于启动阶段;当
Figure BDA0003100469220000139
所述客户端在上一个训练轮处于增长阶段;当所述客户端在上一个训练轮不能完成上一个训练轮的任务量预测值时,所述客户端在上一个训练轮掉队;
根据以下公式对所述客户端在当前训练轮的任务量进行预测,得到所述客户端的当前训练轮的任务量预测值:
Figure BDA0003100469220000135
其中,
Figure BDA0003100469220000136
表示客户端k在第t个训练轮的预测任务量,即上一个训练轮的任务量预测值;
Figure BDA0003100469220000137
表示客户端k在第t+1个训练轮的预测任务量,即当前训练轮的任务量预测值;γ1和γ2分别为启动阶段和增长阶段的增量,且γ1>γ2
参见图3,图3是采用本发明实施例二提供的客户端的任务量预测算法得到的客户端的预测任务量的变化过程示意图。在本发明实施例中,充分利用了客户端在过去所有训练轮的历史任务量完成情况来预测当前训练轮的任务量预测值,并且在此过程中旧的训练轮数的参考任务量所占比重动态递减,阈值的重心始终在最近几个训练轮,这种方式在充分利用历史训练信息的同时也避免了过时训练信息的滥用。
实施例三
本实施例与实施例一不同的是,在对客户端每轮训练的任务量进行预测时,包括对客户端每轮训练的任务量的下限和上限的预测,即在本实施例中,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值。则,在实施例一提供的技术方案的基础上,作为一种替代的实施方式,所述对于每一个所述客户端,根据获取到的所述客户端在上一训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据以下公式对所述客户端在当前训练轮的任务量的下限和任务量的上限进行预测,得到所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值;
Figure BDA0003100469220000141
其中,
Figure BDA0003100469220000142
为客户端k在第t个训练轮的任务量预测下限值,即上一个训练轮的任务量预测下限值;
Figure BDA0003100469220000143
为客户端k在第t个训练轮的任务量预测上限值,即上一个训练轮的任务量预测上限值;
Figure BDA0003100469220000144
为客户端k在第t+1个训练轮的任务量预测下限值,即当前训练轮的任务量预测下限值;
Figure BDA0003100469220000145
为客户端k在第t+l个训练轮的任务量预测上限值,即当前训练轮的任务量预测上限值;u是控制增量的超参数。
参见图4,图4是本发明实施例三提供的客户端的任务量预测算法的流程图。可以看到,在本发明实施例中利用最近一轮客户端训练任务的完成情况对当前训练轮的任务量的下限和上限进行预测,以使得客户端真实能完成的任务量落在预最小值(任务量预测下限值)和最大值(任务量预测上限值)之间,进而使得所述客户端即便没能完成任务量预测上限值也能完成任务量预测下限值而不至于掉队。
进一步的,在本实施例中,所述客户端的本地训练结果具体通过以下方式获得:
所述客户端接收所述服务器下发的全局模型、所述当前训练轮的任务量预测下限值和所述当前训练轮的任务量预测上限值;
所述客户端采用本地数据对所述全局模型进行训练,当在训练过程中检测到所述本地数据完成了所述当前训练轮的任务量预测下限值时,生成所述客户端的本地模型,并将所述客户端的本地模型发送给所述服务器;
响应于所述服务器发送的增大任务量对所述本地数据进行训练的命令,继续训练所述本地数据,并在训练过程中检测到所述本地数据完成了所述当前训练轮的任务量预测上限值时,更新所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,并将所述当前训练轮的本地模型和当前训练轮实际任务量作为所述客户端的本地训练结果。
实施例四
参见图5,图5示出了本发明实施例四提供的客户端的任务量预测算法的流程图。本实施例与实施例一不同的是,本实施例在每轮训练中是对客户端每轮训练的任务量预测下限值和任务量预测上限值进行预测,且其是根据所述客户端在过去所有训练轮的任务完成情况进行预测。
即,在图1提供的步骤S11步骤S14的方案的基础上,作为一种替代的实施方式,所述步骤S11中“获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量”,具体为:
对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量。
进一步的,当所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值时,所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体为:
对于每一个所述客户端,根据以下公式计算所述客户端在过去所有训练轮的负载阈值:
Figure BDA0003100469220000161
其中,
Figure BDA0003100469220000162
为客户端k从第1个训练轮到第t个训练轮的实际任务量的移动加权平均,也即过去所有训练轮的负载阈值,
Figure BDA0003100469220000163
为客户端k从第1轮到第t-1轮的实际任务量的移动加权平均,
Figure BDA0003100469220000164
为客户端k在第t-1个训练轮实际能完成的工作量,α是平滑指数;
根据所述客户端的过去所有训练轮的负载阈值以及所述客户端在上一个训练轮的任务完成情况,确定所述客户端在上一个训练轮基于上一个训练轮的预测任务量下限值训练的状态和基于上一个训练轮的预测任务量上限值训练的状态:当
Figure BDA0003100469220000165
所述客户端在上一个训练轮基于上一个训练轮的预测任务量下限值训练处于启动阶段;当
Figure BDA0003100469220000166
所述客户端基于上一个训练轮的预测任务量下限值训练处于增长阶段;当
Figure BDA0003100469220000167
所述客户端基于上一个训练轮的预测任务量上限值训练处于启动阶段;当
Figure BDA0003100469220000168
所述客户端基于上一个训练轮的预测任务量上限值训练处于增长阶段;当所述客户端在上一个训练轮不能完成上一个训练轮的预测任务量下限值时,所述客户端掉队;
对于每一个所述客户端,根据以下公式对所述客户端在当前训练轮的任务量的下限和任务量的上限进行预测,得到所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值;
(1)当所述客户端在上一个训练轮能够完成任务
Figure BDA0003100469220000169
时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure BDA0003100469220000171
(2)当客户端在上一个训练轮只能完成
Figure BDA0003100469220000172
而无法完成
Figure BDA0003100469220000173
时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure BDA0003100469220000174
(3)当客户端在上一个训练轮掉队时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure BDA0003100469220000175
其中,
Figure BDA0003100469220000176
为客户端k在第t个训练轮的任务量预测下限值,即上一个训练轮的任务量预测下限值;
Figure BDA0003100469220000177
为客户端k在第t个训练轮的任务量预测上限值,即上一个训练轮的任务量预测上限值;
Figure BDA0003100469220000178
为客户端k在第t+1个训练轮的任务量预测下限值,即当前训练轮的任务量预测下限值;
Figure BDA0003100469220000179
为客户端k在第t+1个训练轮的任务量预测上限值,即当前训练轮的任务量预测上限值;γ1和γ2分别为启动阶段和增长阶段的增量,且γ1>γ2
实施例五
本发明实施例提供的基于分布式学习的模型训练方法,其应用于客户端,包括步骤S21到步骤S23:
步骤S21,将所述客户端的历史任务完成情况发送给服务器,以使得所述服务器根据所述客户端上传的历史任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量;
步骤S22,接收所述服务器发送的全局模型和所述当前训练轮任务预测量,并根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量;
步骤S23,将所述当前训练轮的本地模型和所述当前训练轮实际任务量发送给所述服务器,以使得所述服务器将每一个所述客户端的当前训练轮的本地模型聚合成一个新的全局模型,并根据每一所述客户端的所述当前训练轮对每一个所述客户端的历史任务完成情况进行更新。
实施例六
本发明实施例提供一种服务器,包括处理器、存储器以及存储在所述存储器中且被配置为由所述处理器执行的计算机程序,所述处理器执行所述计算机程序时实现如上述的基于分布式学习的模型训练方法,例如,如图1中的步骤S11到步骤S14。
实施例七
本发明实施例提供一种分布式学习系统,所述分布式学习系统包括若干个客户端和服务器,其中,所述服务器与若干个所述客户端通信连接;
所述服务器,用于根据获取到的每一个客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量,并将每一所述客户端的当前训练轮任务预测量和所述服务器的全局模型发送给每一所述客户端;
所述客户端,用于接收所述全局模型及所述当前训练轮任务预测量,并根据所述全局模型和所述当前训练轮任务预测量进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,并将所述客户端的本地模型和所述当前训练轮实际任务量发送给所述服务器;
所述服务器,还用于接收每一个所述客户端的当前训练轮的本地模型和所述当前训练轮实际任务量,并将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据每一个所述客户端的当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
参见图6,图6示出了本发明实施例七提供的分布式学习系统的框架图,在本发明实施例中,所述客户端通过执行如图1中的步骤S11,对客户端可完成的任务量进行预测,使得不同客户端执行不同的任务量,同一客户端根据其不同状态在每轮也执行不同的任务量,预测依据为客户端训练的历史任务完成情况。其中,服务器一般指的是具有网络通讯能力、具有处理器集群的云计算设备,通常服务器的计算力较强存储容量较大。客户端一般指的是具有网络通讯能力、拥有至少一个处理器的移动设备,例如智能手机、平板电脑、PC等。客户端信息收集进程和任务量预测进程均部署在服务器上。初始化时,客户端执行随机的一个任务量,此后执行的任务量均为预测的任务量。通常在每轮训练开始前,服务器会和客户端进行通信以了解客户端的网络状态等其他信息,客户端训练的历史信息并在此时和通信结果一起返回给客户端。然后客户端根据客户端训练的历史信息,可选择前面实施例一到实施例四任一种预测方式对客户端的任务量进行预测,然后随全局模型一起下发至客户端。客户端在本地采用并行更新的方式训练模型。单个客户端完成模型训练后,将训练后的本地模型参数上传至服务器端,服务器端同意进行模型整合,是一种同步的更新方式。通常,模型更新的计算方法为小批量的随机梯度下降法,公式如下:
Figure BDA0003100469220000191
其中,
Figure BDA0003100469220000192
表示客户端k在第t轮训练的神经网络模型参数,b表示当前轮训练选取的批量数据,例如图片分类任务中一般表示批量的图片和图片对应的标签组成的数据对,η为训练神经网络使用的学习率,依据具体任务设定,通常设定的值为0.1、0.01。l为损失函数,可选的可以设置为平方误差函数或者负对数似然函数,
Figure BDA0003100469220000201
为微分符号,表示损失函数l对权重
Figure BDA0003100469220000202
求导数,
Figure BDA0003100469220000203
为模型更新后的模型参数。每次训练得到新的模型参数称为完成了一次模型更新,训练轮数t增加1,神经网络模型的训练过程通常由多轮模型更新构成。每轮训练完成后服务器手机客户端训练后的模型参数并聚合,然后得到新一轮的初始模型并下发给客户端,迭代此过程最后获得训练完成的全局模型。
与现有技术相比,本发明的有益效果如下:
(1)本发明所提出的基于分布式学习的模型训练方法能够自适应地预测分布学习式中分配给客户端的任务量,并最大限度地利用客户端资源,从而避免客户端掉队,最终减少客户端掉队率并提升了分布式学习的全局模型的精度。实验表明,相较于经典算法FedAvg,本发明在系统异构的分布式系统中平均提升了26.7%的全局模型测试精度并平均减少了90.3%的掉队设备。
(2)本发明针对的是系统异构的分布式机器学习场景,而非理想的实验场景,因此可以更方便地应用于实际的机器学习环境,拥有很强的应用性和可实现性。
以上所述是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也视为本发明的保护范围。

Claims (10)

1.一种基于分布式学习的模型训练方法,其应用于服务器,其特征在于,包括:
在上一训练轮结束且当前训练轮开始前,获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量;
在当前训练轮中,向每一个所述客户端下发全局模型和每一所述客户端对应的所述当前训练轮任务预测量,以使得每一个所述客户端执行本地训练的操作;
接收每一个所述客户端返回的本地训练结果,其中,所述本地训练结果是由所述客户端在当前训练轮中根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练得到的当前训练轮的本地模型和当前训练轮实际任务量;
将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据每一个所述客户端的所述当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
2.如权利要求1所述的基于分布式学习的模型训练方法,其特征在于,所述获取每一个客户端上传的历史任务完成情况,并根据每一个所述客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据获取到的所述客户端在上一个训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量;或,
对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量。
3.如权利要求2所述的基于分布式学习的模型训练方法,其特征在于,所述当前训练轮任务预测量是指当前训练轮的任务量预测值,则所述对于每一个所述客户端,根据获取到的所述客户端在上一个训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
根据以下公式对所述客户端在当前训练轮的任务量进行预测,得到所述客户端的当前训练轮的任务量预测值:
Figure FDA0003794572860000021
其中,
Figure FDA0003794572860000022
表示客户端k在第t个训练轮的预测任务量,即上一个训练轮的任务量预测值;
Figure FDA0003794572860000023
表示客户端k在第t+1个训练轮的预测任务量,即当前训练轮的任务量预测值;u是控制增量的超参数。
4.如权利要求2所述的基于分布式学习的模型训练方法,其特征在于,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值,则所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据以下公式对所述客户端在当前训练轮的任务量的下限和任务量的上限进行预测,得到所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值:
Figure FDA0003794572860000031
其中,
Figure FDA0003794572860000032
为客户端k在第t个训练轮的任务量预测下限值,即上一个训练轮的任务量预测下限值;
Figure FDA0003794572860000033
为客户端k在第t个训练轮的任务量预测上限值,即上一个训练轮的任务量预测上限值;
Figure FDA0003794572860000034
为客户端k在第t+1个训练轮的任务量预测下限值,即当前训练轮的任务量预测下限值;
Figure FDA0003794572860000035
为客户端k在第t+1个训练轮的任务量预测上限值,即当前训练轮的任务量预测上限值;u是控制增量的超参数。
5.如权利要求2所述的基于分布式学习的模型训练方法,其特征在于,所述当前训练轮任务预测量是指当前训练轮的任务量预测值,则所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据以下公式计算所述客户端在过去所有训练轮的负载阈值:
Figure FDA0003794572860000036
其中,
Figure FDA0003794572860000037
为客户端k从第1个训练轮到第t个训练轮的实际任务量的移动加权平均,也即过去所有训练轮的负载阈值,
Figure FDA0003794572860000038
为客户端k从第1轮到第t-1轮的实际任务量的移动加权平均,
Figure FDA0003794572860000039
为客户端k在第t-1个训练轮实际能完成的工作量,α是平滑指数;
根据所述客户端的过去所有训练轮的负载阈值以及所述客户端在上一个训练轮的任务完成情况,确定所述客户端在上一个训练轮的状态:当
Figure FDA00037945728600000310
所述客户端在上一个训练轮处于启动阶段;当
Figure FDA00037945728600000311
所述客户端在上一个训练轮处于增长阶段;当所述客户端在上一个训练轮不能完成上一个训练轮的任务量预测值时,所述客户端在上一个训练轮掉队;
根据以下公式对所述客户端在当前训练轮的任务量进行预测,得到所述客户端的当前训练轮的任务量预测值:
Figure FDA0003794572860000041
其中,
Figure FDA0003794572860000042
表示客户端k在第t个训练轮的预测任务量,即上一个训练轮的任务量预测值;
Figure FDA0003794572860000043
表示客户端k在第t+1个训练轮的预测任务量,即当前训练轮的任务量预测值;γ1和γ2分别为启动阶段和增长阶段的增量,且γ1>γ2
6.如权利要求2所述的基于分布式学习的模型训练方法,其特征在于,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值,则所述对于每一个所述客户端,根据获取到的所述客户端在过去所有训练轮的任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量,具体包括:
对于每一个所述客户端,根据以下公式计算所述客户端在过去所有训练轮的负载阈值:
Figure FDA0003794572860000044
其中,
Figure FDA0003794572860000045
为客户端k从第1个训练轮到第t个训练轮的实际任务量的移动加权平均,也即过去所有训练轮的负载阈值,
Figure FDA0003794572860000046
为客户端k从第1轮到第t-1轮的实际任务量的移动加权平均,
Figure FDA0003794572860000047
为客户端k在第t-1个训练轮实际能完成的工作量,α是平滑指数;
根据所述客户端的过去所有训练轮的负载阈值以及所述客户端在上一个训练轮的任务完成情况,确定所述客户端在上一个训练轮基于上一个训练轮的预测任务量下限值训练的状态和基于上一个训练轮的预测任务量上限值训练的状态:当
Figure FDA0003794572860000051
所述客户端在上一个训练轮基于上一个训练轮的预测任务量下限值训练处于启动阶段;当
Figure FDA0003794572860000052
所述客户端基于上一个训练轮的预测任务量下限值训练处于增长阶段;当
Figure FDA0003794572860000053
所述客户端基于上一个训练轮的预测任务量上限值训练处于启动阶段;当
Figure FDA0003794572860000054
所述客户端基于上一个训练轮的预测任务量上限值训练处于增长阶段;当所述客户端在上一个训练轮不能完成上一个训练轮的预测任务量下限值时,所述客户端掉队;
对于每一个所述客户端,根据以下公式对所述客户端在当前训练轮的任务量的下限和任务量的上限进行预测,得到所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值;
(1)当所述客户端在上一个训练轮能够完成任务
Figure FDA0003794572860000055
时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure FDA0003794572860000056
(2)当客户端在上一个训练轮只能完成
Figure FDA0003794572860000057
而无法完成
Figure FDA0003794572860000058
时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure FDA0003794572860000059
(3)当客户端在上一个训练轮掉队时,所述客户端的当前训练轮的任务量预测下限值和当前训练轮的任务量预测上限值为:
Figure FDA00037945728600000510
其中,
Figure FDA00037945728600000511
为客户端k在第t个训练轮的任务量预测下限值,即上一个训练轮的任务量预测下限值;
Figure FDA0003794572860000061
为客户端k在第t个训练轮的任务量预测上限值,即上一个训练轮的任务量预测上限值;
Figure FDA0003794572860000062
为客户端k在第t+1个训练轮的任务量预测下限值,即当前训练轮的任务量预测下限值;
Figure FDA0003794572860000063
为客户端k在第t+1个训练轮的任务量预测上限值,即当前训练轮的任务量预测上限值;γ1和γ2分别为启动阶段和增长阶段的增量,且γ1>γ2
7.如权利要求1所述的基于分布式学习的模型训练方法,其特征在于,所述当前训练轮任务预测量包括当前训练轮的预测任务量下限值和当前训练轮的预测任务量上限值,则所述客户端的本地训练结果具体通过以下方式获得:
所述客户端接收所述服务器下发的全局模型、所述当前训练轮的任务量预测下限值和所述当前训练轮的任务量预测上限值;
所述客户端采用本地数据对所述全局模型进行训练,当在训练过程中检测到所述本地数据完成了所述当前训练轮的任务量预测下限值时,生成所述客户端的本地模型,并将所述客户端的本地模型发送给所述服务器;
响应于所述服务器发送的增大任务量对所述本地数据进行训练的命令,继续训练所述本地数据,并在训练过程中检测到所述本地数据完成了所述当前训练轮的任务量预测上限值时,更新所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,并将所述当前训练轮的本地模型和当前训练轮实际任务量作为所述客户端的本地训练结果。
8.一种基于分布式学习的模型训练方法,其应用于客户端,其特征在于,包括:
将所述客户端的历史任务完成情况发送给服务器,以使得所述服务器根据所述客户端上传的历史任务完成情况,对所述客户端的当前训练轮的任务量进行预测,得到所述客户端的当前训练轮任务预测量;
接收所述服务器发送的全局模型和所述当前训练轮任务预测量,并根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量;
将所述当前训练轮的本地模型和所述当前训练轮实际任务量发送给所述服务器,以使得所述服务器将每一个所述客户端的当前训练轮的本地模型聚合成一个新的全局模型,并根据每一所述客户端的所述当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
9.一种服务器,包括处理器、存储器以及存储在所述存储器中且被配置为由所述处理器执行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至7任意一项所述的基于分布式学习的模型训练方法。
10.一种分布式学习系统,其特征在于:所述分布式学习系统包括若干个客户端和服务器,其中,所述服务器与若干个所述客户端通信连接;
所述服务器,用于根据获取到的每一个客户端上传的历史任务完成情况,对每一个所述客户端的当前训练轮的任务量进行预测,得到每一个所述客户端的当前训练轮任务预测量,并将每一所述客户端的当前训练轮任务预测量和所述服务器的全局模型发送给每一所述客户端;
所述客户端,用于接收所述全局模型及所述当前训练轮任务预测量,并根据所述全局模型、本地数据和所述当前训练轮任务预测量进行本地训练,得到所述客户端的当前训练轮的本地模型和当前训练轮实际任务量,并将所述客户端的本地模型和所述当前训练轮实际任务量发送给所述服务器;
所述服务器,还用于接收每一个所述客户端的当前训练轮的本地模型和所述当前训练轮实际任务量,并将每一个所述客户端的当前训练轮的本地模型聚合成新的全局模型,并根据每一个所述客户端的当前训练轮实际任务量对每一个所述客户端的历史任务完成情况进行更新。
CN202110624386.3A 2021-06-04 2021-06-04 基于分布式学习的模型训练方法、服务器及分布式系统 Active CN113361598B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110624386.3A CN113361598B (zh) 2021-06-04 2021-06-04 基于分布式学习的模型训练方法、服务器及分布式系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110624386.3A CN113361598B (zh) 2021-06-04 2021-06-04 基于分布式学习的模型训练方法、服务器及分布式系统

Publications (2)

Publication Number Publication Date
CN113361598A CN113361598A (zh) 2021-09-07
CN113361598B true CN113361598B (zh) 2022-10-11

Family

ID=77532152

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110624386.3A Active CN113361598B (zh) 2021-06-04 2021-06-04 基于分布式学习的模型训练方法、服务器及分布式系统

Country Status (1)

Country Link
CN (1) CN113361598B (zh)

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107871160A (zh) * 2016-09-26 2018-04-03 谷歌公司 通信高效联合学习
CN112232518A (zh) * 2020-10-15 2021-01-15 成都数融科技有限公司 一种轻量级分布式联邦学习系统及方法
CN112351503A (zh) * 2020-11-05 2021-02-09 大连理工大学 基于任务预测的多无人机辅助边缘计算资源分配方法

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107025205B (zh) * 2016-01-30 2021-06-22 华为技术有限公司 一种分布式系统中的训练模型的方法及设备
CN108009642B (zh) * 2016-10-31 2021-12-14 腾讯科技(深圳)有限公司 分布式机器学习方法和系统
US20210073639A1 (en) * 2018-12-04 2021-03-11 Google Llc Federated Learning with Adaptive Optimization
US11392843B2 (en) * 2019-04-01 2022-07-19 Accenture Global Solutions Limited Utilizing a machine learning model to predict a quantity of cloud resources to allocate to a customer

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107871160A (zh) * 2016-09-26 2018-04-03 谷歌公司 通信高效联合学习
CN112232518A (zh) * 2020-10-15 2021-01-15 成都数融科技有限公司 一种轻量级分布式联邦学习系统及方法
CN112351503A (zh) * 2020-11-05 2021-02-09 大连理工大学 基于任务预测的多无人机辅助边缘计算资源分配方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
联邦学习通信开销研究综述;邱鑫源 等;《计算机应用》;20210430;333-342 *

Also Published As

Publication number Publication date
CN113361598A (zh) 2021-09-07

Similar Documents

Publication Publication Date Title
CN105550323B (zh) 一种分布式数据库负载均衡预测方法和预测分析器
CN107911478B (zh) 基于化学反应优化算法的多用户计算卸载方法及装置
CN112882815B (zh) 基于深度强化学习的多用户边缘计算优化调度方法
CN115408151A (zh) 一种联邦学习训练加速方法
CN110955463B (zh) 支持边缘计算的物联网多用户计算卸载方法
CN114065863B (zh) 联邦学习的方法、装置、系统、电子设备及存储介质
CN112835715B (zh) 基于强化学习的无人机任务卸载策略的确定方法和装置
CN112261120B (zh) 一种配电物联网云边协同任务卸载方法及装置
CN110213327A (zh) 一种基于边缘计算的资源调度方法、装置及系统
CN113315716A (zh) 拥塞控制模型的训练方法和设备及拥塞控制方法和设备
CN112862112A (zh) 联邦学习方法、存储介质、终端、服务器、联邦学习系统
CN116361377B (zh) 基于工业物联网服务平台的负载预测系统、方法及介质
CN114169543A (zh) 一种基于模型陈旧性与用户参与度感知的联邦学习算法
CN115499376A (zh) 一种负载均衡方法、系统、电子设备及存储介质
CN113361598B (zh) 基于分布式学习的模型训练方法、服务器及分布式系统
CN112600869B (zh) 基于td3算法的计算卸载分配方法和装置
US9501321B1 (en) Weighted service requests throttling
CN117436485A (zh) 基于权衡时延和精度的多退出点的端-边-云协同系统及方法
CN113543160A (zh) 5g切片资源配置方法、装置、计算设备及计算机存储介质
CN112866358B (zh) 一种物联网服务重调度的方法、系统及装置
CN116339932A (zh) 资源调度方法、装置和服务器
CN113391897A (zh) 一种面向异构场景的联邦学习训练加速方法
CN117251276B (zh) 一种面向协作学习平台的灵活调度方法及装置
CN117076131B (zh) 一种任务分配方法、装置、电子设备及存储介质
CN115226130B (zh) 基于公平性感知的多无人机数据卸载方法及相关设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant