CN116229172A - 基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 - Google Patents
基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 Download PDFInfo
- Publication number
- CN116229172A CN116229172A CN202310207512.4A CN202310207512A CN116229172A CN 116229172 A CN116229172 A CN 116229172A CN 202310207512 A CN202310207512 A CN 202310207512A CN 116229172 A CN116229172 A CN 116229172A
- Authority
- CN
- China
- Prior art keywords
- model
- image classification
- classification model
- sample
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 87
- 238000013145 classification model Methods 0.000 title claims abstract description 73
- 238000000034 method Methods 0.000 title claims abstract description 53
- 230000006870 function Effects 0.000 claims description 12
- 239000013598 vector Substances 0.000 claims description 10
- 230000002776 aggregation Effects 0.000 claims description 9
- 238000004220 aggregation Methods 0.000 claims description 9
- 238000003860 storage Methods 0.000 claims description 7
- 238000004364 calculation method Methods 0.000 claims description 6
- 238000011176 pooling Methods 0.000 claims description 6
- 238000004590 computer program Methods 0.000 claims description 4
- 238000013507 mapping Methods 0.000 claims description 4
- 238000004422 calculation algorithm Methods 0.000 claims description 3
- 230000008859 change Effects 0.000 claims description 3
- 238000010606 normalization Methods 0.000 claims description 3
- 238000009826 distribution Methods 0.000 description 5
- 230000007246 mechanism Effects 0.000 description 4
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000013461 design Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 238000007792 addition Methods 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000008569 process Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明提供一种基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备,包括:构建训练集和查询集,为训练集添加真实标签;获取初始模型,该初始模型包括嵌入网络和关系网络;将训练集样本和查询集样本成对输入嵌入网络,提取训练集样本特征图和查询集样本特征图并进行拼接,生成拼接特征图;将拼接特征图输入关系网络计算得到相似度分数,以得到训练集样本的类别;采用本地数据集对初始模型进行训练,并构建均方误差损失,得到初始图像分类模型;基于各客户端模型参数构建共享模型,并根据共享模型参数,采用指数移动平均更新所述初始图像分类模型,得到最终的图像分类模型。本发明提供的联邦学习模型训练简单且适用非独立同分布场景。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备。
背景技术
联邦学习被广泛应用于如智能手机、笔记本电脑、可穿戴设备等智能设备的数据隐私保护,用于协作优化共享模型,如用户习惯预测、无线网络优化、个性化推荐等。在联邦学习框架中,多个客户端通过权重聚合而不是本地数据交换协作优化共享模型,保护客户端用户隐私。然而,现有的联邦学习方法严重依赖于高质量的标记数据,例如,用户在使用照片分类器应用程序时,通常不愿意对隐私和敏感的图片进行专门的注释。利用分散的未标记图像数据来学习在保证隐私的情况下的共享模型是一个重要但被忽视的问题。少样本学习方法可以在每个类标签数很少的情况下进行良好工作,因此很多学者开始对联邦少样本学习进行研究。
联邦少样本学习在保护数据隐私的同时,从多个客户端学习少量标记数据。标签的可用性低是机器学习中一个长期存在的问题。虽然正则化和数据增强方法可以缓解过拟合,但不能解决此类问题。因此,现有的少样本方法通过特征嵌入或表示学习来学习可转移的知识,然后通过对下游目标进行微调来学习,但需要复杂的训练机制,以及昂贵的通信协议。此外,由于用户总是有不同的习惯和使用频率,多个设备之间的数据通常是非独立同分布的,导致共享模型性能下降。因此,虽然联邦少样本学习方法是有效的,但不适用于非独立同分布场景。
发明内容
鉴于此,本发明实施例提供了一种基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备,以消除或改善现有技术中存在的一个或更多个缺陷,解决现有联邦少样本学习方法训练复杂、共享模型聚合效果较差且不适用非独立同分布场景的问题。
一方面,本发明提供一种基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
获取本地数据集,并对所述本地数据集进行强数据增强;所述本地数据集包含多个类,每个类中包含多个样本,每个样本中包含一张图像;从所述本地数据集中随机选择第一预设数量类,每类中随机选择第二预设数量个样本,构建训练集,其余作为查询集;为所述训练集添加真实类别标签;
获取初始模型;所述初始模型包括嵌入网络和关系网络;将所述训练集的单个样本和所述查询集的单个样本成对输入所述嵌入网络,提取训练集样本特征图和查询集样本特征图;将所述训练集样本特征图和所述查询集样本特征图拼接,生成拼接特征图;将所述拼接特征图输入所述关系网络,生成所述第一预设数量个相似度分数,并根据所述相似度分数输出所述训练集样本相应的类别;
利用所述本地数据集对所述初始模型进行训练,构建所述关系网络输出的类别与所述真实类别标签之间的均方误差损失,利用所述均方误差损失对所述初始模型的参数进行迭代,以得到初始图像分类模型;
将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,以得到最终的图像分类模型。
在本发明的一些实施例中,获取本地数据集,并对所述本地数据集进行强数据增强,所述强数据增强至少包括空间变换裁剪、旋转、色彩抖动、改变亮度、灰度中的一种或多种组合操作。
在本发明的一些实施例中,所述嵌入网络包括3个卷积块和2个最大池化层,每个卷积块还包括1个卷积层、1个批处理归一化层和1个ReLu非线性层;
所述关系网络包括2个卷积块、2个最大池化层、第一全连接层和第二全连接层;所述第一全连接层还包括1个ReLu非线性层;所述第二全连接层还包括1个Sigmoid非线性层。
在本发明的一些实施例中,将所述拼接特征图输入所述关系网络,生成所述第一预设数量个相似度分数,所述相似度分数的计算式为:
si,j=fη(Concat(fθ(xi),fθ(xi))),i=1,2,…,C;
其中,si,j表示所述相似度分数;fη表示所述关系网络的关系函数;Concat(·)表示深度向量拼接算法;fθ表示所述嵌入网络的嵌入函数;xi表示所述训练集的第i个样本;xj表示所述查询集的第j个样本;C表示所述第一预设数量类。
在本发明的一些实施例中,根据所述相似度分数输出所述训练集样本相应的类别,还包括:
将所述相似度分数输入预设的Sigmoid非线性层,基于Sigmoid函数映射得到一组浮点数向量,获取所述浮点数向量中的最大值,将其作为所述初始模型的输出,得到所述训练集样本的类别。
在本发明的一些实施例中,构建所述关系网络输出的类别与所述真实类别标签之间的均方误差损失,所述均方误差损失计算式为:
其中,L表示所述均方误差损失;m表示输入所述初始模型的样本数量;n表示所述查询集的样本数量;si,j表示所述相似度分数;yi表示所述关系网络输出的类别;yj表示所述真实类别标签。
在本发明的一些实施例中,接收所述共享模型的参数,并采用指数移动平均更新初始图像分类模型,计算式为:
其中,θg和ηg为所述共享模型的参数;θ表示衰减率,η表示更新阈值;t表示第t轮所述共享模型参数聚合;μ表示预设阈值;ξ表示所述初始图像分类模型的参数与共享模型的参数在更新中分别占的权重。
在本发明的一些实施例中,所述方法还包括:
计算所述初始图像分类模型在训练时的模型散度,当所述模型散度大于所述预设阈值时,客户端使用所述共享模型的权重进行更新;当所述模型散度小于或等于所述预设阈值时,客户端使用其初始图像分类模型和所述共享模型的权重组合进行更新。
另一方面,本发明还提供一种基于对比学习的联邦少样本图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
获取待分类的图像;
将所述图像输入如上中任一项所述基于对比学习的联邦少样本图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
另一方面,本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上文中提及的任意一项所述方法的步骤。
本发明的有益效果至少是:
本发明提供一种基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备,包括:获取各客户端本地数据集,并对本地数据集进行强数据增强,可以学习具有明确决策边界的鲁棒分类器;基于本地数据集构建训练集和查询集,并为训练集添加真实类别标签。获取初始模型,该初始模型包括嵌入网络和关系网络,利用嵌入网络和关系网络根据图像的特征相似度对样本进行分类,有效进行少样本学习;利用本地数据集对初始模型进行训练,并构建均方误差损失,以得到初始图像分类模型。考虑到数据的非独立同分布特征,设计动态更新机制,计算各客户端的初始图像分类模型的权重发散程度,基于权重发散程度根据共享模型参数对各客户端初始图像分类模型进行更新,以得到最终的图像分类模型。基于本发明提供的训练方法得到的图像分类模型具备较高的准确率和可扩展性,可以实现准确的图像分类,且适用于非独立同分布场景。
本发明的附加优点、目的,以及特征将在下面的描述中将部分地加以阐述,且将对于本领域普通技术人员在研究下文后部分地变得明显,或者可以根据本发明的实践而获知。本发明的目的和其它优点可以通过在说明书以及附图中具体指出的结构实现到并获得。
本领域技术人员将会理解的是,能够用本发明实现的目的和优点不限于以上具体所述,并且根据以下详细说明将更清楚地理解本发明能够实现的上述和其他目的。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,并不构成对本发明的限定。在附图中:
图1为本发明一实施例中基于对比学习的联邦少样本图像分类模型训练方法的步骤示意图。
图2为本发明一实施例中基于对比学习的联邦少样本图像分类模型训练方法的流程示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚明白,下面结合实施方式和附图,对本发明做进一步详细说明。在此,本发明的示意性实施方式及其说明用于解释本发明,但并不作为对本发明的限定。
在此,还需要说明的是,为了避免因不必要的细节而模糊了本发明,在附图中仅仅示出了与根据本发明的方案密切相关的结构和/或处理步骤,而省略了与本发明关系不大的其他细节。
应该强调,术语“包括/包含”在本文使用时指特征、要素、步骤或组件的存在,但并不排除一个或更多个其它特征、要素、步骤或组件的存在或附加。
在此,还需要说明的是,如果没有特殊说明,术语“连接”在本文不仅可以指直接连接,也可以表示存在中间物的间接连接。
在下文中,将参考附图描述本发明的实施例。在附图中,相同的附图标记代表相同或类似的部件,或者相同或类似的步骤。
这里需要强调的是,在下文中提及的各步骤标记并不是对各步骤先后顺序的限定,而应当理解为可以按照实施例中提及的顺序执行步骤,也可以不同于实施例中的顺序,或者若干步骤同时执行。
为了解决现有联邦少样本学习方法训练复杂、共享模型聚合效果较差且不适用非独立同分布场景的问题,本发明提供一种基于对比学习的联邦少样本图像分类模型训练方法,如图1所示,该方法包括以下步骤S101~S104:
步骤S101:获取本地数据集,并对本地数据集进行强数据增强。其中,本地数据集包含多个类,每个类中包含多个样本,每个样本中包含一张图像。从本地数据集中随机选择第一预设数量类,每类中随机选择第二预设数量个样本,构建训练集,其余作为查询集,为训练集添加真实类别标签。
步骤S102:获取初始模型;该初始模型包括嵌入网络和关系网络;将训练集的单个样本和查询集的单个样本成对输入嵌入网络,提取训练集样本特征图和查询集样本特征图;将训练集样本特征图和查询集样本特征图拼接,生成拼接特征图;将拼接特征图输入关系网络,生成第一预设数量个相似度分数,并根据相似度分数输出训练集样本相应的类别。
步骤S103:利用本地数据集对初始模型进行训练,构建关系网络输出的类别与真实类别标签之间的均方误差损失,利用均方误差损失对初始模型的参数进行迭代,以得到初始图像分类模型。
步骤S104:将初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;其中,共享模型由全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收共享模型的参数,并采用指数移动平均更新初始图像分类模型,以得到最终的图像分类模型。
本发明提出了一种基于对比网络的联邦少样本学习框架,依托传统的分布式学习架构和深度神经网络框架进行模型训练,其中,深度神经网络框架可以选用PyTorch、TensorFlow等。
在步骤S101中,获取各客户端的本地数据集,考虑到在背景技术提及的过拟合问题,本发明对本地数据集进行强数据增强。强数据增强相对于弱增强,包括随机调整大小、裁剪或翻转。对于联邦学习框架中的每个客户端,采用相同的增强策略,不会在不同的数据集上搜索最佳策略。
在一些实施例中,强数据增强至少包括空间变换裁剪、旋转、色彩抖动、改变亮度、灰度中的一种或多种组合操作。基于强数据增强设计,使得联邦学习具有明确决策边界的鲁棒分类器。
从本地数据集中随机选择第一预设数量类,每类中随机选择第二预设数量个样本,构建训练集,其余作为查询集,并为训练集添加真实类别标签。示例性的,第一预设数量记作C,训练集记作S,查询集记作Q;其中,训练集中第i个样本记作xi∈S;查询集中第j个样本记作xj∈Q。
在步骤S102中,各客户端从全局服务器获取初始模型,该初始模型包括嵌入网络和关系网络。其中,嵌入网络由一个权重为θ的神经网络构成,通过学习一个嵌入函数fθ,将输入的图像映射为嵌入特征向量;关系网络是通过学习一个关系函数fη来计算与真实类别标签的关系分数。
在一些实施例中,嵌入网络包括3个卷积块和2个2×2最大池化层,每个卷积块还包括1个3×3卷积层、1个批处理归一化层和1个ReLu非线性层。其中,卷积层的通道数为64。
在一些实施例中,关系网络包括2个卷积块、2个2×2最大池化层、第一全连接层和第二全连接层。其中,卷积块的结构与嵌入网络的卷积块结构相同,第一全连接层为8单元,包括1个ReLu非线性层;第二全连接层为8单元,包括1个Sigmoid非线性层。
如图2所示,为初始模型基于联邦少样本学习的整体流程,具体的:将训练集的单个样本xi和查询集的单个样本xj成对输入嵌入网络,提取训练集样本特征图fθ(xi)和查询集样本特征图fθ(xj)。将训练集样本特征图和查询集样本特征图拼接,生成拼接特征图Concat(fθ(xi),fθ(xj))。将拼接特征图输入关系网络,生成第一预设数量C个相似度分数si,j,并根据相似度分数输出训练集样本相应的类别。
在一些实施例中,相似度分数的计算式如公式(1)所示:
si,j=fη(Concat(fθ(xi),fθ(xj))),i=1,2,...,C; (1)
其中,si,j表示相似度分数;fη表示关系网络的关系函数;Concat(·)表示深度向量拼接算法;fθ表示嵌入网络的嵌入函数;xi表示训练集的第i个样本;xj表示查询集的第j个样本;C表示第一预设数量类。
在一些实施例中,根据相似度分数输出训练集样本相应的类别,还包括以下步骤:
关系网络中预设Sigmoid非线性层,将相似度分数输入预设的Sigmoid非线性层,基于Sigmoid函数映射得到一组浮点数向量,获取浮点数向量中的最大值,将其作为初始模型的输出,得到训练集样本的类别。
在步骤S103中,利用本地数据集对初始模型进行训练,同时构建关系网络输出的类别与真实类别标签之间的均方误差损失,利用均方误差损失对初始模型的参数进行迭代,以得到初始图像分类模型。
在一些实施例中,均方误差损失计算式如公式(2)所示:
其中,L表示均方误差损失;m表示输入初始模型的样本数量;n表示查询集的样本数量;si,j表示相似度分数;yi表示关系网络输出的类别;yj表示真实类别标签。
在步骤S104中,受数据非独立同分布会导致权值发散的启发,本发明设计了动态更新机制,基于权值发散动态更新初始图像分类模型的参数。
在一轮训练中,基于本地数据集训练得到各客户端的初始图像分类模型后,各客户端向全局服务器发送各自模型的参数,其中全局服务器在图2中用云服务器表示。全局服务器获取各客户端模型的参数后,进行加权聚合,构建共享模型,并将共享模型的参数发送至各客户端。各客户端接收共享模型的参数,并采用指数移动平均更新初始图像分类模型,以得到最终的图像分类模型。
在一些实施例中,采用指数移动平均更新初始图像分类模型,计算式如公式(3)和公式(4)所示:
其中,θg和ηg为共享模型的参数;θ表示衰减率,η表示更新阈值;t表示第t轮共享模型参数聚合;μ表示预设阈值;ξ表示初始图像分类模型的参数与共享模型的参数在更新中分别占的权重。
在一些实施例中,计算得到初始图像分类模型在训练时的模型散度后,当模型散度大于预设阈值时,客户端使用共享模型的权重进行更新;当模型散度小于或等于预设阈值时,客户端使用其初始图像分类模型和共享模型的权重组合进行更新。
本发明还提供一种基于对比学习的联邦少样本图像分类方法,该方法包括以下步骤S201~S202:
步骤S201:获取待分类的图像。
步骤S202:将图像输入如上文所述基于对比学习的联邦少样本图像分类模型训练方法得到的图像分类模型,以得到图像的类别。
本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现基于对比学习的联邦少样本图像分类模型训练方法和基于对比学习的联邦少样本图像分类方法的步骤。
与上述方法相应地,本发明还提供了一种设备,该设备包括计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机指令,所述处理器用于执行所述存储器中存储的计算机指令,当所述计算机指令被处理器执行时该设备实现如前所述方法的步骤。
本发明实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时以实现前述边缘计算服务器部署方法的步骤。该计算机可读存储介质可以是有形存储介质,诸如随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、软盘、硬盘、可移动存储盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质。
综上所述,本发明提供一种基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备,包括:获取各客户端本地数据集,并对本地数据集进行强数据增强,可以学习具有明确决策边界的鲁棒分类器;基于本地数据集构建训练集和查询集,并为训练集添加真实类别标签。获取初始模型,该初始模型包括嵌入网络和关系网络,利用嵌入网络和关系网络根据图像的特征相似度对样本进行分类,有效进行少样本学习;利用本地数据集对初始模型进行训练,并构建均方误差损失,以得到初始图像分类模型。考虑到数据的非独立同分布特征,设计动态更新机制,计算各客户端的初始图像分类模型的权重发散程度,基于权重发散程度根据共享模型参数对各客户端初始图像分类模型进行更新,以得到最终的图像分类模型。基于本发明提供的训练方法得到的图像分类模型具备较高的准确率和可扩展性,可以实现准确的图像分类,且适用于非独立同分布场景。
本领域普通技术人员应该可以明白,结合本文中所公开的实施方式描述的各示例性的组成部分、系统和方法,能够以硬件、软件或者二者的结合来实现。具体究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。当以硬件方式实现时,其可以例如是电子电路、专用集成电路(ASIC)、适当的固件、插件、功能卡等等。当以软件方式实现时,本发明的元素是被用于执行所需任务的程序或者代码段。程序或者代码段可以存储在机器可读介质中,或者通过载波中携带的数据信号在传输介质或者通信链路上传送。
需要明确的是,本发明并不局限于上文所描述并在图中示出的特定配置和处理。为了简明起见,这里省略了对已知方法的详细描述。在上述实施例中,描述和示出了若干具体的步骤作为示例。但是,本发明的方法过程并不限于所描述和示出的具体步骤,本领域的技术人员可以在领会本发明的精神后,作出各种改变、修改和添加,或者改变步骤之间的顺序。
本发明中,针对一个实施方式描述和/或例示的特征,可以在一个或更多个其它实施方式中以相同方式或以类似方式使用,和/或与其他实施方式的特征相结合或代替其他实施方式的特征。
以上所述仅为本发明的优选实施例,并不用于限制本发明,对于本领域的技术人员来说,本发明实施例可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
获取本地数据集,并对所述本地数据集进行强数据增强;所述本地数据集包含多个类,每个类中包含多个样本,每个样本中包含一张图像;从所述本地数据集中随机选择第一预设数量类,每类中随机选择第二预设数量个样本,构建训练集,其余作为查询集;为所述训练集添加真实类别标签;
获取初始模型;所述初始模型包括嵌入网络和关系网络;将所述训练集的单个样本和所述查询集的单个样本成对输入所述嵌入网络,提取训练集样本特征图和查询集样本特征图;将所述训练集样本特征图和所述查询集样本特征图拼接,生成拼接特征图;将所述拼接特征图输入所述关系网络,生成所述第一预设数量个相似度分数,并根据所述相似度分数输出所述训练集样本相应的类别;
利用所述本地数据集对所述初始模型进行训练,构建所述关系网络输出的类别与所述真实类别标签之间的均方误差损失,利用所述均方误差损失对所述初始模型的参数进行迭代,以得到初始图像分类模型;
将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,以得到最终的图像分类模型。
2.根据权利要求1所述的基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,获取本地数据集,并对所述本地数据集进行强数据增强,所述强数据增强至少包括空间变换裁剪、旋转、色彩抖动、改变亮度、灰度中的一种或多种组合操作。
3.根据权利要求1所述的基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,所述嵌入网络包括3个卷积块和2个最大池化层,每个卷积块还包括1个卷积层、1个批处理归一化层和1个ReLu非线性层;
所述关系网络包括2个卷积块、2个最大池化层、第一全连接层和第二全连接层;所述第一全连接层还包括1个ReLu非线性层;所述第二全连接层还包括1个Sigmoid非线性层。
4.根据权利要求1所述的基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,将所述拼接特征图输入所述关系网络,生成所述第一预设数量个相似度分数,所述相似度分数的计算式为:
si,j=fη(Concat(fθ(xi),fθ(xj))),i=1,2,…,C;
其中,si,j表示所述相似度分数;fη表示所述关系网络的关系函数;Concat(·)表示深度向量拼接算法;fθ表示所述嵌入网络的嵌入函数;xi表示所述训练集的第i个样本;xj表示所述查询集的第j个样本;C表示所述第一预设数量类。
5.根据权利要求1所述的基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,根据所述相似度分数输出所述训练集样本相应的类别,还包括:
将所述相似度分数输入预设的Sigmoid非线性层,基于Sigmoid函数映射得到一组浮点数向量,获取所述浮点数向量中的最大值,将其作为所述初始模型的输出,得到所述训练集样本的类别。
8.根据权利要求7所述的基于对比学习的联邦少样本图像分类模型训练方法,其特征在于,还包括:
计算所述初始图像分类模型在训练时的模型散度,当所述模型散度大于所述预设阈值时,客户端使用所述共享模型的权重进行更新;当所述模型散度小于或等于所述预设阈值时,客户端使用其初始图像分类模型和所述共享模型的权重组合进行更新。
9.一种基于对比学习的联邦少样本图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
获取待分类的图像;
将所述图像输入如权利要求1至8中任一项所述基于对比学习的联邦少样本图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1至9中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310207512.4A CN116229172A (zh) | 2023-03-03 | 2023-03-03 | 基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310207512.4A CN116229172A (zh) | 2023-03-03 | 2023-03-03 | 基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116229172A true CN116229172A (zh) | 2023-06-06 |
Family
ID=86584133
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310207512.4A Pending CN116229172A (zh) | 2023-03-03 | 2023-03-03 | 基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116229172A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118196514A (zh) * | 2024-03-25 | 2024-06-14 | 东营市人民医院 | 医疗影像识别模型生成方法及系统 |
-
2023
- 2023-03-03 CN CN202310207512.4A patent/CN116229172A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118196514A (zh) * | 2024-03-25 | 2024-06-14 | 东营市人民医院 | 医疗影像识别模型生成方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP7185039B2 (ja) | 画像分類モデルの訓練方法、画像処理方法及びその装置、並びにコンピュータプログラム | |
US12100192B2 (en) | Method, apparatus, and electronic device for training place recognition model | |
CN109840531B (zh) | 训练多标签分类模型的方法和装置 | |
WO2020114378A1 (zh) | 视频水印的识别方法、装置、设备及存储介质 | |
CN108132968B (zh) | 网络文本与图像中关联语义基元的弱监督学习方法 | |
WO2019100724A1 (zh) | 训练多标签分类模型的方法和装置 | |
Yu et al. | Hybrid dual-tree complex wavelet transform and support vector machine for digital multi-focus image fusion | |
CN110765882B (zh) | 一种视频标签确定方法、装置、服务器及存储介质 | |
CN112508094A (zh) | 垃圾图片的识别方法、装置及设备 | |
WO2019146057A1 (ja) | 学習装置、実写画像分類装置の生成システム、実写画像分類装置の生成装置、学習方法及びプログラム | |
CN116310530A (zh) | 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 | |
CN112801107B (zh) | 一种图像分割方法和电子设备 | |
CN111223128A (zh) | 目标跟踪方法、装置、设备及存储介质 | |
CN111507406A (zh) | 一种用于优化神经网络文本识别模型的方法与设备 | |
CN116229170A (zh) | 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 | |
CN112668482A (zh) | 人脸识别训练方法、装置、计算机设备及存储介质 | |
CN116229172A (zh) | 基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 | |
CN113570512A (zh) | 一种图像数据处理方法、计算机及可读存储介质 | |
CN117095252A (zh) | 目标检测方法 | |
Deva Shahila et al. | Soft computing-based non-linear discriminate classifier for multimedia image quality enhancement | |
CN113486736B (zh) | 一种基于活跃子空间与低秩进化策略的黑盒对抗攻击方法 | |
CN115546554A (zh) | 敏感图像的识别方法、装置、设备和计算机可读存储介质 | |
CN115115910A (zh) | 图像处理模型的训练方法、使用方法、装置、设备及介质 | |
CN115471714A (zh) | 数据处理方法、装置、计算设备和计算机可读存储介质 | |
US20240127104A1 (en) | Information retrieval systems and methods with granularity-aware adaptors for solving multiple different tasks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |