CN116310530A - 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 - Google Patents
基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 Download PDFInfo
- Publication number
- CN116310530A CN116310530A CN202310205865.0A CN202310205865A CN116310530A CN 116310530 A CN116310530 A CN 116310530A CN 202310205865 A CN202310205865 A CN 202310205865A CN 116310530 A CN116310530 A CN 116310530A
- Authority
- CN
- China
- Prior art keywords
- model
- image classification
- initial
- classification model
- data set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000013145 classification model Methods 0.000 title claims abstract description 75
- 238000000034 method Methods 0.000 title claims abstract description 55
- 238000012549 training Methods 0.000 title claims abstract description 49
- 239000013598 vector Substances 0.000 claims abstract description 47
- 238000002372 labelling Methods 0.000 claims abstract description 8
- 230000002708 enhancing effect Effects 0.000 claims abstract description 6
- 230000002776 aggregation Effects 0.000 claims description 9
- 238000004220 aggregation Methods 0.000 claims description 9
- 238000004364 calculation method Methods 0.000 claims description 9
- 238000003860 storage Methods 0.000 claims description 7
- 238000013528 artificial neural network Methods 0.000 claims description 6
- 238000004590 computer program Methods 0.000 claims description 4
- 238000011524 similarity measure Methods 0.000 claims description 3
- 230000006870 function Effects 0.000 description 12
- 238000009826 distribution Methods 0.000 description 11
- 230000008569 process Effects 0.000 description 5
- 230000007246 mechanism Effects 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 238000005070 sampling Methods 0.000 description 3
- 238000007792 addition Methods 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 2
- 230000008859 change Effects 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000007786 learning performance Effects 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 238000004138 cluster model Methods 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000001131 transforming effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/762—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using clustering, e.g. of similar faces in social networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明提供一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备:获取客户端本地数据集;获取初始模型,包括语义聚类模型和预训练得到的编码器网络;将本地数据集随机增强两次生成两个视图,输入初始编码器网络,提取特征向量并构建对比损失,训练得到编码器网络;将本地数据集输入编码器网络提取特征向量,并提取样本的Top‑K近邻样本,利用预设函数计算样本所属于不同集群的向量值,得到样本的类别;采用本地数据集对初始模型训练,构建聚类损失,以得到初始图像分类模型;基于各客户端模型参数构建设有自标记模块的共享模型,并根据共享模型参数更新初始图像分类模型。本发明提供的图像分类模型准确率高且适用非独立同分布场景。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备。
背景技术
随着智能设备的普及,联邦学习已经成为最常用的一种隐私保护模型共享方法,并在用户习惯预测、个性化推荐和无线网络优化等许多场景中得到了广泛的应用。现有的联邦学习方法通常只考虑有监督的训练设置,其中客户端数据被完全标记。然而,包含复杂注释的本地数据对于物联网应用来说是不现实的,因为用户总是有不同的习惯和使用频率。示例性的,假设有一个照片分类器应用程序,可以实现自动对相册中的图片进行分类。在这种情况下,应用程序的用户若不愿意自己注释这些隐私和敏感的图片,则会导致服务提供商只能在中央服务器上使用有限的公共图片。因此,在许多现实的物联网场景中,客户端数据可能没有完全标记,只有少量标记数据在服务器上可用。
现有的联邦学习方法在缺乏标签数据场景下主要是采用联邦半监督学习。联邦半监督学习的目标是学习多个客户端之间的一致性。部分工作通过客户端间一致性损失,用于对标记数据和未标记数据进行分布式训练,或是考虑参数更新多样性的半监督训练多样性缩放聚合算法,在移动设备之间交换局部模型的输出,而不是典型框架中使用的模型参数交换。但是,当下游任务没有可用的标签时,这些方法的性能不佳。同时,与理想化的分布条件不同,由于用户的使用习惯和使用频率不同,物联网设备之间的数据通常是非独立的同分布,也会导致共享模型性能下降。
发明内容
鉴于此,本发明实施例提供了一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备,以消除或改善现有技术中存在的一个或更多个缺陷,解决现有联邦无监督学习性能较差且不适用非独立同分布场景的问题。
一方面,本发明提供一种基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
获取本地数据集,所述本地数据集包含多个样本,每个样本包含一张图像;
获取初始模型,所述初始模型包括语义聚类模型和预训练得到的编码器网络;其中,将所述本地数据集的样本进行两次随机增强,生成第一视图和第二视图;将所述第一视图和所述第二视图一同输入初始编码器网络,提取第一特征向量和第二特征向量;采用所述本地数据集对所述初始编码器网络进行训练,并构建第一特征向量和第二特征向量之间的对比损失,以得到训练好的编码器网络;将所述本地数据集按批输入所述语义聚类模型,利用所述编码器网络提取对应样本的特征向量;基于预设神经网络从所述特征向量中提取对应样本的Top-K近邻样本,通过预设Softmax函数计算对应样本所属于不同集群的向量值,以得到对应样本的类别;
采用所述本地数据集对所述初始模型进行训练,并构建聚类损失,利用所述聚类损失对所述初始模型的参数进行迭代,以得到初始图像分类模型;
将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,以得到最终的图像分类模型;其中,所述共享模型还包括自标记模块,所述自标记模块为基于所述共享模型得到的类别设置伪标签,并构建基于所述共享模型得到的类别与相应伪标签之间的交叉熵损失,利用所述交叉熵损失更新所述共享模型参数。
在本发明的一些实施例中,将所述本地数据集的样本进行两次随机增强,所述随机增强至少包括空间变换裁剪、旋转、调节饱和度、调节对比度、调节色调、调节颜色、调节亮度和调节灰度中的一种或多种组合操作。
在本发明的一些实施例中,构建第一特征向量和第二特征向量之间的对比损失,所述对比损失采用归一化温度交叉熵损失。
在本发明的一些实施例中,所述对比损失的计算式为:
其中,表示所述对比损失;i,j分别表示所述第一视图和所述第二视图;zi,zj分别表示所述第一特征向量和所述第二特征向量;sim(zi,zj)表示所述第一视图和所述第二视图的相似度度量;τ是温度因子;M表示所述本地数据集中样本的数量;m表示所述本地数据集中样本的序号。
在本发明的一些实施例中,采用所述本地数据集对所述初始模型进行训练,并构建聚类损失;所述聚类损失的计算式为:
其中,表示所述聚类损失;x表示所述本地数据集xc中的单个样本;/>表示x的相邻样本集Nx中的单个近邻图像样本;q(·)表示预设函数;<·>表示点积运算符号;λ表示权重;k表示集群;pk表示被分配到集群k的概率。
在本发明的一些实施例中,接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,计算式为:
其中,qc表示客户端c的初始图像分类模型参数;qg表示所述共享模型参数;t表示第t轮所述共享模型参数聚合;μ表示预设阈值;ξ表示所述初始图像分类模型参数与所述共享模型参数在更新中分别占的权重。
在本发明的一些实施例中,还包括:
计算所述初始图像分类模型在训练时的模型散度,当所述模型散度大于所述预设阈值时,客户端使用所述共享模型的权重进行更新;当所述模型散度小于或等于所述预设阈值时,客户端使用其初始图像分类模型和所述共享模型的权重组合进行更新。
在本发明的一些实施例中,基于预设置信阈值选择置信度大于所述预设置信阈值的样本,并为相应样本基于所述共享模型得到的类别设置伪标签,构建基于所述共享模型得到的类别与相应伪标签之间的交叉熵损失,所述交叉熵损失计算式为:
其中,Lself表示所述交叉熵损失;x表示所述全局服务器的数据集xg中的单个样本;σ表示所述预设置信阈值;p(x)表示所述共享模型的输出;表示样本x的伪标签;H(·)表示所述伪标签上的标准交叉熵损失。
另一方面,本发明提供一种基于语义聚类的联邦无监督图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
获取待分类的图像;
将所述图像输入如上文中任一项所述基于语义聚类的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
另一方面,本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上文中提及的任意一项所述方法的步骤。
本发明的有益效果至少是:
本发明提供一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备,包括:获取客户端本地数据集,以构建模型的训练集。获取初始模型,包括语义聚类模型和预训练得到的编码器网络。在编码器网络训练中,采用随机数据增强作为对比学习的前置任务,将本地数据集随机增强两次生成两个视图,输入初始编码器网络,提取特征向量;并采用归一化温度交叉熵损失函数进行对比学习,以训练得到编码器网络。利用训练好的编码器网络提取本地数据集的特征向量,采用最近邻语义聚类方法,将对比学习获得的先验知识集成到聚类损失函数中,根据特征相似度对样本进行分类,实现无监督学习。基于各客户端模型参数加权聚合构建共享模型,考虑到各客户端数据的非独立同分布特性,设计动态更新机制,根据客户端模型的权重发散程度利用共享模型参数更新初始图像分类模型;针对聚类过程中存在的聚类错误问题,在共享模型中设置自标记模块,利用高度自信的预测样本来纠正聚类错误,提升模型的分类性能。基于本发明提供的无监督学习方法训练得到的图像分类模型具备高准确率、可扩展性且适用非独立同分布场景。
本发明的附加优点、目的,以及特征将在下面的描述中将部分地加以阐述,且将对于本领域普通技术人员在研究下文后部分地变得明显,或者可以根据本发明的实践而获知。本发明的目的和其它优点可以通过在说明书以及附图中具体指出的结构实现到并获得。
本领域技术人员将会理解的是,能够用本发明实现的目的和优点不限于以上具体所述,并且根据以下详细说明将更清楚地理解本发明能够实现的上述和其他目的。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,并不构成对本发明的限定。在附图中:
图1为本发明一实施例中基于语义聚类的联邦无监督图像分类模型训练方法的步骤示意图。
图2为本发明一实施例中基于语义聚类的联邦无监督图像分类模型训练方法的结构流程示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚明白,下面结合实施方式和附图,对本发明做进一步详细说明。在此,本发明的示意性实施方式及其说明用于解释本发明,但并不作为对本发明的限定。
在此,还需要说明的是,为了避免因不必要的细节而模糊了本发明,在附图中仅仅示出了与根据本发明的方案密切相关的结构和/或处理步骤,而省略了与本发明关系不大的其他细节。
应该强调,术语“包括/包含”在本文使用时指特征、要素、步骤或组件的存在,但并不排除一个或更多个其它特征、要素、步骤或组件的存在或附加。
在此,还需要说明的是,如果没有特殊说明,术语“连接”在本文不仅可以指直接连接,也可以表示存在中间物的间接连接。
在下文中,将参考附图描述本发明的实施例。在附图中,相同的附图标记代表相同或类似的部件,或者相同或类似的步骤。
这里需要强调的是,在下文中提及的各步骤标记并不是对各步骤先后顺序的限定,而应当理解为可以按照实施例中提及的顺序执行步骤,也可以不同于实施例中的顺序,或者若干步骤同时执行。
为了解决现有联邦无监督学习性能较差且不适用非独立同分布场景的问题,本发明提供一种基于语义聚类的联邦无监督图像分类模型训练方法,如图1所示,该方法在各客户端执行,包括以下步骤S101~S104:
步骤S101:获取本地数据集。其中,本地数据集包含多个样本,每个样本包含一张图像。
步骤S102:获取初始模型,初始模型包括语义聚类模型和预训练得到的编码器网络。其中,将本地数据集的样本进行两次随机增强,生成第一视图和第二视图;将第一视图和第二视图一同输入初始编码器网络,提取第一特征向量和第二特征向量;采用本地数据集对所述初始编码器网络进行训练,并构建第一特征向量和第二特征向量之间的对比损失,以得到训练好的编码器网络。将本地数据集按批输入语义聚类模型,利用编码器网络提取对应样本的特征向量;基于预设神经网络从特征向量中提取对应样本的Top-K近邻样本,通过预设Softmax函数计算对应样本所属于不同集群的向量值,以得到对应样本的类别。
步骤S103:采用所本地数据集对初始模型进行训练,并构建聚类损失,利用聚类损失对初始模型的参数进行迭代,以得到初始图像分类模型。
步骤S104:将初始图像分类模型的模型参数发送至全局服务器,以生成共享模型。其中,共享模型由全局服务器根据各客户端初始图像分类模型参数加权聚合得到。接收共享模型的参数,并采用指数移动平均更新初始图像分类模型,以得到最终的图像分类模型。其中,共享模型还包括自标记模块,自标记模块为基于共享模型得到的类别设置伪标签,并构建基于共享模型得到的类别与相应伪标签之间的交叉熵损失,利用交叉熵损失更新共享模型参数。
本发明提出了一种基于语义聚类的联邦无监督学习框架,依托传统的分布式学习架构和深度神经网络框架进行模型训练。其中,深度神经网络框架可以选用PyTorch、TensorFlow等。
如图2所示,为基于语义聚类的联邦无监督图像分类模型训练方法的整体流程图。
在步骤S101中,获取各客户端的本地数据集,在各客户端中进行局部训练。
在步骤S102中,各客户端从全局服务器获取初始模型,该初始模型包括语义聚类模型和预训练得到的编码器网络。
首先对初始编码器网络进行训练。在无监督学习的场景中,由于没有可用的标签,必须定义一个前置条件来确定哪些样本是相同的类。因此,使用数据增强来约束模型预测不受噪声影响,同时基于对比学习,通过在每个客户端上最大化增强图像样本之间的一致性来学习通用模型表示。采用本地数据集或随机抽样小批量样本,对样本进行两次随机增强,生成第一视图和第二视图。
在一些实施例中,本发明采用的随机增强至少包括空间变换裁剪、旋转、调节饱和度、调节对比度、调节色调、调节颜色、调节亮度和调节灰度中的一种或多种组合操作。同时,对于联邦学习框架中的每个客户端,采用相同的增强策略,不会在不同的数据集上搜索最佳策略。
将第一视图和第二视图一同输入初始编码器网络,生成两个视图的语义特征表示,再通过预设全连接层的非线性函数对两个视图的语义特征表示进行变换,分别生成第一特征向量和第二特征向量。其中,第一特征向量和第二特征向量用于计算对比损失。
根据观察发现,具有相似高阶特征的图像样本会更加接近,因此,在一些实施例中,采用归一化温度交叉熵作为本发明的对比损失,每个客户端的对比损失计算式可如公式(1)表示为:
其中,sim(zi,zj的计算式如公式(2)所示:
公式(1)和(2)中,表示对比损失;i,j分别表示第一视图和第二视图;zi,zj分别表示第一特征向量和第二特征向量;sim(zi,zj)表示第一视图和第二视图的相似度度量;τ是温度因子;M表示本地数据集或随机抽样小批量样本中的样本数量;m表示本地数据集或随机抽样小批量样本中的样本的序号;[·]T表示向量转置。
采用本地数据集或随机抽样小批量样本对初始编码器网络进行训练,并利用对比损失对初始编码器网络的参数进行迭代,以得到训练好的编码器网络。
将训练好的编码器网络应用于语义聚类模型中,将样本图像映射为特征表示。具体的,将本地数据集按批输入语义聚类模型,利用编码器网络提取对应样本的特征向量,基于预设神经网络从特征向量中提取对应样本的Top-K近邻样本,利用预设函数将输入的图像样本通过预设Softmax函数计算其所属于不同集群的向量值,根据向量值最终确定对应样本的类别。
在步骤S103中,利用本地数据集对初始模型进行训练,同时构建聚类损失,利用聚类损失对初始模型的参数进行迭代,以得到初始图像分类模型。
在一些实施例中,定义Nx作为输入图像样本的相邻样本集,图像样本被分配到相应集群的概率记作p,由此,每个客户端的聚类损失的计算式如公式(3)表示为:
其中,pk的计算式如公式(4)所示:
公式(3)和(4)中,表示聚类损失;x表示本地数据集xc中的单个样本;/>表示x的相邻样本集Nx中的单个近邻图像样本;q(·)表示预设函数;<·>表示点积运算符号;λ表示权重;k表示集群;pk表示x被分配到集群k的概率。
在步骤S104中,受数据非独立同分布会导致权重发散的启发,本发明设计了动态更新机制,基于权重发散动态更新初始图像分类模型的参数。
在一轮训练中,基于本地数据集训练得到各客户端的初始图像分类模型后,各客户端向全局服务器发送各自模型的参数,其中全局服务器在图2中用云服务器表示。全局服务器获取各客户端模型的参数后,进行加权聚合,构建共享模型,并将共享模型的参数发送至各客户端。各客户端接收共享模型的参数,并采用指数移动平均更新初始图像分类模型,以得到最终的图像分类模型。
在一些实施例中,采用指数移动平均更新初始图像分类模型,计算式如公式(5)所示:
其中,qc表示客户端c的初始图像分类模型参数;qg表示共享模型参数;t表示第t轮共享模型参数聚合;μ表示预设阈值;ξ表示初始图像分类模型参数与共享模型参数在更新中分别占的权重。
利用来度量初始图像分类模型在训练时的模型散度。由于各客户端的数据是不平衡且非独立同分布的,因此,每个客户端的初始图像分类模型参数qc可能存在较大的方差。在本发明中,定义一个预设阈值,并计算初始图像分类模型在训练时的模型散度,当模型散度大于预设阈值时,客户端使用共享模型的权重进行更新;当模型散度小于或等于预设阈值时,客户端使用其初始图像分类模型和共享模型的权重组合进行更新。
在语义聚类过程中,每个图像样本都有K邻居,因此不可避免地会将一些样本分配至错误的聚类中。在本发明中,在共享模型中设计自标记模块,用于将具有高度自信预测的样本倾向于被分配到正确的簇。由此,利用这些高度自信的预测样本,可靠地纠正聚类过程中的错误,进一步提高图像分类模型的性能。
具体的,全局服务器将各客户端的初始图像分类模型参数加权聚合后,定义一个预设置信阈值,选择置信度大于预设置信阈值的样本,对于每个可信样本,为其预测的聚类分配一个伪标签,并构建基于共享模型得到的类别与相应伪标签之间的交叉熵损失,以更新共享模型参数。
在一些实施例中,交叉熵损失的计算式如公式(6)所示:
其中,Lself表示交叉熵损失;x表示全局服务器的数据集xg中的单个样本;σ表示预设置信阈值;p(x)表示共享模型的输出;表示样本x的伪标签;/>使用argmax函数将概率分布变成一个单热分布;H(·)表示伪标签上的标准交叉熵损失。
自标记模块允许共享模型进行自我修正,提高共享模型的性能。
本发明还提供一种基于语义聚类的联邦无监督图像分类方法,该方法包括以下步骤S201~S202:
步骤S201:获取待分类的图像。
步骤S202:将图像输入如上文所述基于语义聚类的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到图像的类别。
本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现基于语义聚类的联邦无监督图像分类模型训练方法和基于语义聚类的联邦无监督图像分类方法的步骤。
与上述方法相应地,本发明还提供了一种设备,该设备包括计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机指令,所述处理器用于执行所述存储器中存储的计算机指令,当所述计算机指令被处理器执行时该设备实现如前所述方法的步骤。
本发明实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时以实现前述边缘计算服务器部署方法的步骤。该计算机可读存储介质可以是有形存储介质,诸如随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、软盘、硬盘、可移动存储盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质。
综上所述,本发明提供一种基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备,包括:获取客户端本地数据集,以构建模型的训练集。获取初始模型,包括语义聚类模型和预训练得到的编码器网络。在编码器网络训练中,采用随机数据增强作为对比学习的前置任务,将本地数据集随机增强两次生成两个视图,输入初始编码器网络,提取特征向量;并采用归一化温度交叉熵损失函数进行对比学习,以训练得到编码器网络。利用训练好的编码器网络提取本地数据集的特征向量,采用最近邻语义聚类方法,将对比学习获得的先验知识集成到聚类损失函数中,根据特征相似度对样本进行分类,实现无监督学习。基于各客户端模型参数加权聚合构建共享模型,考虑到各客户端数据的非独立同分布特性,设计动态更新机制,根据客户端模型的权重发散程度利用共享模型参数更新初始图像分类模型;针对聚类过程中存在的聚类错误问题,在共享模型中设置自标记模块,利用高度自信的预测样本来纠正聚类错误,提升模型的分类性能。基于本发明提供的无监督学习方法训练得到的图像分类模型具备高准确率、可扩展性且适用非独立同分布场景。
本领域普通技术人员应该可以明白,结合本文中所公开的实施方式描述的各示例性的组成部分、系统和方法,能够以硬件、软件或者二者的结合来实现。具体究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。当以硬件方式实现时,其可以例如是电子电路、专用集成电路(ASIC)、适当的固件、插件、功能卡等等。当以软件方式实现时,本发明的元素是被用于执行所需任务的程序或者代码段。程序或者代码段可以存储在机器可读介质中,或者通过载波中携带的数据信号在传输介质或者通信链路上传送。
需要明确的是,本发明并不局限于上文所描述并在图中示出的特定配置和处理。为了简明起见,这里省略了对已知方法的详细描述。在上述实施例中,描述和示出了若干具体的步骤作为示例。但是,本发明的方法过程并不限于所描述和示出的具体步骤,本领域的技术人员可以在领会本发明的精神后,作出各种改变、修改和添加,或者改变步骤之间的顺序。
本发明中,针对一个实施方式描述和/或例示的特征,可以在一个或更多个其它实施方式中以相同方式或以类似方式使用,和/或与其他实施方式的特征相结合或代替其他实施方式的特征。
以上所述仅为本发明的优选实施例,并不用于限制本发明,对于本领域的技术人员来说,本发明实施例可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
获取本地数据集,所述本地数据集包含多个样本,每个样本包含一张图像;
获取初始模型,所述初始模型包括语义聚类模型和预训练得到的编码器网络;其中,将所述本地数据集的样本进行两次随机增强,生成第一视图和第二视图;将所述第一视图和所述第二视图一同输入初始编码器网络,提取第一特征向量和第二特征向量;采用所述本地数据集对所述初始编码器网络进行训练,并构建第一特征向量和第二特征向量之间的对比损失,以得到训练好的编码器网络;将所述本地数据集按批输入所述语义聚类模型,利用所述编码器网络提取对应样本的特征向量;基于预设神经网络从所述特征向量中提取对应样本的Top-K近邻样本,通过预设Softmax函数计算对应样本所属于不同集群的向量值,以得到对应样本的类别;
采用所述本地数据集对所述初始模型进行训练,并构建聚类损失,利用所述聚类损失对所述初始模型的参数进行迭代,以得到初始图像分类模型;
将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,并采用指数移动平均更新所述初始图像分类模型,以得到最终的图像分类模型;其中,所述共享模型还包括自标记模块,所述自标记模块为基于所述共享模型得到的类别设置伪标签,并构建基于所述共享模型得到的类别与相应伪标签之间的交叉熵损失,利用所述交叉熵损失更新所述共享模型参数。
2.根据权利要求1所述的基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,将所述本地数据集的样本进行两次随机增强,所述随机增强至少包括空间变换裁剪、旋转、调节饱和度、调节对比度、调节色调、调节颜色、调节亮度和调节灰度中的一种或多种组合操作。
3.根据权利要求1所述的基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,构建第一特征向量和第二特征向量之间的对比损失,所述对比损失采用归一化温度交叉熵损失。
7.根据权利要求6所述的基于语义聚类的联邦无监督图像分类模型训练方法,其特征在于,还包括:
计算所述初始图像分类模型在训练时的模型散度,当所述模型散度大于所述预设阈值时,客户端使用所述共享模型的权重进行更新;当所述模型散度小于或等于所述预设阈值时,客户端使用其初始图像分类模型和所述共享模型的权重组合进行更新。
9.一种基于语义聚类的联邦无监督图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
获取待分类的图像;
将所述图像输入如权利要求1至8中任一项所述基于语义聚类的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1至9中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310205865.0A CN116310530A (zh) | 2023-03-03 | 2023-03-03 | 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310205865.0A CN116310530A (zh) | 2023-03-03 | 2023-03-03 | 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116310530A true CN116310530A (zh) | 2023-06-23 |
Family
ID=86779136
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310205865.0A Pending CN116310530A (zh) | 2023-03-03 | 2023-03-03 | 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116310530A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117274778A (zh) * | 2023-11-21 | 2023-12-22 | 浙江啄云智能科技有限公司 | 基于无监督和半监督的图像搜索模型训练方法和电子设备 |
CN117392483A (zh) * | 2023-12-06 | 2024-01-12 | 山东大学 | 基于增强学习的相册分类模型训练加速方法、系统及介质 |
-
2023
- 2023-03-03 CN CN202310205865.0A patent/CN116310530A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117274778A (zh) * | 2023-11-21 | 2023-12-22 | 浙江啄云智能科技有限公司 | 基于无监督和半监督的图像搜索模型训练方法和电子设备 |
CN117274778B (zh) * | 2023-11-21 | 2024-03-01 | 浙江啄云智能科技有限公司 | 基于无监督和半监督的图像搜索模型训练方法和电子设备 |
CN117392483A (zh) * | 2023-12-06 | 2024-01-12 | 山东大学 | 基于增强学习的相册分类模型训练加速方法、系统及介质 |
CN117392483B (zh) * | 2023-12-06 | 2024-02-23 | 山东大学 | 基于增强学习的相册分类模型训练加速方法、系统及介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109840531B (zh) | 训练多标签分类模型的方法和装置 | |
CN116310530A (zh) | 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 | |
US20180225549A1 (en) | Media content analysis system and method | |
CN111444966A (zh) | 媒体信息分类方法及装置 | |
US10204090B2 (en) | Visual recognition using social links | |
CN115552429A (zh) | 使用非iid数据的横向联邦学习方法和系统 | |
CN116261731A (zh) | 基于多跳注意力图神经网络的关系学习方法与系统 | |
WO2022252458A1 (zh) | 一种分类模型训练方法、装置、设备及介质 | |
JP6158882B2 (ja) | 生成装置、生成方法、及び生成プログラム | |
Adam et al. | Toward smart traffic management with 3D placement optimization in UAV-assisted NOMA IIoT networks | |
CN116229170A (zh) | 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 | |
CN111507406A (zh) | 一种用于优化神经网络文本识别模型的方法与设备 | |
CN115049076A (zh) | 基于原型网络的迭代聚类式联邦学习方法 | |
Chen et al. | Patch selection denoiser: An effective approach defending against one-pixel attacks | |
Ding et al. | Full‐reference image quality assessment using statistical local correlation | |
WO2017188048A1 (ja) | 作成装置、作成プログラム、および作成方法 | |
Chamoso et al. | Social computing for image matching | |
Jiao et al. | [Retracted] An Improved Cuckoo Search Algorithm for Multithreshold Image Segmentation | |
Shono | Application of support vector regression to CPUE analysis for southern bluefin tuna Thunnus maccoyii, and its comparison with conventional methods | |
CN117095252A (zh) | 目标检测方法 | |
CN116883786A (zh) | 图数据增广方法、装置、计算机设备及可读存储介质 | |
CN113486736B (zh) | 一种基于活跃子空间与低秩进化策略的黑盒对抗攻击方法 | |
CN115797642A (zh) | 基于一致性正则化与半监督领域自适应图像语义分割算法 | |
JP6214073B2 (ja) | 生成装置、生成方法、及び生成プログラム | |
CN112329692B (zh) | 一种有限样本条件下跨场景人体行为无线感知方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |