CN114819190A - 基于联邦学习的模型训练方法、装置、系统、存储介质 - Google Patents
基于联邦学习的模型训练方法、装置、系统、存储介质 Download PDFInfo
- Publication number
- CN114819190A CN114819190A CN202210706292.5A CN202210706292A CN114819190A CN 114819190 A CN114819190 A CN 114819190A CN 202210706292 A CN202210706292 A CN 202210706292A CN 114819190 A CN114819190 A CN 114819190A
- Authority
- CN
- China
- Prior art keywords
- training
- model parameters
- clients
- precision
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Medical Informatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请实施例提供了一种基于联邦学习的模型训练方法、装置、系统、存储介质,属于人工智能技术领域。该方法包括:获取客户端的设备标记信息;通过标记信息从N个客户端中筛选出K1个客户端作为训练参与方;将预设的原始全局模型参数传输给训练参与方;获取训练参与方发送的精度值和本地模型参数,精度值由训练参与方对原始全局模型参数进行精度验证得到;根据精度值和本地模型参数更新原始全局模型参数,得到目标全局模型参数;将目标全局模型参数发送给更新参与方,目标全局模型参数用于对更新参与方的本地模型进行更新,更新参与方是从N个客户端中筛选出K2个客户端。本申请实施例能够保障全局模型在不损失性能的基础上实现更公平的表现。
Description
技术领域
本申请涉及人工智能技术领域,尤其涉及一种基于联邦学习的模型训练方法、装置、系统、存储介质。
背景技术
人工智能的应用中,经常需要对模型进行训练,一个足够有效的模型需要海量数据进行训练,而在一些较敏感的场景下(例如不同医院的患者数据,或不同车辆的驾驶数据等等),单个设备(客户端)可能没有足够数量和质量的数据来学习一个更加健壮的模型,多个设备(多个客户端)共同训练又可能造成隐私泄露的问题。联邦学习常用于该解决隐私泄露的共同训练场景,在联邦学习中,每个客户端利用本地数据训练局部模型或进行参数更新,然后仅传递模型参数到服务器,在服务器将所有参数聚合,从而在不需要交换数据的情况下获得一个联合的模型;但是实际应用中,该方式容易降低准确率和延长训练时间,并损害公平性。
发明内容
本申请实施例的主要目的在于提出一种基于联邦学习的模型训练方法、装置、系统、存储介质,旨在提高模型训练中的公平性、并提高训练效率。
为实现上述目的,本申请实施例的第一方面提出了一种基于联邦学习的模型训练方法,应用于服务器端,所述基于联邦学习的模型训练方法包括:
获取客户端的设备标记信息;其中,所述客户端包括本地模型;
通过标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
将预设的原始全局模型参数传输给所述训练参与方;
获取所述训练参与方发送的精度值和本地模型参数;其中,所述精度值由所述训练参与方对所述原始全局模型参数进行精度验证得到,所述本地模型参数由所述训练参与方通过对本地模型进行训练得到,所述本地模型是所述训练参与方本地的模型;
根据所述精度值和所述本地模型参数更新原始全局模型参数,得到目标全局模型参数;
将所述目标全局模型参数发送给更新参与方;其中,所述目标全局模型参数用于对所述更新参与方的本地模型进行更新,所述更新参与方是从N个所述客户端中筛选出第二预设数量的客户端,第二预设数量为K2,K2<N。
在一些实施例,所述根据所述精度值和所述本地模型参数更新原始全局模型参数,得到目标全局模型参数,包括:
根据所述精度值计算聚合权重;
根据所述聚合权重和所述本地模型参数更新原始全局模型参数,得到所述目标全局模型参数。
在一些实施例,所述根据所述精度值计算聚合权重,包括:
将所述精度值输入至预设的神经网络;其中,所述预设的神经网络为强化学习网络,所述强化学习网络包括智能体,将所述本地模型参数作为所述强化学习网络的状态;
通过所述强化学习网络对所述精度值进行奖励函数计算,得到奖励函数;
所述智能体对所述本地模型参数进行强化学习,得到策略分布;
所述智能体根据所述策略分布计算出所述训练参与方的历史权重,并根据所述奖励函数对所述历史权重进行奖励计算,得到奖励值;其中,所述历史权重作为所述强化学习网络在t-1轮迭代中的动作;
所述智能体根据所述历史权重和所述奖励值计算所述聚合权重;其中,所述聚合权重作为所述强化学习网络在t轮迭代中的动作。
在一些实施例,所述通过所述强化学习网络对所述精度值进行奖励函数计算,得到奖励函数,包括:
根据所述精度值计算平均精度;
根据所述精度值和所述平均精度计算基尼系数;
根据所述平均精度和所述基尼系数计算所述奖励函数。
为实现上述目的,本申请实施例的第二方面提出了一种基于联邦学习的模型训练方法,应用于客户端,所述基于联邦学习的模型训练方法包括:
向服务器端发送所述客户端自身的设备标记信息;其中,所述客户端包括本地模型,所述标记信息用于所述服务器端从N个所述客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
所述训练参与方接收服务器端传输的原始全局模型参数;
所述训练参与方对所述原始全局模型参数进行精度验证,得到精度值;
所述训练参与方对所述本地模型进行训练,得到本地模型参数;
所述训练参与方将所述精度值和所述本地模型参数发送给所述服务器端;
更新参与方接收所述服务器端根据所述精度值和所述本地模型参数对所述服务器的原始全局模型参数进行更新得到的目标全局模型参数,并根据接收到的目标全局模型参数更新自身的本地模型参数;其中,所述更新参与方是所述服务器端从N个所述客户端中筛选出第二预设数量的客户端得到,所述第二预设数量为K2,K2<N。
在一些实施例,所述训练参与方对所述原始全局模型参数进行精度验证,得到精度值,包括:
所述训练参与方通过预设的验证集对所述原始全局模型参数进行测试,得到预测正确的正确样本数;其中,所述验证集包括样本总数;
将所述正确样本数除以所述样本总数,得到所述精度值。
为实现上述目的,本申请实施例的第三方面提出了一种基于联邦学习的模型训练装置,应用于服务器端,所述基于联邦学习的模型训练装置包括:
设备标记获取模块,用于获取客户端的设备标记信息;其中,所述客户端包括本地模型;
客户端筛选模块,用于通过标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
模型参数传输模块,用于将预设的原始全局模型参数传输给所述训练参与方;
数据获取模块,用于获取所述训练参与方发送的精度值和本地模型参数;其中,所述精度值由所述训练参与方对所述原始全局模型参数进行精度验证得到,所述本地模型参数由所述训练参与方通过对本地模型进行训练得到,所述本地模型是所述训练参与方本地的模型;
更新模块,用于根据所述精度值和所述本地模型参数更新原始全局模型参数,得到目标全局模型参数;
参数发送模块,用于将所述目标全局模型参数发送给更新参与方;其中,所述目标全局模型参数用于对所述更新参与方的本地模型进行更新,所述更新参与方是从N个所述客户端中筛选出第二预设数量的客户端,第二预设数量为K2,K2<N。
为实现上述目的,本申请实施例的第四方面提出了一种基于联邦学习的模型训练装置,应用于客户端,所述基于联邦学习的模型训练装置包括:
设备标记发送模块,向服务器端发送所述客户端自身的设备标记信息;其中,所述客户端包括本地模型,所述标记信息用于所述服务器端从N个所述客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
模型参数接收模块,所述模型参数接收模块用于使述训练参与方接收服务器端传输的原始全局模型参数;
精度验证模块,所述精度验证模块用于使所述训练参与方用于对所述原始全局模型参数进行精度验证,得到精度值;
本地模型训练模块,所述本地模型训练模块用于使所述训练参与方对所述本地模型进行训练,得到本地模型参数;
数据发送模块,所述数据发送模块用于使所述训练参与方将所述精度值和所述本地模型参数发送给所述服务器端;
调参模块,所述调参模块用于使更新参与方接收所述服务器端发送的目标全局模型参数,并用于根据接收到的目标全局模型参数更新自身的本地模型参数;其中,所述目标全局模型参数是由所述服务器端根据所述精度值和所述本地模型参数对所述服务器的原始全局模型参数进行更新得到,所述更新参与方是所述服务器端从N个所述客户端中筛选出第二预设数量的客户端得到,所述第二预设数量为K2,K2<N。
为实现上述目的,本申请实施例的第五方面提出了一种基于联邦学习的模型训练系统,所述基于联邦学习的模型训练系统包括存储器、处理器、存储在所述存储器上并可在所述处理器上运行的程序以及用于实现所述处理器和所述存储器之间的连接通信的数据总线,所述程序被所述处理器执行时实现:
如第一方面所述的方法;
或者,
如第二方面所述的方法。
为实现上述目的,本申请实施例的第六方面提出了一种存储介质,所述存储介质为计算机可读存储介质,用于计算机可读存储,所述存储介质存储有一个或者多个程序,所述一个或者多个程序可被一个或者多个处理器执行,以实现:
如第一方面所述的方法;
或者,
如第二方面所述的方法。
本申请提出的基于联邦学习的模型训练方法、装置、系统、存储介质,其通过客户端的标记信息从N个客户端中筛选出K1个客户端作为训练参与方,并将服务器端自身预设的原始全局模型参数传输给训练参与方,从而训练参与方对原始全局模型参数进行精度验证得到精度值、并通过客户端端自身的本地模型进行训练得到本地模型参数;训练参与方将根据精度值和本地模型参数发送给服务器端,从而服务器端根据精度值和本地模型参数更新原始全局模型参数,以得到目标全局模型参数,并将目标全局模型参数发送给更新参与方(该更新参与方是从N个客户端中筛选出的K2个客户端)进行更新,从而实现联邦学习的模型训练,实现动态地调整参数,并基于强化学习的方法,通过训练强化学习智能体学习一个更公平的聚合策略,保障全局模型在不损失性能的基础上实现更公平的表现。
附图说明
图1是本申请实施例提供的应用于服务器端的基于联邦学习的模型训练方法的流程图;
图2是图1中的步骤105的流程图;
图3是图2中的步骤201的流程图;
图4是图3中的步骤302的流程图;
图5是本申请实施例实现基于联邦学习的模型训练方法的整体架构图;
图6是本申请实施例提供的应用于客户端的基于联邦学习的模型训练方法的流程图;
图7是本申请实施例提供的应用于服务器端的基于联邦学习的模型训练装置的结构示意图;
图8是本申请实施例提供的应用于客户端的基于联邦学习的模型训练装置的结构示意图;
图9是本申请实施例提供的基于联邦学习的模型训练系统的硬件结构示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本申请,并不用于限定本申请。
需要说明的是,虽然在装置示意图中进行了功能模块划分,在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于装置中的模块划分,或流程图中的顺序执行所示出或描述的步骤。说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
首先,对本申请中涉及的若干名词进行解析:
人工智能(artificial intelligence,AI):是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学;人工智能是计算机科学的一个分支,人工智能企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能可以对人的意识、思维的信息过程的模拟。人工智能还是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
联邦学习:联邦学习是一种新型的分布式学习范式,也是一种带有隐私保护、安全加密技术的分布式机器学习框架,旨在解决从分散数据中进行深度网络的通信高效学习的问题,可以实现让分散的各参与方在满足不向其他参与者披露隐私数据的前提下,协作进行机器学习的模型训练。
强化学习(Reinforcement Learning,RL):强化学习又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题;学习从环境状态到行为的映射,使得智能体选择的行为能够获得环境最大的奖赏,使得外部环境对学习系统在某种意义下的评价(或整个系统的运行性能)为最佳。强化学习关键要素:智能体(agent),奖励(reward),动作(action),策略(policy),状态(state),环境(environment),马尔科夫决策过程(Markov decision process,MDP)。其基本原理是:如果agent的某个行为策略导致环境正的奖赏(强化信号),那么agent以后产生这个行为策略的趋势便会加强。agent的目标是在每个离散状态发现最优策略以使期望的折扣奖赏和最大。强化学习把学习看作试探评价过程,agent选择一个动作(action)用于环境(environment),环境接受该动作后状态(state)发生变化,同时产生一个强化信号(奖或惩)反馈给agent,agent根据强化信号和环境当前状态再选择下一个动作(action)。强化学习中由环境提供的强化信号是agent对所产生动作的好坏作一种评价(通常为标量信号),而不是告诉agent如何去产生正确的动作(action)。由于外部环境提供了很少的信息,agent必须靠自身的经历进行学习。通过这种方式,agent在动作评价的环境中获得知识,改进动作方案以适应环境受到正强化(奖)的概率增大。选择的动作不仅影响立即强化值,而且影响环境下一时刻的状态及最终的强化值。强化学习系统学习的目标是动态地调整参数,以达到强化信号最大。通常所说的强化学习,智能体agent作为学习系统,获取外部环境的当前状态(state)信息s,对环境采取试探行为u,并获取环境反馈的对此动作的评价r和新的环境状态。如果智能体的某动作u导致环境正的奖赏,那么智能体以后产生这个动作的趋势便会加强;反之,智能体产生这个动作的趋势将减弱。在学习系统的控制行为与环境反馈的状态及评价的反复的交互作用中,以学习的方式不断修改从状态到动作的映射策略,以达到优化系统性能目的。强化学习包括Value-based(基于价值)和Policy-based(基于策略)两类,其中Value-based是学习价值函数,从价值函数采取出策略,确定一个策略at,是一种间接产生策略的方法;Value-Base中的action-value估计值最终会收敛到对应的true values(通常是不同的有限数,可以转化为0到1之间的概率),因此通常会获得一个确定的策略(deterministic policy);Policy-based是学习策略函数,直接产生策略的方法,会产生各个动作的概率πθ(a∣s);Policy-Based通常不会收敛到一个确定性的值;Policy-Based适用于连续的动作空间,在连续的动作空间中,可以不用计算每个动作的概率,而是通过Gaussian distribution (正态分布)选择action。
马尔科夫决策过程(Markov decision process,MDP):一个马尔科夫决策过程由一个五元组构成〈S,A,P,R,γ〉,其中S代表了状态(state)的集合,A代表了动作(action)的集合;P描述了状态转移矩阵,R表示奖励函数,R(s,a)描述了在状态(state)S做动作A的奖励(reward),γ表示衰减因子,γ∈[0,1]。马尔科夫决策过程就是一个典型的序列决策过程的一种公式化,序列决策过程的示例:智能体(agent)与环境(environment)一直在互动;在每个时刻t,agent会接收到来自environment的状态(state)S,基于这个状态S,agent会做出动作(action)A,然后这个动作作用在environment上,于是agent可以接收到一个奖励Rt+1,并且agent就会到达新的状态。所以,其实agent与environment之间的交互就是产生了一个序列S0,A0,R1,S1,A1,R2,...。
深度强化学习(Deep Reinforcement Learning,DRL):深度强化学习是深度学习和强化学习的结合。强化学习定义了优化的目标,深度学习给出了运行机制:表征问题的方式以及解决问题的方式。将强化学习和深度学习结合在一起,寻求一个能够解决任何人类级别任务的代理,得到了能够解决很多复杂问题的一种能力。
人工智能中经常需要对模型进行训练,一个足够有效的模型需要海量数据进行训练,而在一些较敏感的场景下(例如不同医院的患者数据,或不同车辆的驾驶数据等等),单个设备(客户端)可能没有足够数量和质量的数据来学习一个更加健壮的模型,多个设备(多个客户端)共同训练又可能造成隐私泄露的问题。常用于该解决隐私泄露的共同训练场景,在联邦学习中,每个客户端利用本地数据训练局部模型或参数更新,然后仅传递模型参数到服务器,在云端将所有参数聚合,从而在不需要交换数据的情况下获得一个联合的模型。但是实际应用中,异质性导致高达9.2%的准确率下降和2.32倍训练时间的延长,并损害了公平性。
当前的横向联邦学习原理:客户端通过迭代聚合来自不同客户端的本地模型,通过特定的聚合方式训练共享的全局模型。在每轮迭代中,服务器随机选择一定数量的客户端传输全局模型参数,参与训练的客户端用下载的全局模型训练后上传本地训练模型参数并在服务器聚合新的全局模型;基于上述过程,可以概括为一个C分类问题,如下式(1)所示:
其中,fi(w)=Ex~Pi[fi(w,x)] ,是第i个客户端的本地损失函数。将不同客户端的损失函数聚合,假设有N个客户端对数据进行了分区,其中Di是客户端i上数据点的索引集的数目,第i个客户端在全局模型的聚合权重定义为式(2)所示:
由此,可以简单地认为一个客户端对于全局模型的影响由其样本大小决定。训练是在所有样本的联合上进行的均匀分布,其中所有样本都是一致加权的。
上述聚合方案虽然考虑到了客户端数据量的影响,但是由于真实应用中联邦学习中不同客户端上数据大小和分布的异质性,在大型网络中单纯以最小化总损失为目标,可能会不成比例地提升或降低模型在一些客户端上性能,即,导致设备的结果均匀性损失。如,虽然联邦平均精度很高,但网络中单个设备的精度无法得到保障。此时,一些情况下表现较差的客户端更偏向于退出联盟。此外,由于客户端上数据分布的不同,一些数据质量较高的客户端可能具有比其他客户端更重要的预测能力,换句话说,全局可能会过度依赖于个别的客户端的训练模型。因此当我们开始关注公平性、避免这种影响时,总体模型的收敛速度和预测精确度都可能会受到影响。
基于此,本申请实施例提供了一种基于联邦学习的模型训练方法、装置、系统、存储介质,旨在通过训练深度强化学习智能体学习一个更公平的聚合策略,保障全局模型在不损失性能的基础上实现更公平的表现。
本申请实施例可以基于人工智能技术对相关的数据进行获取和处理。其中,人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。
人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、机器人技术、生物识别技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
本申请实施例提供的基于联邦学习的模型训练方法,涉及人工智能技术领域。本申请实施例提供的基于联邦学习的模型训练方法可应用于终端中,也可应用于服务器端中,还可以是运行于终端或服务器端中的软件。在一些实施例中,终端可以是智能手机、平板电脑、笔记本电脑、台式计算机等;服务器端可以配置成独立的物理服务器,也可以配置成多个物理服务器构成的服务器集群或者分布式系统,还可以配置成提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN以及大数据和人工智能平台等基础云计算服务的云服务器;软件可以是实现基于联邦学习的模型训练方法的应用等,但并不局限于以上形式。
本申请可用于众多通用或专用的计算机系统环境或配置中。例如:个人计算机、服务器计算机、手持设备或便携式设备、平板型设备、多处理器系统、基于微处理器的系统、置顶盒、可编程的消费电子设备、网络PC、小型计算机、大型计算机、包括以上任何系统或设备的分布式计算环境等等。本申请可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本申请,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
需要说明的是,在本申请的各个具体实施方式中,当涉及到需要根据用户端的相关数据(本地模型参数、用户信息、用户行为数据、用户历史数据以及用户位置信息等),与用户身份或特性相关的数据进行相关处理时,都会先获得用户的许可或者同意,而且,对这些数据的收集、使用和处理等,都会遵守相关国家和地区的相关法律法规和标准。此外,当本申请实施例需要获取用户的敏感个人信息时,会通过弹窗或者跳转到确认页面等方式获得用户的单独许可或者单独同意,在明确获得用户的单独许可或者单独同意之后,再获取用于使本申请实施例能够正常运行的必要的用户相关数据。
本申请实施例提供的基于联邦学习的模型训练方法、装置、系统、存储介质,具体通过如下实施例进行说明,首先描述本申请实施例中的基于联邦学习的模型训练方法。
图1是本申请实施例提供的基于联邦学习的模型训练方法的一个可选的流程图,图1中所示的方法应用于服务器端,图1所示的方法可以包括但不限于包括步骤101至步骤106。
步骤101,获取客户端的设备标记信息;其中,客户端包括本地模型;
步骤102,通过标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,第一预设数量为K1,K1<N;
步骤103,将预设的原始全局模型参数传输给训练参与方;
步骤104,获取训练参与方发送的精度值和本地模型参数;其中,精度值由训练参与方对原始全局模型参数进行精度验证得到,本地模型参数由训练参与方通过对本地模型进行训练得到,本地模型是训练参与方本地的模型;
步骤105,根据精度值和本地模型参数更新原始全局模型参数,得到目标全局模型参数;
步骤106,将目标全局模型参数发送给更新参与方;其中,目标全局模型参数用于对更新参与方的本地模型进行更新,更新参与方是从N个客户端中筛选出第二预设数量的客户端,第二预设数量为K2,K2<N。
本申请实施例的步骤101至步骤106,通过客户端的标记信息从N个客户端中筛选出K1个客户端作为训练参与方,并将服务器端自身预设的原始全局模型参数传输给训练参与方,从而训练参与方对原始全局模型参数进行精度验证得到精度值、并通过客户端端自身的本地模型进行训练得到本地模型参数;训练参与方将根据精度值和本地模型参数发送给服务器端,从而服务器端根据精度值和本地模型参数更新原始全局模型参数,以得到目标全局模型参数,并将目标全局模型参数发送给更新参与方(该更新参与方是从N个客户端中筛选出的K2个客户端)进行更新,从而实现联邦学习的模型训练,实现动态地调整参数,并基于强化学习的方法,通过训练强化学习智能体学习一个更公平的聚合策略,保障全局模型在不损失性能的基础上实现更公平的表现。
在一些实施例的步骤101中,将一个可用设备作为一个客户端,每一客户端的设备标记信息是每一可用设备的设备标记信息,设备标记信息用于对客户端进行标识,每一客户端具有唯一的设备标记信息,一个可用设备和一个客户端是一一对应的,从而通过设备标记信息可以区分不同的客户端。其中,可用设备可以是手机等移动电子设备。
在执行图1所示的基于联邦学习的模型训练方法,客户端需先将设备标记信息传输给服务器端,从而客户端可以签入服务器端,在客户端签入服务器端后,服务器端可以执行图1所示的基于联邦学习的模型训练方法。
在一些实施例的步骤102中,通过上述的设备标记信息,可以从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,第一预设数量K1,可以根据实际需要进行设置,设置规则可以包括:K1=N*C1,其中,C1为预设比例;由于实际应用场景中,N的取值较大,而训练参与方的数量将小于N,因此C1是一个小于1的数值。此外,在一些实施例中,K个训练参与方可以从N个客户端中随机选取,选出的K个训练参与方,是通过设备标记信息进行唯一标识。其中,N个客户端中的任意两个客户端的数据大小和数据分布可以相同,也可以不相同,且至少有两个客户端的数据大小和数据分布不完全相同;也就是说,N个客户端中的所述客户端的数据大小和数据分布不能完全相同。
步骤103中,在一些实施例,原始全局模型参数可以是通过调用预先训练好的模型参数得到;在另一些实施例,原始全局模型参数也可以是随机初始化的模型参数得到。该原始全局模型参数可以表示为Winit。在一些应用场景,若原始全局模型参数Winit是随机初始化的模型参数,则该原始全局模型参数Winit可以是:均值为0.5且方差为1的高斯分布中的随机取值。
在一些实施例,步骤104的精度值是由训练参与方对原始全局模型参数进行精度验证得到,步骤104的本地模型参数是由训练参与方通过对训练参与方的本地模型进行训练得到。在一具体应用场景中,精度值是验证集上预测正确的样本数除以全部样本数得到,即可以表示为:精度值acck=预测正确的样本数/全部样本数;精度值acck表征的是验证集上的准确率;本申请实施例,以精度值acck表征准确率为例进行说明。在另一具体应用场景中,精度值是验证集上预测正确的正例样本数除以全部正例样本数得到,即可以表示为:精度值acck=预测正确的正例样本数/全部正例样本数;精度值acck表征的是验证集上的精确率;精度值acck表征精确率的实现原理与精度值acck表征准确率的实现原理类似,在此不再赘述。
请参阅图2,在一些实施例,步骤105可以包括但不限于包括步骤201至步骤202:
步骤201,根据精度值计算聚合权重;
步骤202,根据聚合权重和本地模型参数更新原始全局模型参数,得到目标全局模型参数。
具体地,聚合权重表示为Pk,来源于训练参与方自身的本地模型参数表示为Wk,更新得到的目标全局模型参数Wglobal=sum(Pk*Wk),也即:在t轮迭代中,目标全局模型参数为Wglobal=sum(Pk*Wk)。
具体地,请参阅图3,在一些实施例,步骤201可以包括但不限于包括步骤301至步骤305:
在一些实施例,根据精度值计算聚合权重,包括:
步骤301,将精度值输入至预设的神经网络;其中,预设的神经网络为强化学习网络,强化学习网络包括智能体,将本地模型参数作为强化学习网络的状态;
步骤302,通过强化学习网络对精度值进行奖励函数计算,得到奖励函数;
步骤303,智能体对本地模型参数进行强化学习,得到策略分布;
步骤304,智能体根据策略分布计算出训练参与方的历史权重,并根据奖励函数对历史权重进行奖励计算,得到奖励值;其中,历史权重作为强化学习网络在t-1轮迭代中的动作;
步骤305,智能体根据历史权重和奖励值计算聚合权重;其中,聚合权重作为强化学习网络在t轮迭代中的动作。
在一应用场景中,步骤201的实现是通过强化学习原理。在步骤301中,将精度值输入至强化学习网络,通过强化学习网络进行处理,输出聚合权重。在该强化学习网络中,t轮迭代中的K个训练参与方的本地模型参数Wk作为状态statet,statet={Wk t,k∈K};t轮迭代中的聚合权重Pk t作为动作actiont,表示为:actiont={Pk t,k∈K}。
请参阅图3,在一些实施例,在一些实施例的步骤302可以包括但不限于包括步骤401至步骤403:
步骤401,根据精度值计算平均精度;
步骤402,根据精度值和平均精度计算基尼系数;
步骤403,根据平均精度和基尼系数计算奖励函数。
在一些实施例的步骤401中,通过预设的强化学习网络对所获得的所有精度值acck计算平均值,得到平均精度μ。
在一些实施例的步骤402中,预设的强化学习网络根据精度值acck和平均精度μ计
算基尼系数Gini:。即该基尼系数Gini计算方式为:将所有的精度值进
行两两比较得到的精度值求平均再除以两倍的平均精度μ。
在一些实施例的步骤403中,预设的强化学习网络计算t轮迭代中奖励的公式为:奖励rewardt=-μt*log(Ginit)。通过该奖励rewardt、状态statet、动作actiont,训练智能体尽可能快地收敛到目标精度和公平性。由于强化学习是一个连续的随机过程。通过上述的改进,整个联邦学习训练的过程可以视作一个马尔科夫过程,即上一步的状态(本地模型参数)会决定了下一步的各种可能走向的概率(权重)。奖励函数作为强化学习网络训练目标(越大越好)用于指导网络改进。
在一些实施例的步骤303中,强化学习网络中的智能体通过对本地模型参数进行强化学习,可以得到相应的策略分布,该策略分布包括了每一轮迭代中所做出的决策:actiont。在一具有应用场景中,该策略分布中的策略参数ϕ可以表示为:ϕt+1←ϕt+β*rewardt▽ϕlogπ(statet,actiont)。其中,是β相关系数,是一个数值。
在一些实施例的步骤304中,智能体可以根据策略分布π确定训练参与方在t-1轮迭代中的动作actiont-1(即历史权重),并通过奖励函数rewardt=-μt*log(Ginit)对相应的动作actiont-1进行评分,以得到相应的奖励值,该奖励值可以表征对动作actiont-1的奖励或者惩罚,即通过奖励函数对策略分布π中的策略进行评价,若所做策略是好策略,则对该策略(动作action)进行奖励,若所做策略是不好的策略,则对该策略(动作action)进行惩罚。
在一些实施例的步骤305中,智能体根据t-1轮迭代中的动作actiont-1和对动作actiont-1的奖励值进行计算,得到t轮迭代中K个客户端的聚合权重Pk t={Pk t,k∈K}。
在申请实施例t轮迭代中,K个训练参与方的本地模型参数wk作为状态statet,statet={Wk t,k∈K},动作actiont表示为:actiont={Pk t,k∈K};通过奖励函数rewardt对动作actiont作出奖励;通过对上一步的状态Wk t-1决定下一步的各种走向的概率(聚合权重)。本申请实施例,采用强化学习的方式,通过最大化奖励函数来寻求最优分配策略。同时,强化学习鼓励随机的探索,因此将不同本地模型在全局模型所占比重的分配问题建模成为一个强化学习问题,探索最优聚合策略。
上述实施例中,将精度值作为强化学习网络的输入,强化学习网络进行强化学习处理后输出聚合权重,该聚合权重是一个长度为K维的向量,在该向量中,每个元素就是对应位置的客户端的权重。
本申请实施例,应用强化学习的方法,解决联邦学习训练机制下的公平性问题。本申请实施提出的方法原理可以作为一个通用插件,应用于各类不涉及聚合方式改动的联邦学习算法当中,实现公平的联邦学习算法,让传统的联邦学习算法在不损失精度和收敛效率的基础上,取得更加公平的表现,提升实际情况下联邦学习算法的可用性,保障用户的参与意愿,提升用户参与的积极性。
传统的联邦学习算法仅仅是简单地以各客户端的样本数量作为聚合权重;与传统的联邦学习算法相比,而本申请实施例使用的是一个强化学习网络,以模型在各客户端上的精度值作为输入,强化学习网络的输出作为该客户端的聚合权重。这样可以保证所有客户端的验证精度相差不会太大,增强公平性和鲁棒性。
在一些实施例的步骤106中,更新参与方是从N个客户端中随机筛选出的K2个客户端。在一具体应用场景中,K1=K2;可以理解的是,在一实施例,上述K1个训练参与方与该K2个更新参与方中可以存在共同的客户端,在另一实施例中,K1个训练参与方与该K2个更新参与方也可以完全不相同。
本申请实施例,实现联邦学习的模型训练,以实现动态地调整参数,并基于强化学习的方法,通过训练深度强化学习智能体学习一个更公平的聚合策略,保障全局模型在不损失性能的基础上实现更公平的表现。
请参阅图5,在一些实施例,图5示意了基于联邦学习的模型训练方法的整体架构
图,服务器端与K个客户端之间进行联邦学习,服务器端获取客户端的设备标记信息后,根
据设备标记信息从N个客户端中筛选出K个客户端作为K1个训练参与方,其中K=K1;服务器
端将自身预设的原始全局模型参数传输给K1个训练参与方(该K个客户端);在t轮迭代中,
K1个训练参与方对原始全局模型参数进行精度验证得到精度值acck t,K1个训练参与方通过
对本地模型进行训练得到本地模型参数wk,K1个训练参与方将精度值acck t和本地模型参数
Wk发送给服务器端。服务器端根据精度值acck t计算平均精度μ,根据精度值acck t和平均精度
μt计算基尼系数Ginit,并根据基尼系数Ginit计算聚合权重Pk,最后根据聚合权重Pk和本地
模型参数Wk更新原始全局模型参数,得到目标全局模型参数Wglobal,;
并将目标全局模型参数Wglobal发送给K2个更新参与方,其中该K2个更新参与方是从N个客户
端中随机筛选出的K个客户端K2=K=K1。在该应用场景中,服务器端基于强化学习网络的原
理更新原始全局模型参数,其中,在该强化学习网络中,K个训练参与方的本地模型参数Wk
作为状态statet,statet={Wk t,k∈K};根据平均精度μ和基尼系数Gini计算奖励rewardt,
rewardt=-μt*log(Ginit);将聚合权重Pk作为动作actiont,actiont={Pk t,k∈K};其中,t表示
为第t轮迭代。为了简化架构图,本实施例中,将K1个训练参与方与该K2个更新参与方统称
为K个客户端,也就是图5中的K个客户端可以指该K1个训练参与方,也可以指该K2个更新参
与方;可以理解的是,参照上述实施例的描述,该K1个训练参与方与该K2个更新参与方中可
以存在共同的客户端,K1个训练参与方与该K2个更新参与方也可以完全不相同。
在本申请实施例中,服务器端希望尽可能地增加全局准确率,同时各客户端希望最终模型表现与其他参与方的差异不大。因此,这个问题可以被描述为一个总成本分配的博弈论问题:个体动作者的自利目标(稳定性)和降低总成本的总体目标(最优性)。由于隐私问题需要得到保障,在模型训练中,无法直接对每个客户端上的原始数据进行访问,因此无法分析客户端上的数据分布。然而,客户端上的训练样本分布与基于这些样本训练的模型参数之间存在隐式联系。本申请实施例提出的基于联邦学习的模型训练方法的框架,可以被视为联邦学习算法的附加插件;本申请实施例是基于策略梯度的强化学习算法,根据客户端的本地模型参数,旨在通过在每一轮聚合中给参与更新的客户端分配不同的聚合权重方式,对于上述的问题做出权衡,可以保证所有客户端的验证精度相差不会太大,增强公平性和鲁棒性。
图6是本申请实施例提供的基于联邦学习的模型训练方法的一个可选的流程图,图6中所示的方法应用于客户端,图6中所示的方法可以包括但不限于包括步骤501至步骤506。
步骤501,向服务器端发送客户端自身的设备标记信息;其中,客户端包括本地模型,标记信息用于服务器端从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,第一预设数量为K1,K1<N;
步骤502,训练参与方接收服务器端传输的原始全局模型参数;
步骤503,训练参与方对原始全局模型参数进行精度验证,得到精度值;
步骤504,训练参与方对本地模型进行训练,得到本地模型参数;
步骤505,训练参与方将精度值和本地模型参数发送给服务器端;
步骤506,更新参与方接收服务器端根据精度值和本地模型参数对服务器端的原始全局模型参数进行更新得到的目标全局模型参数,并根据接收到的目标全局模型参数更新自身的本地模型参数;其中,更新参与方是服务器端从N个客户端中筛选出第二预设数量的客户端得到,第二预设数量为K2,K2<N。
在一些实施例的步骤501中,至少有N个客户端参与模型训练;将一个可用设备作为一个客户端,每一客户端具有一个唯一的设备标记信息,设备标记信息用于对客户端进行标识,一个可用设备和一个客户端是一一对应的,从而通过设备标记信息可以区分不同的客户端。每一客户端包括自身的本地模型。其中,可用设备可以是手机等移动电子设备。客户端需先将自身的设备标记信息传输给服务器端,从而客户端可以签入服务器端,在客户端签入服务器端后,服务器端根据设备标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,第一预设数量为K1,K1<N。
在一些实施例中,在执行步骤502之前,服务器端会将自身的原始全局模型参数传输给该K1个训练参与方,从而该K1个训练参与方中的每一个训练参与方会接收服务器端传输的原始全局模型参数。如上述实施例所述,服务器端自身的原始全局模型参数Winit可以是通过调用预先训练好的模型参数得到,也可以是随机初始化的模型参数得到,在一些应用场景,以原始全局模型参数Winit是随机初始化的模型参数为例进行说明,则该原始全局模型参数Winit可以是:均值为0.5且方差为1的高斯分布中的随机取值。
在一些实施例的步骤503可以包括但不限于包括:
训练参与方通过预设的验证集对原始全局模型参数进行测试,得到预测正确的正确样本数;其中,验证集包括样本总数;
将正确样本数除以样本总数,得到精度值。
具体地,精度值是指验证精度值,是模型在验证集上的测试准确率;K1个训练参与方对原始全局模型参数进行精度的验证过程可以参考测试过程,即为计算模型预测的准确度,但其结果可以用来对训练的模型进行改进(例如指导调参);本申请实施例的精度值的验证原理可以参照常规的精度验证原理,本申请实施例不做限定;此外,应理解,验证集可以使用二八原则得到,即随机抽取数据集的20%作为验证集,数据集的80%作为测试集;此外,也可以使用其他划分方式获得验证集,本申请实施例不做限定。精度值是验证集上测试正确的正确样本数除以样本总数得到,即可以表示为:精度值acck=正确样本数/样本总数;精度值acck表征的是验证集上的准确率;本申请实施例,以精度值acck表征准确率为例进行说明。在另一具体应用场景中,精度值是验证集上预测正确的正例样本数除以全部正例样本数得到,即可以表示为:精度值acck=预测正确的正例样本数/全部正例样本数;精度值acck表征的是验证集上的精确率;精度值acck表征精确率的实现原理与精度值acck表征准确率的实现原理类似,在此不再赘述。
服务器端先选中一定比例的客户端(K1个训练参与方)参与训练,在t轮迭代中,K1
个训练参与方用上一轮的本地模型参数Wk t-1进行本地训练后,上传更新后的精度值acck t和
更新后的本地模型参数Wk t至服务器端,服务器端根据精度值acck t计算平均精度μt,根据精
度值acck t和平均精度μt计算基尼系数Ginit,并根据基尼系数Ginit计算聚合权重Pk,最后根
据聚合权重Pk和本地模型参数Wk更新原始全局模型参数,得到目标全局模型参数Wglobal,。服务器端再选中一定比例的客户端(K2个更新参与方)参与更新。
在一些实施例的步骤504中,K1个训练参与方对本地模型进行训练,K1个训练参与方中的每一个训练参与方根据自身的本地数据进行训练得到本地模型参数wk。
在一些实施例的步骤505中,K1个训练参与方分别将各自的精度值acck t和本地模型参数Wk发送给服务器端。
在一些实施例,执行步骤506前,服务器端根据K1个训练参与方分别发送的精度值
acck计算平均精度μ,并根据精度值acck和平均精度μ计算基尼系数Gini,再根据基尼系数
Gini计算聚合权重Pk,并根据聚合权重Pk和本地模型参数Wk对服务器端自身的原始全局模
型参数进行更新得到目标全局模型参数;并且,服务器端从N个客户端
中筛选出K2个客户端作为更新训练方;从而K2个更新训练方可以执行步骤506,接收服务器
端发送的目标全局模型参数,并根据接收到的目标全局模型参数更新自身的本地模型参
数。
具体地,K个客户端进行联邦学习训练的原理为:K1个训练参与方中的每一个训练
参与方根据自身的本地数据进行训练得到本地模型参数Wk,并得到更新后的本地模型,每
一个训练参与方根据自身的验证集对更新后的本地模型进行精确度计算,得到精度值
acck,并将该精度值acck输入至图4所示的强化学习网络中,该强化学习网络根据该精度值
acck输出聚合权重Pk,通过加权平均,对原始全局模型参数进行更新,即W=sum(Pk*Wk),重复
该过程,直到服务器端的全局模型收敛,即得到目标全局模型参数。在
该强化学习网络中,K个训练参与方的本地模型参数wk作为状态statet,statet={Wk t,k∈K};
根据平均精度μ和基尼系数Gini计算奖励rewardt,rewardt=-μt*log(Ginit);将聚合权重Pk
作为动作actiont,actiont={Pk t,k∈K};其中,t表示为第t轮迭代。
在该联邦学习训练中,由于每轮随机抽取一定比例的客户端(K个)参与更新,因此权重分配中存在不可微的优化瓶颈。本申请实施例中,针对不可微的优化瓶颈,采用强化学习的方式,通过最大化奖励函数来寻求最优分配策略。同时,强化学习鼓励随机的探索,因此将不同本地模型在全局模型所占比重的分配问题建模成为一个强化学习问题,探索最优聚合策略。
因此,本申请实施例中的联邦学习过程可以建模成一个马尔科夫决策过程(MDP),状态statet由每轮中每个客户端的模型参数表示statet={Wk t,k∈K},给定当前状态,强化学习智能体会学习到一个策略分布,根据策略计算每个客户端对应的聚合权重Pk,奖励rewardt=-μt*log(Ginit),动作actiont={Pk t,k∈K}从而更新全局模型,得到目标全局模型参数,并将目标全局模型参数传输给对应的客户端,客户端在本地验证集上验证精度,得到精度值acck,并计算平均精度μt和基尼系数Ginit,得到强化学习智能体的奖励函数rewardt=-μt*log(Ginit),目的是训练智能体尽可能快地收敛到目标精度和公平性。此外,本申请实施例的强化学习是深度强化学习(DRL)。
由于强化学习是一个连续的随机过程。通过上述的改进,整个联邦学习训练的过程可以视作一个马尔科夫过程,即上一步的状态(模型在各客户端上的本地模型参数Wk t-1)会决定了下一步的各种可能走向的概率(权重)。奖励函数作为强化学习网络训练目标(越大越好)用于指导网络改进。
本申请实施例,应用强化学习的方法,解决联邦学习训练机制下的公平性问题。本申请实施提出的方法原理可以作为一个通用插件,应用于各类不涉及聚合方式改动的联邦学习算法当中,实现公平的联邦学习算法,让传统的联邦学习算法在不损失精度和收敛效率的基础上,取得更加公平的表现,提升实际情况下联邦学习算法的可用性,保障用户的参与意愿,提升用户参与的积极性。
传统的联邦学习算法仅仅是简单地以各客户端的样本数量作为聚合权重;与传统的联邦学习算法相比,而本申请实施例使用的是一个强化学习网络,以模型在各客户端上的精度值作为输入,强化学习网络的输出作为该客户端的聚合权重。这样可以保证所有客户端的验证精度相差不会太大,增强公平性和鲁棒性。
请参阅图7,本申请实施例还提供一种基于联邦学习的模型训练装置,应用于服务器端,可以实现上述应用于服务器端的基于联邦学习的模型训练方法,该基于联邦学习的模型训练装置包括:
设备标记获取模块,用于获取客户端的设备标记信息;其中,客户端包括本地模型;
客户端筛选模块,用于通过标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,第一预设数量为K1,K1<N;
模型参数传输模块,用于将预设的原始全局模型参数传输给训练参与方;
数据获取模块,用于获取训练参与方发送的精度值和本地模型参数;其中,精度值由训练参与方对原始全局模型参数进行精度验证得到,本地模型参数由训练参与方通过对本地模型进行训练得到,本地模型是训练参与方本地的模型;
更新模块,用于根据精度值和本地模型参数更新原始全局模型参数,得到目标全局模型参数;
参数发送模块,用于将目标全局模型参数发送给更新参与方;其中,目标全局模型参数用于对更新参与方的本地模型进行更新,更新参与方是从N个客户端中筛选出第二预设数量的客户端,第二预设数量为K2,K2<N。
该应用于服务器端的基于联邦学习的模型训练装置的具体实施方式与上述应用于服务器端的基于联邦学习的模型训练方法的具体实施例基本相同,在此不再赘述。
请参阅图8,本申请实施例还提供一种基于联邦学习的模型训练装置,应用于客户端,可以实现上述应用于客户端的基于联邦学习的模型训练方法,该基于联邦学习的模型训练装置包括:
设备标记发送模块,向服务器端发送客户端自身的设备标记信息;其中,客户端包括本地模型,标记信息用于服务器端从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,第一预设数量为K1,K1<N;
模型参数接收模块,模型参数接收模块用于使述训练参与方接收服务器端传输的原始全局模型参数;
精度验证模块,精度验证模块用于使训练参与方用于对原始全局模型参数进行精度验证,得到精度值;
本地模型训练模块,本地模型训练模块用于使训练参与方对本地模型进行训练,得到本地模型参数;
数据发送模块,数据发送模块用于使训练参与方将精度值和本地模型参数发送给服务器端;
调参模块,调参模块用于使更新参与方接收服务器端发送的目标全局模型参数,并用于根据接收到的目标全局模型参数更新自身的本地模型参数;其中,目标全局模型参数是由服务器端根据精度值和本地模型参数对服务器的原始全局模型参数进行更新得到,更新参与方是服务器端从N个客户端中筛选出第二预设数量的客户端得到,第二预设数量为K2,K2<N。
该应用于客户端的基于联邦学习的模型训练装置的具体实施方式与上述应用于客户端的基于联邦学习的模型训练方法的具体实施例基本相同,在此不再赘述。
本申请实施例还提供了一种基于联邦学习的模型训练系统,包括:服务器端、客户端;其中,服务器端用于实现上述应用于服务器端的基于联邦学习的模型训练方法,客户端用于实现上述应用于客户端的基于联邦学习的模型训练方法。
本申请实施例还提供了另一实施例的一种基于联邦学习的模型训练系统,该基于联邦学习的模型训练系统可以是硬件形式的电子设备,包括:存储器、处理器、存储在存储器上并可在处理器上运行的程序以及用于实现处理器和存储器之间的连接通信的数据总线,程序被处理器执行时实现上述基于联邦学习的模型训练方法。该基于联邦学习的模型训练系统可以为包括平板电脑、服务器等任意智能终端。
请参阅图9,图9示意了另一实施例的基于联邦学习的模型训练系统的硬件结构,该基于联邦学习的模型训练系统包括:
处理器801,可以采用通用的CPU(CentralProcessingUnit,中央处理器)、微处理器、应用专用集成电路(ApplicationSpecificIntegratedCircuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本申请实施例所提供的技术方案;
存储器802,可以采用只读存储器(ReadOnlyMemory,ROM)、静态存储设备、动态存储设备或者随机存取存储器(RandomAccessMemory,RAM)等形式实现。存储器802可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器802中,并由处理器801来调用执行本申请实施例的基于联邦学习的模型训练方法;
输入/输出接口803,用于实现信息输入及输出;
通信接口804,用于实现本设备与其他设备的通信交互,可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信;
总线805,在设备的各个组件(例如处理器801、存储器802、输入/输出接口803和通信接口804)之间传输信息;
其中处理器801、存储器802、输入/输出接口803和通信接口804通过总线805实现彼此之间在设备内部的通信连接。
本申请实施例还提供了一种存储介质,存储介质为计算机可读存储介质,用于计算机可读存储,存储介质存储有一个或者多个程序,一个或者多个程序可被一个或者多个处理器执行,以实现上述基于联邦学习的模型训练方法。
存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序以及非暂态性计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
本申请实施例提供的基于联邦学习的模型训练方法、装置、系统、存储介质,通过客户端的标记信息从N个客户端中筛选出K1个客户端作为训练参与方,并将服务器端自身预设的原始全局模型参数传输给训练参与方,从而训练参与方对原始全局模型参数进行精度验证得到精度值、并通过客户端端自身的本地模型进行训练得到本地模型参数;训练参与方将根据精度值和本地模型参数发送给服务器端,从而服务器端根据精度值和本地模型参数更新原始全局模型参数,以得到目标全局模型参数,并将目标全局模型参数发送给更新参与方(该更新参与方是从N个客户端中筛选出的K2个客户端)进行更新,从而实现联邦学习的模型训练,实现动态地调整参数,并基于强化学习的方法,通过训练强化学习智能体学习一个更公平的聚合策略,保障全局模型在不损失性能的基础上实现更公平的表现。与传统的仅简单地以各客户端的样本数量作为聚合权重联的邦学习算法相比,本申请实施例使用的是一个强化学习网络,以模型在各客户端上的精度值作为输入,强化学习网络的输出作为该客户端的聚合权重。这样可以保证所有客户端的验证精度相差不会太大,增强公平性和鲁棒性。本申请实施例,应用强化学习的方法,解决联邦学习训练机制下的公平性问题;实现公平的联邦学习算法,让传统的联邦学习算法在不损失精度和收敛效率的基础上,取得更加公平的表现,提升实际情况下联邦学习算法的可用性,保障用户的参与意愿,提升用户参与的积极性。
本申请实施例描述的实施例是为了更加清楚的说明本申请实施例的技术方案,并不构成对于本申请实施例提供的技术方案的限定,本领域技术人员可知,随着技术的演变和新应用场景的出现,本申请实施例提供的技术方案对于类似的技术问题,同样适用。
本领域技术人员可以理解的是,图1-6中示出的技术方案并不构成对本申请实施例的限定,可以包括比图示更多或更少的步骤,或者组合某些步骤,或者不同的步骤。
以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。
本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统、设备中的功能模块/单元可以被实施为软件、固件、硬件及其适当的组合。
本申请的说明书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
应当理解,在本申请中,“至少一个(项)”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:只存在A,只存在B以及同时存在A和B三种情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”,其中a,b,c可以是单个,也可以是多个。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,上述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
上述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括多指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例的方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等各种可以存储程序的介质。
以上参照附图说明了本申请实施例的优选实施例,并非因此局限本申请实施例的权利范围。本领域技术人员不脱离本申请实施例的范围和实质内所作的任何修改、等同替换和改进,均应在本申请实施例的权利范围之内。
Claims (10)
1.一种基于联邦学习的模型训练方法,应用于服务器端,其特征在于,所述基于联邦学习的模型训练方法包括:
获取客户端的设备标记信息;其中,所述客户端包括本地模型;
通过标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
将预设的原始全局模型参数传输给所述训练参与方;
获取所述训练参与方发送的精度值和本地模型参数;其中,所述精度值由所述训练参与方对所述原始全局模型参数进行精度验证得到,所述本地模型参数由所述训练参与方通过对本地模型进行训练得到,所述本地模型是所述训练参与方本地的模型;
根据所述精度值和所述本地模型参数更新原始全局模型参数,得到目标全局模型参数;
将所述目标全局模型参数发送给更新参与方;其中,所述目标全局模型参数用于对所述更新参与方的本地模型进行更新,所述更新参与方是从N个所述客户端中筛选出第二预设数量的客户端,第二预设数量为K2,K2<N。
2.根据权利要求1所述的基于联邦学习的模型训练方法,其特征在于,所述根据所述精度值和所述本地模型参数更新原始全局模型参数,得到目标全局模型参数,包括:
根据所述精度值计算聚合权重;
根据所述聚合权重和所述本地模型参数更新原始全局模型参数,得到所述目标全局模型参数。
3.根据权利要求2所述的基于联邦学习的模型训练方法,其特征在于,所述根据所述精度值计算聚合权重,包括:
将所述精度值输入至预设的神经网络;其中,所述预设的神经网络为强化学习网络,所述强化学习网络包括智能体,将所述本地模型参数作为所述强化学习网络的状态;
通过所述强化学习网络对所述精度值进行奖励函数计算,得到奖励函数;
所述智能体对所述本地模型参数进行强化学习,得到策略分布;
所述智能体根据所述策略分布计算出所述训练参与方的历史权重,并根据所述奖励函数对所述历史权重进行奖励计算,得到奖励值;其中,所述历史权重作为所述强化学习网络在t-1轮迭代中的动作;
所述智能体根据所述历史权重和所述奖励值计算所述聚合权重;其中,所述聚合权重作为所述强化学习网络在t轮迭代中的动作。
4.根据权利要求3所述的基于联邦学习的模型训练方法,其特征在于,所述通过所述强化学习网络对所述精度值进行奖励函数计算,得到奖励函数,包括:
根据所述精度值计算平均精度;
根据所述精度值和所述平均精度计算基尼系数;
根据所述平均精度和所述基尼系数计算所述奖励函数。
5.一种基于联邦学习的模型训练方法,应用于客户端,其特征在于,所述基于联邦学习的模型训练方法包括:
向服务器端发送所述客户端自身的设备标记信息;其中,所述客户端包括本地模型,所述标记信息用于所述服务器端从N个所述客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
所述训练参与方接收服务器端传输的原始全局模型参数;
所述训练参与方对所述原始全局模型参数进行精度验证,得到精度值;
所述训练参与方对所述本地模型进行训练,得到本地模型参数;
所述训练参与方将所述精度值和所述本地模型参数发送给所述服务器端;
更新参与方接收所述服务器端根据所述精度值和所述本地模型参数对所述服务器的原始全局模型参数进行更新得到的目标全局模型参数,并根据接收到的目标全局模型参数更新自身的本地模型参数;其中,所述更新参与方是所述服务器端从N个所述客户端中筛选出第二预设数量的客户端得到,所述第二预设数量为K2,K2<N。
6.根据权利要求5所述的基于联邦学习的模型训练方法,其特征在于,所述训练参与方对所述原始全局模型参数进行精度验证,得到精度值,包括:
所述训练参与方通过预设的验证集对所述原始全局模型参数进行测试,得到预测正确的正确样本数;其中,所述验证集包括样本总数;
将所述正确样本数除以所述样本总数,得到所述精度值。
7.一种基于联邦学习的模型训练装置,应用于服务器端,其特征在于,所述基于联邦学习的模型训练装置包括:
设备标记获取模块,用于获取客户端的设备标记信息;其中,所述客户端包括本地模型;
客户端筛选模块,用于通过标记信息从N个客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
模型参数传输模块,用于将预设的原始全局模型参数传输给所述训练参与方;
数据获取模块,用于获取所述训练参与方发送的精度值和本地模型参数;其中,所述精度值由所述训练参与方对所述原始全局模型参数进行精度验证得到,所述本地模型参数由所述训练参与方通过对本地模型进行训练得到,所述本地模型是所述训练参与方本地的模型;
更新模块,用于根据所述精度值和所述本地模型参数更新原始全局模型参数,得到目标全局模型参数;
参数发送模块,用于将所述目标全局模型参数发送给更新参与方;其中,所述目标全局模型参数用于对所述更新参与方的本地模型进行更新,所述更新参与方是从N个所述客户端中筛选出第二预设数量的客户端,第二预设数量为K2,K2<N。
8.一种基于联邦学习的模型训练装置,应用于客户端,其特征在于,所述基于联邦学习的模型训练装置包括:
设备标记发送模块,向服务器端发送所述客户端自身的设备标记信息;其中,所述客户端包括本地模型,所述标记信息用于所述服务器端从N个所述客户端中筛选出第一预设数量的客户端作为训练参与方;其中,所述第一预设数量为K1,K1<N;
模型参数接收模块,所述模型参数接收模块用于使述训练参与方接收服务器端传输的原始全局模型参数;
精度验证模块,所述精度验证模块用于使所述训练参与方用于对所述原始全局模型参数进行精度验证,得到精度值;
本地模型训练模块,所述本地模型训练模块用于使所述训练参与方对所述本地模型进行训练,得到本地模型参数;
数据发送模块,所述数据发送模块用于使所述训练参与方将所述精度值和所述本地模型参数发送给所述服务器端;
调参模块,所述调参模块用于使更新参与方接收所述服务器端发送的目标全局模型参数,并用于根据接收到的目标全局模型参数更新自身的本地模型参数;其中,所述目标全局模型参数是由所述服务器端根据所述精度值和所述本地模型参数对所述服务器的原始全局模型参数进行更新得到,所述更新参与方是所述服务器端从N个所述客户端中筛选出第二预设数量的客户端得到,所述第二预设数量为K2,K2<N。
9.一种基于联邦学习的模型训练系统,其特征在于,所述基于联邦学习的模型训练系统包括存储器、处理器、存储在所述存储器上并可在所述处理器上运行的程序以及用于实现所述处理器和所述存储器之间的连接通信的数据总线,所述程序被所述处理器执行时实现:
如权利要求1至4任一项所述的方法的步骤;
或者,
如权利要求5至6任一项所述的方法的步骤。
10.一种存储介质,所述存储介质为计算机可读存储介质,用于计算机可读存储,其特征在于,所述存储介质存储有一个或者多个程序,所述一个或者多个程序可被一个或者多个处理器执行,以实现:
如权利要求1至4任一项所述的方法的步骤;
或者,
如权利要求5至6任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210706292.5A CN114819190A (zh) | 2022-06-21 | 2022-06-21 | 基于联邦学习的模型训练方法、装置、系统、存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210706292.5A CN114819190A (zh) | 2022-06-21 | 2022-06-21 | 基于联邦学习的模型训练方法、装置、系统、存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114819190A true CN114819190A (zh) | 2022-07-29 |
Family
ID=82521153
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210706292.5A Pending CN114819190A (zh) | 2022-06-21 | 2022-06-21 | 基于联邦学习的模型训练方法、装置、系统、存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114819190A (zh) |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115145966A (zh) * | 2022-09-05 | 2022-10-04 | 山东省计算中心(国家超级计算济南中心) | 一种面向异构数据的对比联邦学习方法及系统 |
CN115829055A (zh) * | 2022-12-08 | 2023-03-21 | 深圳大学 | 联邦学习模型训练方法、装置、计算机设备及存储介质 |
CN115840965A (zh) * | 2022-12-27 | 2023-03-24 | 光谷技术有限公司 | 一种信息安全保障模型训练方法和系统 |
CN116415978A (zh) * | 2023-04-15 | 2023-07-11 | 广州芳禾数据有限公司 | 基于联邦学习和多方计算的文旅消费数据分析方法和装置 |
CN116935136A (zh) * | 2023-08-02 | 2023-10-24 | 深圳大学 | 处理类别不平衡医学图像分类问题的联邦学习方法 |
CN117173750A (zh) * | 2023-09-14 | 2023-12-05 | 中国民航大学 | 生物信息处理方法、电子设备及存储介质 |
WO2024041130A1 (zh) * | 2022-08-25 | 2024-02-29 | 华为技术有限公司 | 权益分配方法及装置 |
CN117992941A (zh) * | 2024-04-02 | 2024-05-07 | 广东创能科技股份有限公司 | 一种自助终端登录状态监控和主动安全保护的方法 |
-
2022
- 2022-06-21 CN CN202210706292.5A patent/CN114819190A/zh active Pending
Cited By (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2024041130A1 (zh) * | 2022-08-25 | 2024-02-29 | 华为技术有限公司 | 权益分配方法及装置 |
CN115145966A (zh) * | 2022-09-05 | 2022-10-04 | 山东省计算中心(国家超级计算济南中心) | 一种面向异构数据的对比联邦学习方法及系统 |
CN115145966B (zh) * | 2022-09-05 | 2022-11-11 | 山东省计算中心(国家超级计算济南中心) | 一种面向异构数据的对比联邦学习方法及系统 |
CN115829055A (zh) * | 2022-12-08 | 2023-03-21 | 深圳大学 | 联邦学习模型训练方法、装置、计算机设备及存储介质 |
CN115829055B (zh) * | 2022-12-08 | 2023-08-01 | 深圳大学 | 联邦学习模型训练方法、装置、计算机设备及存储介质 |
CN115840965A (zh) * | 2022-12-27 | 2023-03-24 | 光谷技术有限公司 | 一种信息安全保障模型训练方法和系统 |
CN115840965B (zh) * | 2022-12-27 | 2023-08-08 | 光谷技术有限公司 | 一种信息安全保障模型训练方法和系统 |
CN116415978A (zh) * | 2023-04-15 | 2023-07-11 | 广州芳禾数据有限公司 | 基于联邦学习和多方计算的文旅消费数据分析方法和装置 |
CN116415978B (zh) * | 2023-04-15 | 2024-03-22 | 广州芳禾数据有限公司 | 基于联邦学习和多方计算的文旅消费数据分析方法和装置 |
CN116935136A (zh) * | 2023-08-02 | 2023-10-24 | 深圳大学 | 处理类别不平衡医学图像分类问题的联邦学习方法 |
CN117173750A (zh) * | 2023-09-14 | 2023-12-05 | 中国民航大学 | 生物信息处理方法、电子设备及存储介质 |
CN117992941A (zh) * | 2024-04-02 | 2024-05-07 | 广东创能科技股份有限公司 | 一种自助终端登录状态监控和主动安全保护的方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114819190A (zh) | 基于联邦学习的模型训练方法、装置、系统、存储介质 | |
Alain et al. | Variance reduction in sgd by distributed importance sampling | |
CN110610242A (zh) | 一种联邦学习中参与者权重的设置方法及装置 | |
CN111754000A (zh) | 质量感知的边缘智能联邦学习方法及系统 | |
CN110807207B (zh) | 数据处理方法、装置、电子设备及存储介质 | |
CN110826725A (zh) | 基于认知的智能体强化学习方法、装置、系统、计算机设备及存储介质 | |
CN111460528A (zh) | 一种基于Adam优化算法的多方联合训练方法及系统 | |
WO2023206771A1 (zh) | 基于决策流图的环境建模方法、装置和电子设备 | |
CN109313540A (zh) | 口语对话系统的两阶段训练 | |
CN114330125A (zh) | 基于知识蒸馏的联合学习训练方法、装置、设备及介质 | |
CN110175678A (zh) | 用于模拟复杂的强化学习环境的系统和方法 | |
CN114021188A (zh) | 一种联邦学习协议交互安全验证方法、装置及电子设备 | |
CN114261400A (zh) | 一种自动驾驶决策方法、装置、设备和存储介质 | |
CN114492718A (zh) | 飞行决策生成方法和装置、计算机设备、存储介质 | |
US8914505B2 (en) | Methods and apparatus for tuning a network for optimal performance | |
CN110874638B (zh) | 面向行为分析的元知识联邦方法、装置、电子设备及系统 | |
Gummadi et al. | Mean field analysis of multi-armed bandit games | |
CN113726545A (zh) | 基于知识增强生成对抗网络的网络流量生成方法及装置 | |
US20230385611A1 (en) | Apparatus and method for training parametric policy | |
Huang et al. | Online crowd learning with heterogeneous workers via majority voting | |
CN117033997A (zh) | 数据切分方法、装置、电子设备和介质 | |
CN113782217A (zh) | 人体健康状况分级方法及装置 | |
CN112949850A (zh) | 超参数确定方法、装置、深度强化学习框架、介质及设备 | |
CN114169906A (zh) | 电子券推送方法、装置 | |
CN111461188A (zh) | 一种目标业务控制方法、装置、计算设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20220729 |
|
RJ01 | Rejection of invention patent application after publication |