发明内容
本发明的目的在于克服现有技术的不足,提供一种基于类引导元学习的无源域适应的图像分类方法,通过对伪标签进行校正提升伪标签的质量,并通过类引导元学习来为每个样本学习一个权重,减少错误标签的噪声累计问题,同时根据类的置信度不同,减少了类的长尾分布引起的数据偏差问题,显著提升目标域模型分类的性能。
为了实现上述发明目的,本发明基于类引导元学习的无源域适应的图像分类方法包括以下步骤:
S1:根据实际需要选取已训练的源域特征提取器和源域特征分类器,采用其参数初始化目标域特征提取器和目标域特征分类器,然后将目标域特征提取器和目标域特征分类器复制一份,一组作为教师网络,一组作为学生网络,从而构成自训练师生网络;
S3:构建类别感知的元学习模块,包括损失计算模块、损失多层感知机、类伪准确率计算模块、类伪准确率多层感知机和置信度计算模块,其中:
损失计算模块用于根据学生网络对图像样本得到的预测标签计算预测标签损失L(w)并发送至损失多层感知机,其中w表示学生网络参数;
损失多层感知机用于根据图像样本x的预测标签损失L(w)生成对应的样本权重θ表示损失多层感知机的网络参数,S表示预设的权重维度,并将样本权重P(L(w),θ)发送至置信度权重计算模块;
类伪准确率计算模块用于根据教师网络对所有图像样本的预测标签计算每个图像样本在K个类别的类伪准确率pk并发送至类伪准确率多层感知机,k=1,2,…,K,K表示目标类别数量,类伪准确率pk的计算公式如下:
其中,表示教师网络对图像样本xi的预测标签,/>表示二值函数,当时/>否则/> 表示教师网络中目标域特征提取器/>对图像xi所提取得到的特征向量,/>表示目标域特征分类器/>对特征向量/>推断得到的可能性,/>表示采用softmax函数根据可能性/>所得到的概率;
类伪准确率多层感知机用于根据每个图像样本对应的K个类伪准确率pk生成该图像样本的类别权重 表示类伪准确率多层感知机的网络参数,并将类别权重/>发送至置信度权重计算模块;
置信度权重计算模块用于根据样本权重P(L(w),θ)和类别权重计算得到置信度权重/>
其中,表示两个向量之间的点积;
S3:将教师网络、学生网络和元学习模块构成元学习模型,其中教师网络和学生网络分别对输入图像进行预测得到预测标签,元学习模块根据教师网络和学生网络的预测标签计算得到置信度权重;
S4:令迭代次数t=1;
S5:将图像集合XT中每幅图像xi输入至教师网络中,得到图像xi属于类别k的概率γi,k,从而确定图像xi的分类结果并将其作为该图像的初始伪标签y′i,从而得到图像集合XT对应的伪标签集合Y′T,i=1,2,…,N,N表示图像数量,y′i∈[1,K];
S6:采用伪标签校正方法对伪标签集合Y′T进行处理,得到每幅图像xi校正后的伪标签yi,从而得到校正后的伪标签集合YT;
S7:对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用图像集合XT对元学习模型进行训练,上层任务和下层任务交替进行迭代更新,完成元学习模型的训练;
S8:判断是否t<tmax,tmax表示预测的最大迭代次数,如果是,进入步骤S9,否则进入步骤S10;
S9:令t=t+1,采用当前学生网络的参数对教师网络的参数进行更新,返回步骤S5;
S10:从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
本发明基于类引导元学习的无源域适应的图像分类方法,构建由教师网络和学生网络构成的自训练师生网络,并构建类别感知的元学习模块,将教师网络、学生网络和元学习模块构成元学习模型,采用教师网络得到图像集合中每个图像样本的伪标签并进行伪标签校正,对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用图像集合对元学习模型进行训练,然后采用学生网络的参数对教师网络的参数进行更新,重新获取伪标签,如此循环直到达到最大迭代次数,从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
本发明具有以下有益效果:
1)本发明可以在无法访问源域数据,依赖源域中的预训练模型对无标签的目标域图像进行分类,既保护了源域数据的隐私又解决了大规模源域数据计算存储资源的浪费,提升了无源域自适应效果;
2)本发明通过对伪标签进行校正提升伪标签的质量,从而提高学习效果;
3)本发明通过类引导元学习将来自类别的元知识和原始图像样本的损失信息作为补充,输出每个图像样本的伪标签置信度,减少了类的长尾分布引起的数据偏差问题,显著提升目标域模型分类的性能。
具体实施方式
下面结合附图对本发明的具体实施方式进行描述,以便本领域的技术人员更好地理解本发明。需要特别提醒注意的是,在以下的描述中,当已知功能和设计的详细描述也许会淡化本发明的主要内容时,这些描述在这里将被忽略。
实施例
图1是本发明基于类引导元学习的无源域适应的图像分类方法的具体实施方式流程图。如图1所示,本发明基于类引导元学习的无源域适应的图像分类方法的具体步骤包括:
S101:构建自训练师生网络:
根据实际需要选取已训练的源域特征提取器和源域特征分类器,采用其参数初始化目标域特征提取器和目标域特征分类器,然后将目标域特征提取器和目标域特征分类器复制一份,一组作为教师网络,一组作为学生网络,从而构成自训练师生网络。根据实际需要确定目标域的无标签图像,得到图像集合XT。在实际应用中可以对图像集合XT中的图像根据预设方法进行增强处理,以便更好地提取特征。
S102:构建类别感知的元学习模块:
虽然采用伪标签校正能够在很大程度上提升了伪标签的质量,但由于源域与目标域之间的域间隙这一客观问题,部分样本标签仍然存在错误分配问题,为了缓解伪标签错误累计导致性能下降,当模型训练时,应当抑制伪标签中具有低置信度的样本,因此本发明设置了类别感知的元学习模块,同时考虑了来自类别的元知识和样本的损失信息来学习每个样本的置信度权重。图2是本发明中类别感知的元学习模块的结构图。如图2所示,本发明中类别感知的元学习模块包括损失计算模块、损失多层感知机、类伪准确率计算模块、类伪准确率多层感知机和置信度计算模块,其中:
损失计算模块用于根据学生网络对图像样本得到的预测标签计算预测标签损失L(w)并发送至损失多层感知机,其中w表示学生网络参数。本实施例中预测标签损失采用常用的交叉熵损失,计算公式如下:
L(w)=-ylog(f(x,w))
其中,y表示图像样本x当前的标签,f(x,w)表示输入图像样本x由参数为w的学生网络得到的预测标签。
损失多层感知机用于根据图像样本的预测标签损失L(w)生成对应的样本权重θ表示损失多层感知机的网络参数,S表示预设的权重维度,并将样本权重P(L(w),θ)发送至置信度权重计算模块。
类伪准确率计算模块用于根据教师网络对所有图像样本的预测标签计算每个图像样本在K个类别的类伪准确率pk并发送至类伪准确率多层感知机,k=1,2,…,K,K表示目标类别数量,类伪准确率pk的计算公式如下:
其中,表示教师网络对图像样本xi的预测标签,/>表示二值函数,当时/>否则/> 表示教师网络中目标域特征提取器/>对图像xi所提取得到的特征向量,/>表示目标域特征分类器/>对特征向量/>推断得到的可能性(即logits),/>表示采用softmax函数根据可能性所得到的概率。
类伪准确率多层感知机用于根据每个图像样本对应的K个类伪准确率pk生成该图像样本的类别权重 表示类伪准确率多层感知机的网络参数,并将类别权重/>发送至置信度权重计算模块。例如中将S设置为3,相当于将K个类分成高置信度类、中置信度类以及低置信度类3组,使得类引导的元学习置信度模块能对不同组的类别赋予不同的类权重。
置信度权重计算模块用于根据样本权重P(L(w),θ)和类别权重计算得到置信度权重/>
其中,表示两个向量之间的点积。
S103:构建元学习模型:
将教师网络、学生网络和元学习模块构成元学习模型,其中教师网络和学生网络分别对输入图像进行预测得到预测标签,元学习模块根据教师网络和学生网络的预测标签计算得到置信度权重。
S104:令迭代次数t=1。
S105:采用教师网络获取伪标签:
将图像集合XT中每幅图像xi输入至教师网络中,得到图像xi属于类别k的概率γi,k,从而确定图像xi的分类结果并将其作为该图像的初始伪标签y′i,从而得到图像集合XT对应的伪标签集合Y′T,i=1,2,…,N,N表示图像数量,y′i∈[1,K]。
本实施例中,为了提高初始伪标签的准确性,采用加权平均的方法来确定伪标签,具体方法为:
对图像xi进行随机M次增强处理,然后将增强处理后的图像xi,m输入至教师网络得到预测标签,m=1,2,…,M,然后采用如下公式计算得到图像xi属于类别k的概率γi,k:
其中,表示教师网络的目标域特征提取器/>对图像xi,m提取得到的特征,表示教师网络的目标域特征分类器/>预测得到的图像xi,m属于类别k的概率。
选择概率γi,k最大值对应的类别序号作为图像xi的伪标签y′i。
S106:伪标签校正:
考虑到源域和目标域之间存在一定的域偏移,通过教师网络中目标域模型预测得到的伪标签集合YT存在大量的噪声,因此本发明中采用伪标签校正方法对伪标签集合Y′T进行处理,得到每幅图像xi校正后的伪标签yi,从而得到校正后的伪标签集合YT,进一步提升伪标签的质量。
伪标签校正的具体方法可以根据实际需要进行设置,例如常规伪标签校正是通过迭代更新类别的类中心来校正伪标签。经研究发现,当数据类别分布不均时,分类器常常对“大”类(属于该类的样本多)会产生一定的偏见,在分类时候倾向对“大”类分配更高的预测值,而基于机器学习中的平滑性假设在一定程度上可以缓解这种偏见,当两个样本之间的特征相似,则两个样本有较大概率分配相同的标签。因此为了提高校正后的伪标签的质量,本实施例中提出了一种基于信息约束的伪标签校正方法。图3是本实施例中基于信息约束的伪标签校正方法的流程图。如图3所示,本实施例中基于信息约束的伪标签校正方法的具体步骤包括:
S201:构建邻接矩阵:
计算图像集合XT图像的邻接矩阵A,其元素wi,j表示图像xi,xj之间的权重ai,j,计算公式如下:
其中,分别表示教师网络的目标域特征提取器/>对图像xi,xj提取得到的特征,在实际应用中如果伪标签采用加权平均的方式确定,特征可以采用M次特征的平均值。/>表示特征/>和/>的相似度,本实施例中采用将特征和/>展开得到的一维向量的余弦相似度作为特征之间的相似度。σ表示调优参数,e表示自然常数。
S202:基于熵值进行样本划分:
根据确定伪标签时每个图像xi属于类别k的概率γi,k计算得到每个图像对应的熵值entropy(xi):
根据图像的初始伪标签y′i划分得到每个类别的图像集合从图像集合/>中选择熵值排序最高的z张图像,z根据实际情况设置,将筛选出的图像作为有干净标签的样本,其余图像作为无标签样本,从而划分得到有干净标签的样本集合L,无标签样本集合U。
S203:采用标签传播算法更新标签:
采用标签传播算法更新无标签样本的标签,通过利用邻接矩阵的相邻点具有相似标签的假设将干净标签信息从L传播到U。考虑到无标签样本在步骤S105中已经具有了伪标签y′i,这些伪标签可以作为一种先验信息来约束标签的生成,因此本实施例中标签传播算法的目标函数设置如下:
其中,hi、hj表示图像xi、xj的标签构成的独热编码,yi表示来自集合L中干净的标签,μ表示约束参数。可以采用表示所有图像的标签独热编码构成的标签矩阵。
通过优化求解上述基于信息约束的标签传播目标损失函数,就可以完成各个图像标签的校正。本实施例中采用如下优化步骤得到对应的闭式解:
1)对损失函数中的第一项进一步展开为:
因为约束条件中,当i∈L时,hi=yi为已知值,所以第一项为常数,对于第二项可以进一步展开得到:
此外,令为对角矩阵,其对角线元素为/>j∈U;令矩阵的元素为aij,i∈L,j∈U,|L|,|U|分别表示集合L,集合U中样本数量。此外,根据约束条件,上式第一项也为常数,因此,上式可进一步表示为:
其中,表示由集合U中图像样本的标签独热编码构成的标签矩阵,表示由集合L中各个图像样本的标签独热编码构成的标签矩阵。
同理,对于第三项可以进一步展开得到:
此外,令为对角矩阵,其对角线元素为/>j∈U;令矩阵的元素为aij,i∈U,j∈U。因此,上式可进一步表示为:
根据上述分析,目标函数最后展开为:
其中,表示无标签样本集合U中所有图像在步骤S105过程中获得的伪标签对应的独热编码构成的标签矩阵。
对hU求偏导并令等式为0,可以得到目标函数的闭式解:
其中,I表示单位矩阵。从标签矩阵hU中提取得到无标签样本集合U各个样本校正后的伪标签,完成伪标签校正。
S107:训练元学习模型:
对于元学习模型,将自训练师生网络的模型训练损失函数作为元学习的下层任务,将元学习模块的置信度权重函数作为元学习的上层任务,采用数据集对元学习模型进行训练,上层任务和下层任务交替进行迭代更新,完成元学习模型的训练。
元学习模块的双层任务优化可以采用如下公式表示:
其中,表示干净数据集,||表示求取集合中图像样本数量,xmeta表示高质量数据集中的图像样本,ymeta表示图像样本xmeta的标签,/>表示学生网络对图像样本xmeta的预测标签。
由于本发明提出是针对无源域适应的图像分类任务,无法获取无偏估计的干净数据集,因此本实施例中采用输入扰动的预测一致性损失作为元知识来指导元学习上层任务的学习,基于预测一致性损失的元学习双层优化任务的表达式如下:
其中,表示从数据集XT中随机抽取的子数据集,xD表示数据集/>中的图像样本,A(xD)表示对图像样本/>进行数据增强处理后的图像样本,/> 分别表示学生网络对数据样本xD和增强后数据样本A(xD)的预测标签,κ()表示求取KL散度,作为度量输入扰动的预测一致性损失。/>表示图像样本xi的预测标签损失L(w),/>表示图像样本xi对应的置信度权重。
S108:判断是否t<tmax,tmax表示预测的最大迭代次数,如果是,进入步骤S109,否则进入步骤S110。
S109:令t=t+1,采用当前学生网络的参数对教师网络的参数进行更新,返回步骤S105。
S110:利用目标域模型进行图像分类:
从最终的元学习模型中提取学生网络作为目标域模型,并利用该目标域模型对目标域数据图像进行分类。
为了更好地说明本发明的技术效果,采用具体实例对本发明进行实验验证。在本实验验证中,使用无源域适应领域常规测试基础数据集Office-31,其由三个不同域组成,包括由亚马逊商家图像域(Amazon),网络摄像头收集的低分辨率图像域(Webcam)以及单反相机拍摄的高解析度图像域(DSLR),该数据集包含了31类常见的办公物体,如笔记本电脑、文件柜、键盘等,共4652张图像。为了全面体现本发明的优势,设计了六种无源域适应任务,即Amazon→DSLR,Amazon→Webcam,DSLR→Webcam,DSLR→Amazon,Webcam→Amazon,Webcam→DSLR。
本次实验验证中设置4种对比方法,分别为ResNet50,DANN(Domain-adversarialtraining of neural networks),CDAN(Conditional adversarial domain adaptation)以及SHOT(Source hypothesis transfer for unsupervised domain adaptation)。
本发明由PyTorch实现,并在NVIDIARTX3090 GPU上进行了训练,本发明使用的自训练师生网络中教师网络和学生网络采用ResNet50。
表1是本实施例中本发明和对比方法在不同任务下的分类正确率统计表。
表1
如表1所示,从表1中的结果可以看出,本发明在Office-31数据集的6种自适应任务中均取得了最好的效果,从而验证本发明的有效性。
尽管上面对本发明说明性的具体实施方式进行了描述,以便于本技术领域的技术人员理解本发明,但应该清楚,本发明不限于具体实施方式的范围,对本技术领域的普通技术人员来讲,只要各种变化在所附的权利要求限定和确定的本发明的精神和范围内,这些变化是显而易见的,一切利用本发明构思的发明创造均在保护之列。