CN116681144A - 基于动态自适应知识蒸馏的联邦学习模型聚合方法 - Google Patents

基于动态自适应知识蒸馏的联邦学习模型聚合方法 Download PDF

Info

Publication number
CN116681144A
CN116681144A CN202310682277.6A CN202310682277A CN116681144A CN 116681144 A CN116681144 A CN 116681144A CN 202310682277 A CN202310682277 A CN 202310682277A CN 116681144 A CN116681144 A CN 116681144A
Authority
CN
China
Prior art keywords
model
teacher
local
client
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310682277.6A
Other languages
English (en)
Inventor
吕军
马晓静
赵瑞欣
付佳韵
陈付龙
苌婉婷
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Anhui Normal University
Original Assignee
Anhui Normal University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Anhui Normal University filed Critical Anhui Normal University
Priority to CN202310682277.6A priority Critical patent/CN116681144A/zh
Publication of CN116681144A publication Critical patent/CN116681144A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Bioethics (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computing Systems (AREA)
  • Image Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Mathematical Physics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Bioinformatics & Cheminformatics (AREA)

Abstract

本发明提供了一种基于动态自适应知识蒸馏的联邦学习模型聚合方法,可以有效缓解数据异质性带来的精度下降问题。方法包括全局模型初始化、本地模型训练、聚合生成全局模型三个阶段。本发明在本地模型训练阶段使用知识蒸馏技术促进客户端学习全局模型,动态调整知识蒸馏比例使客户端可以根据各自情况自适应学习全局模型,并且动态调整教师模型输出分布使客户端更有效地利用知识蒸馏中教师模型的知识,使得聚合后服务器能够有效生成性能更优的全局模型,同时保证不泄露聚合过程中局部模型和全局模型的额外隐私。本发明能够在保证用户隐私安全的前提下,协同多方训练生成更优的全局模型。

Description

基于动态自适应知识蒸馏的联邦学习模型聚合方法
技术领域
本发明涉及隐私保护和数据安全技术领域,特别涉及一种基于动态自适应知识蒸馏的联邦学习模型聚合方法。
背景技术
传统的集中式学习要求在手机等本地设备上收集的所有数据都要集中存储在数据中心或云服务器上。这一要求不仅引起了对隐私风险和数据泄露的担忧,而且在数据量巨大时,对服务器的存储和计算能力提出了很高的要求。
联邦学习是目前在隐私约束下最广泛采用的机器学习模型协作训练框架,旨在训练一个全局模型,可以在分布在不同设备上的数据上进行训练,同时保护数据隐私。但是联邦学习中每个客户端上的训练数据在很大程度上依赖于特定本地设备的使用情况,因此,客户端的数据分布可能彼此完全不同。这种现象被称为非独立同分布(Non-IID),它可能会导致严重的模型发散,导致精度降低,模型收敛缓慢甚至无法收敛。也就是说,由于局部数据分布的异质性,具有相同初始参数的局部模型会收敛到不同的模型。在联邦学习过程中,通过平均上传的局部模型得到的共享全局模型与理想模型(本地设备上的数据为IID时得到的模型)之间的差异持续增加,收敛速度减慢,使学习性能恶化。
虽然目前已经有一些研究提出可以在本地模型训练时使用知识蒸馏技术约束本地模型向全局模型学习来解决这一问题,但仍然存在许多问题。比如固定的知识蒸馏比例不能自主适应训练过程中的多变性,或是需要额外的辅助数据集来帮助判断合适的知识蒸馏比例,在现实应用中仍然存在诸多困难。因此,如何更好地利用知识蒸馏提升模型的准确率依然是目前亟需解决的技术难题。
发明内容
本发明的目的在于克服现有技术的不足,提供基于动态自适应知识蒸馏的联邦学习模型聚合方法,用以解决联邦学习场景下,固定知识蒸馏比例不能适应各客户端数据分布及模型训练进度不一致的实际情况,进而导致知识蒸馏效果下降,也就无法训练得到高准确度的全局模型的技术问题。
为了实现上述目的,本发明采用的技术方案为:基于动态自适应知识蒸馏的联邦学习模型聚合方法,包括如下步骤:
步骤1:服务器初始化全局模型并将其发送至参与本轮训练的客户端;
步骤2:客户端接收到全局模型后,确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,训练生成本地模型,并上传给服务器;
步骤3:对接收到的本地模型进行聚合生成新的全局模型从而完成本轮训练过程。
步骤1中:服务器根据训练任务选择待训练的模型作为本轮全局模型M;然后选择参与本轮训练的客户端c1,c2,...,cn(1≤n≤N),将全局模型下发给参与训练的客户端;其中客户端c1,c2,...,cN为N个独立的客户端,客户端各自拥有独立的数据D1,D2,...,DN
步骤3中:在接收到全部客户端上传的全部本地模型后,采用联邦平均算法对本地模型进行聚合形成新的全局模型。
计算每个客户端数据集大小占总数据集大小的比例,按照比例对对应本地模型参数进行加权形成全局模型的参数进而得到新的全局模型。
步骤2中包括:
步骤2.1:客户端保存全局模型作为知识蒸馏的教师模型,计算当前客户端对于教师模型的知识蒸馏比例;
步骤2.2:客户端计算教师模型输出分布平缓程度;
步骤2.3:客户端把全局模型作为本地模型的初始模型,利用本地训练数据集训练本地模型;并且把本地模型作为知识蒸馏中的学生模型,使用步骤2.1、2.2中计算出的知识蒸馏比例和输出分布平缓程度约束教师模型的知识蒸馏过程;
步骤2.4:客户端将本地模型上传给服务器。
步骤2.1中:
利用本地数据集确定本轮知识蒸馏中对收到的全局模型学习的比例。客户端ci(1≤i≤n)接收服务器下发的模型M,作为本地模型mi的教师模型,参与本地模型mi优化训练过程。客户端利用本地训练数据集Di=(x,y),测试教师模型在本地训练数据集Di上的准确度Ai。已知数据样本x和对应标签y,输入教师模型,模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败。统计所有样本的预测结果,得到预测准确度并且考虑随着训练过程进行教师模型的变化,加入时间因素,根据当前通信回合轮数t和最终通信回合数T,计算得出本轮教师模型在客户端ci的知识蒸馏比例αi,kd=Ai*t/T。
步骤2.2中:
客户端利用本地数据集Di计算教师模型输出分布平缓程度λ。对于输入x和目标y,教师模型产生logit向量zteacher(x),计算教师模型输出logit向量平缓程度其中zteacher,i(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(zteacher(x))是各类别预测概率的平均值,/>
步骤2.3中:
客户端ci将全局模型M,作为本地模型mi的起点,利用本地数据集Di,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x);
使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异;
对于输入x和目标y,教师模型产生logit向量zteacher(x);根据计算出的教师模型输出logit向量平缓程度λ,缩放logit向量对应的softmax分布k为多类别分类任务的类别数目,便于学生模型对教师模型的知识的学习;针对本地模型和教师模型在同一数据样本上不同的预测概率输出,使用Kullback-Leibler(KL)散度表示教师模型输出概率分布pteacher(x)和学生模型输出概率p(x)分布之间的差异,约束本地模型向教师模型学习;
并且,根据计算出的知识蒸馏比例αi,kd,动态调整交叉熵损失H(p,y)和KL散度损失DKL(pteacher||p)的比例,总损失函数L=(1-αi,kd)H(p,y)+αi,kdDKL(pteacher||p);最后使用总损失反向传播计算梯度更新本地模型。
通过本发明所构思的以上技术方案与现有技术相比,本发明的优点在于:
1、本发明基于联邦学习,服务器只对客户端上交的模型进行聚合操作即可得到全局模型,客户端数据不会泄露给第三方,能够对客户端身份隐私做到很好的保护,不用担心客户端数据泄露的情况,因此本发明具有很高的隐私保护安全性。
2、本发明在确保用户信息数据和隐私不被泄露的情况下实现了对客户端模型的聚合并灵活调控。使用知识蒸馏技术,客户端利用上一轮的全局模型对个性化本地模型训练过程加以约束,很好地改善Non-IID场景下模型权重发散导致聚合后模型性能灾难性下降问题。
3、每个客户端利用全局模型对本地训练数据集的预测准确度,可以自主计算出合适于当前全局模型的知识蒸馏比例,进行选择性学习。相比传统方案灵活性更高,具有很高的实用性。
4、客户端可以根据自己的教师模型输出分布的平缓程度自适应调整教师模型的输出分布,更有利于教师模型与学生模型之间知识的传输,也具有较高的灵活性。
附图说明
下面对本发明说明书各幅附图表达的内容及图中的标记作简要说明:
图1为本发明联邦学习通信过程的流程示意图;
图2为本发明客户端本地训练方法的流程示意图;
图3为本发明服务器聚合模型过程的流程示意图。
具体实施方式
下面对照附图,通过对最优实施例的描述,对本发明的具体实施方式作进一步详细的说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。此外,下面所描述的本发明各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
本发明提供了隐私保护和数据安全技术领域的一种基于动态自适应知识蒸馏的联邦学习模型聚合方法,其目的在于为客户端提供一种基于知识蒸馏的本地训练方法,客户端本地训练阶段,可以自适应地选择对全局模型进行学习,保证本地数据集学习效果的同时维护本地模型权重不过于发散,便于之后服务器直接聚合本地模型生成全局模型。从而实现Non-IID场景下,既保护客户端数据隐私信息不被泄露,又可以基于联邦学习协调多方训练得到的全局模型能够表现出更好的预测准确度,为用户提供服务。
本发明的整体思路在于,服务器初始化全局模型发送给客户端;客户端利用本地数据集确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,灵活进行学习,训练生成本地模型,并上传给服务器;服务器对收到的本地模型进行加权平均,聚合生成新的全局模型,并进入下一次训练过程。最终训练得到一个适用于各个客户端的通用的全局模型。
本申请实施例的方案为:一种基于动态自适应知识蒸馏的联邦学习模型聚合方法,包括:
服务器初始化全局模型,将全局模型下发给客户端;
客户端接收全局模型,使用本地训练数据集训练生成本地模型,并且客户端将训练完成的本地模型上传给服务器;
服务器收集客户端发送的n个本地模型,聚合得到新的全局模型参数。
优选地,所述服务器初始化全局模型,将全局模型下发给客户端,包括:
服务器从初始模型提供者处下载模型或者随机生成初始全局模型M,并将全局模型M统一下发给本轮参与训练的n个客户端c1,c2,...cn,通信模型本身。
优选地,所述客户端接收全局模型,使用本地训练数据集训练生成本地模型,并且客户端将训练完成的本地模型上传给服务器,包括:
(a)客户端ci(1≤i≤n)接收服务器下发的模型M,作为本地模型mi的教师模型,参与本地模型mi优化训练过程。客户端利用本地训练数据集Di=(x,y),测试教师模型在本地训练数据集Di上的准确度Ai。已知数据样本x和对应标签y,输入教师模型,模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败。统计所有样本的预测结果,得到预测准确度并且考虑随着训练过程进行教师模型的变化,加入时间因素,根据当前通信回合轮数t和最终通信回合数T,计算得出本轮教师模型在客户端ci的知识蒸馏比例αi,kd=Ai*t/T。
(b)利用本地训练数据集Di=(x,y),客户端计算教师模型输出logit向量平缓程度λ。对于输入x和目标y,教师模型产生logit向量zteacher(x),计算得到其中zteacher,i(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(zteacher(x))是各类别预测概率的平均值,/>
(c)客户端ci将全局模型M,作为本地模型mi的起点,利用本地数据集Di,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异。
同样的,对于输入x和目标y,教师模型产生logit向量zteacher(x)。根据(b)中计算出的教师模型输出logit向量平缓程度λ,缩放logit向量对应的softmax分布其中k为多类别分类任务的类别数目,便于教师模型到学生模型的知识传输。针对本地模型和教师模型在同一数据样本上不同的预测概率输出,使用KL散度表示教师模型输出概率分布pteacher(x)和学生模型输出概率p(x)分布之间的差异,约束本地模型向教师模型学习。
并且,根据(a)中计算出的知识蒸馏比例αi,kd,动态调整交叉熵损失H(p,y)和KL散度损失DKL(pteacher||p)的比例,总损失函数L=(1-αi,kd)H(p,y)+αi,kdDKL(pteacher||p)。最后使用总损失反向传播计算梯度更新本地模型。
(d)客户端将训练得到的本地模型上传给服务器,通信模型本身。
优选地,所述服务器收集客户端发送的n个本地模型,聚合得到新的全局模型参数,包括:
服务器收齐所有客户端传输的局部模型,使用联邦平均算法对n个本地模型进行聚合,得到全局模型的M,并进入下一轮的模型训练。
下面将具体介绍各个部分:
如图1所示,为本发明实施例提供的一种方法流程示意图,主要包括联邦学习中服务器与客户端通信流程。其中各序号流程分别为:
①:服务器初始化全局模型,下发给客户端;
②:客户端训练本地模型;
③:客户端上传本地模型;
④:服务器聚合本地模型生成新一轮全局模型。
系统包括一个可信任的服务器,和N个独立的客户端c1,c2,...,cN,客户端各自拥有独立的数据D1,D2,...,DN。当训练开始,服务器根据训练任务选择合适的模型,下载预训练模型或随机初始化模型参数,作为本轮全局模型M。然后选择参与本轮训练的客户端c1,c2,...,cn(1≤n≤N),将全局模型下发给参与训练的客户端。客户端将收到的全局模型作为本地模型m1,m2,..,.mn的起点,利用本地数据集D1,D2,...,Dn进行训练得到本地模型,并上传给服务器,本地训练流程如图2所示。然后服务器收集本地模型完成后,按照本地数据集的比例加权平均收集的本地模型参数,得到新的全局模型其中D为所有数据集规模之和,Di是每个客户端i的本地数据集大小,mi是客户端i的本地模型,n为参与本轮训练的客户端数目。服务器聚合流程如图3所示。然后进行下一轮的训练。多轮训练完成后,得到一个最终的通用的全局模型M。
如图2所示,为客户端本地训练方法的流程示意图,用于客户端本地训练阶段。具体步骤包括:
第一步:客户端接收服务器下发的本轮全局模型M,保存为教师模型。利用本地训练数据集Di=(x,y),样本x输入教师模型得到logit向量z(x)=[z1(x),z2(x),...,zk(x)],然后通过softmax函数输出预测的概率p(x)=[p1(x),p2(x),...,pk(x)]。p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败。统计所有样本的预测结果,得到预测准确度因为全局模型训练效果在训练不同阶段表现不一致,为更好衡量适合于当前全局模型的知识蒸馏比例,根据当前通信轮数t、总通信轮数T和第一步得到的预测准确度Ai,计算得出本轮teacher在客户端ci的知识蒸馏比例αi,kd=Ai*t/T。
第二步:对于本地训练数据集Di=(x,y),输入x和目标y,教师模型产生logit向量zteacher(x),计算教师模型输出logit向量的平缓程度其中zteacher,i(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(zteacher(x))是各类别预测概率的平均值,/>
第三步:客户端将全局模型,作为本地模型mi的起点,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x)。同样的,输入x和目标y,教师模型产生logit向量zteacher(x)。根据第二步得到的教师模型输出的logit向量平缓程度λ,缩放zteacher(x)对应的softmax分布k是多类别分类任务的类别数目,便于学生模型对教师模型学习。
第四步:使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异,交叉熵损失为其中k是多类别分类任务的类别数目,p(xi)是本地模型对类别i的预测概率。
第五步:使用KL散度表示教师模型输出概率分布pteacher(x)和学生模型(本地模型)输出概率p(x)分布之间的差异,函数如下:
其中,k是多类别分类任务的类别数目,p(xi)是本地模型对类别i的预测概率,pteacher(xi)是教师模型对类别i的预测概率。
第六步:根据第一步中得到的知识蒸馏比例αi,kd,调整第四步计算出的交叉熵损失H(p,y)和第五步计算出的KL散度损失DKL(pteacher||p)的比例,总损失函数L=(1-αi,kd)H(p,y)+αi,kdDKL(pteacher||p)。最后使用总损失反向传播计算梯度更新本地模型,然后返回第三步进入下一轮本地训练。本地模型训练完成后,客户端将本地模型发送给服务器。
如图3所示,为服务器聚合阶段的流程示意图。具体步骤包括:
第一步:服务器收集各个客户端上传的本地模型。
第二步:根据联邦平均算法,计算每个客户端数据集大小占总数据集大小的比例,按照比例对对应本地模型参数进行加权,平均得到新的全局模型。并进入下一轮全局模型的训练。其中客户端1的样本量a1,训练完成后为本地模型m1,则其对应的参数就是a1*m1/(a1+a2+a3...+an),对应的全局模型M的参数M=(a1*m1+a2*m2...+an*mn)/(a1+a2...+an)。
显然本发明具体实现并不受上述方式的限制,只要采用了本发明的方法构思和技术方案进行的各种非实质性的改进,均在本发明的保护范围之内。

Claims (8)

1.基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:包括如下步骤:
步骤1:服务器初始化全局模型并将其发送至参与本轮训练的客户端;
步骤2:客户端接收到全局模型后,确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,训练生成本地模型,并上传给服务器;
步骤3:对接收到的本地模型进行聚合生成新的全局模型从而完成本轮训练过程。
2.如权利要求1所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
步骤1中:服务器根据训练任务选择待训练的模型作为本轮全局模型M;然后选择参与本轮训练的客户端c1,c2,...,cn(1≤n≤N),将全局模型下发给参与训练的客户端;其中客户端c1,c2,...,cN为N个独立的客户端,客户端各自拥有独立的数据D1,D2,...,DN
3.如权利要求1所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
步骤3中:在接收到全部客户端上传的全部本地模型后,采用联邦平均算法对本地模型进行聚合形成新的全局模型。
4.如权利要求3所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
计算每个客户端数据集大小占总数据集大小的比例,按照比例对对应本地模型参数进行加权形成全局模型的参数进而得到新的全局模型。
5.如权利要求1-4任一所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
步骤2中包括:
步骤2.1:客户端保存全局模型作为知识蒸馏的教师模型,计算当前客户端对于教师模型的知识蒸馏比例;
步骤2.2:客户端计算教师模型输出分布平缓程度;
步骤2.3:客户端把全局模型作为本地模型的初始模型,利用本地训练数据集训练本地模型;并且把本地模型作为知识蒸馏中的学生模型,使用步骤2.1、2.2中计算出的知识蒸馏比例和输出分布平缓程度约束教师模型的知识蒸馏过程;
步骤2.4:客户端将本地模型上传给服务器。
6.如权利要求5所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
步骤2.1中:
利用本地数据集确定本轮知识蒸馏中对收到的全局模型学习的比例;客户端ci(1≤i≤n)接收服务器下发的模型M,作为本地模型mi的教师模型,参与本地模型mi优化训练过程;客户端利用本地训练数据集Di=(x,y),测试教师模型在本地训练数据集Di上的准确度Ai;已知数据样本x和对应标签y,输入教师模型,模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x);p(x)中预测概率最高的类别即为模型预测结果,如与对应样本标签y一致,则预测成功,否则预测失败;统计所有样本的预测结果,得到预测准确度并且考虑随着训练过程进行教师模型的变化,加入时间因素,根据当前通信回合轮数t和最终通信回合数T,计算得出本轮教师模型在客户端ci的知识蒸馏比例αi,kd=Ai*t/T。
7.如权利要求5所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
步骤2.2中:
客户端利用本地数据集Di计算教师模型输出分布平缓程度λ;对于输入x和目标y,教师模型产生logit向量zteacher(x),计算教师模型输出logit向量平缓程度其中zteacher,i(x)代表单个类别的教师模型logit向量,k是多类别分类任务的类别数目,mean(zteacher(x))是各类别预测概率的平均值,/>
8.如权利要求5所述的基于动态自适应知识蒸馏的联邦学习模型聚合方法,其特征在于:
步骤2.3中:
客户端ci将全局模型M,作为本地模型mi的起点,利用本地数据集Di,对于输入x和目标y,本地模型产生logit向量z(x),然后通过softmax函数输出预测的概率p(x);
使用交叉熵函数计算真实概率y和预测概率分布p(x)之间的差异;
对于输入x和目标y,教师模型产生logit向量zteacher(x);根据计算出的教师模型输出logit向量平缓程度λ,缩放logit向量对应的softmax分布k为多类别分类任务的类别数目,便于学生模型对教师模型的知识的学习;针对本地模型和教师模型在同一数据样本上不同的预测概率输出,使用Kullback-Leibler(KL)散度表示教师模型输出概率分布pteacher(x)和学生模型输出概率p(x)分布之间的差异,约束本地模型向教师模型学习;
并且,根据计算出的知识蒸馏比例αi,kd,动态调整交叉熵损失H(p,y)和KL散度损失DKL(pteacher||p)的比例,总损失函数L=(1-αi,kd)H(p,y)+αi,kdDKL(pteacher||p);最后使用总损失反向传播计算梯度更新本地模型。
CN202310682277.6A 2023-06-09 2023-06-09 基于动态自适应知识蒸馏的联邦学习模型聚合方法 Pending CN116681144A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310682277.6A CN116681144A (zh) 2023-06-09 2023-06-09 基于动态自适应知识蒸馏的联邦学习模型聚合方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310682277.6A CN116681144A (zh) 2023-06-09 2023-06-09 基于动态自适应知识蒸馏的联邦学习模型聚合方法

Publications (1)

Publication Number Publication Date
CN116681144A true CN116681144A (zh) 2023-09-01

Family

ID=87786886

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310682277.6A Pending CN116681144A (zh) 2023-06-09 2023-06-09 基于动态自适应知识蒸馏的联邦学习模型聚合方法

Country Status (1)

Country Link
CN (1) CN116681144A (zh)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117010534A (zh) * 2023-09-27 2023-11-07 中国人民解放军总医院 一种基于环形知识蒸馏和元联邦学习的动态模型训练方法、系统及设备
CN117094355A (zh) * 2023-10-20 2023-11-21 网络通信与安全紫金山实验室 模型更新方法、非易失性存储介质及计算机设备
CN117196070A (zh) * 2023-11-08 2023-12-08 山东省计算中心(国家超级计算济南中心) 一种面向异构数据的双重联邦蒸馏学习方法及装置
CN117236421A (zh) * 2023-11-14 2023-12-15 湘江实验室 一种基于联邦知识蒸馏的大模型训练方法
CN118070876A (zh) * 2024-04-19 2024-05-24 智慧眼科技股份有限公司 一种大模型知识蒸馏低秩适应联邦学习方法、电子设备及可读存储介质
CN118101339A (zh) * 2024-04-23 2024-05-28 山东科技大学 一种应对物联网隐私保护的联邦知识蒸馏方法

Cited By (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117010534A (zh) * 2023-09-27 2023-11-07 中国人民解放军总医院 一种基于环形知识蒸馏和元联邦学习的动态模型训练方法、系统及设备
CN117010534B (zh) * 2023-09-27 2024-01-30 中国人民解放军总医院 一种基于环形知识蒸馏和元联邦学习的动态模型训练方法、系统及设备
CN117094355A (zh) * 2023-10-20 2023-11-21 网络通信与安全紫金山实验室 模型更新方法、非易失性存储介质及计算机设备
CN117094355B (zh) * 2023-10-20 2024-03-29 网络通信与安全紫金山实验室 模型更新方法、非易失性存储介质及计算机设备
CN117196070A (zh) * 2023-11-08 2023-12-08 山东省计算中心(国家超级计算济南中心) 一种面向异构数据的双重联邦蒸馏学习方法及装置
CN117196070B (zh) * 2023-11-08 2024-01-26 山东省计算中心(国家超级计算济南中心) 一种面向异构数据的双重联邦蒸馏学习方法及装置
CN117236421A (zh) * 2023-11-14 2023-12-15 湘江实验室 一种基于联邦知识蒸馏的大模型训练方法
CN117236421B (zh) * 2023-11-14 2024-03-12 湘江实验室 一种基于联邦知识蒸馏的大模型训练方法
CN118070876A (zh) * 2024-04-19 2024-05-24 智慧眼科技股份有限公司 一种大模型知识蒸馏低秩适应联邦学习方法、电子设备及可读存储介质
CN118101339A (zh) * 2024-04-23 2024-05-28 山东科技大学 一种应对物联网隐私保护的联邦知识蒸馏方法

Similar Documents

Publication Publication Date Title
CN116681144A (zh) 基于动态自适应知识蒸馏的联邦学习模型聚合方法
CN110460600B (zh) 可抵御生成对抗网络攻击的联合深度学习方法
CN113762530B (zh) 面向隐私保护的精度反馈联邦学习方法
US20240135191A1 (en) Method, apparatus, and system for generating neural network model, device, medium, and program product
CN112734032A (zh) 一种用于横向联邦学习的优化方法
CN113691594B (zh) 一种基于二阶导数解决联邦学习中数据不平衡问题的方法
CN114564746B (zh) 基于客户端权重评价的联邦学习方法和系统
CN117236421B (zh) 一种基于联邦知识蒸馏的大模型训练方法
CN116523079A (zh) 一种基于强化学习联邦学习优化方法及系统
CN115761378B (zh) 基于联邦学习的电力巡检图像分类和检测方法及系统
Shlezinger et al. Collaborative inference via ensembles on the edge
CN116645130A (zh) 基于联邦学习与gru结合的汽车订单需求量预测方法
CN116187469A (zh) 一种基于联邦蒸馏学习框架的客户端成员推理攻击方法
Yang et al. Forecasting time series with genetic programming based on least square method
Lian et al. Traffic sign recognition using optimized federated learning in internet of vehicles
CN115577797B (zh) 一种基于本地噪声感知的联邦学习优化方法及系统
CN115150288B (zh) 一种分布式通信系统和方法
CN116502709A (zh) 一种异质性联邦学习方法和装置
CN116976461A (zh) 联邦学习方法、装置、设备及介质
CN116259057A (zh) 基于联盟博弈解决联邦学习中数据异质性问题的方法
Ching et al. Dual-objective personalized federated service system with partially-labeled data over wireless networks
CN114925848A (zh) 一种基于横向联邦学习框架的目标检测方法
Ma et al. cFedDT: Cross-Domain Federated Learning in Digital Twins for Metaverse Consumer Electronic Products
Li et al. Research on intelligent volume algorithm based on improved genetic annealing algorithm
CN113743012A (zh) 一种多用户场景下的云-边缘协同模式任务卸载优化方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination