CN116452862A - 基于领域泛化学习的图像分类方法 - Google Patents
基于领域泛化学习的图像分类方法 Download PDFInfo
- Publication number
- CN116452862A CN116452862A CN202310325449.4A CN202310325449A CN116452862A CN 116452862 A CN116452862 A CN 116452862A CN 202310325449 A CN202310325449 A CN 202310325449A CN 116452862 A CN116452862 A CN 116452862A
- Authority
- CN
- China
- Prior art keywords
- domain
- generalization
- attention
- output
- network model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 82
- 238000012549 training Methods 0.000 claims abstract description 55
- 230000035945 sensitivity Effects 0.000 claims abstract description 41
- 230000006870 function Effects 0.000 claims abstract description 40
- 238000013527 convolutional neural network Methods 0.000 claims abstract description 8
- 238000004364 calculation method Methods 0.000 claims description 29
- 239000013598 vector Substances 0.000 claims description 29
- 238000011176 pooling Methods 0.000 claims description 24
- 230000007246 mechanism Effects 0.000 claims description 22
- 230000008569 process Effects 0.000 claims description 17
- 238000009826 distribution Methods 0.000 claims description 15
- 238000005457 optimization Methods 0.000 claims description 15
- 230000009466 transformation Effects 0.000 claims description 15
- 238000007781 pre-processing Methods 0.000 claims description 12
- 238000012545 processing Methods 0.000 claims description 9
- 238000000605 extraction Methods 0.000 claims description 8
- 230000008485 antagonism Effects 0.000 claims description 6
- 238000010586 diagram Methods 0.000 claims description 5
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000012935 Averaging Methods 0.000 claims description 3
- 238000010276 construction Methods 0.000 claims description 3
- 238000010606 normalization Methods 0.000 claims description 3
- 230000008521 reorganization Effects 0.000 claims description 3
- 238000006243 chemical reaction Methods 0.000 claims description 2
- 230000004927 fusion Effects 0.000 claims description 2
- 230000006872 improvement Effects 0.000 claims description 2
- 230000005012 migration Effects 0.000 abstract description 6
- 238000013508 migration Methods 0.000 abstract description 6
- 230000006978 adaptation Effects 0.000 description 4
- 238000011160 research Methods 0.000 description 3
- 230000004069 differentiation Effects 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 239000003086 colorant Substances 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/84—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using probabilistic graphical models from image or video features, e.g. Markov models or Bayesian networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于领域泛化学习的图像分类方法,包括:将目标域数据和源域数据与两者对应的标签信息记录在目标域数据集DT和源域数据集DS中;使用卷积神经网络Mobile V2构建领域鉴别器,生成领域分数;对领域泛化网络模型中基本ResNet网络进行改进,并构建混合域注意力;使用局部最大灵敏度代替局部随机灵敏度;构建损失函数,训练并优化网络模型;将训练好的领域泛化网络模型在目标域数据集DT上计算图像分类准确率,评判当前网络模型的图像分类能力和泛化能力。本发明缓解了深度领域泛化方法普遍出现的由于虚假的相关性带来的负迁移问题,增强网络模型泛化能力,提高了图像分类准确率,进一步可以实现更灵活和精确的下游应用。
Description
技术领域
本发明涉及图像分类和领域泛化学习的技术领域,尤其是指一种基于领域泛化学习的图像分类方法。
背景技术
近年来,深度学习在计算机视觉领域中的优越表现使图像分类受到广泛关注,图像分类方法主要是利用足够的有标注数据准确地训练出高可用度的分类器模型的,然而对于实际图像分类应用场景来说,有标注数据的获得是非常困难的。而迁移学习缓解了图像分类任务中瓶颈。迁移学习的核心思想是合理利用旧知识和新知识(如数据、任务或模型)之间的相似性,构建一种从旧知识到新知识的迁移桥梁,从而可以更快更好地学习新知识,完成当前任务。在图像分类任务中,利用迁移学习思想,寻找与目标数据(目标领域数据)相近的有标注的数据(源领域数据)构建模型,利用有标记数据和目标数据之间的相似性,更快更好地完成当前图像分类任务。
目前,迁移学习的主要研究方向是领域自适应和领域泛化,且都有大量的研究。领域自适应强调的是利用源领域丰富的有监督知识来增强目标领域的训练,解决目标领域上的学习问题。领域自适应需要大量有标记的源领域数据和大量目标领域数据进行模型训练,然而在实际场景中,目标领域数据可能非常少或我们对目标领域数据一无所知,这时领域自适应就会受到很大局限。与领域适应不同的是,领域泛化考虑的是一个更有实际意义的研究场景,其假设目标领域的样本在训练过程中是不可用的,目的是在多个源领域中学习一个领域不变的模型,使其可以直接泛化到未知的目标领域上,从而实现有效的知识重用。域泛化的重点是强调训练后的模型对任何未知目标域都有一定的学习能力。因此,领域泛化更适用于现实世界的应用场景。
现阶段的领域泛化学习在图像分类中的方法可分为非深度领域泛化方法与深度领域泛化方法两类。非深度领域泛化方法侧重于从低维的图像中提取浅层特征,提取的图像特征所表达的语义信息有限,无法得到优秀的网络模型。深度领域泛化方法可以从复杂高维的图像中提取丰富的具有代表性的特征,训练出更理想的网络模型。然而大多数深度领域泛化方法都是基于常用的几个大规模预训练卷积神经网络模型提出的,由ImageNet预训练后的卷积神经网络在图像分类中表现出对于颜色和纹理更加偏好的现象,对于像手绘线稿,卡通漫画这类非以形状表征为主的图像来说,神经网络图像分类性能会较下降;而且大多数深度领域泛化方法忽略了网络模型所提取的特征是否所有的特征都有助于提高预测的准确性,关注所有的视觉特征,很容易导致网络模型对非相关特征过度关注,导致网络模型泛化能力下降。
发明内容
本发明的目的在于克服现有技术的缺点与不足,提出了一种基于领域泛化学习的图像分类方法,能够提高神经网络对形状的偏好,使网络模型提取的特征更加鲁棒,同时在特征提取过程中,提高网络模型对任务相关特征的关注,抑制任务不相关特征,缓解了虚假的相关性带来的负迁移问题,并进一步降低模型对输入的小扰动的敏感性,增强模型泛化能力,提高图像分类准确度。
为实现上述目的,本发明所提供的技术方案为:基于领域泛化学习的图像分类方法,包括以下步骤:
1)将图像数据集中所有图像数据预处理为张量格式,选择预处理后的数据集中一个域作为目标域数据,剩余所有域作为源域数据,将目标域数据和源域数据与两者的对应真实标签信息记录在目标域数据集DT和源域数据集DS中,其中,T表示目标域,S表示源域,真实标签信息包括当前数据的领域类别信息、语义类别信息和拼图排列类别信息;
2)使用轻量级的卷积神经网络Mobile V2构建领域泛化网络模型中的领域鉴别器,依据步骤1)中得到源域数据集DS训练卷积神经网络Mobile V2,并根据源域数据集DS中所有数据的领域类别信息生成领域分数Wdc,其中,Wdc是由领域鉴别器提取得到的二维向量数据;
3)对领域泛化网络模型中的基本ResNet网络进行改进,对基本ResNet网络进行三部分改进,第一部分是在基本ResNet网络的第一层卷积层conv1前加入领域鉴别器;第二部分是构建ResNet网络的输出层,分别是语义类别输出、领域类别输出和拼图排列类别输出;第三部分是在残差模块前后分别加入混合域注意力机制,具体为改进的CBAM注意力机制;对CBAM注意力机制的改进是将领域分数分别输入到通道注意力、空间注意力两个模块中,其中,所述通道注意力和空间注意力是CBAM的两个模块;将领域分数进行不同的维度转换后分别与CBAM的两个模块中提取得到的特征进行融合操作;
4)改进泛化误差模型LGEM,使用局部最大敏感度代替局部随机敏感度;所述敏感度是一种度量,测量训练样本zj与其Q邻域范围内不可见样本zu之间的输出差异;其中Q是一个超参数,Q邻域是超立方体;zj是源域数据集DS中第j个训练样本,zu是zj加了扰动后生成的样本且zu在zj的邻域范围内;
5)基于领域鉴别器和改进ResNet网络构建一种新型的领域泛化网络模型并训练,基于改进ResNet网络的输出层构建损失函数,并在损失函数中引入步骤4)中获得的局部最大敏感度,训练并优化该领域泛化网络模型;
6)利用训练好的领域泛化网络模型在目标域数据集DT上计算图像分类准确率,用来评判当前领域泛化网络模型的图像分类能力和泛化能力。
进一步,在步骤1)中,所述预处理是使用transforms对目标域数据集DT和源域数据集DS中图像数据分别进行预处理操作,包括图像数据亮度处理、图像数据饱和度处理和图像数据像素值归一化处理,并将图像数据转化为能够直接输入神经网络的张量格式;同时,源域数据集DS还会按照超参数b的比率,选取b%的图像数据,并将每个图像数据都分割成k块补丁,移动补丁重新组装成与原始图像尺寸相同的图像数据;打乱并重组会产生k!种排列方式,为了减少计算,选择P种拼图排列方式;在数据预处理同时,也将图像数据与对应真实标签信息记录在目标域数据集DT和源域数据集DS中,其中,当前数据的领域类别信息共有N类、当前数据的语义类别信息共有Y类、当前数据的拼图排列类别信息共有P类。
进一步,在步骤2)中,使用源域数据集DS中的图像数据作为输入,使用特征提取方法对图像数据进行特征提取,将提取后的特征图输入到领域鉴别器中,映射到领域类别信息,获得领域分数;领域分数Wdc计算如下:
式中,表示实数,z是来自源域数据集DS的训练样本,Gex(·)是特征提取器,Fdc(·)是领域鉴别器,领域分数的维度是1×N。
进一步,所述步骤3)包括以下步骤:
3.1)将基本ResNet网络的第一层卷积层conv1前加入领域鉴别器,用于获得领域分数Wdc;然后构建ResNet网络的输出层,在基本ResNet网络的语义类别输出的基础上,增加领域类别输出和拼图排列类别输出;最后,在残差模块前后分别加入混合域注意力机制,具体为改进的CBAM注意力机制;
3.2)构建混合域注意力机制中的领域通道注意力模块;使用源域数据集DS中的图像数据作为输入,先由ResNet网络中的第一层卷积层conv1对图像数据进行特征提取,得到的中间特征图X,然后将中间特征图X和领域分数Wdc同时输入到领域通道注意力模块;在领域通道注意力模块中,将输入的中间特征图X,分别经过最大值池化和平均值池化压缩中间特征图的空间维度,得到通道特征向量和/>其中,中间特征图X的维度是H×W×C,H表示中间特征图的高度,W表示中间特征图的宽度,C表示中间特征图的通道数,通道特征向量/>和/>的维度都是1×1×C;由于领域分数与通道特征向量维度不同,所以需要进行维度变换操作,将Wdc与/>和/>保持相同的特征维度1×1×C,然后将Wdc、/>和/>进行逐元素求和合并,以产生领域通道注意力特征图Wca,具体计算如下所示:
式中,表示实数,Gews(·)是逐元素求和合并操作,Gsh1(·)是维度变换操作,和/>是最大值池化和平均值池化后输出的通道特征向量,领域通道注意力特征图Wca的维度是1×1×C;
然后,将领域通道注意力特征图Wca与中间特征图X相乘,完成中间特征图加权操作,以实现自适应特征优化,具体计算如下所示:
式中,表示实数,Gwt(·)是加权操作,X'是领域通道注意力加权的中间特征图,维度是H×W×C;
3.3)构建混合域注意力机制中的领域空间注意力模块;将领域通道注意力加权的中间特征图X'和领域分数Wdc同时输入领域空间注意力模块;在领域空间注意力模块中,将输入的领域通道注意力加权的中间特征图X',分别经过最大值池化和平均值池化压缩通道维度,得到空间特征向量和/>其中,领域通道注意力加权的中间特征图X'的维度是H×W×C,空间特征向量/>和/>的维度都是H×W×1;由于领域分数与空间特征向量维度不同,所以需要进行维度变换操作,将Wdc与/>和/>保持相同的特征维度H×W×1,然后将Wdc、/>和/>基于通道方向进行连接操作,再经过一个卷积操作将H×W×3的特征图降维为1个通道的领域空间注意力特征图Wsa,具体计算如下所示:
式中,表示实数,Gcov(·)是卷积操作,Gsh2(·)是维度变换操作,/>和/>是最大值池化和平均值池化后输出的空间特征向量,领域空间注意力特征图Wsa的维度是H×W×1;
然后,将领域空间注意力特征图Wsa与中间特征图X相乘,完成中间特征图加权操作,以实现自适应特征优化,具体计算如下所示:
式中,表示实数,Gwt(·)是加权操作,X”是领域空间注意力加权的中间特征图,维度是H×W×C。
进一步,在步骤4)中,所述Q是一个给定的超参数,其中扰动Δz的取值范围表示为Δzi为扰动Δz的第i个输入特征,n为输入特征的数量;源域数据集DS中第j个训练样本zj在Q邻域范围内对应的不可见样本数据集合SQ(zj)表示为SQ(zj)={zu|zu=zj+Δz;|Δzi|<Q};训练样本zj和不可见样本zu在Q邻域范围内的输出差异为局部随机敏感度ESM,计算如下:
式中,M表示Q邻域内获取的扰动的个数,m表示第m个扰动,表示训练样本的输出分布;
由于局部随机敏感度是测量多个不可见样本与训练样本之间输出差异的平均值,平均化会忽略敏感度极值情况,因此使用局部最大敏感度代替局部随机敏感度;使用对抗方法,寻找一个能够让不可见样本与训练样本之间的输出差异最大化的对抗扰动radv;所述对抗扰动对于对抗的一方来说,是指加入扰动后预测不可见样本的结果和真实的结果差别最大化,而对于模型来说,是指加入扰动后预测不可见样本的结果和真实的结果一致,差别最小化;所以对抗扰动radv计算如下:
radv:=argmaxD[p(yj|zj),p(yj|zj+r,θ)],||r||≤ε
LLMS=D[p(yj|zj),p(yj|zj+radv,θ)]
式中,对抗扰动radv的求解是通过最大化加入扰动后的不可见样本与训练样本之间的输出差异,其中r表示扰动,θ表示领域泛化网络模型的参数,ε表示对抗方向的规范约束,控制扰动选取的界限,D[p(yj|zj),p(yj|zj+r,θ)]表示两个分布之间的Kullback-Leibler误差,其中p(yj|zj)表示语义类别信息为y的第j个训练样本xj的真实分布,p(yj|zj+r,θ)表示语义类别信息为y的第j个训练样本xj在添加扰动r后由模型参数为θ的领域泛化网络模型产生的输出分布;yj表示源域数据集DS中第j个训练样本zj的语义类别信息;LLMS表示局部最大敏感度,p(yj|zj+radv,θ)表示语义类别信息为y的第j个训练样本xj在添加对抗扰动radv后由模型参数为θ的领域泛化网络模型产生的输出分布;
对于对抗扰动的计算是一个优化问题,很难找到这样一个精确的值表示对抗扰动,所以用以下的线性估计方法在正梯度方向上得到对抗扰动radv的近似解:
式中,g表示正梯度,表示对加入扰动后的不可见样本与训练样本之间的输出差异计算梯度,其中梯度是指函数的微分,代表着函数在给定点的切线的斜率,正梯度表示函数在给定点的上升最快的方向。
进一步,在步骤5)中,总损失函数包括基于改进ResNet网络输出层构建的三种损失函数和步骤4)中获得的局部最大敏感度,总损失函数Ltotal计算如下:
Ltotal=Lclf+αLjig+βLdc+γLLMS
式中,α,β,γ是三个超参数,用来平衡所有的损失函数;Lclf表示数据集语义类别输出损失函数,Ldc表示领域类别输出损失函数,Ljig表示拼图排列类别输出损失函数,LLMS表示局部最大敏感度,Lclf、Ldc和Ljig都是基于交叉熵损失计算,而LLMS是基于Kullback-Leibler计算的;其中,所述损失函数是指计算领域泛化网络模型对输入图像数据的预测标签信息与真实标签信息之间的差距;真实标签信息在步骤1)中,经过预处理时已经保存在数据集中;得到损失函数后,再通过梯度下降算法求解模型参数,最终得到训练并优化的领域泛化网络模型,该网络模型可以完成图像分类任务。
进一步,在步骤6)中,将训练好的领域泛化网络模型直接泛化到目标域数据集DT上,在目标域数据集DT上完成图像分类任务并计算图像分类准确率,不需要再进行领域泛化网络模型的训练和优化,通过图像分类准确率来评判当前领域泛化网络模型的图像分类能力和泛化能力。
本发明与现有技术相比,具有如下优点与有益效果:
1、本发明方法增强了领域泛化网络模型对形状的偏好,使网络模型提取的特征更加鲁棒,进一步提高网络模型图像分类能力。
2、本发明方法提出了混合域注意力机制,通过在特征提取过程中,提高领域泛化网络模型对任务相关特征的关注,抑制任务不相关特征,来缓解了虚假的相关性带来的负迁移问题。
3、本发明方法提出局部最大敏感度,通过使网络模型对输入的小扰动鲁棒,降低网络模型对输入的小扰动的敏感性,进而提高网络模型图像分类能力和泛化能力。
4、本发明方法在图像分类任务中具有广泛的使用空间,操作简单、适应性强,具有广阔的应用前景。
总之,本发明方法能够通过多个源域数据学习一个领域泛化网络模型,可以直接泛化到未知的目标领域上,从而实现有效的知识重用,缓解了实际图像分类应用场景中的瓶颈。本发明缓解了图像分类方法中深度领域泛化方法普遍出现的由于虚假的相关性带来的负迁移问题,同时降低网络模型对输入的小扰动的敏感性,增强网络模型泛化能力,提高了图像分类准确率,进一步可以实现更灵活和精确的下游应用。
附图说明
图1为本发明逻辑流程示意图。
图2为本发明领域泛化模型网络结构图。
图3为本发明混合域注意力机制的结构图。
具体实施方式
下面结合实施例及附图对本发明作进一步详细的描述,但本发明的实施方式不限于此。
如图1所示,本实施例公开了一种基于领域泛化学习的图像分类方法,其包括以下步骤:
1)使用领域泛化公开数据集PACS,将其中一个域作为目标域数据,剩余所有域作为源域数据,此时,目标域数据包含1个域,源域数据包含3个域,公开数据集不限于PACS;将目标域数据和源域数据与两者的对应真实标签信息记录在目标域数据集DT和源域数据集DS中,其中真实标签信息包括当前数据的领域类别信息、语义类别信息和拼图排列类别信息;对于数据集PACS,目标域数据集领域类别信息共有N=1类,源域数据集领域类别信息共有N=3类;目标域数据集和源域数据集数据语义类别信息相同,数据语义类别信息共有Y=7类;目标域数据集拼图排列类别信息共有P=0类,源域数据集拼图排列类别信息共有P类;使用transforms预处理方法对目标域数据集DT和源域数据集DS中图像数据分别进行预处理操作,包括图像数据亮度处理、图像数据饱和度处理、图像数据像素值归一化处理,并将图像数据转化为225×225的张量格式;同时,源域数据集DS还会按照超参数b的比率,选取b%的图像数据,并将每个图像数据都分割成9块补丁,移动补丁重新组装成与原始图像尺寸相同的图像数据;打乱并重组会产生9!种排列方式,为了减少计算,选择P=30种拼图排列方式。
2)如图2所示,构建领域泛化模型网络,使用轻量级的卷积神经网络Mobile V2构建领域鉴别器;使用源域数据集DS中的图像数据作为输入,使用特征提取方法对图像数据进行特征提取,将提取后的特征图输入到领域鉴别器中,映射到领域类别信息,获得领域分数;领域分数Wdc计算如下:
式中,表示实数,z是来自源域数据集DS的训练样本,Gex(·)是特征提取器,Fdc(·)是领域鉴别器,领域分数的维度是1×N。
3)如图2所示,构建领域泛化模型网络,对领域泛化网络模型中的基本ResNet网络进行改进;领域泛化模型网络使用ResNet网络作为骨干网络,由基本ResNet网络和混合域注意力机制(具体为改进的CBAM注意力机制)组成;其中骨干网络不限于ResNet网络,具体实施过程中也在AlexNet网络上进行实验验证;改进ResNet网络包括以下步骤:
3.1)将基本ResNet网络的第一层卷积层conv1前加入领域鉴别器,用于获得领域分数Wdc;然后构建ResNet网络的输出层,在基本ResNet网络的语义类别输出的基础上,增加领域类别输出和拼图排列类别输出;最后,在残差模块前后分别加入混合域注意力机制,具体为改进的CBAM注意力机制;
3.2)构建混合域注意力机制中的领域通道注意力模块;使用源域数据集DS中的图像数据作为输入,先由ResNet网络中第一层卷积层conv1对图像数据进行特征提取,得到的中间特征图X,然后将中间特征图X和领域分数Wdc同时输入到领域通道注意力模块;如图3所示,在领域通道注意力模块中,将输入的中间特征图X,分别经过最大值池化和平均值池化压缩中间特征图的空间维度,得到通道特征向量和/>其中,中间特征图X的维度是H×W×C,其中H表示中间特征图的高度,W表示中间特征图的宽度,C表示中间特征图的通道数,通道特征向量/>和/>的维度都是1×1×C;由于领域分数与通道特征向量维度不同,所以需要进行维度变换操作,将Wdc与/>和/>保持相同的特征维度1×1×C,然后将Wdc、/>和/>进行逐元素求和合并,以产生领域通道注意力特征图Wca,具体计算如下所示:
式中,表示实数,Gews(·)是逐元素求和合并操作,Gsh1(·)是维度变换操作,和/>是最大值池化和平均值池化后输出的通道特征向量,领域通道注意力特征图Wca的维度是1×1×C;
然后,将领域通道注意力特征图Wca与中间特征图X相乘,完成中间特征图加权操作,以实现自适应特征优化,具体计算如下所示:
式中,表示实数,Gwt(·)是加权操作,X'是领域通道注意力加权的中间特征图,维度是H×W×C;
3.3)构建混合域注意力机制中的领域空间注意力模块;如图3所示,将领域通道注意力加权的中间特征图X'和领域分数Wdc,同时输入领域空间注意力模块;在领域空间注意力模块中,将输入的领域通道注意力加权的中间特征图X',分别经过最大值池化和平均值池化压缩通道维度,得到空间特征向量和/>其中,领域通道注意力加权的中间特征图X'的维度是H×W×C,空间特征向量/>和/>的维度都是H×W×1;由于领域分数与空间特征向量维度不同,所以需要进行维度变换操作,将Wdc与/>和/>保持相同的特征维度H×W×1,然后将Wdc、/>和/>基于通道方向进行连接操作,再经过一个卷积操作将H×W×3的特征图降维为1个通道的领域空间注意力特征图Wsa,具体计算如下所示:
式中,表示实数,Gcov(·)是卷积操作,Gsh2(·)是维度变换操作,/>和/>是最大值池化和平均值池化后输出的空间特征向量,领域空间注意力特征图Wsa的维度是H×W×1;
然后,将领域空间注意力特征图Wsa与中间特征图X相乘,完成中间特征图加权操作,以实现自适应特征优化,具体计算如下所示:
式中,表示实数,Gwt(·)是加权操作,,X”是领域空间注意力加权的中间特征图,维度是H×W×C。
4)改进泛化误差模型LGEM,使用局部最大敏感度代替局部随机敏感度;所述敏感度是一种度量,测量训练样本zj与其Q邻域范围内不可见样本zu之间的输出差异;Q是一个给定的超参数,此时Q=1.0;其中扰动Δz的取值范围表示为Δzi为扰动Δz的第i个输入特征,n为输入特征的数量;源域数据集DS中第j个训练样本zj在Q邻域范围内对应的不可见样本数据集合SQ(zj)表示为SQ(zj)={zu|zu=zj+Δz;|Δzi|<Q};训练样本zj和不可见样本zu在Q邻域范围内的输出差异为局部随机敏感度ESM,计算如下:
式中,M表示Q邻域内获取的扰动的个数,m表示第m个扰动,表示训练样本的输出分布;
由于局部随机敏感度是测量多个不可见样本与训练样本之间输出差异的平均值,平均化会忽略敏感度极值情况,因此使用局部最大敏感度代替局部随机敏感度;使用对抗方法,寻找一个能够让不可见样本与训练样本之间的输出差异最大化的对抗扰动radv;所述对抗扰动对于对抗的一方来说,是指加入扰动后预测不可见样本的结果和真实的结果差别最大化,而对于模型来说,是指加入扰动后预测不可见样本的结果和真实的结果一致,差别最小化;所以对抗扰动radv计算如下:
radv:=argmaxD[p(yj|zj),p(yj|zj+r,θ)],||r||≤ε
LLMS=D[p(yj|zj),p(yj|zj+radv,θ)]
式中,对抗扰动radv的求解是通过最大化加入扰动后的不可见样本与训练样本之间的输出差异,其中r表示扰动,θ表示领域泛化网络模型的参数,ε表示对抗方向的规范约束,控制扰动选取的界限,此时ε=1.0,D[p(yj|zj),p(yj|zj+r,θ)]表示两个分布之间的Kullback-Leibler误差,其中p(yj|zj)表示语义类别信息为y的第j个训练样本xj的真实分布,p(yj|zj+r,θ)表示语义类别信息为y的第j个训练样本xj在添加扰动r后由模型参数为θ的领域泛化网络模型产生的输出分布;yj表示源域数据集DS中第j个训练样本zj的语义类别信息;LLMS表示局部最大敏感度,p(yj|zj+radv,θ)表示语义类别信息为y的第j个训练样本xj在添加对抗扰动radv后由模型参数为θ的领域泛化网络模型产生的输出分布;
对于对抗扰动的计算是一个优化问题,很难找到这样一个精确的值表示对抗扰动,所以用以下的线性估计方法在正梯度方向上得到对抗扰动radv的近似解:
式中,g表示正梯度,表示对加入扰动后的不可见样本与训练样本之间的输出差异计算梯度,其中梯度是指函数的微分,代表着函数在给定点的切线的斜率,正梯度表示函数在给定点的上升最快的方向。
5)领域泛化网络模型的总损失函数包括基于改进ResNet网络输出层构建的三种损失函数和步骤4)中获得的局部最大敏感度,总损失函数Ltotal计算如下:
Ltotal=Lclf+αLjig+βLdc+γLLMS
式中,α,β,γ是三个超参数,用来平衡所有的损失函数;Lclf表示数据集语义类别输出损失函数,Ldc表示领域类别输出损失函数,Ljig表示拼图排列类别输出损失函数,LLMS表示局部最大敏感度,Lclf,Ldc和Ljig都是基于交叉熵损失计算,而LLMS是基于Kullback-Leibler计算的;其中,所述损失函数是指计算领域泛化网络模型对输入图像数据的预测标签信息与真实标签信息之间的差距;得到损失函数后,再通过梯度下降算法求解模型参数,最终得到训练并优化的领域泛化网络模型,该网络模型可以完成图像分类任务。
6)将训练好的领域泛化网络模型直接泛化到目标域数据集DT上,在目标域数据集DT上完成图像分类任务并计算图像分类准确率,不需要再进行领域泛化网络模型的训练和优化,通过图像分类准确率来评判当前领域泛化网络模型的图像分类能力和泛化能力。
上述实施例为本发明较佳的实施方式,但本发明的实施方式并不受上述实施例的限制,其他的任何未背离本发明的精神实质与原理下所作的改变、修饰、替代、组合、简化,均应为等效的置换方式,都包含在本发明的保护范围之内。
Claims (7)
1.基于领域泛化学习的图像分类方法,其特征在于,包括以下步骤:
1)将图像数据集中所有图像数据预处理为张量格式,选择预处理后的数据集中一个域作为目标域数据,剩余所有域作为源域数据,将目标域数据和源域数据与两者的对应真实标签信息记录在目标域数据集DT和源域数据集DS中,其中,T表示目标域,S表示源域,真实标签信息包括当前数据的领域类别信息、语义类别信息和拼图排列类别信息;
2)使用轻量级的卷积神经网络Mobile V2构建领域泛化网络模型中的领域鉴别器,依据步骤1)中得到源域数据集DS训练卷积神经网络Mobile V2,并根据源域数据集DS中所有数据的领域类别信息生成领域分数Wdc,其中,Wdc是由领域鉴别器提取得到的二维向量数据;
3)对领域泛化网络模型中的基本ResNet网络进行改进,对基本ResNet网络进行三部分改进,第一部分是在基本ResNet网络的第一层卷积层conv1前加入领域鉴别器;第二部分是构建ResNet网络的输出层,分别是语义类别输出、领域类别输出和拼图排列类别输出;第三部分是在残差模块前后分别加入混合域注意力机制,具体为改进的CBAM注意力机制;对CBAM注意力机制的改进是将领域分数分别输入到通道注意力、空间注意力两个模块中,其中,所述通道注意力和空间注意力是CBAM的两个模块;将领域分数进行不同的维度转换后分别与CBAM的两个模块中提取得到的特征进行融合操作;
4)改进泛化误差模型LGEM,使用局部最大敏感度代替局部随机敏感度;所述敏感度是一种度量,测量训练样本zj与其Q邻域范围内不可见样本zu之间的输出差异;其中Q是一个超参数,Q邻域是超立方体;zj是源域数据集DS中第j个训练样本,zu是zj加了扰动后生成的样本且zu在zj的邻域范围内;
5)基于领域鉴别器和改进ResNet网络构建一种新型的领域泛化网络模型并训练,基于改进ResNet网络的输出层构建损失函数,并在损失函数中引入步骤4)中获得的局部最大敏感度,训练并优化该领域泛化网络模型;
6)利用训练好的领域泛化网络模型在目标域数据集DT上计算图像分类准确率,用来评判当前领域泛化网络模型的图像分类能力和泛化能力。
2.根据权利要求1所述的基于领域泛化学习的图像分类方法,其特征在于,在步骤1)中,所述预处理是使用transforms对目标域数据集DT和源域数据集DS中图像数据分别进行预处理操作,包括图像数据亮度处理、图像数据饱和度处理和图像数据像素值归一化处理,并将图像数据转化为能够直接输入神经网络的张量格式;同时,源域数据集DS还会按照超参数b的比率,选取b%的图像数据,并将每个图像数据都分割成k块补丁,移动补丁重新组装成与原始图像尺寸相同的图像数据;打乱并重组会产生k!种排列方式,为了减少计算,选择P种拼图排列方式;在数据预处理同时,也将图像数据与对应真实标签信息记录在目标域数据集DT和源域数据集DS中,其中,当前数据的领域类别信息共有N类、当前数据的语义类别信息共有Y类、当前数据的拼图排列类别信息共有P类。
3.根据权利要求2所述的基于领域泛化学习的图像分类方法,其特征在于,在步骤2)中,使用源域数据集DS中的图像数据作为输入,使用特征提取方法对图像数据进行特征提取,将提取后的特征图输入到领域鉴别器中,映射到领域类别信息,获得领域分数;领域分数Wdc计算如下:
式中,表示实数,z是来自源域数据集DS的训练样本,Gex(·)是特征提取器,Fdc(·)是领域鉴别器,领域分数的维度是1×N。
4.根据权利要求3所述的基于领域泛化学习的图像分类方法,其特征在于,所述步骤3)包括以下步骤:
3.1)将基本ResNet网络的第一层卷积层conv1前加入领域鉴别器,用于获得领域分数Wdc;然后构建ResNet网络的输出层,在基本ResNet网络的语义类别输出的基础上,增加领域类别输出和拼图排列类别输出;最后,在残差模块前后分别加入混合域注意力机制,具体为改进的CBAM注意力机制;
3.2)构建混合域注意力机制中的领域通道注意力模块;使用源域数据集DS中的图像数据作为输入,先由ResNet网络中的第一层卷积层conv1对图像数据进行特征提取,得到的中间特征图X,然后将中间特征图X和领域分数Wdc同时输入到领域通道注意力模块;在领域通道注意力模块中,将输入的中间特征图X,分别经过最大值池化和平均值池化压缩中间特征图的空间维度,得到通道特征向量和/>其中,中间特征图X的维度是H×W×C,H表示中间特征图的高度,W表示中间特征图的宽度,C表示中间特征图的通道数,通道特征向量和/>的维度都是1×1×C;由于领域分数与通道特征向量维度不同,所以需要进行维度变换操作,将Wdc与/>和/>保持相同的特征维度1×1×C,然后将Wdc、/>和/>进行逐元素求和合并,以产生领域通道注意力特征图Wca,具体计算如下所示:
式中,表示实数,Gews(·)是逐元素求和合并操作,Gsh1(·)是维度变换操作,/>和是最大值池化和平均值池化后输出的通道特征向量,领域通道注意力特征图Wca的维度是1×1×C;
然后,将领域通道注意力特征图Wca与中间特征图X相乘,完成中间特征图加权操作,以实现自适应特征优化,具体计算如下所示:
式中,表示实数,Gwt(·)是加权操作,X'是领域通道注意力加权的中间特征图,维度是H×W×C;
3.3)构建混合域注意力机制中的领域空间注意力模块;将领域通道注意力加权的中间特征图X'和领域分数Wdc同时输入领域空间注意力模块;在领域空间注意力模块中,将输入的领域通道注意力加权的中间特征图X',分别经过最大值池化和平均值池化压缩通道维度,得到空间特征向量和/>其中,领域通道注意力加权的中间特征图X'的维度是H×W×C,空间特征向量/>和/>的维度都是H×W×1;由于领域分数与空间特征向量维度不同,所以需要进行维度变换操作,将Wdc与/>和/>保持相同的特征维度H×W×1,然后将Wdc、/>和/>基于通道方向进行连接操作,再经过一个卷积操作将H×W×3的特征图降维为1个通道的领域空间注意力特征图Wsa,具体计算如下所示:
式中,表示实数,Gcov(·)是卷积操作,Gsh2(·)是维度变换操作,/>和/>是最大值池化和平均值池化后输出的空间特征向量,领域空间注意力特征图Wsa的维度是H×W×1;
然后,将领域空间注意力特征图Wsa与中间特征图X相乘,完成中间特征图加权操作,以实现自适应特征优化,具体计算如下所示:
式中,表示实数,Gwt(·)是加权操作,X”是领域空间注意力加权的中间特征图,维度是H×W×C。
5.根据权利要求4所述的基于领域泛化学习的图像分类方法,其特征在于,在步骤4)中,所述Q是一个给定的超参数,其中扰动Δz的取值范围表示为Δzi为扰动Δz的第i个输入特征,n为输入特征的数量;源域数据集DS中第j个训练样本zj在Q邻域范围内对应的不可见样本数据集合SQ(zj)表示为SQ(zj)={zu|zu=zj+Δz;|Δzi|<Q};训练样本zj和不可见样本zu在Q邻域范围内的输出差异为局部随机敏感度ESM,计算如下:
式中,M表示Q邻域内获取的扰动的个数,m表示第m个扰动,表示训练样本的输出分布;
由于局部随机敏感度是测量多个不可见样本与训练样本之间输出差异的平均值,平均化会忽略敏感度极值情况,因此使用局部最大敏感度代替局部随机敏感度;使用对抗方法,寻找一个能够让不可见样本与训练样本之间的输出差异最大化的对抗扰动radv;所述对抗扰动对于对抗的一方来说,是指加入扰动后预测不可见样本的结果和真实的结果差别最大化,而对于模型来说,是指加入扰动后预测不可见样本的结果和真实的结果一致,差别最小化;所以对抗扰动radv计算如下:
radv:=arg max D[p(yj|zj),p(yj|zj+r,θ)],||r||≤ε
LLMS=D[p(yj|zj),p(yj|zj+radv,θ)]
式中,对抗扰动radv的求解是通过最大化加入扰动后的不可见样本与训练样本之间的输出差异,其中r表示扰动,θ表示领域泛化网络模型的参数,ε表示对抗方向的规范约束,控制扰动选取的界限,D[p(yj|zj),p(yj|zj+r,θ)]表示两个分布之间的Kullback-Leibler误差,其中p(yj|zj)表示语义类别信息为y的第j个训练样本xj的真实分布,p(yj|zj+r,θ)表示语义类别信息为y的第j个训练样本xj在添加扰动r后由模型参数为θ的领域泛化网络模型产生的输出分布;yj表示源域数据集DS中第j个训练样本zj的语义类别信息;LLMS表示局部最大敏感度,p(yj|zj+radv,θ)表示语义类别信息为y的第j个训练样本xj在添加对抗扰动radv后由模型参数为θ的领域泛化网络模型产生的输出分布;
对于对抗扰动的计算是一个优化问题,很难找到这样一个精确的值表示对抗扰动,所以用以下的线性估计方法在正梯度方向上得到对抗扰动radv的近似解:
式中,g表示正梯度,表示对加入扰动后的不可见样本与训练样本之间的输出差异计算梯度,其中梯度是指函数的微分,代表着函数在给定点的切线的斜率,正梯度表示函数在给定点的上升最快的方向。
6.根据权利要求5所述的基于领域泛化学习的图像分类方法,其特征在于,在步骤5)中,总损失函数包括基于改进ResNet网络输出层构建的三种损失函数和步骤4)中获得的局部最大敏感度,总损失函数Ltotal计算如下:
Ltotal=Lclf+αLjig+βLdc+γLLMS
式中,α,β,γ是三个超参数,用来平衡所有的损失函数;Lclf表示数据集语义类别输出损失函数,Ldc表示领域类别输出损失函数,Ljig表示拼图排列类别输出损失函数,LLMS表示局部最大敏感度,Lclf、Ldc和Ljig都是基于交叉熵损失计算,而LLMS是基于Kullback-Leibler计算的;其中,所述损失函数是指计算领域泛化网络模型对输入图像数据的预测标签信息与真实标签信息之间的差距;真实标签信息在步骤1)中,经过预处理时已经保存在数据集中;得到损失函数后,再通过梯度下降算法求解模型参数,最终得到训练并优化的领域泛化网络模型,该网络模型可以完成图像分类任务。
7.根据权利要求6所述的基于领域泛化学习的图像分类方法,其特征在于:在步骤6)中,将训练好的领域泛化网络模型直接泛化到目标域数据集DT上,在目标域数据集DT上完成图像分类任务并计算图像分类准确率,不需要再进行领域泛化网络模型的训练和优化,通过图像分类准确率来评判当前领域泛化网络模型的图像分类能力和泛化能力。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310325449.4A CN116452862A (zh) | 2023-03-30 | 2023-03-30 | 基于领域泛化学习的图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310325449.4A CN116452862A (zh) | 2023-03-30 | 2023-03-30 | 基于领域泛化学习的图像分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116452862A true CN116452862A (zh) | 2023-07-18 |
Family
ID=87129521
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310325449.4A Pending CN116452862A (zh) | 2023-03-30 | 2023-03-30 | 基于领域泛化学习的图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116452862A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116883681A (zh) * | 2023-08-09 | 2023-10-13 | 北京航空航天大学 | 一种基于对抗生成网络的域泛化目标检测方法 |
CN117115567A (zh) * | 2023-10-23 | 2023-11-24 | 南方科技大学 | 基于特征调整的域泛化图像分类方法、系统、终端及介质 |
-
2023
- 2023-03-30 CN CN202310325449.4A patent/CN116452862A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116883681A (zh) * | 2023-08-09 | 2023-10-13 | 北京航空航天大学 | 一种基于对抗生成网络的域泛化目标检测方法 |
CN116883681B (zh) * | 2023-08-09 | 2024-01-30 | 北京航空航天大学 | 一种基于对抗生成网络的域泛化目标检测方法 |
CN117115567A (zh) * | 2023-10-23 | 2023-11-24 | 南方科技大学 | 基于特征调整的域泛化图像分类方法、系统、终端及介质 |
CN117115567B (zh) * | 2023-10-23 | 2024-03-26 | 南方科技大学 | 基于特征调整的域泛化图像分类方法、系统、终端及介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109949317B (zh) | 基于逐步对抗学习的半监督图像实例分割方法 | |
CN111639692B (zh) | 一种基于注意力机制的阴影检测方法 | |
CN111583263B (zh) | 一种基于联合动态图卷积的点云分割方法 | |
CN110717526B (zh) | 一种基于图卷积网络的无监督迁移学习方法 | |
CN116452862A (zh) | 基于领域泛化学习的图像分类方法 | |
CN112446423B (zh) | 一种基于迁移学习的快速混合高阶注意力域对抗网络的方法 | |
CN111008639B (zh) | 一种基于注意力机制的车牌字符识别方法 | |
CN113326731A (zh) | 一种基于动量网络指导的跨域行人重识别算法 | |
CN110598018B (zh) | 一种基于协同注意力的草图图像检索方法 | |
CN111898736A (zh) | 基于属性感知的高效行人重识别方法 | |
CN111639564A (zh) | 一种基于多注意力异构网络的视频行人重识别方法 | |
Han et al. | Unsupervised semantic aggregation and deformable template matching for semi-supervised learning | |
CN116311483B (zh) | 基于局部面部区域重构和记忆对比学习的微表情识别方法 | |
CN115424177A (zh) | 一种基于增量学习的孪生网络目标跟踪的方法 | |
CN112507800A (zh) | 一种基于通道注意力机制和轻型卷积神经网络的行人多属性协同识别方法 | |
CN111126155B (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
CN113222072A (zh) | 基于K-means聚类和GAN的肺部X光图像分类方法 | |
CN113807214B (zh) | 基于deit附属网络知识蒸馏的小目标人脸识别方法 | |
CN114780767A (zh) | 一种基于深度卷积神经网络的大规模图像检索方法及系统 | |
CN112749734B (zh) | 一种基于可迁移注意力机制的领域自适应的目标检测方法 | |
CN114495004A (zh) | 一种基于无监督跨模态的行人重识别方法 | |
CN114492581A (zh) | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 | |
CN112528077A (zh) | 基于视频嵌入的视频人脸检索方法及系统 | |
CN116645562A (zh) | 一种细粒度伪造图像的检测方法及其模型训练方法 | |
CN114911967B (zh) | 一种基于自适应域增强的三维模型草图检索方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |