CN117523295A - 基于类引导元学习的无源域适应的图像分类方法 - Google Patents
基于类引导元学习的无源域适应的图像分类方法 Download PDFInfo
- Publication number
- CN117523295A CN117523295A CN202311543424.8A CN202311543424A CN117523295A CN 117523295 A CN117523295 A CN 117523295A CN 202311543424 A CN202311543424 A CN 202311543424A CN 117523295 A CN117523295 A CN 117523295A
- Authority
- CN
- China
- Prior art keywords
- image
- pseudo
- representing
- network
- tag
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 47
- 230000003044 adaptive effect Effects 0.000 title claims abstract description 8
- 238000012549 training Methods 0.000 claims abstract description 37
- 230000006870 function Effects 0.000 claims abstract description 21
- 238000004364 calculation method Methods 0.000 claims description 29
- 239000011159 matrix material Substances 0.000 claims description 21
- 238000012937 correction Methods 0.000 claims description 13
- 230000008447 perception Effects 0.000 claims description 9
- 238000012545 processing Methods 0.000 claims description 9
- 238000000605 extraction Methods 0.000 claims description 4
- 238000005457 optimization Methods 0.000 claims description 4
- 239000013598 vector Substances 0.000 claims description 4
- 230000004931 aggregating effect Effects 0.000 claims description 3
- 230000001419 dependent effect Effects 0.000 claims description 3
- 239000002994 raw material Substances 0.000 claims description 2
- 230000006978 adaptation Effects 0.000 abstract description 10
- 238000009826 distribution Methods 0.000 description 7
- 238000009825 accumulation Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 241000022852 Letis Species 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000015556 catabolic process Effects 0.000 description 2
- 238000006731 degradation reaction Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000012733 comparative method Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000001902 propagating effect Effects 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 230000001502 supplementing effect Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 239000002699 waste material Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0985—Hyperparameter optimisation; Meta-learning; Learning-to-learn
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于类引导元学习的无源域适应的图像分类方法,构建由教师网络和学生网络构成的自训练师生网络,并构建类别感知的元学习模块,将教师网络、学生网络和元学习模块构成元学习模型,采用教师网络得到图像集合中图像样本的伪标签并进行伪标签校正,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用图像集合对元学习模型进行训练,采用学生网络的参数对教师网络的参数进行更新,重新获取伪标签,如此循环直到达到最大迭代次数,从最终的元学习模型中提取学生网络作为目标域模型对目标域数据图像进行分类。本发明可在无源域适应情况下有效提升目标域模型分类的性能。
Description
技术领域
本发明属于图像分类技术领域,更为具体地讲,涉及一种基于类引导元学习的无源域适应的图像分类方法。
背景技术
深度神经网络已成功地在各种应用中展示了高性能。然而,如果训练和测试数据的分布不同,则会发生显著的性能下降,这被称为域偏移。无监督域适应图像分类在假设两个域中的数据分布不同的情况下,利用完全标注的源数据图像和未标注的目标图像数据来缓解域移位问题,所有传统的无监督域适应图像分类方法都假设源数据和对应标签两者的可用性。然而,这在一些情况下可能是不切实际的。首先,对数据隐私和安全的日益担忧迫使公司只发布目标图像数据,无法获取源数据图像。第二,当源数据图像比目标数据图像大得多时,需要许多资源来训练模型。无源域自适应旨在使预训练的源模型适应未标记的目标域,而无需访问标记良好的源数据,这种应用场景较传统的无监督域适应更为普遍。
现有的无源域自适应图像分类方法主要分为两大类,一类是数据生成的方式,基于数据生成方法的目的是重建源域,以补偿缺失的源域数据,从而使无监督域适应方法可以扩展到无源域自适应图像分类方法。而生成模型的训练通常是复杂的且生成模型容易出现模型崩塌问题,即生成的图像样本在特征空间中聚集在一些局部模式附近,而忽略其他潜在的类别和样本分布。另一类是基于自训练的方法,这类方法假设源预训练模型由于源和目标域的相似性而在目标域上具有一定程度的泛化。目前基于自训练的方法占据无源域自适应图像分类方法的主流,其主要通过源域模型对目标域图像进行标签预测来指导模型自训练,但现有的基于自训练方法主要存在以下问题:1)将获取的伪标签按相同权重分配给样本用于指导模型训练容易导致噪声累计,导致模型性能下降;2)现有方法没有考虑到现实收集的数据呈现显著的长尾分布现象,即一些常见类的样本数量大,罕见类样本数量少,这往往导致训练的模型对常见类的预测相对稳定,而后者容易预测错误,现有方法往往忽略了这种数据偏差问题导致模型预测结果较差。
发明内容
本发明的目的在于克服现有技术的不足,提供一种基于类引导元学习的无源域适应的图像分类方法,通过对伪标签进行校正提升伪标签的质量,并通过类引导元学习来为每个样本学习一个权重,减少错误标签的噪声累计问题,同时根据类的置信度不同,减少了类的长尾分布引起的数据偏差问题,显著提升目标域模型分类的性能。
为了实现上述发明目的,本发明基于类引导元学习的无源域适应的图像分类方法包括以下步骤:
S1:根据实际需要选取已训练的源域特征提取器和源域特征分类器,采用其参数初始化目标域特征提取器和目标域特征分类器,然后将目标域特征提取器和目标域特征分类器复制一份,一组作为教师网络,一组作为学生网络,从而构成自训练师生网络;
S3:构建类别感知的元学习模块,包括损失计算模块、损失多层感知机、类伪准确率计算模块、类伪准确率多层感知机和置信度计算模块,其中:
损失计算模块用于根据学生网络对图像样本得到的预测标签计算预测标签损失L(w)并发送至损失多层感知机,其中w表示学生网络参数;
损失多层感知机用于根据图像样本x的预测标签损失L(w)生成对应的样本权重θ表示损失多层感知机的网络参数,S表示预设的权重维度,并将样本权重P(L(w),θ)发送至置信度权重计算模块;
类伪准确率计算模块用于根据教师网络对所有图像样本的预测标签计算每个图像样本在K个类别的类伪准确率pk并发送至类伪准确率多层感知机,k=1,2,…,K,K表示目标类别数量,类伪准确率pk的计算公式如下:
其中,表示教师网络对图像样本xi的预测标签,/>表示二值函数,当时/>否则/> 表示教师网络中目标域特征提取器/>对图像xi所提取得到的特征向量,/>表示目标域特征分类器/>对特征向量/>推断得到的可能性,/>表示采用softmax函数根据可能性/>所得到的概率;
类伪准确率多层感知机用于根据每个图像样本对应的K个类伪准确率pk生成该图像样本的类别权重 表示类伪准确率多层感知机的网络参数,并将类别权重/>发送至置信度权重计算模块;
置信度权重计算模块用于根据样本权重P(L(w),θ)和类别权重计算得到置信度权重/>
其中,表示两个向量之间的点积;
S3:将教师网络、学生网络和元学习模块构成元学习模型,其中教师网络和学生网络分别对输入图像进行预测得到预测标签,元学习模块根据教师网络和学生网络的预测标签计算得到置信度权重;
S4:令迭代次数t=1;
S5:将图像集合XT中每幅图像xi输入至教师网络中,得到图像xi属于类别k的概率γi,k,从而确定图像xi的分类结果并将其作为该图像的初始伪标签y′i,从而得到图像集合XT对应的伪标签集合Y′T,i=1,2,…,N,N表示图像数量,y′i∈[1,K];
S6:采用伪标签校正方法对伪标签集合Y′T进行处理,得到每幅图像xi校正后的伪标签yi,从而得到校正后的伪标签集合YT;
S7:对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用图像集合XT对元学习模型进行训练,上层任务和下层任务交替进行迭代更新,完成元学习模型的训练;
S8:判断是否t<tmax,tmax表示预测的最大迭代次数,如果是,进入步骤S9,否则进入步骤S10;
S9:令t=t+1,采用当前学生网络的参数对教师网络的参数进行更新,返回步骤S5;
S10:从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
本发明基于类引导元学习的无源域适应的图像分类方法,构建由教师网络和学生网络构成的自训练师生网络,并构建类别感知的元学习模块,将教师网络、学生网络和元学习模块构成元学习模型,采用教师网络得到图像集合中每个图像样本的伪标签并进行伪标签校正,对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用图像集合对元学习模型进行训练,然后采用学生网络的参数对教师网络的参数进行更新,重新获取伪标签,如此循环直到达到最大迭代次数,从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
本发明具有以下有益效果:
1)本发明可以在无法访问源域数据,依赖源域中的预训练模型对无标签的目标域图像进行分类,既保护了源域数据的隐私又解决了大规模源域数据计算存储资源的浪费,提升了无源域自适应效果;
2)本发明通过对伪标签进行校正提升伪标签的质量,从而提高学习效果;
3)本发明通过类引导元学习将来自类别的元知识和原始图像样本的损失信息作为补充,输出每个图像样本的伪标签置信度,减少了类的长尾分布引起的数据偏差问题,显著提升目标域模型分类的性能。
附图说明
图1是本发明基于类引导元学习的无源域适应的图像分类方法的具体实施方式流程图;
图2是本发明中类别感知的元学习模块的结构图;
图3是本实施例中基于信息约束的伪标签校正方法的流程图。
具体实施方式
下面结合附图对本发明的具体实施方式进行描述,以便本领域的技术人员更好地理解本发明。需要特别提醒注意的是,在以下的描述中,当已知功能和设计的详细描述也许会淡化本发明的主要内容时,这些描述在这里将被忽略。
实施例
图1是本发明基于类引导元学习的无源域适应的图像分类方法的具体实施方式流程图。如图1所示,本发明基于类引导元学习的无源域适应的图像分类方法的具体步骤包括:
S101:构建自训练师生网络:
根据实际需要选取已训练的源域特征提取器和源域特征分类器,采用其参数初始化目标域特征提取器和目标域特征分类器,然后将目标域特征提取器和目标域特征分类器复制一份,一组作为教师网络,一组作为学生网络,从而构成自训练师生网络。根据实际需要确定目标域的无标签图像,得到图像集合XT。在实际应用中可以对图像集合XT中的图像根据预设方法进行增强处理,以便更好地提取特征。
S102:构建类别感知的元学习模块:
虽然采用伪标签校正能够在很大程度上提升了伪标签的质量,但由于源域与目标域之间的域间隙这一客观问题,部分样本标签仍然存在错误分配问题,为了缓解伪标签错误累计导致性能下降,当模型训练时,应当抑制伪标签中具有低置信度的样本,因此本发明设置了类别感知的元学习模块,同时考虑了来自类别的元知识和样本的损失信息来学习每个样本的置信度权重。图2是本发明中类别感知的元学习模块的结构图。如图2所示,本发明中类别感知的元学习模块包括损失计算模块、损失多层感知机、类伪准确率计算模块、类伪准确率多层感知机和置信度计算模块,其中:
损失计算模块用于根据学生网络对图像样本得到的预测标签计算预测标签损失L(w)并发送至损失多层感知机,其中w表示学生网络参数。本实施例中预测标签损失采用常用的交叉熵损失,计算公式如下:
L(w)=-ylog(f(x,w))
其中,y表示图像样本x当前的标签,f(x,w)表示输入图像样本x由参数为w的学生网络得到的预测标签。
损失多层感知机用于根据图像样本的预测标签损失L(w)生成对应的样本权重θ表示损失多层感知机的网络参数,S表示预设的权重维度,并将样本权重P(L(w),θ)发送至置信度权重计算模块。
类伪准确率计算模块用于根据教师网络对所有图像样本的预测标签计算每个图像样本在K个类别的类伪准确率pk并发送至类伪准确率多层感知机,k=1,2,…,K,K表示目标类别数量,类伪准确率pk的计算公式如下:
其中,表示教师网络对图像样本xi的预测标签,/>表示二值函数,当时/>否则/> 表示教师网络中目标域特征提取器/>对图像xi所提取得到的特征向量,/>表示目标域特征分类器/>对特征向量/>推断得到的可能性(即logits),/>表示采用softmax函数根据可能性所得到的概率。
类伪准确率多层感知机用于根据每个图像样本对应的K个类伪准确率pk生成该图像样本的类别权重 表示类伪准确率多层感知机的网络参数,并将类别权重/>发送至置信度权重计算模块。例如中将S设置为3,相当于将K个类分成高置信度类、中置信度类以及低置信度类3组,使得类引导的元学习置信度模块能对不同组的类别赋予不同的类权重。
置信度权重计算模块用于根据样本权重P(L(w),θ)和类别权重计算得到置信度权重/>
其中,表示两个向量之间的点积。
S103:构建元学习模型:
将教师网络、学生网络和元学习模块构成元学习模型,其中教师网络和学生网络分别对输入图像进行预测得到预测标签,元学习模块根据教师网络和学生网络的预测标签计算得到置信度权重。
S104:令迭代次数t=1。
S105:采用教师网络获取伪标签:
将图像集合XT中每幅图像xi输入至教师网络中,得到图像xi属于类别k的概率γi,k,从而确定图像xi的分类结果并将其作为该图像的初始伪标签y′i,从而得到图像集合XT对应的伪标签集合Y′T,i=1,2,…,N,N表示图像数量,y′i∈[1,K]。
本实施例中,为了提高初始伪标签的准确性,采用加权平均的方法来确定伪标签,具体方法为:
对图像xi进行随机M次增强处理,然后将增强处理后的图像xi,m输入至教师网络得到预测标签,m=1,2,…,M,然后采用如下公式计算得到图像xi属于类别k的概率γi,k:
其中,表示教师网络的目标域特征提取器/>对图像xi,m提取得到的特征,表示教师网络的目标域特征分类器/>预测得到的图像xi,m属于类别k的概率。
选择概率γi,k最大值对应的类别序号作为图像xi的伪标签y′i。
S106:伪标签校正:
考虑到源域和目标域之间存在一定的域偏移,通过教师网络中目标域模型预测得到的伪标签集合YT存在大量的噪声,因此本发明中采用伪标签校正方法对伪标签集合Y′T进行处理,得到每幅图像xi校正后的伪标签yi,从而得到校正后的伪标签集合YT,进一步提升伪标签的质量。
伪标签校正的具体方法可以根据实际需要进行设置,例如常规伪标签校正是通过迭代更新类别的类中心来校正伪标签。经研究发现,当数据类别分布不均时,分类器常常对“大”类(属于该类的样本多)会产生一定的偏见,在分类时候倾向对“大”类分配更高的预测值,而基于机器学习中的平滑性假设在一定程度上可以缓解这种偏见,当两个样本之间的特征相似,则两个样本有较大概率分配相同的标签。因此为了提高校正后的伪标签的质量,本实施例中提出了一种基于信息约束的伪标签校正方法。图3是本实施例中基于信息约束的伪标签校正方法的流程图。如图3所示,本实施例中基于信息约束的伪标签校正方法的具体步骤包括:
S201:构建邻接矩阵:
计算图像集合XT图像的邻接矩阵A,其元素wi,j表示图像xi,xj之间的权重ai,j,计算公式如下:
其中,分别表示教师网络的目标域特征提取器/>对图像xi,xj提取得到的特征,在实际应用中如果伪标签采用加权平均的方式确定,特征可以采用M次特征的平均值。/>表示特征/>和/>的相似度,本实施例中采用将特征和/>展开得到的一维向量的余弦相似度作为特征之间的相似度。σ表示调优参数,e表示自然常数。
S202:基于熵值进行样本划分:
根据确定伪标签时每个图像xi属于类别k的概率γi,k计算得到每个图像对应的熵值entropy(xi):
根据图像的初始伪标签y′i划分得到每个类别的图像集合从图像集合/>中选择熵值排序最高的z张图像,z根据实际情况设置,将筛选出的图像作为有干净标签的样本,其余图像作为无标签样本,从而划分得到有干净标签的样本集合L,无标签样本集合U。
S203:采用标签传播算法更新标签:
采用标签传播算法更新无标签样本的标签,通过利用邻接矩阵的相邻点具有相似标签的假设将干净标签信息从L传播到U。考虑到无标签样本在步骤S105中已经具有了伪标签y′i,这些伪标签可以作为一种先验信息来约束标签的生成,因此本实施例中标签传播算法的目标函数设置如下:
其中,hi、hj表示图像xi、xj的标签构成的独热编码,yi表示来自集合L中干净的标签,μ表示约束参数。可以采用表示所有图像的标签独热编码构成的标签矩阵。
通过优化求解上述基于信息约束的标签传播目标损失函数,就可以完成各个图像标签的校正。本实施例中采用如下优化步骤得到对应的闭式解:
1)对损失函数中的第一项进一步展开为:
因为约束条件中,当i∈L时,hi=yi为已知值,所以第一项为常数,对于第二项可以进一步展开得到:
此外,令为对角矩阵,其对角线元素为/>j∈U;令矩阵的元素为aij,i∈L,j∈U,|L|,|U|分别表示集合L,集合U中样本数量。此外,根据约束条件,上式第一项也为常数,因此,上式可进一步表示为:
其中,表示由集合U中图像样本的标签独热编码构成的标签矩阵,表示由集合L中各个图像样本的标签独热编码构成的标签矩阵。
同理,对于第三项可以进一步展开得到:
此外,令为对角矩阵,其对角线元素为/>j∈U;令矩阵的元素为aij,i∈U,j∈U。因此,上式可进一步表示为:
根据上述分析,目标函数最后展开为:
其中,表示无标签样本集合U中所有图像在步骤S105过程中获得的伪标签对应的独热编码构成的标签矩阵。
对hU求偏导并令等式为0,可以得到目标函数的闭式解:
其中,I表示单位矩阵。从标签矩阵hU中提取得到无标签样本集合U各个样本校正后的伪标签,完成伪标签校正。
S107:训练元学习模型:
对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用数据集对元学习模型进行训练,上层任务和下层任务交替进行迭代更新,完成元学习模型的训练。
元学习模块的双层任务优化可以采用如下公式表示:
其中,表示干净数据集,||表示求取集合中图像样本数量,xmeta表示高质量数据集中的图像样本,ymeta表示图像样本xmeta的标签,/>表示学生网络对图像样本xmeta的预测标签。
由于本发明提出是针对无源域适应的图像分类任务,无法获取无偏估计的干净数据集,因此本实施例中采用输入扰动的预测一致性损失作为元知识来指导元学习上层任务的学习,基于预测一致性损失的元学习双层优化任务的表达式如下:
其中,表示从数据集XT中随机抽取的子数据集,xD表示数据集/>中的图像样本,A(xD)表示对图像样本/>进行数据增强处理后的图像样本,/> 分别表示学生网络对数据样本xD和增强后数据样本A(xD)的预测标签,κ()表示求取KL散度,作为度量输入扰动的预测一致性损失。/>表示图像样本xi的预测标签损失L(w),/>表示图像样本xi对应的置信度权重。
S108:判断是否t<tmax,tmax表示预测的最大迭代次数,如果是,进入步骤S109,否则进入步骤S110。
S109:令t=t+1,采用当前学生网络的参数对教师网络的参数进行更新,返回步骤S105。
S110:利用目标域模型进行图像分类:
从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
为了更好地说明本发明的技术效果,采用具体实例对本发明进行实验验证。在本实验验证中,使用无源域适应领域常规测试基础数据集Office-31,其由三个不同域组成,包括由亚马逊商家图像域(Amazon),网络摄像头收集的低分辨率图像域(Webcam)以及单反相机拍摄的高解析度图像域(DSLR),该数据集包含了31类常见的办公物体,如笔记本电脑、文件柜、键盘等,共4652张图像。为了全面体现本发明的优势,设计了六种无源域适应任务,即Amazon→DSLR,Amazon→Webcam,DSLR→Webcam,DSLR→Amazon,Webcam→Amazon,Webcam→DSLR。
本次实验验证中设置4种对比方法,分别为ResNet50,DANN(Domain-adversarialtraining of neural networks),CDAN(Conditional adversarial domain adaptation)以及SHOT(Source hypothesis transfer for unsupervised domain adaptation)。
本发明由PyTorch实现,并在NVIDIARTX3090 GPU上进行了训练,本发明使用的自训练师生网络中教师网络和学生网络采用ResNet50。
表1是本实施例中本发明和对比方法在不同任务下的分类正确率统计表。
表1
如表1所示,从表1中的结果可以看出,本发明在Office-31数据集的6种自适应任务中均取得了最好的效果,从而验证本发明的有效性。
尽管上面对本发明说明性的具体实施方式进行了描述,以便于本技术领域的技术人员理解本发明,但应该清楚,本发明不限于具体实施方式的范围,对本技术领域的普通技术人员来讲,只要各种变化在所附的权利要求限定和确定的本发明的精神和范围内,这些变化是显而易见的,一切利用本发明构思的发明创造均在保护之列。
Claims (4)
1.一种基于类引导元学习的无源域适应的图像分类方法,其特征在于,包括以下步骤:
S1:根据实际需要选取已训练的源域特征提取器和源域特征分类器,采用其参数初始化目标域特征提取器和目标域特征分类器,然后将目标域特征提取器和目标域特征分类器复制一份,一组作为教师网络,一组作为学生网络,从而构成自训练师生网络;
S3:构建类别感知的元学习模块,包括损失计算模块、损失多层感知机、类伪准确率计算模块、类伪准确率多层感知机和置信度计算模块,其中:
损失计算模块用于根据学生网络对图像样本得到的预测标签计算预测标签损失L(w)并发送至损失多层感知机,其中w表示学生网络参数;
损失多层感知机用于根据图像样本x的预测标签损失L(w)生成对应的样本权重θ表示损失多层感知机的网络参数,S表示预设的权重维度,并将样本权重P(L(w),θ)发送至置信度权重计算模块;
类伪准确率计算模块用于根据教师网络对所有图像样本的预测标签计算每个图像样本在K个类别的类伪准确率pk并发送至类伪准确率多层感知机,k=1,2,…,K,K表示目标类别数量,类伪准确率pk的计算公式如下:
其中,表示教师网络对图像样本xi的预测标签,/>表示二值函数,当/>时否则/> 表示教师网络中目标域特征提取器/>对图像xi所提取得到的特征向量,/>表示目标域特征分类器/>对特征向量/>推断得到的可能性,/>表示采用softmax函数根据可能性/>所得到的概率;
类伪准确率多层感知机用于根据每个图像样本对应的K个类伪准确率pk生成该图像样本的类别权重 表示类伪准确率多层感知机的网络参数,并将类别权重发送至置信度权重计算模块;
置信度权重计算模块用于根据样本权重P(L(w),θ)和类别权重计算得到置信度权重/>
其中,表示两个向量之间的点积;
S3:将教师网络、学生网络和元学习模块构成元学习模型,其中教师网络和学生网络分别对输入图像进行预测得到预测标签,元学习模块根据教师网络和学生网络的预测标签计算得到置信度权重;
S4:令迭代次数t=1;
S5:将图像集合XT中每幅图像xi输入至教师网络中,得到图像xi属于类别k的概率γi,k,从而确定图像xi的分类结果并将其作为该图像的初始伪标签y′i,从而得到图像集合XT对应的伪标签集合Y′T,i=1,2,…,N,N表示图像数量,y′i∈[1,K];
S6:采用伪标签校正方法对伪标签集合Y′T进行处理,得到每幅图像xi校正后的伪标签yi,从而得到校正后的伪标签集合YT;
S7:对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用图像集合XT对元学习模型进行训练,上层任务和下层任务交替进行迭代更新,完成元学习模型的训练;
S8:判断是否t<tmax,tmax表示预测的最大迭代次数,如果是,进入步骤S9,否则进入步骤S10;
S9:令t=t+1,采用当前学生网络的参数对教师网络的参数进行更新,返回步骤S5;
S10:从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
2.根据权利要求1所述的图像分类方法,其特征在于,所述步骤S5中伪标签采用如下方法确定:
对图像xn进行随机M次增强处理,然后将增强处理后的图像xi,m输入至教师网络得到预测标签,m=1,2,…,M,然后采用如下公式计算得到图像xi属于类别k的概率γi,k:
其中,表示教师网络的目标域特征提取器/>对图像xi,m提取得到的特征,表示教师网络的目标域特征分类器/>预测得到的图像xi,m属于类别k的概率;
选择概率γi,k最大值对应的类别序号作为图像xi的伪标签y′i。
3.根据权利要求1所述的图像分类方法,其特征在于,所述步骤S6中伪标签校正方法的具体步骤如下:
1)计算图像集合XT图像的邻接矩阵A,其元素ai,j表示图像xi,xj之间的权重ai,j,计算公式如下:
其中,分别表示教师网络的目标域特征提取器/>对图像xi,xj提取得到的特征,/>表示特征/>和/>的相似度;σ表示调优参数,e表示自然常数;
2)根据确定伪标签时每个图像xi属于类别k的概率γi,k计算得到每个图像对应的熵值entropy(xi):
根据图像的初始伪标签y′i划分得到每个类别的图像集合从图像集合/>中选择熵值排序最高的z张图像,z根据实际情况设置,将筛选出的图像作为有干净标签的样本,其余图像作为无标签样本,从而划分得到有干净标签的样本集合L,无标签样本集合U;
3)采用如下方法计算得到无标签样本集合U中图像样本经校正后的标签:
采用如下公式计算集合U中图像样本的标签独热编码构成的标签矩阵
其中,表示对角矩阵,其对角线元素为/> 为对角矩阵,其对角线元素为/>矩阵/>的元素为aij,i∈U,j∈U;表示由集合U中图像样本的标签独热编码构成的标签矩阵,/>表示无标签样本集合U中所有图像在步骤S5过程中获得的伪标签对应的独热编码构成的标签矩阵,|L|,|U|分别表示集合L,集合U中样本数量,I表示单位矩阵,μ表示约束参数;
从标签矩阵hU中提取得到无标签样本集合U各个样本校正后的伪标签,完成伪标签校正。
4.根据权利要求1所述的图像分类方法,其特征在于,所述步骤S7中元学习模型训练过程中双层优化任务的表示式如下:
其中,表示从数据集XT中随机抽取的子数据集,xD表示数据集/>中的图像样本,A(xD)表示对图像样本/>进行数据增强处理后的图像样本,分别表示学生网络对数据样本xD和增强后数据样本A(xD)的预测标签,κ()表示求取KL散度,/>表示图像样本xi的预测标签损失L(w),表示图像样本xi对应的置信度权重。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311543424.8A CN117523295B (zh) | 2023-11-17 | 2023-11-17 | 基于类引导元学习的无源域适应的图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311543424.8A CN117523295B (zh) | 2023-11-17 | 2023-11-17 | 基于类引导元学习的无源域适应的图像分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117523295A true CN117523295A (zh) | 2024-02-06 |
CN117523295B CN117523295B (zh) | 2024-09-24 |
Family
ID=89760346
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311543424.8A Active CN117523295B (zh) | 2023-11-17 | 2023-11-17 | 基于类引导元学习的无源域适应的图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117523295B (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117892183A (zh) * | 2024-03-14 | 2024-04-16 | 南京邮电大学 | 一种基于可靠迁移学习的脑电信号识别方法及系统 |
CN118334062A (zh) * | 2024-06-13 | 2024-07-12 | 江西师范大学 | 无源域自适应眼底图像分割方法和设备 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021068180A1 (en) * | 2019-10-11 | 2021-04-15 | Beijing Didi Infinity Technology And Development Co., Ltd. | Method and system for continual meta-learning |
CN115578568A (zh) * | 2022-11-15 | 2023-01-06 | 南京码极客科技有限公司 | 一种小规模可靠数据集驱动的噪声修正算法 |
CN116977731A (zh) * | 2023-07-31 | 2023-10-31 | 厦门大学 | 面向目标物分类的模型自增强方法、介质和设备 |
-
2023
- 2023-11-17 CN CN202311543424.8A patent/CN117523295B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021068180A1 (en) * | 2019-10-11 | 2021-04-15 | Beijing Didi Infinity Technology And Development Co., Ltd. | Method and system for continual meta-learning |
CN115578568A (zh) * | 2022-11-15 | 2023-01-06 | 南京码极客科技有限公司 | 一种小规模可靠数据集驱动的噪声修正算法 |
CN116977731A (zh) * | 2023-07-31 | 2023-10-31 | 厦门大学 | 面向目标物分类的模型自增强方法、介质和设备 |
Non-Patent Citations (2)
Title |
---|
MAHSA GHORBANI: "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference", 《ARXIV》, 8 April 2021 (2021-04-08), pages 1 - 12 * |
张玉清;董颖;柳彩云;雷柯楠;孙鸿宇;: "深度学习应用于网络空间安全的现状、趋势与展望", 计算机研究与发展, no. 06, 12 January 2018 (2018-01-12), pages 3 - 28 * |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117892183A (zh) * | 2024-03-14 | 2024-04-16 | 南京邮电大学 | 一种基于可靠迁移学习的脑电信号识别方法及系统 |
CN117892183B (zh) * | 2024-03-14 | 2024-06-04 | 南京邮电大学 | 一种基于可靠迁移学习的脑电信号识别方法及系统 |
CN118334062A (zh) * | 2024-06-13 | 2024-07-12 | 江西师范大学 | 无源域自适应眼底图像分割方法和设备 |
Also Published As
Publication number | Publication date |
---|---|
CN117523295B (zh) | 2024-09-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111814854B (zh) | 一种无监督域适应的目标重识别方法 | |
CN109308318B (zh) | 跨领域文本情感分类模型的训练方法、装置、设备及介质 | |
CN110321926B (zh) | 一种基于深度残差修正网络的迁移方法及系统 | |
CN117523295B (zh) | 基于类引导元学习的无源域适应的图像分类方法 | |
CN112446423B (zh) | 一种基于迁移学习的快速混合高阶注意力域对抗网络的方法 | |
CN113469186B (zh) | 一种基于少量点标注的跨域迁移图像分割方法 | |
CN108921342B (zh) | 一种物流客户流失预测方法、介质和系统 | |
CN113312505B (zh) | 一种基于离散在线哈希学习的跨模态检索方法及系统 | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
Cholakov et al. | Transformers predicting the future. Applying attention in next-frame and time series forecasting | |
CN105701516B (zh) | 一种基于属性判别的自动图像标注方法 | |
Liang et al. | Deep multi-label learning for image distortion identification | |
Ahn et al. | Accurate online tensor factorization for temporal tensor streams with missing values | |
Chen et al. | Label-retrieval-augmented diffusion models for learning from noisy labels | |
Liu et al. | Modal-regression-based broad learning system for robust regression and classification | |
CN116561591B (zh) | 科技文献语义特征提取模型训练方法、特征提取方法及装置 | |
CN107993311B (zh) | 一种用于半监督人脸识别门禁系统的代价敏感隐语义回归方法 | |
Liu et al. | TCD-CF: Triple cross-domain collaborative filtering recommendation | |
CN116543237B (zh) | 无源域无监督域适应的图像分类方法、系统、设备及介质 | |
CN117333717A (zh) | 基于网络信息技术的安全监控方法及系统 | |
Guangyu | Analysis of sports video intelligent classification technology based on neural network algorithm and transfer Learning | |
CN115797642A (zh) | 基于一致性正则化与半监督领域自适应图像语义分割算法 | |
CN109284375A (zh) | 一种基于原始数据信息保留的域自适应降维方法 | |
Mi et al. | Visual relationship forecasting in videos | |
CN116756676A (zh) | 一种摘要生成方法及相关装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |