CN117994635B - 一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 - Google Patents
一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 Download PDFInfo
- Publication number
- CN117994635B CN117994635B CN202410396190.7A CN202410396190A CN117994635B CN 117994635 B CN117994635 B CN 117994635B CN 202410396190 A CN202410396190 A CN 202410396190A CN 117994635 B CN117994635 B CN 117994635B
- Authority
- CN
- China
- Prior art keywords
- client
- model
- local
- global
- updating
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 57
- 238000012549 training Methods 0.000 claims abstract description 47
- 230000006870 function Effects 0.000 claims abstract description 34
- 230000007246 mechanism Effects 0.000 claims abstract description 12
- 230000008569 process Effects 0.000 claims description 27
- 238000012360 testing method Methods 0.000 claims description 18
- 238000009499 grossing Methods 0.000 claims description 9
- 101100455978 Arabidopsis thaliana MAM1 gene Proteins 0.000 claims description 7
- 230000004931 aggregating effect Effects 0.000 claims description 4
- 238000012804 iterative process Methods 0.000 claims 1
- 238000010801 machine learning Methods 0.000 abstract description 7
- 238000012545 processing Methods 0.000 abstract description 2
- 238000009826 distribution Methods 0.000 description 13
- 238000004220 aggregation Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 6
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 5
- 230000002776 aggregation Effects 0.000 description 5
- 230000000694 effects Effects 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 238000005192 partition Methods 0.000 description 3
- 238000012795 verification Methods 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 238000012937 correction Methods 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 238000010200 validation analysis Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000008520 organization Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
- G06V10/443—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components by matching or filtering
- G06V10/449—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters
- G06V10/451—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters with interaction between the filter responses, e.g. cortical complex cells
- G06V10/454—Integrating the filters into a hierarchical structure, e.g. convolutional neural networks [CNN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Biophysics (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Data Mining & Analysis (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Biodiversity & Conservation Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明属于机器学习技术处理,为了解决现有联邦学习中客户端模型训练的稳定性差,以及噪声负面影响的问题,提出了一种噪声鲁棒性增强的联邦元学习图像识别方法及系统,通过引入AdaBelief优化器和SCAFFOLD算法中的动量和控制变量机制对全局模型参数和全局控制变量进行本地更新,能够提高客户端本地训练稳定性并加快收敛速度;在客户端本地更新中,通过动态权重参数对损失函数进行改进,并结合平滑标签策略,个性化的学习策略在减少局部噪声的负面影响的同时并提高模型的泛化能力。
Description
技术领域
本发明属于机器学习技术领域,尤其涉及一种噪声鲁棒性增强的联邦元学习图像识别方法及系统。
背景技术
机器学习在人们的日常生活中越来越常见,它被应用到人们生活中的方方面面,给人们的生活带来了极大的便利。但是,为满足用户隐私、数据安全和政府规定,机构之间无法直接共享数据。为了能够在不侵犯隐私和保证数据安全的前提下进行机器学习建模。联邦学习框架应运而生。联邦学习是一种高效的通信和保护隐私的替代方法,可让一组组织或同一组织内的群组以协作和迭代的方式训练和改进共享的全局机器学习模型,参与联邦学习的组织之间不会交换唯一数据。从而可以既保证数据隐私安全,又能完成机器学习建模任务。
联邦学习是一种新兴的人工智能基础技术,其设计目标是在保障大数据交换时的信息安全、保护终端数据和个人数据隐私、保证合法合规的前提下,在多参与方或多计算结点之间开展高效率的机器学习。然而,联邦学习面临的其中一大挑战是数据异质性,联邦客户端中的数据通常是非独立同分布的。数据中也可能存在大量的噪声。设备之间的数据可能有一种潜在的统计结构表示不同设备的关系和分布。这可能会严重影响联邦学习的全局模型的性能。同时,由于数据分布各不相同,从而可能导致全局模型在某些联邦客户端的本地性能较好,在其他客户端表现很差。一些研究方法针对解决非独立同分布数据下全局模型效果差,但是缺乏对局部模型的个性化考虑,现有的方案通常在局部客户端的表现各有不同。因此,联邦学习中的数据异质性、噪声处理和个性化局部客户端问题具有重要意义。
对于处理数据异质性和噪声的方法,通常可以分为客户端方法和中央服务器端方法。在中央服务器端,可以通过改变模型聚合策略来提升全局模型的性能,例如使用加权聚合来对不同客户端的贡献进行调整。客户端方法主要侧重于通过调整全局模型的参数,构建个性化模型以提升局部模型性能,同时可以采用一些对抗噪声的方法,如对输入数据进行预处理或者在训练过程中引入噪声以提高模型的鲁棒性。尽管这些方法在一定程度上减缓了数据异质性和噪声对联邦学习的影响,但通常仅关注最终全局模型的表现,可能导致某些客户端的性能较差。
综上,如何在联邦学习中,提高客户端模型训练的稳定性以及减少噪声负面影响,是目前需要解决的技术问题。
发明内容
为克服上述现有技术的不足,本发明提供了一种噪声鲁棒性增强的联邦元学习图像识别方法及系统,通过引入AdaBelief优化器和SCAFFOLD算法中的动量和控制变量机制对全局模型参数和全局控制变量进行本地更新,能够提高客户端本地训练稳定性并加快收敛速度;在客户端本地更新中,通过动态权重参数对损失函数进行改进,并结合平滑标签策略,个性化的学习策略在减少局部噪声的负面影响的同时并提高模型的泛化能力。
为实现上述目的,本发明的第一个方面提供一种噪声鲁棒性增强的联邦元学习图像识别方法,包括:
中央服务器将当前的全局模型参数和全局控制变量发送给各客户端;
各客户端根据所接收的当前的全局模型参数和全局控制变量,引入AdaBelief优化器使全局模型适应本地数据集,通过SCAFFOLD算法中的动量和控制变量机制对全局模型参数和全局控制变量进行本地更新,得到更新后的本地模型参数以及本地控制变量;其中,在本地更新过程中,通过动态权重参数对损失函数进行改进,以改进后的损失函数和平滑标签策略进行本地更新;
各客户端将更新后的本地模型以及控制变量更新差异上传给中央服务器;
中央服务器根据所接收的各客户端更新后的本地模型以及控制变量更新进行聚合,得到更新后的全局模型参数和全局控制变量,迭代更新,直至全局模型收敛,利用训练好的全局模型进行图像识别。
本发明的第二个方面提供一种噪声鲁棒性增强的联邦元学习图像识别系统,包括:中央服务器和各客户端;
所述中央服务器,用于将当前的全局模型参数和全局控制变量发送给各客户端;
所述各客户端,用于根据所接收的当前的全局模型参数和全局控制变量,引入AdaBelief优化器使全局模型适应本地数据集,通过SCAFFOLD算法中的动量和控制变量机制对全局模型参数和全局控制变量进行本地更新,得到更新后的本地模型以及本地控制变量;其中,在本地更新过程中,通过动态权重参数对损失函数进行改进,以改进后的损失函数和平滑标签策略进行本地更新;
所述各客户端,用于将更新后的本地模型以及控制变量更新差异上传给中央服务器;
所述中央服务器,用于根据所接收的各客户端更新后的本地模型以及控制变量更新进行聚合,得到更新后的全局模型参数和全局控制变量,迭代更新,直至全局模型收敛,利用训练好的全局模型进行图像识别。
以上一个或多个技术方案存在以下有益效果:
在本发明中,通过引入AdaBelief优化器和SCAFFOLD算法中的动量和控制变量机制对全局模型参数和全局控制变量进行本地更新,能够提高客户端本地训练稳定性并加快收敛速度;在客户端本地更新中,通过动态权重参数对损失函数进行改进,并结合平滑标签策略,个性化的学习策略在减少局部噪声的负面影响的同时并提高模型的泛化能力。
本发明附加方面的优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
构成本发明的一部分的说明书附图用来提供对本发明的进一步理解,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。
图1为本发明实施例一中一种噪声鲁棒性增强的联邦元学习图像识别方法流程图。
具体实施方式
应该指出,以下详细说明都是示例性的,旨在对本发明提供进一步的说明。除非另有指明,本文使用的所有技术和科学术语具有与本发明所属技术领域的普通技术人员通常理解的相同含义。
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本发明的示例性实施方式。
在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。
实施例一
如图1所示,本实施例公开了一种噪声鲁棒性增强的联邦元学习图像识别方法,包括:
中央服务器将当前的全局模型参数和全局控制变量发送给各客户端;
各客户端根据所接收的当前的全局模型参数和全局控制变量,采用元学习方式进行训练更新,在训练更新过程中,基于AdaBelief优化器使全局模型适应本地数据集;基于SCAFFOLD算法中的动量和控制变量机制,根据全局控制变量进行本地更新,得到更新后的本地模型以及本地控制变量;通过动态权重参数对损失函数进行改进,以改进后的损失函数和平滑标签策略进行本地更新;
各客户端将更新后的本地模型以及控制变量更新差异上传给中央服务器;
中央服务器根据所接收的各客户端更新后的本地模型以及控制变量更新进行聚合,得到更新后的全局模型参数和全局控制变量,迭代更新,直至全局模型收敛,利用训练好的全局模型进行图像识别。
在本实施例中,将非独立同分布数据集应用于联邦元学习设置,搭建联邦元学习环境并初始化模型,对客户端模型进行更新并上传到中央服务器中,在中央服务器聚合所有客户端的模型并更新全局模型。
将基准数据集CIFAR-10和CIFAR-100适应于非独立同分布数据下联邦元学习设置。
CIFAR-10数据集包含了60000个32x32的彩色图像,涵盖10个类别,每个类别含有6000张图像,其中训练图像数量为50000,测试图像数量为10000。数据集被划分为五个训练批次和一个测试批次,每个批次包含10000张图像。测试批次中包括每个类别随机选择的1000张图像。训练批次中的图像以随机顺序排列,但某些批次可能包含某个类别的图像数量多于其他类别。总体而言,各训练批次均包含5000张来自各个类别的图像。
CIFAR-100数据集涵盖了100个类别,每个类别包含600张图像。每个类别分别包含500张训练图像和100张测试图像。CIFAR-100中的这100个类别被划分成了20个超类。每张图像都附带有一个“精细”标签,用以表示其所属的具体类别,以及一个“粗糙”标签,用以表示其所属的超类。
使用Dirichlet分布来划分不同客户端之间的非独立同分布数据分区,并且将CIFAR-10和CIFAR-100上的异构程度值设置为0.5。同时,为生成有噪音的数据集,使用数据增强为训练集添加噪音,噪音率/>,并将训练集中20%的标签翻转到错误的标签上作为每个联邦客户端本地数据集。
为了在接下来的步骤中,保证准确率的前提下更快地收敛,采用包括三个全连接层和一个分类器的CNN网络架构参与训练。该网络架构是带有参数的神经网络,它由特征提取器/>和分类器两部分/>组成,其中,分类器是从全局平均池化层获取输入。客户端/>的本地模型的参数表示为/>。默认环境下,总共运行300轮全局通讯,总共有20个客户端。设置局部训练批次大小为64,内学习率为0.001,外学习率为0.1,动量衰减为0.9,初始控制变量为0.01,优化器为AdaBelief。
在本实施例步骤2中,在第次通信时,中央服务器从/>个客户端中选择/>个客户端形成集合/>。在首轮通信开始时,中央服务器随机初始化全局模型参数/>和全局控制变量,同时也为每个客户端分配一个初始的个人控制变量/>,然后将全局模型与全局控制变量发送给被选中参与本轮的客户端。
客户端接收到全局模型和控制变量后,使用其本地数据集进行模型的训练和更新 ,其中/>。在此过程中,为了提高训练的稳定性并加快收敛速度,特别在处理复杂联邦学习场景时,采用了 AdaBelief优化器,并结合带动量的SCAFFOLD算法的策略来进行更新。
每个客户端的更新过程中,引入了AdaBelief优化器以适应其本地数据分布,并使用带动量的SCAFFOLD算法中定义的动量和控制变量机制来进一步指导更新过程,具体为:在向联邦中央服务器发送模型及参数时同时发送一个控制变量c,随后这个控制变量c也会从联邦端发送回中央服务器,进行聚合。
AdaBelief优化器是基于adam优化器改进,其主要改进为在于对梯度二阶矩估计的处理。
具体而言:Adam优化器结合了Momentum和RMSProp的特点,关键步骤如下:
计算梯度的一阶矩估计:
计算梯度的二阶矩估计:
对一阶矩和二阶矩进行偏差修正:
更新参数:
其中,是在时间步/>处的梯度,/>和/>分别是梯度的一阶矩 (均值) 和二阶矩(未中心化的方差) 的估计,/>和/>是衰减率,/>是学习率,/>是一个很小的数以避免除以零。/>表示模式在时间步/>的参数。/>为修正后一阶矩,/>为修正后二阶矩。
AdaBelief优化器:
AdaBelief优化器在二阶矩的计算上做了关键的改进,以期更好地捕捉梯度的不确定性,并据此调整学习率:
计算梯度的一阶矩估计:
与Adam的主要区别,计算梯度的修正二阶矩估计:
对一阶矩和修正二阶矩进行偏差修正,类似于Adam的步骤,但应用于:
使用的是修正的二阶矩更新参数:
其中,是学习率,/>是一个很小的数以避免除以零,/>表示模式在时间步/>的参数。
AdaBelief计算的是梯度与其一阶矩估计之差的平方,而Adam计算的是梯度的平方。这种差异使得AdaBelief在调整学习率时,更加侧重于梯度变化的不确定性,旨在提供一种更稳定和效率更高的参数更新策略。
引入AdaBelief优化器目的是稳定训练,提高训练速度和模型的最终性能 ,核心思想为:在更新模型参数时,不仅考虑梯度的一阶矩估计即均值,通过二阶矩估计的方式还考虑到梯度的不确定性。AdaBelief通过对比预测的梯度和实际观察到的梯度之间的差异,来调整学习速率,这样做的目的是使优化过程更加稳定,尤其是在训练的初期阶段。
具体而言,每个客户端的模型更新考虑了如下形式:
其中,是经AdaBelief优化器调整后的学习率,/>是动量衰减参数,/>是客户端在上一次迭代中的动量项,/>是客户端的控制变量,/>是全局控制变量,/>是在当前全局模型参数/>下,客户端/>的本地数据集/>上计算得到的梯度,/>为更新前客户端本地模型参数。
更新后的客户端本地模型参数和控制变量的调整即/>,随后被上传到中央服务器。
中央服务器聚合所有参与本轮的客户端上传的模型更新和控制变量的调整,以计算下一轮通信的全局模型和更新全局控制变量:
其中,为联邦聚合后全局模型参数,具体为,假设当前为第t次训练,训练后的客户端本地模型/>全部上传到中央服务器进行聚合,聚合后为t+1时间步聚合的全局模型参数/>,在t+1次训练的时候发送给联邦客户端。/>为联邦端数据集,/>为中央服务器聚合各个联邦客户端上传的控制变量参数/>,/>,/>为选中的中央服务器集合。上标k表示其中一个中央服务器,上标t为联邦学习轮次。
通过这种方式,结合了AdaBelief优化器和引入动量的SCAFFOLD算法策略的联邦学习过程不仅能够有效应对非IID数据分布的挑战,还能在复杂的联邦学习环境中提高模型训练的稳定性和加速收敛。
客户端首先在本地通过中央服务器发送的全局模型进行训练,将本地数据分为训练集、验证集和测试集,训练集分为支撑集和查询集。
具体的,验证集和测试集80%、10%和10%的比例划分,以确保训练、验证和测试之间的独立性,因此每一个客户端上的数据集分为三个部分,训练集,验证集/>,测试集。同时对于训练集中分为支撑集/>和查询集/>,比例为70%和30%。
在每个客户端中训练集用于在每个任务上进行快速学习,验证集用于选择最佳的学习策略,测试集用于评估模型在不同任务上的表现。元学习的流程涉及在内循环中反复迭代,每次迭代都会根据任务来更新模型,然后在外循环中通过测试集来验证其泛化能力。最后在本地通过个性化方案,获得适合本地的个性化模型。通过这种方式,模型可以逐步改进其在各种任务上的性能,实现更好的泛化能力。
为了应对客户端之间数据分布的潜在差异,本实施例引入个性化学习策略,通过在局部更新过程中对全局模型进行微调,以更好地适应各个客户端的特定数据分布。个性化学习的目标是最小化每个客户端的本地损失,同时保留从全局模型学到的通用知识,以更好地适应其本地数据。这意味着对于联邦学习框架中的每个客户端,希望找到最适合其本地数据/>的模型参数/>。考虑到这一点,个性化学习可以通过优化以下目标函数来实现,该函数考虑了本地数据对模型的影响:
其中,代表客户端/>的本地损失函数,/>是在模型/>参数为时,对于客户端/>的本地数据集/>中每个数据点/>的损失。
个性化学习的过程中,进一步采用元学习框架,MAML和Meta-SGD,这时个性化模型的更新可以从全局模型的参数开始,通过一定数量的梯度更新步骤来适应每个客户端的本地数据。具体地,MAML方法中的个性化更新可以表示为:
对于Meta-SGD,考虑到学习率本身也是可学习的,个性化更新步骤变为:
其中,表示针对客户端/>优化的学习率向量,○表示元素级乘法即Hadamard乘积,θ指的是元学习方式中的可训练参数的初始化值。通过这样的个性化更新步骤,可以使模型更好地适应每个客户端的特定数据分布,同时保留从全局模型学习到的知识。
在第次通信中,每个客户端使用本地数据通过几次 (例如一次) 梯度下降更新,就能得到适合其本地数据集的模型。这种基于元学习的方法旨在通过少数几步更新快速适应新任务,增强模型在新环境中的表现和泛化能力。MAML算法和Meta-SGD是此过程中采用的两种元学习策略,其中MetaSGD进一步学习了内部学习率/>,提供了更为灵活的更新机制。
采用MAML算法时,初始化模型的参数为/>,对于每个任务/>,算法使用/>作为起始参数,并在支撑集/>上执行一步或多步梯度下降来更新参数至/>。这个过程称为内部更新,使用的训练损失函数是:
其中,是动态权重对称交叉熵,以降低噪声的干扰;/>: 表示模型/>, 其参数为/>, 作用于输入/>;/>来自支撑集/>的特定任务/>的训练数据集,用于模型的内部更新;是与输入/>相对应的真实标签或输出。然后,在查询集/>上测试经过内部更新的模型/>并计算测试损失:
其中,: 表示经过一系列内部更新后的模型/>, 其参数更新为/>,作用于输入/>。/>来自查询集/>,这是同一任务/>的另一部分数据,用于测试经过内部更新的模型性能。/>是与输入/>相对应的真实标签或输出,用于评估模型/>在查询集/>上的性能。
这反映了模型在新任务上的泛化能力。外部更新的目的是通过最小化测试损失来优化元学习算法的参数或模型的初始参数/>。
Meta-SGD进一步发展了MAML的思想,不仅学习模型参数的初始值,而且同时学习每个参数的内部学习率/>。这使得学习率/>可以针对每个参数进行优化,为不同的模型参数提供了个性化的更新步长。Meta-SGD的优化目标可以表达为:
这里,○表示Hadamard乘积 (元素乘),意味着每个参数的更新都是其梯度乘以对应的学习率。
在联邦学习环境中,通过这种方式,不仅可以利用元学习快速适应新客户端或新任务,还可以通过个性化学习策略和先进的优化器来提高整个系统的效率和效果。个性化学习策略确保每个客户端的模型能够更好地适应其特定的数据分布,而先进的优化器AdaBelief则为联邦学习系统中的模型训练提供了更稳定和高效的方法。通过这些策略的结合,能够在保持高准确率的同时加快模型的收敛速度,进而提高联邦学习系统的整体性能和实用性。
为了减少局部噪声的负面影响并提高模型的泛化能力,采用了动态权重对称交叉熵学习方法,该方法通过动态权重参数和标签平滑进行优化。动态权重对称交叉摘损失函数结合了标准交叉摘损失(CE) 和反向交叉摘损失 (RCE),同时引入了动态调整机制和标签平滑策略,以应对标签噪声问题并促进难以学习类别的有效学习。动态权重对称损失函数定义如下:
其中,是标准交叉熵损失,是反向交叉熵损失,其中,/>为数据表集中第i个样本的特征向量。不同于固定的权重参数,在这里,/>是一个随训练过程动态调整的权重参数,旨在根据模型在训练过程中的表现或某些评估指标动态平衡CE损失和RCE损失的贡献。/>和分别代表真实概率分布和模型预测的概率分布。
此外,为了进一步提高模型对噪声标签的鲁棒性和泛化能力,采用标签平滑策略。在标签平滑中,真实标签被替换为一个更加平滑的标签分布,这减少了模型对硬标签的依赖,并鼓励模型学习更加稳健的特征表示。标签平滑后的真实标签表示为:
其中,是一个小的平滑参数,/>是类别的总数。这样,每个类别都会被分配一个基础的概率,减少了模型对可能噪声标签的过拟合。在训练迭代过程中,为了避免因标签噪声导致的模型误更新,采用改进的SL损失来计算模型预测的伪标签与相应的平滑标签之间的损失。
局部更新策略调整为:
其中,代表经过标签平滑处理的标签,/>是学习率,/>表示梯度运算符。/>:在时间步/>的客户端/>的模型参数。/>:更新后,在时间步/>的客户端/>的模型参数。/>:学习率,控制参数更新的步长。/>关于模型参数/>的损失函数/>的梯度。:客户端/>在参数/>下的模型预测。/>:客户端/>的目标或真实值。: 客户端/>在时间步/>的损失函数,量化了模型预测/>与真实值/>之间的差异。
本实施例的方法不仅能够在存在标签噪声的情况下更有效地学习难以学习的类别,而且还能避免模型在易于学习的类别上过度拟合噪声标签,从而提高整体性能。
本实施例研究了联邦学习中客户端之间的数据异构性和数据噪声,提出了一种新的联邦元学习框架。该框架可以在保障联邦隐私安全的前提下,利用各联邦客户端的本地数据进行训练,元学习实现在每个联邦客户端的都有适合本地的个性化模型,在训练过程中引用动量概念,加速在复杂场景下的收敛和提高稳定性以及对抗数据异质性。随后通过动态权重对称交叉熵损失函数,还能对抗样本中的噪声问题,还从而解决了联邦学习中数据异质性和噪声对模型性能的影响。本实施例的应用能够提升本地客户端的性能,进而通过全局模型融合获得性能提升,同时能够在单个客户端上进一步提升个性化模型的性能,且优于当前面对数据异质性和噪声的联邦元学习方法。
实施例二
本实施例的目的是提供一种噪声鲁棒性增强的联邦元学习图像识别系统,包括:中央服务器和各客户端;
所述中央服务器,用于将当前的全局模型参数和全局控制变量发送给各客户端;
所述各客户端,用于各客户端根据所接收的当前的全局模型参数和全局控制变量,采用元学习方式进行训练更新,在训练更新过程中,基于AdaBelief优化器使全局模型适应本地数据集;基于SCAFFOLD算法中的动量和控制变量机制,根据全局控制变量进行本地更新,得到更新后的本地模型以及本地控制变量;通过动态权重参数对损失函数进行改进,以改进后的损失函数和平滑标签策略进行本地更新;
所述各客户端,用于将更新后的本地模型以及控制变量更新差异上传给中央服务器;
所述中央服务器,用于根据所接收的各客户端更新后的本地模型以及控制变量更新进行聚合,得到更新后的全局模型参数和全局控制变量,迭代更新,直至全局模型收敛,利用训练好的全局模型进行图像识别。
本领域技术人员应该明白,上述本发明的各模块或各步骤可以用通用的计算机装置来实现,可选地,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。本发明不限制于任何特定的硬件和软件的结合。
上述虽然结合附图对本发明的具体实施方式进行了描述,但并非对本发明保护范围的限制,所属领域技术人员应该明白,在本发明的技术方案的基础上,本领域技术人员不需要付出创造性劳动即可做出的各种修改或变形仍在本发明的保护范围以内。
Claims (6)
1.一种噪声鲁棒性增强的联邦元学习图像识别方法,其特征在于,包括:
中央服务器将当前的全局模型参数和全局控制变量发送给各客户端;
各客户端根据所接收的当前的全局模型参数和全局控制变量,采用元学习方式进行训练更新,在训练更新过程中,基于AdaBelief优化器使全局模型适应本地数据集,具体为:利用AdaBelief优化器对比预测的梯度和实际观察到的梯度之间的差异,来调整学习率;将调整后的学习率应用在客户端本地模型更新中;基于动量衰减参数、客户端在上一次迭代更新中的动量项、客户端的控制变量、全局控制变量、所述调整后的学习率,以及客户端当前本地模型参数下,基于本地数据集所计算的梯度,对客户端本地模型进行更新;基于SCAFFOLD算法中的动量和控制变量机制,根据全局控制变量进行本地更新,得到更新后的本地模型以及本地控制变量;通过动态权重参数对损失函数进行改进,以改进后的损失函数和平滑标签策略进行本地更新,具体的,通过随训练过程动态调整的权重参数,以及标准交叉摘损失和反向交叉摘损失构建动态权重对称损失函数;在客户端本地更新过程中,采用MAML方法进行模型参数内部更新和外部更新;并计算基于所述动态权重对称损失函数的内部更新损失函数和基于所述动态权重对称损失函数的测试损失;根据所述内部更新损失函数和所述测试损失,使用Meta-SGD方法进行优化更新;各客户端将更新后的本地模型以及控制变量更新差异上传给中央服务器;
中央服务器根据所接收的各客户端更新后的本地模型以及控制变量更新进行聚合,得到更新后的全局模型参数和全局控制变量,迭代更新,直至全局模型收敛,利用训练好的全局模型进行图像识别。
2.如权利要求1所述的一种噪声鲁棒性增强的联邦元学习图像识别方法,其特征在于,各客户端将更新后的本地模型参数以及控制变量的调整量上传至中央服务器。
3.如权利要求1所述的一种噪声鲁棒性增强的联邦元学习图像识别方法,其特征在于,在标签平滑中,通过平滑参数和类别总数对真实标签进行平滑处理,得到标签平滑后的真实标签。
4.如权利要求1所述的一种噪声鲁棒性增强的联邦元学习图像识别方法,其特征在于,在客户端更新迭代过程中,利用动态权重对称损失函数计算模型预测的伪标签与相应的平滑标签之间的损失,根据所计算的损失对本地模型进行局部更新。
5.一种噪声鲁棒性增强的联邦元学习图像识别系统,其特征在于,包括:中央服务器和各客户端;
所述中央服务器,用于将当前的全局模型参数和全局控制变量发送给各客户端;
所述各客户端,用于各客户端根据所接收的当前的全局模型参数和全局控制变量,采用元学习方式进行训练更新,在训练更新过程中,基于AdaBelief优化器使全局模型适应本地数据集,具体为:利用AdaBelief优化器对比预测的梯度和实际观察到的梯度之间的差异,来调整学习率;将调整后的学习率应用在客户端本地模型更新中;基于动量衰减参数、客户端在上一次迭代更新中的动量项、客户端的控制变量、全局控制变量、所述调整后的学习率,以及客户端当前本地模型参数下,基于本地数据集所计算的梯度,对客户端本地模型进行更新;基于SCAFFOLD算法中的动量和控制变量机制,根据全局控制变量进行本地更新,得到更新后的本地模型以及本地控制变量;通过动态权重参数对损失函数进行改进,以改进后的损失函数和平滑标签策略进行本地更新,具体的,通过随训练过程动态调整的权重参数,以及标准交叉摘损失和反向交叉摘损失构建动态权重对称损失函数;在客户端本地更新过程中,采用MAML方法进行模型参数内部更新和外部更新;并计算基于所述动态权重对称损失函数的内部更新损失函数和基于所述动态权重对称损失函数的测试损失;根据所述内部更新损失函数和所述测试损失,使用Meta-SGD方法进行优化更新;
所述各客户端,用于将更新后的本地模型以及控制变量更新差异上传给中央服务器;
所述中央服务器,用于根据所接收的各客户端更新后的本地模型以及控制变量更新进行聚合,得到更新后的全局模型参数和全局控制变量,迭代更新,直至全局模型收敛,利用训练好的全局模型进行图像识别。
6.如权利要求5所述的一种噪声鲁棒性增强的联邦元学习图像识别系统,其特征在于,在所述各客户端更新过程中,通过随训练过程动态调整的权重参数,以及标准交叉摘损失和反向交叉摘损失构建动态权重对称损失函数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410396190.7A CN117994635B (zh) | 2024-04-03 | 2024-04-03 | 一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410396190.7A CN117994635B (zh) | 2024-04-03 | 2024-04-03 | 一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117994635A CN117994635A (zh) | 2024-05-07 |
CN117994635B true CN117994635B (zh) | 2024-06-11 |
Family
ID=90893675
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410396190.7A Active CN117994635B (zh) | 2024-04-03 | 2024-04-03 | 一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117994635B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118468928B (zh) * | 2024-07-12 | 2024-09-20 | 中国电子科技集团公司第三十研究所 | 一种安全领域大模型微调方法、装置及可读储存介质 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021120676A1 (zh) * | 2020-06-30 | 2021-06-24 | 平安科技(深圳)有限公司 | 联邦学习网络下的模型训练方法及其相关设备 |
CN113557576A (zh) * | 2018-12-26 | 2021-10-26 | 生命解析公司 | 在表征生理系统时配置和使用神经网络的方法和系统 |
CN116543229A (zh) * | 2023-05-24 | 2023-08-04 | 北京理工大学 | 基于自适应累积系数的深度学习优化器的图像分类方法 |
CN117253072A (zh) * | 2023-08-03 | 2023-12-19 | 长春大学 | 一种基于个性化联邦学习的图像分类方法 |
CN117292221A (zh) * | 2023-09-26 | 2023-12-26 | 山东省计算中心(国家超级计算济南中心) | 基于联邦元学习的图像识别方法及系统 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11556716B2 (en) * | 2020-08-24 | 2023-01-17 | Intuit Inc. | Intent prediction by machine learning with word and sentence features for routing user requests |
US20230308465A1 (en) * | 2023-04-12 | 2023-09-28 | Roobaea Alroobaea | System and method for dnn-based cyber-security using federated learning-based generative adversarial network |
-
2024
- 2024-04-03 CN CN202410396190.7A patent/CN117994635B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113557576A (zh) * | 2018-12-26 | 2021-10-26 | 生命解析公司 | 在表征生理系统时配置和使用神经网络的方法和系统 |
WO2021120676A1 (zh) * | 2020-06-30 | 2021-06-24 | 平安科技(深圳)有限公司 | 联邦学习网络下的模型训练方法及其相关设备 |
CN116543229A (zh) * | 2023-05-24 | 2023-08-04 | 北京理工大学 | 基于自适应累积系数的深度学习优化器的图像分类方法 |
CN117253072A (zh) * | 2023-08-03 | 2023-12-19 | 长春大学 | 一种基于个性化联邦学习的图像分类方法 |
CN117292221A (zh) * | 2023-09-26 | 2023-12-26 | 山东省计算中心(国家超级计算济南中心) | 基于联邦元学习的图像识别方法及系统 |
Non-Patent Citations (3)
Title |
---|
Juntang Zhuang等.AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.ARXIV.2020,全文. * |
Karimireddy等.SCAFFOLD: Stochastic Controlled Averaging for Federated Learning.ARXIV.2019,全文. * |
冯芬玲 ; 阎美好 ; 刘承光 ; 李万 ; .基于IPSO-Capsule-NN模型的中欧班列出口需求量预测.中国铁道科学.2020,(第02期),全文. * |
Also Published As
Publication number | Publication date |
---|---|
CN117994635A (zh) | 2024-05-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN117994635B (zh) | 一种噪声鲁棒性增强的联邦元学习图像识别方法及系统 | |
CN113191484A (zh) | 基于深度强化学习的联邦学习客户端智能选取方法及系统 | |
CN112949837A (zh) | 一种基于可信网络的目标识别联邦深度学习方法 | |
US20160012330A1 (en) | Neural network and method of neural network training | |
CN109983480A (zh) | 使用聚类损失训练神经网络 | |
CN117523291A (zh) | 基于联邦知识蒸馏和集成学习的图像分类方法 | |
CN115840900A (zh) | 一种基于自适应聚类分层的个性化联邦学习方法及系统 | |
CN115344883A (zh) | 一种用于处理不平衡数据的个性化联邦学习方法和装置 | |
CN114510652A (zh) | 一种基于联邦学习的社交协同过滤推荐方法 | |
CN117236421B (zh) | 一种基于联邦知识蒸馏的大模型训练方法 | |
CN117253072A (zh) | 一种基于个性化联邦学习的图像分类方法 | |
CN115409155A (zh) | 基于Transformer增强霍克斯过程的信息级联预测系统及方法 | |
CN107578101B (zh) | 一种数据流负载预测方法 | |
CN117290721A (zh) | 数字孪生建模方法、装置、设备及介质 | |
CN117113274A (zh) | 基于联邦蒸馏的异构网络无数据融合方法、系统 | |
CN117454330A (zh) | 一种对抗模型中毒攻击的个性化联邦学习方法 | |
CN114330464A (zh) | 一种融合元学习的多终端协同训练算法及系统 | |
CN116719607A (zh) | 基于联邦学习的模型更新方法及系统 | |
Singhal et al. | Greedy Shapley Client Selection for Communication-Efficient Federated Learning | |
CN116259057A (zh) | 基于联盟博弈解决联邦学习中数据异质性问题的方法 | |
CN115695429A (zh) | 面向Non-IID场景的联邦学习客户端选择方法 | |
CN115019359A (zh) | 一种云用户身份识别任务分配与并行处理方法 | |
CN114170338A (zh) | 一种差分隐私保护下基于自适应梯度裁剪的图像生成方法 | |
Yin et al. | SynCPFL: Synthetic Distribution Aware Clustered Framework for Personalized Federated Learning | |
CN117173750B (zh) | 生物信息处理方法、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |