CN116522988A - 基于图结构学习的联邦学习方法、系统、终端及介质 - Google Patents
基于图结构学习的联邦学习方法、系统、终端及介质 Download PDFInfo
- Publication number
- CN116522988A CN116522988A CN202310804065.0A CN202310804065A CN116522988A CN 116522988 A CN116522988 A CN 116522988A CN 202310804065 A CN202310804065 A CN 202310804065A CN 116522988 A CN116522988 A CN 116522988A
- Authority
- CN
- China
- Prior art keywords
- model parameters
- target user
- graph
- local model
- optimized
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 55
- 238000012549 training Methods 0.000 claims abstract description 45
- 238000005070 sampling Methods 0.000 claims abstract description 20
- 239000011159 matrix material Substances 0.000 claims description 45
- 238000005259 measurement Methods 0.000 claims description 21
- 230000002776 aggregation Effects 0.000 claims description 13
- 238000004220 aggregation Methods 0.000 claims description 13
- 230000004931 aggregating effect Effects 0.000 claims description 11
- 238000003860 storage Methods 0.000 claims description 11
- 238000000691 measurement method Methods 0.000 claims description 4
- 238000004891 communication Methods 0.000 abstract description 9
- 238000005457 optimization Methods 0.000 abstract description 6
- 230000008569 process Effects 0.000 description 11
- 238000004590 computer program Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 7
- 238000009826 distribution Methods 0.000 description 4
- 230000006870 function Effects 0.000 description 4
- 230000004044 response Effects 0.000 description 4
- 238000001514 detection method Methods 0.000 description 3
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 239000013598 vector Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Machine Translation (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于图结构学习的联邦学习方法、系统、终端及介质,每轮次训练时,在所有用户端中采样出若干个目标用户端参与本轮次的训练,根据全局模型参数更新目标用户端的本地模型参数后,目标用户端迭代优化本地模型获得优化后的本地模型参数,然后采用图网络模型学习目标用户端之间的异质性并根据异质性来聚合所有目标用户端的优化后的本地模型参数,更新全局模型参数,循环进行迭代直至完成模型的优化。通过对用户端采样以减少参与训练的用户端的数量,可以减少每轮次训练的通信开销;通过采用图网络模型学习目标用户端之间的异质性,能够自适应地聚合目标用户端的优化后的本地模型参数,提高训练效率,获得鲁棒性好的优化后的全局模型。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及的是一种基于图结构学习的联邦学习方法、系统、终端及介质。
背景技术
人工智能的发展需要大量的数据,并且需要许多高质量的数据。而在医疗、金融、通信等领域,受限于数据安全、个人隐私等约束条件,数据来源方无法直接交换数据,形成“数据孤岛”现象,制约着人工智能模型能力的进一步提高。因此,在这些领域常使用联邦学习方法来解决机器学习中的数据孤岛问题。
在联邦学习过程中,先设定聚类超参,然后对用户端进行聚类获得各个用户端的权重,再对用户端上传的本地模型参数进行加权聚合。然而由于存在数据异质性问题(如各个用户端的数据分布各不相同),难以准确地设定聚类超参,尤其是在大量用户端的联邦学习场合。因此,存在分类偏差,模型参数聚合不准确的情况,导致训练效率低、优化后的模型鲁棒性差。
发明内容
本发明的主要目的在于提供一种基于图结构学习的联邦学习方法、系统、智能终端及存储介质,能够解决联邦学习时训练效率低、优化后的模型鲁棒性差的问题。
为了实现上述目的,本发明第一方面提供一种基于图结构学习的联邦学习方法,所述方法包括:
初始化全局模型参数;
在所有用户端中采样,获得多个目标用户端;
根据所述全局模型参数更新所有所述目标用户端的本地模型参数;
基于所述目标用户端各自的训练数据和本地模型参数,在所述目标用户端迭代优化所述目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
返回在所有用户端中采样以重新获得所述全局模型参数,直至满足预设条件并输出优化后的全局模型。
可选的,所述图网络模型为图注意力模型,所述将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数,包括:
根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵;
将所述邻接矩阵和所有所述优化后的本地模型参数输入所述图注意力模型,获得所述全局模型参数。
可选的,所述根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵,包括:
基于所述优化后的本地模型参数,根据余弦相似度度量方法计算每两个目标用户端之间的相似度度量值;
根据所有的相似度度量值构建所述邻接矩阵。
可选的,所述将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数,包括:
基于所述邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
计算所有所述更新值的均值,获得所述全局模型参数。
可选的,还设有多层感知机,所述将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数,包括:
基于所述邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
将所有的所述更新值输入多层感知机,获得每一类目标用户端的分值;
根据所述分值对所有所述更新值进行加权平均,获得所述全局模型参数。
本发明第二方面提供一种基于图结构学习的联邦学习系统,其中,上述系统包括:
初始化模块,用于初始化全局模型参数;
采样模块,用于在所有用户端中采样,获得多个目标用户端;
参数更新模块,用于根据所述全局模型参数更新所有目标用户端的本地模型参数;
训练模块,用于基于所述目标用户端各自的训练数据和本地模型参数,在所述目标用户端迭代优化所述目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
聚合模块,用于将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
迭代模块,用于返回在所有用户端中采样以重新获得所述全局模型参数,直至满足预设条件并输出优化后的全局模型。
可选的,所述图网络模型为图注意力模型,所述聚合模块还包括邻接矩阵单元,所述邻接矩阵单元用于根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵;所述图注意力模型用于根据输入的所述邻接矩阵和所有所述优化后的本地模型参数获得所述全局模型参数。
可选的,还设有多层感知机,所述图注意力模型用于基于所述邻接矩阵中的相似度度量值获得每一类目标用户端的优化后的本地模型参数的更新值,所述多层感知机用于基于所述更新值获得每一类目标用户端的分值,所述聚合模块用于根据所述分值对所有更新值进行加权平均,获得所述全局模型参数。
本发明第三方面提供一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的基于图结构学习的联邦学习程序,上述基于图结构学习的联邦学习程序被上述处理器执行时实现任意一项上述基于图结构学习的联邦学习方法的步骤。
本发明第四方面提供一种计算机可读存储介质,上述计算机可读存储介质上存储有基于图结构学习的联邦学习程序,上述基于图结构学习的联邦学习程序被处理器执行时实现任意一项上述基于图结构学习的联邦学习方法的步骤。
由上可见,本发明每轮次训练时,在所有用户端中采样出若干个目标用户端参与本轮次的训练,根据全局模型参数更新目标用户端的本地模型参数后,目标用户端迭代优化本地模型获得优化后的本地模型参数,然后采用图网络模型学习目标用户端之间的异质性并根据异质性来聚合所有目标用户端的优化后的本地模型参数以更新全局模型参数,循环进行迭代直至完成模型的优化。通过对用户端采样以减少参与训练的用户端的数量,可以减少每轮次训练的通信开销;通过采用图网络模型学习目标用户端之间的异质性,能够自适应地聚合目标用户端的优化后的本地模型参数,训练效率高,获得鲁棒性好的优化后模型。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其它的附图。
图1是本发明实施例提供的基于图结构学习的联邦学习方法流程示意图;
图2是图1实施例的联邦学习框架示意图;
图3是图1实施例中获得全局模型参数的具体流程示意图;
图4是本发明实施例根据图注意力模型获得全局模型参数的具体流程示意图;
图5是本发明另一实施例根据图注意力模型获得全局模型参数的具体流程示意图;
图6是本发明实施例提供的基于图结构学习的联邦学习系统的结构示意图;
图7是本发明实施例提供的一种智能终端的内部结构原理框图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本发明实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本发明。在其它情况下,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本发明的描述。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当…时”或“一旦”或“响应于确定”或“响应于检测到”。类似的,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述的条件或事件]”或“响应于检测到[所描述条件或事件]”。
下面结合本发明实施例的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明的一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
在下面的描述中阐述了很多具体细节以便于充分理解本发明,但是本发明还可以采用其它不同于在此描述的其它方式来实施,本领域技术人员可以在不违背本发明内涵的情况下做类似推广,因此本发明不受下面公开的具体实施例的限制。
典型的联邦学习由若干个用户端和一个服务端构成,目标是在分布式用户端上训练一个共享模型,同时避免暴露它们的训练数据,上述共享模型部署在用户端上时又称为本地模型。联邦学习过程分为自治和联合两部分。首先,两个或两个以上的用户端在各自终端安装初始化的共享模型,每个用户端拥有相同的模型,之后用户端可以使用当地的数据进行训练。由于用户端拥有不同的数据,最终各用户端训练完毕的模型也拥有不同的模型参数(即本地模型参数)。将不同的本地模型参数同时上传到服务端,服务端将完成本地模型参数的聚合与更新,并且将更新好的模型参数返回到用户端以更新用户端的本地模型参数,各个用户端再开始下一次的迭代。以上的迭代过程会一直重复,直到整个训练过程达到收敛条件。
常规的联邦学习通常假设用户的本地数据都是敏感信息,因此本地模型参数聚合时一视同仁。而实际的联邦学习场景常存在以下三个数据异质性,即统计异质性(Statistical heterogeneity)、隐私异质性(Privacy heterogeneity)和模型异质性(Model heterogeneity)。其中,统计异质性认为用户的数据在不同的客户端是不服从独立同分布的(即Non Independent Identically Distribution:non-IID数据),隐私异质性认为用户的本地数据应该包含公开的和敏感的信息,这样就需要对数据进行不同程度的隐私保护,模型异质性认为不同客户端的模型需要被自适应的在服务端进行聚合。
针对数据异质性的问题,虽然目前已采用设定聚类超参后对本地模型参数进行聚类(相当于对用户端分类)并加权聚合,但是聚类超参难以准确地设定,尤其是在大量用户端的联邦学习场景,导致用户端的本地模型训练效率低、鲁棒性差。
本发明提供了一种基于图结构学习的联邦学习方法,将联邦学习场景下对用户端的分类问题转化为一种基于图结构的软分类问题,即使用图网络模型对不同数据异质程度的用户端进行更为精细且灵活的分类,用户端的本地模型参数聚合的准确程度高,使得用户端的本地模型训练效率高、鲁棒性好。
本发明实施例提供了一种基于图结构学习的联邦学习方法,作为服务端部署在服务器上,用来对分布式金融终端的网络模型进行训练。具体的,如图1所示,本实施例包括如下步骤:
步骤S100:初始化全局模型参数;
步骤S200:在所有用户端中采样,获得多个目标用户端;
步骤S300:根据全局模型参数更新所有目标用户端的本地模型参数;
具体地,用户端为分布式部署的金融终端。全局模型参数保存在服务端,为用户端的本地模型参数聚合后的结果。
联邦学习训练开始前,需要保证所有用户端的本地模型参数相同。在服务端对全局模型参数做随机初始化,然后服务端将全局模型参数传输至用户端,使得用户端的本地模型参数在初始时保持一致。
考虑到较多用户端参与训练时,服务端的通信带宽有限导致的通信拥堵和用户端上传本地模型参数时产生的通信开销问题,每轮次的迭代训练时,在所有用户端中进行采样,随机采样出预设数量的用户端作为目标用户端参与当前通信轮次下的训练。预设数量不做限制,一般为联邦学习场景用户端总数的某一比例。需要说明的是,一个完整的轮次包括以下过程:服务端将全局模型参数传输至用户端更新本地模型参数、用户端对本地模型进行迭代优化以获得优化后的本地模型参数,将优化后的本地模型参数传输至服务端进行聚合。
通过采用用户端采样的方法能够使得通信开销得以降低,可用于较多用户端参与训练的场景。
步骤S400:基于目标用户端各自的训练数据、本地模型参数,在目标用户端迭代优化目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
具体地,各个目标用户端进行训练时,本地模型参数相同,训练数据不同。利用本地模型参数对目标用户端的网络模型进行初始化后,利用本地的训练数据对本地模型采用随机梯度下降方法进行本地迭代训练,本地迭代训练收敛后从每个目标用户端各自的本地模型得到各个目标用户端的优化后的本地模型参数,且每个目标用户端的优化后的本地模型参数各不相同,并将所有的优化后的本地模型参数传输至服务端。
步骤S500:将所有的优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
具体地,异质性体现了不同目标用户端之间训练数据的相似程度,异质性高的两个目标用户端的优化后的本地模型参数之间的差异也大。在对本地模型参数聚合过程中,目前通过人为设定聚类超参来对目标用户端的本地模型参数进行分类,较为武断,难以准确地反映目标用户端之间的异质性,目标用户端的分类出现偏差,导致根据目标用户端的分类对各个目标用户端的本地模型参数进行加权聚合获得的全局模型参数不准确,性能损失较大。因此,本发明采用图网络模型对优化后的本地模型参数进行分析来将异质性低的目标用户端进行组合,降低异质性低的目标用户端占用的权重,使得全局模型参数能够更加真实地反映联邦学习场景。
服务端在获得所有目标用户端的优化后的本地模型参数后,计算优化后的本地模型参数之间的相似程度,将相似程度作为目标用户端相互之间的异质性值。采用图结构描述目标用户端及其之间的异质性关系,以目标用户端为节点,目标用户端相互之间的异质性值为边。然后将图结构对应的图数据输入图网络模型,采用图网络模型学习目标用户端之间的异质性,根据异质性实现目标用户端的自动分类,并根据目标用户端之间的异质性自适应地调整各个目标用户端的权重来聚合所有的优化后的本地模型参数,获得更加准确和有效的全局模型参数,使得迭代训练收敛更快。具体地,图网络模型可以提取用户端的图结构的空间特征,利用异质性值对目标用户端进行聚合从而生成新的目标用户端表示,实现目标用户端的自动分类,然后累计每个目标用户端的边获得各个目标用户端的权重。图网络模型的基础架构可以为各种图神经网络(Graph Neural Networks:GNN) ,如:图卷积网络 (Graph Convolution Networks:GCN)、图注意力模型(Graph Attention Networks:GAT)、图生成网络(Graph Generative Networks:GGN)等。
参考如图2所示的联邦学习框架,本实施例中获得全局模型参数的具体步骤如图3所示,包括:
步骤S510:根据优化后的本地模型参数,计算用于表征目标用户端之间的连接的邻接矩阵;
具体地,获得优化后的本地模型参数后,将每个目标用户端的本地模型参数矢量化,然后根据余弦相似度度量方法计算每两个矢量之间的相似度,该相似度为每两个目标用户端之间的相似度度量值,再根据所有的相似度度量值构建行列式矩阵,获得用户结构邻接矩阵(以下简称邻接矩阵)。邻接矩阵中的每个元素能够表征目标用户端之间的连接关系,元素值越大,表示两个目标用户端越相似,目标用户端之间的连接越紧密。
步骤S520:将邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数。
具体地,在图注意力模型中,图中的每个节点能够根据相邻节点的特征,为其分配不同的权值;并且引入注意力机制后,只与相邻节点有关,即共享边的节点有关,无需得到整张图的信息。更加适应于联邦学习场景的不确定图结构问题。
将邻接矩阵和所有优化后的本地模型参数输入图注意力模型后,获得每个节点的权值,即获得每个目标用户端的权值,然后根据目标用户端的权值对各个目标用户端的本地模型参数进行加权,获得各个目标用户端的加权后本地模型参数(W1、W2、...Wn) ,然后再对加权后本地模型参数取均值,获得全局模型参数。
在另一个实施例中,采用图注意力模型直接对图结构的目标用户端进行分类聚合,获得每一类目标用户端的本地模型参数的更新值,然后进行加权平均,获得全局模型参数。如图4所示,根据图注意力模型获得全局模型参数具体包括如下步骤:
步骤A521:基于邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
步骤A522:对所有的更新值加权平均,获得全局模型参数。
具体地,图注意力模型根据邻接矩阵中的相似度度量值,将相似度度量值在预设阈值范围内的目标用户端归为同一类,并将同一类目标用户端中的所有目标用户端聚合为一个目标用户端。在目标用户端分类聚合的过程中,对同一类目标用户端中的各个目标用户端的本地模型参数进行处理,获得该类目标用户端的优化后的本地模型参数的更新值。然后取所有类别的更新值的均值,获得全局模型参数。在一个示例中,在每一类目标用户端中筛选出一个目标用户端,忽略其他的同类目标用户端。
通过采用邻接矩阵中的相似度度量值来对目标用户端进行归类和筛选,能够更加有效地降低同一类本地模型参数的权重,训练后的本地模型的鲁棒更高。
在一个实施例中,如图5所示,根据图注意力模型获得全局模型参数具体包括如下步骤:
步骤B521:基于邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
步骤B522:将所有的更新值输入多层感知机,获得每一类目标用户端的分值;
步骤B523:根据分值对所有更新值进行加权平均,获得全局模型参数。
具体地,根据邻接矩阵中的相似度度量值,得到每类目标用户端的优化后的本地模型参数的更新值后,还将所有的更新值输入多层感知机,通过多层感知机对每一类目标用户端进行打分,将多层感知机输出的分值作为各类目标用户端的权重,再采用该权重对各类的更新值进行加权平均,获得全局模型参数。
由上所述,通过采用多层感知机对各类目标用户端进行打分,能够评判各类目标用户端的重要性,然后再将分值作为权重进行加权平均。能够进一步地提高全局模型参数的有效性,获得更好的训练效果,提高模型的鲁棒性。
步骤S600:返回步骤S200,进行下一轮次的迭代优化以重新获得全局模型参数,直至满足预设条件并输出优化后的全局模型。
具体地,服务端更新了全局模型参数后,重新开始用户端的采样,使用下一批的目标用户端进行迭代训练和进行联邦学习,重新获得全局模型参数,直至满足预设条件。预设条件可以为全局模型参数的梯度更新值达到训练收敛条件或迭代次数达到预设的通信轮次数。当联邦学习训练完毕后,根据全局模型参数配置服务端的全局模型,获得优化后的全局模型,然后将优化后的全局模型传输至各个用户端以更新用户端的本地模型。
可选的,也可以将全局模型参数传给各个用户端,用户端根据全局模型参数配置用户端的本地模型,同样实现将优化后的全局模型更新至用户端的效果。
综上所述,由于采用图模型学习用户端之间的图结构,不需要预先设置聚类超参,能够自适应地进行用户端分类,进而有效促进用户端的本地模型参数信息在聚合时的利用。并且通过将用户聚类参数问题重定义为一个图模型在用户之间的连接学习问题,能够以图网络模型为基础构造上层用户结构学习模型,相比于已有的简单分类方法在分类上更加精细且有效,可以应对数据异质程度较高且用户数较大的场景。
示例性系统
如图6所示,对应于上述基于图结构学习的联邦学习方法,本发明实施例还提供基于图结构学习的联邦学习系统,上述系统包括:
初始化模块600,用于初始化全局模型参数;
采样模块610,用于在所有用户端中采样,获得多个目标用户端;
参数更新模块620,用于根据所述全局模型参数更新所有目标用户端的本地模型参数;
训练模块630,用于基于所述目标用户端各自的训练数据、本地模型参数,在目标用户端迭代优化所述目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
聚合模块640,用于将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
迭代模块650,用于返回在所有用户端中采样以重新获得所述全局模型参数,直至满足预设条件并输出优化后的全局模型。
具体的,本实施例中,上述基于图结构学习的联邦学习系统的各模块的具体功能可以参照上述基于图结构学习的联邦学习方法中的对应描述,在此不再赘述。
基于上述实施例,本发明还提供了一种智能终端,其原理框图可以如图7所示。上述智能终端包括通过系统总线连接的处理器、存储器、网络接口以及显示屏。其中,该智能终端的处理器用于提供计算和控制能力。该智能终端的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和基于图结构学习的联邦学习程序。该内存储器为非易失性存储介质中的操作系统和基于图结构学习的联邦学习程序的运行提供环境。该智能终端的网络接口用于与外部的终端通过网络连接通信。该基于图结构学习的联邦学习程序被处理器执行时实现上述任意一种基于图结构学习的联邦学习方法的步骤。该智能终端的显示屏可以是液晶显示屏或者电子墨水显示屏。
本领域技术人员可以理解,图7中示出的原理框图,仅仅是与本发明方案相关的部分结构的框图,并不构成对本发明方案所应用于其上的智能终端的限定,具体的智能终端可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,提供了一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的基于图结构学习的联邦学习程序,上述基于图结构学习的联邦学习程序被上述处理器执行时进行以下操作指令:
初始化全局模型参数;
在所有用户端中采样,获得多个目标用户端;
根据所述全局模型参数更新所有目标用户端的本地模型参数;
基于所述目标用户端各自的训练数据和本地模型参数,在所述目标用户端迭代优化所述目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
返回在所有用户端中采样以重新获得所述全局模型参数,直至满足预设条件并输出优化后的全局模型。
可选的,所述图网络模型为图注意力模型,所述将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数,包括:
根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵;
将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数。
可选的,所述根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵,包括:
基于所述优化后的本地模型参数,根据余弦相似度度量方法计算每两个目标用户端之间的相似度度量值;
根据所有的相似度度量值构建所述邻接矩阵。
可选的,所述将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数,包括:
基于所述邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
计算所有更新值的均值,获得所述全局模型参数。
可选的,还设有多层感知机,所述将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数,包括:
基于所述邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
将所有的所述更新值输入多层感知机,获得每一类目标用户端的分值;
根据所述分值对所有所述更新值进行加权平均,获得所述全局模型参数。
本发明实施例还提供一种计算机可读存储介质,上述计算机可读存储介质上存储有基于图结构学习的联邦学习程序,上述基于图结构学习的联邦学习程序被处理器执行时实现本发明实施例提供的任意一种基于图结构学习的联邦学习方法的步骤。
应理解,上述实施例中各步骤的序号大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将上述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各实例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟是以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的实施例中,应该理解到,所揭露的装置/终端设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/终端设备实施例仅仅是示意性的,例如,上述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以由另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
上述集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,上述计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,上述计算机程序包括计算机程序代码,上述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。上述计算机可读介质可以包括:能够携带上述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,RandomAccess Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,上述计算机可读存储介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减。
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解;其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不是相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
Claims (10)
1.基于图结构学习的联邦学习方法,其特征在于,所述方法包括:
初始化全局模型参数;
在所有用户端中采样,获得多个目标用户端;
根据所述全局模型参数更新所有所述目标用户端的本地模型参数;
基于所述目标用户端各自的训练数据和本地模型参数,在所述目标用户端迭代优化所述目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
返回在所有用户端中采样以重新获得所述全局模型参数,直至满足预设条件并输出优化后的全局模型。
2.如权利要求1所述的基于图结构学习的联邦学习方法,其特征在于,所述图网络模型为图注意力模型,所述将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数,包括:
根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵;
将所述邻接矩阵和所有所述优化后的本地模型参数输入所述图注意力模型,获得所述全局模型参数。
3.如权利要求2所述的基于图结构学习的联邦学习方法,其特征在于,所述根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵,包括:
基于所述优化后的本地模型参数,根据余弦相似度度量方法计算每两个目标用户端之间的相似度度量值;
根据所有的相似度度量值构建所述邻接矩阵。
4.如权利要求3所述的基于图结构学习的联邦学习方法,其特征在于,所述将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数,包括:
基于所述邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
计算所有所述更新值的均值,获得所述全局模型参数。
5.如权利要求3所述的基于图结构学习的联邦学习方法,其特征在于,还设有多层感知机,所述将所述邻接矩阵和所有优化后的本地模型参数输入图注意力模型,获得全局模型参数,包括:
基于所述邻接矩阵中的相似度度量值,采用图注意力模型获得每一类目标用户端的优化后的本地模型参数的更新值;
将所有的所述更新值输入多层感知机,获得每一类目标用户端的分值;
根据所述分值对所有所述更新值进行加权平均,获得所述全局模型参数。
6.基于图结构学习的联邦学习系统,其特征在于,所述系统包括:
初始化模块,用于初始化全局模型参数;
采样模块,用于在所有用户端中采样,获得多个目标用户端;
参数更新模块,用于根据所述全局模型参数更新所有目标用户端的本地模型参数;
训练模块,用于基于所述目标用户端各自的训练数据和本地模型参数,在所述目标用户端迭代优化所述目标用户端的本地模型,获得每个目标用户端的优化后的本地模型参数;
聚合模块,用于将所有的所述优化后的本地模型参数输入图网络模型以学习目标用户端之间的异质性,并根据所述异质性聚合所有的优化后的本地模型参数,获得全局模型参数;
迭代模块,用于返回在所有用户端中采样以重新获得所述全局模型参数,直至满足预设条件并输出优化后的全局模型。
7.如权利要求6所述的基于图结构学习的联邦学习系统,其特征在于,所述图网络模型为图注意力模型,所述聚合模块还包括邻接矩阵单元,所述邻接矩阵单元用于根据所述优化后的本地模型参数,计算用于表征所述目标用户端之间的连接的邻接矩阵;所述图注意力模型用于根据输入的所述邻接矩阵和所有所述优化后的本地模型参数获得所述全局模型参数。
8.如权利要求7所述的基于图结构学习的联邦学习系统,其特征在于,还设有多层感知机,所述图注意力模型用于基于所述邻接矩阵中的相似度度量值获得每一类目标用户端的优化后的本地模型参数的更新值,所述多层感知机用于基于所有的所述更新值获得每一类目标用户端的分值,所述聚合模块用于根据所述分值对所有所述更新值进行加权平均,获得所述全局模型参数。
9.智能终端,其特征在于,所述智能终端包括存储器、处理器以及存储在所述存储器上并可在所述处理器上运行的基于图结构学习的联邦学习程序,所述基于图结构学习的联邦学习程序被所述处理器执行时实现如权利要求1-5任意一项所述基于图结构学习的联邦学习方法的步骤。
10.计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有基于图结构学习的联邦学习程序,所述基于图结构学习的联邦学习程序被处理器执行时实现如权利要求1-5任意一项所述基于图结构学习的联邦学习方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310804065.0A CN116522988B (zh) | 2023-07-03 | 2023-07-03 | 基于图结构学习的联邦学习方法、系统、终端及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310804065.0A CN116522988B (zh) | 2023-07-03 | 2023-07-03 | 基于图结构学习的联邦学习方法、系统、终端及介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116522988A true CN116522988A (zh) | 2023-08-01 |
CN116522988B CN116522988B (zh) | 2023-10-31 |
Family
ID=87399760
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310804065.0A Active CN116522988B (zh) | 2023-07-03 | 2023-07-03 | 基于图结构学习的联邦学习方法、系统、终端及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116522988B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116958149A (zh) * | 2023-09-21 | 2023-10-27 | 湖南红普创新科技发展有限公司 | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 |
Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200220851A1 (en) * | 2019-12-13 | 2020-07-09 | TripleBlind, Inc. | Systems and methods for efficient computations on split data and split algorithms |
US20210158099A1 (en) * | 2019-11-26 | 2021-05-27 | International Business Machines Corporation | Federated learning of clients |
CN114912705A (zh) * | 2022-06-01 | 2022-08-16 | 南京理工大学 | 一种联邦学习中异质模型融合的优化方法 |
CN115146786A (zh) * | 2022-06-29 | 2022-10-04 | 支付宝(杭州)信息技术有限公司 | 联邦学习的实现方法、装置、系统、介质、设备以及产品 |
CN115511109A (zh) * | 2022-09-30 | 2022-12-23 | 中南大学 | 一种高泛化性的个性化联邦学习实现方法 |
CN115688913A (zh) * | 2022-12-29 | 2023-02-03 | 中南大学 | 一种云边端协同个性化联邦学习方法、系统、设备及介质 |
CN116205311A (zh) * | 2023-02-16 | 2023-06-02 | 同济大学 | 一种基于Shapley值的联邦学习方法 |
CN116227623A (zh) * | 2023-01-29 | 2023-06-06 | 深圳前海环融联易信息科技服务有限公司 | 联邦学习方法、装置、计算机设备及存储介质 |
-
2023
- 2023-07-03 CN CN202310804065.0A patent/CN116522988B/zh active Active
Patent Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210158099A1 (en) * | 2019-11-26 | 2021-05-27 | International Business Machines Corporation | Federated learning of clients |
US20200220851A1 (en) * | 2019-12-13 | 2020-07-09 | TripleBlind, Inc. | Systems and methods for efficient computations on split data and split algorithms |
CN114912705A (zh) * | 2022-06-01 | 2022-08-16 | 南京理工大学 | 一种联邦学习中异质模型融合的优化方法 |
CN115146786A (zh) * | 2022-06-29 | 2022-10-04 | 支付宝(杭州)信息技术有限公司 | 联邦学习的实现方法、装置、系统、介质、设备以及产品 |
CN115511109A (zh) * | 2022-09-30 | 2022-12-23 | 中南大学 | 一种高泛化性的个性化联邦学习实现方法 |
CN115688913A (zh) * | 2022-12-29 | 2023-02-03 | 中南大学 | 一种云边端协同个性化联邦学习方法、系统、设备及介质 |
CN116227623A (zh) * | 2023-01-29 | 2023-06-06 | 深圳前海环融联易信息科技服务有限公司 | 联邦学习方法、装置、计算机设备及存储介质 |
CN116205311A (zh) * | 2023-02-16 | 2023-06-02 | 同济大学 | 一种基于Shapley值的联邦学习方法 |
Non-Patent Citations (2)
Title |
---|
"基于边缘的联邦学习模型清洗和设备聚类方法", 《计算机学报》, vol. 44, no. 12, pages 2515 - 2528 * |
CHAMATH PALIHAWADANA 等: "FedSim: similarity guided model aggregation for federated learning", 《NEUROCOMPUTING》, vol. 483, pages 432 - 445, XP086986944, DOI: 10.1016/j.neucom.2021.08.141 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116958149A (zh) * | 2023-09-21 | 2023-10-27 | 湖南红普创新科技发展有限公司 | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 |
CN116958149B (zh) * | 2023-09-21 | 2024-01-12 | 湖南红普创新科技发展有限公司 | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 |
Also Published As
Publication number | Publication date |
---|---|
CN116522988B (zh) | 2023-10-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2017206936A1 (zh) | 基于机器学习的网络模型构造方法及装置 | |
WO2020098606A1 (zh) | 节点分类方法、模型训练方法、装置、设备及存储介质 | |
CN109617888B (zh) | 一种基于神经网络的异常流量检测方法及系统 | |
CN110968426B (zh) | 一种基于在线学习的边云协同k均值聚类的模型优化方法 | |
WO2021089013A1 (zh) | 空间图卷积网络的训练方法、电子设备及存储介质 | |
WO2021098618A1 (zh) | 数据分类方法、装置、终端设备及可读存储介质 | |
CN116522988B (zh) | 基于图结构学习的联邦学习方法、系统、终端及介质 | |
CN112163637B (zh) | 基于非平衡数据的图像分类模型训练方法、装置 | |
WO2023179099A1 (zh) | 一种图像检测方法、装置、设备及可读存储介质 | |
CN113987236B (zh) | 基于图卷积网络的视觉检索模型的无监督训练方法和装置 | |
CN115829055B (zh) | 联邦学习模型训练方法、装置、计算机设备及存储介质 | |
WO2023207013A1 (zh) | 一种基于图嵌入的关系图谱关键人员分析方法及系统 | |
EP4386579A1 (en) | Retrieval model training method and apparatus, retrieval method and apparatus, device and medium | |
WO2024040941A1 (zh) | 神经网络结构搜索方法、装置及存储介质 | |
CN112817563B (zh) | 目标属性配置信息确定方法、计算机设备和存储介质 | |
CN112348079A (zh) | 数据降维处理方法、装置、计算机设备及存储介质 | |
CN116431597A (zh) | 用于训练数据分类模型的方法、电子设备和计算机程序产品 | |
CN105809200B (zh) | 一种生物启发式自主抽取图像语义信息的方法及装置 | |
CN112836629A (zh) | 一种图像分类方法 | |
CN116561622A (zh) | 一种面向类不平衡数据分布的联邦学习方法 | |
CN116307078A (zh) | 账户标签预测方法、装置、存储介质及电子设备 | |
CN115795355A (zh) | 一种分类模型训练方法、装置及设备 | |
CN116010832A (zh) | 联邦聚类方法、装置、中心服务器、系统和电子设备 | |
CN115496954A (zh) | 眼底图像分类模型构建方法、设备及介质 | |
CN115170919A (zh) | 图像处理模型训练及图像处理方法、装置、设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |