CN116680565A - 一种联合学习模型训练方法、装置、设备及存储介质 - Google Patents

一种联合学习模型训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN116680565A
CN116680565A CN202310612394.5A CN202310612394A CN116680565A CN 116680565 A CN116680565 A CN 116680565A CN 202310612394 A CN202310612394 A CN 202310612394A CN 116680565 A CN116680565 A CN 116680565A
Authority
CN
China
Prior art keywords
client
model parameter
parameter vector
local
local model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310612394.5A
Other languages
English (en)
Inventor
郭一川
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xinao Xinzhi Technology Co ltd
Original Assignee
Xinao Xinzhi Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xinao Xinzhi Technology Co ltd filed Critical Xinao Xinzhi Technology Co ltd
Priority to CN202310612394.5A priority Critical patent/CN116680565A/zh
Publication of CN116680565A publication Critical patent/CN116680565A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Mathematical Physics (AREA)
  • Medical Informatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Image Analysis (AREA)

Abstract

本申请公开了一种联合学习模型训练方法、装置、设备及存储介质,涉及模型训练技术领域,用以快捷准确的对联合学习模型进行训练。本申请服务器在对各客户端的联合学习模型的每轮迭代训练过程中,接收各客户端发送的上一轮输出的本地模型参数向量;针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的本轮全局模型参数向量发送给该客户端,该客户端采用该本轮全局模型参数向量,对其当前保存的本地模型参数向量进行更新,基于此,实现快捷准确的对联合学习模型进行训练的目的。

Description

一种联合学习模型训练方法、装置、设备及存储介质
技术领域
本申请涉及模型训练技术领域,尤其涉及一种联合学习模型训练方法、装置、设备及存储介质。
背景技术
对于拥有大量数据的单一数据中心(也可称为客户端),可以采用集中式训练的方式得到其所需要的模型。然而在单一数据中心的数据量不足,希望利用其他数据中心的数据的情况下,考虑到数据安全问题,通过采用联合学习训练的方式来得到所需要的模型是非常优秀的方法。
在不同数据中心的数据分布较为相似时,通过常规的联合学习训练方式可以较好地完成模型训练任务。然而,在不同数据中心的数据为非独立同分布或数据差异较大等时,常规的联合学习训练方式一方面可能产生模型收敛困难甚至不收敛的情况,另一方面还可能使得模型的精度受损。
因此,亟需探索出一种可以快捷准确的对联合学习模型进行训练的技术方案。
发明内容
本申请提供了一种联合学习模型训练方法、装置、设备及存储介质,用以快捷准确的对联合学习模型进行训练。
第一方面,本申请提供了一种联合学习模型训练方法,所述方法应用于服务器,所述方法包括:
在对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
接收各客户端发送的上一轮输出的本地模型参数向量;
针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值之后,所述基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量之前,所述方法还包括:
基于所述相似度值以及设定权重算法,确定各客户端各自对应的权重;其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;
根据所述权重,对相应的相似度值进行更新,基于更新后的相似度值进行后续步骤。
在一种可能的实施方式中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,包括:
基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
基于所述相似度值以及注意力引入函数,确定各其他客户端分别对应的权重;
基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重。
在一种可能的实施方式中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值包括:
基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
确定该客户端对应的权重为预设值;
基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
第二方面,本申请提供了一种联合学习模型训练方法,所述方法应用于客户端,所述方法包括:
在参与对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量;
接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新;
基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练,包括:
基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;
所述方法还包括:
获得本轮输出的本地模型参数向量。
第三方面,本申请提供了一种联合学习模型训练装置,所述装置应用于服务器,所述装置包括:
接收模块,用于在对各客户端的联合学习模型的每轮迭代训练过程中,接收各客户端发送的上一轮输出的本地模型参数向量;
确定模块,用于针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述确定模块,还用于:基于所述相似度值以及设定权重算法,确定各客户端各自对应的权重;其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;根据所述权重,对相应的相似度值进行更新,基于各客户端的本地模型参数向量与更新后的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量。
在一种可能的实施方式中,所述确定模块,具体用于:基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述确定模块,具体用于:基于所述相似度值以及注意力引入函数,确定各其他客户端分别对应的权重;基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重。
在一种可能的实施方式中,所述确定模块,具体用于:基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述确定模块,具体用于:确定该客户端对应的权重为预设值;基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
第四方面,本申请提供了一种联合学习模型训练装置,所述装置应用于客户端,所述装置包括:
发送模块,用于在参与对各客户端的联合学习模型的每轮迭代训练过程中,将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量;
更新模块,用于接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新;
训练模块,用于基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述训练模块,具体用于:基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;并获得本轮输出的本地模型参数向量。
第五方面,本申请提供了一种电子设备,所述电子设备至少包括处理器和存储器,所述处理器用于执行存储器中存储的计算机程序时实现如上述任一所述方法的步骤。
第六方面,本申请提供了一种计算机可读存储介质,其存储有计算机程序,所述计算机程序被处理器执行时实现如上述任一所述方法的步骤。
由于本申请服务器(云端)在对各客户端的联合学习模型的每轮迭代训练过程中,可以至少执行以下步骤:接收各客户端发送的上一轮输出的本地模型参数向量;针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的本轮全局模型参数向量发送给该客户端,使该客户端采用该本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。由于本申请可以基于注意力机制,使客户端获得的联合学习模型既可以在模型聚合过程中利用到其他客户端的数据,在模型聚合过程中受益的基础上,还可以保留客户端本地私有数据的特异性,每个客户端均可以获得适合其自身专有的联合学习模型,兼顾客户端的数据差异性以及客户端之间的合作性,从而可以实现快捷准确的对联合学习模型进行训练的目的。
附图说明
为了更清楚地说明本申请实施例或相关技术中的实施方式,下面将对实施例或相关技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。
图1示出了一些实施例提供的第一种联合学习模型训练过程示意图;
图2示出了一些实施例提供的第二种联合学习模型训练过程示意图;
图3示出了一些实施例提供的第三种联合学习模型训练过程示意图;
图4示出了一些实施例提供的第四种联合学习模型训练过程示意图;
图5示出了一些实施例提供的第五种联合学习模型训练过程示意图;
图6示出了一些实施例提供的第六种联合学习模型训练过程示意图;
图7示出了一些实施例提供的一种联合学习模型训练装置示意图;
图8示出了一些实施例提供的另一种联合学习模型训练装置示意图;
图9示出了一些实施例提供的一种电子设备结构示意图;
图10示出了一些实施例提供的另一种电子设备结构示意图。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步的详细描述,显然,本申请所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
需要说明的是,本申请中对于术语的简要说明,仅是为了方便理解接下来描述的实施方式,而不是意图限定本申请的实施方式。除非另有说明,这些术语应当按照其普通和通常的含义理解。
本申请中说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”等是用于区别类似或同类的对象或实体,而不必然意味着限定特定的顺序或先后次序,除非另外注明。应该理解这样使用的用语在适当情况下可以互换。
术语“包括”和“具有”以及他们的任何变形,意图在于覆盖但不排他的包含,例如,包含了一系列组件的产品或设备不必限于清楚地列出的所有组件,而是可包括没有清楚地列出的或对于这些产品或设备固有的其它组件。
术语“模块”是指任何已知或后来开发的硬件、软件、固件、人工智能、模糊逻辑或硬件或/和软件代码的组合,能够执行与该元件相关的功能。
最后应说明的是:以上各实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述各实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的范围。
为了快捷准确的对联合学习模型进行训练,本申请提供了一种联合学习模型训练方法、装置、设备及存储介质。
本申请实施例所有实施方式对数据的获取、存储、使用、处理等均符合国家法律法规的相关规定。
实施例1:
图1示出了一些实施例提供的第一种联合学习模型训练过程示意图,该方法应用于服务器。如图1所示,服务器在对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
S101:接收各客户端发送的上一轮输出的本地模型参数向量。
在一种可能的实施方式中,在对各客户端的联合学习模型的任意一轮(为方便描述,以第k轮进行举例说明)迭代训练过程中,各客户端可以将上一轮(如第k-1轮)输出的联合学习模型的本地模型参数向量发送给服务器。其中,为方便描述,将客户端输出的联合学习模型的模型参数向量称为本地模型参数向量。将任意一个客户端,如第i个客户端(也可称为客户端i)在第k轮迭代训练过程中输出的本地模型参数向量用表示。将第i个客户端(客户端i)在第k-1轮迭代训练过程中输出的本地模型参数向量用/>表示。
S102:针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,服务器接收到每个客户端发送的上一轮输出的本地模型参数向量之后,为了可以快捷准确的对联合学习模型进行训练,得到各客户端各自适合的联合学习模型,服务器可以基于注意力机制(Attention Mechanism),将各客户端在上一轮输出的本地模型参数向量进行加权融合(为方便描述,后续可称为模型聚合),从而确定适合每个客户端的联合学习模型。
具体的,针对每个客户端,服务器可以先确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,然后基于各客户端的本地模型参数向量与确定的相似度值,确定适合该客户端的联合学习模型的本轮全局模型参数向量(其中,为方便描述,将服务器确定的适合任一客户端的联合学习模型称为云端模型,将服务器确定的联合学习模型的模型参数向量称为全局模型参数向量)。示例性的,可以是两个客户端的本地模型参数向量之间越相似(越接近),这两个客户端的本地模型参数向量之间的相似度值越大,这种情况下,针对每个客户端,服务器可以基于各客户端的本地模型参数向量与相应相似度值的加权值,确定适合该客户端的联合学习模型的本轮全局模型参数向量。当然,也可以是两个客户端的本地模型参数向量之间越相似(越接近),这两个客户端的本地模型参数向量之间的相似度值越小,这种情况下,针对每个客户端,服务器可以基于各客户端的本地模型参数向量与相应相似度值的倒数等的加权值,确定适合该客户端的联合学习模型的本轮全局模型参数向量,本申请对此不作具体限定,可以根据需求灵活设置。
在一种可能的实施方式中,为了快捷准确的确定每个客户端各自的联合学习模型,针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值之后,基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量之前,还可以先基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、以及设定权重算法,确定各客户端分别对应的权重;然后可以根据确定的各客户端各自对应的权重,对相应的相似度值进行更新,基于各客户端本地模型参数向量与更新后的相似度值(即各客户端分别对应的权重),确定该客户端的联合学习模型的本轮全局模型参数向量。示例性的,针对每个客户端,可以基于各客户端的本地模型参数向量与各客户端各自对应的权重的加权值,确定该客户端的联合学习模型的本轮全局模型参数向量。
在一种可能的实施方式中,在基于相似度值以及设定权重算法,确定各客户端各自对应的权重时,相似度值与权重可以呈正相关,也可以呈负相关,其中,本地模型参数向量之间越相似,相应权重越大。也就是说,可以是两个客户端的本地私有数据的分布越接近,这两个客户端输出的本地模型参数向量之间越相似,这两个客户端的本地模型参数向量之间的相似度值可以越大,或者也可以越小;当这两个客户端的本地模型参数向量之间的相似度值越大时,相似度值与权重可以呈正相关,即相应权重可以越大;而当这两个客户端的本地模型参数向量之间的相似度值越小时,相似度值与权重可以呈负相关,即确定的相应权重可以越小。
示例性的,如果第1个客户端与第2个客户端的本地模型参数向量之间较相似,第1个客户端与第3个客户端的本地模型参数向量不相似,则在确定适合第1个客户端的联合学习模型的本轮全局模型参数向量时,第2个客户端对应的权重可以较大,而第3个客户端对应的权重可以较小。
相较于相关技术中在模型聚合过程中各客户端的贡献(权重)都是均等的,各客户端均得到同一个联合学习模型,在不同客户端的数据为非独立同分布或数据差异较大等时,联合学习模型收敛困难甚至不收敛以及模型精度受损而言,本申请在模型聚合过程中,可以基于注意力机制让数据分布更接近的客户端获得更大的权重,实现“越相似合作越密切”,使客户端获得的联合学习模型既可以在模型聚合过程中利用到其他客户端的数据,在模型聚合过程中受益的基础上,还可以保留客户端本地私有数据的特异性,每个客户端均可以获得适合其自身专有的联合学习模型,兼顾客户端的数据差异性以及客户端之间的合作性,从而可以实现快捷准确的对联合学习模型进行训练的目的。
为方便理解,下面用公式形式对本申请确定第i个客户端(客户端i)在第k轮的联合学习模型的本轮全局模型参数向量的过程进行解释说明。
其中,第i个客户端在第k轮的联合学习模型的本轮全局模型参数向量用表示。假设总共有m个客户端,其中每个客户端相对于第i个客户端的权重(为方便描述,后续称为每个客户端对应的权重)用ξ表示,ξi,1为第1个客户端对应的权重,/>为第1个客户端在上一轮输出的本地模型参数向量,ξi,2为第2个客户端对应的权重,/>为第2个客户端在上一轮输出的本地模型参数向量,ξi,m为第m个客户端对应的权重,/>为第m个客户端在上一轮输出的本地模型参数向量。
针对第i个客户端,可以基于各客户端的本地模型参数向量与各客户端各自对应的权重的加权值,确定第i个客户端的联合学习模型的本轮全局模型参数向量可选的,各客户端的权重总值可以为设定数值,其中,本申请对各客户端的权重总值不作具体限定,可以根据需求灵活设置,示例性的,各客户端的权重总值可以为1等设定数值,如ξi,1i,2+…+ξi,m=1。
确定了每个客户端的联合学习模型的本轮全局模型参数向量之后,针对每个客户端,服务器可以将该客户端的本轮全局模型参数向量发送给该客户端。针对每个客户端,该客户端接收到服务器发送的本轮全局模型参数向量之后,可以采用该本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,即将当前保存的本地模型参数向量更新为该本轮全局模型参数向量,然后基于更新后的本地模型参数向量(即本轮全局模型参数向量)对待训练的联合学习模型进行本轮迭代训练。
其中,每个客户端本地均可以拥有自己专有的数据集(为方便描述,称为本地私有数据,用符号Di表示)。每个客户端可以分别基于其本地私有数据以及更新后的本地模型参数向量(本轮全局模型参数向量),对每个客户端的联合学习模型进行本轮迭代训练。
为方便理解,下面通过一个具体实施例对本申请提供的联合学习模型训练过程进行解释说明。
参阅图2,图2示出了一些实施例提供的第二种联合学习模型训练过程示意图,该过程包括以下步骤:
S201:在准备开始训练时,服务器(云端)可以初始化云端模型,各客户端也可以初始化本地模型,服务器可以将初始化云端模型(全局模型参数向量)发送(下发)给每个客户端。
可选的,在准备开始训练时,服务器发送给每个客户端的全局模型参数向量(云端模型)可以是相同的。
S202:针对每个客户端,该客户端可以采用接收到的全局模型参数向量对当前保存的本地模型参数向量进行更新,基于更新后的本地模型参数向量以及本地私有数据,对待训练的联合学习模型进行迭代训练,获得本轮输出的本地模型参数向量。
S203:各客户端可以分别将输出的本地模型参数向量发送给服务器。
S204:针对每个客户端,服务器可以确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,基于相似度值以及设定权重算法,确定各客户端各自对应的权重,其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;根据权重,对相应的相似度值进行更新,基于各客户端的本地模型参数向量与更新后的相似度值(权重),确定该客户端的联合学习模型的下一轮全局模型参数向量,服务器将该客户端的下一轮全局模型参数向量发送给该客户端。
S205:客户端判断是否达到终止训练条件,若是,则进行S206;若否,则重复进行S202。
其中,本申请对终止训练条件不作具体限定,可以根据需求灵活设置,例如可以是训练总轮数达到设定轮数阈值,也可以是模型已经收敛等。
S206:根据接收到的本轮全局模型参数向量对当前保存的本地模型参数向量进行更新,并结束训练。
由于本申请服务器(云端)在对各客户端的联合学习模型的每轮迭代训练过程中,可以至少执行以下步骤:接收各客户端发送的上一轮输出的本地模型参数向量;针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的本轮全局模型参数向量发送给该客户端,使该客户端采用该本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。由于本申请可以基于注意力机制,使客户端获得的联合学习模型既可以在模型聚合过程中利用到其他客户端的数据,在模型聚合过程中受益的基础上,还可以保留客户端本地私有数据的特异性,每个客户端均可以获得适合其自身专有的联合学习模型,兼顾客户端的数据差异性以及客户端之间的合作性,从而可以实现快捷准确的对联合学习模型进行训练的目的。
实施例2:
为了准确的确定客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,在上述实施例的基础上,在本申请实施例中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,包括:
基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,针对每个客户端,可以基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。示例性的,以计算第i个客户端在上一轮(第k-1)的本地模型参数向量与第j个客户端在上一轮(第k-1)的本地模型参数向量/>之间的相似度值为例,可以基于欧式距离相似度算法:/>来计算/>与/>之间的相似度值。可选的,可以是两个客户端的本地模型参数向量之间越相似,这两个客户端的本地模型参数向量之间的欧氏距离越小,相似度值也越小。
在一种可能的实施方式中,为了准确的确定各客户端的全局模型参数向量,针对每个客户端,在确定各其他客户端相对于该客户端的权重时,可以是分别基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值以及注意力引入函数,来确定各其他客户端分别对应的权重(各其他客户端分别相对于该客户端的权重)。例如,注意力引入函数可以为:A(x)=1-e-x,其在[0,+∞)区间内为单调递减函数,可以基于 来确定各其他客户端分别对应的权重,即/>其中,本地模型参数向量与第i个客户端(客户端i)的本地模型参数向量/>越相似的其他客户端,与客户端i之间的欧式距离越小,该其他客户端对应的权重越大,其对该客户端(客户端i)的全局模型参数向量/>(云端模型)的贡献也就越大。
在一种可能的实施方式中,确定了各其他客户端分别对应的权重之后,可以基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重(也可称为该客户端相对于该客户端自身的权重,用ξi,i表示)。示例性的,假设设定权重总值为1,可以用1减去各其他客户端对应的权重和,从而得到该客户端对应的权重ξi,i
为方便理解,下面通过一个具体实施例对本申请提供的联合学习模型训练过程进行解释说明。
参阅图3,图3示出了一些实施例提供的第三种联合学习模型训练过程示意图,在对各客户端的联合学习模型的每轮迭代训练过程中,服务器执行以下步骤:
S301:服务器接收各客户端发送的上一轮(如第k-1轮)输出的本地模型参数向量。
S302:针对每个客户端,服务器基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,基于相似度值以及注意力引入函数,确定各其他客户端分别对应的权重,基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重;根据权重,对相应的相似度值进行更新,基于各客户端的本地模型参数向量与更新后的相似度值(权重),确定该客户端的联合学习模型的本轮(如第k轮)全局模型参数向量,服务器将该客户端的本轮全局模型参数向量发送给该客户端。
S303:针对每个客户端,该客户端采用该客户端的本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并基于更新后的本地模型参数向量对该客户端待训练的联合学习模型进行本轮迭代训练。
其中,在对联合学习模型的训练过程中,可以重复循环进行S301-S303,直到得到训练完成的目标联合学习模型。
本申请可以基于欧式距离相似度算法,快捷准确的确定客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。另外,本申请可以基于由欧式距离相似度算法获得的相似度值以及注意力引入函数,快捷的确定各其他客户端分别对应的权重,并可以基于设定权重总值及各其他客户端对应的权重,确定该客户端对应的权重,基于确定的各客户端分别对应的权重,可以快捷准确的确定各客户端对该客户端的全局模型参数向量所做的贡献大小,可以提高确定的客户端的全局模型参数向量的准确性。
实施例3:
为了准确的确定客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,在上述各实施例的基础上,在本申请实施例中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值包括:
基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,在计算相似度值,除了可以用上述实施例介绍的欧式距离相似度算法之外,针对每个客户端,还可以基于余弦值相似度算法,来确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。例如,可以基于余弦值相似度算法:来计算第i个客户端在上一轮(第k-1)的本地模型参数向量/>与第j个客户端在上一轮(第k-1)的本地模型参数向量/>之间的相似度值。其中,两个客户端的本地模型参数向量(矩阵)越相似,这两个客户端的本地模型参数向量之间的余弦值越接近1,即越大,相似度值也越大。两个客户端的本地模型参数向量(矩阵)越不相似,这两个客户端的本地模型参数向量之间的余弦值越小,相似度值也越小。其中,本申请对确定相似度值的具体方式不作具体限定,例如可以采用欧式距离相似度算法,也可以采用余弦值相似度算法等,可以根据需求灵活选择。示例性的,当模型参数向量较大时,可以用余弦相似度算法来计算相似度值。
在上述各实施例的基础上,在本申请实施例中,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
确定该客户端对应的权重为预设值;
基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
在一种可能的实施方式中,在基于余弦值相似度算法确定各相似度值之后,在确定各客户端分别对应的权重时,以确定各客户端分别相对于客户端i的权重为例,可以先确定该客户端对应的权重(也可称为该客户端相对于该客户端自身的权重ξi,i)为预设值,其中,本申请对ξi,i的具体取值不做具体限定,可以根据需求灵活设置。示例性的,ξi,i可以小于1。确定了该客户端对应的权重ξi,i之后,可以基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重ξi,i,确定各其他客户端分别对应的权重。示例性的,以计算客户端j对应的权重(客户端j相对于客户端i的权重)ξi,j为例,可以基于公式:来计算。其中,σ为超参数,可以为固定值。假设总共有m个客户端,h可以为除i之外的任一数值,例如h可以为1、2、3、……m,且h不等于i。
为方便理解,下面通过一个具体实施例对本申请提供的联合学习模型训练过程进行解释说明。
参阅图4,图4示出了一些实施例提供的第四种联合学习模型训练过程示意图,在对各客户端的联合学习模型的每轮迭代训练过程中,服务器执行以下步骤:
S401:服务器接收各客户端发送的上一轮(如第k-1轮)输出的本地模型参数向量。
S402:针对每个客户端,服务器基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,并确定该客户端对应的权重为预设值,基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重;根据权重,对相应的相似度值进行更新,基于各客户端的本地模型参数向量与更新后的相似度值(权重),确定该客户端的联合学习模型的本轮(如第k轮)全局模型参数向量,服务器将该客户端的本轮全局模型参数向量发送给该客户端。
S403:针对每个客户端,该客户端采用该客户端的本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并基于更新后的本地模型参数向量对该客户端待训练的联合学习模型进行本轮迭代训练。
其中,在对联合学习模型的训练过程中,可以重复循环进行S401-S403,直到得到训练完成的目标联合学习模型。
本申请可以基于余弦值相似度算法,快捷准确的确定客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。另外,本申请可以确定该客户端对应的权重为预设值,并基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重,基于确定的各客户端分别对应的权重,可以快捷准确的确定各客户端对该客户端的全局模型参数向量所做的贡献大小,可以提高确定的客户端的全局模型参数向量的准确性。
实施例4:
基于相同的技术构思,本申请提供了一种联合学习模型训练方法,该方法应用于任一客户端。参阅图5,图5示出了一些实施例提供的第五种联合学习模型训练过程示意图。在参与对各客户端的联合学习模型的每轮迭代训练过程中,每个客户端至少执行以下步骤:
S501:将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量。
S502:接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新。
S503:基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,在上述各实施例的基础上,在本申请实施例中,所述基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练,包括:
基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;
所述方法还包括:
获得本轮输出的本地模型参数向量。
在一种可能的实施方式中,为了更好地结合服务器(云端)所采用的注意力机制,可以使用邻近点法对客户端的联合学习模型进行优化训练。具体的,可以基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练。其中优化训练的目标可以为:
其中,为客户端i在第k轮输出的本地模型参数向量,/>表示本地模型参数向量w的维度,Fi(w)表示本地模型在本地私有数据集上的损失,/>为服务器确定的客户端i的本轮全局模型参数向量(云端模型),/>为超参数。基于临近点法对客户端的联合学习模型进行训练时,可以基于正则项:/>使得本地模型参数向量尽可能靠近全局模型参数向量(云端模型),可以快捷准确的对联合学习模型进行训练。
为方便理解,下面再通过一个具体实施例对本申请提供的联合学习模型的训练过程进行解释说明。参阅图6,图6示出了一些实施例提供的第六种联合学习模型训练过程示意图。
假设共有m个客户端,分别命名为客户端1、客户端2、……客户端i……、客户端m。可选的,各客户端可以为用于预测用气负荷的客户端,不同的客户端由于所归属的城燃公司所在地区的地域、天气、文化等差异,每个客户端的本地私有数据的差异较大。每个客户端的本地私有数据可以分别用D1、D2、……、Di、……Dm表示。
在第k-1轮的迭代训练过程中,各客户端分别输出了各自的本地模型参数向量,其中,客户端1输出的本地模型参数向量用表示,客户端2输出的本地模型参数向量用表示,客户端i输出的本地模型参数向量用/>表示,客户端m输出的本地模型参数向量用/>表示。各客户端可以均将其输出的本地模型参数向量发送给服务器(云端)。针对每个客户端,服务器分别确定各客户端对应的权重,其中,确定各客户端的权重的过程与上述实施例中介绍的确定权重的过程相同,在此不再赘述。其中,以确定相对于客户端i的权重为例,确定的客户端1对应的权重用ξi,1表示,客户端2对应的权重用ξi,2表示,客户端i对应的权重用ξi,i表示,客户端m对应的权重用ξi,m表示。服务器可以基于各客户端的本地模型参数向量与相应权重的加权值,确定该客户端的联合学习模型的本轮全局模型参数向量。其中,确定各客户端的本轮全局模型参数向量的过程与上述实施例相同,在此不再赘述。为方便描述,将确定的客户端1适合的本轮全局模型参数向量用/>表示,客户端2适合的本轮全局模型参数向量用/>表示,客户端i适合的本轮全局模型参数向量用/>表示,客户端m适合的本轮全局模型参数向量用/>表示。针对每个客户端,服务器可以该客户端的本轮全局模型参数向量发送给该客户端,该客户端接收到该本轮全局模型参数向量之后,可以采用该本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练,在此不再赘述。本申请提供的联合学习模型训练方法可以基于注意力机制,获得精度良好的适合每个客户端自身的联合学习模型。
实施例5:
基于相同的技术构思,本申请提供了一种联合学习模型训练装置,所述装置应用于服务器,参阅图7,图7示出了一些实施例提供的一种联合学习模型训练装置示意图,所述装置包括:
接收模块71,用于在对各客户端的联合学习模型的每轮迭代训练过程中,接收各客户端发送的上一轮输出的本地模型参数向量;
确定模块72,用于针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述确定模块72,还用于:基于所述相似度值以及设定权重算法,确定各客户端各自对应的权重;其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;根据所述权重,对相应的相似度值进行更新,基于各客户端的本地模型参数向量与更新后的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量。
在一种可能的实施方式中,所述确定模块72,具体用于:基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述确定模块72,具体用于:基于所述相似度值以及注意力引入函数,确定各其他客户端分别对应的权重;基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重。
在一种可能的实施方式中,所述确定模块72,具体用于:基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述确定模块72,具体用于:确定该客户端对应的权重为预设值;基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
实施例6:
基于相同的技术构思,本申请提供了一种联合学习模型训练装置,所述装置应用于客户端,参阅图8,图8示出了一些实施例提供的另一种联合学习模型训练装置示意图,所述装置包括:
发送模块81,用于在参与对各客户端的联合学习模型的每轮迭代训练过程中,将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量;
更新模块82,用于接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新;
训练模块83,用于基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述训练模块83,具体用于:基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;并获得本轮输出的本地模型参数向量。
实施例7:
基于相同的技术构思,本申请还提供了一种电子设备,图9示出了一些实施例提供的一种电子设备结构示意图,如图9所示,电子设备包括:处理器91、通信接口92、存储器93和通信总线94,其中,处理器91,通信接口92,存储器93通过通信总线94完成相互间的通信;
所述存储器93中存储有计算机程序,当所述程序被所述处理器91执行时,使得所述处理器91执行如下步骤:
在对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
接收各客户端发送的上一轮输出的本地模型参数向量;
针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述处理器91,还用于:基于所述相似度值以及设定权重算法,确定各客户端各自对应的权重;其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;根据所述权重,对相应的相似度值进行更新,基于各客户端的本地模型参数向量与更新后的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量。
在一种可能的实施方式中,所述处理器91,具体用于:基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述处理器91,具体用于:基于所述相似度值以及注意力引入函数,确定各其他客户端分别对应的权重;基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重。
在一种可能的实施方式中,所述处理器91,具体用于:基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述处理器91,具体用于:确定该客户端对应的权重为预设值;基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口92用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。可选地,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述处理器可以是通用处理器,包括中央处理器、网络处理器(NetworkProcessor,NP)等;还可以是数字指令处理器(Digital Signal Processing,DSP)、专用集成电路、现场可编程门陈列或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。
基于相同的技术构思,本申请还提供了一种电子设备,参阅图10所示,图10示出了一些实施例提供的另一种电子设备结构示意图,电子设备包括:处理器101、通信接口102、存储器103和通信总线104,其中,处理器101,通信接口102,存储器103通过通信总线104完成相互间的通信;
所述存储器103中存储有计算机程序,当所述程序被所述处理器101执行时,使得所述处理器101执行如下步骤:
在参与对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量;
接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新;
基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述处理器,具体用于:基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;并获得本轮输出的本地模型参数向量。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口102用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。可选地,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述处理器可以是通用处理器,包括中央处理器、网络处理器(NetworkProcessor,NP)等;还可以是数字指令处理器(Digital Signal Processing,DSP)、专用集成电路、现场可编程门陈列或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。
实施例8:
基于相同的技术构思,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质内存储有可由电子设备执行的计算机程序,当所述程序在所述电子设备上运行时,使得所述电子设备执行时实现如下步骤:
在对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
接收各客户端发送的上一轮输出的本地模型参数向量;
针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值之后,所述基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量之前,所述方法还包括:
基于所述相似度值以及设定权重算法,确定各客户端各自对应的权重;其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;
根据所述权重,对相应的相似度值进行更新,基于更新后的相似度值进行后续步骤。
在一种可能的实施方式中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,包括:
基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
基于所述相似度值以及注意力引入函数,确定各其他客户端分别对应的权重;
基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重。
在一种可能的实施方式中,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值包括:
基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
在一种可能的实施方式中,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
确定该客户端对应的权重为预设值;
基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
基于相同的技术构思,本申请还提供了一种计算机可读存储介质,所述计算机可读存储介质内存储有可由电子设备执行的计算机程序,当所述程序在所述电子设备上运行时,使得所述电子设备执行时实现如下步骤:
在参与对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量;
接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新;
基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
在一种可能的实施方式中,所述基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练,包括:
基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;
所述方法还包括:
获得本轮输出的本地模型参数向量。
上述计算机可读存储介质可以是电子设备中的处理器能够存取的任何可用介质或数据存储设备,包括但不限于磁性存储器如软盘、硬盘、磁带、磁光盘(MO)等、光学存储器如CD、DVD、BD、HVD等、以及半导体存储器如ROM、EPROM、EEPROM、非易失性存储器(NANDFLASH)、固态硬盘(SSD)等。
基于相同的技术构思,本申请提供了一种计算机程序产品,所述计算机程序产品包括:计算机程序代码,当所述计算机程序代码在计算机上运行时,使得计算机执行时实现上述应用于电子设备的任一方法实施例所述的方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令,在计算机上加载和执行所述计算机指令时,全部或部分地产生按照本申请实施例所述的流程或功能。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
显然,本领域的技术人员可以对本申请进行各种改动和变型而不脱离本申请的精神和范围。这样,倘若本申请的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包含这些改动和变型在内。

Claims (10)

1.一种联合学习模型训练方法,其特征在于,所述方法应用于服务器,所述方法包括:
在对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
接收各客户端发送的上一轮输出的本地模型参数向量;
针对每个客户端,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量;将该客户端的所述本轮全局模型参数向量发送给该客户端,使该客户端采用所述本轮全局模型参数向量对该客户端当前保存的本地模型参数向量进行更新,并使该客户端基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
2.根据权利要求1所述的方法,其特征在于,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值之后,所述基于各客户端的本地模型参数向量与确定的相似度值,确定该客户端的联合学习模型的本轮全局模型参数向量之前,所述方法还包括:
基于所述相似度值以及设定权重算法,确定各客户端各自对应的权重;其中相似度值与权重呈正相关或呈负相关,本地模型参数向量之间越相似,相应权重越大;
根据所述权重,对相应的相似度值进行更新,基于更新后的相似度值进行后续步骤。
3.根据权利要求1或2所述的方法,其特征在于,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值,包括:
基于欧式距离相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
4.根据权利要求3所述的方法,其特征在于,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
基于所述相似度值以及注意力引入函数,确定各其他客户端分别对应的权重;
基于设定权重总值与各其他客户端对应的权重和的差值,确定该客户端对应的权重。
5.根据权利要求1或2所述的方法,其特征在于,所述确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值包括:
基于余弦值相似度算法,确定该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值。
6.根据权利要求5所述的方法,其特征在于,所述基于所述相似度值以及设定权重算法,确定各客户端分别对应的权重,包括:
确定该客户端对应的权重为预设值;
基于该客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值、该客户端对应的权重,确定各其他客户端分别对应的权重。
7.一种联合学习模型训练方法,其特征在于,所述方法应用于客户端,所述方法包括:
在参与对各客户端的联合学习模型的每轮迭代训练过程中,至少执行以下步骤:
将所述客户端自身上一轮输出的本地模型参数向量发送给服务器,使所述服务器确定所述客户端的本地模型参数向量分别与各其他客户端的本地模型参数向量之间的相似度值;并使所述服务器基于各客户端的本地模型参数向量与确定的相似度值,确定所述客户端的联合学习模型的本轮全局模型参数向量;
接收所述服务器发送的所述客户端的本轮全局模型参数向量,并采用所述本轮全局模型参数向量对所述客户端当前保存的本地模型参数向量进行更新;
基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练。
8.根据权利要求7所述的方法,其特征在于,所述基于更新后的本地模型参数向量对待训练的联合学习模型进行本轮迭代训练,包括:
基于更新后的本地模型参数向量、所述客户端的本地私有数据以及邻近点法,对所述客户端的联合学习模型进行本轮迭代训练;
所述方法还包括:
获得本轮输出的本地模型参数向量。
9.一种电子设备,其特征在于,所述电子设备至少包括处理器和存储器,所述处理器用于执行存储器中存储的计算机程序时实现如权利要求1-8任一所述方法的步骤。
10.一种计算机可读存储介质,其特征在于,其存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1-8任一所述方法的步骤。
CN202310612394.5A 2023-05-29 2023-05-29 一种联合学习模型训练方法、装置、设备及存储介质 Pending CN116680565A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310612394.5A CN116680565A (zh) 2023-05-29 2023-05-29 一种联合学习模型训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310612394.5A CN116680565A (zh) 2023-05-29 2023-05-29 一种联合学习模型训练方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN116680565A true CN116680565A (zh) 2023-09-01

Family

ID=87778429

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310612394.5A Pending CN116680565A (zh) 2023-05-29 2023-05-29 一种联合学习模型训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN116680565A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117057442A (zh) * 2023-10-09 2023-11-14 之江实验室 一种基于联邦多任务学习的模型训练方法、装置及设备

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117057442A (zh) * 2023-10-09 2023-11-14 之江实验室 一种基于联邦多任务学习的模型训练方法、装置及设备

Similar Documents

Publication Publication Date Title
CN109313586B (zh) 使用基于云端的度量迭代训练人工智能的系统
CN110770764A (zh) 超参数的优化方法及装置
CN116680565A (zh) 一种联合学习模型训练方法、装置、设备及存储介质
CN102298569A (zh) 在线学习算法的并行化
JP7009020B2 (ja) 学習方法、学習システム、学習装置、方法、適用装置、及びコンピュータプログラム
CN114261400B (zh) 一种自动驾驶决策方法、装置、设备和存储介质
WO2020034593A1 (zh) 人群绩效特征预测中的缺失特征处理方法及装置
CN111174793A (zh) 路径规划方法及装置、存储介质
CN109616224B (zh) 一种航迹关联置信度评估方法、电子设备和存储介质
CN116756536B (zh) 数据识别方法、模型训练方法、装置、设备及存储介质
CN117151208B (zh) 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质
CN113220883A (zh) 一种文本分类模型性能优化方法、装置及存储介质
CN111710153A (zh) 交通流量的预测方法、装置、设备及计算机存储介质
US20220309398A1 (en) Decentralized control of beam generating devices
CN110533158B (zh) 模型建构方法、系统及非易失性电脑可读取记录介质
US20230419172A1 (en) Managing training of a machine learning model
CN114742644A (zh) 训练多场景风控系统、预测业务对象风险的方法和装置
CN112416560A (zh) 众包场景中对于数值任务的真值推断和在线任务分配方法
CN112488831A (zh) 区块链网络交易方法、装置、存储介质及电子设备
CN109581284A (zh) 基于交互多模型的非视距误差消除方法
CN114580578B (zh) 具有约束的分布式随机优化模型训练方法、装置及终端
CN117076131B (zh) 一种任务分配方法、装置、电子设备及存储介质
CN112506673B (zh) 面向智能边缘计算的协同模型训练任务配置方法
Du et al. CDA-MBPO: Corrected Data Aggregation for Model-Based Policy Optimization
CN115292037A (zh) 边缘网络下的任务可靠性保证方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination