CN113379067A - 提升联邦学习在Non-IID和Mismatched场景下性能的方法 - Google Patents
提升联邦学习在Non-IID和Mismatched场景下性能的方法 Download PDFInfo
- Publication number
- CN113379067A CN113379067A CN202110717717.8A CN202110717717A CN113379067A CN 113379067 A CN113379067 A CN 113379067A CN 202110717717 A CN202110717717 A CN 202110717717A CN 113379067 A CN113379067 A CN 113379067A
- Authority
- CN
- China
- Prior art keywords
- model
- client
- iid
- local
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 49
- 238000012549 training Methods 0.000 claims abstract description 30
- 238000012545 processing Methods 0.000 claims description 6
- 230000003042 antagnostic effect Effects 0.000 claims description 2
- 238000004891 communication Methods 0.000 claims description 2
- 238000012216 screening Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000002360 preparation method Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- General Health & Medical Sciences (AREA)
- Physics & Mathematics (AREA)
- Bioethics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Data Mining & Analysis (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Electrically Operated Instructional Devices (AREA)
Abstract
本发明公开了一种提升联邦学习在Non‑IID和Mismatched场景下性能的方法,包括:步骤1,每个客户端在每轮训练中更新本地模型,客户端中设有生成对抗网络;步骤2,服务端接收到来自各客户端的本地模型后,聚合成一个中间态模型后传回至各客户端;步骤3,各客户端用接收到的中间态模型预测假数据并与其标签对比挑选出需要上传的假数据形成假数据子集合,上传至服务端;步骤4,服务端对接收自各客户端的假数据子集合进行处理,得到一个近似IID分布的假数据集,利用该假数据集对中间态模型进行训练得到该轮的最终全局模型。该方法在尽可能不影响正常训练的同时,提升了联邦学习在Non‑IID和Mismatched场景下的性能。
Description
技术领域
本发明涉及大数据处理领域,尤其涉及一种提升联邦学习在Non-IID和Mismatched场景下性能的方法。
背景技术
在大数据的时代背景下,人工智能和机器学习迅速发展。但考虑到实际情况,数据离散分布在众多公司企业内部,因数据隐私保护需求无法直接共享使用,联邦学习便应运而生。联邦学习保护隐私的特性一方面有利地扩展了机器学习的适用范围,另一方面也带来了诸多新的挑战。其中最为重要的一个挑战便是Non-IID场景。对于公司企业来说,无法要求他们所拥有的数据都是独立同分布的(IID),相反,他们的数据通常是Non-IID甚至是缺少某些类别的(Mismatched)。
FedAvg作为最经典的联邦学习框架,它在每一轮训练中将各个客户端的模型取平均得到全局模型。这种做法在IID场景中效果较好,然而,在Non-IID场景中,其表现开始逐渐变差且不稳定。针对存在的问题目前的方法主要分为以下两种解决方式:(1)改良的FedAvg;(2)个性化联邦学习。前者在原有的FedAvg基础上通过某些修改使其能适应Non-IID场景,后者试图在客户端的协作下为每个客户端生成一个更加优秀的个性化模型。目前的方法通常只利用到了模型参数而不涉及其他参数,虽然尽可能保护了隐私,但其性能也相对的明显下降。此外,现有方法鲜有提及如何改善Mismatched场景下联邦学习方法性能的问题。
发明内容
针对现有技术所存在的问题,本发明的目的是提供一种提升联邦学习在Non-IID和Mismatched场景下性能的方法,能解决现有联邦学习方法,由于只利用到了模型参数而不涉及其他参数,虽然尽可能保护了隐私,但存在性能明显下降以及没有改善Mismatched场景下联邦学习方法性能的问题。
本发明的目的是通过以下技术方案实现的:
本发明实施方式提供一种提升联邦学习在Non-IID和Mismatched场景下性能的方法,用于应用联邦学习处理数据的由多个客户端与至少一个服务端通信的网络系统中,包括:
步骤1,每个客户端在每轮训练中更新本地模型,所述客户端中设有预先训练好的本地的生成对抗网络;
步骤2,服务端接收到来自各客户端的本地模型后,将各客户端的本地模型聚合成一个中间态模型,将该中间态模型分别传回至各客户端;
步骤3,各客户端接收到所述中间态模型后,用该中间态模型预测假数据,并与所述假数据的标签对比挑选出需要上传的假数据形成假数据子集合,将所述假数据子集合上传至所述服务端;
步骤4,所述服务端对接收自各客户端的假数据子集合进行处理,得到一个近似IID分布的假数据集,利用该假数据集对所述中间态模型进行训练得到该轮的最终全局模型。
由上述本发明提供的技术方案可以看出,本发明实施例提供的提升联邦学习在Non-IID和Mismatched场景下性能的方法,其有益效果为:
通过在各客户端引入本地的生成对抗网络,各客户端可以产生少量假数据供服务端使用,而服务端使用这些假数据能获得更多有效信息从而使得学习效果大大提升;由于对假数据的严格限制和监管,客户端可以自主选择上传数据的数量、分布情况等,使得服务端难以从假数据中推断出真实信息,不影响客户端数据的应试保护。该方法在传统联邦学习的基础上解决了现实中经常遇到的Non-IID和Mismatched场景性能不佳的问题,在尽可能不影响正常训练的同时,提升了联邦学习在Non-IID和Mismatched场景下的性能。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的提升联邦学习在Non-IID和Mismatched场景下性能的方法流程图。
具体实施方式
下面结合本发明的具体内容,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。
参见图1,本发明实施例提供一种提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,用于应用联邦学习处理数据的由多个客户端与至少一个服务端通信的网络系统中,包括:
步骤1,每个客户端在每轮训练中更新本地模型,所述客户端中设有预先训练好的本地的生成对抗网络;
步骤2,服务端接收到来自各客户端的本地模型后,将各客户端的本地模型聚合成一个中间态模型,将该中间态模型分别传回至各客户端;
步骤3,各客户端接收到所述中间态模型后,用该中间态模型预测各客户端本地的生成对抗网络生成的假数据,并与所述假数据的标签对比挑选出需要上传的假数据形成假数据子集合,将所述假数据子集合上传至所述服务端;
步骤4,所述服务端对接收自各客户端的假数据子集合进行处理,得到一个近似IID分布的假数据集,利用该假数据集对所述中间态模型进行训练得到该轮的最终全局模型。
上述方法还包括:
步骤5,所述服务端将得到的最终全局模型发送至各客户端,由各客户端利用最终全局模型对本地数据进行处理。
上述方法,其特征在于,所述步骤1中,按以下方式各客户端维护一个本地的生成对抗网络,包括:
在每轮训练中,对生成对抗网络迭代一轮;
训练完成后使用该生成对抗网络生成为本地数据量5%的无标签假数据,通过本地模型对生成的无标签假数据进行标注。
上述方法的步骤2中,以FedAVG直接取加权平均方式将各客户端的本地模型聚合成一个中间态模型。
上述方法的步骤3中,需要上传的假数据为:在本地模型与中间态模型的预测结果分类不同的假数据。
上述方法中,客户端为能与服务端持续保持通信的智能移动终端、计算机中的任一种;
服务端为服务器。
本发明的方法通过在各客户端引入本地的生成对抗网络,各客户端可以产生少量假数据供服务端使用,而服务端使用这些假数据能获得更多有效信息从而使得学习效果大大提升;由于对假数据的严格限制和监管,客户端可以自主选择上传数据的数量、分布情况等,使得服务端难以从假数据中推断出真实信息,不影响客户端数据的应试保护。该方法在传统联邦学习的基础上解决了现实中经常遇到的Non-IID和Mismatched场景性能不佳的问题,在尽可能不影响正常训练的同时,提升了联邦学习在Non-IID和Mismatched场景下的性能。
下面对本发明实施例具体作进一步地详细描述。
参见图1,本发明实施例提供一种提升联邦学习在Non-IID和Mismatched场景下性能的方法,主要包括以下步骤:
步骤1,客户端在每轮训练中在更新本地模型的同时,维护一个本地的生成对抗网络;
上述步骤1中,生成对抗网络可以让客户端在开始联邦学习之前预先训练好以节省时间,在每轮训练中只需对该生成对抗网络迭代一轮以保证生成数据的多样性即可。训练完成后使用该生成对抗网络生成少量无标签假数据(生成的少量无标签假数据量为本地假数据量的5%),生成对抗网络本身是一个无监督训练过程,生成的假数据是无标签的,因此使用已训练好的本地模型对其标注;
步骤2,服务端接收到来自各个客户端的本地模型后先聚合起来,形成一个中间态模型,之后将该中间态模型分别再传回各客户端;
步骤3,各客户端接收到中间态模型后,用该中间态模型预测各客户端本地的生成对抗网络生成的假数据并与假数据标签对比,挑选出需要上传的假数据形成假数据子集合,将假数据子集合上传给服务端做进一步处理;
步骤4,服务端接收到来自各个客户端的假数据子集合,经过处理后得到一个近似IID分布的假数据集;将中间态模型在该假数据集上进一步训练得到该轮的最终全局模型。由于假数据集规模小,同时经过客户端的严格筛选,真实数据仍然保留于各个客户端本地,可以在一定程度上保证数据隐私的安全性。
本发明的方法解决了现实中联邦学习经常遇到的IID和Mismatched场景的问题,在尽可能不影响正常训练的同时,提升了联邦学习的性能。
实施例
本实施例提供的提升联邦学习在Non-IID和Mismatched场景下性能的方法,可用于语音识别或图像识别等场景中,包括:
步骤S1,前期准备:
参与联邦学习的各方客户端确认联邦学习任务并确定需要使用的本地数据集,客户端需要确认各自的数据隐私保护情况,制定严格的数据筛选机制,以判断哪些假数据可以上传用于服务端的训练,之后客户端各自训练好生成对抗网络后,等待联邦学习训练开始;
步骤S2,联邦学习训练:
基本训练步骤按照服务端接收到来自各个客户端的本地模型后先聚合起来,形成一个中间态模型,之后将该中间态模型分别再传回各客户端,由各客户端利用中间态模型对假数据进行预测筛选出符合各客户端各自指定的筛选规则的假数据形成假数据集合,上传至服务端用于训练中间态模型。本步骤在联邦学习过程中,服务端与各个客户端保持联系以确保客户端在线,客户端严格依照各自指定的筛选规则选择合适的假数据上传至服务端以训练中间态模型,训练结束标志以模型性能达到一定程度或以联邦学习训练管理员的判断为准;
步骤S3,后续处理:
服务端将最终的全局模型发回至所有的客户端,客户端验证该全局模型的性能并投入使用。
本领域普通技术人员可以理解:实现上述实施例方法中的全部或部分流程是可以通过程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存储记忆体(Random Access Memory,RAM)等。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。
Claims (6)
1.一种提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,用于应用联邦学习处理数据的由多个客户端与至少一个服务端通信的网络系统中,包括:
步骤1,每个客户端在每轮训练中更新本地模型,所述客户端中设有预先训练好的本地的生成对抗网络;
步骤2,服务端接收到来自各客户端的本地模型后,将各客户端的本地模型聚合成一个中间态模型,将该中间态模型分别传回至各客户端;
步骤3,各客户端接收到所述中间态模型后,用该中间态模型预测各客户端本地的生成对抗网络生成的假数据,并与所述假数据的标签对比挑选出需要上传的假数据形成假数据子集合,将所述假数据子集合上传至所述服务端;
步骤4,所述服务端对接收自各客户端的假数据子集合进行处理,得到一个近似IID分布的假数据集,利用该假数据集对所述中间态模型进行训练得到该轮的最终全局模型。
2.根据权利要求1所述的提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,还包括:
步骤5,所述服务端将得到的最终全局模型发送至各客户端,由各客户端利用最终全局模型对本地数据进行处理。
3.根据权利要求1或2所述的提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,所述步骤1中,按以下方式各客户端维护一个本地的生成对抗网络,包括:
在每轮训练中,对生成对抗网络迭代一轮;
训练完成后使用该生成对抗网络生成占本地假数据量5%的无标签假数据,通过本地模型对生成的无标签假数据进行标注。
4.根据权利要求1或2所述的提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,所述步骤2中,以FedAVG直接取加权平均方式将各客户端的本地模型聚合成一个中间态模型。
5.根据权利要求1或2所述的提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,所述步骤3中,需要上传的假数据为:在本地模型与中间态模型的预测结果分类不同的假数据。
6.根据权利要求1或2所述的提升联邦学习在Non-IID和Mismatched场景下性能的方法,其特征在于,所述方法中,客户端为能与服务端持续保持通信的智能移动终端、计算机中的任一种;
服务端为服务器。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110717717.8A CN113379067A (zh) | 2021-06-28 | 2021-06-28 | 提升联邦学习在Non-IID和Mismatched场景下性能的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110717717.8A CN113379067A (zh) | 2021-06-28 | 2021-06-28 | 提升联邦学习在Non-IID和Mismatched场景下性能的方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113379067A true CN113379067A (zh) | 2021-09-10 |
Family
ID=77579486
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110717717.8A Pending CN113379067A (zh) | 2021-06-28 | 2021-06-28 | 提升联邦学习在Non-IID和Mismatched场景下性能的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113379067A (zh) |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109816044A (zh) * | 2019-02-11 | 2019-05-28 | 中南大学 | 一种基于wgan-gp和过采样的不平衡学习方法 |
CN112488147A (zh) * | 2020-11-02 | 2021-03-12 | 东北林业大学 | 一种基于对抗网络的冗余去除主动学习方法 |
-
2021
- 2021-06-28 CN CN202110717717.8A patent/CN113379067A/zh active Pending
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109816044A (zh) * | 2019-02-11 | 2019-05-28 | 中南大学 | 一种基于wgan-gp和过采样的不平衡学习方法 |
CN112488147A (zh) * | 2020-11-02 | 2021-03-12 | 东北林业大学 | 一种基于对抗网络的冗余去除主动学习方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112181666B (zh) | 一种基于边缘智能的设备评估和联邦学习重要性聚合方法 | |
CN111310932A (zh) | 横向联邦学习系统优化方法、装置、设备及可读存储介质 | |
EP3885966B1 (en) | Method and device for generating natural language description information | |
US20240135191A1 (en) | Method, apparatus, and system for generating neural network model, device, medium, and program product | |
US20220318412A1 (en) | Privacy-aware pruning in machine learning | |
CN113988314A (zh) | 一种选择客户端的分簇联邦学习方法及系统 | |
CN107087160A (zh) | 一种基于BP‑Adaboost神经网络的用户体验质量的预测方法 | |
CN114884832B (zh) | 端云协同系统、分布式处理集群及移动端设备 | |
CN114529765B (zh) | 一种数据处理方法、设备以及计算机可读存储介质 | |
CN112346936A (zh) | 应用故障根因定位方法及系统 | |
CN111105016A (zh) | 一种数据处理方法、装置、电子设备及可读存储介质 | |
US10972703B2 (en) | Method, device, and storage medium for processing webcam data | |
CN114358307A (zh) | 基于差分隐私法的联邦学习方法及装置 | |
CN115841133A (zh) | 一种联邦学习方法、装置、设备及存储介质 | |
CN115189908B (zh) | 一种基于网络数字孪生体的随机攻击生存性评估方法 | |
CN114936377A (zh) | 模型训练和身份匿名化方法、装置、设备及存储介质 | |
JP2023001926A (ja) | 画像融合方法及び装置、画像融合モデルのトレーニング方法及び装置、電子機器、記憶媒体、並びにコンピュータプログラム | |
CN114492854A (zh) | 训练模型的方法、装置、电子设备以及存储介质 | |
CN116797346A (zh) | 基于联邦学习的金融欺诈行为检测方法及系统 | |
CN114510615A (zh) | 一种基于图注意力池化网络的细粒度加密网站指纹分类方法和装置 | |
CN110826867B (zh) | 车辆管理方法、装置、计算机设备和存储介质 | |
CN113379067A (zh) | 提升联邦学习在Non-IID和Mismatched场景下性能的方法 | |
CN110175283B (zh) | 一种推荐模型的生成方法及装置 | |
Grassucci et al. | Enhancing Semantic Communication with Deep Generative Models: An Overview | |
CN115146292A (zh) | 一种树模型构建方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20210910 |