CN113361721B - 模型训练方法、装置、电子设备、存储介质及程序产品 - Google Patents

模型训练方法、装置、电子设备、存储介质及程序产品 Download PDF

Info

Publication number
CN113361721B
CN113361721B CN202110730081.0A CN202110730081A CN113361721B CN 113361721 B CN113361721 B CN 113361721B CN 202110730081 A CN202110730081 A CN 202110730081A CN 113361721 B CN113361721 B CN 113361721B
Authority
CN
China
Prior art keywords
training
target
terminal devices
model
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110730081.0A
Other languages
English (en)
Other versions
CN113361721A (zh
Inventor
刘吉
周晨娣
窦德景
贾俊铖
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202110730081.0A priority Critical patent/CN113361721B/zh
Publication of CN113361721A publication Critical patent/CN113361721A/zh
Application granted granted Critical
Publication of CN113361721B publication Critical patent/CN113361721B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F9/00Arrangements for program control, e.g. control units
    • G06F9/06Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
    • G06F9/46Multiprogramming arrangements
    • G06F9/48Program initiating; Program switching, e.g. by interrupt
    • G06F9/4806Task transfer initiation or dispatching
    • G06F9/4843Task transfer initiation or dispatching by program, e.g. task dispatcher, supervisor, operating system
    • G06F9/4881Scheduling strategies for dispatcher, e.g. round robin, multi-level priority queues
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F9/00Arrangements for program control, e.g. control units
    • G06F9/06Arrangements for program control, e.g. control units using stored programs, i.e. using an internal store of processing equipment to receive or retain programs
    • G06F9/46Multiprogramming arrangements
    • G06F9/50Allocation of resources, e.g. of the central processing unit [CPU]
    • G06F9/5005Allocation of resources, e.g. of the central processing unit [CPU] to service a request
    • G06F9/5027Allocation of resources, e.g. of the central processing unit [CPU] to service a request the resource being a machine, e.g. CPUs, Servers, Terminals

Landscapes

  • Engineering & Computer Science (AREA)
  • Software Systems (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本公开提供了一种模型训练方法、装置、电子设备、存储介质及程序产品,涉及人工智能技术领域,尤其涉及分布式计算技术领域。该方法包括:针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备;将目标全局模型的全局模型参数发送给至少两个目标终端设备;接收至少两个目标终端设备发送的本地模型参数,并根据至少两个目标终端设备发送的本地模型参数,更新目标全局模型参数,本地模型参数为至少两个目标终端设备各自根据本地训练样本对目标全局模型进行训练得到的。提高了多个全局模型的训练效率。

Description

模型训练方法、装置、电子设备、存储介质及程序产品
技术领域
本公开涉及人工智能技术领域中的分布式计算技术,尤其涉及一种模型训练方法、装置、电子设备、存储介质及程序产品。
背景技术
联邦学习是一种新的分布式学习机制,利用分布式的数据和计算资源进行机器学习模型的协作训练。联邦学习系统通常包括一个服务器和多个终端设备,联邦学习是由服务器将待训练的全局模型下发给各终端设备,由各终端设备利用本地的私有数据各自训练更新模型参数,并将更新的模型参数上传给服务器,最后由服务器将各终端设备更新的模型参数聚合得到新的全局模型,重复执行进行多轮上述训练,直至全局模型达到收敛。
在多任务连邦学习中,联邦学习系统中存在多个待训练的全局模型,如果同一时间各终端设备只能训练一个全局模型,这无疑会增加其他全局模型的等待时间且训练效率极低,为此可以选择令多个全局模型在多个终端设备之间并行训练。那么,如何为每个全局模型分配终端设备以提高多个全局模型的训练效率是一个亟待解决的问题。
发明内容
本公开提供了一种提高了多全局模型的训练效率的模型训练方法、装置、电子设备、存储介质及程序产品。
根据本公开的一方面,提供了一种模型训练方法,所述方法包括:
针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备;
将所述目标全局模型的全局模型参数发送给所述至少两个目标终端设备;
接收所述至少两个目标终端设备发送的本地模型参数,并根据所述至少两个目标终端设备发送的本地模型参数,更新所述目标全局模型参数,所述本地模型参数为所述至少两个目标终端设备各自根据本地训练样本对所述目标全局模型进行训练得到的。
根据本公开的另一方面,提供了一种模型训练装置,所述装置包括:
选择模块,用于针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从所述多个终端设备中选择至少两个目标终端设备;
发送模块,用于将所述目标全局模型的全局模型参数发送给所述至少两个目标终端设备;
接收模块,用于接收所述至少两个目标终端设备发送的本地模型参数,所述本地模型参数为所述至少两个目标终端设备各自根据本地训练样本对所述目标全局模型进行训练得到的;
更新模块,用于根据所述至少两个目标终端设备发送的本地模型参数,更新所述目标全局模型参数。
根据本公开的再一方面,提供了一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述第一方面所述的方法。
根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行上述第一方面所述的方法。
根据本公开的又一方面,提供了一种计算机程序产品,所述程序产品包括:计算机程序,所述计算机程序存储在可读存储介质中,电子设备的至少一个处理器可以从所述可读存储介质读取所述计算机程序,所述至少一个处理器执行所述计算机程序使得电子设备执行第一方面所述的方法。
根据本公开的技术方案,提高了多全局模型的训练效率。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开实施例提供的联邦学习系统的示意图;
图2是根据本公开实施例提供的模型训练方法的流程示意图;
图3是根据本公开实施例提供的模型训练装置的结构示意图;
图4是用来实现本公开实施例的模型训练方法的电子设备的示意性框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
在多任务联邦学习中,联邦学习系统中有多个待训练的全局模型,每个待训练的全局模型即为一个任务,例如图像分类、语音识别、文本生成等任务。为了提高训练效率,多任务联邦学习中由多个终端设备并行训练这些任务,以使所有任务都能够尽快收敛。然而,由于终端设备的资源有限,多个任务需要共享终端设备的资源,因此如何为每个任务分配恰当的终端设备是必须解决的问题。
为此,本公开实施例的方案中,在为每一个任务分配终端设备时,考虑到当前任务对其他任务的影响,整体考虑所有任务完成训练的时间来为每一个任务分配终端设备,从而最大程度地减少所有任务的总体收敛时间,提高训练效率。
本公开提供一种模型训练方法、装置、电子设备、存储介质及程序产品,应用于人工智能技术领域的分布式计算领域,具体可以应用于移动式边缘计算、物联网云服务、数据联邦平台等场景中,以达到提高训练效率的目的。
在本公开实施例中,假设联邦学习系统由一台服务器和K个终端设备组成,如图1中所示。将K个终端设备的集合记为其中,各终端设备的索引k={1,2,...,K}。联邦学习系统共同参与M个任务(M个不同的全局模型)的训练,其中,各任务的索引m={1,2,...,M}。每个终端设备上都具有M个任务的本地训练样本,其中,第k个设备上第m个任务的本地训练样本为/>其中,/>为/>的样本个数,/>为终端设备k上第m个任务的第d个nm维的输入数据向量,/>为数据/>的标签。任务m在所有终端设备上的训练样本为/>Dm的样本个数为/>
每个终端设备上都有全部任务的本地训练样本,多任务联邦学习是通过不同任务的损失函数从相应的本地训练样本中学习各自的模型参数wm。多任务联邦学习的全局学习问题可以通过以下公式(1)表示:
其中,W={w1,w2,...,wm}是所有任务的模型参数的集合,/>是第m个任务的输入输出/>数据对在参数为wm上的模型损失。
为了解决公式(1)的问题,服务器需要根据多任务联邦学习方案连续地为不同任务选择终端设备,以在选择的终端设备迭代地更新全局模型,直到所有任务的模型收敛。在确保收敛精度的同时,如何使所有任务尽快完成收敛是本公开的方案主要关注的问题。为此,本公开实施例的方法中考虑所有任务完成训练的时间来为每一个任务分配终端设备。
第k个终端设备在接收到第m个任务的全局模型参数后,完成第r轮训练所需要的时间主要由计算时间和通信时间/>决定。对于每个任务,每一轮训练所需要的时间由训练该任务的终端设备中速度最慢的设备决定。假设终端设备与服务器的通信是并行的,第m个任务的第r轮训练所需的时间如下:
为了提高多任务的训练效率,提出以下效率优化问题:
其中,
S={S1,S2,…,SM),
其中,为任务m的第r轮训练的终端设备,sm为任务m的所有轮训练的终端设备的集合,S表示所有任务的所有轮训练的终端设备的集合,/>表示任务m的收敛曲线的参数,lm为任务m的预期损失值或达到收敛的损失值,Rm表示实现损失lm所需的训练轮数。在理想情况下,给定任务的收敛精度lm,根据FedAvg实验得到收敛到相应精度所需的轮数Rm,然后采用最小二乘法即可拟合出此公式(7),也即得到参数/> 公式(7)实现了对任务收敛精度的约束,在给定任务的收敛精度lm后,任务收敛所需的训练轮数Rm也随之确定。
上述优化问题是使得所有任务的收敛所需的时间最小化。由于M个任务并行进行训练,本地训练样本的大小和全局模型的复杂度都有所不同,因此同一设备完成不同任务的更新所需的时间也可能有所不同。为了描述局部模型更新所需时间的随机性,假设终端设备完成一轮训练所需的时间遵循位移指数分布:
其中,参数ak>0是终端设备k的计算能力的最大值,参数μk>0是终端设备k的计算能力的波动值。
待解决的公式(3)和(4)是一个组合优化问题,为了解决此问题,服务器在接收到空闲终端设备的资源信息后,可以根据接收到的设备资源信息调度各任务所需的终端设备。此外,每个任务的训练轮数不需要一致,各任务之间也不需要互相等待。一般情况下,给定全局模型的收敛精度,如公式(7)所示,收敛所需的训练轮数也大致确定。
在理想情况下,即所有终端设备的资源和状态保持不变,服务器可以根据所有终端设备的资源信息,为每个任务一次调度完成所有轮训练所需的终端设备。然而,在实际的计算环境中,终端设备的资源和状态会发生改变。例如,终端设备可能当前是空闲和可用的,但在一段时间之后,设备可能是忙碌且不可用或者一部分资源被占用。因此,一次性完成所有轮训练的终端设备调度是不现实的,为此,本公开实施例中,考虑在实际调度时,每次为待训练的任务调度当前这一轮训练所需的终端设备,并确保当前时间节点,所有任务所需的一轮训练时间最短,也即每个任务的每轮训练均由服务器根据各个任务的训练时间为其安排相应的终端设备,从而提高所有任务的训练效率。
下面,将通过具体的实施例对本公开提供的模型训练方法进行详细地说明。可以理解的是,下面这几个具体的实施例可以相互结合,对于相同或相似的概念或过程可能在某些实施例不再赘述。
图2是根据本公开实施例提供的一种模型训练方法的流程示意图。该方法的执行主体为前述的联邦学习系统中的服务器,服务器中包括多个待训练的全局模型。如图2所示,该方法包括:
S201、针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备。
目标全局模型多个全局模型中的任一全局模型,即本公开实施例中,针对多个全局模型中的任一全局模型的每一轮训练,均根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备,其中,多个终端设备是连邦学习系统中当前空闲的终端设备。为全局模型的一轮训练所选择的终端设备的数量可以预先确定,例如可以按照联邦学习系统中终端设备的数量K乘以相应的系数C,0<C<1,得到为全局模型分配的终端设备的数量。
对于目标全局模型当前的一轮训练,从当前空闲的终端设备中所选择的至少两个目标终端设备,需要保证当前这一轮训练各个全局模型所需要的时间最小。
S202、将目标全局模型的全局模型参数发送给至少两个目标终端设备。
服务器将目标全局模型的全局模型参数发送给选择的至少两个目标终端设备,以使该至少两个目标终端设备对目标全局模型进行训练。可以理解的是,服务器为不同的任务选择各自的终端设备,不同的全局模型的训练互不影响,实现并行训练。
S203、接收至少两个目标终端设备发送的本地模型参数,并根据至少两个目标终端设备发送的本地模型参数,更新目标全局模型参数。
本地模型参数为至少两个目标终端设备各自根据本地训练样本对目标全局模型进行训练得到的。
每个终端设备上均具有各个全局模型的本地训练样本,该至少两个目标终端设备采用各自的本地训练样本对目标全局模型进行训练,完成一轮训练即可得到各自训练获得的本地模型参数,并将各自训练获得的本地模型参数发送给服务器,以使得服务器对该至少两个目标终端设备这一轮的本地模型参数进行聚合,得到新的目标全局模型参数。
之后,服务器重复执行本实施例的步骤,为下一轮训练选择终端设备,并将新的目标全局模型参数发送给下一轮选择的终端设备,以进行下一轮训练,直至目标全局模型收敛。
本公开实施例提供的模型训练方法,在为待训练的每个任务调度当前这一轮训练所需的终端设备时,考虑每个任务对其他任务的影响,根据训练各个任务所需的时间,确保当前时间节点,所有任务所需的一轮训练时间最短,从而提高所有任务的训练效率。
以下结合具体示例对本公开实施例的方法进行说明。
针对多个全局模型中目标全局模型的一轮训练,服务器获取各终端设备当前的资源状态;根据各终端设备当前的资源状态,确定训练各个全局模型所需要的时间;根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备。
示例的,针对多个全局模型中目标全局模型的一轮训练,服务器向各终端设备发送资源请求,并根据各终端设备的响应确定各终端设备当前的资源状态,基于各终端设备当前的资源状态,确定训练各个全局模型所需要的时间,从而根据训练各个全局模型所需要的时间,从多个终端设备中为目标全局模型当前的一轮训练选择至少两个目标终端设备,使得所有全局模型的收敛时间最小。
在终端设备调度过程中,终端设备参与训练的公平性和参与训练数据分布的均衡性是影响收敛速度的关键因素。如果过度选择训练较快的设备,尽管这能加速每一轮的训练速度,但会使得全局模型的训练集中在一小部分设备上,最终导致任务的收敛精度下降。而训练的最终目标是使所有联邦学习任务尽快收敛,同时还要确保模型的准确性。因此,本申请实施例可以在尽可能确保终端设备参与的公平性的前提下来进行设备调度。
针对多个全局模型中目标全局模型的一轮训练,服务器确定多个终端设备中每个终端设备在当前一轮训练之前参与训练目标全局模型的次数是否大于预设值;若存在第一终端设备参与训练目标全局模型的次数大于预设值,则将第一终端设备从多个终端设备中剔除;并且相应的,根据训练各个全局模型所需要的时间,从剔除第一终端设备后的多个终端设备中选择至少两个目标终端设备。从而,确保终端设备参与训练的公平性,防止一些终端设备过度参与训练,并且避免在调度过程中偏向选择较快的终端设备,在确保任务精度的前提下,提升任务的收敛速度。
对于参与训练的数据的均衡性,可以将其与训练时间共同作为优化的目标。假设服务器在第r轮训练中为任务j调度的至少两个终端设备的集合为其中j={1,2,...,M}。任务j训练所需的所有本地训练样本共分为Lj类,存在大小为Lj的集合Qj,任务j在所有类别上的本地训练样本的数量的初始值Qj[l]=0,l=0,l={1,2,...,Lj}。在第r轮训练时,将第r轮之前所有参与任务j的训练的终端设备的本地训练样本按类别进行统计,并将结果放入集合Qj。即,获取各个终端设备训练目标全局模型所使用的本地训练样本的类别;确定各类别的本地训练样本的数量;根据类别和数量,确定波动值。
示例的,可以根据以下公(9)来衡量当前所有参与任务j的训练的本地训练样本在类别上的波动值,该波动值的大小表示数据在类别上的波动程度:
参与模型训练的数据越均衡,则模型收敛越快越稳定。同时,考虑任务j的当前一轮调度会对其他任务的调度产生影响,因此在为当前任务j的一轮训练调度终端设备时会考虑所有任务的一轮训练的运行情况,将本公开实施例中的效率优化问题由公式(3)和(4)优化为以下公式(10),即,在为任务j的第r轮训练调度终端设备时所解决的问题可以写为:
其中,Sr表示所有任务的第r轮训练的终端设备的集合,λ为波动值/>的权重参数,其他各参数的含义与前述相同。
上述调度问题仍然是一个组合优化问题。本公开实施例中可以采用两种方法解决公式(10)提出的问题。
在一种方法中,采用贪心算法,根据各个全局模型所需要的时间,确定各终端设备训练所有全局模型所需要的时间;按照各终端设备训练所有全局模型所需要的时间从小到大的顺序,选择至少两个目标终端设备。
将各终端设备分别带入公式(10)中,可以得到各终端设备训练各个全局模型所需要的时间,并求和得到各终端设备训练所有全局模型所需要的时间,再加上波动值参数,得到优化目标公式(10)中的值,并按照该值对对应的终端设备进行排序。由于/>与终端设备无关,因此,将各终端设备带入公式(10)中时,/>的值相同,因此,在排序时可以仅按照各终端设备训练所有全局模型所需要的时间/>从小到大的顺序对各终端设备进行排序,并选择至少两个目标终端设备作为任务j的第r轮训练的终端设备/>从而,在每一轮调度终端设备时,都能保证这一轮所有任务的收敛时间最小,提高训练效率。
在另一种方法中,采用贝叶斯优化的方法解决公式(10)的优化问题。将公式(10)中的定义为一个目标函数f(x),其中x是决策向量。
对于超参数组合x={x1,x2,...,xM},有最优调度方案xopt
xopt=argmlnx f(X) (13)
超参数组合x表示的即为公式(10)中的xopt即为使得f(x)最小的Sr
由于目标函数f(x)的数学性质是未知的,因此它不能用简单的数学推导出。因此,假设函数f(x)是能够进行随机森林(Random Forest,RF)模型拟合的。而随机森林模型是一种集成学习方法,通过组合多个弱学习器来提升预测精度,并且随机森林模型固有的并行性,以及对数据进行下采样的特点,使得随机森林模型非常适合于大规模数据集。函数f(x)是贝叶斯优化中的求解目标,用来拟合f(x)的RF模型则是贝叶斯优化的概率代理模型,而贝叶斯优化的另一重要部分,采集函数(Acquisition Function)则可以采用改善概率函数(Probability of Improvement,PI)。采用贝叶斯优化的方法解决公式(10)的问题的方法如下:
S1:从多个终端设备中随机选择多个设备样本x,即每个设备样本包括用于训练各全局模型的终端设备。
S2:计算目标全局模型的本地训练样本类别的波动值;根据训练各个全局模型所需要的时间,计算设备样本训练所有全局模型所需要的时间,并确定波动值和设备样本训练所有全局模型所需要的时间的和值,即确定f(x)的值。
S3:将设备样本以及与设备样本对应的和值添加到观测集合Π0中,并采用观测集合Π0对初始随机森林模型进行训练,得到一个训练后的第一随机森林模型。
S4:再次从多个终端设备中获取多个新的设备样本x,并根据新的设备样本和第一随机森林模型,确定为任务j选择的至少两个目标终端设备
本步骤可以通过N次迭代实现,对于t=1,2...,N次的每一次迭代:
获取新的设备样本x,将新的设备样本输入第一随机森林模型中,得到预测值,根据预测值的方差和均值计算采集函数αPI(x;∏t-1),将新的设备样本中采集函数的值最小的第一设备样本xt=argminxαPI(x;Πt-1)以及与第一设备样本对应的和值yt=f(xt)添加到观测集合中,得到新的观测集合Πt=Πt-1∪(xt,yt);采用新的观测集合Πt对第一随机森林模型进行训练,得到第二随机森林模型,并将第二随机森林模型作为新的第一随机森林模型,重复执行迭代,直至执行次数达到预设值N,即可得到最终的新的观测集合ΠN;从新的观测集合ΠN中确定为任务j调度的至少两个目标终端设备从新的观测集合ΠN中,确定和值f(x)最小的目标设备样本xopt,即f(x)最小的/>从目标设备样本中确定目标全局模型(任务j)对应的至少两个目标终端设备/>
在贝叶斯优化方法中,针对于某一任务j的当前轮的终端设备调度,通过上述方法可以得到所有任务的当前轮的终端设备的调度方案但是仅采用任务j的调度方案/>其他任务的调度方案并不采用。对于其他任务的任一轮调度方案同样采样上述方法,按照公式(10)优化目标进行确定。从而,在保证收敛精度的基础上,提高所有任务的训练效率。
图3是根据本公开实施例提供的一种模型训练装置的结构示意图。如图3所示,模型训练装置300包括:
选择模块301,用于针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备;
发送模块302,用于将目标全局模型的全局模型参数发送给至少两个目标终端设备;
接收模块303,用于接收至少两个目标终端设备发送的本地模型参数,本地模型参数为至少两个目标终端设备各自根据本地训练样本对目标全局模型进行训练得到的;
更新模块304,用于根据至少两个目标终端设备发送的本地模型参数,更新目标全局模型参数。
在一种实施方式中,选择模块301包括:
获取子模块,用于获取各终端设备当前的资源状态;
确定子模块,用于根据各终端设备当前的资源状态,确定训练各个全局模型所需要的时间;
第一选择子模块,用于根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备。
在一种实施方式中,第一选择子模块包括:
第一确定单元,用于根据各个全局模型所需要的时间,确定各终端设备训练所有全局模型所需要的时间;
第一选择单元,用于按照各终端设备训练所有全局模型所需要的时间从小到大的顺序,选择至少两个目标终端设备。
在一种实施方式中,第一选择子模块包括:
计算单元,用于计算目标全局模型的本地训练样本类别的波动值;
第二确定单元,用于根据训练各个全局模型所需要的时间,计算设备样本训练所有全局模型所需要的时间,并确定波动值和设备样本训练所有全局模型所需要的时间的和值;设备样本包括用于训练各全局模型的终端设备;
训练单元,用于将设备样本以及与设备样本对应的和值添加到观测集合中,并采用观测集合对初始随机森林模型进行训练,得到第一随机森林模型;
第三确定单元,用于再次从多个终端设备中获取新的设备样本,并根据新的设备样本和第一随机森林模型,确定至少两个目标终端设备。
在一种实施方式中,第三确定单元包括:
迭代单元,用于将新的设备样本输入第一随机森林模型中,得到预测值,根据预测值的方差和均值计算采集函数,将新的设备样本中采集函数的值最小的第一设备样本以及与第一设备样本对应的和值添加到观测集合中,得到新的观测集合;采用新的观测集合对第一随机森林模型进行训练,得到第二随机森林模型,再次获取新的设备样本,并将第二随机森林模型作为新的第一随机森林模型,重复执行此步骤,直至执行次数达到预设值;
第四确定单元,用于从新的观测集合中确定至少两个目标终端设备。
在一种实施方式中,第四确定单元包括:
第一确定子单元,用于从新的观测集合中,确定和值最小的目标设备样本;
第二确定子单元,用于从目标设备样本中确定目标全局模型对应的至少两个目标终端设备。
在一种实施方式中,计算单元包括:
获取子单元,用于获取各个终端设备训练目标全局模型所使用的本地训练样本的类别;
第三确定子单元,用于确定各类别的本地训练样本的数量;
第四确定子单元,用于根据类别和数量,确定波动值。
在一种实施方式中,模型训练装置300还包括:
判断模块,用于确定多个终端设备中每个终端设备参与训练目标全局模型的次数是否大于预设值;
剔除模块,用于若存在第一终端设备参与训练目标全局模型的次数大于预设值,则将第一终端设备从多个终端设备中剔除;
选择模块301包括:第二选择子模块,用于根据训练各个全局模型所需要的时间,从剔除第一终端设备后的多个终端设备中选择至少两个目标终端设备。
本公开实施例的装置可用于执行上述方法实施例中的模型训练方法,其实现原理和技术效果类似,此处不再赘述。
根据本公开的实施例,本公开还提供了一种电子设备和存储有计算机指令的非瞬时计算机可读存储介质。
根据本公开的实施例,本公开还提供了一种计算机程序产品,程序产品包括:计算机程序,计算机程序存储在可读存储介质中,电子设备的至少一个处理器可以从可读存储介质读取计算机程序,至少一个处理器执行计算机程序使得电子设备执行上述任一实施例提供的方案。
图4是用来实现本公开实施例的模型训练方法的电子设备的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图4所示,电子设备400包括计算单元401,其可以根据存储在只读存储器(ROM)402中的计算机程序或者从存储单元408加载到随机访问存储器(RAM)403中的计算机程序,来执行各种适当的动作和处理。在RAM 403中,还可存储设备400操作所需的各种程序和数据。计算单元401、ROM 402以及RAM 403通过总线404彼此相连。输入/输出(I/O)接口405也连接至总线404。
设备400中的多个部件连接至I/O接口405,包括:输入单元406,例如键盘、鼠标等;输出单元407,例如各种类型的显示器、扬声器等;存储单元408,例如磁盘、光盘等;以及通信单元409,例如网卡、调制解调器、无线通信收发机等。通信单元409允许设备400通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元401可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元401的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元401执行上文所描述的各个方法和处理,例如模型训练方法。例如,在一些实施例中,模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元408。在一些实施例中,计算机程序的部分或者全部可以经由ROM 402和/或通信单元409而被载入和/或安装到设备400上。当计算机程序加载到RAM 403并由计算单元401执行时,可以执行上文描述的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元401可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务("Virtual Private Server",或简称"VPS")中,存在的管理难度大,业务扩展性弱的缺陷。服务器也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发申请中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。

Claims (12)

1.一种模型训练方法,所述方法包括:
针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备;
将所述目标全局模型的全局模型参数发送给所述至少两个目标终端设备;
接收所述至少两个目标终端设备发送的本地模型参数,并根据所述至少两个目标终端设备发送的本地模型参数,更新所述目标全局模型参数,所述本地模型参数为所述至少两个目标终端设备各自根据本地训练样本对所述目标全局模型进行训练得到的;
其中,所述根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备,包括:
获取各终端设备当前的资源状态;根据所述各终端设备当前的资源状态,确定训练各个全局模型所需要的时间;根据训练所述各个全局模型所需要的时间,从所述多个终端设备中选择至少两个目标终端设备;
其中,所述根据训练所述各个全局模型所需要的时间,从所述多个终端设备中选择至少两个目标终端设备,包括:
计算所述目标全局模型的本地训练样本类别的波动值;根据训练所述各个全局模型所需要的时间,计算设备样本训练所有全局模型所需要的时间,并确定所述波动值和所述设备样本训练所有全局模型所需要的时间的和值;所述设备样本包括用于训练各全局模型的终端设备;将所述设备样本以及与所述设备样本对应的和值添加到观测集合中,并采用所述观测集合对初始随机森林模型进行训练,得到第一随机森林模型;再次从所述多个终端设备中获取新的设备样本,并根据所述新的设备样本和所述第一随机森林模型,确定所述至少两个目标终端设备。
2.根据权利要求1所述的方法,其中,所述根据所述新的设备样本和所述第一随机森林模型,确定所述至少两个目标终端设备,包括:
将所述新的设备样本输入所述第一随机森林模型中,得到预测值,根据所述预测值的方差和均值计算采集函数,将所述新的设备样本中采集函数的值最小的第一设备样本以及与所述第一设备样本对应的和值添加到所述观测集合中,得到新的观测集合;采用所述新的观测集合对所述第一随机森林模型进行训练,得到第二随机森林模型,再次获取新的设备样本,并将所述第二随机森林模型作为新的第一随机森林模型,重复执行此步骤,直至执行次数达到预设值;
从新的观测集合中确定所述至少两个目标终端设备。
3.根据权利要求2所述的方法,其中,所述从新的观测集合中确定所述至少两个目标终端设备,包括:
从所述新的观测集合中,确定和值最小的目标设备样本;
从所述目标设备样本中确定所述目标全局模型对应的所述至少两个目标终端设备。
4.根据权利要求1-3任一项所述的方法,所述计算所述目标全局模型的本地训练样本类别的波动值,包括:
获取各个终端设备训练所述目标全局模型所使用的本地训练样本的类别;
确定各所述类别的本地训练样本的数量;
根据所述类别和所述数量,确定所述波动值。
5.根据权利要求1-3任一项所述的方法,所述根据训练各个全局模型所需要的时间,从所述多个终端设备中选择至少两个目标终端设备之前,所述方法还包括:
确定所述多个终端设备中每个终端设备参与训练所述目标全局模型的次数是否大于预设值;
若存在第一终端设备参与训练所述目标全局模型的次数大于所述预设值,则将所述第一终端设备从所述多个终端设备中剔除;
所述根据训练各个全局模型所需要的时间,从所述多个终端设备中选择至少两个目标终端设备,包括:
根据训练各个全局模型所需要的时间,从剔除所述第一终端设备后的多个终端设备中选择所述至少两个目标终端设备。
6.一种模型训练装置,所述装置包括:
选择模块,用于针对多个全局模型中目标全局模型的一轮训练,根据训练各个全局模型所需要的时间,从多个终端设备中选择至少两个目标终端设备;
发送模块,用于将所述目标全局模型的全局模型参数发送给所述至少两个目标终端设备;
接收模块,用于接收所述至少两个目标终端设备发送的本地模型参数,所述本地模型参数为所述至少两个目标终端设备各自根据本地训练样本对所述目标全局模型进行训练得到的;
更新模块,用于根据所述至少两个目标终端设备发送的本地模型参数,更新所述目标全局模型参数;
所述选择模块包括:
获取子模块,用于获取各终端设备当前的资源状态;
确定子模块,用于根据所述各终端设备当前的资源状态,确定训练各个全局模型所需要的时间;
第一选择子模块,用于根据训练所述各个全局模型所需要的时间,从所述多个终端设备中选择至少两个目标终端设备;
其中,所述第一选择子模块包括:
计算单元,用于计算所述目标全局模型的本地训练样本类别的波动值;
第二确定单元,用于根据训练所述各个全局模型所需要的时间,计算设备样本训练所有全局模型所需要的时间,并确定所述波动值和所述设备样本训练所有全局模型所需要的时间的和值;所述设备样本包括用于训练各全局模型的终端设备;
训练单元,用于将所述设备样本以及与所述设备样本对应的和值添加到观测集合中,并采用所述观测集合对初始随机森林模型进行训练,得到第一随机森林模型;
第三确定单元,用于再次从所述多个终端设备中获取新的设备样本,并根据所述新的设备样本和所述第一随机森林模型,确定所述至少两个目标终端设备。
7.根据权利要求6所述的装置,其中,所述第三确定单元包括:
迭代单元,用于将所述新的设备样本输入所述第一随机森林模型中,得到预测值,根据所述预测值的方差和均值计算采集函数,将所述新的设备样本中采集函数的值最小的第一设备样本以及与所述第一设备样本对应的和值添加到所述观测集合中,得到新的观测集合;采用所述新的观测集合对所述第一随机森林模型进行训练,得到第二随机森林模型,再次获取新的设备样本,并将所述第二随机森林模型作为新的第一随机森林模型,重复执行此步骤,直至执行次数达到预设值;
第四确定单元,用于从新的观测集合中确定所述至少两个目标终端设备。
8.根据权利要求7所述的装置,其中,所述第四确定单元包括:
第一确定子单元,用于从所述新的观测集合中,确定和值最小的目标设备样本;
第二确定子单元,用于从所述目标设备样本中确定所述目标全局模型对应的所述至少两个目标终端设备。
9.根据权利要求6-8任一项所述的装置,所述计算单元包括:
获取子单元,用于获取各个终端设备训练所述目标全局模型所使用的本地训练样本的类别;
第三确定子单元,用于确定各所述类别的本地训练样本的数量;
第四确定子单元,用于根据所述类别和所述数量,确定所述波动值。
10.根据权利要求6-8任一项所述的装置,所述装置还包括:
判断模块,用于确定所述多个终端设备中每个终端设备参与训练所述目标全局模型的次数是否大于预设值;
剔除模块,用于若存在第一终端设备参与训练所述目标全局模型的次数大于所述预设值,则将所述第一终端设备从所述多个终端设备中剔除;
所述选择模块包括:
第二选择子模块,用于根据训练各个全局模型所需要的时间,从剔除所述第一终端设备后的多个终端设备中选择所述至少两个目标终端设备。
11.一种电子设备,包括:
至少一个处理器;以及与至少一个处理器通信连接的存储器;
其中,存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行权利要求1-5中任一项的方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行权利要求1-5中任一项所述的方法。
CN202110730081.0A 2021-06-29 2021-06-29 模型训练方法、装置、电子设备、存储介质及程序产品 Active CN113361721B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110730081.0A CN113361721B (zh) 2021-06-29 2021-06-29 模型训练方法、装置、电子设备、存储介质及程序产品

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110730081.0A CN113361721B (zh) 2021-06-29 2021-06-29 模型训练方法、装置、电子设备、存储介质及程序产品

Publications (2)

Publication Number Publication Date
CN113361721A CN113361721A (zh) 2021-09-07
CN113361721B true CN113361721B (zh) 2023-07-18

Family

ID=77537077

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110730081.0A Active CN113361721B (zh) 2021-06-29 2021-06-29 模型训练方法、装置、电子设备、存储介质及程序产品

Country Status (1)

Country Link
CN (1) CN113361721B (zh)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113850394B (zh) * 2021-09-18 2023-02-28 北京百度网讯科技有限公司 联邦学习方法、装置、电子设备及存储介质
CN114118437B (zh) * 2021-09-30 2023-04-18 电子科技大学 一种面向微云中分布式机器学习的模型更新同步方法
CN114065864B (zh) * 2021-11-19 2023-08-11 北京百度网讯科技有限公司 联邦学习方法、联邦学习装置、电子设备以及存储介质
CN114217933A (zh) * 2021-12-27 2022-03-22 北京百度网讯科技有限公司 多任务调度方法、装置、设备以及存储介质
CN114548426B (zh) * 2022-02-17 2023-11-24 北京百度网讯科技有限公司 异步联邦学习的方法、业务服务的预测方法、装置及系统
CN116187473B (zh) * 2023-01-19 2024-02-06 北京百度网讯科技有限公司 联邦学习方法、装置、电子设备和计算机可读存储介质

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106845327A (zh) * 2015-12-07 2017-06-13 展讯通信(天津)有限公司 人脸对齐模型的训练方法、人脸对齐方法和装置
CN109345302A (zh) * 2018-09-27 2019-02-15 腾讯科技(深圳)有限公司 机器学习模型训练方法、装置、存储介质和计算机设备
CN109871702A (zh) * 2019-02-18 2019-06-11 深圳前海微众银行股份有限公司 联邦模型训练方法、系统、设备及计算机可读存储介质
WO2020224205A1 (zh) * 2019-05-07 2020-11-12 清华大学 基于区块链的安全协作深度学习方法及装置
CN112365007A (zh) * 2020-11-11 2021-02-12 深圳前海微众银行股份有限公司 模型参数确定方法、装置、设备及存储介质
CN112906864A (zh) * 2021-02-20 2021-06-04 深圳前海微众银行股份有限公司 信息处理方法、装置、设备、存储介质及计算机程序产品
CN113011602A (zh) * 2021-03-03 2021-06-22 中国科学技术大学苏州高等研究院 一种联邦模型训练方法、装置、电子设备和存储介质

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10460255B2 (en) * 2016-07-29 2019-10-29 Splunk Inc. Machine learning in edge analytics
US11836576B2 (en) * 2018-04-13 2023-12-05 International Business Machines Corporation Distributed machine learning at edge nodes

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106845327A (zh) * 2015-12-07 2017-06-13 展讯通信(天津)有限公司 人脸对齐模型的训练方法、人脸对齐方法和装置
CN109345302A (zh) * 2018-09-27 2019-02-15 腾讯科技(深圳)有限公司 机器学习模型训练方法、装置、存储介质和计算机设备
CN109871702A (zh) * 2019-02-18 2019-06-11 深圳前海微众银行股份有限公司 联邦模型训练方法、系统、设备及计算机可读存储介质
WO2020224205A1 (zh) * 2019-05-07 2020-11-12 清华大学 基于区块链的安全协作深度学习方法及装置
CN112365007A (zh) * 2020-11-11 2021-02-12 深圳前海微众银行股份有限公司 模型参数确定方法、装置、设备及存储介质
CN112906864A (zh) * 2021-02-20 2021-06-04 深圳前海微众银行股份有限公司 信息处理方法、装置、设备、存储介质及计算机程序产品
CN113011602A (zh) * 2021-03-03 2021-06-22 中国科学技术大学苏州高等研究院 一种联邦模型训练方法、装置、电子设备和存储介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
面向移动终端智能的自治学习系统;徐梦炜;刘渊强;黄康;刘譞哲;黄罡;;软件学报(第10期);全文 *

Also Published As

Publication number Publication date
CN113361721A (zh) 2021-09-07

Similar Documents

Publication Publication Date Title
CN113361721B (zh) 模型训练方法、装置、电子设备、存储介质及程序产品
CN113516250B (zh) 一种联邦学习方法、装置、设备以及存储介质
CN112561078B (zh) 分布式的模型训练方法及相关装置
CN114298322B (zh) 联邦学习方法和装置、系统、电子设备、计算机可读介质
WO2018102240A1 (en) Joint language understanding and dialogue management
CN112560996B (zh) 用户画像识别模型训练方法、设备、可读存储介质及产品
CN112559007A (zh) 多任务模型的参数更新方法、装置及电子设备
CN112580733B (zh) 分类模型的训练方法、装置、设备以及存储介质
CN114065864B (zh) 联邦学习方法、联邦学习装置、电子设备以及存储介质
CN114723966B (zh) 多任务识别方法、训练方法、装置、电子设备及存储介质
CN113627536B (zh) 模型训练、视频分类方法,装置,设备以及存储介质
CN113850394B (zh) 联邦学习方法、装置、电子设备及存储介质
CN114818913A (zh) 决策生成方法和装置
CN114792125B (zh) 基于分布式训练的数据处理方法、装置、电子设备和介质
US20220374775A1 (en) Method for multi-task scheduling, device and storage medium
CN114841341B (zh) 图像处理模型训练及图像处理方法、装置、设备和介质
CN113591709B (zh) 动作识别方法、装置、设备、介质和产品
CN113361575B (zh) 模型训练方法、装置和电子设备
CN113313049A (zh) 超参数的确定方法、装置、设备、存储介质以及计算机程序产品
CN114461923B (zh) 社群发现方法、装置、电子设备和存储介质
CN117421126A (zh) 一种多租户无服务器平台资源管理方法
CN117575553A (zh) 岗位匹配方法、装置、设备及存储介质
CN117634590A (zh) 模型蒸馏方法、装置、设备及存储介质
CN115598967A (zh) 参数整定模型训练、参数确定方法、装置、设备及介质
Kan et al. Edge computing dynamic unloading based on deep reinforcement learning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant