CN117557870A - 基于联邦学习客户端选择的分类模型训练方法及系统 - Google Patents
基于联邦学习客户端选择的分类模型训练方法及系统 Download PDFInfo
- Publication number
- CN117557870A CN117557870A CN202410022912.2A CN202410022912A CN117557870A CN 117557870 A CN117557870 A CN 117557870A CN 202410022912 A CN202410022912 A CN 202410022912A CN 117557870 A CN117557870 A CN 117557870A
- Authority
- CN
- China
- Prior art keywords
- client
- training
- clients
- matrix
- virtual queue
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 214
- 238000000034 method Methods 0.000 title claims abstract description 109
- 238000013145 classification model Methods 0.000 title claims abstract description 22
- 239000011159 matrix material Substances 0.000 claims abstract description 152
- 230000008569 process Effects 0.000 claims abstract description 41
- 238000004590 computer program Methods 0.000 claims description 13
- 238000004220 aggregation Methods 0.000 claims description 9
- 230000002776 aggregation Effects 0.000 claims description 9
- 238000003860 storage Methods 0.000 claims description 8
- 238000004364 calculation method Methods 0.000 claims description 6
- 238000009826 distribution Methods 0.000 description 14
- 230000008901 benefit Effects 0.000 description 8
- 230000006870 function Effects 0.000 description 5
- 238000005457 optimization Methods 0.000 description 5
- 230000006872 improvement Effects 0.000 description 4
- 238000006243 chemical reaction Methods 0.000 description 3
- 238000011161 development Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000007774 longterm Effects 0.000 description 3
- 238000010187 selection method Methods 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 230000007786 learning performance Effects 0.000 description 2
- 238000004519 manufacturing process Methods 0.000 description 2
- 230000003068 static effect Effects 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 108010014172 Factor V Proteins 0.000 description 1
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 238000005304 joining Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000005303 weighing Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Databases & Information Systems (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请涉及一种基于联邦学习客户端选择的分类模型训练方法、系统及介质,其中,基于联邦学习客户端选择的分类模型训练方法包括:初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定参与本轮训练的K个客户端;将所述全局模型发送至所述K个客户端进行并行训练,得到聚合后的全局模型;更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,并重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;使用训练好的全局模型对目标数据集进行分类,得到分类结果,提高了图像分类的精度。
Description
技术领域
本申请涉及联邦学习技术领域,特别是涉及一种基于联邦学习客户端选择的分类模型训练方法及系统。
背景技术
联邦学习系统通常包括一个中央服务器和多个客户端。客户端使用本地数据来训练本地模型,然后将模型参数上传至服务器,服务器通过聚合多个客户端的模型参数来形成全局模型。然而,由于通信带宽的限制,不是所有客户端都能参与每一轮的训练。在实际训练过程中,通常只选择其中的一小部分客户端参与训练。这些客户端的数据来源和处理方式通常是高度异质的,因此选择哪些客户端参与训练成为影响联邦学习性能的关键问题。
目前,最常见的客户端方法是随机选择策略,即在每个训练轮次中随机选择固定数量的客户端。另外一种方法考虑客户端在训练过程中的损失值,当客户端的训练损失越大时,表示当前模型无法较好地学习本地数据,则该方法选择训练损失最大的若干个客户端参与训练。但是,这两种方法都没有考虑客户端数据分布的多样性和公平性约束,导致对联邦学习的性能改进非常有限,进而导致基于联邦学习客户端选择训练获得的分类模型精度较低。
发明内容
基于此,有必要针对上述技术问题,提供一种基于联邦学习客户端选择的分类模型训练方法、系统、设备及介质。
第一方面,本申请实施例提供了一种基于联邦学习客户端选择的分类模型训练方法,所述方法包括:
初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;
在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端;
将所述全局模型发送至所述K个客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型;
更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,并重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;
使用训练好的全局模型对目标数据集进行分类,得到分类结果。
在其中一个实施例中,所述在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定参与本轮训练的K个客户端还包括:
若本轮训练为第一轮训练,则选择客户端备选集合中所有的客户端参与本轮训练,并更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵。
在其中一个实施例中,所述在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端包括:
初始化客户端选择集合为空集;
基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端;
在每一次选择客户端的过程中,判断各所述第一客户端以及对应的第二客户端是否在当前的客户端选择集合中,得到当前的被选中的结果;
基于所述当前的被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端,将所述被选择客户端从所述客户端备选集合移出,并将其添加至所述客户端选择集合中;直至所述客户端选择集合中包含K个客户端。
在其中一个实施例中,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端包括:
对于所述客户端备选集合中任一第一客户端,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,搜索与所述第一客户端的相似度小于第一约束参数,且与所述第一客户端在之前所有训练轮次中被选中频率差别最大的客户端为第二客户端。
在其中一个实施例中,所述基于所述被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端的计算公式如下:
;
;
;
;
其中,i m 为被选择客户端,Z i (t)和Q i (t)为虚拟队列,V为权衡因子,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否在当前的客户端选择集合中,/>为第二客户端在第t轮训练中是否在当前的客户端选择集合中,
表示客户端i与客户端j之间的相似度,客户端i在所述客户端备选集合中,所述客户端j在所述客户端选择集合中,S t 为所述客户端选择集合。
在其中一个实施例中,所述更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵包括:
根据K个客户端进行并行训练的获得的梯度,更新所述客户端相似度矩阵,所述客户端相似度矩阵中第i行、第j列的元素更新方式为:
;
其中,为第t轮训练中第i个客户端进行并行训练后获得的梯度值,/>为第t轮训练中第j个客户端进行并行训练后获得的梯度值,S t 为所述客户端选择集合;
所述虚拟队列Z i (t)和Q i (t)的更新方式为:
;
;
其中,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否被选中的结果,为第二客户端在第t轮训练中是否被选中的结果;
基于所述参与本轮训练的K个客户端,将所述客户端选中频率矩阵中的对应元素进行更新。
在其中一个实施例中,所述基于各所述客户端的梯度得到聚合后的全局模型包括:
;
其中,w t+1为第t+1轮聚合后的全局模型,w t 为第t轮聚合后的全局模型;为学习速率,/>为第t轮中第i个客户端进行并行训练的获得的梯度值。
第二方面,本申请实施例还提供了一种基于联邦学习客户端选择的分类模型训练系统,所述系统包括:
初始化模块,用于初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;
训练模块,用于在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端;
获得模块,用于将所述全局模型发送至各所述客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型;
更新模块,用于更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;
分类模块,用于使用训练好的全局模型对目标数据集进行分类,得到分类结果。
第三方面,本申请实施例还提供了一种计算机设备,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行如上述第一方面所述的方法。
第四方面,本申请实施例还提供了一种计算机可读存储介质,所述存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现如上述第一方面所述的方法。
上述基于联邦学习客户端选择的分类模型训练方法、系统及介质,通过初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端;将所述全局模型发送至所述K个客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型;更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,并重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;使用训练好的全局模型对目标数据集进行分类,得到分类结果。解决了联邦学习过程中没有考虑客户端数据分布的多样性和公平性约束,导致基于联邦学习客户端选择的图像分类精度较低的问题,提高了分类的精度。
本申请的一个或多个实施例的细节在以下附图和描述中提出,以使本申请的其他特征、目的和优点更加简明易懂。
附图说明
此处所说明的附图用来提供对本申请的进一步理解,构成本申请的一部分,本申请的示意性实施例及其说明用于解释本申请,并不构成对本申请的不当限定。在附图中:
图1是一个实施例中基于联邦学习客户端选择的分类模型训练方法的应用环境示意图;
图2是一个实施例中基于联邦学习客户端选择的分类模型训练方法的流程示意图;
图3是一个实施例中执行S202具体步骤的流程示意图;
图4是一个实施例中基于FMNIST数据集使用本申请方法与其他方法的图像分类准确率对比图;
图5是一个实施例中基于CIFAR数据集使用本申请方法与其他方法的图像分类准确率对比图;
图6是一个实施例中基于联邦学习客户端选择的分类模型训练系统的结构框图;
图7是一个实施例中计算机设备结构示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行描述和说明。应当理解,此处所描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。基于本申请提供的实施例,本领域普通技术人员在没有作出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。
显而易见地,下面描述中的附图仅仅是本申请的一些示例或实施例,对于本领域的普通技术人员而言,在不付出创造性劳动的前提下,还可以根据这些附图将本申请应用于其他类似情景。此外,还可以理解的是,虽然这种开发过程中所作出的努力可能是复杂并且冗长的,然而对于与本申请公开的内容相关的本领域的普通技术人员而言,在本申请揭露的技术内容的基础上进行的一些设计,制造或者生产等变更只是常规的技术手段,不应当理解为本申请公开的内容不充分。
在本申请中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域普通技术人员显式地和隐式地理解的是,本申请所描述的实施例在不冲突的情况下,可以与其它实施例相结合。
除非另作定义,本申请所涉及的技术术语或者科学术语应当为本申请所属技术领域内具有一般技能的人士所理解的通常意义。本申请所涉及的“一”、“一个”、“一种”、“该”等类似词语并不表示数量限制,可表示单数或复数。本申请所涉及的术语“包括”、“包含”、“具有”以及它们任何变形,意图在于覆盖不排他的包含;例如包含了一系列步骤或模块(单元)的过程、方法、系统、产品或设备没有限定于已列出的步骤或单元,而是可以还包括没有列出的步骤或单元,或可以还包括对于这些过程、方法、产品或设备固有的其它步骤或单元。本申请所涉及的“连接”、“相连”、“耦接”等类似的词语并非限定于物理的或者机械的连接,而是可以包括电气的连接,不管是直接的还是间接的。本申请所涉及的“多个”是指两个或两个以上。“和/或”描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。本申请所涉及的术语“第一”、“第二”、“第三”等仅仅是区别类似的对象,不代表针对对象的特定排序。
随着互联网和信息技术的发展,各行各业积累了大量的数据。然而,这些数据通常以碎片化、离散化的形式分布在不同行业或移动设备中。受限于隐私法规约束和价值分配等难题,各个机构的数据无法直接交换。联邦学习是一种分布式的机器学习框架,能够在保证用户数据不出本地的前提下,通过传输加密后的模型参数或梯度,进而实现多方联合建模。联邦学习目前已经成为解决数据协作与隐私保护矛盾的新兴方法。目前,选择哪些客户端参与训练成为影响联邦学习性能的关键问题。
例如,考虑一个联邦学习系统,其中有4个客户端,客户端1和2的数据分布非常相似,客户端3和4的数据分布也非常相似。在这种情况下,当所有客户端都参与训练时,模型的性能最优。但如果每轮只能选择两个客户端参与训练,合适的选择策略应该是从客户端备选集合{1,2}和客户端备选集合{3,4}中分别选择一个客户端,例如选择客户端{1,3}参与训练。相反,如果同时选择客户端{1,2}参与训练,模型的泛化性能将会受损,因为模型没有学习到客户端3或4的信息。因此,本申请强调应该选择具有不同数据分布的客户端集合,以提高模型的泛化性能。
此外,虽然选择客户端{1,3}能够满足数据多样性的要求,但如果每轮训练都固定选择客户端{1,3},则会导致严重的不公平问题。具体来说,客户端1和2的分布是相似的,从公平性的角度考虑,它们在长期训练过程中被选中的概率也应该是相似的。如果某个客户端长期不被选中,则可能导致它退出联邦学习系统。在考虑公平性的情况下,合适的客户端选择策略应该是交替选择客户端{1,3},{1,4},{2,3},{2,4}。这样,客户端1和2都有类似的机会被选中,从而减轻不公平问题。另外,如果固定选择客户端{1,3},则模型在它们的数据集上过度训练可能导致过拟合问题。
由于现有技术中没有考虑客户端数据分布的多样性和公平性约束,导致对联邦学习的性能改进非常有限,进而导致基于联邦学习客户端选择的图像分类精度较低。针对上述问题,本申请实施例提供一种基于联邦学习客户端选择的分类模型训练方法。
本申请实施例提供的基于联邦学习客户端选择的分类模型训练方法,可以应用于如图1所示的应用环境中。其中,服务器102通过网络与N个客户端104进行通信。服务器102首先初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵,在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在具有N个客户端104的客户端备选集合中确定参与本轮训练的K个客户端104,并将所述全局模型发送至所述K个客户端104。K个客户端104利用各自的本地数据并行训练模型得到梯度值,并将梯度值上传至服务器102。服务器102根据客户端104上传的梯度更新相关参数,并对客户端104上传的梯度聚合,并得到聚合后的全局模型。服务器102和N个客户端104重复迭代过程,直到达到设定的训练次数。服务器102使用训练好的全局模型对目标数据集进行分类,得到分类结果。
本申请实施例提供了一种基于联邦学习客户端选择的分类模型训练方法,如图2所示,以该方法应用于图1中的应用环境为例进行说明,所述方法包括以下步骤:
S201,初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵。
具体的,服务器初始化全局模型w 0 ;并初始化虚拟队列Z i (0)=0和Q i (0)=0,其中i=1,…,N;初始化客户端相似度矩阵,N代表全体客户端的数目,矩阵中第i行、第j列的元素代表第i个客户端与第j个客户端之间的相似度,将客户端相似度矩阵中所有元素初始化为0;初始化客户端选中频率矩阵/>,T代表总的训练轮数,矩阵中第i行、第j列的元素代表第i个客户端在第t轮训练中是否被选中,将矩阵中所有元素初始化为0,若在后续训练过程中第i个客户端被选中则将矩阵中对应元素置为1。
S202,在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端。
具体的,本申请在每一次迭代训练过程中,根据所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在具有N个客户端备选集合中确定参与本轮训练的K个客户端,并有K<N。
S203,将所述全局模型发送至所述K个客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型。
S204,更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,并重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型。
具体的,当本轮训练完成后,更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,并重复S202至S204,直到达到设定的迭代训练次数T,以获得训练好的全局模型。
S205,使用训练好的全局模型对目标数据集进行分类,得到分类结果。
本申请提供了一种新的联邦学习客户端选择的分类模型训练方法,此方法能够在保证公平性约束的同时最大化数据分布的多样性。相较于现有方法,本申请的优势体现在两个方面:第一个是性能方面,本申请能够加快联邦全局模型收敛的速度,并提高模型的预测准确率;第二个是公平性方面,本申请能够保证每个客户端都有机会参与训练,并且数据分布相似的客户端被选中的概率也相似,从而提升了客户端参与联邦学习系统的积极性,有利于联邦学习系统的可持续发展。通过本申请的联邦学习客户端选择方法,在每轮的迭代训练中得到满足公平性约束同时最大化数据多样性的客户端备选集合,遵循公平性约束,以确保每个客户端都有机会参与训练,可以缓解模型训练过拟合问题,进一步提升模型性能,从而提高了分类的精度。
在其中一个实施例中,所述在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定参与本轮训练的K个客户端还包括:若本轮训练为第一轮训练,则选择客户端备选集合中所有的客户端参与本轮训练,并更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵。
具体的,如果是第一轮训练,即当t=1时,则选择所有客户端参与训练,即,St为客户端选择集合,表示第t轮训练中被选中的客户端备选集合,当本轮训练完成后,更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵。如果/>时,则基于上一轮训练后更新的所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端,即。
在其中一个实施例中,如图3所示,所述在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端包括以下步骤:
S301,初始化客户端选择集合为空集。
具体的,初始化客户端选择集合为空集,即,以及构造一个客户端备选集合,客户端备选集合P中包含了全体的客户端。
S302,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端。
具体的,首先根据所述客户端相似度矩阵和所述客户端选中频率矩阵,对于客户端备选集合P中任意一个第一客户端(客户端i),确定与所述第一客户端对应的第二客户端(客户端)。
S303,在每一次选择客户端的过程中,判断各所述第一客户端以及对应的第二客户端是否在当前的客户端选择集合中,得到当前的被选中的结果。
S304,基于所述当前的被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端,将所述被选择客户端从所述客户端备选集合移出,并将其添加至所述客户端选择集合中;直至所述客户端选择集合中包含K个客户端。
具体的,基于所述当前的被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次根据当前的客户端选择集合,在当前客户端备选集合P中确定一个被选择客户端,将所述被选择客户端从所述客户端备选集合P移出,并将其添加至所述客户端选择集合St中。重复S302至S304的方法,逐个选择第K个客户端,直到客户端选择集合St中包含K个客户端。具体的,在第一轮训练时,所述客户端选择集合为空集,所述第一客户端所对应的第二客户端都不在客户端选择集合中,则被选中的结果为未被选中,对应的被选中的结果均为0。随着选择过程的进行,某些客户端可能被选中并加入所述客户端选择集合,当这些被选中的客户端是第二客户端时,则对应的被选中的结果设置为1。
在其中一个实施例中,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端包括:
对于所述客户端备选集合中任一第一客户端,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,搜索与所述第一客户端的相似度小于第一约束参数,且与所述第一客户端在之前所有训练轮次中被选中频率差别最大的客户端为第二客户端。
其中,第一约束参数用表示,所述第一客户端与所述第二客户端的相似度可通过客户端相似度矩阵获得,所述第一客户端与所述第二客户端在本轮训练中被选中频率可通过客户端选中频率矩阵获得,其中,所述第一客户端在本轮训练中被选中频率为前t-1轮训练中被选中结果的平均值,所述第二客户端在本轮训练中被选中频率为前t-1轮训练中被选中结果的平均值。
在其中一个实施例中,所述基于所述被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端的计算公式如下:
;
;
;
;
其中,i m 为被选择客户端,Z i (t)和Q i (t)为虚拟队列,V为权衡因子,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否在当前的客户端选择集合中,/>为第二客户端在第t轮训练中是否在当前的客户端选择集合中,
表示客户端i与客户端j之间的相似度,客户端i在所述客户端备选集合中,所述客户端j在所述客户端选择集合中,St为所述客户端选择集合。
具体的,在第t轮训练过程中,计算从所述客户端备选集合中选择哪个客户端,本申请执行以下步骤:
步骤1:初始化客户端选择集合St为空集。
步骤2:将所述客户端备选集合中所有第一客户端对应的被选中结果x i,t 都设置为1,计算对应的,并将L值最大时对应的客户端作为被选中的客户端。
步骤3:将被选中的客户端从所述客户端备选集合中移出,放入所述客户端选择集合。
步骤4:确定被选中的客户端是哪个客户端所对应的第二客户端,并将被选中第二客户端的被选中结果设置为1;
步骤5:重复上述过程,直至从所述客户端备选集合中选出K个客户端。
在其中一个实施例中,所述更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵包括以下内容:
根据K个客户端进行并行训练的获得的梯度,更新所述客户端相似度矩阵,所述客户端相似度矩阵中第i行、第j列的元素更新方式为:
;
其中,为第t轮训练中第i个客户端进行并行训练后获得的梯度值,/>为第t轮训练中第j个客户端进行并行训练后获得的梯度值,S t 为所述客户端选择集合;
所述虚拟队列Z i (t)和Q i (t)的更新方式为:
;
;
其中,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否被选中的结果,为第二客户端在第t轮训练中是否被选中的结果;
基于所述参与本轮训练的K个客户端,将所述客户端选中频率矩阵中的对应元素进行更新。
在其中一个实施例中,所述基于各所述客户端的梯度得到聚合后的全局模型使用以下公式:
;
其中,w t+1为第t+1轮聚合后的全局模型,w t 为第t轮聚合后的全局模型;为学习速率,/>为第t轮中第i个客户端进行并行训练的获得的梯度值。
下面通过优选实施例进行描述和说明。本实施例应用于一个典型的联邦学习系统,包含一个中央服务器和N个客户端。
本申请需要预先定义变量,包括:(1)权衡因子V,当V越大时代表方法更侧重多样性,当V越小时代表方法更侧重公平性;(2)学习速率,该值对应于训练阶段模型参数更新的步长;(3)全体客户端数目N,以及每轮训练中被选中的客户端数目K,并有K<N;(4)联邦学习的总训练轮数T;(5)公平性约束中的参数/>和/>。本申请最终输出是一个训练好的全局模型w T ,其中w代表模型参数,T代表总训练轮数。该方法的具体实现方式如下:
步骤1:服务器初始化模型w 0 ,并初始化虚拟队列Z i (0)=0,Q i (0)=0,其中i=1,…,N;初始化客户端相似度矩阵,矩阵中第i行、第j列的元素代表第i个客户端与第j个客户端之间的相似度,将客户端相似度矩阵中所有元素初始化为0;初始化客户端选中频率矩阵/>,T代表总的训练轮数,矩阵中第i行、第j列的元素代表第i个客户端在第t轮训练中是否被选中,将矩阵中所有元素初始化为0,若在后续训练过程中第i个客户端被选中则将矩阵中对应元素置为1。
步骤2:服务器选择参与本轮训练的客户端备选集合。具体分为以下步骤:如果是第一轮训练,即t=1时,则选择所有客户端参与训练,即客户端选择集合。如果时,则按照下述策略确定客户端选择集合St,以满足数据多样性和公平性的要求。
步骤2.1:初始化客户端选择集合为空集,即,以及构造一个临时集合为客户端备选集合/>,集合P中包含了全体的客户端。
步骤2.2:对于任意一个客户端i,基于客户端相似度矩阵D和客户端选中频率矩阵C,搜索与它的相似度小于但是被选中的频率差别最大的客户端/>。观察客户端i和客户端/>在第t轮是否被选中,并得到对应的被选中结果x i,t 和/>,进而根据下式计算m i,t 和n i,t ,其中:
;
;
步骤2.3:对于客户端备选集合P中所有客户端,根据下式求解的值,
步骤2.4:搜索使取值最大的客户端i m ,即:
步骤2.5:将客户端i m 加入集合St,并将i m 从集合P中移出,即:
步骤2.6:重复步骤2.2至步骤2.5,直到集合St中包含K个客户端。
步骤3:服务器将模型w t 发送给客户端选择集合St。
步骤4:客户端选择集合St中的各个客户端利用各自的本地数据并行训练模型,得到梯度值,并将梯度值上传至服务器。
步骤5:服务器根据客户端上传的梯度更新相关参数。
步骤5.1:首先更新客户端相似度矩阵D,其中第i行、第j列的元素更新方式为:
步骤5.2:按照下式更新虚拟队列Z i (t),Q i (t),
步骤5.3:根据参与本轮训练的K个客户端,将所述客户端选中频率矩阵中的对应元素置为1。
步骤6:服务器对客户端上传的梯度聚合,并得到聚合后的全局模型w t+1 ,
重复步骤2至步骤6,直到达到设定的训练次数T,训练好的全局模型w T 。
步骤7:使用训练好的全局模型w T 对目标数据集进行分类,得到分类结果。
为了便于理解,对本申请的具体实施方式中出现的公式做以下解释说明:
本申请考虑一个典型的联邦学习系统,即包含一个中央服务器和N个客户端,并使用表示全体客户端备选集合。假设每轮训练中,服务器选择K个客户端,并将第t轮被选中的客户端备选集合表示为St。
(1)数据多样性的度量方式
本申请认为全体客户端备选集合具有最大的数据多样性,通过量化数据分布的多样性,从而指导应该选择哪些客户端备选集合。并通过下式来量化客户端选择集合St的数据分布多样性:
在上式中,表示第i个客户端在第t轮训练中的模型梯度,/>表示全体客户端的模型梯度之和;/>表示被选中的客户端备选集合St中模型梯度的加权和,其中/>表示第j个客户端在第t轮训练时的权重;/>表示范数运算,用来刻画两者的差异,本申请使用的是二范数。当两者的差异越小时,代表客户端备选集合St的数据分布越接近全体的数据分布,则认为St中数据的多样性越高。
(2)个体公平约束
本申请通过x i,t 表示第i个客户端在第t轮训练是否被选中,如果被选中则x i,t =1,否则x i,t =0,本申请通过以下公式计算第i个客户端在整个训练过程中被选中的概率:
其中,p i 表示第i个客户端在整个训练过程中被选中的概率,它也等于第i个客户端在整个训练过程中被选中的次数,其中T是训练轮次的总数目,表示求平均值。为了使每个客户端都有机会参与训练,本申请引入了/>个体公平约束,即当两个客户端的相似度小于/>时,则在整个训练过程中它们被选择的概率差值应该小于/>。/>个体公平约束可以通过下式来描述:
其中d sim (i,j)是衡量两个客户端梯度相似性的指标。在本申请中相似性指标的计算公式为:
当两个客户端的梯度越相似时,该指标的值越大。
(3)优化目标
基于上述设定,客户端选择方法可以表示成下述优化问题:
其中是服务器在第1轮至第T轮选中的客户端备选集合,是客户端备选集合St与全体客户端数据分布的差异,差异越小代表客户端备选集合St的数据多样性越高。/>用来描述在整个T轮训练过程中两者差异的平均值,本申请通过最小化两者的差异以实现最大的数据多样性。此外,约束条件/>代表的是个体公平约束,约束条件/>代表客户端备选集合应该包含K个客户端。
(4)问题转化
在上述优化目标中,目标函数和被选中的概率pi都是整个T轮训练过程的平均值,这是一个长期约束。但是,在每轮训练开始时,服务器都需要实时在线地选择参与训练的客户端。因此,本申请引入李雅普诺夫(Lyapunov)函数来求解上述优化问题,从而将长期约束分解为每个训练轮次的在线选择问题。李雅普诺夫(Lyapunov)函数是用来证明动力系统或自治微分方程稳定性的函数,在动力系统稳定性理论及控制理论中具有重要的应用。
引入李雅普诺夫(Lyapunov)函数之后,上述优化目标可以转化为下述形式:
在上式中,
其中,x i,t 表示第i个客户端在第t轮训练是否被选中,如果被选中则x i,t =1,否则x i,t =0。通过客户端表示与当前客户端i的相似度小于/>但是被选中的概率差别最大的客户端,即:
表示第/>个客户端在第t轮训练是否被选中,/>是/>个体公平约束中预先定义的参数,Z i (t)与Q i (t)是为第i个客户端构造的两个虚拟队列,它们的初始条件为Z i (t)=0,Q i (t)=0,并且更新方式为:/>
其中,V是一个预先定义的权衡因子,通常取值在[0,1]之间,实验发现V=0.8的效果最好。
原问题需要在所有训练轮次T上求平均值才能得到一系列最优的解集合,而经过转换后只需要在当前训练轮次t求解最优的客户端备选集合St。这样处理有两个优势:第一个优势是原问题求解非常困难甚至无法求解,经过转换后简化了问题的求解方式;第二个优势是经过转换后问题可以快速求解,满足了联邦学习系统对实时性的要求。
为了验证本实施例所提供方法的有效性,本申请在FMNIST和CIFAR数据集上进行了验证。这两个数据集都是图像数据集,用于图像分类任务。本申请考虑联邦学习系统中共有100个客户端,并且服务器将每轮都选择10个客户端参与训练。本申请使用的预定义参数为。同时与现有的三种方法进行了对比,其中随机选择方法是指服务器随机地选择客户端,AFL方法和PowerD方法是选择训练损失较大的客户端,这些方法都没有考虑数据多样性和公平性。在FMNIST数据集中的验证结果如图4,在CIFAR数据集中的验证结果图5所示,横轴是训练轮次,纵轴是测试准确率,准确率越高代表模型的性能越好,图4和图5的结果表明本申请的方法在不同的数据集中取得了10%至20%的准确率提升,有效提高了图像分类的精度。需要说明的是,除了图像领域,本申请的方法也适用于其它领域的分类任务。
本申请实施例还提供了一种基于联邦学习客户端选择的分类模型训练系统,如图6所示,所述系统包括初始化模块10、训练模块20、获得模块30、更新模块40、分类模块50。
初始化模块10用于初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;
训练模块20用于在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端;
获得模块30用于将所述全局模型发送至各所述客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型;
更新模块40用于更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;
分类模块50用于使用训练好的全局模型对目标数据集进行分类,得到分类结果。
在其中一个实施例中,训练模块20还用于:若本轮训练为第一轮训练,则选择客户端备选集合中所有的客户端参与本轮训练,并更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵。
在其中一个实施例中,训练模块20还用于:初始化客户端选择集合为空集;
基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端;确定上一轮训练中各所述第一客户端以及对应的第二客户端是否被选中的结果;基于所述被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端,将所述被选择客户端从所述客户端备选集合移出,并将其添加至所述客户端选择集合中;直至所述客户端选择集合中包含K个客户端。
在其中一个实施例中,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端包括:
对于所述客户端备选集合中任一第一客户端,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,搜索与所述第一客户端的相似度小于第一约束参数,且与所述第一客户端在之前所有训练轮次中被选中频率差别最大的客户端为第二客户端。
在其中一个实施例中,所述基于所述被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端的计算公式如下:
;
;
;
;
其中,i m 为被选择客户端,Z i (t)和Q i (t)为虚拟队列,V为权衡因子,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否在当前的客户端选择集合中,/>为第二客户端在第t轮训练中是否在当前的客户端选择集合中,
表示客户端i与客户端j之间的相似度,客户端i在所述客户端备选集合中,所述客户端j在所述客户端选择集合中,S t 为所述客户端选择集合。
在其中一个实施例中,更新模块40还用于根据K个客户端进行并行训练的获得的梯度,更新所述客户端相似度矩阵,所述客户端相似度矩阵中第i行、第j列的元素更新方式为:
;
其中,为第t轮训练中第i个客户端进行并行训练后获得的梯度值,/>为第t轮训练中第j个客户端进行并行训练后获得的梯度值,S t 为所述客户端选择集合;
所述虚拟队列Z i (t)和Q i (t)的更新方式为:
;
;/>
其中,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否被选中的结果,为第二客户端在第t轮训练中是否被选中的结果;
基于所述参与本轮训练的K个客户端,将所述客户端选中频率矩阵中的对应元素进行更新。
在其中一个实施例中,所述基于各所述客户端的梯度得到聚合后的全局模型包括:
;
其中,w t+1为第t+1轮聚合后的全局模型,w t 为第t轮聚合后的全局模型;为学习速率,/>为第t轮中第i个客户端进行并行训练的获得的梯度值。
需要说明的是,上述各个模块可以是功能模块也可以是程序模块,既可以通过软件来实现,也可以通过硬件来实现。对于通过硬件来实现的模块而言,上述各个模块可以位于同一处理器中;或者上述各个模块还可以按照任意组合的形式分别位于不同的处理器中。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图7所示。该计算机设备包括通过系统总线连接的处理器、存储器和网络接口。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质和内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储图像数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种基于联邦学习客户端选择的分类模型训练方法。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现上述任一项基于联邦学习客户端选择的分类模型训练方法实施例中的步骤。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-Only Memory,ROM)、磁带、软盘、闪存或光存储器等。易失性存储器可包括随机存取存储器(RandomAccess Memory,RAM)或外部高速缓冲存储器。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic Random Access Memory,DRAM)等。
以上所述实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。
Claims (10)
1.一种基于联邦学习客户端选择的分类模型训练方法,其特征在于,所述方法包括:
初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;
在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端;
将所述全局模型发送至所述K个客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型;
更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,并重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;
使用训练好的全局模型对目标数据集进行分类,得到分类结果。
2.根据权利要求1所述的方法,其特征在于,所述在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定参与本轮训练的K个客户端还包括:
若本轮训练为第一轮训练,则选择客户端备选集合中所有的客户端参与本轮训练,并更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵。
3.根据权利要求2所述的方法,其特征在于,所述在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端包括:
初始化客户端选择集合为空集;
基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端;
在每一次选择客户端的过程中,判断各所述第一客户端以及对应的第二客户端是否在当前的客户端选择集合中,得到当前的被选中的结果;
基于所述当前的被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端,将所述被选择客户端从所述客户端备选集合移出,并将其添加至所述客户端选择集合中;直至所述客户端选择集合中包含K个客户端。
4.根据权利要求3所述的方法,其特征在于,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,确定所述客户端备选集合中任一第一客户端所对应的第二客户端包括:
对于所述客户端备选集合中任一第一客户端,基于所述客户端相似度矩阵以及所述客户端选中频率矩阵,搜索与所述第一客户端的相似度小于第一约束参数,且与所述第一客户端在之前所有训练轮次中被选中频率差别最大的客户端为第二客户端。
5.根据权利要求3所述的方法,其特征在于,所述基于所述当前的被选中的结果、所述虚拟队列以及所述客户端相似度矩阵,每次在所述客户端备选集合中确定一个被选择客户端的计算公式如下:
;
;
;
;
其中,i m 为被选择客户端,Z i (t)和Q i (t)为虚拟队列,V为权衡因子,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否在当前的客户端选择集合中,/>为第二客户端在第t轮训练中是否在当前的客户端选择集合中,
表示客户端i与客户端j之间的相似度,客户端i在所述客户端备选集合中,所述客户端j在所述客户端选择集合中,S t 为所述客户端选择集合。
6.根据权利要求5所述的方法,其特征在于,所述更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵包括:
根据K个客户端进行并行训练的获得的梯度,更新所述客户端相似度矩阵,所述客户端相似度矩阵中第i行、第j列的元素更新方式为:
;
其中,为第t轮训练中第i个客户端进行并行训练后获得的梯度值,/>为第t轮训练中第j个客户端进行并行训练后获得的梯度值,S t 为所述客户端选择集合;
所述虚拟队列Z i (t)和Q i (t)的更新方式为:
;
;
其中,为公平约束参数,x i,t 为第一客户端在第t轮训练中是否被选中的结果,/>为第二客户端在第t轮训练中是否被选中的结果;
基于所述参与本轮训练的K个客户端,将所述客户端选中频率矩阵中的对应元素进行更新。
7.根据权利要求1所述的方法,其特征在于,所述基于各所述客户端的梯度得到聚合后的全局模型包括:
;
其中,w t+1为第t+1轮聚合后的全局模型,w t 为第t轮聚合后的全局模型;为学习速率,为第t轮中第i个客户端进行并行训练的获得的梯度值。
8.一种基于联邦学习客户端选择的分类模型训练系统,其特征在于,所述系统包括:
初始化模块,用于初始化全局模型,并初始化虚拟队列、客户端相似度矩阵以及客户端选中频率矩阵;
训练模块,用于在每一次迭代训练过程中,基于所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,在客户端备选集合中确定参与本轮训练的K个客户端;
获得模块,用于将所述全局模型发送至各所述客户端进行并行训练,获得各所述客户端的梯度,并基于各所述客户端的梯度得到聚合后的全局模型;
更新模块,用于更新所述虚拟队列、所述客户端相似度矩阵以及所述客户端选中频率矩阵,重复所述迭代训练过程直至达到设定的迭代次数,获得训练好的全局模型;
分类模块,用于使用训练好的全局模型对目标数据集进行分类,得到分类结果。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至权利要求7中任一项所述的方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至权利要求7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410022912.2A CN117557870B (zh) | 2024-01-08 | 2024-01-08 | 基于联邦学习客户端选择的分类模型训练方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410022912.2A CN117557870B (zh) | 2024-01-08 | 2024-01-08 | 基于联邦学习客户端选择的分类模型训练方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117557870A true CN117557870A (zh) | 2024-02-13 |
CN117557870B CN117557870B (zh) | 2024-04-23 |
Family
ID=89818802
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410022912.2A Active CN117557870B (zh) | 2024-01-08 | 2024-01-08 | 基于联邦学习客户端选择的分类模型训练方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117557870B (zh) |
Citations (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112052480A (zh) * | 2020-09-11 | 2020-12-08 | 哈尔滨工业大学(深圳) | 一种模型训练过程中的隐私保护方法、系统及相关设备 |
CN113052334A (zh) * | 2021-04-14 | 2021-06-29 | 中南大学 | 一种联邦学习实现方法、系统、终端设备及可读存储介质 |
CN113191484A (zh) * | 2021-04-25 | 2021-07-30 | 清华大学 | 基于深度强化学习的联邦学习客户端智能选取方法及系统 |
DE102021108101A1 (de) * | 2020-06-02 | 2021-12-02 | Samsung Electronics Co., Ltd. | System und Verfahren für föderales Lernen unter Verwendung von anonymisierter Gewichtungsfaktorisierung |
CN114417417A (zh) * | 2022-01-24 | 2022-04-29 | 山东大学 | 一种基于联邦学习的工业物联网隐私保护系统及方法 |
CN114492829A (zh) * | 2021-12-10 | 2022-05-13 | 中国科学院自动化研究所 | 基于联邦学习场景的训练参与方的选择方法及装置 |
CN114595396A (zh) * | 2022-05-07 | 2022-06-07 | 浙江大学 | 一种基于联邦学习的序列推荐方法和系统 |
CN115600691A (zh) * | 2022-09-22 | 2023-01-13 | 深圳大学(Cn) | 联邦学习中的客户端选择方法、系统、装置和存储介质 |
CN115796271A (zh) * | 2022-11-11 | 2023-03-14 | 中国科学技术大学苏州高等研究院 | 基于客户端选择和梯度压缩的联邦学习方法 |
WO2023036184A1 (en) * | 2021-09-08 | 2023-03-16 | Huawei Cloud Computing Technologies Co., Ltd. | Methods and systems for quantifying client contribution in federated learning |
KR20230063629A (ko) * | 2021-11-02 | 2023-05-09 | 광주과학기술원 | 합의 기반의 연합 학습 방법 |
CN116167452A (zh) * | 2022-12-13 | 2023-05-26 | 重庆邮电大学 | 一种基于模型相似性的集群联邦学习的方法 |
US20230177349A1 (en) * | 2020-06-01 | 2023-06-08 | Intel Corporation | Federated learning optimizations |
CN117217328A (zh) * | 2023-09-04 | 2023-12-12 | 西安电子科技大学 | 基于约束因子的联邦学习客户端选择方法 |
-
2024
- 2024-01-08 CN CN202410022912.2A patent/CN117557870B/zh active Active
Patent Citations (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20230177349A1 (en) * | 2020-06-01 | 2023-06-08 | Intel Corporation | Federated learning optimizations |
DE102021108101A1 (de) * | 2020-06-02 | 2021-12-02 | Samsung Electronics Co., Ltd. | System und Verfahren für föderales Lernen unter Verwendung von anonymisierter Gewichtungsfaktorisierung |
CN112052480A (zh) * | 2020-09-11 | 2020-12-08 | 哈尔滨工业大学(深圳) | 一种模型训练过程中的隐私保护方法、系统及相关设备 |
CN113052334A (zh) * | 2021-04-14 | 2021-06-29 | 中南大学 | 一种联邦学习实现方法、系统、终端设备及可读存储介质 |
CN113191484A (zh) * | 2021-04-25 | 2021-07-30 | 清华大学 | 基于深度强化学习的联邦学习客户端智能选取方法及系统 |
WO2023036184A1 (en) * | 2021-09-08 | 2023-03-16 | Huawei Cloud Computing Technologies Co., Ltd. | Methods and systems for quantifying client contribution in federated learning |
KR20230063629A (ko) * | 2021-11-02 | 2023-05-09 | 광주과학기술원 | 합의 기반의 연합 학습 방법 |
CN114492829A (zh) * | 2021-12-10 | 2022-05-13 | 中国科学院自动化研究所 | 基于联邦学习场景的训练参与方的选择方法及装置 |
CN114417417A (zh) * | 2022-01-24 | 2022-04-29 | 山东大学 | 一种基于联邦学习的工业物联网隐私保护系统及方法 |
CN114595396A (zh) * | 2022-05-07 | 2022-06-07 | 浙江大学 | 一种基于联邦学习的序列推荐方法和系统 |
CN115600691A (zh) * | 2022-09-22 | 2023-01-13 | 深圳大学(Cn) | 联邦学习中的客户端选择方法、系统、装置和存储介质 |
CN115796271A (zh) * | 2022-11-11 | 2023-03-14 | 中国科学技术大学苏州高等研究院 | 基于客户端选择和梯度压缩的联邦学习方法 |
CN116167452A (zh) * | 2022-12-13 | 2023-05-26 | 重庆邮电大学 | 一种基于模型相似性的集群联邦学习的方法 |
CN117217328A (zh) * | 2023-09-04 | 2023-12-12 | 西安电子科技大学 | 基于约束因子的联邦学习客户端选择方法 |
Non-Patent Citations (2)
Title |
---|
MA, ZEZHONG: ""Fast-convergent federated learning with"", 《JOURNAL OF SYSTEMS ARCHITECTURE》, 3 July 2021 (2021-07-03) * |
王亚?;: "面向数据共享交换的联邦学习技术发展综述", 无人系统技术, no. 06, 15 November 2019 (2019-11-15) * |
Also Published As
Publication number | Publication date |
---|---|
CN117557870B (zh) | 2024-04-23 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Alain et al. | Variance reduction in sgd by distributed importance sampling | |
WO2020094060A1 (zh) | 推荐方法及装置 | |
US20190095464A1 (en) | Dual deep learning architecture for machine-learning systems | |
CN111259738B (zh) | 人脸识别模型构建方法、人脸识别方法及相关装置 | |
CN113408209A (zh) | 跨样本联邦分类建模方法及装置、存储介质、电子设备 | |
CN108665089B (zh) | 一种用于选址问题的鲁棒优化模型求解方法 | |
CN117150821B (zh) | 基于智能化仿真的装备效能评估数据集的构建方法 | |
US20230140696A1 (en) | Method and system for optimizing parameter intervals of manufacturing processes based on prediction intervals | |
WO2023036184A1 (en) | Methods and systems for quantifying client contribution in federated learning | |
CN116645130A (zh) | 基于联邦学习与gru结合的汽车订单需求量预测方法 | |
CN114925854A (zh) | 一种基于梯度相似性度量的联邦学习节点选择方法及系统 | |
CN117557870B (zh) | 基于联邦学习客户端选择的分类模型训练方法及系统 | |
CN110113180B (zh) | 一种基于偏置张量分解的云服务响应时间预测方法和装置 | |
CN114677547B (zh) | 一种基于自保持表征扩展的类增量学习的图像分类方法 | |
CN116258923A (zh) | 图像识别模型训练方法、装置、计算机设备和存储介质 | |
CN115630566A (zh) | 一种基于深度学习和动力约束的资料同化方法和系统 | |
CN115941804A (zh) | 一种算力路径的推荐方法和装置 | |
CN111027709B (zh) | 信息推荐方法、装置、服务器及存储介质 | |
CN115730631A (zh) | 联邦学习的方法和装置 | |
US20240177063A1 (en) | Information processing apparatus, information processing method, and non-transitory recording medium | |
CN113780526B (zh) | 人脸识别网络训练的方法、电子设备及存储介质 | |
CN111369374B (zh) | 一种基于概率产生式的社交网络时序链接预测方法及装置 | |
CN117255041A (zh) | 一种基于张量Tucker分解的网络吞吐量预测方法和装置 | |
CN117972554A (zh) | 基于联邦学习和知识蒸馏的多分类模型优化方法和装置 | |
CN116306963A (zh) | 一种应用于云架构下的边缘学习建模训练方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |