CN115422537A - 一种抵御联邦学习标签翻转攻击的方法 - Google Patents
一种抵御联邦学习标签翻转攻击的方法 Download PDFInfo
- Publication number
- CN115422537A CN115422537A CN202210486095.7A CN202210486095A CN115422537A CN 115422537 A CN115422537 A CN 115422537A CN 202210486095 A CN202210486095 A CN 202210486095A CN 115422537 A CN115422537 A CN 115422537A
- Authority
- CN
- China
- Prior art keywords
- client
- local
- model
- benign
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 31
- 230000007306 turnover Effects 0.000 title description 7
- 238000012549 training Methods 0.000 claims abstract description 100
- 230000002776 aggregation Effects 0.000 claims abstract description 21
- 238000004220 aggregation Methods 0.000 claims abstract description 21
- 238000012797 qualification Methods 0.000 claims abstract description 16
- 230000006698 induction Effects 0.000 claims abstract description 11
- 238000013507 mapping Methods 0.000 claims abstract description 8
- 230000006870 function Effects 0.000 claims description 40
- 238000001514 detection method Methods 0.000 claims description 26
- 238000009826 distribution Methods 0.000 claims description 17
- 238000000605 extraction Methods 0.000 claims description 14
- 238000012512 characterization method Methods 0.000 claims description 10
- 238000010801 machine learning Methods 0.000 claims description 6
- 230000004931 aggregating effect Effects 0.000 claims description 4
- 230000003190 augmentative effect Effects 0.000 claims description 3
- 230000001939 inductive effect Effects 0.000 claims description 3
- 238000002372 labelling Methods 0.000 claims 1
- 238000012360 testing method Methods 0.000 description 11
- 230000007123 defense Effects 0.000 description 8
- 238000005457 optimization Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 4
- 230000008569 process Effects 0.000 description 4
- 238000005070 sampling Methods 0.000 description 4
- 230000003416 augmentation Effects 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000006116 polymerization reaction Methods 0.000 description 2
- 238000011176 pooling Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000002547 anomalous effect Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000033228 biological regulation Effects 0.000 description 1
- 238000004140 cleaning Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000003064 k means clustering Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 238000011895 specific detection Methods 0.000 description 1
- 231100000331 toxic Toxicity 0.000 description 1
- 230000002588 toxic effect Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/50—Monitoring users, programs or devices to maintain the integrity of platforms, e.g. of processors, firmware or operating systems
- G06F21/55—Detecting local intrusion or implementing counter-measures
- G06F21/56—Computer malware detection or handling, e.g. anti-virus arrangements
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Theoretical Computer Science (AREA)
- Computer Security & Cryptography (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Computer Hardware Design (AREA)
- Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Data Mining & Analysis (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Virology (AREA)
- Computer And Data Communications (AREA)
- Data Exchanges In Wide-Area Networks (AREA)
Abstract
本发明涉及联邦学习技术领域,提出一种抵御联邦学习标签翻转攻击的方法,包括以下步骤:服务端训练生成网络,该网络基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;客户端基于全局模型和生成网络检测其数据质量;服务端根据客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;良性客户端构建个性化模型,并以生成网络产生的潜在特征作为归纳偏置对本地训练进行调节;每个良性客户端完成本地训练后,分别将模型参数发送给服务端进行聚合,用于对服务端本地的全局模型进行更新;得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
Description
技术领域
本发明涉及联邦学习技术领域,更具体地,涉及一种抵御联邦学习标签翻转攻击的方法。
背景技术
在人工智能领域,传统的数据处理模式往往是一方收集数据,再转移到另一方进行处理、清洗并建模,最后把模型卖给第三方。但随着法规的完善和监控的愈加严格,如果数据离开收集方或者用户不清楚模型的具体用途,运营者可能会触犯法律问题。
一种可能的解决方案为“联邦学习”。联邦学习无需集中的数据管理器来收集和验证数据集,而是允许数据保存在节点(客户端)上,并引入一个中央协调器来构建全局模型,该模型通过每个客户端基于本地数据的更新参数进行优化,从而实现数据隐私保护。
联邦学习中存在的一个较大问题是,由于客户端数据是不可见且无法验证的,恶意客户端可以通过篡改本地数据,并将由这些数据训练得到的更新参数发送给服务端进行全局模型优化,从而实现对全局模型的攻击并破坏其性能。一种简单、高效且常见的攻击策略为“标签翻转攻击”——攻击者通过篡改部分样本的标签来注入恶意数据。简单之处在于这一攻击是任何用户都可以实施的,即实施者无需提前了解整个联邦学习系统,如系统流程、模型类型和参数等。有效之处在于,即便只有约50个中毒样本,也可以显著提高模型的误分类率,甚至达到90%。因此,在联邦学习中,标签翻转攻击是一个亟待解决的问题。
目前,联邦学习中大部分对标签翻转攻击的防御策略或需要大量的计算代价和通信代价,如重训练并验证算法在每个样本上的性能;或需要一些难以获得的先验知识,如需要提前估计异常样本数量的异常检测方案。这些方法在某些场景下体现了其有效性,但具体实践过程中还缺乏足够的鲁棒性。
发明内容
本发明为克服上述现有技术所述现有技术中代价过高、难以实践和鲁棒性不足的缺陷,提供一种抵御联邦学习标签翻转攻击的方法。
为解决上述技术问题,本发明的技术方案如下:一种抵御联邦学习标签翻转攻击的方法,包括以下步骤:
S1:服务端训练生成网络,所述生成网络基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
S2:服务端将全局模型和生成网络广播给每个客户端,每个客户端基于全局模型和生成网络检测其数据质量;
S3:服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
S4:每个良性客户端分别构建个性化模型,并以生成网络产生的潜在特征作为归纳偏置对本地训练进行调节;
S5:每个良性客户端完成本地训练后,分别将模型参数发送给服务端进行聚合,用于对服务端本地的全局模型进行更新;
S6:重复S1~S5,直至全局模型收敛或达到预设的停止条件,得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
优选地,所述S1步骤中,在第一轮训练开始前,服务端初始化由θ:=[θf,θp]参数化的全局模型和由ω参数化的生成网络,其中θf为特征提取模块,θp为预测模块;服务端训练生成网络的步骤包括:
将训练样本输入所述生成网络,生成网络输出训练样本的潜在特征,并结合所述潜在特征和全局模型中的预测模块θp,通过目标函数训练生成网络。
优选地,目标函数的表达式如下:
其中,h(z;θp)是预测模块θp的输出,l(·)为非负的凸损失函数;R是一个产生随机标签序列的函数,表示由生成网络Gω输出的关于随机标签序列的潜在特征,范式Ex~D表示样本x采样于数据分布D,如式中表示样本采样于函数R产生的数据分布空间,J(·)为机器学习中的代价函数。
本技术方案中,给定随机标签序列目标函数只需全局模型的预测模块θp就可以在服务端进行优化,即生成网络的训练不会给各客户端带来额外的计算或时间开销;相较于联邦系统的训练模型,该生成网络是十分轻量的,尤其是特征空间紧凑的情况下。因此完成生成网络的训练无需过多的额外时间成本。轻量的特性也使其易于训练和下载。
优选地,S2步骤中,每个客户端基于全局模型和生成网络检测其数据质量的步骤包括:
S21:计算全局模型的预测模块θp对客户端i的本地数据标签y所对应的潜在特征z~Gω(·|y)的预测值,所述预测值为第一预测值;
S22:计算全局模型对客户端的本地数据样本x的预测值,所述预测值为第二预测值;
S23:计算所述第一预测值和第二预测值中相同元素个数,根据所述相同元素个数计算得到客户端i的质量参数DQi;
S24:客户端i完成数据质量检测后,将质量参数DQi返回给服务端。
优选地,S2步骤中,获取第i个客户端的质量参数DQi的表达式如下:
其中,Di为客户端i的本地数据,|Di|为其本地数据量,z~Gω(·|y)表示由生成网络输出的关于y的潜在特征,Acc(a,b)是计算序列a和b中相同元素个数的函数,accmaxh(z;θp)表示第一预测值,argmaxh(f(x;θ))表示第二预测值。
本技术方案中,当accmaxh(z;θp)≠argmaxh(f(x;θ))时,说明其中部分标签被篡改,所以在恶意客户端的数据中所检测得到的质量参数在数值上要小于良性客户端。
优选地,S3步骤中,将客户端分类,并取消每个恶意客户端参与本轮训练的资格的步骤包括:
服务端通过聚类算法根据客户端i的质量参数DQi将客户端分为两类;
分别计算每个类中质量参数DQ的平均值;
比较两个类中质量参数DQ的平均值,将平均值较低的类中的客户端作为恶意客户端集合,将平均值较高的集合作为良性客户端集合;
取消恶意节点集中每个恶意客户端参与本轮训练的资格。
本技术方案中,每个客户端完成数据质量检测后,都将质量参数DQi返回给服务端,服务端随后将根据这些质量参数DQ:={DQ1,DQ2,...,DQN}通过质量参数累加并除以该类中的客户端数量检测恶意节点;因此服务端可以在每轮训练之前有效、及时地识别恶意客户端,无需像其他检测策略那样对客户端的本地数据进行再训练,从而限制他们参与全局聚合,以此抵御攻击。
优选地,S4步骤中,良性客户端i构建个性化模型,优化本地模型的步骤包括:
S42:计算良性客户端i的个性化模型在其本地数据Di上的经验风险L(θi);
S43:通过在本地训练阶段对良性客户端i中参与训练的本地数据标签进行计数,获得良性客户端i本地数据标签的先验分布的经验近似值p(y);
S44:每个良性客户端i从生成网络Gω中获得潜在特征z~Gω(·|y)作为增广表征,为本地训练引入归纳偏置,并根据经验近似值p(y)和经验风险L(θi)对本地模型进行优化。
本技术方案中,该个性化策略基于这样一种现象:“数据之间通常有相似的全局表征,而客户端之间的统计异质性主要集中在标签上。因此,可以让每个客户端优化一个个性化的低维预测模块并对其本地样本特征产生唯一的标签”。通过这一个性化设计,我们可以通过整合每个良性客户端的真实知识恢复潜在特征空间,并让生成网络从中学习并提取准确的潜在特征。
优选地,S42步骤中,计算良性客户端i的个性化模型在其本地数据Di上的经验风险L(θi)的表达式如下:
其中,Di为良性客户端i的本地数据,|Di|为其本地数据量,y为良性客户端i的本地数据标签,x为客户端i的本地数据样本,为良性客户端i的基础层,为良性客户端i的个性化层,l(·)为非负的凸损失函数,h(·)为预测层,为模型的特征提取模块对输入样本x的输出。
其中,范式Ex~D表示样本x采样于数据分布D,如式中Ey~p(y)表示本地数据标签y采样于经验近似值p(y)产生的数据分布空间,J(·)为机器学习中的代价函数。
一种抵御联邦学习标签翻转攻击的系统,应用于上述的抵御联邦学习标签翻转攻击的方法,包括:
生成网络训练模块,用于在服务端训练生成网络,并基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
数据质量检测模块,用于在客户端基于全局模型和生成网络检测每个客户端的数据质量;
客户端分类模块,用于在服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
客户端个性化模块,用于在客户端,每个良性客户端分别构建个性化模型,并以生成网络产生的潜在特征作为归纳偏置对本地训练进行调节;
全局模型聚合模块,用于在服务端,对每个良性客户端完成本地训练后返回的模型参数进行聚合,根据经过聚合的模型参数对全局模型进行更新。
与现有技术相比,本发明技术方案的有益效果是:本发明提供了一种新的防御策略,在无需先验知识以及不产生过多额外代价的前提下,鲁棒地实现对标签翻转攻击实施者的有效检测;本发明仅需由服务端向客户端额外传输一个轻量级生成网络,且该生成网络在服务端就能完成训练,因此能以很小的通信和计算代价识别恶意篡改样本标签的客户端,并加快模型收敛速度以及提高模型的预测精度。
附图说明
图1为实施例1的为本实施例的抵御标签翻转攻击的流程图;
图2为实施例2检测数据质量的流程图;
图3为实施例2的恶意客户端检测流程图;
图4为实施例2检测效果示意图;
图5为实施例2优化本地模型的流程图;
图6为实施例2在FedAvg中应用前后的测试曲线对比图。
图7为实施例3的抵御联邦学习标签翻转攻击的系统的架构图。
具体实施方式
附图仅用于示例性说明,不能理解为对本专利的限制;
为了更好说明本实施例,附图某些部件会有省略、放大或缩小,并不代表实际产品的尺寸;
对于本领域技术人员来说,附图中某些公知结构及其说明可能省略是可以理解的。
下面结合附图和实施例对本发明的技术方案做进一步的说明。
实施例1
本实施例提出的一种抵御联邦学习标签翻转攻击的方法中,包括以下步骤:
S1:服务端训练生成网络,所述生成网络基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
S2:服务端将全局模型和生成网络广播给每个客户端,每个客户端基于全局模型和生成网络检测其数据质量;
S3:服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
S4:每个良性客户端分别构建个性化模型,并以生成网络产生的潜在特征作为归纳偏置对本地训练进行调节;
S5:每个良性客户端完成本地训练后,分别将模型参数发送给服务端进行聚合,用于对服务端本地的全局模型进行更新;
S6:重复S1~S5,直至全局模型收敛或达到预设的停止条件,得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
本实施例提出的一种抵御联邦学习标签翻转攻击的方法,如图1所示,为本实施例的抵御标签翻转流程图。
在一可选实施例中,在S1步骤中,基于全局模型训练的一个轻量级的生成网络,所述生成网络的训练不会给各客户端带来额外的计算或时间开销。
在一可选实施例中,在S2步骤中,服务端将全局模型和生成网络广播给每个客户端后,每个客户端通过式(1)检测其质量参数DQi,并将质量参数DQi返回给服务端;
在一可选实施例中,S3步骤中,将客户端分类,并取消每个恶意客户端参与本轮训练的资格的步骤包括:
服务端通过聚类算法根据客户端i的质量参数DQi将客户端分为两类;
分别计算每个类中质量参数DQ的平均值;
比较两个类中质量参数DQ的平均值,将平均值较低的类中的客户端作为恶意客户端集合,将平均值较高的集合作为良性客户端集合;
取消恶意节点集中每个恶意客户端参与本轮训练的资格。
在一可选实施例中,在S4步骤中,每个客户端base+personalization的形式构建个性化模型,并以生成网络产生的增强表征作为归纳偏置对本地训练进行调节,优化本地模型并将本地模型返回给服务端。
在一可选实施例中,每个客户端都可以从生成网络中获得潜在的特征表征z~Gω(·|y)作为增广表征,为本地训练引入归纳偏置,增强其本地模型的泛化性能。因此,本地模型的优化目标是最大限度地提高对增强样本和本地数据进行正确预测的能力,如式(3)所示:
其中,是客户端i的个性化模型在其本地数据Di上的经验风险,Di为良性客户端i的本地数据,|Di|为其本地数据量,y为良性客户端i的本地数据标签,x为客户端i的本地数据样本,为良性客户端i的基础层,为良性客户端i的个性化层,l(·)为非负的凸损失函数,为模型的特征提取模块对输入样本x的输出,h(·)是预测模块θp的输出,范式Ey~p(y)表示样本y采样于p(y),J(·)为机器学习中的代价函数。
在一可选实施例中,在S5步骤中,服务端通过聚合算法对每个良性客户端完成本地训练后返回的模型参数进行聚合,聚合方式可以有多种选择,如经典的平均聚合FedAvg。由于本发明的防御策略对聚合策略无额外要求,因此能够优先考虑将其应用于先进的联邦学习框架中。聚合后的全局模型θ同样包含两个模块,即特征提取模块θf和预测模块θp。其中θp用于指导生成网络的训练,以恢复特征空间上的聚合分布,θf则作为构建局部个性化模型的共享组件。
在具体实施过程中,服务端通过目标函数训练一个轻量级的生成网络,并基于该生成网络和全局模型输出关于客户端的随机标签的潜在特征;服务端将全局模型和生成网络广播给每个客户端,客户端基于全局模型和生成网络检测每个客户端的数据质量,每个客户端完成数据质量检测后,将质量参数返回给服务端,服务端根据这些质量参数DQ:={DQ1,DQ2,...,DQN}检测恶意节点,服务端通过聚类算法根据客户端i的质量参数DQi将客户端分为两类,分别计算每个类中质量参数DQ的平均值,比较两个类中质量参数DQ的平均值,将平均值较低的类中的客户端作为恶意客户端集合,将平均值较高的集合作为良性客户端集合;并取消每个恶意客户端参与本轮训练的资格;每个良性客户端分别构建个性化模型,同时从生成网络中获得潜在的特征表征作为增广表征,为本地训练引入归纳偏置,对本地模型进行优化,并将优化后的本地模型返回给服务端;服务端中对每个良性客户端完成本地训练后返回的模型参数进行聚合,多次重复以上步骤,达到预设条件后停止,并得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
实施例2
本实施例提出的一种抵御联邦学习标签翻转攻击的方法中,在数据集Fashion-MNIST上测试本发明的检测效果。本实施例以FedAvg为例,展示了该算法是否应用本发明的防御策略——“MCDFL”来抵御标签翻转攻击的性能差异。以下是一些具体的设置。
本实施例中,Fashion-MNIST是一个灰度图像数据集,涵盖了服装、衬衫、包等10个类别的70,000张不同的商品正面图像。数据集预先划分为60,000个训练图像和10,000个测试图像,分别平均分配给每个客户端用于训练和测试。本实施例使用了一个具有两个卷积层的卷积神经网络。该模型在集中式场景中的测试精度为91.87%。生成网络为带有一个隐藏层的多层感知机。
本实施例中,设置客户端数量为100,并分别在恶意客户端数量为[5,10,20,30,40]的五个场景中进行测试。全局训练次数(globalepoch)为200轮。每轮训练中,客户端的局部更新(localepoch)为25次,每次更新使用数据批(batch)大小为32。生成网络以onehot形式的标签向量作为输入,并输出一个维度为d的特征表征。每轮训练中,生成网络的更新次数(epoch)为20,每次更新使用数据批(batch)大小为32。
本实施例中,对于标签翻转攻击策略,我们表述为“源标签”→“目标标签”的形式,具体分为三种情况,包括:
(1)源标签在非中毒联邦学习中经常被误分类为目标标签的情况;
(2)源标签很少被误分类为目标标签的情况;
(3)这两个极端之间的情况。
具体来说,我们针对上述三种情况设定了(1)6:衬衫→0:T恤/上衣,(2)1:裤子→3:连衣裙,(3)4:外套→6:衬衫等三种情况的标签翻转攻击。
在具体实施过程中,本实施例提出的一种抵御联邦学习标签翻转攻击的方法,具体包括以下步骤:
S1:服务端训练生成网络,所述生成网络基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
S2:服务端将全局模型和生成网络广播给每个客户端,每个客户端基于全局模型和生成网络检测其数据质量;
S3:服务端根据每个客户端的数据质量,通过k-means聚类,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
S4:每个良性客户端以base+personalization的形式构建个性化模型,并以生成网络产生的潜在特征作为归纳偏置来调节本地训练;
S5:每个良性客户端完成本地训练后,分别将模型参数发送给服务端进行聚合,用于对服务端本地的全局模型进行更新;本实施例中聚合方式采用FedAvg,以便与不采用本发明的检测方案的FedAvg算法作性能比较;
在本实施例中,S1步骤中,在第一轮训练开始前,服务端初始化由θ:=[θf,θp]参数化的全局模型和由ω参数化的生成网络,其中θf为特征提取模块,θp为预测模块;本实施例的全局模型的初始化结构如表1所示,本实施例的生成网络参数结构如表2所示。
表1全局模型的初始化结构参数
表2生成网络参数结构
上述表格中,Conv2D为二维卷积层,ReLu为ReLu函数,BN为批量归一化层,MaxPooling为最大池化层,FC为全连接层。
在一可选实施例中,每次将生成网络发送给各客户端之前,服务端都先对生成网络进行优化,其目标函数如式(1)所示;
其中g(·)是预测层h的逻辑输出,σ(·)是应用于该逻辑输出的非线性激活函数,即h(z;θp)=σ(g(z;θp))。损失函数l为交叉熵损失函数;R是一个产生随机标签序列的函数,这些标签序列作为生成网络的训练样本,范式Ex~D表示样本x采样于数据分布D,如式中表示样本采样于函数R产生的数据分布空间。
在本实施例中,在步骤S2中,服务端将全局模型和生成网络广播给每个客户端,以检测数据质量,如图2所示,为本实施例检测数据质量的流程图,具体的检测流程如下:
S21:计算全局模型的预测模块θp对客户端i的本地数据标签y所对应的潜在特征z~Gω(·|y)的预测值,所述预测值为第一预测值;
S22:计算全局模型对客户端的本地数据样本x的预测值,所述预测值为第二预测值;
S23:计算所述第一预测值和第二预测值中相同元素个数,根据所述相同元素个数计算得到客户端i的质量参数DQi;
S24:客户端i完成数据质量检测后,将质量参数DQi返回给服务端。
在一可选实施例中,客户端i的数据质量DQi如式(2)所示;
其中,Di为客户端i的本地数据,|Di|为其本地数据量,y为客户端i的本地数据标签,z~Gω(·|y)表示由生成网络输出的关于y的潜在特征,Acc(a,b)是计算序列a和b中相同元素个数的函数,accmaxh(z;θp)表示第一预测值,argmaxh(f(x;θ))表示第二预测值。
在本实施例中,在S3步骤中,每个客户端完成数据质量检测后,都将质量参数DQi返回给服务端,服务端随后将根据这些质量参数DQ:={DQ1,DQ2,...,DQN}检测恶意节点。具体来说,包括以下步骤:
服务端通过k-means聚类根据客户端i的质量参数DQi将客户端分为两类;
分别计算每个类中质量参数DQ的平均值;
比较两个类中质量参数DQ的平均值,将平均值较低的类中的客户端作为恶意客户端集合,将平均值较高的集合作为良性客户端集合;
取消恶意节点集中每个恶意客户端参与本轮训练的资格。
如图3所示,为本实施例的恶意客户端检测流程图;如图4(a)~图4(e)所示,为本实施例的恶意客户端数量个数(5,10,20,30,40)的检测效果图,每个检测图像所对应的攻击情况都是随机的。
在本实施例中,在步骤S4中,如图5所示,为本实施优化本地模型的流程图;具体的,良性客户端i构建个性化模型,优化本地模型的步骤包括:
S42:计算良性客户端i的个性化模型在其本地数据Di上的经验风险L(θi);
S43:通过在本地训练阶段对良性客户端i中参与训练的本地数据标签进行计数,获得良性客户端i本地数据标签的先验分布的经验近似值p(y);
S44:每个良性客户端i从生成网络Gω中获得潜在特征z~Gω(·|y)作为增广表征,为本地训练引入归纳偏置,并根据经验近似值p(y)和经验风险L(θi)对本地模型进行优化。
在一可选实施例中,本地模型的优化目标是最大限度地提高对增强样本和本地数据进行正确预测的能力,基于生成网络和良性客户端i的个性化模型在其本地数据Di上的经验风险L(θi)优化本地模型的表达式如式(3)所示:
其中是客户端i的个性化模型在其本地数据Di上的经验风险,其中,Di为良性客户端i的本地数据,|Di|为其本地数据量,y为良性客户端i的本地数据标签,x为客户端i的本地数据样本,为良性客户端i的基础层,为良性客户端i的个性化层,l(·)为非负的凸损失函数,h(·)为预测层,范式Ex~D表示样本x采样于数据分布D,如式中Ey~p(y)表示本地数据标签y采样于p(y)产生的数据分布空间,为模型的特征提取模块对输入样本x的输出,J(·)为机器学习的代价函数。
在一可选实施例中,对于模型在增强样本和本地数据的损失函数,我们分别采用相对熵损失函数DKL和交叉熵损失函数H,如式(4)和(5)所示:
其中P(x)为真实概率分布,Q(x)为预测概率分布。
在一可选实施例中,在步骤S5中,我们采用与FedAvg相同的平均聚合算法,以便与未应用本发明的抵御标签翻转攻击方法的FedAvg做性能比较。FedAvg的聚合算法如式(6)所示:
其中|Dk|是客户端k的数据量,B是良性客户端聚合。
在具体实施过程中,重复S1至S5,直至模型收敛或达到停止条件,得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
本实施例中,如图6所示,为本实施例在FedAvg中应用前后的测试曲线对比图。MCDFL为应用本发明防御策略的FedAvg算法;MCDFL(5,10,20,30,40)或FedAvg(5,10,20,30,40)分别代表算法在5,10,20,30,40个恶意客户端的情况下的测试曲线。实验结果表明,纯FedAvg的测试准确率随着恶意客户端数量的增加而降低,且变得越来越曲折。更糟糕的是,在大约100个global epoch后,出现了“gradient-drift”现象。“gradient-drift”是构建鲁棒性防御的一个潜在挑战,其产生的原因是模型的更新参数可能来自于良性客户端,也可能来自于恶意攻击者,在此具体表现为模型的测试准确率陡然下降。而在应用了MCDFL的FedAvg中,它的测试准确率明显优于纯FedAvg。并且,在不同程度的噪声环境下,检测策略通过快速识别恶意参与者并取消其参与训练的资格,保持了测试准确率的稳定性。此外,稳定的检测效果和较平滑的预测准确率曲线也说明该防御策略可以有效解决“gradient-drift”的问题。
实施例3
本实施例提出的一种抵御联邦学习标签翻转攻击的系统,应用于上述的抵御联邦学习标签翻转攻击的方法,包括:
生成网络训练模块,用于在服务端训练生成网络,并基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
数据质量检测模块,用于在客户端基于全局模型和生成网络检测每个客户端的数据质量;
客户端分类模块,用于在服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
客户端个性化模块,用于在客户端,每个良性客户端分别构建个性化模型,并以生成网络产生的增强表征作为归纳偏置对本地训练进行调节;
全局模型聚合模块,用于在服务端,对每个良性客户端完成本地训练后返回的模型参数进行聚合,根据经过聚合的模型参数对全局模型进行更新。
如图7所示,为本实施例的抵御联邦学习标签翻转攻击的系统的架构图。
在一可选实施例中,生成网络训练模块中,在第一轮训练开始前,服务端初由θ:=[θf,θp]参数化的全局模型和生成网络,其中θf为特征提取模块,θp为预测模块;服务端训练生成网络的步骤包括:
将训练样本输入所述生成网络,生成网络输出训练样本的潜在特征,并结合所述潜在特征和全局模型中的预测模块θp,通过目标函数训练生成网络。
所述训练生成网络的目标函数如式(1)所示;
其中,g(·)是预测层h的逻辑输出,σ(·)是应用于该逻辑输出的非线性激活函数,即h(z;θp)=σ(g(z;θp))。l为非负的凸损失函数。R是一个产生随机标签序列的函数,这些标签序列作为生成网络的训练样本。因此,给定随机标签序列式(1)只需全局模型的预测模块θp就可以在服务端进行优化,换句话说,生成网络的训练不会给各客户端带来额外的计算或时间开销。相较于联邦系统的训练模型,该生成网络是十分轻量的,尤其是特征空间紧凑的情况下。因此完成生成网络的训练无需过多的额外时间成本。轻量的特性也使其易于训练和下载
在一可选实施例中,数据质量检测模块中在客户端基于全局模型和生成网络检测每个客户端的数据质量,并返回给服务端;客户端分类模块中服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格。
在一可选实施例中,所述检测策略无需如其他检测策略那样对客户端的本地数据进行再训练,因此服务端可以在每轮训练之前有效、及时地识别恶意参与者,从而限制恶意参与者参与全局聚合,以此抵御攻击。
在一可选实施例中,客户端个性化模块中每个良性客户端以base+personalization的形式构建个性化模型,并以生成网络产生的增强表征作为归纳偏置来调节本地训练。具体来讲,客户端i的本地模型θi分为两个模块,分别为基础层和个性化层其中用于提取本地特征的基础层为全局模型的特征提取层θf,即输出预测结果的个性化层则是本地模型的预测模块
在一可选实施例中,本发明的防御策略对聚合策略无额外要求,因此能够优先考虑将其应用于先进的联邦学习框架中。全局模型聚合模块中聚合后的全局模型θ同样包含两个模块,即特征提取模块θf和预测模块θp。其中θp用于指导生成网络的训练,以恢复特征空间上的聚合分布,θf则作为构建局部个性化模型的共享组件。
在一可选实施例中,经多次重复优化后得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
在具体实施过程中,生成网络训练模块通过目标函数在服务端训练一个轻量级的生成网络,并基于该生成网络和全局模型输出关于用户的随机标签的潜在特征;数据质量检测模块中服务端将全局模型和生成网络广播给每个客户端,客户端基于全局模型和生成网络检测每个客户端的数据质量,每个客户端完成数据质量检测后,都将质量参数返回给服务端;客户端分类模块在服务端中通过聚类算法将质量参数分为两类,将平均质量较低的集合判断为恶意节点集,并取消每个恶意客户端参与本轮训练的资格;客户端个性化模块用于客户端中,每个良性客户端分别构建个性化模型,同时从生成网络中获得潜在的特征表征作为增广表征,为本地训练引入归纳偏置,对本地模型进行优化,并将优化后的本地模型返回给服务端;全局模型聚合模块通过聚合算法在服务端中对每个良性客户端完成本地训练后返回的模型参数进行聚合,经多次重复优化后得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
相同或相似的标号对应相同或相似的部件;
附图中描述位置关系的用语仅用于示例性说明,不能理解为对本专利的限制;
显然,本发明的上述实施例仅仅是为清楚地说明本发明所作的举例,而并非是对本发明的实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式的变化或变动。这里无需也无法对所有的实施方式予以穷举。凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明权利要求的保护范围之内。
Claims (10)
1.一种抵御联邦学习标签翻转攻击的方法,其特征在于,包括以下步骤:
S1:服务端训练生成网络,所述生成网络基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
S2:服务端将全局模型和生成网络广播给每个客户端,每个客户端基于全局模型和生成网络检测其数据质量;
S3:服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
S4:每个良性客户端分别构建个性化模型,并以生成网络产生的潜在特征作为归纳偏置对本地训练进行调节;
S5:每个良性客户端完成本地训练后,分别将模型参数发送给服务端进行聚合,用于对服务端本地的全局模型进行更新;
S6:重复S1~S5,直至全局模型收敛或达到预设的停止条件,得到最终的全局模型,基于最终的全局模型抵御标签翻转攻击。
4.根据权利要求1所述的一种抵御联邦学习标签翻转攻击的方法,其特征在于,所述S2步骤中,每个客户端基于全局模型和生成网络检测其数据质量的步骤包括:
S21:计算全局模型的预测模块θp对客户端i的本地数据标签y所对应的潜在特征z~Gω(·|y)的预测值,所述预测值为第一预测值;
S22:计算全局模型对客户端的本地数据样本x的预测值,所述预测值为第二预测值;
S23:计算所述第一预测值和第二预测值中相同元素个数,根据所述相同元素个数计算得到客户端i的质量参数DQi;
S24:客户端i完成数据质量检测后,将质量参数DQi返回给服务端。
6.根据权利要求1所述的一种抵御联邦学习标签翻转攻击的方法,其特征在于,所述S3步骤中,将客户端分类,并取消每个恶意客户端参与本轮训练的资格的步骤包括:
服务端通过聚类算法根据客户端i的质量参数DQi将客户端分为两类;
分别计算每个类中质量参数DQ的平均值;
比较两个类中质量参数DQ的平均值,将平均值较低的类中的客户端作为恶意客户端集合,将平均值较高的集合作为良性客户端集合;
取消恶意节点集中每个恶意客户端参与本轮训练的资格。
7.根据权利要求1所述的一种抵御联邦学习标签翻转攻击的方法,其特征在于,所述S4步骤中,良性客户端i构建个性化模型,优化本地模型的步骤包括:
S42:计算良性客户端i的个性化模型在其本地数据Di上的经验风险L(θi);
S43:通过在本地训练阶段对良性客户端i中参与训练的本地数据标签进行计数,获得良性客户端i本地数据标签的先验分布的经验近似值p(y);
S44:每个良性客户端i从生成网络Gω中获得潜在特征z~Gω(·|y)作为增广表征,为本地训练引入归纳偏置,并根据经验近似值p(y)和经验风险L(θi)对本地模型进行优化。
10.一种抵御联邦学习标签翻转攻击的系统,应用于权利要求1~9任一项所述的抵御联邦学习标签翻转攻击的方法,其特征在于,包括:
生成网络训练模块,用于在服务端训练生成网络,并基于服务端本地的全局模型学习标签到潜在特征空间的映射关系;
数据质量检测模块,用于在客户端基于全局模型和生成网络检测每个客户端的数据质量;
客户端分类模块,用于在服务端根据每个客户端的数据质量,将客户端分为良性客户端集合和恶意客户端集合,并取消每个恶意客户端参与本轮训练的资格;
客户端个性化模块,用于在客户端,每个良性客户端分别构建个性化模型,并以生成网络产生的增强表征作为归纳偏置对本地训练进行调节;
全局模型聚合模块,用于在服务端,对每个良性客户端完成本地训练后返回的模型参数进行聚合,根据经过聚合的模型参数对全局模型进行更新。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210486095.7A CN115422537A (zh) | 2022-05-06 | 2022-05-06 | 一种抵御联邦学习标签翻转攻击的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210486095.7A CN115422537A (zh) | 2022-05-06 | 2022-05-06 | 一种抵御联邦学习标签翻转攻击的方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115422537A true CN115422537A (zh) | 2022-12-02 |
Family
ID=84195727
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210486095.7A Pending CN115422537A (zh) | 2022-05-06 | 2022-05-06 | 一种抵御联邦学习标签翻转攻击的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115422537A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116842577A (zh) * | 2023-08-28 | 2023-10-03 | 杭州海康威视数字技术股份有限公司 | 联邦学习模型投毒攻击检测及防御方法、装置及设备 |
CN117313898A (zh) * | 2023-11-03 | 2023-12-29 | 湖南恒茂信息技术有限公司 | 基于关键周期识别的联邦学习恶意模型更新检测方法 |
-
2022
- 2022-05-06 CN CN202210486095.7A patent/CN115422537A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116842577A (zh) * | 2023-08-28 | 2023-10-03 | 杭州海康威视数字技术股份有限公司 | 联邦学习模型投毒攻击检测及防御方法、装置及设备 |
CN116842577B (zh) * | 2023-08-28 | 2023-12-19 | 杭州海康威视数字技术股份有限公司 | 联邦学习模型投毒攻击检测及防御方法、装置及设备 |
CN117313898A (zh) * | 2023-11-03 | 2023-12-29 | 湖南恒茂信息技术有限公司 | 基于关键周期识别的联邦学习恶意模型更新检测方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Lin et al. | Free-riders in federated learning: Attacks and defenses | |
Hu et al. | A novel image steganography method via deep convolutional generative adversarial networks | |
Bao et al. | Threat of adversarial attacks on DL-based IoT device identification | |
CN107704877B (zh) | 一种基于深度学习的图像隐私感知方法 | |
Aïvodji et al. | Gamin: An adversarial approach to black-box model inversion | |
Din et al. | Exploiting evolving micro-clusters for data stream classification with emerging class detection | |
CN106295694B (zh) | 一种迭代重约束组稀疏表示分类的人脸识别方法 | |
CN115422537A (zh) | 一种抵御联邦学习标签翻转攻击的方法 | |
CN113645197B (zh) | 一种去中心化的联邦学习方法、装置及系统 | |
Shi et al. | Active deep learning attacks under strict rate limitations for online API calls | |
CN111767848A (zh) | 一种基于多域特征融合的辐射源个体识别方法 | |
Lin et al. | Fairgrape: Fairness-aware gradient pruning method for face attribute classification | |
CN112365005A (zh) | 基于神经元分布特征的联邦学习中毒检测方法 | |
Mareen et al. | Comprint: Image forgery detection and localization using compression fingerprints | |
Luo et al. | Beyond universal attack detection for continuous-variable quantum key distribution via deep learning | |
McClintick et al. | Countering physical eavesdropper evasion with adversarial training | |
US20220222578A1 (en) | Method of training local model of federated learning framework by implementing classification of training data | |
CN113343123B (zh) | 一种生成对抗多关系图网络的训练方法和检测方法 | |
Xian et al. | Understanding backdoor attacks through the adaptability hypothesis | |
Qu et al. | An {Input-Agnostic} Hierarchical Deep Learning Framework for Traffic Fingerprinting | |
Dou et al. | Unsupervised anomaly detection in heterogeneous network time series with mixed sampling rates | |
Xiao et al. | Privacy-Preserving Federated Class-Incremental Learning | |
CN117150321B (zh) | 设备信任度评价方法、装置、服务设备及存储介质 | |
Lu et al. | A Fine-tuning-based Adversarial Network for Member Privacy Preserving | |
Zheng et al. | CRFL: A novel federated learning scheme of client reputation assessment via local model inversion |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |