CN114065861A - 基于对比对抗学习的领域自适应方法及装置 - Google Patents

基于对比对抗学习的领域自适应方法及装置 Download PDF

Info

Publication number
CN114065861A
CN114065861A CN202111363731.9A CN202111363731A CN114065861A CN 114065861 A CN114065861 A CN 114065861A CN 202111363731 A CN202111363731 A CN 202111363731A CN 114065861 A CN114065861 A CN 114065861A
Authority
CN
China
Prior art keywords
classifier
domain
loss
data
parameters
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111363731.9A
Other languages
English (en)
Inventor
孙艳丰
陈亮
王少帆
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Technology
Original Assignee
Beijing University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Technology filed Critical Beijing University of Technology
Priority to CN202111363731.9A priority Critical patent/CN114065861A/zh
Publication of CN114065861A publication Critical patent/CN114065861A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

基于对比对抗学习的领域自适应方法及装置,在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项。分类器C1和C2分别使用不同的数据增强方式的特征,因此保障了分类器的多样性,使得双分类器能够更高效的找出处于分类边界的样本,使得模型学习到的特征含有更多有效信息,从而较好地解决无监督领域自适应问题,在传统基于双分类器对抗方法的基础上,不仅考虑分类器在目标域上的决策边界,同时也进一步关注域间差异。

Description

基于对比对抗学习的领域自适应方法及装置
技术领域
本发明涉及计算机视觉的技术领域,尤其涉及一种基于对比对抗学习的领域自适应方法,以及基于对比对抗学习的领域自适应装置。
背景技术
本发明重点解决无监督领域自适应的图像分类问题,将深度网络和领域自适应问题相结合。通过在特征空间对样本的特征进行处理,减小源域和目标域分布的差异,使得源域上学到的知识同样可以作用于目标域。深度无监督领域自适应作为一个备受关注的研究领域,已有大量的学者参与了这个领域的研究工作。目前,深度无监督领域自适应方法主要可以分为三类,分别是:基于分布距离度量的方法,基于域对抗的方法以及基于重构误差的方法。
(1)基于分布距离度量的方法:
这类方法的核心思想是通过最小化不同域之间的分布距离来实现不同域数据之间的对齐。目前最大平均差异(MMD)、相关性对齐(CORAL)和Wasserstein度量这三类度量分布距离的方式被广泛应用在这类算法中。最大平均差异(MMD)最初用于检验两个分布是否相同,现在一般用于度量两个分布之间的相似性。该度量通过寻找在样本空间上的一个连续函数,求不同分布的样本在该函数上的函数值的均值来求解两个分布对应于这一函数的平均差异。如果平均值不同,那么样本很可能不是来自同一个分布,因为可以确定当且仅当两个分布相等时,他们之间的MMD距离为零。将这种思想用于深度自适应方法时,可以通过对深度网络进行最小化特征分布之间的MMD距离的约束,来实现减小特征分布之间的MMD距离,并达到最小化深度网络对应层之间的输出特征之间的分布差异的目的。考虑到深度网络具有多层结构,其中每一层都会输出相应的特征。为了探索MMD距离约束与不同层间输出之间的适配性,Longetal.提出了多项式核MMD(MK-MMD)算法。此外,Bousmalis等人在探索了基于MMD的领域自适应方法的有效性。基于相关性对齐(CORAL)的自适应方法与MK-MMD方法相似,更多的考虑了整个域的空间结构,从数据波动性的角度考虑域间分布的对齐。这类方法通过减小源特征和目标特征的二阶统计量(协方差)之间的距离,使得数据的波动性相似,通过这种方式使位于两个域的不同分布的结构较为一致,从而减小两个分布之间的差异。由最优运输问题定义的Wasserstein距离也叫做推土机距离(EMD),用于度量不同分布之间的距离。这类方法通过求解最优运输规划问题最小化域分布差异。
(2)基于域对抗的方法:
学者们将对抗的思想引入DA,得到了基于域对抗方法的领域自适应方法。基于对抗的领域自适应方法引入域判别器,成功的将领域自适应问题和对抗网络结合在一起。域判别器最大化域分布差异,特征提取器最小化域分布差异。这类方法通过特征提取器和域判别器之间的对抗,将来自源域和目标域的数据投影到了一个公共空间中,得到了不同域数据在该空间中的域不变表示,并利用这些域不变表示特征实现域间数据差异的消除。这种基于域对抗的方法通常需要通过交替迭代的方式寻找最优解。为了减少模型的时间复杂度,通过引入梯度反转层,使得网络成为一个端到端的模型,来实现模型复杂度的降低。此外,Shen等人使用Wasserstein距离约束域判别器进行域对抗学习,减小数据分布之间的差异,取得了良好的成果。
(3)基于误差重构的方法
相对于减小数据分布之间的差异,该方法假设可以获得样本分类信息的特征,同时该特征可用于重构原始数据。这类方法需要编码器和解码器。编码器对源域数据编码,而后分类器对该特征进行分类,这样使得编码器生成的特征能够区分源域的样本(即是一个比较好的特征),对于目标域特征用解码器解码,尽量还原目标域的样本。这样得到的特征所在的特征空间在源域和目标域样本上相近。
上述方法都有其各自的优势,但也普遍存在待解决的难题。即如何在自适应过程中充分挖掘源域样本有效信息,有效防止误匹配以及如何在自适应过程中适配底层的特征。
发明内容
为克服现有技术的缺陷,本发明要解决的技术问题是提供了一种基于对比对抗学习的领域自适应方法,其保障了分类器的多样性,使得双分类器能够更高效的找出处于分类边界的样本,使得模型学习到的特征含有更多有效信息,从而较好地解决无监督领域自适应问题,在传统基于双分类器对抗方法的基础上,不仅考虑分类器在目标域上的决策边界,同时也进一步关注域间差异。
本发明的技术方案是:这种基于对比对抗学习的领域自适应方法,该方法包括以下步骤:
(1)在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义为公式(1):
Figure BDA0003360178590000031
其中,Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数;
(2)固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,损失函数为公式
(2):
Figure BDA0003360178590000032
其中,Ldis(·,·)表示双分类器对目标域样本判别差异仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度,定义为公式(7):
Figure BDA0003360178590000041
其中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数;
(3)固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,定义为公式(8):
Figure BDA0003360178590000042
本发明在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,因此保障了分类器的多样性,使得双分类器能够更高效的找出处于分类边界的样本,使得模型学习到的特征含有更多有效信息,从而较好地解决无监督领域自适应问题,在传统基于双分类器对抗方法的基础上,不仅考虑分类器在目标域上的决策边界,同时也进一步关注域间差异。
还提供了一种基于对比对抗学习的领域自适应装置,该装置包括:
训练模块,其配置来在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义为公式(1):
Figure BDA0003360178590000043
其中,Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数;分类器更新模块,其配置来固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,损失函数为公式(2):
Figure BDA0003360178590000044
其中,Ldis(·,·)表示双分类器对目标域样本判别差异仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度,定义为公式(7):
Figure BDA0003360178590000051
其中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数;
特征提取器更新模块,其配置来固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,定义为公式(8):
Figure BDA0003360178590000052
附图说明
图1示出了传统的双分类器对抗训练方法的步骤二、三。
图2是本发明的模型结构流程图。
图3是四种方法的数据分布结构图。
图4是分类器差异图。
图5是根据本发明的基于对比对抗学习的领域自适应方法的流程图。
具体实施方式
如图5所示,这种基于对比对抗学习的领域自适应方法,该方法包括以下步骤:
(1)在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义为公式(1):
Figure BDA0003360178590000053
其中,Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数;
(2)固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,损失函数为公式(2):
Figure BDA0003360178590000061
其中,Ldis(·,·)表示双分类器对目标域样本判别差异仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度,定义为公式(7):
Figure BDA0003360178590000062
其中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数;
(3)固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,定义为公式(8):
Figure BDA0003360178590000063
本发明在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,因此保障了分类器的多样性,使得双分类器能够更高效的找出处于分类边界的样本,使得模型学习到的特征含有更多有效信息,从而较好地解决无监督领域自适应问题,在传统基于双分类器对抗方法的基础上,不仅考虑分类器在目标域上的决策边界,同时也进一步关注域间差异。
优选地,所述步骤(1)中,首先采用随机数据增强的方式将xi增强两次分别获得
Figure BDA0003360178590000064
Figure BDA0003360178590000065
然后两个视角的数据同时送入特征提取器,经过分类器中的隐藏层映射后获得不同视角的特征
Figure BDA0003360178590000066
Figure BDA0003360178590000067
优选地,所述步骤(2)中,使用余弦相似度度量不同视角特征之间的差异,为公式(4):
Figure BDA0003360178590000068
其中,
Figure BDA0003360178590000071
表示
Figure BDA0003360178590000072
Figure BDA0003360178590000073
之间的余弦相似度,两个分类器期望最小化
Figure BDA0003360178590000074
Figure BDA0003360178590000075
之间的余弦相似度,而特征提取器期望最大化两者的相似度。
优选地,所述步骤(2)中,使用MLP以及梯度停止技巧防止模式崩塌,带有对称关系分类器差异损失为公式(5):
Figure BDA0003360178590000076
其中,模块M的输入和输出维度一致以满足向量余弦相似度计算的需要。
优选地,所述步骤(3)中,使用Sliced Wasserstein距离减小源域和目标域数据标签分布的差异,使得目标域数据向着正确的方向移动,自适应损失为公式(6):
Figure BDA0003360178590000077
其中,fs和ft分别表示源域和目标域特征。
优选地,该方法应用到Image-CLEF-DA数据集,该数据集由三个子域构成,分别是Caltech-256(C),ImageNet ILSVRC 2012(I)以及Pascalvoc 2012(P),整个数据集中含有1800张图片样本,每一个子域中分别含有600张图片样本并包含有12个类别。
还提供了一种基于对比对抗学习的领域自适应装置,该装置包括:
还提供了一种基于对比对抗学习的领域自适应装置,该装置包括:
训练模块,其配置来在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义为公式(1):
Figure BDA0003360178590000078
其中,Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数;分类器更新模块,其配置来固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,损失函数为公式(2):
Figure BDA0003360178590000081
其中,Ldis(·,·)表示双分类器对目标域样本判别差异仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度,定义为公式(7):
Figure BDA0003360178590000082
其中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数;
特征提取器更新模块,其配置来固定分类器C1和C2中的参数,使用Ldis
更新特征提取器中的参数,在这一步骤中保留了自适应损失项,定义为公式(8):
Figure BDA0003360178590000083
以下更详细地说明本发明的内容。
本发明主要研究无监督领域自适应分类问题。给定源域数据
Figure BDA0003360178590000084
和目标域数据
Figure BDA0003360178590000085
xs和xt分别是源域和数据,ys是源域数据的标签,共有C个类别。目标域数据类别与源域一致但缺乏真实的样本标签。为便于理解,本部分首先介绍传统的双分类器对抗方法,然后介绍了本发明的创新以及改进点。
1传统双分类器对抗学习方法
传统基于双分类器对抗的模型主要由两个部分构成,分别是特征提取器G以及两个分类器C1和C2。首先源域和目标域数据送入到特征提取器分别获得特征fs和ft,然后将特征同时送入不同的分类器。分类器输出对目标域的概率分布p1(ys|xs),p2(ys|xs),p1(yt|xt),p2(yt|xt)。对抗过程发生在特征提取器和两个分类器之间,分类器最大化p1(yt|xt)和p2(yt|xt)之间的差异,而特征提取器期望通过提取共性特征最小化两个分类器的输出差异。该模型的目的是找到处于分类器决策边界的目标域样本,并通过限制分类器输出一致使得提高边界点的置信度。关于模型处理目标域数据的整体结构示意图如图1所示。
传统双分类器对抗训练方法的三个步骤如下:
步骤1:在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义如下:
Figure BDA0003360178590000091
上式中Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数。
步骤2:固定特征提取器中参数,仅更新分类器C1和C2。最小化分类器分类损失以及最大化分类器对目标域样本判别差异。损失函数定义如下:
Figure BDA0003360178590000092
上式中,Ldis(·,·)表示双分类器对目标域样本判别差异。
步骤3:固定分类器C1和C2中的参数,通过最小化Ldis(·,·)更新特征提取器中的参数,损失函数定义如下:
Figure BDA0003360178590000093
重复以上三个步骤直到模型收敛。模型可以有效找到目标域数据的决策边界样本点以及利用双分类器的多样性改善模型分类性能。
2基于对比对抗学习的双分类器方法
如上所述,双分类器对抗方法可根据不同分类器之间的差异找出处于分类边界的样本。然而分类器C1和C2共享特征提取器中的特征,仅依靠分类器中的参数不足以保证分类器的多样性。为提高不同分类器的多样性,本发明参考Simsiam模型,采用不同的数据增强方式获得不同视角的特征,同时将双分类器对抗模型紧密结合双分类器思想。首先采用随机数据增强的方式将xi增强两次分别获得
Figure BDA0003360178590000094
Figure BDA0003360178590000095
然后两个视角的数据同时送入特征提取器,经过分类器中的隐藏层映射后获得不同视角的特征
Figure BDA0003360178590000096
Figure BDA0003360178590000097
由于输入不同分类器中的特征存在差异,因此从输入数据的角度保证了分类器输出的多样性。
传统的双分类器对抗模型通常使用L1范数度量分类器输出差异。但仅考虑预测概率的差异会令模型无法关注到特征中有效信息。因此将双分类器模型结合了对比学习思想,使用不同分类器中的特征差异定义分类器的差异。由对比学习思想可知,通过减小不同视角特征差异可以增强数据据的表征能力,从而提升模型的性能。本文使用余弦相似度度量不同视角特征之间的差异,公式定义如下:
Figure BDA0003360178590000101
上式
Figure BDA0003360178590000102
表示
Figure BDA0003360178590000103
Figure BDA0003360178590000104
之间的余弦相似度。两个分类器期望最小化
Figure BDA0003360178590000105
Figure BDA0003360178590000106
之间的余弦相似度,而特征提取器期望最大化两者的相似度。在对比学习中,仅优化上式会造成模式崩塌等问题,因此使用了MLP以及梯度停止技巧防止该问题。综上所述,带有对称关系分类器差异损失定义如下:
Figure BDA0003360178590000107
上式中,模块M的输入和输出维度一致以满足向量余弦相似度计算的需要。
尽管基于双分类器对抗模型可以高效的找出处于分类边界的样本,但却无法保证模型向正确的方向收敛。例如对于一个三分类任务,分类器对当前的目标样本预测概率分别是[0.98,0.01,0.01]和[0.97,0.01,0.02],但该样本的真实标签可能是[0,1,0]。尽管此时两个分类器的输出差异小,但两个分类器的输出依旧是错误的。这是因为在训练过程中,模型仅关注两个分类器在目标域数据上的预测差异,但最小化此差异并不能有效指导边界点向对应的类别中心收敛。由现有的领域自适应理论可知,模型在目标域上的分类误差主要由源域数据的判别误差以及两域之间的差异所界定。忽略域分布匹配会极大限制模型的性能。
为解决上述问题,使用Sliced Wasserstein距离减小源域和目标域数据标签分布的差异,使得目标域数据向着正确的方向移动。最终的自适应损失定义如下:
Figure BDA0003360178590000111
上式中,fs和ft分别表示源域和目标域特征。模型整体的优化训练过程如下:
在步骤一中,模型和传统的双分类器对抗模型一样,使用源域数据交叉熵损失更新整个模型的参数。
步骤二中固定特征提取器中参数,仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度。公式定义如下:
Figure BDA0003360178590000112
上式中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数。
步骤三中,模型固定住两个分类器中的参数,并且使用Ldis更新特征提取器中的参数。在这一步骤中,保留了自适应损失项,公式定义如下:
Figure BDA0003360178590000113
图2中xs和xt分别代表原始的源域和目标域数据,
Figure BDA0003360178590000114
Figure BDA0003360178590000115
是源域数据经过两次随机数据增强得到的图片样本。
Figure BDA0003360178590000116
Figure BDA0003360178590000117
是目标域数据经过两次随机数据增强得到的图片样本。fv1和fv2分别是经过特征提取器G以及全连接层FC映射后的不同视角的特征。
Figure BDA0003360178590000118
Figure BDA0003360178590000119
是分类器C1对源域和目标域的预测概率,
Figure BDA00033601785900001110
Figure BDA00033601785900001111
是分类器C2对源域和目标域的预测概率。上图中的MLP模块各层中均包含有BN层,主要由三层全连接层构成并且输出层的维度为1024,隐含层的节点数量为512,这使得MLP类似于一个瓶颈的结构。通过MLP模块映射可以预防特征模式崩塌问题。
本发明对上述方法进行了实验验证,并使用平均的分类精度衡量模型的性能,并取得了明显的效果。实验部分所使用的样本为有标注的源域样本以及未标注的目标域样本。
本实验选择Image-CLEF验证模型的效果。其中,Image-CLEF-DA数据集由三个子域构成,分别是Caltech-256(C),ImageNet ILSVRC 2012(I)以及Pascalvoc 2012(P)。整个数据集中含有1800张图片样本,每一个子域中分别含有600张图片样本并包含有12个类别。为直观显示模型的效果,进行数据可视化实验。为证明模型中各模块的作用,同样在该数据集上开展了消融实验。所有的代码均使用Python以及Pytorch,使用的显卡为RTX 3090。
为了突出模型的性能,实验选取了三种主流对比方法:
Source-only:此方法使用Resnet-50作为模型的主干网络。该方法使用源域数据训练好分类器,而后将该模型直接对目标域数据分类。此方法为所有对比方法的基准线。
经典方法:为证明基于双分类器对抗模型的性能优势,本文分别选择了较为经典的基于MMD距离度量的方法以及基于域判别器对抗的模型方法。Deep Adaptation Network(DAN)方法使用MMD距离减小分布之间的差异。Domain Adversarial Neural Network(DANN)使用判别器辨别当前数据来自于源域还是目标域。在DANN的基础上,ConditionalDomain Adversarial Network(CDAN)使用多重线性映射进行特征融合,使得源域和目标域分布结构更加清晰,从而实现了域分布结构匹配,提升了模型的性能。
双分类器对抗方法:为证明本模型的性能优势,本文同样选择与基于双分类器对抗方法比较。MCD是首次使用双分类器对抗训练的模型,该方法使用L1范数度量两个分类器中的差异。在MCD的基础上,SWD使用了Wasserstein距离度量两个分类器的输出差异。最近JADA模型在MCD的基础上结合了判别器对抗方法的思想,成功的使用判别器对齐两个数据分布。通过对比上述方法能体现本模型的有效性。
本部分将讨论模型Image-CLEF数据集上的实验结果。所有实验数据均为模型30次迭代之后的输出。所有实验结果记录于表1中,其中识别任务记作S→T,S表示有标注数据所在的源域,T表示无标注数据所在的目域,S→T表示利用源域的标注数据解决目标域数据的分类任务:
Method I→P P→I I→C C→I C→P P→C Avg
Source-only 74.8 83.9 91.5 78.0 65.5 91.2 80.7
DAN 74.5 82.2 92.8 86.3 69.2 89.8 82.5
DANN 75.0 86.0 96.2 87.0 74.3 91.5 85.0
CDAN 76.7 90.6 97.0 90.5 74.5 93.5 87.1
MCD 77.3 89.2 92.7 88.2 71.0 92.3 85.1
SWD 76.9 90.7 93.8 88.3 74.2 93.8 86.3
JADA 78.2 90.1 95.9 90.8 76.8 94.1 87.7
Ours 79.9 92.5 95.4 92.7 78.8 94.2 88.8
表1
由表1可知,本文所选取的模型达到了最优平均精度。特别是对于迁移较为困难的任务C→P,本模型提升效果最为突出。实验结果证明对比学习能充分学习到复杂样本中的有效信息。
为更直观展现模型优点,本文使用T-sne对迁移任务P→C进行数据降维可视化。同时选取Source-only、MCD以及SWD作为对比方法。数据可视化结果如图3所示。图中“ο”代表源域数据,“×”代表目标域数据。如图3所示,由于存在域差异,Source-only方法无法有效提高模型在目标域上的分类精度。相比于Source-only方法,MCD以及SWD都能减少目标域上处在分类边界的样本。但当目标域数据分布结构较复杂时,不同类别样本的距离依旧不清晰。由图可知,本模型中数据分布结构较为清晰,不同类别之间的距离较大。不同域中相同类别之间也对齐的更为紧密,从而证明了标签分布匹配的有效性。
由图4中可知,在模型训练的开始阶段分类器的输出准确率差异较大,并且两个分类器在模型的初始阶段性能均次于双分类器的综合性能。这表明两个视角的数据能够在训练初期提供互补的信息,因此分类器综合输出性能会提高。随着模型训练次数增加,模型性能逐渐趋向于一致,三条曲线趋向一致。这表明双分类器的输出差异较小,证明了特征相似性对分类器输出的影响。以上实验结果证明了结合域对齐以及对比学习可以有效提升模型的性能。
以上所述,仅是本发明的较佳实施例,并非对本发明作任何形式上的限制,凡是依据本发明的技术实质对以上实施例所作的任何简单修改、等同变化与修饰,均仍属本发明技术方案的保护范围。

Claims (7)

1.基于对比对抗学习的领域自适应方法,其特征在于:该方法包括以下步骤:
(1)在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义为公式(1):
Figure FDA0003360178580000011
其中,Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数;
(2)固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,损失函数为公式(2):
Figure FDA0003360178580000012
其中,Ldis(·,·)表示双分类器对目标域样本判别差异仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度,定义为公式(7):
Figure FDA0003360178580000013
其中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数;
(3)固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,定义为公式(8):
Figure FDA0003360178580000014
2.根据权利要求1所述的基于对比对抗学习的领域自适应方法,其特征在于:所述步骤(1)中,首先采用随机数据增强的方式将xi增强两次分别获得
Figure FDA0003360178580000015
Figure FDA0003360178580000016
然后两个视角的数据同时送入特征提取器,经过分类器中的隐藏层映射后获得不同视角的特征
Figure FDA0003360178580000017
Figure FDA0003360178580000018
3.根据权利要求2所述的基于对比对抗学习的领域自适应方法,其特征在于:所述步骤(2)中,使用余弦相似度度量不同视角特征之间的差异,为公式(4):
Figure FDA0003360178580000021
其中,
Figure FDA0003360178580000024
表示
Figure FDA0003360178580000025
Figure FDA0003360178580000026
之间的余弦相似度,两个分类器期望最小化
Figure FDA0003360178580000027
Figure FDA0003360178580000028
之间的余弦相似度,而特征提取器期望最大化两者的相似度。
4.根据权利要求3所述的基于对比对抗学习的领域自适应方法,其特征在于:所述步骤(2)中,使用MLP以及梯度停止技巧防止模式崩塌,带有对称关系分类器差异损失为公式(5):
Figure FDA0003360178580000022
其中,模块M的输入和输出维度一致以满足向量余弦相似度计算的需要。
5.根据权利要求4所述的基于对比对抗学习的领域自适应方法,其特征在于:所述步骤(3)中,使用Sliced Wasserstein距离减小源域和目标域数据标签分布的差异,使得目标域数据向着正确的方向移动,自适应损失为公式(6):
Figure FDA0003360178580000023
其中,fs和ft分别表示源域和目标域特征。
6.根据权利要求5所述的基于对比对抗学习的领域自适应方法,其特征在于:该方法应用到Image-CLEF-DA数据集,该数据集由三个子域构成,分别是Caltech-256(C),ImageNetILSVRC 2012(I)以及Pascalvoc 2012(P),整个数据集中含有1800张图片样本,每一个子域中分别含有600张图片样本并包含有12个类别。
7.基于对比对抗学习的领域自适应装置,其特征在于:该装置包括:训练模块,其配置来在源域数据上使用损失函数Lcls(xs,ys)训练整个网络模型,优化过程定义为公式(1):
Figure FDA0003360178580000031
其中,Lce(·,·)是交叉熵损失,θg,θc1θc2分别是特征网络G,C1,C2中的参数;分类器更新模块,其配置来固定特征提取器中参数,仅更新分类器C1和C2,最小化分类器分类损失以及最大化分类器对目标域样本判别差异,损失函数为公式(2):
Figure FDA0003360178580000032
其中,Ldis(·,·)表示双分类器对目标域样本判别差异仅更新分类器中的参数,同时模型加入分布对齐损失并最小化特征相似度,定义为公式(7):
Figure FDA0003360178580000033
其中,θc1,θc2分别代表了分类器C1和C2中的参数,λ和η分别代表损失函数中的平衡参数;
特征提取器更新模块,其配置来固定分类器C1和C2中的参数,使用Ldis更新特征提取器中的参数,在这一步骤中保留了自适应损失项,定义为公式(8):
Figure FDA0003360178580000034
CN202111363731.9A 2021-11-17 2021-11-17 基于对比对抗学习的领域自适应方法及装置 Pending CN114065861A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111363731.9A CN114065861A (zh) 2021-11-17 2021-11-17 基于对比对抗学习的领域自适应方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111363731.9A CN114065861A (zh) 2021-11-17 2021-11-17 基于对比对抗学习的领域自适应方法及装置

Publications (1)

Publication Number Publication Date
CN114065861A true CN114065861A (zh) 2022-02-18

Family

ID=80277362

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111363731.9A Pending CN114065861A (zh) 2021-11-17 2021-11-17 基于对比对抗学习的领域自适应方法及装置

Country Status (1)

Country Link
CN (1) CN114065861A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114693972A (zh) * 2022-03-29 2022-07-01 电子科技大学 一种基于重建的中间域领域自适应方法
CN114723994A (zh) * 2022-04-18 2022-07-08 中国矿业大学 一种基于双分类器对抗增强网络的高光谱图像分类方法
CN114782697A (zh) * 2022-04-29 2022-07-22 四川大学 一种对抗子领域自适应隐写检测方法
CN117456309A (zh) * 2023-12-20 2024-01-26 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 基于中间域引导与度量学习约束的跨域目标识别方法

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114693972A (zh) * 2022-03-29 2022-07-01 电子科技大学 一种基于重建的中间域领域自适应方法
CN114693972B (zh) * 2022-03-29 2023-08-29 电子科技大学 一种基于重建的中间域领域自适应方法
CN114723994A (zh) * 2022-04-18 2022-07-08 中国矿业大学 一种基于双分类器对抗增强网络的高光谱图像分类方法
CN114782697A (zh) * 2022-04-29 2022-07-22 四川大学 一种对抗子领域自适应隐写检测方法
CN114782697B (zh) * 2022-04-29 2023-05-23 四川大学 一种对抗子领域自适应隐写检测方法
CN117456309A (zh) * 2023-12-20 2024-01-26 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 基于中间域引导与度量学习约束的跨域目标识别方法
CN117456309B (zh) * 2023-12-20 2024-03-15 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 基于中间域引导与度量学习约束的跨域目标识别方法

Similar Documents

Publication Publication Date Title
CN114065861A (zh) 基于对比对抗学习的领域自适应方法及装置
Xie et al. Differentiable top-k with optimal transport
CN113378632B (zh) 一种基于伪标签优化的无监督域适应行人重识别方法
CN113326731B (zh) 一种基于动量网络指导的跨域行人重识别方法
CN111444955B (zh) 一种基于类意识领域自适应的水下声纳图像无监督分类方法
CN114241569B (zh) 人脸识别攻击样本的生成方法、模型训练方法及相关设备
WO2021088365A1 (zh) 确定神经网络的方法和装置
US10614343B2 (en) Pattern recognition apparatus, method, and program using domain adaptation
CN113887480B (zh) 基于多解码器联合学习的缅甸语图像文本识别方法及装置
CN112861616A (zh) 一种无源领域自适应目标检测方法
CN114842267A (zh) 基于标签噪声域自适应的图像分类方法及系统
CN114821152B (zh) 基于前景-类别感知对齐的域自适应目标检测方法及系统
CN115410088A (zh) 一种基于虚拟分类器的高光谱图像领域自适应方法
CN112488229A (zh) 一种基于特征分离和对齐的域自适应无监督目标检测方法
CN114692732A (zh) 一种在线标签更新的方法、系统、装置及存储介质
CN116486483A (zh) 基于高斯建模的跨视角行人重识别方法及装置
Kanatani Model selection criteria for geometric inference
Li et al. Adaptive pseudo labeling for source-free domain adaptation in medical image segmentation
CN117115547A (zh) 基于自监督学习与自训练机制的跨域长尾图像分类方法
CN114139631B (zh) 一种面向多目标训练对象可选择的灰盒的对抗样本生成方法
Wang et al. Source data-free cross-domain semantic segmentation: Align, teach and propagate
CN114387642A (zh) 图像分割方法、装置、设备和存储介质
Long et al. Video domain adaptation based on optimal transport in grassmann manifolds
Csaba et al. Multilevel knowledge transfer for cross-domain object detection
CN115640418B (zh) 基于残差语义一致性跨域多视角目标网站检索方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination