CN117523291A - 基于联邦知识蒸馏和集成学习的图像分类方法 - Google Patents

基于联邦知识蒸馏和集成学习的图像分类方法 Download PDF

Info

Publication number
CN117523291A
CN117523291A CN202311524190.2A CN202311524190A CN117523291A CN 117523291 A CN117523291 A CN 117523291A CN 202311524190 A CN202311524190 A CN 202311524190A CN 117523291 A CN117523291 A CN 117523291A
Authority
CN
China
Prior art keywords
model
client
learning
federal
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311524190.2A
Other languages
English (en)
Inventor
毛文杰
鱼滨
张琛
解宇
刘伟明
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xidian University
Original Assignee
Xidian University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xidian University filed Critical Xidian University
Priority to CN202311524190.2A priority Critical patent/CN117523291A/zh
Publication of CN117523291A publication Critical patent/CN117523291A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/042Knowledge-based neural networks; Logical representations of neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/098Distributed learning, e.g. federated learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Biomedical Technology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Multimedia (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开一种基于联邦知识蒸馏和集成学习的图像分类方法,其步骤为:服务器生成训练数据集和辅助数据集,构建联邦学习全局模型并进行初始化,将其下发至选择的客户端。客户端基于有监督损失和一致性约束损失训练本地模型,完成后将模型参数上传至服务器。服务器对接收到的模型进行加权聚合,利用辅助数据集进行基于集成学习的模型分段知识蒸馏过程,将客户端模型的知识融合到全局模型。本发明提升了全局模型的图像分类泛化性能,增强了客户端模型的分类精度,提高了系统对异质数据的鲁棒性。

Description

基于联邦知识蒸馏和集成学习的图像分类方法
技术领域
本发明属于图像处理技术领域,更进一步涉及数据分类技术领域中的一种基于联邦知识蒸馏和集成学习的图像分类方法。本发明可通过物联网中多个设备的协同训练,将训练好的模型用于图像分类任务。
背景技术
图像分类是计算机视觉中的一个重要任务,在实际应用中有广泛的应用,如自动驾驶、医疗诊断、人脸识别等。随着物联网的快速发展,大量的设备和传感器能够收集数据,然而,由于隐私安全的限制,这些数据无法集中存储和处理。联邦学习允许多个设备协同训练全局共享的共识模型,同时保护数据的隐私。在此模型的基础上,把未知的图像输入进去,得到该测试样本的预测类别。然而,在物联网中,每个设备节点的数据分布存在较大差异,这种分布异质性会导致设备训练出的模型漂移、难以收敛和精度降低的问题。一些研究提出在本地模型训练时使用知识蒸馏技术来约束本地模型向全局模型学习来解决此问题。然而,这种方法仍然存在不足:当设备拥有的数据较少时,难以从中学习到有效的信息;可能需要设备间共享的额外代理数据集,这违背了隐私保护要求,增加了通信开销,在实际应用中可能会带来困难。
安徽师范大学在其申请的专利文献“基于动态自适应知识蒸馏的联邦学习模型聚合方法”(专利申请号:CN202310682277.6,专利公开号:CN116681144A,公布日期2023.09.01)中提出了一种基于动态自适应知识蒸馏的联邦学习模型聚合方法。该方法主要包括如下步骤:(1)服务器初始化全局模型并将其发送至参与本轮训练的客户端;(2)客户端接收到全局模型后,确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,训练生成本地模型,并上传给服务器;(3)对接收到的本地模型进行聚合生成新的全局模型从而完成本轮训练过程。该方法存在的不足之处是,在面对客户端数据较少和分布不平衡的情况下执行联邦学习任务时,客户端模型优化程度有限,难以在全局模型中学习到有效信息,导致模型性能较低。
中国人民解放军总医院在其申请的专利文献“针对数据异质性的个性化联邦学习方法、系统及存储介质”(专利申请号:CN202311035140.8,专利公布号:CN116933866A,公布日期2023.10.24)中提出了一种针对数据异质性的个性化联邦学习方法、系统及存储介质。该方法包括如下步骤:(1)服务端将初始学习模型发送至客户端;(2)服务端根据所有客户端的数据分布相似性对客户端进行聚类,生成客户端的相似性网络图;(3)客户端进行本地迭代训练,得到训练梯度和第一权重参数进行训练更新后的第一更新参数;(4)服务端计算所有客户端上传的第一更新参数的平均值,得到第一平均更新参数;(5)服务端自动更新迭代初始学习模型,并得到第二权重参数更新后的第二更新参数;(6)服务端更新初始学习模型;(7)重复上述步骤直至模型损失函数收敛,得到联邦学习模型。该方法存在的不足之处是:该方法中客户端只上传特征提取层参数和梯度,在服务器进行迭代更新分类层参数并发送到客户端,这使得客户端本地模型的更新方向存在较大波动,降低了本地模型的收敛效率,最终影响在数据异质性场景下客户端模型的预测性能。
发明内容
本发明的目的在于针对上述现有技术存在的问题,提出一种基于联邦知识蒸馏和集成学习的图像分类方法,旨在解决数据异质联邦学习中训练图像分类模型时模型难以收敛和泛化性能差的问题。
实现本发明目的的思路是,本发明首先提出针对联邦学习客户端本地训练的损失函数的优化。对本地模型的目标函数施加一致性约束,使得客户端在迭代更新时,平衡本地模型和全局模型的更新方向,避免了各个客户端模型在聚合时的参数方差过大,造成全局模型无法收敛的问题。其次,本发明提出全局模型的可学习聚合策略和知识迁移方法。在服务器端,先进行客户端模型的加权聚合,得到全局模型。随后使用基于集成学习的策略,通过构建可学习的集成模型,结合多个本地模型,使用分段模型蒸馏训练,挖掘并学习在各个非独立同分布数据的上训练得到的模型的潜藏知识来提高模型的泛化性能。由于全局模型吸收了各个客户端的知识,因此在差异化数据样本空间上进行优化的时候,具有较强的抵抗数据异质性的鲁棒性。因此,本发明通过在联邦学习中应用集成学习策略和蒸馏学习方法,实现了在缺乏大量监督数据、类别分布不平衡、数据呈非独立同分布情况下学习高性能模型的目标。
实现本发明目的的具体步骤如下:
步骤1,生成和分配样本集:
生成训练样本集和辅助样本集,为每个客户端分配各自的客户端样本集;
步骤2,在服务端构建一个卷积神经网络和一个多层感知机网络,初始化网络参数,分别作为联邦学习全局模型和全局集成模型;
步骤3,服务器确定参与客户端:
服务器随机选择Ns个客户端,并将其确定为下个轮次将要参与联邦学习的客户端,随后将联邦学习全局模型分发给被选择的客户端,Ns≥3;
步骤4,客户端进行本地训练:
将每个参与客户端的样本集输入到其对应的模型中,使用有监督损失和一致性损失作为本地模型学习的联合损失函数,采用随机梯度下降算法作为优化器,进行梯度的反向传播计算更新模型参数,直到本地模型达到收敛;最终每个客户端得到训练好的联邦学习客户端模型,在客户端本地暂存一个模型副本,并将训练好的模型上传到服务器;
步骤5,服务器进行客户端模型的集成和知识迁移:
步骤5.1,服务器对本轮次接收到的从客户端上传的模型参数进行加权聚合,获得聚合后的联邦学习全局模型:
步骤5.2,将辅助样本集输入到客户端上传的模型,得到每个本地模型的中间输出向量和类别预测向量;
步骤5.3,将辅助数据集输入聚合后的联邦学习全局模型,得到中间输出向量和类别预测向量;
步骤5.4,计算步骤5.2和步骤5.3中得到的模型中间输出向量分布的KL散度,并求平均,得到基于模型特征输出分布的蒸馏损失值
步骤5.5,将步骤5.2中得到的类别预测向量作为输入,送入服务器全局集成模型中,得到综合类别预测向量
步骤5.6,将综合类别预测向量和步骤5.3中得到的联邦学习全局模型类别预测向量进行KL散度计算,得到基于模型预测软分布的蒸馏损失值/>
步骤5.7,将辅助样本集中的真实标签y和步骤5.3得到的联邦学习全局模型类别预测向量计算得到联邦学习全局模型拟合辅助样本数据的损失
步骤5.8,将三个损失值作为服务端损失函数使用随机梯度下降算法和梯度反向传播,对联邦学习全局模型进行再训练;
步骤6,判断服务器最终全局模型是否满足联邦学习训练的终止条件,若是,则得到最终训练好的联邦学习全局模型后执行步骤7,否则,将当前迭代次数加1后执行步骤3;
步骤7,将待分类的图像样本输入到训练好的联邦学习全局模型中,输出分类结果。
本发明与现有技术相比,具有以下优点:
第一,本发明在服务器端进行分段模型知识蒸馏和知识迁移过程,将客户端模型所含的通用和任务相关的知识向全局模型转移,结合了多个客户端模型的知识来指导全局模型的微调。由于全局模型吸收了各个客户端的知识,在差异化数据样本空间上进行优化的时候,具有较强的抵抗数据异质性的鲁棒性,从而使得本发明在图像数据匮乏、数据呈非独立同分布的情况下可以训练得到具有较强泛化性的图像分类模型。
第二,本发明在服务器端利用模型集成学习策略对客户端模型进行聚合,通过设计一个可学习的集成模型,对各个模型的输出进行适应性地赋予权重和聚合,得到较好的指导输出信息,缓解了直接对各个客户端模型进行聚合后产生的由于参数方差过大而造成的聚合模型分类性能低下的问题。使得本发明提高了联邦学习全局模型的整体训练效率,能够较快地达到所需的全局模型分类准确率的要求。
第三,本发明在客户端应用模型参数一致性约束,对各个客户端的本地模型参数的更新方向进行限制,在个性化更新的同时考虑全局模型的优化方向,从而降低了本地模型在训练阶段独立更新时造成的模型漂移问题,在一定程度提高了联邦学习聚合得到的全局模型的分类性能。使得本发明扩展了联邦学习在数据异质性环境下的适应性,加速了联邦的训练过程,提高了本地模型的分类准确率。
第四,本发明基于传统的联邦学习过程,在服务器和客户端之间的通信只涉及模型参数信息的传输,而不泄露任何本地或服务器的私有数据。客户端之间被禁止传输和共享原始数据;客户端和服务器上的私有数据也不会被披露给第三方,该方法能够保证各个参与客户端的隐私,因此本发明具有较高的隐私安全性。
附图说明
图1是本发明的流程图。
具体实施方式
下面结合图1和实施例,对本发明的实现步骤做进一步的详细描述。
步骤1,生成和分配样本集。
本发明的实施是将图像分类领域公开的Fashion-MNIST数据集,它包含了10个类别中的几种时尚服装的图像。该数据集包括60000张用于训练的灰度图像和10000张用于测试的灰度图像。该数据集中包括从0到9的10个类别。
步骤1.1,对训练样本集中的每张图像进行随机裁剪、随机水平翻转、随机角度旋转的预处理操作,并限定最终的裁剪后的图像尺寸为28×28,得到预处理后的样本集。
步骤1.2,从预处理好的样本集的每种类别中随机选取500张图像,得到总共5000张图像,组成服务器端辅助样本集,其余55000个样本组成训练样本集。
步骤1.3,从训练样本集中随机选取至少3种类别,每种类别随机选取1000张图像,总共得到选取的3000张图像,将所选取的图像组成一个客户端的样本集,联邦学习系统中总共10个客户端,从而得到10个不同的客户端样本集,其中每个客户端对应一个客户端样本集。
步骤1.4,从每个客户端样本集中随机选取2400个样本组成训练集,将样本集中剩余的600个样本组成本地测试集。
步骤2,构建联邦学习全局模型和全局集成模型。
在服务端构建一个卷积神经网络和一个多层感知机网络,并进行网络参数的初始化,分别作为联邦学习全局模型和全局集成模型。
步骤2.1,基于步骤1分配的灰度视觉图像样本集,构建可用于该类型数据的卷积神经网络架构,作为联邦学习全局网络,使其能够通过训练而拟合图像数据集,学习并优化模型参数,从而完成测试集图像的分类预测任务。具体来说,本发明实施例中所述的卷积神经网络是由12个层串联而成的网络,其结构依次为:第一卷积层,第一批归一化层,第一激活层,第一池化层,第二卷积层,第二批归一化层,第二激活层,第二池化层,第一全连接层,第三激活层,dropout层,第二全连接层。前8层为特征提取模块,后4层为预测模块。将第一、第二卷积层的卷积核的个数分别设置为16,32,卷积核的大小均设置为5×5,步长均设置为1,填充宽度均设置为2。第一至第三激活层采用Relu激活函数,将inplace参数设置为False。单个图像样本输入网络经过第一卷积层和第二卷积层处理后的特征图维度分别为14×14,7×7。第一、第二池化层均采用最大池化方式,池化区域核的大小均设置为2×2,池化步长均设置为2。第一、第二批归一化层的eps参数设置为1×10-5,momentum参数设置为0.1,affine参数设置为True。将Dropout层的drop_rate参数设置为0.05。将第一、第二连接层的神经元的数量分别设置为512和10,其中10等于Fashion-MNIST数据集的类别总数。
步骤2.2,构建位于服务端的多层感知机网络,其结构包括3个全连接层和2个激活层,分别为第一全连接层,第一激活层,第二全连接层,第二激活层,第三全连接层;其中,第一、第二全连接层都具有128个隐藏单元,随后分别连接ReLU激活函数,并将inplace参数设置为False;第三全连接层具有C个神经元。该网络用于集成各个客户端模型的logits,自动给不同客户端模型的logits分配权重,最终获得集成后的logits,作为指导知识参与全局模型的再训练。
步骤2.3,针对卷积神经网络和多层感知机网络中的卷积层和全连接层权重参数使用Kaiming初始化,偏置参数使用0固定值初始化方法。
步骤3,服务器确定参与客户端并下发模型参数。
服务器随机选择3个客户端,并将其确定为下个轮次将要参与联邦学习的客户端,随后将联邦学习全局模型分发给被选择的客户端。
步骤4,客户端进行本地训练、模型暂存和上传模型参数到服务器。
参与客户端接收服务器下发的全局模型,作为本地训练的客户端模型。将每个参与客户端的样本集输入到其对应的模型中,使用有监督损失和一致性损失作为本地模型学习的联合损失函数,采用随机梯度下降算法作为优化器,进行梯度的反向传播计算更新模型参数,直到本地模型达到收敛。最终每个客户端得到训练好的联邦学习客户端模型,在客户端本地暂存一个模型副本,并将训练好的模型上传到服务器。
步骤4.1,接收服务器下发的联邦学习全局模型的参数,将其加载进本地模型中,作为当前轮次参与本地训练的联邦学习客户端模型的初始化参数。
步骤4.2,利用下述求差公式,计算当前本地训练过程中的联邦学习本地客户端模型的参数矩阵与暂存的上个轮次的联邦学习本地客户端模型的参数矩阵的变化量,可以被表示如下:
其中,r表示第k个联邦学习客户端当前正在训练的模型的参数矩阵中的参数值与上个本地轮次训练后相对应行和列的参数值的变化量,||·||表示计算欧式距离,wk是客户端k在训练过程中的本地模型权重参数矩阵,wl是上个全局训练轮次后第k个客户端暂存的本地模型参数矩阵。由于将联邦学习客户端模型参数矩阵中的参数值和上个轮次的参数变化量作为模型学习的约束,减小了联邦学习客户端模型参数的方差,降低由于数据异质性导致的模型漂移程度。
步骤4.3,根据计算出的当前模型相对于上个轮次的变化量r对联邦学习客户端模型计算模型的一致性约束损失,并结合在本地有标签数据上的交叉熵损失,构建本地模型学习的联合损失函数,该损失函数可以表示为;
其中,表示第k个联邦学习客户端的本地训练的联合损失,Dk表示第k个客户端样本集中样本的总数,C表示样本集中的类别总数,/>表示训练集中第i个样本对应于第c个类别的真实标签,log(·)表示以自然常数w为底的对数操作,wk表示第k个客户端模型的权重参数,p(c|xi,wk)表示训练集中第i个样本xi输入第k个客户端本地模型所输出预测值中属于第c个类别的概率,/>表示平衡一致性损失的超参数,r表示本地客户端模型的参数矩阵的变化量。
步骤4.4,对当前轮次训练好的联邦学习客户端模型在本地进行暂存,并将当前全局轮次训练完毕的联邦学习客户端本地模型的参数上传到服务器。
本发明的实施例中所有被选择的客户端的本地模型都使用上述的本地训练方式,当3个客户端全都训练完成后,获得3个当前轮次训练完成的联邦学习客户端模型。
步骤5,服务器进行客户端模型的集成和知识迁移。
步骤5.1,服务器对本轮次接收到的从客户端上传的模型参数进行加权聚合,获得聚合后的联邦学习全局模型。具体来说,将客户端模型权重参数进行加权平均之后的模型权重加载到联邦学习全局模型中,更新全局模型的参数。利用下述聚合公式,服务器对Ns个客户端模型参数进行加权聚合获得联邦学习全局模型参数:
其中,Ns表示服务器端选择的客户端数量,Dk表示分配给第k个联邦学习客户端样本集中的样本数。Ds表示被采样的客户端子集中包含的所有样本总数,wk表示联邦学习客户端k上传到服务器的模型权重参数,wg表示聚合生成的新的联邦学习全局模型。
步骤5.2,将辅助样本集输入到各个被选中的客户端上传的模型中,进行前向传播,得到每个客户端模型的中间输出向量(最后一个卷积层输出的特征展平后的向量)和类别预测向量(最后一个全连接层输出的logits);
步骤5.3,将辅助数据集输入到聚合后的联邦学习全局模型中,进行前向传播,得到全局模型的中间输出向量和类别预测向量;
步骤5.4,计算步骤5.2和步骤5.3中得到的模型中间输出向量分布的KL散度,并将这些值进行求和平均,得到基于模型特征输出分布的蒸馏损失值表示如下:
其中,Ns表示服务器选择的客户端数量,x表示辅助样本集中的一个批次的样本,τ表示用于软化模型输出分布的蒸馏温度,其值根据蒸馏效果设定,zc表示模型输出的对应于第c个类别的logits值,Φ(·)表示由温度τ软化的softmax函数计算而得到的概率分布。表示将辅助样本输入到由/>参数化的联邦学习第k个客户端模型中特征提取模块输出的中间输出向量,/>表示将辅助样本输入到由/>参数化的联邦学习全局模型中特征提取模块输出的中间输出向量,KL(·)表示Kullback–Leibler散度函数,∑表示求和操作。
步骤5.5,将步骤5.2中得到的从各个被选择的客户端上传的本地模型中提取的类别预测向量作为输入,送入到服务器全局集成模型中,得到综合类别预测向量如下:
其中,wo表示全局集成模型的参数矩阵,concat(·)表示将辅助样本输入到被选择的客户端本地模型中输出的logits向量进行连接,表示将连接的logits向量输入名为olnet的全局集成模型而得到的综合类别预测向量。
步骤5.6,将综合类别预测向量和步骤5.3中得到的联邦学习全局模型类别预测向量进行KL散度计算,得到基于模型预测软分布的蒸馏损失值/>如下:
其中,表示将辅助样本输入到联邦学习全局模型而输出的logits向量,/>表示服务器端集成模型输出的综合类别预测向量。
步骤5.7,将辅助样本集中的真实标签y和步骤5.3得到的联邦学习全局模型类别预测向量送入交叉熵损失函数,计算得到联邦学习全局模型拟合辅助样本数据的损失如下:
其中,DS表示联邦学习服务器上辅助训练集中所拥有的样本总数,ym,c表示将第m个辅助样本输入全局模型的输出中对应于第c类的真实标签one-hot值,log(·)表示以自然常数e为底的对数函数。表示联邦学习全局模型对于第m个辅助样本输出的logits向量中对应于第c个类别的值,计算如下:
其中,表示将辅助样本xm输入到由/>参数化的联邦学习全局模型中特征提取模块而输出的中间输出向量,/>表示将中间输出向量输入到由/>参数化的联邦学习全局模型中预测模块而输出的类别预测向量。
步骤5.8,利用步骤5.4,步骤5.6,步骤5.7得到的三个损失值,通过超参数平衡各个损失之间的贡献,共同作为最终的损失函数服务器将辅助训练集数据输入到联邦学习全局模型中,利用随机梯度下降法,迭代更新全局模型权重参数,该训练进行5个epoch,得到二次微调训练之后的联邦学习全局模型。服务器全局模型训练中使用的总损失函数Ls可以表示为:
其中,γ,β,δ为平衡各个损失的超参数,Lce,Lkl,Lkd分别为交叉熵损失函数,嵌入向量匹配损失和概率分布一致性损失,γ,β,δ分别表示三项的平衡超参数。
步骤6,判断服务器最终全局模型是否满足联邦学习训练的终止条件,若是,则得到最终训练好的联邦学习全局模型后执行步骤7,否则,将当前迭代次数加1后执行步骤3;
上述的训练的终止条件指的是满足下述条件之一的情形:
条件1,联邦学习全局模型性能达到指定的预期目标;
条件2,联邦学习训练轮次达到预先设置的迭代次数。
步骤7,将待分类的图像样本输入到训练好的联邦学习全局模型中,输出分类结果。
下面结合计算机仿真实验对本发明的效果做进一步的说明。
1.仿真实验条件:
本发明的仿真实验的硬件平台为:处理器为Intel(R)Xeon(R)CPU E5-2650v4,主频为2.20GHz,内存256GB。
本发明的仿真实验的软件平台为:Ubuntu 20.04.3LTS操作系统、Python 3.7.0语言,使用基于CUDA 11.8版本的pytorch 1.11.0的编程框架。
本发明仿真实验所使用的数据集为Fashion MNIST(服饰数据集),它是Zalando文章图像的类似于MNIST的数据集,包含10类时尚服装的图像。与MNIST类似,该数据集包含70000张灰度图像,其中包含60000个样本的训练集和10000个样本的测试集,每个样本都是一个28x28灰度图像。然而,Fashion-MNIST中的一些图像类别在视觉上是相似的,这使得分类任务比常规MNIST手写数据更具有挑战性。
2.仿真内容及其结果分析:
联邦学习利用中央服务器协调各个客户端利用各自的私有数据训练本地模型,以期望联合起来获得一个具有良好预测能力的机器学习模型,而不泄露客户端的隐私。具体来说,本地客户端将训练好的本地模型参数上传到中央服务器,中央参数服务器对其进行加权聚合之后获得更新的联邦学习全局模型参数。随后,服务器借助辅助数据集将本地模型的知识向全局模型转移。接着,服务器将更新后的全局模型发布给客户端,进行下一轮次的本地训练。这种联邦学习模型的更新过程会遵循一种“上传-发布”的双向通信迭代学习的方式,直到满足预定的终止条件。
本发明仿真实验是采用本发明和六个现有技术(FedAvg、FedProx、Fedproto、CDKT-FL、PerAvg、pFedMe)分别在Fashion-MNIST数据集上进行联邦学习的模型训练,最后获得六种方法对应的训练好的联邦学习模型。仿真实验所模拟的场景是10个客户端协同训练一个联邦学习全局模型。
在仿真实验中,采用的六个现有技术是指:
现有技术FedAvg是指,Mcmahan等人在“Communication-Efficient Learning ofDeep Networks from Decentralized Data.2017”中提出的联邦平均学习算法,其中服务器接收客户端上传的模型参数,进行平均后生成新的全局模型下发给客户端参与下一轮的训练,简称FedAvg。
现有技术FedProx指的是,T.Li等人在“Federated optimization inheterogeneous networks.2020.”中提出的异构网络中的联邦学习模型优化方法,通过在局部训练中添加一个近端项来限制局部更新的长度来减少偏差,简称FedProx。
现有技术Fedproto是指,Y.Tan等人在“Fedproto:Federated prototypelearning across heterogeneous clients,2022”中提出的联邦原型学习框架,其中客户端和服务器通信抽象类原型而不是梯度,简称Fedproto。
现有技术CDKT-FL是指,M.N.Nguyen等人在“Cdkt-fl:Cross-device knowledgetransfer using proxy dataset in federated learning,2022”中提出的联邦学习知识迁移框架,并在服务器和客户端上执行跨设备知识转移,简称CDKT-FL。
现有技术PerAvg是指,A.Fallah等人在“Personalized federated learningwith theoretical guarantees:A model-agnostic meta-learning approach,2020”中提出的个性化联邦学习框架,使用了一种模型不可知论的元学习方法,简称PerAvg。
现有技术pFedMe是指,C.T Dinh等人在“Personalized federated learningwith moreau envelopes,2020”中提出的个性化联邦学习算法,使用Moreau包络作为客户的正则化损失函数,助于将个性化模型优化与全局模型学习分离,简称pFedMe。
为了验证本发明仿真实验的效果,利用下述两个评价指标:客户端平均准确率、全局模型泛化准确率,分别对六种对比方法和本发明的联邦学习模型训练结果进行评价,将所有计算结果绘制成表1和表2。
两种评价指标的计算方式为:
表1.仿真实验中本发明和各现有技术的全局模型测试结果一览表
方法 准确率
FedAvg 88.155%
CDKT-FL 78.259%
FedProx 35.370%
PerAvg 34.916%
pFedMe 10.793%
本发明方法 89.229%
表1中准确率是指10个客户端协同进行联邦学习模型训练,每个轮次的客户端采样率为30%,并使全局模型最终收敛时的全局模型泛化准确率。由表1可以看出,对比方法中的最经典的FedAvg训练方法最终收敛时在non-iid数据(狄利克雷分布参数α为1)上可以获得88.155%的准确率。CDKT-FL模型训练方法由于基于跨设备知识迁移方法,非独立同分布的数据集上获得了78.259%的准确率。FedProx针对客户端模型做约束优化,但是忽略了全局模型知识的迁移,最终获得35.370%的准确率。PerAvg和pFedMe属于个性化联邦学习算法,主要关注客户端模型的个性化预测性能,在非独立同分布数据和客户端参与率较低的情况下训练得到的全局模型准确率较差。本发明方法结合了知识蒸馏和集成学习的优越性,最终全局模型可以获得超越FedAvg训练方法的准确率(89.229%)。本发明方法既提高了图像分类模型在非独立同分布数据上训练的准确率,又提升了图像分类模型应对极端异质数据和小参与率场景下的联邦学习的鲁棒性,提高了联邦学习全局模型的泛化性能。
表2.仿真实验中本发明和各现有技术客户端模型测试结果一览表
方法 准确率
FedAvg 97.855%
CDKT-FL 95.249%
Fedproto 38.341%
FedProx 97.348%
PerAvg 94.869%
pFedMe 97.145%
本发明方法 97.892%
表2中的准确率是指10个客户端协同进行联邦学习模型训练,每个轮次的客户端采样率为30%的情况下,最终训练完成之后的每个客户端模型在测试数据上的分类平均准确率。从表2可以看出,本发明的方法在Fashion-MNIST数据集上的客户端模型平均性能达到的了现有技术的最高水平,提高了客户端模型在非独立同分布数据上参与联邦学习训练时的图像分类准确率,提升了客户端参与联邦学习的积极性,增强了本地模型的个性化图像分类性能。

Claims (9)

1.一种基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于,基于集成学习和分段知识蒸馏,在服务器上对聚合后的全局图像分类模型进行本地知识的集成与迁移,增强了全局模型的泛化性,提高了训练效率;在客户端上,将融合本地知识的全局模型进行带有一致性约束的更新,使全局模型传递的广义知识更好地适应局部表示,减轻本地模型的漂移问题;最终得到一个具有对抗数据异质鲁棒性和泛化性的图像分类模型;该图像分类方法的具体步骤包括如下:
步骤1,生成和分配样本集:
生成训练样本集和辅助样本集,为每个客户端分配各自的客户端样本集;
步骤2,在服务端构建一个卷积神经网络和一个多层感知机网络,初始化网络参数,分别作为联邦学习全局模型和全局集成模型;
步骤3,服务器确定参与客户端:
服务器随机选择Ns个客户端,并将其确定为下个轮次将要参与联邦学习的客户端,随后将联邦学习全局模型分发给被选择的客户端,Ns≥3;
步骤4,客户端进行本地训练:
将每个参与客户端的样本集输入到其对应的模型中,使用有监督损失和一致性损失作为本地模型学习的联合损失函数,采用随机梯度下降算法作为优化器,进行梯度的反向传播计算更新模型参数,直到本地模型达到收敛;最终每个客户端得到训练好的联邦学习客户端模型,在客户端本地暂存一个模型副本,并将训练好的模型上传到服务器;
步骤5,服务器进行客户端模型的集成和知识迁移:
步骤5.1,服务器对本轮次接收到的从客户端上传的模型参数进行加权聚合,获得聚合后的联邦学习全局模型:
步骤5.2,将辅助样本集输入到客户端上传的模型,得到每个本地模型的中间输出向量和类别预测向量;
步骤5.3,将辅助样本集输入聚合后的联邦学习全局模型,得到中间输出向量和类别预测向量;
步骤5.4,计算步骤5.2和步骤5.3中得到的模型中间输出向量分布的KL散度,并求平均,得到基于模型特征输出分布的蒸馏损失值
步骤5.5,将步骤5.2中得到的类别预测向量作为输入,送入服务器全局集成模型中,得到综合类别预测向量
步骤5.6,将综合类别预测向量和步骤5.3中得到的联邦学习全局模型类别预测向量进行KL散度计算,得到基于模型预测软分布的蒸馏损失值/>
步骤5.7,将辅助样本集中的真实标签y和步骤5.3得到的联邦学习全局模型类别预测向量计算得到联邦学习全局模型拟合辅助样本数据的损失
步骤5.8,将三个损失值作为服务端损失函数使用随机梯度下降算法和梯度反向传播,对联邦学习全局模型进行再训练;
步骤6,判断服务器最终全局模型是否满足联邦学习训练的终止条件,若是,则得到最终训练好的联邦学习全局模型后执行步骤7,否则,将当前迭代次数加1后执行步骤3;
步骤7,将待分类的图像样本输入到训练好的联邦学习全局模型中,输出分类结果。
2.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于,步骤1中所述的训练样本集和辅助样本集指的是,生成至少包含10种类别的图像样本,其中每种类别至少6000张图像,将所选取的所有的图像组成样本集;将样本集中的每种类别随机选取至少500个样本组成辅助样本集,剩余的样本组成训练样本集;步骤1中所述的客户端样本集指的是,从训练样本集中的类别中随机选取至少3种类别,每种类别至少1000张图像,将所选取的图像组成一个客户端样本集;联邦学习系统中存在至少10个客户端,从而得到至少10个不同的客户端样本集,其中每个客户端对应一个客户端样本集。
3.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤2中所述的卷积神经网络是由12个层串联而成,其结构依次为:第一卷积层,第一批归一化层,第一激活层,第一池化层,第二卷积层,第二批归一化层,第二激活层,第二池化层,第一全连接层,第三激活层,dropout层,第二全连接层;前8层为特征提取模块,后4层为预测模块;将第一、第二卷积层的卷积核的个数分别设置为16,32,卷积核的大小均设置为5×5,步长均设置为1,填充宽度均设置为2;第一至第三激活层采用Relu激活函数,将inplace参数设置为False;单个图像样本输入网络经过第一卷积层和第二卷积层处理后的特征图维度分别为14×14,7×7;第一、第二池化层均采用最大池化方式,池化区域核的大小均设置为2×2,池化步长均设置为2;第一、第二批归一化层的eps参数设置为1×10-5,momentum参数设置为0.1,affine参数设置为True;将Dropout层的drop_rate参数设置为0.05;将第一、第二连接层的神经元的数量分别设置为512和C,其中C等于数据集样本的类别总数;步骤2中所述的多层感知机网络的结构包括3个全连接层和2个激活层,分别为第一全连接层,第一激活层,第二全连接层,第二激活层,第三全连接层;其中,第一、第二全连接层都具有128个隐藏单元,随后分别连接ReLU激活函数,并将inplace参数设置为False;第三全连接层具有C个神经元。
4.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤4中所述本地模型学习的联合损失函数如下:
其中,表示第k个联邦学习客户端的本地训练的联合损失,Dk表示第k个客户端样本集中样本的总数,C表示样本集中的类别总数,/>表示训练集中第i个样本对应于第c个类别的真实标签,log(·)表示以自然常数e为底的对数操作,wk表示第k个客户端模型的权重参数,p(c|xi,wk)表示训练集中第i个样本xi输入第k个客户端本地模型所输出预测值中属于第c个类别的概率,/>表示平衡一致性损失的超参数,wl表示第k个客户端本地暂存的模型参数。
5.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.4中所述的基于模型特征输出分布的蒸馏损失值如下:
其中,Ns表示服务器选择的客户端数量,τ表示用于软化模型输出分布的蒸馏温度,其值根据蒸馏效果设定,zc表示模型输出的对应于第c个类别的logits值,Φ(·)表示由温度τ软化的softmax函数计算而得到的概率分布;表示将辅助样本输入到由/>参数化的联邦学习第k个客户端模型中特征提取模块输出的中间输出向量,/>表示将辅助样本输入到由/>参数化的联邦学习全局模型中特征提取模块输出的中间输出向量,KL(·)表示Kullback–Leibler散度函数,∑表示求和操作。
6.根据权利要求5所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.5中所述的综合类别预测向量如下:
其中,wo表示全局集成模型的参数矩阵,concat(·)表示将辅助样本输入到被选择的客户端本地模型中输出的logits向量进行连接,表示将连接的logits向量输入名为olnet的全局集成模型而得到的综合类别预测向量。
7.根据权利要求6所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.6中所述的基于模型预测软分布的蒸馏损失值如下:
其中,表示将辅助样本输入到联邦学习全局模型而输出的logits向量,/>表示服务器端集成模型输出的综合类别预测向量。
8.根据权利要求5所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.7中所述的联邦学习全局模型拟合辅助样本数据的损失如下:
其中,DS表示联邦学习服务器上辅助训练集中所拥有的样本总数,ym,c表示将第m个辅助样本输入全局模型的输出中对应于第c类的真实标签one-hot值;表示联邦学习全局模型对于第m个辅助样本输出的logits向量中对应于第c个类别的值,计算如下:
其中,表示将辅助样本xm输入到由/>参数化的联邦学习全局模型中特征提取模块而输出的中间输出向量,/>表示将中间输出向量输入到由/>参数化的联邦学习全局模型中预测模块而输出的类别预测向量。
9.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤6所述的训练的终止条件指的是满足下述条件之一的情形:
条件1,联邦学习全局模型性能达到指定的预期目标;
条件2,联邦学习训练轮次达到预先设置的迭代次数。
CN202311524190.2A 2023-11-15 2023-11-15 基于联邦知识蒸馏和集成学习的图像分类方法 Pending CN117523291A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311524190.2A CN117523291A (zh) 2023-11-15 2023-11-15 基于联邦知识蒸馏和集成学习的图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311524190.2A CN117523291A (zh) 2023-11-15 2023-11-15 基于联邦知识蒸馏和集成学习的图像分类方法

Publications (1)

Publication Number Publication Date
CN117523291A true CN117523291A (zh) 2024-02-06

Family

ID=89756319

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311524190.2A Pending CN117523291A (zh) 2023-11-15 2023-11-15 基于联邦知识蒸馏和集成学习的图像分类方法

Country Status (1)

Country Link
CN (1) CN117523291A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117829320A (zh) * 2024-03-05 2024-04-05 中国海洋大学 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117829320A (zh) * 2024-03-05 2024-04-05 中国海洋大学 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法

Similar Documents

Publication Publication Date Title
JP6625785B1 (ja) データ識別器訓練方法、データ識別器訓練装置、プログラム及び訓練方法
WO2022121289A1 (en) Methods and systems for mining minority-class data samples for training neural network
WO2022063151A1 (en) Method and system for relation learning by multi-hop attention graph neural network
CN106462724A (zh) 基于规范化图像校验面部图像的方法和系统
CN112465120A (zh) 一种基于进化方法的快速注意力神经网络架构搜索方法
CN109740734B (zh) 一种利用优化神经元空间排布的卷积神经网络的图像分类方法
WO2021051987A1 (zh) 神经网络模型训练的方法和装置
CN117523291A (zh) 基于联邦知识蒸馏和集成学习的图像分类方法
CN107220368B (zh) 图像检索方法及装置
CN113268669B (zh) 基于联合神经网络的面向关系挖掘的兴趣点推荐方法
CN114943345A (zh) 基于主动学习和模型压缩的联邦学习全局模型训练方法
CN113822315A (zh) 属性图的处理方法、装置、电子设备及可读存储介质
WO2023036184A1 (en) Methods and systems for quantifying client contribution in federated learning
CN115587633A (zh) 一种基于参数分层的个性化联邦学习方法
CN113239638A (zh) 一种基于蜻蜓算法优化多核支持向量机的逾期风险预测方法
CN113987236B (zh) 基于图卷积网络的视觉检索模型的无监督训练方法和装置
CN114997374A (zh) 一种针对数据倾斜的快速高效联邦学习方法
CN114358250A (zh) 数据处理方法、装置、计算机设备、介质及程序产品
CN112541530B (zh) 针对聚类模型的数据预处理方法及装置
CN108154165B (zh) 基于大数据与深度学习的婚恋对象匹配数据处理方法、装置、计算机设备和存储介质
CN111309923B (zh) 对象向量确定、模型训练方法、装置、设备和存储介质
CN116645130A (zh) 基于联邦学习与gru结合的汽车订单需求量预测方法
CN116259057A (zh) 基于联盟博弈解决联邦学习中数据异质性问题的方法
CN115660116A (zh) 基于稀疏适配器的联邦学习方法及系统
CN115936110A (zh) 一种缓解异构性问题的联邦学习方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination