CN116229170A - 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 - Google Patents
基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 Download PDFInfo
- Publication number
- CN116229170A CN116229170A CN202310199005.0A CN202310199005A CN116229170A CN 116229170 A CN116229170 A CN 116229170A CN 202310199005 A CN202310199005 A CN 202310199005A CN 116229170 A CN116229170 A CN 116229170A
- Authority
- CN
- China
- Prior art keywords
- loss
- model
- domain
- task
- classification
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 61
- 238000013145 classification model Methods 0.000 title claims abstract description 56
- 230000005012 migration Effects 0.000 title claims abstract description 46
- 238000013508 migration Methods 0.000 title claims abstract description 46
- 238000012549 training Methods 0.000 title claims abstract description 43
- 238000003062 neural network model Methods 0.000 claims abstract description 40
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 10
- 230000007246 mechanism Effects 0.000 claims abstract description 7
- 238000013528 artificial neural network Methods 0.000 claims description 9
- 238000005457 optimization Methods 0.000 claims description 6
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000004590 computer program Methods 0.000 claims description 4
- 238000000605 extraction Methods 0.000 claims description 3
- 230000002776 aggregation Effects 0.000 claims description 2
- 238000004220 aggregation Methods 0.000 claims description 2
- 230000006870 function Effects 0.000 description 23
- 230000008901 benefit Effects 0.000 description 6
- 230000003044 adaptive effect Effects 0.000 description 5
- 230000004048 modification Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 241000282326 Felis catus Species 0.000 description 2
- 238000007792 addition Methods 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 241000282994 Cervidae Species 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 239000003550 marker Substances 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明提供一种基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备:获取包含完整标签和部分标签的本地数据集;获取初始神经网络模型,其包括自适应增量层和深度迁移模块;自适应增量层为在初始神经网络模型的每个卷积层后添加一个全连接层;在深度迁移模块中构建域分类和域混淆的竞争机制,并采用知识蒸馏方法保存相关类别之间的信息;采用本地数据集为模型进行训练,构建域分类损失、域混淆损失和软标签损失的联合损失,并根据分类任务平均精度确定任务优先级,训练得到初始图像分类模型;基于各客户端模型参数构建共享模型,并根据共享模型参数更新初始图像分类模型。本发明提供的图像分类模型精度高且能保留个性化局部模型。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备。
背景技术
随着物联网设备的快速增加,移动设备、智能手机等通过开放的通信网络平台进行连接,为通过数据共享提高新兴应用的服务质量开辟了新的可能性。联邦学习技术对于解决物联网场景下的隐私保护训练是有效的,如用户习惯预测、个性化推荐和无线网络优化,其成功的部分原因是对多个客户端的大量标记数据进行训练。现有的联合学习方法并不能通过对标签数据的大量训练实现泛化性能。然而,在现实的物联网场景中,由于用户习惯或没有足够的专业知识来正确标注数据,客户的数据总是伴随着很少的标签。示例性的,一手机健身应用程序可以纠正用户的身体姿势,但在这种情况下,由于用户可能无法评估自己的姿势是否合格,从而无法为应用程序提供对应的数据标签。因此,大型标记数据集所带来的性能优势是以成本和应用程序有限为代价的。
传统的联邦学习通常是依靠大量的标签数据来实现模型性能提升的,但在半监督环境下,由于训练数据标签的匮乏,使得传统的分布式模型优化算法不再适用。由此引入联邦半监督学习,联邦半监督学习方法将半监督学习方法与联邦学习框架相结合,通过局部监督训练和设备间的知识转移迭代优化共享模型。现有方法在相应应用领域取得了一定的成功,但依旧以下两大问题:一是传统的联邦半监督学习通常基于半监督算法实现设备本地数据集标签数据与无标签数据的知识迁移,忽略了设备间的迁移需求。二是由于联邦学习是以训练一个泛化的共享模型为目标,较难实现在保留每类客户端独特的任务需求的情况下,实现模型一致性和个性化间的平衡。
发明内容
鉴于此,本发明实施例提供了一种基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备,以消除或改善现有技术中存在的一个或更多个缺陷,解决现有联邦半监督学习方法在稀疏标签环境下造成模型精度低、无法实现标签数据与无标签数据间知识迁移以及无法保留个性化局部模型的问题。
一方面,本发明提供一种基于任务迁移的联邦无监督图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
获取本地数据集,所述本地数据集包括含完整类别标签的源数据和含部分类别标签的目标数据,每个数据包含一张图像样本;
获取初始神经网络模型,所述初始神经网络模型包括自适应增量层和深度迁移模块;所述自适应增量层为在所述初始神经网络模型的每个卷积层后添加一个全连接层;将所述本地数据集的图像样本按批次输入所述初始神经网络模型进行特征提取,利用预设域分类器判别相应图像样本属于所述源数据或所述目标数据,并利用预设预混淆层通过域混淆对齐域,构建域混淆竞争机制;采用知识蒸馏方法利用所述源数据计算每个类别之间的关系数值,并求取各关系数值的平均值作为与相应源数据具备相关性的目标数据的软标签,以输出相应图像样本的类别;
采用所述本地数据集对所述初始神经网络进行训练,构建域分类损失、域混淆损失和软标签损失,根据所述域分类损失、所述域混淆损失和所述软标签损失构建联合损失,并计算各分类任务的平均精度,根据所述平均精度确定每类任务在损失函数中的权重,利用所述联合损失对所述初始神经网络模型的参数进行迭代,得到初始图像分类模型;
将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,基于所述自适应增量层更新所述初始图像分类模型,以得到最终的图像分类模型。
在本发明的一些实施例中,所述域分类损失的损失函数定义为Softmax交叉熵损失函数,计算式为:
在本发明的一些实施例中,所述域混淆损失的计算式为:
其中,Lconf表示所述域混淆损失;d表示所述本地数据集Dk中的一个数据;pd表示对应网络输出的特征向量。
在本发明的一些实施例中,所述软标签损失的计算式为:
其中,Lsoft表示所述软标签损失;d表示所述本地数据集Dk中的一个数据;y表示所述初始神经网络模型判定的图像类别;ysoft表示数据d的软标签;q表示知识蒸馏后的网络输出。
在本发明的一些实施例中,所述域分类损失、所述域混淆损失和所述软标签损失采用加权组合构建联合损失,所述联合损失的计算式为:
在本发明的一些实施例中,计算各分类任务的平均精度,根据所述平均精度确定每类任务在损失函数中的权重,还包括:
计算各分类任务的平均精度,作为所述初始神经网络模型的关键性能指标,利用所述关键性能指标作为度量每类任务在损失函数中的权重的指标,每类任务的权重定义为:
其中,λ用来控制每类任务的相对优先级;κt表示任务的平均精度。
在本发明的一些实施例中,采用动态缩放的交叉熵损失代替交叉熵损失,以降低简单分类任务的权重,总的图像分类损失的计算式为:
其中,Lcla表示总的图像分类损失;t表示总分类任务T中的一个任务;d表示所述本地数据集Dk中的一个数据;Lc表示包含类别c的分类任务。
在本发明的一些实施例中,各客户端与所述全局服务器构建图像分类系统,所述系统通过最小化总损失函数以定义目标函数,所述目标函数的计算式为:
Lk(ω)=η1Lcla(Xs,Xt;ω)+η2Ltra(Xs,Xt;ω);
其中,Ltotal(ω)表示所述目标函数;D表示所有客户端本地数据集集合;k表示所有客户端K中的一个客户端;Dk表示客户端k的本地数据集;Lk(ω)表示客户端k的损失函数;η1和η2用于平衡多个优化目标;Lcla表示总的图像分类损失;Ltra表示所述联合损失;Xs表示客户端k的本地数据集中的源数据;Xt表示客户端k的本地数据集中的目标数据;ω表示所述初始神经网络模型参数。
另一方面,本发明提供一种基于任务迁移的联邦无监督图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
获取待分类的图像;
将所述图像输入如上文中任一项所述基于任务迁移的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
另一方面,本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上文中提及的任意一项所述方法的步骤。
本发明的有益效果至少是:
本发明提供一种基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备,通过获取包含完整标签和部分标签的本地数据集,构建源域和目标域。获取初始神经网络模型,包括自适应增量层和深度迁移模块。自适应增量层为在初始神经网络模型的每个卷积层后添加一个全连接层,在模型更新时保留客户端的个性化局部模型,提升模型泛化能力。在深度迁移模块中构建域分类和域混淆的竞争机制,并采用知识蒸馏方法实现完整标签数据和部分标签数据间的域混淆,降低源域与目标域间的距离,实现部分标签数据的训练,并设计软标签来调整类别间的信息,实现任务迁移,对客户端有效进行半监督学习。采用本地数据集为模型进行训练,构建域分类损失、域混淆损失和软标签损失的联合损失,并引入动态任务损失来自动调整任务之间的权重,确定任务优先级,最终训练得到初始图像分类模型。基于各客户端模型参数构建共享模型,并根据共享模型参数更新初始图像分类模型。基于本发明提供的方法训练得到的图像分类模型精度高且能保留局部个性化模型。
本发明的附加优点、目的,以及特征将在下面的描述中将部分地加以阐述,且将对于本领域普通技术人员在研究下文后部分地变得明显,或者可以根据本发明的实践而获知。本发明的目的和其它优点可以通过在说明书以及附图中具体指出的结构实现到并获得。
本领域技术人员将会理解的是,能够用本发明实现的目的和优点不限于以上具体所述,并且根据以下详细说明将更清楚地理解本发明能够实现的上述和其他目的。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,并不构成对本发明的限定。在附图中:
图1为本发明一实施例中基于任务迁移的联邦无监督图像分类模型训练方法的步骤示意图。
图2为本发明一实施例中基于任务迁移的联邦无监督图像分类模型训练方法的结构流程示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚明白,下面结合实施方式和附图,对本发明做进一步详细说明。在此,本发明的示意性实施方式及其说明用于解释本发明,但并不作为对本发明的限定。
在此,还需要说明的是,为了避免因不必要的细节而模糊了本发明,在附图中仅仅示出了与根据本发明的方案密切相关的结构和/或处理步骤,而省略了与本发明关系不大的其他细节。
应该强调,术语“包括/包含”在本文使用时指特征、要素、步骤或组件的存在,但并不排除一个或更多个其它特征、要素、步骤或组件的存在或附加。
在此,还需要说明的是,如果没有特殊说明,术语“连接”在本文不仅可以指直接连接,也可以表示存在中间物的间接连接。
在下文中,将参考附图描述本发明的实施例。在附图中,相同的附图标记代表相同或类似的部件,或者相同或类似的步骤。
这里需要强调的是,在下文中提及的各步骤标记并不是对各步骤先后顺序的限定,而应当理解为可以按照实施例中提及的顺序执行步骤,也可以不同于实施例中的顺序,或者若干步骤同时执行。
为了解决现有联邦半监督学习方法在稀疏标签环境下造成模型精度低、无法实现标签数据与无标签数据间知识迁移以及无法保留个性化局部模型的问题,本发明提供一种基于任务迁移的联邦无监督图像分类模型训练方法,如图1所示,该方法在各客户端执行,包括以下步骤S101~S104:
步骤S101:获取本地数据集,其中,本地数据集包括含完整类别标签的源数据和含部分类别标签的目标数据,每个数据包含一张图像样本。
步骤S102:获取初始神经网络模型,其中,初始神经网络模型包括自适应增量层和深度迁移模块;自适应增量层为在初始神经网络模型的每个卷积层后添加一个全连接层;将本地数据集的图像样本按批次输入初始神经网络模型进行特征提取,利用预设域分类器判别相应图像样本属于源数据或目标数据,并利用预设预混淆层通过域混淆对齐域,构建域混淆竞争机制;采用知识蒸馏方法利用源数据计算每个类别之间的关系数值,并求取各关系数值的平均值作为与相应源数据具备相关性的目标数据的软标签,以输出相应图像样本的类别。
步骤S103:采用本地数据集对初始神经网络进行训练,构建域分类损失、域混淆损失和软标签损失,根据域分类损失、域混淆损失和软标签损失构建联合损失,并计算各分类任务的平均精度,根据平均精度确定每类任务在损失函数中的权重,利用联合损失对初始神经网络模型的参数进行迭代,得到初始图像分类模型。
步骤S104:将初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;其中,共享模型由全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收共享模型的参数,基于自适应增量层更新初始图像分类模型,以得到最终的图像分类模型。
本发明提出了一种基于域和任务迁移的联邦无监督学习框架,依托传统的分布式学习架构和深度神经网络框架进行模型训练。其中,深度神经网络框架可以选用PyTorch、TensorFlow等。
在步骤S101中,示例性的,假设一组客户端C={C1,C2,...,Ck}和一个全局服务器G,每个客户端都拥有一个本地数据集Dk={Xs,Xt},其中,Xs表示含完整类别标签的源数据,Xt表示只含部分类别标签的目标数据。
获取客户端Ck对应的本地数据集Dk,用于训练客户端Ck的局部模型。
在步骤S102中,获取初始神经网络模型,示例性的,选用VGG-Net深度神经网络。初始神经网络模型包括自适应增量层和深度迁移模块,其中,自适应增量层用于根据共享模型参数更新客户端局部模型时,保留客户端个性化局部模型;深度迁移模块用于域混淆和任务迁移。
具体的,将客户端Ck的本地数据集Dk输入初始神经网络模型,对各图像样本进行特征提取。本发明在初始神经网络最后一个全连接层之前添加一个域混淆层,用于通过域混淆对齐域,其中,源域是指含完整类别标签的源数据,目标域是指仅含部分类别标签的目标数据。利用预设域分类器分类各图像样本对应的域,即判断各图像样本数据是源数据还是目标数据,以从标记完整的源域学习得到标记稀疏的目标域的表示。
在一些实施例中,对于本地数据集中任一数据(图像样本),域分类损失的损失函数可以定义为Softmax交叉熵损失函数,计算式如公式(1)所示:
再进行域混淆,在一些实施例中,域混淆损失的计算式如公式(2)所示:
公式(2)中,Lconf表示域混淆损失;d表示本地数据集Dk中的一个数据;pd表示对应网络输出的特征向量。
基于域分类损失和域混淆损失,训练初始神经网络模型既能分类源数据和目标数据,又能混淆源数据和目标数据,域分类损失和域混淆损失两者共同构成一个域混淆竞争机制,域分类损失用来更好的分类,域混淆损失用来最大程度的混淆源数据和目标数据。
为了更好的对齐源数据和目标数据,采用知识蒸馏方法来保存相关类别之间的信息。利用源数据计算每个类别之间的关系数值,并求取各关系数值的平均值作为与相应源数据具备相关性的目标数据的软标签。示例性的,对于马的图像,在软标签中是一个概率向量,在这个向量中,马和鹿的值更接近,而与鸟类图像的值差别较大。
在一些实施例中,采用软标签损失而不是标准的Softmax损失作为任务传输损失,软标签损失的计算式如公式(3)所示:
公式(3)中,Lsoft表示软标签损失;d表示本地数据集Dk中的一个数据;y表示初始神经网络模型判定的图像类别;ysoft表示数据d的软标签;q表示知识蒸馏后的网络输出。
在步骤S103中,采用本地数据集对初始神经网络进行训练,构建域分类损失、域混淆损失和软标签损失,根据域分类损失、域混淆损失和软标签损失构建联合损失,利用联合损失对初始神经网络模型的参数进行迭代,得到初始图像分类模型。
在一些实施例中,域分类损失、域混淆损失和软标签损失采用加权组合构建联合损失,联合损失的计算式如公式(4)所示:
通过最小化联合损失,可以同时对齐域和分类任务,实现含完整类别标签的源数据和含部分类别标签的目标数据之间的知识转移。
同时,由于本地数据集中缺乏相当一部分的数据标签,初始神经网络会产生混乱的结果。受动态任务优先级的启发,本发明做了进一步改进,利用关键性能指标作为度量每类任务在损失函数中的权重的指标,以优化网络性能。
在本发明中,计算各分类任务的平均精度,作为初始神经网络模型的关键性能指标,利用关键性能指标作为度量每类任务在损失函数中的权重的指标,每类任务的权重可定义如公式(5)所示:
公式(5)中,λ用来控制每类任务的相对优先级;κt表示任务的平均精度。
采用动态缩放的交叉熵损失代替交叉熵损失,以降低简单分类任务的权重,示例性的,马的分类任务的平均精度低于猫的分类任务的平均精度,则降低猫的分类任务权重,总的图像分类损失的计算式如公式(6)所示:
其中,Lcla表示总的图像分类损失;t表示总分类任务T中的一个任务;d表示本地数据集Dk中的一个数据;Lc表示包含类别c的分类任务。
在步骤S104中,将各客户端的初始图像分类模型参数发送至全局服务器,生成共享模型。其中,共享模型由全局服务器根据各客户端初始图像分类模型参数加权聚合得到。各客户端接收共享模型的参数,基于自适应增量层更新初始图像分类模型,以得到最终的图像分类模型。
具体的,自适应增量层是指在初始神经网络模型的每个卷积层后添加一个全连接层,该层元素由0和1构成,在各客户端初始神经网络模型训练过程中,自适应增量层的参数保持固定,在给定交叉熵损失的情况下反向传播学习增量层的权值。在全局服务器更新共享模型时,将全连接层添加至卷积层后,以保留各客户端本地个性化模型。
在一些实施例中,基于各客户端与全局服务器构建图像分类系统,图像分类系统通过最小化总损失函数以定义目标函数,该目标函数的计算式如公式(7)所示:
公式(7)中,Ltotal(ω)表示目标函数;D表示所有客户端本地数据集集合;k表示所有客户端K中的一个客户端;Dk表示客户端k的本地数据集;Lk(ω)表示客户端k的损失函数。
其中,对于每个客户端本地的图像分类模型,其损失函数如公式(8)所示:
Lk(ω)=η1Lcla(Xs,Xt;ω)+η2Ltra(Xs,Xt;ω); (8)
公式(8)中,Lk(ω)表示客户端k的损失函数;η1和η2用于平衡多个优化目标;Lcla表示总的图像分类损失;Ltra表示联合损失;Xs表示客户端k的本地数据集中的源数据;Xt表示客户端k的本地数据集中的目标数据;ω表示初始神经网络模型参数。
其中,Lcla用于提高模型分类器的性能,Ltra用于混淆域以更好的学习目标数据的表示,η1和η2用于平衡多个优化目标。
本发明还提供一种基于任务迁移的联邦无监督图像分类方法,该方法在客户端执行,包括以下步骤S201~S202:
步骤S201:获取待分类的图像。
步骤S202:将图像输入如上文中所述基于任务迁移的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到图像的类别。
本发明还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现基于任务迁移的联邦无监督图像分类模型训练方法和基于任务迁移的联邦无监督图像分类方法的步骤。
与上述方法相应地,本发明还提供了一种设备,该设备包括计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机指令,所述处理器用于执行所述存储器中存储的计算机指令,当所述计算机指令被处理器执行时该设备实现如前所述方法的步骤。
本发明实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时以实现前述边缘计算服务器部署方法的步骤。该计算机可读存储介质可以是有形存储介质,诸如随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、软盘、硬盘、可移动存储盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质。
综上所述,本发明提供一种基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备,包括:获取包含完整标签和部分标签的本地数据集,构建源域和目标域。获取初始神经网络模型,包括自适应增量层和深度迁移模块。自适应增量层为在初始神经网络模型的每个卷积层后添加一个全连接层,在模型更新时保留客户端的个性化局部模型,提升模型泛化能力。在深度迁移模块中构建域分类和域混淆的竞争机制,并采用知识蒸馏方法实现完整标签数据和部分标签数据间的域混淆,降低源域与目标域间的距离,实现部分标签数据的训练,并设计软标签来调整类别间的信息,实现任务迁移,对客户端有效进行半监督学习。采用本地数据集为模型进行训练,构建域分类损失、域混淆损失和软标签损失的联合损失,并引入动态任务损失来自动调整任务之间的权重,确定任务优先级,最终训练得到初始图像分类模型。基于各客户端模型参数构建共享模型,并根据共享模型参数更新初始图像分类模型。基于本发明提供的方法训练得到的图像分类模型精度高且能保留个性化局部模型。
本领域普通技术人员应该可以明白,结合本文中所公开的实施方式描述的各示例性的组成部分、系统和方法,能够以硬件、软件或者二者的结合来实现。具体究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。当以硬件方式实现时,其可以例如是电子电路、专用集成电路(ASIC)、适当的固件、插件、功能卡等等。当以软件方式实现时,本发明的元素是被用于执行所需任务的程序或者代码段。程序或者代码段可以存储在机器可读介质中,或者通过载波中携带的数据信号在传输介质或者通信链路上传送。
需要明确的是,本发明并不局限于上文所描述并在图中示出的特定配置和处理。为了简明起见,这里省略了对已知方法的详细描述。在上述实施例中,描述和示出了若干具体的步骤作为示例。但是,本发明的方法过程并不限于所描述和示出的具体步骤,本领域的技术人员可以在领会本发明的精神后,作出各种改变、修改和添加,或者改变步骤之间的顺序。
本发明中,针对一个实施方式描述和/或例示的特征,可以在一个或更多个其它实施方式中以相同方式或以类似方式使用,和/或与其他实施方式的特征相结合或代替其他实施方式的特征。
以上所述仅为本发明的优选实施例,并不用于限制本发明,对于本领域的技术人员来说,本发明实施例可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于任务迁移的联邦无监督图像分类模型训练方法,其特征在于,所述方法在各客户端执行,包括以下步骤:
获取本地数据集,所述本地数据集包括含完整类别标签的源数据和含部分类别标签的目标数据,每个数据包含一张图像样本;
获取初始神经网络模型,所述初始神经网络模型包括自适应增量层和深度迁移模块;所述自适应增量层为在所述初始神经网络模型的每个卷积层后添加一个全连接层;将所述本地数据集的图像样本按批次输入所述初始神经网络模型进行特征提取,利用预设域分类器判别相应图像样本属于所述源数据或所述目标数据,并利用预设预混淆层通过域混淆对齐域,构建域混淆竞争机制;采用知识蒸馏方法利用所述源数据计算每个类别之间的关系数值,并求取各关系数值的平均值作为与相应源数据具备相关性的目标数据的软标签,以输出相应图像样本的类别;
采用所述本地数据集对所述初始神经网络进行训练,构建域分类损失、域混淆损失和软标签损失,根据所述域分类损失、所述域混淆损失和所述软标签损失构建联合损失,并计算各分类任务的平均精度,根据所述平均精度确定每类任务在损失函数中的权重,利用所述联合损失对所述初始神经网络模型的参数进行迭代,得到初始图像分类模型;
将所述初始图像分类模型的模型参数发送至全局服务器,以生成共享模型;所述共享模型由所述全局服务器根据各客户端初始图像分类模型参数加权聚合得到;接收所述共享模型的参数,基于所述自适应增量层更新所述初始图像分类模型,以得到最终的图像分类模型。
8.根据权利要求7所述的基于任务迁移的联邦无监督图像分类模型训练方法,其特征在于,各客户端与所述全局服务器构建图像分类系统,所述系统通过最小化总损失函数以定义目标函数,所述目标函数的计算式为:
Lk(ω)=η1Lcla(Xs,Xt;ω)+η2Ltra(Xs,Xt;ω);
其中,Ltotal(ω)表示所述目标函数;D表示所有客户端本地数据集集合;k表示所有客户端K中的一个客户端;Dk表示客户端k的本地数据集;Lk(ω)表示客户端k的损失函数;η1和η2用于平衡多个优化目标;Lcla表示总的图像分类损失;Ltra表示所述联合损失;Xs表示客户端k的本地数据集中的源数据;Xt表示客户端k的本地数据集中的目标数据;ω表示所述初始神经网络模型参数。
9.一种基于任务迁移的联邦无监督图像分类方法,其特征在于,该方法在客户端执行,包括以下步骤:
获取待分类的图像;
将所述图像输入如权利要求1至8中任一项所述基于任务迁移的联邦无监督图像分类模型训练方法得到的图像分类模型,以得到所述图像的类别。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1至9中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310199005.0A CN116229170A (zh) | 2023-03-03 | 2023-03-03 | 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310199005.0A CN116229170A (zh) | 2023-03-03 | 2023-03-03 | 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116229170A true CN116229170A (zh) | 2023-06-06 |
Family
ID=86582159
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310199005.0A Pending CN116229170A (zh) | 2023-03-03 | 2023-03-03 | 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116229170A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117130790A (zh) * | 2023-10-23 | 2023-11-28 | 云南蓝队云计算有限公司 | 一种云计算资源池动态调度方法 |
CN117811846A (zh) * | 2024-02-29 | 2024-04-02 | 浪潮电子信息产业股份有限公司 | 基于分布式系统的网络安全检测方法、系统、设备及介质 |
-
2023
- 2023-03-03 CN CN202310199005.0A patent/CN116229170A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117130790A (zh) * | 2023-10-23 | 2023-11-28 | 云南蓝队云计算有限公司 | 一种云计算资源池动态调度方法 |
CN117130790B (zh) * | 2023-10-23 | 2023-12-29 | 云南蓝队云计算有限公司 | 一种云计算资源池动态调度方法 |
CN117811846A (zh) * | 2024-02-29 | 2024-04-02 | 浪潮电子信息产业股份有限公司 | 基于分布式系统的网络安全检测方法、系统、设备及介质 |
CN117811846B (zh) * | 2024-02-29 | 2024-05-28 | 浪潮电子信息产业股份有限公司 | 基于分布式系统的网络安全检测方法、系统、设备及介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109840531B (zh) | 训练多标签分类模型的方法和装置 | |
CN116229170A (zh) | 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备 | |
CN109919183B (zh) | 一种基于小样本的图像识别方法、装置、设备及存储介质 | |
CN111444966A (zh) | 媒体信息分类方法及装置 | |
CN112115967B (zh) | 一种基于数据保护的图像增量学习方法 | |
CN115731424B (zh) | 基于强化联邦域泛化的图像分类模型训练方法及系统 | |
WO2022252458A1 (zh) | 一种分类模型训练方法、装置、设备及介质 | |
CN112288572B (zh) | 业务数据处理方法及计算机设备 | |
CN110598869B (zh) | 基于序列模型的分类方法、装置、电子设备 | |
CN116310530A (zh) | 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备 | |
Dinov et al. | Black box machine-learning methods: Neural networks and support vector machines | |
CN112270334B (zh) | 一种基于异常点暴露的少样本图像分类方法及系统 | |
KR102093079B1 (ko) | 레이블 데이터를 이용한 생성적 적대 신경망 기반의 분류 시스템 및 방법 | |
CN113591892A (zh) | 一种训练数据的处理方法及装置 | |
CN113570512A (zh) | 一种图像数据处理方法、计算机及可读存储介质 | |
CN115063374A (zh) | 模型训练、人脸图像质量评分方法、电子设备及存储介质 | |
CN114676755A (zh) | 基于图卷积网络的无监督域自适应的分类方法 | |
CN114861936A (zh) | 一种基于特征原型的联邦增量学习方法 | |
CN114255381A (zh) | 图像识别模型的训练方法、图像识别方法、装置及介质 | |
Suyal et al. | An Agile Review of Machine Learning Technique | |
CN113420879A (zh) | 多任务学习模型的预测方法及装置 | |
CN112861601A (zh) | 生成对抗样本的方法及相关设备 | |
CN116229172A (zh) | 基于对比学习的联邦少样本图像分类模型训练方法、分类方法及设备 | |
CN115081626B (zh) | 基于表征学习的个性化联邦少样本学习系统及方法 | |
KR102093090B1 (ko) | 레이블 데이터를 이용한 생성적 적대 신경망 기반의 분류 시스템 및 방법 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |