CN117408330B - 面向非独立同分布数据的联邦知识蒸馏方法及装置 - Google Patents

面向非独立同分布数据的联邦知识蒸馏方法及装置 Download PDF

Info

Publication number
CN117408330B
CN117408330B CN202311714820.2A CN202311714820A CN117408330B CN 117408330 B CN117408330 B CN 117408330B CN 202311714820 A CN202311714820 A CN 202311714820A CN 117408330 B CN117408330 B CN 117408330B
Authority
CN
China
Prior art keywords
data
model
fusion
client
global
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202311714820.2A
Other languages
English (en)
Other versions
CN117408330A (zh
Inventor
田辉
王欢
郭玉刚
张志翔
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hefei High Dimensional Data Technology Co ltd
Original Assignee
Hefei High Dimensional Data Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hefei High Dimensional Data Technology Co ltd filed Critical Hefei High Dimensional Data Technology Co ltd
Priority to CN202311714820.2A priority Critical patent/CN117408330B/zh
Publication of CN117408330A publication Critical patent/CN117408330A/zh
Application granted granted Critical
Publication of CN117408330B publication Critical patent/CN117408330B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0475Generative networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/094Adversarial learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/098Distributed learning, e.g. federated learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/776Validation; Performance evaluation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/94Hardware or software architectures specially adapted for image or video understanding
    • G06V10/95Hardware or software architectures specially adapted for image or video understanding structured as a network, e.g. client-server architectures

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Multimedia (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本申请涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置,其包括根据公共数据集进行随机采样,获取辅助数据集;基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据;控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据;控制客户端根据预设的局部模型蒸馏算法以及融合数据对深度学习模型进行优化训练,得到全局模型,本申请通过生成网络模型和局部模型蒸馏算法对客户端的深度学习模型进行优化,减少深度学习模型的优化目标与全局优化目标的偏差。

Description

面向非独立同分布数据的联邦知识蒸馏方法及装置
技术领域
本申请涉及数据安全技术领域,尤其是涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置。
背景技术
随着互联网、物联网、云计算和大数据等各种技术的快速发展,企业面临海量的数据处理与分析,数据的搜集、共享、发布和分析过程中可能导致用户隐私信息的泄露,给用户带来巨大损失。同时,全球数据保护法规越来越严格,企业在使用数据过程中面临隐私泄露和数据违规风险。因此,隐私计算技术变得越发重要。
联邦学习是一种新兴的人工智能技术,最初由谷歌在2016年提出,旨在解决个人数据在安卓手机端的隐私问题。该技术的设计动机是保护手机或平板计算机中用户的隐私数据,因此提出了一种数据不动模型动的新型分布式机器学习范式。联邦学习可以看成是一种分布式机器学习框架,与传统的分布式机器学习框架不同,其使用了加密技术,并且各方数据保存在本地。在联邦学习中,各个参与方(例如手机、平板计算机等设备)将本地数据进行计算和更新,然后将结果发送回中央服务器进行聚合。联邦学习体现了集中数据收集和最小化的原则,可以减轻传统集中式机器学习和数据挖掘方法带来的系统和统计层面上的隐私风险和通信效率开销。
针对上述中的相关技术,由于联邦学习系统中各个客户端通过不同的硬件或软件设备收集并处理数据,因此客户端之间的数据分布往往是差异极其大的,并进一步导致各客户端深度学习模型的参数不一致。各客户端深度学习模型的优化目标与全局优化目标存在偏差,在模型训练时会远离最优点,从而导致模型在效率、效果、隐私保护层面上都不能达到一个很好的效果。
发明内容
为了改善各客户端深度学习模型的优化目标与全局优化目标存在偏差,在模型训练时会远离最优点,从而导致模型在效率、效果、隐私保护层面上都不能达到一个很好的效果的问题,本申请提供一种面向非独立同分布数据的联邦知识蒸馏方法及装置。
第一方面,本申请提供的一种面向非独立同分布数据的联邦知识蒸馏方法,采用如下的技术方案:包括:
根据预设的公共数据集进行随机采样,获取辅助数据集;
基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;
控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;
控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户的深度学习模型进行优化训练,得到全局模型。
可选的,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。
可选的,所述对抗目标损失函数的计算公式为:
其中,为所述辅助数据集中的数据样本,/>为所述噪声向量,/>为所述生成网络,/>和/>则分别代表所述生成网络/>和所述鉴别网络/>的模型参数。
可选的,所述互信息平滑损失函数的计算公式为:
其中,代表一次批处理过程中所述噪声向量/>的数量。
可选的,所述相似度惩罚损失函数的计算公式为:
其中,和/>代表重复采样过程中不同的噪声向量。
可选的,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:
基于所述生成网络模型生成的所述生成网络数据/>和客户端的所述本地数据/>通过所述数据融合算法进行融合,得到所述融合数据/>
其中,所述数据融合算法的计算公式为:
其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,/>为样本/>的伪标签,/>和/>为合成后的数据样本和标签。
可选的,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型,包括:
计算所述生成网络数据与所述本地数据之间的数量比例;
控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对生成网络进行优化训练,得到所述全局模型;
其中,所述局部模型蒸馏算法的计算公式为:
其中,其中为所述本地数据的样本数量,/>为所述生成网络数据的样本数量,是代表客户端本地的深度学习模型/>在所述生成网络数据/>和所述融合数据/>之间Kullback-Leibler距离,/>为用于调整知识蒸馏强度的参数,/>为所述生成网络数据中标签为/>的样本数量,/>则代表归一化指数函数。
可选的,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型之后,还包括:
若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述全局模型进行迭代优化,获取全部客户端的优化模型;
接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。
可选的,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型之后,还包括:
接收全体客户端深度学习模型的模型参数;
基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;
基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;
基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;
将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;
其中,所述集成模型的计算公式为:
其中,是一个可学习参数并处于0到1之间,/>则是用于控制权重参数正则化的程度,/>代表客户端上的所述模型参数;
所述全局聚合蒸馏算法的定义如下:
其中代表所述全局模型,/>代表所述集成模型,/>为所述虚拟数据集中的数据样本。
第二方面,本申请还提供一种面向非独立同分布数据的联邦知识蒸馏装置,采用如下技术方案,包括:
数据采样模块,用于根据预设的公共数据集进行随机采样,获取辅助数据集;
生成网络模块,用于基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
数据生成模块,用于将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;
数据融合模块,用于控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据;
模型优化模块,用于控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型。
综上所述,本申请通过采用上述技术方案,服务器根据公共数据集进行随机采样,并根据辅助数据集和优化函数对生成网络进行预训练,获取生成网络模型,服务器再将生成网络模型发送至客户端,客户端根据噪声向量输出对应的生成网络数据,客户端根据将本地数据和生成网络数据通过数据融合算法进行动量融合,并根据局部蒸馏算法以及融合数据对生深度学习模型进行优化训练,直至所有客户端依次对全局模型进行优化迭代后,客户端将全局模型发送至服务器,服务器再将全局模型进行平均加权处理后下发至所有客户端,从而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
附图说明
图1是本申请实施例中一种面向非独立同分布数据的联邦知识蒸馏方法的流程示意图。
图2是本申请实施例中一种面向非独立同分布数据的联邦知识蒸馏装置的结构框图。
附图标记说明:310、数据采样模块;320、生成网络模块;330、数据生成模块;340、数据融合模块;350、模型优化模块。
具体实施方式
以下结合附图1-2对本申请作进一步详细说明。
本申请实施例公开一种面向非独立同分布数据的联邦知识蒸馏方法,知识蒸馏是获取高效小规模网络的一种新兴方法,其主要思想是将学习能力强的模型中的信息迁移到简单的模型中去,可以有效提取出数据中的潜在信息。
本申请主要通过优化函数对生成网络进行预训练,得到生成网络模型,所有客户端基于生成网络模型以及本地数据对深度学习模型进行优化,得到全局模型,最后服务器将全局模型再下发至所有客户端,减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大大提高了深度学习模型图像分类任务的准确率。
其中,客户端的深度学习模型可以是ResNet深度神经网络模型,ResNet深度神经网络模型是指:论文“Deep Residual Learning for Image Recognition”中提出的基于ResNet深度神经网络模型进行图像识别的方法,简称ResNet深度神经网络模型。
参照图1,本申请实施例至少包括步骤S10至步骤S50。
S10,根据预设的公共数据集进行随机采样,获取辅助数据集。
其中,本申请实施例中所采用的公共数据集为CIFAR-10和CIFAR-100数据集,也可以使用其他数据集。
应当理解的是,由于参与模型训练的公共数据集都是符合独立同分布的数据集,但是这并不满足联邦学习系统中跨客户端本地数据之间非独立同分布的假设。因此,本申请基于狄利克雷分布来划分公共数据集,以满足跨客户端本地数据之间非独立同分布的要求。并且,由于是从公共数据集中随机采样,因此不会泄露各参与客户端的私有数据信息。
本申请实施例在CIFAR-10数据集上测试基于狄利克雷分布的非独立同分布数据划分算法,并进行可视化呈现,其中规定客户端数量,狄利克雷分布的参数向量/>满足/>,其中/>
S20,基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型。
其中,生成网络和鉴别网络是生成对抗网络的组成部分,生成对抗网络由IanGoodfellow 等人在2014年提出,它是一种深度神经网络架构,由一个生成网络和一个鉴别网络组成。生成网络产生『假』数据,并试图欺骗鉴别网络;鉴别网络对生成数据进行真伪鉴别,试图正确识别所有假数据。
S30,将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据。
S40,控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据。
其中,本申请实施例中的本地数据是基于狄利克雷分布对公共数据集进行划分后平均分配至每个客户端的,每个客户端的本地数据数量一致,但内容和类别不一致。
S50,控制客户端根据预设的局部模型蒸馏算法以及融合数据对深度学习模型进行优化训练,得到全局模型。
具体来说,服务器根据公共数据集进行随机采样,并根据辅助数据集和优化函数对生成网络进行预训练,获取生成网络模型,服务器再将生成网络模型发送至客户端,客户端根据噪声向量输出对应的生成网络数据,客户端根据将本地数据和生成网络数据通过数据融合算法进行动量融合,并根据局部蒸馏算法以及融合数据对生成网络模型进行优化训练,直至所有客户端依次对全局模型进行优化迭代后,客户端将全局模型发送至服务器,服务器再将平均加权处理后全局模型下发至所有客户端,从而所基于生成网络模型对本地的深度学习模型进行迭代优化,进而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
实际来说,对于联邦学习中客户端的深度学习模型而言,其定义为/>(其模型参数为/>),对于辅助数据集/>而言,其中每个样本/>是从初始的公共数据集/>中随机采样得到。需要注意的是,客户端/>上的本地数据集/>是符合非独立同分布的,同时客户端总数量为/>且全局模型被定义为/>(其模型参数为/>)。
在一些实施例中,针对中央服务器而言,基于辅助数据集使用数据样本/>和基于高斯噪声初始化的噪声向量/>,通过对抗目标损失函数/>来训练出一个轻量级生成网络模型。对抗目标损失函数/>的计算公式为:
其中和/>分别代表生成网络/>和鉴别网络/>的模型参数,注意在训练生成器模型的过程中输入样本/>可以是真实数据/>或者是之前生成器所生成的旧数据/>
在一些实施例中,考虑到随机抽样的辅助数据集的样本数量过少,为了减少在生成网络模型的训练过程中出现模式崩溃等问题,本申请实施例从互信息的角度出发,将鉴别网络/>视为一个分类模型,然后通过互信息平滑损失函数/>来最大化生成网络数据的平均信息熵,以此达到平衡生成网络模型类分布的目的。互信息平滑损失函数/>的计算公式为:
其中代表一次批处理过程中噪声向量/>的数量,通过互信息平滑损失函数/>可以使得基于生成器/>生成的数据的类别信息更加平衡。
在一些实施例中,为了进一步增强生成网络模型生成的生成网络数据的多样性,本申请实施例从重采样的角度提出了相似度惩罚损失函数,即考虑到不同的噪声向量/>和/>,基于相似度惩罚损失函数/>在生成相似类别的同时扩大/>和/>之间的距离。的相似度惩罚损失函数/>的计算公式为:
通过相似度惩罚损失函数可以使得生成器有效地生成同一类别的不同样本。
进一步的,基于对抗目标损失函数、互信息平滑损失函数/>、以及相似度惩罚损失函数/>可以得到生成网络的优化函数/>,基于此优化目标使得生成网络可以生成更多样化且更清晰的数据样本。优化函数/>的计算公式为:
通过优化函数基于辅助数据集/>训练生成网络,从而得到生成网络模型。
在一些实施例中,服务器将预训练的生成网络模型发送给参与训练的各客户端,对于客户端而言,基于生成网络模型生成的生成网络数据/>和客户端的本地数据/>通过动量数据融合算法进行融合,可以得到融合数据/>。动量数据融合算法的计算公式为:
其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,为样本/>的伪标签,/>和/>为合成后的数据样本和标签,其有效保留了生成网络数据/>和本地数据/>的类别信息。
接着,客户端计算生成网络数据与本地数据在融合数据中所占的比例,并在客户端模型局部训练中应用该比例对损失进行加权计算。然后,客户端将合成数据和/>视为一种先验信息,基于局部模型蒸馏算法并设计优化目标,从而对客户端/>的本地模型/>进行优化。局部模型蒸馏算法的计算公式/>为:
其中,其中为本地数据的样本数量,/>为生成网络数据的样本数量,/>是代表客户端本地的深度学习模型/>在生成网络数据/>和融合数据/>之间Kullback-Leibler距离,/>为调整知识蒸馏强度的参数,/>为生成网络数据中标签为/>的样本数量,/>代表归一化指数函数。例如:生成网络数据有20个样本,本地数据有80个样本,那么计算损失的时候,/>目标函数需要乘上80/(20+80)=0.8。
通过数据融合算法和局部模型蒸馏算法对深度学习模型进行优化,大大增加了深度学习模型对本地数据的拟合度。
进一步的,若存在多个客户端,则控制每个客户端通过局部模型蒸馏算法、数据融合算法对深度学习模型进行迭代优化,获取全部客户端的优化模型;接收所有客户端的优化模型,并根据优化模型进行平均加权处理,得到全局模型,从而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
在一些实施例中,服务器接收全体客户端深度学习模型的模型参数,基于每个客户端的模型参数通过可学习参数进行加权处理,得到集成模型,集成模型/>的定义如下:
其中,是一个可学习的参数并处于0到1之间,/>则是用于控制权重参数正则化的程度,/>代表客户端/>上的模型参数。
接着,服务器基于生成网络模型批量生成的生成网络数据,获取一个虚拟数据集,并基于全局聚合蒸馏算法和集成模型,通过解耦数据中的类别信息从而对全局模型进行微调。全局聚合蒸馏算法/>的定义如下:
其中代表全局模型,/>代表客户端的集成模型。
最后,基于虚拟数据集,通过全局聚合蒸馏算法微调全局模型/>,重复上述的步骤,控制每个客户端根据局部模型蒸馏算法以及融合数据、全局聚合蒸馏算法和集成模型对全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度,可以有效消除由于全局更新引入的模型聚合漂移问题。
本申请实施例一种面向非独立同分布数据的联邦知识蒸馏方法的实施原理为:服务器根据公共数据集进行随机采样,并根据辅助数据集和优化函数对生成网络进行预训练,获取生成网络模型,服务器再将生成网络模型发送至客户端,客户端根据噪声向量输出对应的生成网络数据,客户端根据将本地数据和生成网络数据通过数据融合算法进行动量融合,并根据局部蒸馏算法以及融合数据对深度学习模型进行优化训练,同时客户端通过全局聚合蒸馏算法对全局模型进行微调,直至所有客户端依次对深度学习模型进行优化迭代后,获取优化模型,客户端将优化模型发送至服务器,服务器再将全部的优化模型进行平均加权处理,得到全局模型,最后将全局模型下发至所有客户端,从而便于所有的客户端基于全局模型对本地的深度学习模型进行迭代优化,进而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
下面结合仿真实验对本申请的效果做进一步的说明:
仿真实验条件:
本申请仿真实验的硬件平台为:一个中心服务器计算机,处理器为Intel至强E3-1231V3,主频为3.6GHz,内存64GB,英伟达GeForce RTX 3090显卡。三台客户端计算机,处理器为Intel(R) Core(TM) i7-9700F,主频为3.0GHz,内存16GB,英伟达GeForce RTX 2060显卡。
本申请仿真实验的软件平台为:Ubuntu 16.04 LTS,64位操作系统、Python 3.8、PyTorch深度学习框架(版本1.11.0)以及PyCharm代码编写软件。
仿真实验内容及其结果分析:
本申请仿真实验是采用本申请和一个现有技术(ResNet神经网络)分别对两种常见的图像分类数据集(CIFAR-10数据集和CIFAR-100数据集)进行图像预测任务,并获得分类预测结果。其中,在本申请实验中,划分的训练集和测试集的比例为7:3。
为了验证本申请实验的效果,采用全局模型在测试数据集上的预测分类准确率作为定量评价指标,对经过本方法和其他方法训练的模型进行评价。
在本方法的仿真实验中,其他方法分别为联邦平均聚合算法(FedAvg)、联邦优化算法(FedProx)、联邦归一化平均算法(FedNova)、联邦终身学习算法(FedCurv)、联邦融合集成算法(FedDF)和联邦无数据知识蒸馏算法(FedGEN)。
在本方法的仿真实验中,代表基于狄利克雷分布划分后的数据集的非独立同分布程度的大小,其中/>如果越小,则数据非独立同分布程度越大。
从表1中可以看出,本申请方法与其他方法相比,通过本方法训练后的模型在不同的数据集和数据不平衡程度上实现了更高的分类预测准确率,特别是在CIFAR-100数据集上,虽然其训练数据复杂且严重不平衡,但通过本申请方法训练后的全局模型仍然取得了优异的预测精度。
以上仿真实验表明:本申请提出了一种面向非独立同分布数据的联邦知识蒸馏方法,分别在本地客户端和中央服务器上通过局部模型蒸馏和全局聚合蒸馏算法,解决了现有技术中处理非独立同分布数据时可能存在的模型训练偏差问题,以及中央服务器上存在的模型聚合漂移问题。
图1为一个实施例中面向非独立同分布数据的联邦知识蒸馏方法的流程示意图。应该理解的是,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行;除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行;并且图1中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些子步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
基于相同的技术构思,参照图2,本申请实例还提供了一种面向非独立同分布数据的联邦知识蒸馏装置,采用如下技术方案,该装置包括:
数据采样模块310,用于根据预设的公共数据集进行随机采样,获取辅助数据集;
生成网络模块320,用于基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
数据生成模块330,用于将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据;
数据融合模块340,用于控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据;
模型优化模块350,用于控制客户端根据预设的局部模型蒸馏算法以及融合数据对客户端的深度学习模型进行优化训练,得到全局模型。
在一些实施例中,优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。
在一些实施例中,对抗目标损失函数的计算公式为:
其中,为辅助数据集中的数据样本,/>为噪声向量,/>为生成网络,/>和/>分别代表生成网络/>和鉴别网络/>的模型参数。
在一些实施例中,互信息平滑损失函数的计算公式为:
其中,代表一次批处理过程中噪声向量/>的数量。
在一些实施例中,相似度惩罚损失函数的计算公式为:
其中,和/>代表重复采样过程中不同的噪声向量。
在一些实施例中,数据融合模块340具体用于基于生成网络模型生成的生成网络数据/>和客户端的本地数据/>通过数据融合算法进行融合,得到融合数据/>
其中,数据融合算法的计算公式为:
其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,/>为样本的伪标签,/>和/>为合成后的数据样本和标签。
在一些实施例中,数据融合模块340还用于计算生成网络数据与本地数据之间的数量比例;
控制客户端基于局部模型蒸馏算法、数量比例以及融合数据对生成网络进行优化训练,得到全局模型;
其中,局部模型蒸馏算法的计算公式为:
其中,其中为本地数据的样本数量,/>为生成网络数据的样本数量,/>是代表客户端本地的深度学习模型/>在生成网络数据/>和融合数据/>之间Kullback-Leibler距离,/>为用于调整知识蒸馏强度的参数,/>为生成网络数据中标签为/>的样本数量,代表归一化指数函数。
在一些实施例中,模型优化模块350还用于若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述深度学习模型进行迭代优化,获取全部客户端的优化模型;
接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。
在一些实施例中,模型优化模块350还用于接收全体客户端深度学习模型的模型参数;
基于每个客户端的模型参数通过可学习参数进行加权处理,得到集成模型;
基于生成网络模型批量生成的生成网络数据,得到虚拟数据集;
基于全局聚合蒸馏算法,通过解耦生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;
将全局微调模型重新分发给各个客户端依次进行迭代优化,直至全局微调模型收敛或者达到指定精度;
其中,集成模型的计算公式为:
其中,是一个可学习参数并处于0到1之间,/>则是用于控制权重参数正则化的程度,/>代表客户端上的模型参数;
全局聚合蒸馏算法的定义如下:
其中代表全局模型,/>代表集成模型,/>为虚拟数据集中的数据样本。
本申请实例还公开一种控制设备。
具体来说,该控制设备包括存储器和处理器,存储器上存储有能够被处理器加载并执行上述面向非独立同分布数据的联邦知识蒸馏方法的计算机程序。
本申请实例还公开一种计算机可读存储介质。
具体来说,该计算机可读存储介质,其存储有能够被处理器加载并执行如上述面向非独立同分布数据的联邦知识蒸馏方法的计算机程序,该计算机可读存储介质例如包括:U盘、移动硬盘、只读存储器(Read-OnlyMemory,ROM)、随机存取存储器(RandomAccessMemory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
以上均为本申请的较佳实施例,并非依此限制本申请的保护范围,故:凡依本申请的结构、形状、原理所做的等效变化,均应涵盖于本申请的保护范围之内。

Claims (3)

1.一种面向非独立同分布数据的联邦知识蒸馏方法,其特征在于,所述方法包括:
根据预设的公共数据集进行随机采样,获取辅助数据集;
基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;
控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;
控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型;
其中,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数;
所述对抗目标损失函数的计算公式为:
其中,为所述辅助数据集中的数据样本,/>为所述噪声向量,/>为所述生成网络,和/>分别代表所述生成网络/>和所述鉴别网络/>的模型参数;
所述互信息平滑损失函数的计算公式为:
其中,代表一次批处理过程中所述噪声向量/>的数量;
所述相似度惩罚损失函数的计算公式为:
其中,和/>代表重复采样过程中不同的噪声向量;
其中,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:
基于所述生成网络模型生成的所述生成网络数据/>和客户端的所述本地数据/>通过所述数据融合算法进行融合,得到所述融合数据/>
其中,所述数据融合算法的计算公式为:
其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,/>为样本的伪标签,/>和/>为合成后的数据样本和标签;
其中,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:
计算所述生成网络数据与所述本地数据之间的数量比例;
控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对所述深度学习模型进行优化训练,得到所述全局模型;
其中,所述局部模型蒸馏算法的计算公式为:
其中,其中为所述本地数据的样本数量,/>为所述生成网络数据的样本数量,/>是代表客户端本地的深度学习模型/>在所述生成网络数据/>和所述融合数据/>之间Kullback-Leibler距离,/>为用于调整知识蒸馏强度的参数,/>为所述生成网络数据中标签为/>的样本数量,/>则代表归一化指数函数;
其中,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型之后,还包括:
接收全体客户端深度学习模型的模型参数;
基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;
基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;
基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;
将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;
其中,所述集成模型的计算公式为:
其中,是一个可学习参数并处于0到1之间,/>则是用于控制权重参数正则化的程度,/>代表客户端上的所述模型参数;
所述全局聚合蒸馏算法的定义如下:
其中代表所述全局模型,/>代表所述集成模型,/>为所述虚拟数据集中的数据样本。
2.根据权利要求1所述的方法,其特征在于,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:
若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述深度学习模型进行迭代优化,获取全部客户端的优化模型;
接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。
3.一种面向非独立同分布数据的联邦知识蒸馏装置,其特征在于,所述装置包括:
数据采样模块,用于根据预设的公共数据集进行随机采样,获取辅助数据集;
生成网络模块,用于基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
数据生成模块,用于将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;
数据融合模块,用于控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;
模型优化模块,用于控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型;
其中,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数;
所述对抗目标损失函数的计算公式为:
其中,为所述辅助数据集中的数据样本,/>为所述噪声向量,/>为所述生成网络,和/>分别代表所述生成网络/>和所述鉴别网络/>的模型参数;
所述互信息平滑损失函数的计算公式为:
其中,代表一次批处理过程中所述噪声向量/>的数量;
所述相似度惩罚损失函数的计算公式为:
其中,和/>代表重复采样过程中不同的噪声向量;
其中,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:
基于所述生成网络模型生成的所述生成网络数据/>和客户端的所述本地数据/>通过所述数据融合算法进行融合,得到所述融合数据/>
其中,所述数据融合算法的计算公式为:
其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,/>为样本的伪标签,/>和/>为合成后的数据样本和标签;
其中,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:
计算所述生成网络数据与所述本地数据之间的数量比例;
控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对所述深度学习模型进行优化训练,得到所述全局模型;
其中,所述局部模型蒸馏算法的计算公式为:
其中,其中为所述本地数据的样本数量,/>为所述生成网络数据的样本数量,/>是代表客户端本地的深度学习模型/>在所述生成网络数据/>和所述融合数据/>之间Kullback-Leibler距离,/>为用于调整知识蒸馏强度的参数,/>为所述生成网络数据中标签为/>的样本数量,/>则代表归一化指数函数;
其中,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型之后,还包括:
接收全体客户端深度学习模型的模型参数;
基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;
基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;
基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;
将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;
其中,所述集成模型的计算公式为:
其中,是一个可学习参数并处于0到1之间,/>则是用于控制权重参数正则化的程度,/>代表客户端上的所述模型参数;
所述全局聚合蒸馏算法的定义如下:
其中代表所述全局模型,/>代表所述集成模型,/>为所述虚拟数据集中的数据样本。
CN202311714820.2A 2023-12-14 2023-12-14 面向非独立同分布数据的联邦知识蒸馏方法及装置 Active CN117408330B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311714820.2A CN117408330B (zh) 2023-12-14 2023-12-14 面向非独立同分布数据的联邦知识蒸馏方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311714820.2A CN117408330B (zh) 2023-12-14 2023-12-14 面向非独立同分布数据的联邦知识蒸馏方法及装置

Publications (2)

Publication Number Publication Date
CN117408330A CN117408330A (zh) 2024-01-16
CN117408330B true CN117408330B (zh) 2024-03-15

Family

ID=89492865

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311714820.2A Active CN117408330B (zh) 2023-12-14 2023-12-14 面向非独立同分布数据的联邦知识蒸馏方法及装置

Country Status (1)

Country Link
CN (1) CN117408330B (zh)

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019241659A1 (en) * 2018-06-15 2019-12-19 Subtle Medical, Inc. Systems and methods for magnetic resonance imaging standardization using deep learning
WO2021022752A1 (zh) * 2019-08-07 2021-02-11 深圳先进技术研究院 一种多模态三维医学影像融合方法、系统及电子设备
CN113421318A (zh) * 2021-06-30 2021-09-21 合肥高维数据技术有限公司 一种基于多任务生成对抗网络的字体风格迁移方法和系统
CN115858675A (zh) * 2022-12-05 2023-03-28 西安电子科技大学 基于联邦学习框架的非独立同分布数据处理方法
CN116311323A (zh) * 2023-01-17 2023-06-23 北京荣大科技股份有限公司 基于对比学习的预训练文档模型对齐优化方法
WO2023124296A1 (zh) * 2021-12-29 2023-07-06 新智我来网络科技有限公司 基于知识蒸馏的联合学习训练方法、装置、设备及介质
CN116629376A (zh) * 2023-04-26 2023-08-22 浙江大学 一种基于无数据蒸馏的联邦学习聚合方法和系统
CN116883751A (zh) * 2023-07-18 2023-10-13 安徽大学 基于原型网络对比学习的无监督领域自适应图像识别方法
CN116910571A (zh) * 2023-09-13 2023-10-20 南京大数据集团有限公司 一种基于原型对比学习的开集域适应方法及系统
CN117115547A (zh) * 2023-09-05 2023-11-24 云南大学 基于自监督学习与自训练机制的跨域长尾图像分类方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11710300B2 (en) * 2017-11-06 2023-07-25 Google Llc Computing systems with modularized infrastructure for training generative adversarial networks
EP4073714A1 (en) * 2019-12-13 2022-10-19 Qualcomm Technologies, Inc. Federated mixture models

Patent Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019241659A1 (en) * 2018-06-15 2019-12-19 Subtle Medical, Inc. Systems and methods for magnetic resonance imaging standardization using deep learning
WO2021022752A1 (zh) * 2019-08-07 2021-02-11 深圳先进技术研究院 一种多模态三维医学影像融合方法、系统及电子设备
CN113421318A (zh) * 2021-06-30 2021-09-21 合肥高维数据技术有限公司 一种基于多任务生成对抗网络的字体风格迁移方法和系统
WO2023124296A1 (zh) * 2021-12-29 2023-07-06 新智我来网络科技有限公司 基于知识蒸馏的联合学习训练方法、装置、设备及介质
CN115858675A (zh) * 2022-12-05 2023-03-28 西安电子科技大学 基于联邦学习框架的非独立同分布数据处理方法
CN116311323A (zh) * 2023-01-17 2023-06-23 北京荣大科技股份有限公司 基于对比学习的预训练文档模型对齐优化方法
CN116629376A (zh) * 2023-04-26 2023-08-22 浙江大学 一种基于无数据蒸馏的联邦学习聚合方法和系统
CN116883751A (zh) * 2023-07-18 2023-10-13 安徽大学 基于原型网络对比学习的无监督领域自适应图像识别方法
CN117115547A (zh) * 2023-09-05 2023-11-24 云南大学 基于自监督学习与自训练机制的跨域长尾图像分类方法
CN116910571A (zh) * 2023-09-13 2023-10-20 南京大数据集团有限公司 一种基于原型对比学习的开集域适应方法及系统

Non-Patent Citations (7)

* Cited by examiner, † Cited by third party
Title
Logit Calibration for Non-IID and Long-Tailed Data in Federated Learning;Huan Wang等;2022 IEEE Intl Conf on Parallel & Distributed Processing with Applications, Big Data&Cloud Computing, Sustainable Computing & Communications, Social Computing&Networking(ISPA/BDCloud/SocialCom/SustainCom);20230323;正文第783页第一栏第4段-第789页第一栏第1段 *
刘天.面向数据异构的联邦学习的性能优化研究.中国博士学位论文全文数据库 信息科技辑.2022,(第12期),第I138-4页. *
孙季丰等.基于DeblurGAN和低秩分解的去运动模糊.华南理工大学学报(自然科学版).2020,第48卷(第01期),第32-42页. *
李剑.非独立同分布数据下的联邦学习算法研究.中国优秀硕士学位论文全文数据库信息科技辑.2022,(第01期),正文第6页第4段-第57页第4段,图3-1,图4-1. *
王欢等.联合多任务学习的人脸超分辨率重建.中国图象图形学报.2020,第25卷(第02期),第229-240页. *
赵子平等.基于联邦学习的智能助老服务研究.信号处理.2023,第39卷(第04期),第667-677页. *
非独立同分布数据下的联邦学习算法研究;李剑;中国优秀硕士学位论文全文数据库信息科技辑;20220115(第01期);正文第6页第4段-第57页第4段,图3-1,图4-1 *

Also Published As

Publication number Publication date
CN117408330A (zh) 2024-01-16

Similar Documents

Publication Publication Date Title
CN110210560B (zh) 分类网络的增量训练方法、分类方法及装置、设备及介质
CN110852447B (zh) 元学习方法和装置、初始化方法、计算设备和存储介质
CN111461226A (zh) 对抗样本生成方法、装置、终端及可读存储介质
US11315032B2 (en) Method and system for recommending content items to a user based on tensor factorization
TW202123052A (zh) 防止隱私資料洩漏的編碼模型訓練方法及裝置
TW202026984A (zh) 伺服器、客戶端、用戶核身方法及系統
CN110751291A (zh) 实现安全防御的多方联合训练神经网络的方法及装置
WO2022037541A1 (zh) 图像处理模型训练方法、装置、设备及存储介质
CN110298240B (zh) 一种汽车用户识别方法、装置、系统及存储介质
JP2023535140A (ja) ターゲット・ドメインに対する転移学習プロセスに適合するソース・データセットを識別すること
CN107958247A (zh) 用于人脸图像识别的方法和装置
CN110276243A (zh) 分数映射方法、人脸比对方法、装置、设备及存储介质
Valery et al. CPU/GPU collaboration techniques for transfer learning on mobile devices
CN109165654A (zh) 一种目标定位模型的训练方法和目标定位方法及装置
Kawa et al. A note on deepfake detection with low-resources
WO2020051232A1 (en) Decentralized biometric identification and authentication network
CN112052865A (zh) 用于生成神经网络模型的方法和装置
CN116152938A (zh) 身份识别模型训练和电子资源转移方法、装置及设备
CN117408330B (zh) 面向非独立同分布数据的联邦知识蒸馏方法及装置
CN116151965B (zh) 一种风险特征提取方法、装置、电子设备及存储介质
Vashishtha et al. An Ensemble approach for advance malware memory analysis using Image classification techniques
CN110738227B (zh) 模型训练方法及装置、识别方法、存储介质及电子设备
US20240037995A1 (en) Detecting wrapped attacks on face recognition
Abdukhamidov et al. Hardening Interpretable Deep Learning Systems: Investigating Adversarial Threats and Defenses
CN113343898B (zh) 基于知识蒸馏网络的口罩遮挡人脸识别方法、装置及设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant