CN116451593A - 基于数据质量评估的强化联邦学习动态采样方法及设备 - Google Patents

基于数据质量评估的强化联邦学习动态采样方法及设备 Download PDF

Info

Publication number
CN116451593A
CN116451593A CN202310700718.0A CN202310700718A CN116451593A CN 116451593 A CN116451593 A CN 116451593A CN 202310700718 A CN202310700718 A CN 202310700718A CN 116451593 A CN116451593 A CN 116451593A
Authority
CN
China
Prior art keywords
client
determining
model
federal learning
cost function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202310700718.0A
Other languages
English (en)
Other versions
CN116451593B (zh
Inventor
梁美玉
赵泽华
杜军平
薛哲
李昂
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Posts and Telecommunications
Original Assignee
Beijing University of Posts and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Posts and Telecommunications filed Critical Beijing University of Posts and Telecommunications
Priority to CN202310700718.0A priority Critical patent/CN116451593B/zh
Publication of CN116451593A publication Critical patent/CN116451593A/zh
Application granted granted Critical
Publication of CN116451593B publication Critical patent/CN116451593B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F30/00Computer-aided design [CAD]
    • G06F30/20Design optimisation, verification or simulation
    • G06F30/27Design optimisation, verification or simulation using machine learning, e.g. artificial intelligence, neural networks, support vector machines [SVM] or training a model
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/15Correlation function computation including computation of convolution operations
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/092Reinforcement learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/098Distributed learning, e.g. federated learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computing Systems (AREA)
  • Mathematical Optimization (AREA)
  • Mathematical Analysis (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Pure & Applied Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computational Mathematics (AREA)
  • Geometry (AREA)
  • Medical Informatics (AREA)
  • Computer Hardware Design (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请提供一种基于数据质量评估的强化联邦学习动态采样方法及设备,能够根据初始梯度信息构建初始全局模型,并根据初始全局模型的模型性能确定每个客户端的贡献指数,将贡献指数应用到联邦学习的客户端采样中,可以基于贡献指数评估每个客户端的数据质量。然后根据贡献指数和预设的目标精度确定每个客户端的最优动作价值函数值,因为最优动作价值函数综合考虑了模型性能和模型精度,所以根据最优动作价值函数值对预设数量个客户端进行采样,能够有效地在大量参与联邦学习的客户端中智能化地筛选出高数据质量的客户端,利用具有高数据质量的客户端进行强化联邦学习,可以提高联邦学习得到的全局模型的质量和精度。

Description

基于数据质量评估的强化联邦学习动态采样方法及设备
技术领域
本申请涉及数据处理技术领域,尤其涉及一种基于数据质量评估的强化联邦学习动态采样方法及设备。
背景技术
在联邦学习中,参与的客户端数量通常非常庞大且客户端拥有的数据质量复杂多样,因为模型分发和重新上传的带宽受限,在所有参与设备上并行执行模型更新和聚合是不切实际的,所以一般只选取一部分客户端参与联邦学习的训练过程。因此,客户端采样方法对于降低联邦学习的通信开销,提高联邦训练过程中的收敛速度和最终模型精度等至关重要。相关技术中基于客户端上的数据样本数量占整个训练样本的比例进行客户端采样来降低联邦学习的通信开销,但是,简单地将数据量作为评判客户端质量的指标,忽略了数据量大的客户端可能数据质量较低的可能,此时根据数据量选择客户端反而会降低模型质量并影响最终模型精度。
发明内容
有鉴于此,本申请的目的在于提出一种基于数据质量评估的强化联邦学习动态采样方法及设备,用于提高联邦学习得到的模型的质量和精度。
基于上述目的,本申请的第一方面提供了一种基于深度强化学习和数据质量评估的联邦学习客户端动态采样方法,包括:
确定客户端的初始梯度信息;
根据所述初始梯度信息构建联邦学习在当前通信回合的初始全局模型;
根据所述初始全局模型的模型性能确定每个客户端的贡献指数;
根据贡献指数和预设的目标精度确定每个所述客户端的最优动作价值函数值;
根据所述最优动作价值函数值对预设数量个客户端进行采样。
可选地,所述根据所述梯度信息构建联邦学习在当前通信回合的初始全局模型,包括:
根据联邦学习的通信轮次数确定历史全局模型;
确定每个所述客户端在联邦学习的当前通信轮次的样本量;
根据所述初始梯度信息、所述样本量和所述客户端数量确定聚合梯度;
根据所述聚合梯度和所述历史全局模型确定联邦学习在当前通信回合的所述初始全局模型。
可选地,所述模型性能包括标准模型性能和终端模型性能;
所述根据所述初始全局模型的模型性能确定每个客户端的贡献指数,包括:
根据预设的标准测试集确定所述初始全局模型的所述标准模型性能;
根据每个所述客户端的数据集确定所述初始全局模型的所述终端模型性能;
根据所述标准模型性能和所述终端模型性能确定每个客户端的贡献指数。
可选地,所述根据贡献指数和预设的目标精度确定每个所述客户端的最优动作价值函数值,包括:
根据所述客户端数量确定动作空间,其中,每个客户端对应所述动作空间内的一个选取动作;
根据所述目标精度和所述贡献指数确定所述动作空间中每个所述选取动作对应的即时奖励;
根据所述即时奖励和预设的折扣因子确定每个所述选取动作对应的所述最优动作价值函数值。
可选地,所述根据所述动作价值函数值对预设数量个客户端进行采样,包括:
对所述动作价值函数值进行降序排列,得到选取集合;
在所述选取集合中选取前所述预设数量个目标动作价值函数值;
将所述目标价值函数值对应的客户端确定目标客户端,并对所述目标客户端进行采样。
可选地,所述根据所述目标精度和所述贡献指数确定所述动作空间中每个所述选取动作对应的即时奖励,包括:
根据所述联邦学习的通信轮次数确定所述初始全局模型在预设的验证集中的当前测试精度;
确定所述当前测试精度和所述目标精度确定精度差值;
根据所述精度差值和所述贡献指数确定所述即时奖励。
可选地,所述基于数据质量评估的强化联邦学习动态采样方法还包括:
根据每个通信轮次的即时奖励和所述折扣因子确定所述累计折扣奖励,其中,所述累计折扣奖励和通信轮次之间为反比关系;
响应于所述累计折扣奖励小于等于预设的奖励阈值,结束训练并输出聚合后的全局模型。
本申请的第二方面提供了一种基于数据质量评估的强化联邦学习动态采样装置,包括:
信息获取模块,被配置为:确定客户端的初始梯度信息;
模型重建模块,被配置为:根据所述初始梯度信息构建联邦学习在当前通信回合的初始全局模型;
贡献计算模块,被配置为:根据所述初始全局模型的模型性能确定每个客户端的贡献指数;
价值计算模块,被配置为:根据贡献指数和预设的目标精度确定每个所述客户端的最优动作价值函数值;
动态采样模块,被配置为:根据所述最优动作价值函数值对预设数量个客户端进行采样。
本申请的第三方面提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如本申请第一方面提供的所述的方法。
本申请的第四方面提供了一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使计算机执行本申请第一方面提供的所述方法。
从上面所述可以看出,本申请提供的基于数据质量评估的强化联邦学习动态采样方法及设备,能够根据初始梯度信息构建初始全局模型,并根据初始全局模型的模型性能确定每个客户端的贡献指数,将贡献指数应用到联邦学习的客户端采样中,可以基于贡献指数评估每个客户端的数据质量。然后根据贡献指数和预设的目标精度确定每个强化联邦学习客户端的最优动作价值函数值,根据最优动作价值函数值对预设数量个客户端进行动态采样,因为最优动作价值函数综合考虑模型性能和模型精度,所以根据最优动作价值函数值对预设数量个客户端进行采样,能够有效地在大量参与联邦学习的客户端中筛选出高数据质量的客户端,利用具有高数据质量的客户端进行联邦学习,可以提高联邦学习得到的全局模型的质量和精度。
附图说明
为了更清楚地说明本申请或相关技术中的技术方案,下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例联邦学习的架构图;
图2为本申请实施例基于数据质量评估的强化联邦学习动态采样方法的流程图;
图3为本申请实施例构建初始全局模型的流程图;
图4为本申请实施例确定贡献指数的流程图;
图5为本申请实施例确定最优动作价值函数值的流程图;
图6为本申请实施例客户端选取的流程图;
图7为本申请实施例采用基于数据质量评估的强化联邦学习动态采样方法的联邦学习方法的流程图;
图8为本申请实施例基于数据质量评估的强化联邦学习动态采样装置的结构示意图;
图9为本申请实施例电子设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚明白,以下结合具体实施例,并参照附图,对本申请进一步详细说明。
需要说明的是,除非另外定义,本申请实施例使用的技术术语或者科学术语应当为本申请所属领域内具有一般技能的人士所理解的通常意义。本申请实施例中使用的“第一”、“第二”以及类似的词语并不表示任何顺序、数量或者重要性,而只是用来区分不同的组成部分。“包括”或者“包含”等类似的词语意指出现该词前面的元件或者物件涵盖出现在该词后面列举的元件或者物件及其等同,而不排除其他元件或者物件。“连接”或者“相连”等类似的词语并非限定于物理的或者机械的连接,而是可以包括电性的连接,不管是直接的还是间接的。“上”、“下”、“左”、“右”等仅用于表示相对位置关系,当被描述对象的绝对位置改变后,则该相对位置关系也可能相应地改变。
在本文中,需要理解的是,附图中的任何元素数量均用于示例而非限制,以及任何命名都仅用于区分,而不具有任何限制含义。
基于上述背景技术的描述,相关技术中还存在如下的情况:
大数据时代提供了海量数据,但是由于隐私安全、法律法规、公司制度等问题,大部分行业中的数据都是以孤岛形式存在。联邦学习是面向这种数据孤岛场景而设计的机器学习范式,它可以在数据不出本地的情况下完成联邦学习模型的训练,保证客户端数据的隐私与安全。
如图1所示,联邦学习是一种机器学习范式,联邦学习的联邦服务器和多个参与客户端之间通过传递模型参数保证数据可以在不出本地的情况下完成联邦学习模型的训练,即数据不动模型动,以此保证客户端数据的隐私与安全。
对于参与联邦学习的任意一个参与客户端,联邦服务器将初始全局模型发送至客户端,客户端利用自身的数据集对初始全局模型进行训练更新后,将更新后的模型或模型参数返回联邦服务器,联邦服务器将各个客户端返回的更新后的模型进行模型聚合,完成一轮通信回合的模型训练,经过多轮通信回合的训练后,输出最后一轮通信回合得到的全局模型。
可看出客户端采样方法对于降低联邦学习的通信开销,提高联邦训练过程中的收敛速度和最终模型精度等至关重要。相关技术中基于客户端上的数据样本数量占整个训练样本的比例进行客户端采样来降低联邦学习的通信开销,但是,简单地将数据量作为评判客户端质量的指标,忽略了数据量大的客户端可能数据质量较低的可能,此时根据数据量选择客户端反而会降低模型质量并影响最终模型精度。
对于客户端数据质量的衡量,可以采用贡献指数进行衡量,贡献指数用于评估每个客户端对联邦学习训练的全局模型的贡献,在联邦学习的训练过程中记录中间结果,并使用这些中间结果来近似计算每个客户端的贡献指数,对于贡献指数可以采用以下两种方式进行计算:
第一种方法通过用不同轮次中的梯度更新联邦学习中的初始全局模型来重建模型,并通过这些重建模型的性能来计算贡献指数。第二种方法通过用本轮的梯度更新上一轮的全局模型来计算本轮客户端的贡献指数。
但是,若直接将贡献指数应用于联邦学习的客户端采样中,对贡献指数较高的客户端的选择过多可能会使全局模型“漂移”到其本地优化器,导致对全局模型更新的偏差,从而出现客户端“漂移”现象,即最后只采样几个甚至单个客户端进行训练。
本申请实施例提供的基于数据质量评估的强化联邦学习动态采样方法,能够基于客户端的数据质量评估,以及深度强化学习技术,对联邦学习客户端进行智能化动态采样,包括根据初始梯度信息构建初始全局模型,并根据初始全局模型的模型性能确定每个客户端的贡献指数,将贡献指数应用到联邦学习的客户端采样中,可以基于贡献指数评估每个客户端的数据质量。然后根据贡献指数和预设的目标精度确定每个强化联邦学习客户端的最优动作价值函数值,根据最优动作价值函数值对预设数量个客户端进行采样。因为最优动作价值函数综合考虑模型性能和模型精度,所以根据最优动作价值函数值对预设数量个客户端进行采样,能够有效地在大量参与联邦学习的客户端中筛选出高数据质量的客户端,利用具有高数据质量的客户端进行联邦学习,可以提高联邦学习得到的模型的质量和精度。
需要说明的是,由于在联邦学习过程中应用到了深度强化学习进行客户端的动态采样,所以称之为强化联邦学习。
在一些实施例中,如图2所示,基于数据质量评估的强化联邦学习动态采样方法,包括:
步骤201:确定客户端的初始梯度信息。
具体实施时,初始梯度信息为上一轮通信回合得到的历史梯度信息。当联邦学习的数据提供者包括n个客户端时,每个客户端提供的用于训练的数据集为,并为数据集的样本量,用于表示训练数据集的大小,并确定联邦学习通信回合的轮次数R-1表示最大联邦学习的通信轮次数。对于刚建立连接(t=0)时,在与参与联邦学习的客户端建立连接后,将第一轮通信回合开始前初始化的全局模型W (0)下发至每个客户端,返回子模型,对子模型进行模型聚合就得到了第0轮通信回合的全局模型W (1),就可以根据计算第一轮通信回合的的初始梯度信息。
对于连接建立完成后的通信回合(t=1、2...、R-1),将第t-1轮通信回合返回的子模型和聚合后的全局模型W (t),根据计算每个客户端的初始梯度信息。
步骤202:根据初始梯度信息构建联邦学习在当前通信回的初始全局模型。
具体实施时,通过各个客户端上样本量的加权平均来聚合每个客户端的梯度信息得到聚合梯度,然后根据各个客户端的聚合梯度进行模型重建,得到初始全局模型,其中,用表示的非空子集。根据来自客户端的初始梯度信息近似地重建每轮的初始全局模型而不是在N的所有非空子集上重新训练这些模型,避免了全局模型“漂移”到本地客户端。
步骤203:根据初始全局模型的模型性能确定每个客户端的贡献指数。
具体实施时,对每一个客户端,基于自身的数据集评估初始全局模型的模型的性能,通过评估重建后的初始全局模型的性能来确定自身的贡献指数。两个客户端i和j上的数据集Di和Dj,如果两者对于初始全局模型性能有相同的影响,即,则客户端i和客户端j具有相同的贡献指数 i = j
步骤204:根据贡献指数和预设的目标精度,通过深度强化学习确定每个客户端的最优动作价值函数值。
具体实施时,首先需要确定深度强化学习下联邦学习客户端采样动作的状态空间和动作空间。其中,由于在联邦学习的训练过程中,全局模型将在每一轮通信结束时更新,因此,在第t轮时,联邦学习服务器上的t-1轮的全局模型为,客户端上得到的t-1轮的的子模型为,则定义第t轮通信回合的状态空间为。在本申请实施例中,将强化联邦学习服务器视为基于深度Q网络的智能体,该智能体部署在联邦学习服务器上。其中,联邦学习服务器上的智能体维护着子模型列表,并且仅当客户端i在第t轮被选中参与训练初始全局模型,进而得到新的初始梯度信息时,才会更新。
进一步地,动作空间为:其中,意味着客户端i被选中参与联邦学习的训练过程。
而最优动作价值函数表示智能体在系统状态下选择特定动作可以获得的最大预期回报,所以最优动作价值函数值越大,选取对应客户端可获得的最大预期回报越大,选取该客户端训练的得到全局模型的质量和准度越高。其中,最优动作价值函数可以根据下式得到:
其中,为即时奖励,表示转态转换概率,表示反映当前奖励对未来奖励的重要性递减的折扣因子。而即时奖励,需要根据贡献指数和预设的目标精度来确定,其中,是初始全局模型在第t轮后在保留的验证集上实现的测试精度,所以在确定最优动作价值函数值时,贡献指数和测试精度可以视作为奖励项,因为即时奖励与贡献指数、测试精度均为正相关关系。
步骤205:根据最优动作价值函数值对预设数量个客户端进行采样。
具体实施时,可以在每轮通信回合中根据最优动作价值函数值大小进行降序排列,并选取前预设数量K个最优动作价值函数值对应的客户端参与本轮通信回合的联邦学习,完成客户端的智能化采样,由于每个通信回合选取的客户端存在不同,所以在整个联邦学习的过程中,采样是动态进行的。
需要说明的是,步骤201至步骤205是根据联邦学习在当前通信回合的当前最大回报进行客户端采样的,进一步地,为了避免总是注重当前回报最大化而忽略长期回报,可以通过计算累积折扣奖励来表示长期回报,其中,累积折扣奖励的表达式为:
由于折扣因子,所以可以看出累积折扣奖励随着通信回合轮数t的增加会越来越小,所以将通信轮次作为惩罚项可以加快模型的收敛速度并提高最终模型的精度。
综上所述,本申请实施例提供的基于数据质量评估的强化联邦学习动态采样方法,能够基于客户端的数据质量评估,以及深度强化学习技术,对联邦学习客户端进行智能化动态采样。在每一轮联邦学习中,首先,通过聚合上一轮次中各个客户端的梯度信息重建初始全局模型,根据模型性能评估各个客户端的贡献指数,即评估各个数据提供者的数据质量。然后,基于深度强化学习的智能体将贡献指数和模型的测试精度作为奖励项构建最优动作价值函数,最后,根据最优动作价值函数值选择前预设数量K个客户端参与到联邦学习的训练过程,完成客户端的智能化采样。并将通信轮次作为惩罚项可以加快模型的收敛速度并提高最终模型的精度。由于最优动作价值函数综合考虑模型性能和模型精度,所以根据最优动作价值函数值对预设数量个客户端进行采样,能够有效地在大量参与联邦学习的客户端中筛选出高数据质量的客户端,利用具有高数据质量的客户端进行联邦学习,可以提高联邦学习得到的模型的质量和精度。
在一些实施例中,如图3所示,根据梯度信息构建当前通信回的初始全局模型,包括:
步骤301:根据联邦学习的通信轮次数确定历史全局模型。
具体实施时,若当前通信回合为第t轮,则联邦学习的通信轮次数为t,则t-1轮通信回合中重建的初始全局模型为历史全局模型为
步骤302:确定每个客户端在联邦学习的当前通信轮次的样本量。
具体实施时,当联邦学习的数据提供者包括n个客户端时,每个客户端提供的用于训练的数据集为,并为数据集的样本量,用于表示训练数据集的大小,且样本量不会随着通信轮次的增加而改变。
步骤303:根据初始梯度信息、样本量和客户端数量确定聚合梯度。
具体实施时,通过各个客户端上样本量的加权平均来聚合每个客户端的梯度信息得到聚合梯度,使用初始梯度信息、样本量和客户端数量n,根据梯度聚合公式确定聚合梯度,其中,梯度聚合公式为:
步骤304:根据聚合梯度和历史全局模型确定联邦学习在当前通信回合的初始全局模型。
具体实施时,根据聚合梯度和历史全局模型,根据公式得到初始全局模型
在一些实施例中,模型性能包括标准模型性能和终端模型性能;如图4所示,根据初始全局模型的模型性能确定每个客户端的贡献指数,包括:
步骤401:根据预设的标准测试集确定初始全局模型的标准模型性能。
具体实施时,用B表示机器学习算法,T表示标准测试集,在标准测试集T上评估模型的标准模型性能由表示,在不存在疑意的情况下可以用,则初始全局模型在标准测试集T上的标准模型性能为
步骤402:根据每个客户端的数据集确定初始全局模型的终端模型性能。
具体实施时,当联邦学习的数据提供者包括n个客户端时,用表示每个客户端提供的用于训练的数据集,对于第i个客户端的数据集对初始全局模型进行评估,得到的终端模型性能
步骤403:根据标准模型性能和终端模型性能确定每个客户端的贡献指数。
具体实施时,用表示客户端i提供的数据集在满足,T和B条件下的贡献指数,可以简化为。其中,表示全部的数据集。
贡献指数应满足以下性质:
性质1:如果数据集对机器学习算法B在标准测试集T上的性能没有影响,则数据集的贡献指数为零。即对任意非空子集,如果有,那么=0。
性质2:若两个客户端i和j上的数据集Di和Dj,对机器学习算法B在标准测试集T上的性能影响相同,则两个客户端具有相同的贡献指数,即若,则 =
性质3:贡献指数对于标准测试集是线性相关的,对于不相交的两个标准测试集T1、T2,有
则满足上述三个性质的贡献指数的计算公式为:
其中,表示客户端i的贡献指数,C为一个常数,表示非空子集M和客户端集合N均不含有客户端i,表示非空子集M中客户端的数量。
在一些实施例中,如图5所示,根据贡献指数和预设的目标精度,通过深度强化学习确定每个强化联邦学习客户端的最优动作价值函数值,包括:
步骤501:获取联邦学习在当前通信回合的上一通信回合的历史模型数据,并根据历史模型数据构建深度强化学习的状态空间。
具体实施时,由于在联邦学习的训练过程中,联邦学习的当前通信回合为t时,在当前通信回合的上一通信回合t-1时的历史模型数据包括:联邦学习服务器上的的全局模型为,客户端上得到的子模型为,则第t轮通信回合的状态间为
步骤502:根据客户端数量确定深度强化学习的动作空间,其中,每个客户端对应动作空间内的一个选取动作。
具体实施时,当联邦学习的数据提供者包括n个客户端时,动作空间为:其中,表示客户端i被选中的选取动作。
步骤503:根据目标精度和贡献指数确定动作空间中每个选取动作对应的即时奖励。
具体实施时,根据目标精度和贡献指数确定动作空间中每个选取动作对应的即时奖励,包括:
根据联邦学习的通信轮次数确定初始全局模型在预设的验证集中的当前测试精度;
确定当前测试精度和目标精度确定精度差值;
根据精度差值和贡献指数确定即时奖励。
则即时奖励的计算公式为:
其中,为目标精度,是初始全局模型在第t轮后在保留的验证集上实现的测试精度,表示贡献指数是正向激励,可以作为奖励项,保证了即时奖励随着测试精度呈指数变化,控制了即时奖励随测试精度变化速度,是为了鼓励以更少的轮数完成联邦学习训练,因为需要的轮数越多,智能体获得的累积奖励就越少。因为,所以有,所以训练的结束条件之一为:当测试精度达到目标精度,即时联邦学习训练停止,此时达到其最大值0。
步骤504:根据即时奖励和预设的折扣因子确定每个选取动作在状态空间下的最优动作价值函数值。
具体实施时,最优动作价值函数值可以根据下式得到:
其中,为即时奖励,表示转态转换概率,表示反映当前奖励对未来奖励的重要性递减的折扣因子。基于最优动作价值函数,根据即时奖励和折扣因子确定每个选取动作对应的最优动作价值函数值。
在一些实施例中,如图6所示,根据动作价值函数值对预设数量个客户端进行采样,包括:
步骤601:对动作价值函数值进行降序排列,得到选取集合。
具体实施时,最优动作价值函数值表示智能体在特定系统状态下选择特定动作可以获得的最大预期回报,所以最优动作价值函数值越大,选取对应客户端可获得的最大预期回报越大,选取该客户端训练的得到全局模型的质量和准度越高。所以根据动作价值函数值的大小对动作价值函数值进行降序排列,得到选取集合。
步骤602:在选取集合中选取前预设数量个目标动作价值函数值。
具体实施时,在选取集合中,越靠前的动作价值函数值越大,所以根据预设数量K在选取集合中选取前K个动作价值函数值作为目标动作价值函数值。
步骤603:将目标价值函数值对应的客户端确定为目标客户端。
具体实施时,每个目标动作价值函数值都是一个客户端的贡献指数计算得到的,所以每个目标动作价值函数值均对应一个目标客户端。
步骤604:对目标客户端进行选取。
具体实施时,被选取的K个目标客户端会在当前通信回合参与到联邦学习的训练,完成采样。
在一些实施例中,在采样之后还包括:
根据每个通信轮次的即时奖励确定和折扣因子确定深度强化学习的累计折扣奖励,其中,累计折扣奖励和通信轮次之间为反比关系。
响应于累计折扣奖励小于等于预设的奖励阈值,结束联邦学习并输出聚合后的全局模型。
具体实施时,为了避免总是注重当前回报最大化而忽略长期回报,可以通过计算累积折扣奖励来表示长期回报,其中,累积折扣奖励的表达式为:
由于折扣因子,所以可以看出累积折扣奖励随着通信回合轮数t的增加会越来越小,所以将通信轮次作为惩罚项可以加快模型的收敛速度并提高最终模型的精度。所以当累积折扣奖励的值小于等于预设的奖励阈值时,也可以结束训练。需要说明的是,若累积折扣奖励和测试精度的收敛条件均未达成,当达到最大训练轮次时,结束训练输出全局模型。
在一些实施例中,如图7所示,采用基于数据质量评估的强化联邦学习动态采样方法的联邦学习方法包括以下步骤:
Step1.数据提供者中的n个客户端均与联邦学习服务器建立连接,确保每个客户端设备可用。
Step2.每个客户端都从服务器下载初始的全局模型权重,使用来表示子集合,并初始化基于不同非空子集的重建模型,从而得到初始化的贡献指数
Step3.在第轮中,其中首先计算客户端的梯度进行梯度聚合,然后,根据来自客户端的梯度近似地重建每轮的初始全局模型而不是在的所有非空子集上重新训练这些模型。再然后计算不同客户端(数据提供者)在当前通信回合中的贡献指数。再然后基于深度Q网络的智能体计算所有设备的最优动作价值函数值
Step4.基于深度Q网络的智能体根据计算的智能地选择前个客户端,被选中的个客户端会在本地执行的本地随机梯度下降并得到
Step5.被上传到服务器进行模型聚合,计算全局模型。进入第轮并重复步骤3-5。重复步骤 3-5 直到达到目标(累积折扣奖励和测试精度达到收敛条件)或达到一定的通信回合。
与上述过程对应的联邦学习的算法过程包括:
基于数据质量评估的强化联邦学习动态采样方法的联邦学习算法:
输入:本地批量大小B,本地迭代次数E,学习率η,每轮选择的客户端数量
输出:全局模型
1:Server aggregation:
2:;
3:初始化,;
4:for 每一轮执行:
5:对客户端;
6:for 每个子集并行执行:
7:;
8:;
9:end for;
10:for并行执行:
11:; /*计算贡献指数*/
12:end for;
13: 计算价值函数;
14:agent根据价值函数选择top-K设备;
15:ClientUpdate();
16:;
17: end for;
18: return;
在任一客户端中执行以下算法:
19:ClientUpdate(,):
20:将数据集划分成个批量大小;
21: for 每次本地迭代执行:
22:for 批量并行执行:
23:;
24:end for;
25: end for;
26: return给服务器。
需要说明的是,本申请实施例的方法可以由单个设备执行,例如一台计算机或服务器等。本实施例的方法也可以应用于分布式场景下,由多台设备相互配合来完成。在这种分布式场景的情况下,这多台设备中的一台设备可以只执行本申请实施例的方法中的某一个或多个步骤,这多台设备相互之间会进行交互以完成所述的方法。
需要说明的是,上述对本申请的一些实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于上述实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
基于同一发明构思,与上述任意实施例方法相对应的,本申请还提供了一种基于数据质量评估的强化联邦学习动态采样装置。
参考图8,所述基于数据质量评估的强化联邦学习动态采样装置,包括:
信息获取模块10,被配置为:确定客户端的初始梯度信息;
模型重建模块20,被配置为:根据初始梯度信息构建联邦学习在当前通信回合的初始全局模型;
贡献计算模块30,被配置为:根据初始全局模型的模型性能确定每个客户端的贡献指数;
价值计算模块40,被配置为:根据贡献指数和预设的目标精度确定每个客户端的最优动作价值函数值;
动态采样模块50,被配置为:根据最优动作价值函数值对预设数量个客户端进行采样。
为了描述的方便,描述以上装置时以功能分为各种模块分别描述。当然,在实施本申请时可以把各模块的功能在同一个或多个软件和/或硬件中实现。
上述实施例的装置用于实现前述任一实施例中相应的基于数据质量评估的强化联邦学习动态采样方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
基于同一发明构思,与上述任意实施例方法相对应的,本申请还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上任意一实施例所述的基于数据质量评估的强化联邦学习动态采样方法。
图9示出了本实施例所提供的一种更为具体的电子设备硬件结构示意图, 该设备可以包括:处理器1010、存储器1020、输入/输出接口1030、通信接口1040和总线 1050。其中处理器1010、存储器1020、输入/输出接口1030和通信接口1040通过总线1050实现彼此之间在设备内部的通信连接。
处理器1010可以采用通用的CPU(Central Processing Unit,中央处理器)、微处理器、应用专用集成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本说明书实施例所提供的技术方案。
存储器1020可以采用ROM(Read Only Memory,只读存储器)、RAM(Random AccessMemory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。
输入/输出接口1030用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在设备中(图中未示出),也可以外接于设备以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。
通信接口1040用于连接通信模块(图中未示出),以实现本设备与其他设备的通信交互。其中通信模块可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信。
总线1050包括一通路,在设备的各个组件(例如处理器1010、存储器1020、输入/输出接口1030和通信接口1040)之间传输信息。
需要说明的是,尽管上述设备仅示出了处理器1010、存储器1020、输入/输出接口1030、通信接口1040以及总线1050,但是在具体实施过程中,该设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述设备中也可以仅包含实现本说明书实施例方案所必需的组件,而不必包含图中所示的全部组件。
上述实施例的电子设备用于实现前述任一实施例中相应的基于数据质量评估的强化联邦学习动态采样方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
基于同一发明构思,与上述任意实施例方法相对应的,本申请还提供了一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行如上任一实施例所述的基于数据质量评估的强化联邦学习动态采样方法。
本实施例的计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。
上述实施例的存储介质存储的计算机指令用于使所述计算机执行如上任一实施例所述的基于数据质量评估的强化联邦学习动态采样方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
所属领域的普通技术人员应当理解:以上任何实施例的讨论仅为示例性的,并非旨在暗示本申请的范围(包括权利要求)被限于这些例子;在本申请的思路下,以上实施例或者不同实施例中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上所述的本申请实施例的不同方面的许多其它变化,为了简明它们没有在细节中提供。
另外,为简化说明和讨论,并且为了不会使本申请实施例难以理解,在所提供的附图中可以示出或可以不示出与集成电路(IC)芯片和其它部件的公知的电源/接地连接。此外,可以以框图的形式示出装置,以便避免使本申请实施例难以理解,并且这也考虑了以下事实,即关于这些框图装置的实施方式的细节是高度取决于将要实施本申请实施例的平台的(即,这些细节应当完全处于本领域技术人员的理解范围内)。在阐述了具体细节(例如,电路)以描述本申请的示例性实施例的情况下,对本领域技术人员来说显而易见的是,可以在没有这些具体细节的情况下或者这些具体细节有变化的情况下实施本申请实施例。因此,这些描述应被认为是说明性的而不是限制性的。
尽管已经结合了本申请的具体实施例对本申请进行了描述,但是根据前面的描述,这些实施例的很多替换、修改和变型对本领域普通技术人员来说将是显而易见的。例如,其它存储器架构(例如,动态RAM(DRAM))可以使用所讨论的实施例。
本申请实施例旨在涵盖落入所附权利要求的宽泛范围之内的所有这样的替换、修改和变型。因此,凡在本申请实施例的精神和原则之内,所做的任何省略、修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (10)

1.一种基于数据质量评估的强化联邦学习动态采样方法,其特征在于,包括:
确定客户端的初始梯度信息;
根据所述初始梯度信息构建联邦学习在当前通信回合的初始全局模型;
根据所述初始全局模型的模型性能确定每个客户端的贡献指数;
根据贡献指数和预设的目标精度,通过深度强化学习确定每个所述客户端的最优动作价值函数值;
根据所述最优动作价值函数值对预设数量个客户端进行采样。
2.根据权利要求1所述的方法,其特征在于,所述根据所述梯度信息构建联邦学习在当前通信回合的初始全局模型,包括:
根据联邦学习的通信轮次数确定历史全局模型;
确定每个所述客户端在联邦学习的当前通信轮次的样本量;
根据所述初始梯度信息、所述样本量和所述客户端数量确定聚合梯度;
根据所述聚合梯度和所述历史全局模型确定联邦学习在当前通信回合的所述初始全局模型。
3.根据权利要求1所述的方法,其特征在于,所述模型性能包括标准模型性能和终端模型性能;
所述根据所述初始全局模型的模型性能确定每个客户端的贡献指数,包括:
根据预设的标准测试集确定所述初始全局模型的所述标准模型性能;
根据每个所述客户端的数据集确定所述初始全局模型的所述终端模型性能;
根据所述标准模型性能和所述终端模型性能确定每个客户端的贡献指数。
4.根据权利要求1所述的方法,其特征在于,所述根据贡献指数和预设的目标精度,通过深度强化学习确定每个所述客户端的最优动作价值函数值,包括:
获取联邦学习在所述当前通信回合的上一通信回合的历史模型数据,并根据所述历史模型数据构建深度强化学习的状态空间;
根据所述客户端数量确定深度强化学习的动作空间,其中,每个客户端对应所述动作空间内的一个选取动作;
根据所述目标精度和所述贡献指数确定所述动作空间中每个所述选取动作对应的即时奖励;
根据所述即时奖励和预设的折扣因子确定每个所述选取动作在所述状态空间下的所述最优动作价值函数值。
5.根据权利要求4所述的方法,其特征在于,所述根据所述动作价值函数值对预设数量个客户端进行采样,包括:
对所述动作价值函数值进行降序排列,得到选取集合;
在所述选取集合中选取前所述预设数量个目标动作价值函数值;
将所述目标价值函数值对应的客户端确定为目标客户端,并对所述目标客户端进行选取。
6.根据权利要求4所述的方法,其特征在于,所述根据所述目标精度和所述贡献指数确定所述动作空间中每个所述选取动作对应的即时奖励,包括:
根据所述联邦学习的通信轮次数确定所述初始全局模型在预设的验证集中的当前测试精度;
确定所述当前测试精度和所述目标精度确定精度差值;
根据所述精度差值和所述贡献指数确定所述即时奖励。
7.根据权利要求5所述的方法,其特征在于,还包括:
根据每个通信轮次的即时奖励和所述折扣因子确定深度强化学习的所述累计折扣奖励,其中,所述累计折扣奖励和通信轮次之间为反比关系;
响应于所述累计折扣奖励小于等于预设的奖励阈值,结束联邦学习并输出聚合后的全局模型。
8.一种基于数据质量评估的强化联邦学习动态采样装置,其特征在于,包括:
信息获取模块,被配置为:确定客户端的初始梯度信息;
模型重建模块,被配置为:根据所述初始梯度信息构建联邦学习在当前通信回合的初始全局模型;
贡献计算模块,被配置为:根据所述初始全局模型的模型性能确定每个客户端的贡献指数;
价值计算模块,被配置为:根据贡献指数和预设的目标精度确定每个所述客户端的最优动作价值函数值;
动态采样模块,被配置为:根据所述最优动作价值函数值对预设数量个客户端进行采样。
9.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如权利要求1至7任意一项所述的方法。
10.一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使计算机执行权利要求1至7任一所述方法。
CN202310700718.0A 2023-06-14 2023-06-14 基于数据质量评估的强化联邦学习动态采样方法及设备 Active CN116451593B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310700718.0A CN116451593B (zh) 2023-06-14 2023-06-14 基于数据质量评估的强化联邦学习动态采样方法及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310700718.0A CN116451593B (zh) 2023-06-14 2023-06-14 基于数据质量评估的强化联邦学习动态采样方法及设备

Publications (2)

Publication Number Publication Date
CN116451593A true CN116451593A (zh) 2023-07-18
CN116451593B CN116451593B (zh) 2023-11-14

Family

ID=87130488

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310700718.0A Active CN116451593B (zh) 2023-06-14 2023-06-14 基于数据质量评估的强化联邦学习动态采样方法及设备

Country Status (1)

Country Link
CN (1) CN116451593B (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116614484A (zh) * 2023-07-19 2023-08-18 北京邮电大学 基于结构增强的异质数据联邦学习方法及相关设备
CN117521783A (zh) * 2023-11-23 2024-02-06 北京天融信网络安全技术有限公司 联邦机器学习方法、装置、存储介质及处理器

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021208720A1 (zh) * 2020-11-19 2021-10-21 平安科技(深圳)有限公司 基于强化学习的业务分配方法、装置、设备及存储介质
CN113992676A (zh) * 2021-10-27 2022-01-28 天津大学 端边云架构和完全信息下分层联邦学习的激励方法及系统
CN116187483A (zh) * 2023-02-10 2023-05-30 清华大学 模型训练方法、装置、设备、介质和程序产品

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021208720A1 (zh) * 2020-11-19 2021-10-21 平安科技(深圳)有限公司 基于强化学习的业务分配方法、装置、设备及存储介质
CN113992676A (zh) * 2021-10-27 2022-01-28 天津大学 端边云架构和完全信息下分层联邦学习的激励方法及系统
CN116187483A (zh) * 2023-02-10 2023-05-30 清华大学 模型训练方法、装置、设备、介质和程序产品

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116614484A (zh) * 2023-07-19 2023-08-18 北京邮电大学 基于结构增强的异质数据联邦学习方法及相关设备
CN116614484B (zh) * 2023-07-19 2023-11-10 北京邮电大学 基于结构增强的异质数据联邦学习方法及相关设备
CN117521783A (zh) * 2023-11-23 2024-02-06 北京天融信网络安全技术有限公司 联邦机器学习方法、装置、存储介质及处理器

Also Published As

Publication number Publication date
CN116451593B (zh) 2023-11-14

Similar Documents

Publication Publication Date Title
CN116451593B (zh) 基于数据质量评估的强化联邦学习动态采样方法及设备
CN110476172B (zh) 用于卷积神经网络的神经架构搜索
US11861474B2 (en) Dynamic placement of computation sub-graphs
CN113516248B (zh) 一种量子门测试方法、装置及电子设备
CN111898578B (zh) 人群密度的获取方法、装置、电子设备
CN108091166B (zh) 可用停车位数目变化的预测方法、装置、设备及存储介质
CN112766497B (zh) 深度强化学习模型的训练方法、装置、介质及设备
CN110633859B (zh) 一种两阶段分解集成的水文序列预测方法
CN116342172A (zh) 基于线性回归和决策树结合的油价预测方法、装置及设备
CN116050256A (zh) 移动位置的预测方法及相关设备
CN111798263A (zh) 一种交易趋势的预测方法和装置
CN117473032A (zh) 基于一致扩散的场景级多智能体轨迹生成方法及装置
CN116562650A (zh) 一种短期风电功率预测方法、装置及计算机可读存储介质
CN116466835A (zh) 笔迹预测方法、装置、电子设备及存储介质
CN116362251A (zh) 一种命名实体识别模型的训练方法、装置、设备和介质
CN116882536A (zh) 一种降雨数据预测方法、装置、电子设备和存储介质
CN116055489A (zh) 一种基于ddpg算法选择车辆的异步联邦优化方法
CN116402138A (zh) 一种多粒度历史聚合的时序知识图谱推理方法及系统
CN115668219A (zh) 生成对抗网络中的少样本域适应
CN111062468A (zh) 生成网络的训练方法和系统、以及图像生成方法及设备
CN118265053A (zh) 通信辅助感知场景下的高效数据增强优化方法及相关设备
CN114841276A (zh) 数据处理方法和装置、电子设备、计算机可读介质
CN117743859B (zh) 工业分析模型的训练方法、使用方法及介质
CN115358304A (zh) 标签生成模型的训练方法、生成标签的方法及相关设备
CN118503090A (zh) 基于业务中台微服务的智能感知方法、装置及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant