CN113673555B - 一种基于记忆体的无监督域适应图片分类方法 - Google Patents

一种基于记忆体的无监督域适应图片分类方法 Download PDF

Info

Publication number
CN113673555B
CN113673555B CN202110776679.3A CN202110776679A CN113673555B CN 113673555 B CN113673555 B CN 113673555B CN 202110776679 A CN202110776679 A CN 202110776679A CN 113673555 B CN113673555 B CN 113673555B
Authority
CN
China
Prior art keywords
domain
target domain
memory
sample
class
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110776679.3A
Other languages
English (en)
Other versions
CN113673555A (zh
Inventor
李玺
郑良立
汪慧
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202110776679.3A priority Critical patent/CN113673555B/zh
Publication of CN113673555A publication Critical patent/CN113673555A/zh
Application granted granted Critical
Publication of CN113673555B publication Critical patent/CN113673555B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computational Linguistics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种基于记忆体的无监督域适应图片分类方法,用于在给定有标签的源域数据集和无标签的目标域数据集上,通过记忆体对齐源域和目标域的分布,将源域数据集的知识迁移到目标域数据集上,在目标域数据集上获得较高的图像分类准确率。具体包括如下步骤:获取源域数据集和目标域数据集;用神经网络模型提取数据集中图片的特征,使用聚类算法辅助记忆体逐类别地存储源域和目标域的特征;训练神经网络,以源域与目标域记忆体的分布的相似性作为条件约束神经网络;不断迭代,得到训练好的网络模型;将模型应用在目标域数据集上,进行图像分类任务。本发明适用于无监督域适应领域中的知识迁移,面对各类复杂的情况具有较佳的效果和鲁棒性。

Description

一种基于记忆体的无监督域适应图片分类方法
技术领域
本发明属于无监督域适应领域,特别地涉及一种基于记忆体的无监督域适应图片分类方法。
背景技术
无监督域适应被定义为如下问题:在给定有标签的源域数据集和无标签的目标域数据集的情况下,将有标签的源域数据集的知识迁移到无标签的目标域数据集上。这类任务可以有效地减轻深度学习训练过程中对有标签数据的需求,从而减少可以减少标签的手工标注的成本。该任务主要有两个关键点:第一是如何将源域数据集的知识迁移到目标域数据集上;第二是如何对目标域的无标签数据集的内在关系进行建模从而更好的利用迁移过来的知识。针对第一点,本发明认为在迁移过程中,无监督域适应任务不仅需要将源域的知识迁移到目标域上,而且应该在迁移的过程中减少知识受到的外部干扰;针对第二点,本发明认为即使是在没有标签的困难场景,数据集内部仍然存在着固有的内部关系,这种关系对更好地利用源域迁移过来是必要的。传统的方法一般关注的是知识的迁移,而没有更深层次的考虑迁移过程中的知识的抗干扰性以及对目标域数据集的运用,这在本任务中是非常重要的。
由于对抗学习的成功,目前基于对抗的方法逐渐被应用到无监督域适应领域中。现有的对抗方法主要是分别输入源域的一组图片和目标域的一组图片,得到两者的特征,并用对抗的方式让两组特征对抗,从而使得源域特征逼近目标域特征。然而,这类方法没有考虑对抗过程中,特征的内在分布会受到干扰,影响最终的迁移效果。
发明内容
为解决上述问题,本发明的目的在于提供一种基于记忆体的无监督域适应图片分类方法。该方法基于神经网络,目标是在无监督域适应的迁移过程中保证类内结构的紧致性。在无监督域适应中,类内的图片存在相互的关系,例如同类的图片由于有着相似的属性,颜色,形状,对比度等关联信息,其对应特征与同类图片的距离一般小于其与异类图片的特征的距离。针对这个发现,我们的工作设计了一个统一的端到端的深度学习框架对目标域的特征的类内结构进行的建模,并以此作为约束保持了无监督域适应中迁移的类内结构的紧致性,从而使得到的模型更具准确性和鲁棒性,最终在目标域数据集的图片分类任务上获取较高的准确率。
为实现上述目的,本发明的技术方案为:
一种基于记忆体的无监督域适应图片分类方法,其包括以下步骤:
S1、获取用于训练的有标签的源域数据集以及无标签的目标域数据集,所述源域数据集和目标域数据集均为图片数据集;
S2、用神经网络提取数据集中每张图片的特征,并根据聚类算法构建提取到的特征的类内结构;
S3、将源域和目标域每个类别的特征分别储存到源域和目标域对应类别的记忆体中;
S4、训练神经网络,并在训练过程中以源域和目标域记忆体的分布相似性作为条件约束神经网络;
S5、在完成由S2~S4构成的一轮更新训练后,利用训练过的模型重新提取每张图片的特征以及特征的类内结构,再根据新提取的特征更新记忆体,并以此源域和目标域记忆体的分布相似性作为条件约束进一步训练神经网络,完成新一轮的更新训练;
S6、不断重复步骤S5对神经网络进行迭代更新训练,直至网络收敛,得到最终的训练好的神经网络;
S7、在获得训练好的神经网络后,使用训练好的神经网络在目标域数据集上进行图像分类。
进一步地,步骤S1的具体实现步骤包括:
S11、获取包含ns个图片样本xs以及它们对应的标签ys的源域数据集
其中,表示源域数据集的第i个图片样本,/>表示样本/>的标签, {1,2,...,K}是源域数据集中样本所属的标签空间,共包含K类标签,/> 且i∈{1,2,...,ns};
S12、获取包含nt个图片样本xt但不含标签的目标域数据集
其中,表示目标域数据集的第j个图片样本,j∈{1,2,...,nt};目标域数据集中样本所属的标签空间和源域数据集的标签空间一致,即/>
进一步地,步骤S2的具体实现步骤包括:
S21、用一个神经网络的特征提取模块提取数据集每张图片的特征:
其中,是神经网络的特征提取模块,/>是特征提取模块随机初始化后的参数,/>是源域数据集的第i个图片样本的特征,/>是目标域数据集的第j个图片样本的特征;
S22、计算出源域中每个类别所有图片的特征的均值并用其初始化目标域每个类团/>的中心/>
其中,是源域第k类样本的数量,/>是源域第k类样本的特征的中心,/>是初始目标域特征的第k个类团/>的中心;
S23、计算每个目标类团的中心/>与每个目标样本特征的球面空间距离:
其中,‖·‖代表内部变量的模,<·,·>代表两个变量的向量点积;
S24、针对每个目标域图片样本特征将其按照距离/>进行排序后归类于距离最近的类团,所有目标域图片样本特征均归类后重新计算出每个类团/>的中心
其中,代表重新归类后属于类团/>的特征数量;
S25、不断交替迭代S23和S24的聚类算法,收敛后得到K个类团这K个类团代表目标域数据集的类内结构。
进一步地,步骤S3的具体实现步骤包括:
S31、将每个类团的类别k作为属于该类团的样本/>的标签/>
S32、从源域和目标域每个类团中各自抽取一部分特征分别装入源域和目标域的记忆体中:
其中,N为记忆体对应的长度,i∈{1,2,...,N},为源域记忆体第k类的第i个特征,/>为目标域记忆体第k类的第i个特征,/>为源域的第k个类团/>中的第i个特征,/>为目标域的第k个类团/>中的第i个特征。
进一步地,步骤S4中的具体实现步骤包括:
S41、通过优化第一loss函数l1(·,·),得到神经网络的特征提取模块和源域分类器模块/>在源域数据集/>上的最优参数/>和/>
S42、对于每一个目标域图片样本求得其特征:
其中r为限制特征ft的系数;通过目标域记忆体Mt重新预测的类别:
其中指目标域的第k类记忆体,d(·,·)表示计算L2距离;对于每一个样本/>若/>与此样本对应的类团的类别相同,则将此样本视为可靠样本,按照先进后出的原则将此样本的特征加入目标域的第k类记忆体/>中;
S43、通过优化第二loss函数l2(·,·,·)使得一个可反向传播的神经网络替代分类器模块学习到不可反向传播的聚类得到的类内结构:
其中,是目标域数据集/>的第k类可靠样本;/>是目标域记忆体第k类样本的中心:
是目标域记忆体第k类以外第y类样本的中心:
其中,代表目标域的第k类记忆体/>中的一个样本特征,/>代表目标域的第y类记忆体/>中的一个样本特征,/>为目标域的第k类记忆体/>中样本特征的数量,为标域的第y类记忆体/>中样本特征的数量;
S44、通过优化第三loss函数l3(·,·)提升源域和目标域记忆体分布的相似性:
其中,l3(·,·)为衡量分布差异的函数,计算公式为:
其中,分别为源域记忆体第i,j类特征的集合,/>为目标域记忆体第i,j类特征的集合,kernel为核函数。
进一步地,所述第一loss函数l1(·,·)为交叉熵损失函数,所述第二loss函数 l2(·,·,·)为Triple损失。
进一步地,所述核函数kernel的计算公式为:
其中N′为使用的核函数的数量,γn为:
进一步地,步骤S5中,第n轮更新训练的具体实现步骤包括:
S51、以第n-1轮更新训练得到的神经网络特征提取模块为基础,按照S2步骤的操作,重新提取特征并得到对应的源域数据集的类团/>和目标域数据集的类团/>完成第n轮的特征类内结构构造;
S52、按照S3步骤的操作,将第n轮特征类内结构构造得到的源域和目标域类团中的特征分别装入源域和目标域的记忆体中,完成第n轮的记忆体初始化;
S53、以第n-1轮更新训练得到的神经网络特征提取模块以及分类器模块/>为基础,按照S4步骤的操作进行第n轮的以源域和目标域记忆体的分布相似性为条件约束的神经网络训练,得到第n轮更新训练后的神经网络参数/>和/>
本发明的基于记忆体的无监督域适应图片分类方法,相比于现有的无监督域适应图片分类方法,具有以下有益效果:
首先,本发明的无监督域适应图片分类方法定义了无监督域适应中两个重要的问题:1.训练过程中的批是通过随机采样得到的,数据的类别分布会不平衡;2.聚类等无监督方法会引入噪声,影响模型的学习效果。通过寻求这两个方向的解决方法,可以有效地提高的无监督域适应的优化效果,提升了目标域数据集上的图片分类准确度。
其次,本发明的基于记忆体的无监督域适应图片分类方法基于无监督域适应的特点建立优化流程。在基于记忆体的无监督域适应图片分类方法的优化体系中,使用记忆体对源域和目标域的特征分布进行建模,并逐类别地对齐源域和目标域的分布,充分利用了目标域不同图片对应特征的内在联系,有效地提高神经网络模型迁移后的效果,提升了目标域数据集上的图片分类准确度。
最后,本发明的基于记忆体的无监督域适应图片分类方法使用了自步机制选择样本,提高了目标域数据集上图片分类任务的鲁棒性。
本发明的基于记忆体的无监督域适应图片分类方法,实现简单,适用范围广,具有良好的应用价值。本发明的基于记忆体的无监督域适应图片分类方法,能够有效减少神经网络模型的迁移时收到的干扰和提高神经网络模型迁移后在目标域数据集上的图片分类任务效果。
附图说明
图1为本发明的流程示意图;
图2为本发明提出的更新训练过程的框架的示意图。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
相反,本发明涵盖任何由权利要求定义的在本发明的精髓和范围上做的替代、修改、等效方法以及方案。进一步,为了使公众对本发明有更好的了解,在下文对本发明的细节描述中,详尽描述了一些特定的细节部分。对本领域技术人员来说没有这些细节部分的描述也可以完全理解本发明。
参考图1,在本发明的较佳实施例中,提供了一种基于记忆体的无监督域适应图片分类方法,该方法用于在给定有标签的源域数据集和无标签的目标域数据集的情况下,将有标签的源域数据集的知识迁移到无标签的目标域数据集上,并在迁移的过程中保证源域和目标域上每个类的分布对齐,以达到提高目标域上图片分类准确度的目的。该方法包括以下步骤:
S1、获取用于训练的有标签的源域数据集以及无标签的目标域数据集,所述源域数据集和目标域数据集均为图片数据集。本步骤的具体实现步骤包括:
S11、获取包含ns个图片样本xs以及它们对应的标签ys的源域数据集
其中,表示源域数据集的第i个图片样本,/>表示样本/>的标签, {1,2,...,K}是源域数据集中样本所属的标签空间,共包含K类标签,/> 且i∈{1,2,...,ns};
S12、获取包含nt个图片样本xt但不含标签的目标域数据集
其中,表示目标域数据集的第j个图片样本,j∈{1,2,...,nt};目标域数据集中样本所属的标签空间和源域数据集的标签空间一致,即/>
本发明的算法目标为:通过算法训练神经网络,使其能够为无标签的目标域数据集的每个样本预测对应的标签。
本发明中的神经网络含有特征提取模块g(·,θg)以及分类器模块f(·,θf),特征提取模块提取图片的特征后送入分类器模块进行分类,θg和θf分别为特征提取模块以及分类器模块的模块参数,其初始值为和/>神经网络的具体形式不限,在后续的实施例中采用了两种复杂神经网络,分别是ResNet-50,和 ResNet-101。当然也可以采用其他具有特征提取模块以及分类器模块的神经网络。
S2、用神经网络提取数据集中每张图片的特征,并根据聚类算法构建提取到的特征的类内结构。本步骤的具体实现步骤包括:
S21、用一个神经网络的特征提取模块提取数据集每张图片的特征:
其中,是神经网络的特征提取模块,/>是特征提取模块随机初始化后的参数,/>是源域数据集的第i个图片样本的特征,/>是目标域数据集的第j个图片样本的特征;
S22、计算出源域中每个类别所有图片的特征的均值并用其初始化目标域每个类团/>的中心/>
其中,是源域第k类样本的数量,/>是源域第k类样本的特征的中心,/>是初始目标域特征的第k个类团/>的中心;
S23、计算每个目标类团的中心/>与每个目标样本特征的球面空间距离:
其中,‖·‖代表内部变量的模,<·,·>代表两个变量的向量点积;
S24、针对每个目标域图片样本特征将其按照距离/>进行排序后归类于距离最近的类团,所有目标域图片样本特征均归类后重新计算出每个类团/>的中心
其中,代表重新归类后属于类团/>的特征数量;
S25、不断交替迭代S23和S24的聚类算法,收敛后得到K个类团这K个类团代表目标域数据集的类内结构。
S3、将源域和目标域每个类别的特征分别储存到源域和目标域对应类别的记忆体中。本步骤的具体实现步骤包括:
S31、将每个类团的类别k作为属于该类团的样本/>的标签/>
S32、从源域和目标域每个类团中各自抽取一部分特征分别装入源域和目标域的记忆体中:
其中,N为记忆体对应的长度,i∈{1,2,...,N},为源域记忆体第k类的第i个特征,/>为目标域记忆体第k类的第i个特征,/>为源域的第k个类团/>中的第i个特征,/>为目标域的第k个类团/>中的第i个特征。
S4、训练神经网络,并在训练过程中以源域和目标域记忆体的分布相似性作为条件约束神经网络。本步骤的具体实现步骤包括:
S41、通过优化第一loss函数l1(·,·),得到神经网络的特征提取模块和源域分类器模块/>在源域数据集/>上的最优参数/>和/>
式中:第一loss函数l1(·,·)为交叉熵损失函数;
S42、对于每一个目标域图片样本求得其特征:
其中r为限制特征ft的系数;通过目标域记忆体Mt重新预测的类别:
其中指目标域的第k类记忆体,d(·,·)表示计算L2距离;对于每一个样本/>若/>与此样本对应的类团的类别相同,则将此样本视为可靠样本,按照先进后出的原则将此样本的特征加入目标域的第k类记忆体/>中;
S43、通过优化第二loss函数l2(·,·,·)使得一个可反向传播的神经网络替代分类器模块学习到不可反向传播的聚类得到的类内结构:
其中,第二loss函数l2(·,·,·)为Triple损失,是目标域数据集/>的第k类可靠样本;/>是目标域记忆体第k类样本的中心:
是目标域记忆体第k类以外第y类样本的中心:
其中,代表目标域的第k类记忆体/>中的一个样本特征,/>代表目标域的第y类记忆体/>中的一个样本特征,/>为目标域的第k类记忆体/>中样本特征的数量,为标域的第y类记忆体/>中样本特征的数量;
S44、通过优化第三loss函数l3(·,·)提升源域和目标域记忆体分布的相似性:
其中,l3(·,·)为衡量分布差异的函数,计算公式为:
其中,分别为源域记忆体第i,j类特征的集合,/>为目标域记忆体第i,j类特征的集合,kernel为核函数。
上述核函数kernel的计算公式为:
其中N′为使用的核函数的数量,γn为:
S5、在完成由S2~S4构成的一轮更新训练后,利用训练过的模型重新提取每张图片的特征以及特征的类内结构,再根据新提取的特征更新记忆体,并以此源域和目标域记忆体的分布相似性作为条件约束进一步训练神经网络,完成新一轮的更新训练。本步骤的具体实现步骤包括:
当前的更新训练轮数记为n,n≥2,则第n轮更新训练的过程如下:
S51、以第n-1轮更新训练得到的神经网络特征提取模块为基础,按照S2步骤的操作,重新提取特征并得到对应的源域数据集的类团/>和目标域数据集的类团/>完成第n轮的特征类内结构构造;
S52、按照S3步骤的操作,将第n轮特征类内结构构造得到的源域和目标域类团中的特征分别装入源域和目标域的记忆体中,完成第n轮的记忆体初始化;
S53、以第n-1轮更新训练得到的神经网络特征提取模块以及分类器模块/>为基础,按照S4步骤的操作进行第n轮的以源域和目标域记忆体的分布相似性为条件约束的神经网络训练,得到第n轮更新训练后的神经网络参数/>和/>
上述第n轮更新训练的过程本质上就是利用前一轮更新参数和/>后的神经网络去重复S2~S4,每一轮的具体做法与前述S2~S4基本相同,仅更新模型参数即可。上述更新训练过程的框架如图2所示。
S6、不断重复步骤S5对神经网络进行迭代更新训练,每一轮更新训练均需要如前所述进行特征提取,特征类内结构构造、更新记忆体以及以两个域上记忆体的相似性为条件训练神经网络,直至网络收敛后停止迭代,得到最终的训练好的神经网络模型。
在该迭代过程中,在最后一轮训练中,以上一个阶段的最优的神经网络的特征提取模块和分类器模块/>为训练的基础,重复S5步骤的操作,得到最优的神经网络的/>和分类器模块/>及他们对应的最优参数/>和/>
S7、在获得训练好的神经网络模型后,使用训练好的神经网络模型在目标域的图像数据集上进行图像分类。在本步骤中,步骤S6完成后即得到了最优的神经网络的特征提取模块和分类器模块/>及他们对应的最优参数/>和/>使用最优参数下的特征提取模块/>和分类器模块/>在目标域上/>进行分类任务。
为了判断本发明方法的分类准确性,通过以下公式计算分类准确率:
其中,1[·]表示其中的条件成立时函数值取1,否则取0;表示目标域数据集中的第i张图片,/>表示目标域数据集中的第i张图片对应的标签;/>表示使用最优参数下的特征提取模块/>和分类器模块/>对/>进行分类任务得到的标签。
下面将上述方法应用于具体数据集中,以展示其技术效果。
实施例
下面基于上述方法进行仿真实验,本实施例的实现方法如前S1~S7所述,不再详细阐述具体的步骤,下面仅针对实验结果展示其结果。
本实施例使用了两种复杂网络,分别是ResNet-50,和ResNet-101。并在无监督域适应任务的三大数据集Office-31、Office-Home、VisDA-2017数据集上实施多次重复训练实验,证明了本方法可以有效地提高无监督域适应的效果。同时本发明还设置了无监督域适应图片分类传统方法作为对照。
表1本发明方法在Office-31、Office-Home、VisDA-2017数据集上的实施效果
上述实施例中,本发明的基于记忆体的无监督域适应图片分类方法首先使用记忆体对源域和目标域的特征分布进行建模,并逐类别地对齐源域和目标域的分布,充分利用了目标域不同图片对应特征的内在联系;在此基础上,使用了自步机制选择样本,提高了无监督域适应过程的鲁棒性。从结构看出,本发明的优化方法相比于传统方法能够明显提高无监督域适应效果,本方法优化后神经网络对于目标域数据集的图片分类任务的测试准确率进一步提升。
通过以上技术方案,本发明实施例基于记忆体的无监督域适应图片分类方法。本发明可以将原本的无监督域适应转化为以源域和目标域记忆体的分布的相似性为约束条件的无监督域适应,从而提高神经网络优化效果,提升目标域上图片分类任务的准确率。本发明适用于无监督域适应中的从有标签的源域数据集迁移到无标签的目标域数据集的迁移学习任务,面对各类复杂的情况具有较佳的效果和鲁棒性。
以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明的保护范围之内。

Claims (4)

1.一种基于记忆体的无监督域适应图片分类方法,其特征在于,包括以下步骤:
S1、获取用于训练的有标签的源域数据集以及无标签的目标域数据集,所述源域数据集和目标域数据集均为图片数据集;
S2、用神经网络提取数据集中每张图片的特征,并根据聚类算法构建提取到的特征的类内结构;
S3、将源域和目标域每个类别的特征分别储存到源域和目标域对应类别的记忆体中;
S4、训练神经网络,并在训练过程中以源域和目标域记忆体的分布相似性作为条件约束神经网络;
S5、在完成由S2~S4构成的一轮更新训练后,利用训练过的模型重新提取每张图片的特征以及特征的类内结构,再根据新提取的特征更新记忆体,并以此源域和目标域记忆体的分布相似性作为条件约束进一步训练神经网络,完成新一轮的更新训练;
S6、不断重复步骤S5对神经网络进行迭代更新训练,直至网络收敛,得到最终的训练好的神经网络;
S7、在获得训练好的神经网络后,使用训练好的神经网络在目标域数据集上进行图像分类;
步骤S1的具体实现步骤包括:
S11、获取包含ns个图片样本xs以及它们对应的标签ys的源域数据集
其中,表示源域数据集的第i个图片样本,/>表示样本/>的标签,{1,2,…,K}是源域数据集中样本所属的标签空间,共包含K类标签,/>且i∈{1,2,…,ns};
S12、获取包含nt个图片样本xt但不含标签的目标域数据集
其中,表示目标域数据集的第j个图片样本,j∈{1,2,…,nt};目标域数据集中样本所属的标签空间和源域数据集的标签空间一致,即/>
步骤S2的具体实现步骤包括:
S21、用一个神经网络的特征提取模块提取数据集每张图片的特征:
其中,是神经网络的特征提取模块,/>是特征提取模块随机初始化后的参数,是源域数据集的第i个图片样本的特征,/>是目标域数据集的第j个图片样本的特征;
S22、计算出源域中每个类别所有图片的特征的均值并用其初始化目标域每个类团的中心/>
其中,是源域第k类样本的数量,/>是源域第k类样本的特征的中心,/>是初始目标域特征的第k个类团/>的中心;
S23、计算每个目标类团的中心/>与每个目标样本特征的球面空间距离:
其中,‖·‖代表内部变量的模,<·,·>代表两个变量的向量点积;
S24、针对每个目标域图片样本特征将其按照距离/>进行排序后归类于距离最近的类团,所有目标域图片样本特征均归类后重新计算出每个类团/>的中心/>
其中,代表重新归类后属于类团/>的特征数量;
S25、不断交替迭代S23和S24的聚类算法,收敛后得到K个类团这K个类团代表目标域数据集的类内结构;
步骤S3的具体实现步骤包括:
S31、将每个类团的类别k作为属于该类团的样本/>的标签/>
S32、从源域和目标域每个类团中各自抽取一部分特征分别装入源域和目标域的记忆体中:
其中,N为记忆体对应的长度,i∈{1,2,…,N},为源域记忆体第k类的第i个特征,为目标域记忆体第k类的第i个特征,/>为源域的第k个类团/>中的第i个特征,/>为目标域的第k个类团/>中的第i个特征;
步骤S4中的具体实现步骤包括:
S41、通过优化第一loss函数l1(·,·),得到神经网络的特征提取模块和源域分类器模块/>在源域数据集/>上的最优参数/>和/>
S42、对于每一个目标域图片样本求得其特征:
其中r为限制特征ft的系数;通过目标域记忆体Mt重新预测的类别:
其中指目标域的第k类记忆体,d(·,·)表示计算L2距离;对于每一个样本/>与此样本对应的类团的类别相同,则将此样本视为可靠样本,按照先进后出的原则将此样本的特征加入目标域的第k类记忆体/>中;
S43、通过优化第二loss函数l2(·,·,·)使得一个可反向传播的神经网络替代分类器模块学习到不可反向传播的聚类得到的类内结构:
其中,是目标域数据集/>的第k类可靠样本;/>是目标域记忆体第k类样本的中心:
是目标域记忆体第k类以外第y类样本的中心:
其中,代表目标域的第k类记忆体/>中的一个样本特征,/>代表目标域的第y类记忆体/>中的一个样本特征,/>为目标域的第k类记忆体/>中样本特征的数量,/>为标域的第y类记忆体/>中样本特征的数量;
S44、通过优化第三loss函数l3(·,·)提升源域和目标域记忆体分布的相似性:
其中,l3(·,·)为衡量分布差异的函数,计算公式为:
其中,分别为源域记忆体第i,j类特征的集合,/>为目标域记忆体第i,j类特征的集合,kernel为核函数。
2.如权利要求1所述的基于记忆体的无监督域适应图片分类方法,其特征在于,所述第一loss函数l1(·,·)为交叉熵损失函数,所述第二loss函数l2(·,·,·)为Triple损失。
3.如权利要求1所述的基于记忆体的无监督域适应图片分类方法,其特征在于,所述核函数kernel的计算公式为:
其中N′为使用的核函数的数量,γn为:
4.如权利要求1所述的基于记忆体的无监督域适应图片分类方法,其特征在于,步骤S5中,第n轮更新训练的具体实现步骤包括:
S51、以第n-1轮更新训练得到的神经网络特征提取模块为基础,按照S2步骤的操作,重新提取特征并得到对应的源域数据集的类团/>和目标域数据集的类团/>完成第n轮的特征类内结构构造;
S52、按照S3步骤的操作,将第n轮特征类内结构构造得到的源域和目标域类团中的特征分别装入源域和目标域的记忆体中,完成第n轮的记忆体初始化;
S53、以第n-1轮更新训练得到的神经网络特征提取模块以及分类器模块为基础,按照S4步骤的操作进行第n轮的以源域和目标域记忆体的分布相似性为条件约束的神经网络训练,得到第n轮更新训练后的神经网络参数/>和/>
CN202110776679.3A 2021-07-09 2021-07-09 一种基于记忆体的无监督域适应图片分类方法 Active CN113673555B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110776679.3A CN113673555B (zh) 2021-07-09 2021-07-09 一种基于记忆体的无监督域适应图片分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110776679.3A CN113673555B (zh) 2021-07-09 2021-07-09 一种基于记忆体的无监督域适应图片分类方法

Publications (2)

Publication Number Publication Date
CN113673555A CN113673555A (zh) 2021-11-19
CN113673555B true CN113673555B (zh) 2023-12-12

Family

ID=78539034

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110776679.3A Active CN113673555B (zh) 2021-07-09 2021-07-09 一种基于记忆体的无监督域适应图片分类方法

Country Status (1)

Country Link
CN (1) CN113673555B (zh)

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN105404902A (zh) * 2015-10-27 2016-03-16 清华大学 基于脉冲神经网络的图像特征描述和记忆方法
CN111832605A (zh) * 2020-05-22 2020-10-27 北京嘀嘀无限科技发展有限公司 无监督图像分类模型的训练方法、装置和电子设备
CN111931814A (zh) * 2020-07-03 2020-11-13 浙江大学 一种基于类内结构紧致约束的无监督对抗域适应方法
CN112396078A (zh) * 2019-08-16 2021-02-23 中国移动通信有限公司研究院 一种服务分类方法、装置、设备及计算机可读存储介质
WO2021057427A1 (zh) * 2019-09-25 2021-04-01 西安交通大学 一种基于PU learning的跨区域企业偷漏税识别方法及系统
CN113011456A (zh) * 2021-02-05 2021-06-22 中国科学技术大学 用于图像分类的基于类别自适应模型的无监督域适应方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN105404902A (zh) * 2015-10-27 2016-03-16 清华大学 基于脉冲神经网络的图像特征描述和记忆方法
CN112396078A (zh) * 2019-08-16 2021-02-23 中国移动通信有限公司研究院 一种服务分类方法、装置、设备及计算机可读存储介质
WO2021057427A1 (zh) * 2019-09-25 2021-04-01 西安交通大学 一种基于PU learning的跨区域企业偷漏税识别方法及系统
CN111832605A (zh) * 2020-05-22 2020-10-27 北京嘀嘀无限科技发展有限公司 无监督图像分类模型的训练方法、装置和电子设备
CN111931814A (zh) * 2020-07-03 2020-11-13 浙江大学 一种基于类内结构紧致约束的无监督对抗域适应方法
CN113011456A (zh) * 2021-02-05 2021-06-22 中国科学技术大学 用于图像分类的基于类别自适应模型的无监督域适应方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
Unsupervised Learning using Pretrained CNN and Associative Memory Bank;Qun Liu 等;《IJCNN》;第1-8页 *

Also Published As

Publication number Publication date
CN113673555A (zh) 2021-11-19

Similar Documents

Publication Publication Date Title
CN110321926B (zh) 一种基于深度残差修正网络的迁移方法及系统
CN108399428B (zh) 一种基于迹比准则的三元组损失函数设计方法
CN109117793B (zh) 基于深度迁移学习的直推式雷达高分辨距离像识别方法
CN111931814B (zh) 一种基于类内结构紧致约束的无监督对抗域适应方法
CN114241282A (zh) 一种基于知识蒸馏的边缘设备场景识别方法及装置
CN108985268B (zh) 基于深度迁移学习的归纳式雷达高分辨距离像识别方法
CN108875933B (zh) 一种无监督稀疏参数学习的超限学习机分类方法及系统
CN108171318B (zh) 一种基于模拟退火—高斯函数的卷积神经网络集成方法
CN111079847B (zh) 一种基于深度学习的遥感影像自动标注方法
CN108399406A (zh) 基于深度学习的弱监督显著性物体检测的方法及系统
Yu et al. Multi-target unsupervised domain adaptation without exactly shared categories
CN110210468B (zh) 一种基于卷积神经网络特征融合迁移的文字识别方法
CN111275092A (zh) 一种基于无监督域适应的图像分类方法
CN107491729B (zh) 基于余弦相似度激活的卷积神经网络的手写数字识别方法
CN111967325A (zh) 一种基于增量优化的无监督跨域行人重识别方法
CN113095229B (zh) 一种无监督域自适应行人重识别系统及方法
CN112465226B (zh) 一种基于特征交互和图神经网络的用户行为预测方法
CN113313179B (zh) 一种基于l2p范数鲁棒最小二乘法的噪声图像分类方法
CN117993282A (zh) 一种面向智能制造故障诊断的域适应性信息瓶颈联邦学习方法
CN111783688B (zh) 一种基于卷积神经网络的遥感图像场景分类方法
CN112949590A (zh) 一种跨域行人重识别模型构建方法及构建系统
CN113673555B (zh) 一种基于记忆体的无监督域适应图片分类方法
CN112750128A (zh) 图像语义分割方法、装置、终端及可读存储介质
CN109284375A (zh) 一种基于原始数据信息保留的域自适应降维方法
CN114037866B (zh) 一种基于可辨伪特征合成的广义零样本图像分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant