CN113191484A - 基于深度强化学习的联邦学习客户端智能选取方法及系统 - Google Patents

基于深度强化学习的联邦学习客户端智能选取方法及系统 Download PDF

Info

Publication number
CN113191484A
CN113191484A CN202110449033.4A CN202110449033A CN113191484A CN 113191484 A CN113191484 A CN 113191484A CN 202110449033 A CN202110449033 A CN 202110449033A CN 113191484 A CN113191484 A CN 113191484A
Authority
CN
China
Prior art keywords
client
learning
federal
clients
selection
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110449033.4A
Other languages
English (en)
Other versions
CN113191484B (zh
Inventor
张尧学
邓永恒
吕丰
任炬
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Central South University
Original Assignee
Tsinghua University
Central South University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University, Central South University filed Critical Tsinghua University
Priority to CN202110449033.4A priority Critical patent/CN113191484B/zh
Publication of CN113191484A publication Critical patent/CN113191484A/zh
Application granted granted Critical
Publication of CN113191484B publication Critical patent/CN113191484B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种基于深度强化学习的联邦学习客户端智能选取方法及系统,该方法包括:联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及客户端选择方案从多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给客户端选择智能体,以奖励用于优化更新策略网络;策略网络通过强化学习方法离线训练得到。本发明可从候选移动边缘设备中选择高质量的设备参与联邦学习,以处理分布式客户端低质量数据问题,以显著提高联邦学习质量。

Description

基于深度强化学习的联邦学习客户端智能选取方法及系统
技术领域
本发明涉及大规模分布式边缘智能学习系统的性能优化技术领域,尤其涉及一种基于深度强化学习的联邦学习客户端智能选取方法及系统。
背景技术
移动边缘设备的普及使得边缘产生的数据快速增长,同时也促进了现代人工智能应用的繁荣发展。然而,由于隐私问题和高昂的数据传输成本,传统的在云端收集大量数据进行集中式模型训练的机制变得不太可取。为了在不泄露隐私的前提下充分利用数据资源,一种新的学习范式应运而生,即联邦学习(Federated Learning,FL),它可以让移动边缘设备在不共享其原始数据的情况下协同训练全局模型。在联邦学习中,分布式设备使用自己的数据在本地训练全局模型,然后将模型更新提交给服务器进行模型聚合,聚合后的模型更新用于更新全局模型,然后返回给每个设备以进行下一轮的迭代。全局模型的训练过程便可以通过这种方式以分布式和隐私保护的方式迭代完成。
联邦学习尽管在隐私保护方面具有巨大的潜力,但在实现高性能学习质量方面仍然面临着技术挑战。与在数据中心进行训练时数据充足且资源不受限制不同,参与联邦学习的分布式设备通常在硬件条件和数据资源上都受到限制,且存在异质性,这会极大地影响学习性能。例如,由于传感器的缺陷和功率的限制,移动设备难免会收集一些错误标注的低质量数据,导致设备本地学习质量参差不齐。然而,不加区分地聚合低质量的模型更新会反向恶化全局模型的质量。因此,客户端选择,尤其是从候选客户端中选择合适的移动设备参与分布式学习,成为高质量联邦学习的关键。
最近,现有的一些工作提出了一些联邦学习的客户端选择方案。例如,Nishio等人提出了一种资源感知的选择方案,根据客户端的计算和通信资源选择客户端,使得能够在有限的资源约束下最大限度地增加参与者的数量,加速联邦学习性能的提升。Mohammed等人通过选择模型测试精度较高的候选客户端参与联邦学习的训练过程,提高了联邦学习的学习精度。Huang等人提出了一种有公平性保证的客户端选择方案,可以在联邦学习的训练效率和公平性之间取得良好的权衡。为了减少联邦学习训练的延迟,Xia等人提出了一种基于多臂老虎机的在线客户端调度方案,可以显著缩短模型训练的时间开销。Wang等人提出利用强化学习智能选择联邦学习的参与客户端,以克服客户端非独立同分布的数据对学习性能的负面影响,加快模型训练过程。但是,现有的客户端选择方案并没有充分考虑客户端的数据质量对联邦学习性能的影响,如何综合考虑客户端的数据数量、数据质量、计算资源等因素对模型训练质量的影响,为联邦学习智能地选取高质量的参与节点,仍需进一步探索和研究。
发明内容
本发明提供了一种基于深度强化学习的联邦学习客户端智能选取方法(以下简称AUCTION)及系统,用以解决现有的客户端选择方案并没有充分考虑客户端的数据数量、数据质量、计算资源等因素对联邦学习性能的影响的技术问题。
为解决上述技术问题,本发明提出的技术方案为:
一种基于深度强化学习的联邦学习客户端智能选取方法,应用于联邦服务市场框架,联邦服务市场框架包括一个以一定的预算招募客户端完成联邦学习任务的联邦平台和多个愿意参与联邦学习并向联邦平台提交联邦学习任务的候选客户端;包括以下步骤:
联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及客户端选择方案从多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给客户端选择智能体,以奖励用于优化更新策略网络;策略网络通过强化学习方法离线训练得到。
作为本发明的方法的进一步改进:
客户端选择智能体,为基于编码器-解码器结构的策略网络,编码器将客户端状态映射为中间向量表示,解码器根据中间向量表示生成客户端选择方案;客户端状态包括数据大小、数据质量和价格。
优选地,策略网络的强化学习模型,包括状态、动作、奖励和策略:
状态:状态s={x1,x2,…,xn}包含给定联邦学习任务所有候选客户端的特征,每个客户端Ci的特征xi是一个三维向量,用xi={qi,di,bi}表示,其中qi和di分别是客户端Ci的数据质量和用于训练的样本数量,bi是客户端Ci完成该学习任务的价格;
动作:采用顺序动作,即客户端选择代理通过采取一系列的动作一一做出客户端选择决策;一个单独的动作只从一组最多N个候选客户端中选出一个客户端;
奖励:将执行客户端选择操作后从联邦服务市场观察到的奖励r作为训练后损失函数值的减少率,即:
Figure BDA0003038069850000021
其中,F(w)是学习任务测试数据集上的初始全局损失函数值,F(w*)是经过选定客户端的多轮协同训练后达到的测试损失函数值;
策略:将客户端选择的一个可行动作a={a1,…,ai,…}定义为候选客户端的一个子集,其中ai∈{C1,C2,…,Cn}且
Figure BDA0003038069850000031
策略网络为一个随机的客户端选择策略π(a|s,B)用于在给定状态s和学习预算B的情况下选择一个可行动作a;训练策略网络的目标是最大化累计奖励。
优选地,最大化累计奖励,表示为:
Figure BDA0003038069850000032
其中r(a|s)是在状态s执行动作a后的奖励;
使用REINFORCE算法来优化J,使用梯度下降来不断优化参数θ:
Figure BDA0003038069850000033
其中b(s)代表一个独立于a的基准函数用于加速训练过程;参数θ是编码器和解码器可学习参数的并集。
优选地,编码器包括:
客户端嵌入层首先通过线性投影把三维输入特征xi转化为初始的dh维嵌入向量
Figure BDA0003038069850000034
Figure BDA0003038069850000035
其中Wx和bx为可学习参数;
然后,嵌入向量会经过L个注意力层更新,其中,每一个注意力层l∈{1,2,…,L}输出嵌入向量
Figure BDA0003038069850000036
每个注意力层包含一个MHA层和一个FF层,每层后面都添加了一个跳跃连接和批归一化。
优选地,解码器包括:
基于编码器输出的嵌入向量和解码器在时间t′<t时间输出的客户端选择结果,解码器在每个时间点t输出一个选中的客户端at直到学习预算用尽;解码器的网络包含一个多头注意力层和一个单头注意力层。
本发明还提供一种计算机系统,包括存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述任一方法的步骤。
本发明具有以下有益效果:
1、本发明的基于深度强化学习的联邦学习客户端智能选取方法及系统,可以利用客户端当前的学习质量相关的监测信息和历史的模型训练记录,自动学习客户端选择策略,以在联邦学习服务市场中实时地做出客户端选择决策。
2、在优选方案中,本发明利用深度强化学习技术,将客户端选择策略编码为神经网络,将每个客户端的数据大小、数据质量和学习价格作为输入,并输出在学习预算内选择的客户端子集,然后策略网络观察所选客户端的联邦学习性能,再利用策略梯度算法逐步改进其客户端选择策略。
3、本发明的基于深度强化学习的联邦学习客户端智能选取方法及系统,为了能够适应联邦服务市场中客户端数量的动态变化并减小强化学习算法的搜索空间,本发明设计了基于编码器-解码器结构的策略网络,其中编码器采用注意力机制将客户端信息转化为嵌入向量,然后解码器再根据编码器输出的嵌入向量进行顺序的客户端选择策略。
除了上面所描述的目的、特征和优点之外,本发明还有其它的目的、特征和优点。下面将参照附图,对本发明作进一步详细的说明。
附图说明
构成本申请的一部分的附图用来提供对本发明的进一步理解,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是本发明优选实施例的联邦服务市场的示意图;
图2是本发明优选实施例的基于深度强化学习的联邦学习客户端智能选取方法(AUCTION)的流程示意图;
图3是本发明优选实施例的基于深度强化学习的联邦学习客户端智能选取方法(AUCTION)的流程图;
图4是本发明优选实施例的训练客户端选择智能体的过程示意图;图4(a)为对于MLP MNIST学习任务;图4(b)为对于LeNet-5 FMNIST学习任务;图4(c)为对于ResNet-18CIFAR-10学习任务;
图5是10个候选客户端下本发明优选实施例(AUCTION)与其他客户端选择方案的性能对比图;图5(a)为对于MLP MNIST学习任务;图5(b)为对于LeNet-5 FMNIST学习任务;图5(c)为对于ResNet-18 CIFAR-10学习任务;
图6是50个候选客户端下本发明优选实施例(AUCTION)与其他客户端选择方案的性能对比图;图6(a)为对于MLP MNIST学习任务;图6(b)为对于LeNet-5 FMNIST学习任务;图6(c)为对于ResNet-18 CIFAR-10学习任务;
图7是学习预算对性能的影响图;图7(a)学习预算=5;图7(b)学习预算=10;图7(c)学习预算=15;图7(d)学习预算=20;
图8是LeNet-5 FMNIST学习任务在不同的客户端选择模型下的损失函数减少量,其中’ours-10’代表使用10个候选客户端训练得到的AUCTION模型。
具体实施方式
以下结合附图对本发明的实施例进行详细说明,但是本发明可以由权利要求限定和覆盖的多种不同方式实施。
图1是本实施例中所称的一个典型的联邦服务市场框架的示意图,其包括有一个联邦平台和一些愿意参与联邦学习的候选客户端,联邦平台以一定的预算来招募客户端完成任务,参与联邦学习的客户端可以向联邦平台提交联邦学习任务。对于一个给定的联邦学习任务,存在一组N个客户端
Figure BDA0003038069850000053
愿意以{b1,b2,…bn}的价格参与其中,每个客户端Ci维护着一组与该联邦学习任务相关的私有本地数据样本
Figure BDA0003038069850000051
然而,一些客户端的训练样本可能会被错误标注,这在现实中很常见但是会显著恶化联邦学习的性能。因此,为了达到令人满意的学习性能,联邦平台需要在给定的联邦学习任务预算B内从客户端集合
Figure BDA0003038069850000052
中选择一组最优的客户端。被选中的客户端将使用它们的私有数据样本协同训练联邦学习模型,然后获得它们声明的报酬。
参见图2,本发明的基于深度强化学习的联邦学习客户端智能选取方法,包括以下步骤:联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及客户端选择方案从多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给客户端选择智能体,以奖励用于优化更新策略网络;策略网络通过强化学习方法离线训练得到。
本发明实施例的客户端选择方案使用神经网络即策略网络作为客户端选择智能体,它将客户端的状态作为输入,并输出客户端选择动作。客户端状态包括对联邦学习性能产生重要影响的数据大小、数据质量和价格,而动作则决定选择哪些客户端参与联邦学习任务模型的训练过程。为了使得客户端选择方案能够适应客户端数量的动态变化并减少动作搜索空间,本发明设计了一个基于编码器-解码器结构的策略网络,编码器将客户端状态映射为中间向量表示,然后解码器根据这些中间向量表示生成客户端选择方案。本发明使用强化学习的方法离线训练策略网络,首先从联邦服务市场环境中收集客户端的状态,然后智能体根据当前环境状况做出客户端选择动作。之后,被选中的客户端协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给智能体,奖励用于更新策略网络,逐步完善客户端选择策略。
本发明实施例的强化学习模型。客户端选择问题的强化学习建模,包括状态、动作、奖励和策略。
1)状态(state):状态s={x1,x2,…,xn}包含给定联邦学习任务所有候选客户端的特征,每个客户端Ci的特征xi是一个三维向量,用xi={qi,di,bi}表示,其中qi和di分别是客户端Ci的数据质量和用于训练的样本数量,bi是客户端Ci完成该学习任务的价格(即应得的报酬)。由于出于隐私考虑,无法访问到每个客户端的原始数据,因此无法直接获取每个客户端的训练数据样本的质量(即训练数据中标签错误的样本所占的比例)。本发明使用每个客户端Ci贡献的本地模型的测试精度来代表数据质量qi作为客户端Ci的特征之一,这样便可以在不破坏数据隐私的情况下捕获客户端的数据质量特征。
2)动作(action):为了降低动作空间,本发明采用顺序动作,即客户端选择代理通过采取一系列的动作一一做出客户端选择决策。一个单独的动作只从一组最多N个候选客户端中选出一个客户端,通过这样的动作序列,动作空间可以减少到O(N)。
3)奖励(reward):客户端选择策略的目标是最小化模型训练的损失函数。因此,本发明将执行客户端选择操作后(即在所选的客户端使用其本地数据样本训练全局模型后)从联邦服务市场观察到的奖励r设置为训练后损失函数值的减少率,即:
Figure BDA0003038069850000061
其中,F(w)是学习任务测试数据集上的初始全局损失函数值,F(w*)是经过选定客户端的多轮协同训练后达到的测试损失函数值。
4)将客户端选择的一个可行动作a={a1,…,ai,…}定义为候选客户端的一个子集,其中ai∈{C1,C2,…,Cn}且
Figure BDA0003038069850000062
AUCTION的策略网络定义了一个随机的客户端选择策略π(a|s,B)用于在给定状态s和学习预算B的情况下选择一个可行动作a。
本发明实施例的客户端选择策略网络。如图3所示,AUCTION的策略网络是一个基于注意力(attention)机制的深度神经网络模型,由编码器(Encoder)网络和解码器(Decoder)网络组成。具体如下:
1)编码器:
在编码器网络中,客户端嵌入层(Client Embedding Layer)首先通过线性投影把三维输入特征xi转化为初始的dh维嵌入向量
Figure BDA0003038069850000063
其中Wx和bx为可学习参数。然后,嵌入向量会经过L个注意力(attention)层更新,其中,每一个注意力层l∈{1,2,…,L}输出嵌入向量
Figure BDA0003038069850000064
参照Transformer的编码器结构,每个注意力层包含一个multi-headattention(MHA,多头注意力)层和一个fully connected feed-forward(FF,全连接前馈)层,每层后面都添加了一个跳跃连接(skip-connection)和批归一化(BN,batchnormalization):
Figure BDA0003038069850000065
Figure BDA0003038069850000066
multi-head attention(MHA)层由M个并行运行的注意力头组成,每个客户端Ci的MHA值根据每个头的输出
Figure BDA0003038069850000067
计算得到:
Figure BDA0003038069850000068
其中
Figure BDA0003038069850000069
是一个可学习的参数矩阵。给定一个客户端嵌入向量hi
Figure BDA00030380698500000610
的值由自注意力机制计算得到:
Figure BDA0003038069850000071
Figure BDA0003038069850000072
Figure BDA0003038069850000073
其中
Figure BDA0003038069850000074
Figure BDA0003038069850000075
为可学习的参数矩阵,每个客户端的查询(query)qi、键(key)ki和值(value)vi是通过映射相同的嵌入向量hi来计算的dk是query/key向量的维度。
FF层的值由两个线性变换与ReLu激活函数计算得到:
Figure BDA0003038069850000076
2)解码器:
基于编码器输出的嵌入向量和解码器在时间t′<t时间输出的客户端选择结果,解码器在每个时间点t输出一个选中的客户端at直到学习预算用尽。解码器网络包含一个multi-head attention(多头注意力)层和一个single-head attention(单头注意力)层。Multi-head attention层的值d(0)由注意力机制计算得到。具体来说,以编码器的输出,即最终编码器输出的客户端嵌入向量
Figure BDA0003038069850000077
作为输入,解码器首先计算一个聚合嵌入向量
Figure BDA0003038069850000078
为了提高效率,我们只从聚合嵌入向量
Figure BDA0003038069850000079
中计算每个注意力头(head)的单个查询qs而从客户端嵌入向量
Figure BDA00030380698500000710
Figure BDA00030380698500000711
其中
Figure BDA00030380698500000712
Figure BDA00030380698500000713
为可学习的参数矩阵。为了确保选中的客户端不重复并且不超过学习预算,本发明在时间t为每个客户端Ci定义了一个注意力mask(标志)
Figure BDA00030380698500000714
让at-1=(a1,a2,…,at-1)代表在时间点t-1已经被选中的客户端,Bt-1代表剩余的学习预算,即
Figure BDA00030380698500000715
Figure BDA00030380698500000716
定义:
Figure BDA00030380698500000717
然后计算权重asj并且标志在时间点t不能被访问的客户端:
Figure BDA00030380698500000718
最后,multi-head attention值d(0)可基于每个head(头)的输出
Figure BDA0003038069850000081
利用公式(3)计算得到:
Figure BDA0003038069850000082
为了计算在时间点t选择客户端Ci的概率
Figure BDA0003038069850000083
multi-head attention层之后有一个single-head attention层。查询q和键ki分别由multi-head attention值d(0)和客户端嵌入向量
Figure BDA0003038069850000084
计算得到:
Figure BDA0003038069850000085
其中
Figure BDA0003038069850000086
Figure BDA0003038069850000087
为可学习的参数矩阵。然后为每个客户端计算权重ai并使用tanh运算将结果限制在范围[-C,C]内:
Figure BDA0003038069850000088
最后,在时间点t选择客户端Ci的概率
Figure BDA0003038069850000089
可使用softmax运算计算得到:
Figure BDA00030380698500000810
例如:在客户端选择策略网络的编码器网络中,设置dh=128,并使用L=3个注意力层,每一个注意力层由一个M=8个注意力头的multi-head attention层和一个有着512维隐藏子层的全连接前馈层组成;在解码器网络中,设置C=10,并且解码器中的multi-head attention层也有M=8个注意力头。
本发明实施例的策略网络的训练:
策略网络的参数θ是编码器和解码器可学习参数的并集。策略网络训练的目标为优化随机策略πθ(a|s,B)的参数θ,该策略给定一组状态为s的输入客户端,并赋予联邦学习性能高的客户端选择策略(即具有高回报的策略)以高的概率。为此,使用策略梯度法来优化策略网络的参数:对于给定的学习任务,客户端选择代理首先观察联邦服务市场环境的状态s,包括每个候选客户端Ci的特征xi={qi,di,bi},对于数据质量特征qi,每个候选客户端将使用一小部分相同数量的本地数据样本来训练全局模型,并将得到的模型更新上传到AUCTION,然后AUCTION使用测试数据集评估每个模型的精度作为数据质量特征。在模型训练过程中,可以获得数据大小特征di,而每个客户端会向平台上报价格特征bi。然后,客户端选择代理根据策略选择一个动作a,联邦服务市场执行动作a的过程即选择a中的客户端参与模型训练。具体来说,在每一轮中,每个被选中的客户端Ci利用本地di个数据样本训练全局模型,并将模型更新提交到联邦平台进行聚合,这样,全局模型便可被迭代更新。随后,客户端选择代理评估全局模型的质量并得到奖励r。之后便可以根据(state,action,reward)经验更新策略网络。
训练策略网络的目标是最大化累计奖励:
Figure BDA0003038069850000091
其中r(a|s)是在状态s执行动作a后的奖励。本发明使用REINFORCE算法来优化J,并使用梯度下降来不断优化它的参数θ:
Figure BDA0003038069850000092
其中b(s)代表一个独立于a的基准函数用于加速训练过程。本发明定义b(s)为迄今为止训练得到的最好的模型输出的客户端选择动作,也就是说,通过贪婪地选择概率最大的动作来获得b(s)的值。这样一来,如果客户端选择动作a的奖励比贪婪选择的好,则r(a|s)-b(s)为正值,导致该动作被强化,所以策略模型就会被训练得不断进步。
具体训练算法为:首先,随机生成一个训练集合
Figure BDA0003038069850000093
其中每个样本
Figure BDA0003038069850000094
代表联邦服务市场的一种状态,si中候选客户端的特征从均匀分布随机生成。训练集合
Figure BDA0003038069850000095
以及训练迭代次数E、批次大小Bs和学习预算B作为输入,算法在E次迭代后输出策略网络的更新参数θ。在每次迭代中,算法从集合
Figure BDA00030380698500000910
中抽取一批大小为Bs的样本,对于每个样本si,客户端选择代理首先从策略πθ(a|s,B)中得到一个可行的动作ai,然后贪婪地选择动作bi。之后,联邦服务市场分别执行动作ai和bi,并分别计算奖励r(ai|si)和
Figure BDA0003038069850000096
最后,算法计算梯度
Figure BDA0003038069850000097
Figure BDA0003038069850000098
并利用Adam优化器更新参数θ,
Figure BDA0003038069850000099
使用以下三个联邦学习任务来评估本发明提出的基于深度强化学习的联邦学习客户端智能选取方法及系统:1)MLP MNIST,使用Multi-layer Perceptron(MLP)模型训练MNIST数据集;2)LeNet-5 FMNIST,使用LeNet-5模型训练Fashion-MNIST(FMNIST)数据集;3)ResNet-18 CIFAR-10,使用ResNet-18模型训练CIFAR-10数据集。并将本发明提出的联邦学习客户端选择方案与以下三个可行的客户端选择方案比较:1)Greedy(贪心算法),基于数据大小与数据质量的乘机除以价格的值,即di·qi/bi贪心地选择数据规模大、数据质量高、价格低的客户端;2)Random(随机算法),在学习预算B内随机地选择客户端;3)Pricefirst(价格优先算法),优先选择学习价格低的客户端,以在有限的学习预算内选择尽可能多的客户端。
本发明提出的客户端选择方案先为每个学习任务固定一个客户端规模,然后离线训练一个客户端选择智能体,再利用训练后的智能体对每个具有可变规模候选客户端的学习任务进行在线客户端选择决策。图4展示了客户端选择智能体在3个联邦学习任务上的训练过程,其中候选客户端的数量为20,每个学习任务的预算为10。平均奖励为一个小批量(minibatch)内的平均奖励,具体而言,对于MLP MNIST任务和LeNet-5 FMNIST任务,奖励为5轮次联邦训练后在各自的测试数据集上测得的损失函数减少量,而对于ResNet-18CIFAR-10任务,奖励为20轮次联邦训练后在CIFAR-10测试数据集上测得的损失函数减少量。从图4中可以看到客户端选择智能体对每个学习任务的训练奖励在经过几十个小批量(minibatches)的训练后,可以很快收敛到一个稳定的较高值,这说明客户端选择智能体可以有效学习到如何做出最优的客户端选择策略。
模拟一个联邦服务市场,每个学习任务有10个候选客户端,其中一半客户端的训练数据样本标注错误,错误率从(0,1)随机生成。图5展示了在学习预算为10的情况下,采用不同的客户端选择策略,每个学习任务的损失函数值减少量。可以看到,对于3个学习任务,AUCTION的性能显著优于其他客户端选择方案。此外,可以发现,由于Greedy机制在客户选择过程中同时考虑了数据大小、数据质量和价格,因此,Greedy机制的性能优于Random和Price first机制。这说明数据大小、数据质量和价格对学习性能都是至关重要的,AUCTION可以在这三者之间做出更好的权衡,从而获得更优异的性能。
为了证明AUCTION的鲁棒性,接下来评估其在联邦服务市场上面对大规模候选客户端时的性能。图6展示了有50个候选客户端时每个学习任务的性能。同样,有50%的客户端拥有错误的训练数据样本,每个学习任务的预算为10.可以看到,AUCTION在大规模客户端场景下仍然表现良好,明显优于其他客户端选择方案,因此也证明了AUCTION在不同的联邦服务市场环境下对每个学习任务的高效性和鲁棒性。
图7展示了LeNet-5 FMNIST任务在不同的学习预算下的性能,其中有50个候选客户端。从实验结果中可以得出两点结论。首先,AUCTION在所有情况下的性能都优于其他的客户端选择方案。其次,当增加学习预算时,AUCTION与其他客户端选择方案,尤其是Greedy机制的性能差距变得越来越显著。这是因为随着学习预算的增加,AUCTION的性能先是增加,然后稳定在一个相对较高的值,而Greedy的性能先是增加,但随后开始下降。究其原因,当数据样本标注错误的客户端数量固定时,随着学习预算的增加,Greedy可以选择更多的客户端来完成学习任务,但选择低质量训练数据的客户端的可能性也会增加。结果,所选取的具有低质量训练数据样本的客户端对学习性能产生了负面影响,导致损失函数减少量的下降。然而,AUCTION仍然可以自适应学习预算的变化,保持相对稳定的性能。
图8展示了LeNet-5 FMNIST学习任务通过不同的客户端选择模型选择的参与客户端经过30轮联邦训练后得到的损失函数值减少量。具体来说,先分别使用10-50个候选客户端离线训练AUCTION客户端选择智能体,然后再分别使用10-50个不同数量的候选客户端在线评估其客户端选择性能,其中学习预算设置为10。可以看到,当面对不同数量的在线候选客户端时,训练后的AUCTION模型也能有很好的表现。这意味着AUCTION对于不同数量的客户端有很好的可扩展性,这在真实的联邦服务市场中更加实用。
综上所述,本发明提出的基于深度强化学习的联邦学习客户端智能选取方法及系统,可处理分布式客户端低质量数据问题,以显著提高联邦学习质量。并且可以自动学习高质量的联邦学习客户端选择方案,相比其他的客户端选择方案有更好的性能。本发明设计的基于编码器-解码器结构的客户端选择策略网络,可自适应客户端数量的动态变化,实用性强。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (7)

1.一种基于深度强化学习的联邦学习客户端智能选取方法,应用于联邦服务市场框架,所述联邦服务市场框架包括一个以一定的预算招募客户端完成联邦学习任务的联邦平台和多个愿意参与联邦学习并向联邦平台提交联邦学习任务的候选客户端;其特征在于,包括以下步骤:
联邦平台通过从联邦服务市场环境中收集客户端的状态作为输入,输入到基于策略网络的客户端选择智能体中,输出客户端选择方案;联邦平台根据当前环境状况以及所述客户端选择方案从所述多个候选客户端中选取一组最优的客户端以协同训练联邦学习模型,并将联邦学习性能作为奖励反馈给所述客户端选择智能体,以奖励用于优化更新策略网络;所述策略网络通过强化学习方法离线训练得到。
2.根据权利要求1所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述客户端选择智能体,为基于编码器-解码器结构的策略网络,编码器将客户端状态映射为中间向量表示,解码器根据所述中间向量表示生成客户端选择方案;所述客户端状态包括数据大小、数据质量和价格。
3.根据权利要求2所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述策略网络的强化学习模型,包括状态、动作、奖励和策略:
状态:状态s={x1,x2,…,xn}包含给定联邦学习任务所有候选客户端的特征,每个客户端Ci的特征xi是一个三维向量,用xi={qi,di,bi}表示,其中qi和di分别是客户端Ci的数据质量和用于训练的样本数量,bi是客户端Ci完成该学习任务的价格;
动作:采用顺序动作,即客户端选择代理通过采取一系列的动作一一做出客户端选择决策;一个单独的动作只从一组最多N个候选客户端中选出一个客户端;
奖励:将执行客户端选择操作后从联邦服务市场观察到的奖励r作为训练后损失函数值的减少率,即:
Figure FDA0003038069840000011
其中,F(w)是学习任务测试数据集上的初始全局损失函数值,F(w*)是经过选定客户端的多轮协同训练后达到的测试损失函数值;
策略:将客户端选择的一个可行动作a={a1,…,ai,…}定义为候选客户端的一个子集,其中ai∈{C1,C2,…,Cn}且
Figure FDA0003038069840000012
策略网络为一个随机的客户端选择策略π(a|s,B)用于在给定状态s和学习预算B的情况下选择一个可行动作a;训练策略网络的目标是最大化累计奖励。
4.根据权利要求3所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述最大化累计奖励,表示为:
Figure FDA0003038069840000021
其中r(a|s)是在状态s执行动作a后的奖励;
使用REINFORCE算法来优化J,使用梯度下降来不断优化参数θ:
Figure FDA0003038069840000022
其中b(s)代表一个独立于a的基准函数用于加速训练过程;参数θ是编码器和解码器可学习参数的并集。
5.根据权利要求3所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述编码器包括:
客户端嵌入层首先通过线性投影把三维输入特征xi转化为初始的dh维嵌入向量
Figure FDA0003038069840000023
其中Wx和bx为可学习参数;
然后,嵌入向量会经过L个注意力层更新,其中,每一个注意力层l∈{1,2,…,L}输出嵌入向量
Figure FDA0003038069840000024
每个注意力层包含一个MHA层和一个FF层,每层后面都添加了一个跳跃连接和批归一化。
6.根据权利要求5所述的基于深度强化学习的联邦学习客户端智能选取方法,其特征在于,所述解码器包括:
基于编码器输出的嵌入向量和解码器在时间t′<t时间输出的客户端选择结果,解码器在每个时间点t输出一个选中的客户端at直到学习预算用尽;解码器的网络包含一个多头注意力层和一个单头注意力层。
7.一种计算机系统,包括存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现上述权利要求1至6中任一所述方法的步骤。
CN202110449033.4A 2021-04-25 2021-04-25 基于深度强化学习的联邦学习客户端智能选取方法及系统 Active CN113191484B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110449033.4A CN113191484B (zh) 2021-04-25 2021-04-25 基于深度强化学习的联邦学习客户端智能选取方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110449033.4A CN113191484B (zh) 2021-04-25 2021-04-25 基于深度强化学习的联邦学习客户端智能选取方法及系统

Publications (2)

Publication Number Publication Date
CN113191484A true CN113191484A (zh) 2021-07-30
CN113191484B CN113191484B (zh) 2022-10-14

Family

ID=76978829

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110449033.4A Active CN113191484B (zh) 2021-04-25 2021-04-25 基于深度强化学习的联邦学习客户端智能选取方法及系统

Country Status (1)

Country Link
CN (1) CN113191484B (zh)

Cited By (18)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113673696A (zh) * 2021-08-20 2021-11-19 山东鲁软数字科技有限公司 一种基于强化联邦学习的电力行业起重作业违章检测方法
CN114124784A (zh) * 2022-01-27 2022-03-01 军事科学院系统工程研究院网络信息研究所 一种基于垂直联邦的智能路由决策保护方法和系统
CN114153640A (zh) * 2021-11-26 2022-03-08 哈尔滨工程大学 一种基于深度强化学习的系统容错策略方法
CN114385376A (zh) * 2021-12-09 2022-04-22 北京理工大学 一种异构数据下边缘侧联邦学习的客户端选择方法
CN114492845A (zh) * 2022-04-01 2022-05-13 中国科学技术大学 资源受限条件下提高强化学习探索效率的方法
CN114554459A (zh) * 2022-01-19 2022-05-27 苏州大学 一种近端策略优化辅助的车联网联邦学习客户端选择方法
CN114595396A (zh) * 2022-05-07 2022-06-07 浙江大学 一种基于联邦学习的序列推荐方法和系统
CN114598667A (zh) * 2022-03-04 2022-06-07 重庆邮电大学 一种基于联邦学习的高效设备选择与资源分配方法
CN114627648A (zh) * 2022-03-16 2022-06-14 中山大学·深圳 一种基于联邦学习的城市交通流诱导方法及系统
CN115018086A (zh) * 2022-06-08 2022-09-06 河海大学 一种基于联邦学习的模型训练方法及联邦学习系统
CN115130683A (zh) * 2022-07-18 2022-09-30 山东大学 一种基于多代理模型的异步联邦学习方法及系统
WO2023036184A1 (en) * 2021-09-08 2023-03-16 Huawei Cloud Computing Technologies Co., Ltd. Methods and systems for quantifying client contribution in federated learning
CN115829028A (zh) * 2023-02-14 2023-03-21 电子科技大学 一种多模态联邦学习任务处理方法及系统
WO2023109827A1 (zh) * 2021-12-15 2023-06-22 维沃移动通信有限公司 客户端筛选方法及装置、客户端及中心设备
WO2023185788A1 (zh) * 2022-03-28 2023-10-05 维沃移动通信有限公司 候选成员的确定方法、装置及设备
CN117557870A (zh) * 2024-01-08 2024-02-13 之江实验室 基于联邦学习客户端选择的分类模型训练方法及系统
WO2024032031A1 (zh) * 2022-08-09 2024-02-15 华为技术有限公司 一种数据分析方法及装置
CN114153640B (zh) * 2021-11-26 2024-05-31 哈尔滨工程大学 一种基于深度强化学习的系统容错策略方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018212918A1 (en) * 2017-05-18 2018-11-22 Microsoft Technology Licensing, Llc Hybrid reward architecture for reinforcement learning
US20180357552A1 (en) * 2016-01-27 2018-12-13 Bonsai AI, Inc. Artificial Intelligence Engine Having Various Algorithms to Build Different Concepts Contained Within a Same AI Model
US20200244707A1 (en) * 2019-01-24 2020-07-30 Deepmind Technologies Limited Multi-agent reinforcement learning with matchmaking policies
CN112348204A (zh) * 2020-11-05 2021-02-09 大连理工大学 一种基于联邦学习和区块链技术的边缘计算框架下海洋物联网数据安全共享方法
CN112465151A (zh) * 2020-12-17 2021-03-09 电子科技大学长三角研究院(衢州) 一种基于深度强化学习的多智能体联邦协作方法
CN112668128A (zh) * 2020-12-21 2021-04-16 国网辽宁省电力有限公司物资分公司 联邦学习系统中终端设备节点的选择方法及装置

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180357552A1 (en) * 2016-01-27 2018-12-13 Bonsai AI, Inc. Artificial Intelligence Engine Having Various Algorithms to Build Different Concepts Contained Within a Same AI Model
WO2018212918A1 (en) * 2017-05-18 2018-11-22 Microsoft Technology Licensing, Llc Hybrid reward architecture for reinforcement learning
US20200244707A1 (en) * 2019-01-24 2020-07-30 Deepmind Technologies Limited Multi-agent reinforcement learning with matchmaking policies
CN112348204A (zh) * 2020-11-05 2021-02-09 大连理工大学 一种基于联邦学习和区块链技术的边缘计算框架下海洋物联网数据安全共享方法
CN112465151A (zh) * 2020-12-17 2021-03-09 电子科技大学长三角研究院(衢州) 一种基于深度强化学习的多智能体联邦协作方法
CN112668128A (zh) * 2020-12-21 2021-04-16 国网辽宁省电力有限公司物资分公司 联邦学习系统中终端设备节点的选择方法及装置

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
IHAB MOHAMMED等: "Budgeted Online Selection of Candidate IoT Clients to Participate in Federated Learning", 《IEEE》 *
TAKAYUKI NISHIO等: "Client Selection for Federated Learning with Heterogeneous Resources in Mobile Edge", 《IEEE》 *
YUWEI WANG等: "A Novel Reputation-aware Client Selection Scheme for Federated Learning within Mobile Environments", 《IEEE》 *

Cited By (25)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113673696A (zh) * 2021-08-20 2021-11-19 山东鲁软数字科技有限公司 一种基于强化联邦学习的电力行业起重作业违章检测方法
CN113673696B (zh) * 2021-08-20 2024-03-22 山东鲁软数字科技有限公司 一种基于强化联邦学习的电力行业起重作业违章检测方法
WO2023036184A1 (en) * 2021-09-08 2023-03-16 Huawei Cloud Computing Technologies Co., Ltd. Methods and systems for quantifying client contribution in federated learning
CN114153640B (zh) * 2021-11-26 2024-05-31 哈尔滨工程大学 一种基于深度强化学习的系统容错策略方法
CN114153640A (zh) * 2021-11-26 2022-03-08 哈尔滨工程大学 一种基于深度强化学习的系统容错策略方法
CN114385376A (zh) * 2021-12-09 2022-04-22 北京理工大学 一种异构数据下边缘侧联邦学习的客户端选择方法
CN114385376B (zh) * 2021-12-09 2024-05-31 北京理工大学 一种异构数据下边缘侧联邦学习的客户端选择方法
WO2023109827A1 (zh) * 2021-12-15 2023-06-22 维沃移动通信有限公司 客户端筛选方法及装置、客户端及中心设备
CN114554459A (zh) * 2022-01-19 2022-05-27 苏州大学 一种近端策略优化辅助的车联网联邦学习客户端选择方法
CN114124784B (zh) * 2022-01-27 2022-04-12 军事科学院系统工程研究院网络信息研究所 一种基于垂直联邦的智能路由决策保护方法和系统
CN114124784A (zh) * 2022-01-27 2022-03-01 军事科学院系统工程研究院网络信息研究所 一种基于垂直联邦的智能路由决策保护方法和系统
CN114598667A (zh) * 2022-03-04 2022-06-07 重庆邮电大学 一种基于联邦学习的高效设备选择与资源分配方法
CN114627648A (zh) * 2022-03-16 2022-06-14 中山大学·深圳 一种基于联邦学习的城市交通流诱导方法及系统
WO2023185788A1 (zh) * 2022-03-28 2023-10-05 维沃移动通信有限公司 候选成员的确定方法、装置及设备
CN114492845A (zh) * 2022-04-01 2022-05-13 中国科学技术大学 资源受限条件下提高强化学习探索效率的方法
CN114492845B (zh) * 2022-04-01 2022-07-15 中国科学技术大学 资源受限条件下提高强化学习探索效率的方法
CN114595396A (zh) * 2022-05-07 2022-06-07 浙江大学 一种基于联邦学习的序列推荐方法和系统
CN115018086A (zh) * 2022-06-08 2022-09-06 河海大学 一种基于联邦学习的模型训练方法及联邦学习系统
CN115018086B (zh) * 2022-06-08 2024-05-03 河海大学 一种基于联邦学习的模型训练方法及联邦学习系统
CN115130683A (zh) * 2022-07-18 2022-09-30 山东大学 一种基于多代理模型的异步联邦学习方法及系统
WO2024032031A1 (zh) * 2022-08-09 2024-02-15 华为技术有限公司 一种数据分析方法及装置
CN115829028B (zh) * 2023-02-14 2023-04-18 电子科技大学 一种多模态联邦学习任务处理方法及系统
CN115829028A (zh) * 2023-02-14 2023-03-21 电子科技大学 一种多模态联邦学习任务处理方法及系统
CN117557870B (zh) * 2024-01-08 2024-04-23 之江实验室 基于联邦学习客户端选择的分类模型训练方法及系统
CN117557870A (zh) * 2024-01-08 2024-02-13 之江实验室 基于联邦学习客户端选择的分类模型训练方法及系统

Also Published As

Publication number Publication date
CN113191484B (zh) 2022-10-14

Similar Documents

Publication Publication Date Title
CN113191484B (zh) 基于深度强化学习的联邦学习客户端智能选取方法及系统
CN114297722B (zh) 一种基于区块链的隐私保护异步联邦共享方法及系统
Du et al. Beyond deep reinforcement learning: A tutorial on generative diffusion models in network optimization
CN113222179A (zh) 一种基于模型稀疏化与权重量化的联邦学习模型压缩方法
CN114595396B (zh) 一种基于联邦学习的序列推荐方法和系统
CN116523079A (zh) 一种基于强化学习联邦学习优化方法及系统
CN116471286A (zh) 基于区块链及联邦学习的物联网数据共享方法
CN115271099A (zh) 一种支持异构模型的自适应个性化联邦学习方法
CN113781002B (zh) 云边协同网络中基于代理模型和多种群优化的低成本工作流应用迁移方法
Shan et al. An end-to-end deep RL framework for task arrangement in crowdsourcing platforms
CN116108919A (zh) 一种基于相似特征协作的个性化联邦学习方法和系统
Chen et al. Generative adversarial reward learning for generalized behavior tendency inference
Xiao et al. Clustered federated multi-task learning with non-iid data
Chen et al. Profit-Aware Cooperative Offloading in UAV-Enabled MEC Systems Using Lightweight Deep Reinforcement Learning
CN115577797B (zh) 一种基于本地噪声感知的联邦学习优化方法及系统
Shen et al. An optimization approach for worker selection in crowdsourcing systems
CN113743012B (zh) 一种多用户场景下的云-边缘协同模式任务卸载优化方法
Mays et al. Decentralized data allocation via local benchmarking for parallelized mobile edge learning
CN117033997A (zh) 数据切分方法、装置、电子设备和介质
Zhang et al. Optimizing federated edge learning on non-IID data via neural architecture search
CN111027709B (zh) 信息推荐方法、装置、服务器及存储介质
Zeng et al. Enhanced federated learning with adaptive block-wise regularization and knowledge distillation
Ayyadurai et al. Cloud Computing Based Workload Optimization using Long Short Term Memory Algorithm
CN117674303B (zh) 一种基于数据价值阈值的虚拟电厂并行控制方法
Wang et al. Quality-oriented federated learning on the fly

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant