CN110837850A - 一种基于对抗学习损失函数的无监督域适应方法 - Google Patents

一种基于对抗学习损失函数的无监督域适应方法 Download PDF

Info

Publication number
CN110837850A
CN110837850A CN201911012806.1A CN201911012806A CN110837850A CN 110837850 A CN110837850 A CN 110837850A CN 201911012806 A CN201911012806 A CN 201911012806A CN 110837850 A CN110837850 A CN 110837850A
Authority
CN
China
Prior art keywords
domain
label
loss function
target
source
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN201911012806.1A
Other languages
English (en)
Other versions
CN110837850B (zh
Inventor
陈铭浩
蔡登�
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN201911012806.1A priority Critical patent/CN110837850B/zh
Publication of CN110837850A publication Critical patent/CN110837850A/zh
Application granted granted Critical
Publication of CN110837850B publication Critical patent/CN110837850B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2155Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques

Abstract

本发明公开了一种基于对抗学习损失函数的无监督域适应方法,包括:(1)源域图像经过特征提取网络G生成高层特征,经过分类器C与真实标签做交叉熵损失,另一方面经过域判别器D生成混淆矩阵,将伪标签纠正为真实标签。(2)目标域图像经过特征提取网络G生成高层特征,经过分类器C生成伪标签,另一方面高层特征经过域判别器D生成混淆矩阵,将伪标签纠正为相反分布。(3)让特征生成器与判别器对抗优化上述损失函数;此外对于目标域上的混淆矩阵,生成纠正标签,并作为目标域的标签,优化分类器。利用本发明,使无监督域适应中,能够对纠正伪标签的噪声,同时匹配域之间分布差异,从而提高目标域的分类精度。

Description

一种基于对抗学习损失函数的无监督域适应方法
技术领域
本发明属于迁移学习的无监督域适应领域,尤其是涉及一种基于对抗学习损失函数的无监督域适应方法。
背景技术
近年来,深度学习在分类任务上取得了令人瞩目的进展。深层神经网络的成功是建立在具有大量标记样本的大规模数据集的基础上的。然而,在许多实际情况下,大量的标记样本是不可获取的。在已有数据集上预训练的深层神经网络不能很好地泛化到具有不同外观特征的新数据。从本质上讲,域之间数据分布的差异使得将知识从源域传输到目标域变得困难。这种转移问题被称为域迁移问题。
无监督域适应旨在解决上述域迁移问题,在将模型从有标记的源域传输到未标记的目标域过程中。已经被证明:目标域上的分类器的精度受源域上的精度和域之间差异性的限制。因此,当前UDA研究的主要思路是对齐源域和目标域之间的分布。域之间的分布差异可以通过最大均值误差(MMD)或二阶统计量来衡量。
基于域对抗学习的方法通过对抗学习过程来让提取的特征在域之间相似。2016年JMLR期刊上收录的文章《Domain-adversarial training of neural networks》中提出了域对抗学习方法,它训练域鉴别器以区分特征是来自源域还是目标域。为了愚弄鉴别器,特征生成器必须输出相似的源和目标特征分布。在匹配特征分布后,他们假设了在源特征上训练的分类器能够对目标样本的特征进行分类。然而,对于这类无监督域适应方法来说,学习目标域上的判别性特征是一个挑战。这是因为它们忽略了生成的目标特征是否可以被分类器识别。
最近,基于自我训练的方法成为了无监督域适应的另一个解决方案,并在多项任务中表现出最先进的性能。这类方法最早起源于半监督学习,一个与无监督域适应相似的任务,后来被有效地用于了无监督域适应。基于自训练的方法使用网络预测来训练网络本身,这是在无监督的方式下完成的。一种典型的方法是在2018 European Conference onComputer Vision会议上收录的《Unsupervised domain adaptation for semanticsegmentation via class-balanced self-training》中提出的伪标签方法:生成与目标样本的较大预测概率相对应的伪标签,并用这些伪标签对模型进行训练。这样,有助于目标分类的特征就得到了增强。然而,源特征分布和目标特征分布之间的对齐是隐式的,而且没有理论保证。在不匹配域分布的情况下,基于自训练的方法会导致在浅网络的情况下性能下降。
发明内容
为解决现有技术存在的问题,本发明提供了一种基于对抗学习损失函数的无监督域适应方法,使无监督域适应中,能够对纠正伪标签的噪声,同时匹配域之间分布差异,从而提高目标域的分类精度。
一种基于对抗学习损失函数的无监督域适应方法,其特征在于,包括:
(1)对于有标注的源域和无标注的目标域,训练一个特征提取器网络G,分别提取源域和目标域的特征;
(2)使用分类器网络C对提取后的特征进行预测,分别生成源域和目标域的伪标签;确定源域上分类器网络C输出的概率向量与真实标签的损失函数;
(3)使用域判别器D分别生成源域和目标域的混淆矩阵,将混淆矩阵与伪标签相乘得到源域的纠正标签和目标域的纠正标签;
分别确定源域的纠正标签与真实标签的损失函数,以及目标域的纠正标签与伪标签的相反分布的损失函数;
(4)让特征提取器网络G与域判别器D进行对抗学习,优化上述损失函数;
(5)对于目标域上的混淆矩阵,生成纠正标签,并作为目标域的标签,优化分类器网络C。
具体而言,步骤(1)中,对于一个有标注的源域:
Figure BDA0002244712950000031
和一个无标注的目标域:
Figure BDA0002244712950000032
本发明训练一个特征提取器网络G来提取数据xs或xt的高层特征。
步骤(2)中,使用一个分类器网络C来对于提取后的特征进行K-类分类。分类器C将输出预测概率向量ps
Figure BDA00022447129500000311
代表着xs,xt属于每类的预测概率。所述伪标签的生成公式为:
Figure BDA0002244712950000033
Figure BDA0002244712950000034
其中,
Figure BDA0002244712950000035
为源域的伪标签,
Figure BDA0002244712950000036
为目标域的伪标签,δ是一个阈值,
Figure BDA0002244712950000037
表示ps向量中的最大值,表示pt向量中的最大值。
本发明考虑在目标域上提供适当的损失函数。从理论上讲,理想损失函数是具有正确标签yt的损失:
其中,
Figure BDA00022447129500000312
是一个基础损失函数,例如交叉熵损失(CE),均方损失(MAE),然而在无监督域适应情形下,正确标签yt在目标域上是不可得的。典型的伪标签方法就会使用伪标签
Figure BDA00022447129500000310
来代替正确标签。
本发明使用混淆矩阵来分析伪标签的损失函数与正确标签的差距。分析公式如下所示:
Figure BDA0002244712950000041
其中,
Figure BDA0002244712950000042
是混淆矩阵。对于未标记的目标域,混淆矩阵是未知的。为表达的简单起见,定义
Figure BDA0002244712950000043
称之为纠正标签。
在以前研究噪声标签的工作中,通常假设混淆矩阵条件独立于输入和对于噪声率α是均匀的。之前理论证明,这种unhinge损失函数对均匀噪声具有鲁棒性:
Figure BDA0002244712950000046
然而这些工作中的噪声假设在伪标签情形下是不成立的,这使得纠正伪标签问题更加难以解决。
本发明的总体思路是,如果我们可以充分估计混淆矩阵,则伪标签中的噪声将得到校正,并且可以在目标域上近似优化理想损失函数。
步骤(2)中,源域上分类器网络C输出的概率向量与真实标签的损失函数为:
Figure BDA0002244712950000044
其中,ps为源域中提取的特征经过分类器网络和softmax函数后得到的预测概率向量,ys为源域上数据的真实标签。
步骤(3)中,本发明提出使用一个多层的多分类网络D,称为域判别器,来生成混淆矩阵
Figure BDA0002244712950000045
然而想要直接生成一个矩阵是很困难的任务,因此对混淆矩阵做了一个均匀假设:
定义1:噪声矩阵
Figure BDA0002244712950000051
对于噪声向量
Figure BDA0002244712950000052
如果当k=l,而
Figure BDA0002244712950000054
对于k≠l。
在这种均匀假设下,只需要生成噪声向量
Figure BDA0002244712950000055
即可生成混淆矩阵。因此域判别器D将深度特征G(x)作为输入,输出一个多维的向量D(G(x))∈RK,然后经过一个sigmoid层,得到噪声向量ξ(x)=σ(D(G(x)))。每个ξ(x)的分量代表着伪标签与正确标签一致的概率:
Figure BDA0002244712950000056
Figure BDA0002244712950000057
本发明采用域对抗学习的思想,使域判别器D和特征提取器网络G玩一个minimax游戏。相比于域对抗学习让鉴别器执行域分类任务,本发明让鉴别器为源域和目标域生成不同的噪声矢量。
对抗学习的详细过程为:对于源特征G(xs),域判别器D的目标是使校正后的标签矢量
Figure BDA0002244712950000058
与真实标签ys差异最小化。
源域的纠正标签与真实标签的损失函数为:
Figure BDA0002244712950000059
其中,
Figure BDA00022447129500000510
由域判别器的输出
Figure BDA00022447129500000511
根据定义1生成,ps k为分类器的预测向量ps的第k分量,ys为ys的ont-hot向量:当k=ys
Figure BDA00022447129500000513
当k≠ys
至于目标特征,域判别器D做刚好相反的事情。域判别器D将伪标签
Figure BDA00022447129500000514
校正为相反的分布
Figure BDA00022447129500000515
其中
Figure BDA00022447129500000516
对于
Figure BDA00022447129500000517
Figure BDA00022447129500000518
对于
Figure BDA00022447129500000519
目标域的纠正标签与伪标签的相反分布的损失函数为:
Figure BDA00022447129500000520
其中,
Figure BDA00022447129500000521
由域判别器的输出
Figure BDA00022447129500000522
根据定义1生成。
步骤(4)中,进行对抗学习时,优化损失函数时,优化的损失函数为总的对抗性损失函数为:
Figure BDA0002244712950000067
其中,
Figure BDA0002244712950000068
为源域的纠正标签与真实标签的损失函数,
Figure BDA0002244712950000069
为目标域的纠正标签与伪标签的相反分布的损失函数;域判别器D通过最小化总的对抗性损失函数以区分源特征和目标特征,特征提取器网络G通过最大化总的对抗性损失函数来欺骗域判别器。与普通域对抗学习相比,这种对抗损失将分类器预测和标签信息考虑在内。这样的噪声校正域鉴别器就可以实现逐类特征对齐。
然而正如生成对抗网络(GANs)的工作所揭示的那样,对抗学习的训练过程可能是不稳定的。因此,在进行对抗学习时,在源域上向分类器网络C添加分类任务,以使其训练更加稳定。那么域判别器D不仅必须区分源域和目标域,还用于正确分类源样本,域判别器D的输出经过softmax函数的概率向量与真实标签做交叉熵损失,作为正则项:
其中,所述域判别器D的优化目标为:
Figure BDA0002244712950000064
其中,
Figure BDA0002244712950000065
表示计算分布下的均值。
步骤(5)中,在对混淆矩阵进行对抗学习之后,我们可以对伪标签中的噪声进行校正,并为目标样本构建适当的损失函数。由于unhinge损失对噪声的均匀部分具有鲁棒性,因此我们选择unhinge损耗。对伪标签与混淆矩阵进行矩阵乘法得到纠正标签
Figure BDA0002244712950000066
对于目标域的数据,使用纠正标签作为目标域上的损失函数:
Figure BDA0002244712950000071
其中,
Figure BDA0002244712950000072
由域判别器的输出
Figure BDA0002244712950000073
根据定义1生成;
Figure BDA0002244712950000076
为unhinge损耗;
结合源域上分类器网络C输出的概率向量与真实标签的损失函数分类器网络C和特征提取器网络G的优化目标为:
Figure BDA0002244712950000074
其中,λ∈[0,1]为平衡两种损失的超参数,使用随机梯度下降算法SGD优化上述损失函数。
与现有技术相比,本发明具有以下有益效果:
1、本发明提出的域对抗学习能够有效地匹配两个域的特征分布,而且由于将标签信息考虑在内,比起一般的域对抗学习,它能够做到逐类匹配分布。
2、本发明矫正伪标签噪声的方法,一方面能够学习到有判别性的目标特征,另一方面能防止伪标签的噪声误导模型训练。
3、本发明将域对抗学习与自学习两种方法有机结合到一个整体框架下,这种结合方法效果比这两个各自的方法都要好。
附图说明
图1为本发明方法的示意图;
图2为本发明方法的域对抗学习的流程示意图;
图3为本发明实施例中ALDA方法在office-31数据集上特征可视化示意图;
图4为本发明实施例在数字集上采用的网络结构。
具体实施方式
下面结合附图和实施例对本发明做进一步详细描述,需要指出的是,以下所述实施例旨在便于对本发明的理解,而对其不起任何限定作用。
如图1和图2所示,一种基于对抗学习损失函数的无监督域适应方法,本发明框架主要分为两条分支分别处理两个域的图像(参见图2):(1)(虚线)源域图像经过特征提取网络G生成高层特征,经过分类器C与真实标签做交叉熵损失,另一方面经过域判别器D生成混淆矩阵,将伪标签纠正为真实标签。(2)(实线)目标域图像经过特征提取网络G生成高层特征,经过分类器C生成伪标签,另一方面高层特征经过域判别器D生成混淆矩阵,将伪标签纠正为相反分布。(3)让特征生成器与判别器对抗优化上述损失函数。此外对于目标域上的混淆矩阵,生成纠正标签,并作为目标域的标签,优化分类器。具体步骤分别阐述如下:
(1)源域图像经过特征提取网络G生成高层特征,经过分类器C与真实标签做交叉熵损失,另一方面经过域判别器D生成混淆矩阵,将伪标签纠正为真实标签。基本步骤如下:
1-1.源域图像经过特征提取器网络(如ResNet-50去掉最后一层分类层)提取得到高层特征G(xs)。G(xs)经过分类器C,得到其属于各类的分值,再经过一个softmax函数,得到预测概率向量ps∈RK。由于源域上的数据有正确标签ys,令预测的ps与真实标签做交叉熵损失:
Figure BDA0002244712950000081
1-2.G(xs)还经过域判别器D生成混淆向量
Figure BDA0002244712950000082
根据定义1来生成相应的混淆矩阵
Figure BDA0002244712950000083
混淆矩阵与伪标签
Figure BDA0002244712950000084
Figure BDA0002244712950000085
相乘得到纠正标签最后令纠正标签与真实标签做交叉熵损失:
Figure BDA0002244712950000091
1-3.域判别器D的输出经过softmax函数:
Figure BDA0002244712950000092
其与真实标签做交叉熵损失,作为正则项:
Figure BDA0002244712950000093
(2)目标域图像经过特征提取网络G生成高层特征,经过分类器C生成伪标签,另一方面高层特征经过域判别器D生成混淆矩阵,将伪标签纠正为相反分布。基本步骤如下:
2-1.类似源域过程,目标域图像经过特征提取器网络(如ResNet-50去掉最后一层分类层)提取得到高层特征G(xt)。G(xt)经过分类器C以及softmax函数,得到预测概率向量pt∈RK。取伪标签:
Figure BDA0002244712950000094
2-2.G(xt)还经过域判别器D生成混淆向量
Figure BDA0002244712950000095
然后根据定义1.来生成相应的混淆矩阵
Figure BDA0002244712950000096
混淆矩阵与伪标签相乘得到纠正标签
Figure BDA0002244712950000097
最后令纠正标签与相反分布
Figure BDA0002244712950000098
(其中对于
Figure BDA00022447129500000910
Figure BDA00022447129500000911
对于
Figure BDA00022447129500000912
)做交叉熵损失:
Figure BDA00022447129500000913
(3)让特征生成器与判别器对抗优化上述损失函数。此外对于目标域上的混淆矩阵,生成纠正标签,并作为目标域的标签,优化分类器。基本步骤如下:
令域判别器D优化:
Figure BDA00022447129500000914
另一方面对于目标数据,生成纠正标签,并作为目标损失:
分类器和特征提取器的优化目标为:
Figure BDA0002244712950000102
Figure BDA0002244712950000103
使用随机梯度下降算法SGD优化上述损失函数。
为了验证本发明方法的效果,本发明在4种迁移情景:数字集,office-31,office-home,VisDA-2017迁移上与其他目前最前沿的无监督域适应方法进行对比。
数字集:我们在三种迁移情形上进行了实验:USPS到MNIST(U→M),MNIST到USPS(M→U),SVHN到MNIST(S→M)。MNIST包含60000个手写数字图像,USPS包含7 291个图像。街景门牌号码(SVHN)包含73257张在自然场景中带有数字和数字的图像。我们在MNIST和USPS的测试集上报告评估结果。
Office-31是用于无监督域适应的常用数据集,其中包含从以下三个领域收集的4652个图像和13个类别:亚马逊(A),网络摄像头(W)和DSLR(D)。我们评估了所有方法在六个域适应任务中:A→W,D→W,W→D,A→D,D→A和W→A。
Office-Home是一个比office-31更困难的领域适应数据集,其中包括来自四个不同领域的15500张图像:艺术图像(Ar),剪贴画(Cl),产品图像(Pr)和真实世界(Rw)。每个域都包含办公室和家庭场景中常见的65种对象类别的图像。我们评估了所有方法在12种域适应方案中的表现。
VisDA-2017是一个大规模的数据集,对于无监督域从模拟到真实的适应性提出了挑战。数据集包含152397个合成图像作为源域,并包含55388个真实图像作为目标域。这两个域共享12个对象类别。我们评估了所有方法在VisDA验证集上。
对于数字数据集,我们采用图4中使用的特征生成器和分类器网络,并使用Adam梯度下降对模型进行优化,学习率为1×10-3。伪标签的阈值δ设置为0.6
对于其他三个数据集,我们采用ResNet-50作为特征提取器网络。ResNet-50是在ImageNet上进行预训练。判别器由三个完全连接的层加上dropout组成。我们用动量为0.9的随机梯度下降(SGD)优化器训练模型。伪标签的阈值δ设置为0.9。在所有实验中,超参数λ从0逐渐增加到1。
整体对比结果分别如表1、表2、表3和表4所示:
表1
Figure BDA0002244712950000111
表1总结了在Office-31上的结果。本发明的ALDA明显优于其他方法。因为ALDA与自训练方法相结合来学习判别特征,所以与基于领对抗学习的方法(例如DANN,JAN,MADA)相比,ALDA可获得更好的结果。与ALDA相似,CDAN+E也将分类预测纳入判别,并使用预测的熵作为重要权重。但是,由于采用了噪声校正域区分的设计,在难迁移任务(例如A→W,A→D,D→A与W→A)上,ALDA的性能优于CDAN+E。出色的结果表明,将领对抗学习和自训练的方法进行组合非常重要。
在图3中,我们使用t-SNE分别可视化了使用ResNet-50,自训练self-training,DANN和ALDA的A→W适应任务(31类)中提取的特征。仅使用ResNet-50时,目标特征分布与来源。尽管自训练和DANN可以对齐源域和目标域的分布,但它们的目标域聚类团与源域聚类团并不完全匹配。对于自训练,某些源域聚类团没有目标样本在周围,这表明某些目标样本被错误分类。对于DANN,目标域聚类团的边界是模糊的,因为DANN无法实现按类别的特征对齐。对于ALDA,目标聚类与相应的源聚类紧密匹配,这表明ALDA提取的目标特征具有良好的一致性和可区分性。
表2
Figure BDA0002244712950000121
表2总结了在Office-home上的结果。对于这种更困难的域适应数据集,本发明的ALDA仍然超越了最先进的方法。与Office-31相比,Office-Home的类别更多,并且域之间的外观差距更大。类别的数量越多,表明ALDA中鉴别器输出的成分越多,这导致了更强的分类域鉴别能力。因此,ALDA可以大大胜过类别无关的方法,例如DANN,JAN。
表3
Figure BDA0002244712950000122
表3总结了在VisDA数据集中的定量结果。尽管仅基于ResNet-50网络,但我们的ALDA的性能比其他领域的适应方法要好得多。由于合成域和真实域之间存在很大差异,这些结果表明ALDA可以更好地应对困难的适应性任务。
表4
表4总结了与最新方法相比较的数字集迁移的实验结果。为了进行公平的比较,我们仅调整图像的大小并对其进行归一化,并且不应用任何其他数据增强功能。我们将每个实验进行三次,并报告其平均结果和差异。如表所示,ALDA优于最先进的分布对齐方法(例如DANN,MCD,CDAN)和基于自训练的方法(例如具有确定阈值(MT+CT)的Mean Teacher)。ALDA还大大减少了无监督域适应和目标域上的监督学习之间的性能差距。
本发明使用域判别器网络生成混淆矩阵,并以对抗学习方式对其进行训练。具体来说,对于源样本,鉴别器经过训练以生成可以将伪标签校正为正确标签的混淆矩阵。对于目标样本,鉴别器经过训练以生成将伪标签校正为相反分布的混淆矩阵。特征生成器旨在混淆域判别器,以使其无法区分源样本和目标样本。对抗过程最终导致目标域上出现合适的混淆矩阵。然后,我们采用混淆矩阵构造校正后的目标损失,并使用损失对目标域上的分类器进行优化。
本发明的主要贡献为:(1)本发明提出了对抗学习损失函数的无监督域适应方法用于无标签目标样本的无监督训练。本发明的域对抗学习能够实现逐类特征分布对齐,同时还能提高特征的分类判别性。(2)本发明的判别器能够生成合适的混淆矩阵,来纠正伪标签(3)本发明将域对抗学习与自学习两种域适应方法有机结合到了一起,拥有两者的优点。(4)本发明在四种标准的域适应数据集上都达到了最好效果,相比于之前的域适应方法。
以上所述的实施例对本发明的技术方案和有益效果进行了详细说明,应理解的是以上所述仅为本发明的具体实施例,并不用于限制本发明,凡在本发明的原则范围内所做的任何修改、补充和等同替换,均应包含在本发明的保护范围之内。

Claims (8)

1.一种基于对抗学习损失函数的无监督域适应方法,其特征在于,包括:
(1)对于有标注的源域和无标注的目标域,训练一个特征提取器网络G,分别提取源域和目标域的特征;
(2)使用分类器网络C对提取后的特征进行预测,分别生成源域和目标域的伪标签;确定源域上分类器网络C输出的概率向量与真实标签的损失函数;
(3)使用域判别器D分别生成源域和目标域的混淆矩阵,将混淆矩阵与伪标签相乘得到源域的纠正标签和目标域的纠正标签;
分别确定源域的纠正标签与真实标签的损失函数,以及目标域的纠正标签与伪标签的相反分布的损失函数;
(4)让特征提取器网络G与域判别器D进行对抗学习,优化上述损失函数;
(5)对于目标域上的混淆矩阵,生成纠正标签,并作为目标域的标签,优化分类器网络C。
2.根据权利要求1所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,步骤(2)中,所述伪标签的生成公式为:
Figure FDA0002244712940000011
Figure FDA0002244712940000012
其中,
Figure FDA0002244712940000013
为源域的伪标签,
Figure FDA0002244712940000014
为目标域的伪标签,δ是一个阈值,
Figure FDA0002244712940000015
表示ps向量中的最大值,
Figure FDA0002244712940000016
表示pt向量中的最大值。
3.根据权利要求1所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,步骤(2)中,源域上分类器网络C输出的概率向量与真实标签的损失函数为:
Figure FDA0002244712940000021
其中,ps为源域中提取的特征经过分类器网络和softmax函数后得到的预测概率向量,ys为源域上数据的真实标签。
4.根据权利要求1所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,步骤(3)中,源域的纠正标签与真实标签的损失函数为:
Figure FDA0002244712940000022
其中,
Figure FDA0002244712940000023
Figure FDA0002244712940000024
由域判别器的输出
Figure FDA0002244712940000025
根据定义生成,ps k为分类器的预测向量ps的第k分量,ys为ys的ont-hot向量:
Figure FDA0002244712940000026
当k=ys
Figure FDA0002244712940000027
当k≠ys
5.根据权利要求4所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,步骤(3)中,目标域的纠正标签与伪标签的相反分布的损失函数为:
Figure FDA0002244712940000028
其中,
Figure FDA0002244712940000029
Figure FDA00022447129400000210
由域判别器的输出
Figure FDA00022447129400000211
根据定义生成;伪标签的相反分布
Figure FDA00022447129400000212
对于
Figure FDA00022447129400000213
对于
6.根据权利要求5所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,步骤(4)中,进行对抗学习时,优化损失函数时,优化的损失函数为总的对抗性损失函数为:
Figure FDA00022447129400000215
其中,
Figure FDA00022447129400000216
为源域的纠正标签与真实标签的损失函数,
Figure FDA00022447129400000217
为目标域的纠正标签与伪标签的相反分布的损失函数;域判别器D通过最小化总的对抗性损失函数以区分源特征和目标特征,特征提取器网络G通过最大化总的对抗性损失函数来欺骗域判别器。
7.根据权利要求6所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,在进行对抗学习时,在源域上向分类器网络C添加分类任务,所述的域判别器D还用于正确分类源样本,域判别器D的输出经过softmax函数的概率向量与真实标签做交叉熵损失,作为正则项:
Figure FDA0002244712940000032
其中,所述域判别器D的优化目标为:
Figure FDA0002244712940000034
其中,
Figure FDA0002244712940000035
表示计算
Figure FDA0002244712940000036
分布下的均值。
8.根据权利要求6所述的基于对抗学习损失函数的无监督域适应方法,其特征在于,步骤(5)中,对于目标域的数据,使用纠正标签作为目标域上的损失函数:
Figure FDA0002244712940000037
其中,
Figure FDA0002244712940000038
由域判别器的输出
Figure FDA00022447129400000310
根据定义生成;
Figure FDA00022447129400000311
为unhinge损耗;
结合源域上分类器网络C输出的概率向量与真实标签的损失函数
Figure FDA00022447129400000312
分类器网络C和特征提取器网络G的优化目标为:
Figure FDA00022447129400000313
Figure FDA00022447129400000314
其中,λ∈[0,1]为平衡两种损失的超参数,使用随机梯度下降算法SGD优化上述损失函数。
CN201911012806.1A 2019-10-23 2019-10-23 一种基于对抗学习损失函数的无监督域适应方法 Active CN110837850B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201911012806.1A CN110837850B (zh) 2019-10-23 2019-10-23 一种基于对抗学习损失函数的无监督域适应方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201911012806.1A CN110837850B (zh) 2019-10-23 2019-10-23 一种基于对抗学习损失函数的无监督域适应方法

Publications (2)

Publication Number Publication Date
CN110837850A true CN110837850A (zh) 2020-02-25
CN110837850B CN110837850B (zh) 2022-06-21

Family

ID=69575774

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201911012806.1A Active CN110837850B (zh) 2019-10-23 2019-10-23 一种基于对抗学习损失函数的无监督域适应方法

Country Status (1)

Country Link
CN (1) CN110837850B (zh)

Cited By (27)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111738455A (zh) * 2020-06-02 2020-10-02 山东大学 一种基于集成域自适应的故障诊断方法及系统
CN111738315A (zh) * 2020-06-10 2020-10-02 西安电子科技大学 基于对抗融合多源迁移学习的图像分类方法
CN111738172A (zh) * 2020-06-24 2020-10-02 中国科学院自动化研究所 基于特征对抗学习和自相似性聚类的跨域目标重识别方法
CN111753918A (zh) * 2020-06-30 2020-10-09 浙江工业大学 一种基于对抗学习的去性别偏见的图像识别模型及应用
CN111797814A (zh) * 2020-07-21 2020-10-20 天津理工大学 基于通道融合和分类器对抗的无监督跨域动作识别方法
CN111832605A (zh) * 2020-05-22 2020-10-27 北京嘀嘀无限科技发展有限公司 无监督图像分类模型的训练方法、装置和电子设备
CN112131967A (zh) * 2020-09-01 2020-12-25 河海大学 基于多分类器对抗迁移学习的遥感场景分类方法
CN112149722A (zh) * 2020-09-11 2020-12-29 南京大学 一种基于无监督域适应的图像自动标注方法
CN112163486A (zh) * 2020-09-18 2021-01-01 杭州电子科技大学 一种基于稀疏学习和域对抗网络的脑电通道优化方法
CN112215795A (zh) * 2020-09-02 2021-01-12 苏州超集信息科技有限公司 一种基于深度学习的服务器部件智能检测方法
CN112232252A (zh) * 2020-10-23 2021-01-15 湖南科技大学 基于最优输运的传动链无监督域适应故障诊断方法
CN112767328A (zh) * 2021-01-08 2021-05-07 厦门大学 基于对抗学习和适应性分析的医学图像病灶跨域检测方法
CN112801179A (zh) * 2021-01-27 2021-05-14 北京理工大学 面向跨领域复杂视觉任务的孪生分类器确定性最大化方法
CN112820301A (zh) * 2021-03-15 2021-05-18 中国科学院声学研究所 一种融合分布对齐和对抗学习的无监督跨域声纹识别方法
CN113326848A (zh) * 2021-06-17 2021-08-31 中山大学 半监督领域自适应方法、系统、设备及存储介质
CN113392967A (zh) * 2020-03-11 2021-09-14 富士通株式会社 领域对抗神经网络的训练方法
CN113435394A (zh) * 2021-07-13 2021-09-24 郑州大学 一种基于标签概率序列的高鲁棒性深度道路提取方法
CN113469273A (zh) * 2021-07-20 2021-10-01 南京信息工程大学 基于双向生成及中间域对齐的无监督域适应图像分类方法
CN113537466A (zh) * 2021-07-12 2021-10-22 广州杰纳医药科技发展有限公司 实时生成对抗样本的深度学习训练数据增广方法、装置、电子设备及介质
CN113673347A (zh) * 2021-07-20 2021-11-19 杭州电子科技大学 一种基于Wasserstein距离的表征相似对抗网络
CN114283287A (zh) * 2022-03-09 2022-04-05 南京航空航天大学 基于自训练噪声标签纠正的鲁棒领域自适应图像学习方法
CN115546567A (zh) * 2022-12-01 2022-12-30 成都考拉悠然科技有限公司 一种无监督领域适应分类方法、系统、设备及存储介质
CN115577797A (zh) * 2022-10-18 2023-01-06 东南大学 一种基于本地噪声感知的联邦学习优化方法及系统
CN115795313A (zh) * 2023-01-16 2023-03-14 中国科学院合肥物质科学研究院 核主泵故障诊断模型的训练方法、故障诊断方法和系统
CN116029394A (zh) * 2023-03-29 2023-04-28 季华实验室 自适应文本情感识别模型训练方法、电子设备及存储介质
CN117746079A (zh) * 2023-11-15 2024-03-22 中国地质大学(武汉) 一种高光谱图像的聚类预测方法、系统、存储介质及设备
CN117746079B (zh) * 2023-11-15 2024-05-14 中国地质大学(武汉) 一种高光谱图像的聚类预测方法、系统、存储介质及设备

Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107506800A (zh) * 2017-09-21 2017-12-22 深圳市唯特视科技有限公司 一种基于无监督域适应的无标签视频人脸识别方法
US20180307947A1 (en) * 2017-04-25 2018-10-25 Nec Laboratories America, Inc. Cyclic generative adversarial network for unsupervised cross-domain image generation
CN109753992A (zh) * 2018-12-10 2019-05-14 南京师范大学 基于条件生成对抗网络的无监督域适应图像分类方法
US20190147320A1 (en) * 2017-11-15 2019-05-16 Uber Technologies, Inc. "Matching Adversarial Networks"
US20190147854A1 (en) * 2017-11-16 2019-05-16 Microsoft Technology Licensing, Llc Speech Recognition Source to Target Domain Adaptation
CN109948741A (zh) * 2019-03-04 2019-06-28 北京邮电大学 一种迁移学习方法及装置
CN109948648A (zh) * 2019-01-31 2019-06-28 中山大学 一种基于元对抗学习的多目标域适应迁移方法及系统
CN110135579A (zh) * 2019-04-08 2019-08-16 上海交通大学 基于对抗学习的无监督领域适应方法、系统及介质
CN110186680A (zh) * 2019-05-30 2019-08-30 盐城工学院 一种对抗判别域适应一维卷积神经网络智能故障诊断方法
CN110210486A (zh) * 2019-05-15 2019-09-06 西安电子科技大学 一种基于素描标注信息的生成对抗迁移学习方法
CN110222690A (zh) * 2019-04-29 2019-09-10 浙江大学 一种基于最大二乘损失的无监督域适应语义分割方法

Patent Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180307947A1 (en) * 2017-04-25 2018-10-25 Nec Laboratories America, Inc. Cyclic generative adversarial network for unsupervised cross-domain image generation
CN107506800A (zh) * 2017-09-21 2017-12-22 深圳市唯特视科技有限公司 一种基于无监督域适应的无标签视频人脸识别方法
US20190147320A1 (en) * 2017-11-15 2019-05-16 Uber Technologies, Inc. "Matching Adversarial Networks"
US20190147854A1 (en) * 2017-11-16 2019-05-16 Microsoft Technology Licensing, Llc Speech Recognition Source to Target Domain Adaptation
CN109753992A (zh) * 2018-12-10 2019-05-14 南京师范大学 基于条件生成对抗网络的无监督域适应图像分类方法
CN109948648A (zh) * 2019-01-31 2019-06-28 中山大学 一种基于元对抗学习的多目标域适应迁移方法及系统
CN109948741A (zh) * 2019-03-04 2019-06-28 北京邮电大学 一种迁移学习方法及装置
CN110135579A (zh) * 2019-04-08 2019-08-16 上海交通大学 基于对抗学习的无监督领域适应方法、系统及介质
CN110222690A (zh) * 2019-04-29 2019-09-10 浙江大学 一种基于最大二乘损失的无监督域适应语义分割方法
CN110210486A (zh) * 2019-05-15 2019-09-06 西安电子科技大学 一种基于素描标注信息的生成对抗迁移学习方法
CN110186680A (zh) * 2019-05-30 2019-08-30 盐城工学院 一种对抗判别域适应一维卷积神经网络智能故障诊断方法

Non-Patent Citations (6)

* Cited by examiner, † Cited by third party
Title
ARITRA GHOSH ET AL.: "《Robust loss functions under label noise for deep neural networks》", 《HTTPS://ARXIV.ORG/ABS/1712.09482》, 27 December 2017 (2017-12-27) *
TU DINH NGUYEN ET AL.: "《Dual Discriminator Generative Adversarial Nets》", 《HTTPS://ARXIV,ORG/ABS/1709.03831》, 12 December 2017 (2017-12-12) *
WENQING CHU ET AL.: "《Weakly-Supervised Caricature Face Parsing Through Domain Adaptation》", 《2019 IEEE INTERNATIONAL CONFERENCE ON IMAGE PROCESSING (ICIP)》, 25 September 2019 (2019-09-25) *
毛潇锋: "《基于对抗学习的深度视觉域适应方法研究》", 《中国优秀博硕士学位论文全文数据库(硕士)信息科技辑》, no. 5, 15 May 2019 (2019-05-15) *
许浩 等: "《多层面的分步领域适应图像分类算法》", 《小型微型计算机系统》, vol. 40, no. 9, 30 September 2019 (2019-09-30) *
许浩 等: "《带有双判别器的对抗性领域适应图像分类算法》", 《计算机工程与科学》, vol. 41, no. 9, 30 September 2019 (2019-09-30) *

Cited By (41)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113392967A (zh) * 2020-03-11 2021-09-14 富士通株式会社 领域对抗神经网络的训练方法
CN111832605A (zh) * 2020-05-22 2020-10-27 北京嘀嘀无限科技发展有限公司 无监督图像分类模型的训练方法、装置和电子设备
CN111832605B (zh) * 2020-05-22 2023-12-08 北京嘀嘀无限科技发展有限公司 无监督图像分类模型的训练方法、装置和电子设备
CN111738455B (zh) * 2020-06-02 2021-05-11 山东大学 一种基于集成域自适应的故障诊断方法及系统
CN111738455A (zh) * 2020-06-02 2020-10-02 山东大学 一种基于集成域自适应的故障诊断方法及系统
CN111738315A (zh) * 2020-06-10 2020-10-02 西安电子科技大学 基于对抗融合多源迁移学习的图像分类方法
CN111738172B (zh) * 2020-06-24 2021-02-12 中国科学院自动化研究所 基于特征对抗学习和自相似性聚类的跨域目标重识别方法
CN111738172A (zh) * 2020-06-24 2020-10-02 中国科学院自动化研究所 基于特征对抗学习和自相似性聚类的跨域目标重识别方法
CN111753918A (zh) * 2020-06-30 2020-10-09 浙江工业大学 一种基于对抗学习的去性别偏见的图像识别模型及应用
CN111753918B (zh) * 2020-06-30 2024-02-23 浙江工业大学 一种基于对抗学习的去性别偏见的图像识别模型及应用
CN111797814A (zh) * 2020-07-21 2020-10-20 天津理工大学 基于通道融合和分类器对抗的无监督跨域动作识别方法
CN112131967A (zh) * 2020-09-01 2020-12-25 河海大学 基于多分类器对抗迁移学习的遥感场景分类方法
CN112131967B (zh) * 2020-09-01 2022-08-19 河海大学 基于多分类器对抗迁移学习的遥感场景分类方法
CN112215795A (zh) * 2020-09-02 2021-01-12 苏州超集信息科技有限公司 一种基于深度学习的服务器部件智能检测方法
CN112215795B (zh) * 2020-09-02 2024-04-09 苏州超集信息科技有限公司 一种基于深度学习的服务器部件智能检测方法
CN112149722A (zh) * 2020-09-11 2020-12-29 南京大学 一种基于无监督域适应的图像自动标注方法
CN112149722B (zh) * 2020-09-11 2024-01-16 南京大学 一种基于无监督域适应的图像自动标注方法
CN112163486B (zh) * 2020-09-18 2024-03-12 杭州电子科技大学 一种基于稀疏学习和域对抗网络的脑电通道优化方法
CN112163486A (zh) * 2020-09-18 2021-01-01 杭州电子科技大学 一种基于稀疏学习和域对抗网络的脑电通道优化方法
CN112232252A (zh) * 2020-10-23 2021-01-15 湖南科技大学 基于最优输运的传动链无监督域适应故障诊断方法
CN112232252B (zh) * 2020-10-23 2023-12-01 湖南科技大学 基于最优输运的传动链无监督域适应故障诊断方法
CN112767328A (zh) * 2021-01-08 2021-05-07 厦门大学 基于对抗学习和适应性分析的医学图像病灶跨域检测方法
CN112767328B (zh) * 2021-01-08 2022-06-14 厦门大学 基于对抗学习和适应性分析的医学图像病灶跨域检测方法
CN112801179A (zh) * 2021-01-27 2021-05-14 北京理工大学 面向跨领域复杂视觉任务的孪生分类器确定性最大化方法
CN112820301A (zh) * 2021-03-15 2021-05-18 中国科学院声学研究所 一种融合分布对齐和对抗学习的无监督跨域声纹识别方法
CN112820301B (zh) * 2021-03-15 2023-01-20 中国科学院声学研究所 一种融合分布对齐和对抗学习的无监督跨域声纹识别方法
CN113326848A (zh) * 2021-06-17 2021-08-31 中山大学 半监督领域自适应方法、系统、设备及存储介质
CN113326848B (zh) * 2021-06-17 2023-04-18 中山大学 半监督领域自适应方法、系统、设备及存储介质
CN113537466A (zh) * 2021-07-12 2021-10-22 广州杰纳医药科技发展有限公司 实时生成对抗样本的深度学习训练数据增广方法、装置、电子设备及介质
CN113435394A (zh) * 2021-07-13 2021-09-24 郑州大学 一种基于标签概率序列的高鲁棒性深度道路提取方法
CN113469273B (zh) * 2021-07-20 2023-12-05 南京信息工程大学 基于双向生成及中间域对齐的无监督域适应图像分类方法
CN113673347A (zh) * 2021-07-20 2021-11-19 杭州电子科技大学 一种基于Wasserstein距离的表征相似对抗网络
CN113469273A (zh) * 2021-07-20 2021-10-01 南京信息工程大学 基于双向生成及中间域对齐的无监督域适应图像分类方法
CN114283287A (zh) * 2022-03-09 2022-04-05 南京航空航天大学 基于自训练噪声标签纠正的鲁棒领域自适应图像学习方法
CN115577797B (zh) * 2022-10-18 2023-09-26 东南大学 一种基于本地噪声感知的联邦学习优化方法及系统
CN115577797A (zh) * 2022-10-18 2023-01-06 东南大学 一种基于本地噪声感知的联邦学习优化方法及系统
CN115546567A (zh) * 2022-12-01 2022-12-30 成都考拉悠然科技有限公司 一种无监督领域适应分类方法、系统、设备及存储介质
CN115795313A (zh) * 2023-01-16 2023-03-14 中国科学院合肥物质科学研究院 核主泵故障诊断模型的训练方法、故障诊断方法和系统
CN116029394A (zh) * 2023-03-29 2023-04-28 季华实验室 自适应文本情感识别模型训练方法、电子设备及存储介质
CN117746079A (zh) * 2023-11-15 2024-03-22 中国地质大学(武汉) 一种高光谱图像的聚类预测方法、系统、存储介质及设备
CN117746079B (zh) * 2023-11-15 2024-05-14 中国地质大学(武汉) 一种高光谱图像的聚类预测方法、系统、存储介质及设备

Also Published As

Publication number Publication date
CN110837850B (zh) 2022-06-21

Similar Documents

Publication Publication Date Title
CN110837850B (zh) 一种基于对抗学习损失函数的无监督域适应方法
CN108717568B (zh) 一种基于三维卷积神经网络的图像特征提取与训练方法
EP3767536A1 (en) Latent code for unsupervised domain adaptation
Li et al. Adversarial open-world person re-identification
CN107341463B (zh) 一种结合图像质量分析与度量学习的人脸特征识别方法
CN107437077A (zh) 一种基于生成对抗网络的旋转面部表示学习的方法
CN108230291B (zh) 物体识别系统训练方法、物体识别方法、装置和电子设备
CN113076994B (zh) 一种开集域自适应图像分类方法及系统
CN112801297B (zh) 一种基于条件变分自编码器的机器学习模型对抗性样本生成方法
CN110532862B (zh) 基于门控融合单元的特征融合组群识别方法
CN110827265A (zh) 基于深度学习的图片异常检测方法
CN113627543A (zh) 一种对抗攻击检测方法
CN111611909A (zh) 多子空间域自适应人脸识别方法
Ye et al. Reducing bias to source samples for unsupervised domain adaptation
CN114863176A (zh) 基于目标域移动机制的多源域自适应方法
CN113011513B (zh) 一种基于通用域自适应的图像大数据分类方法
CN113239926B (zh) 基于对抗的多模态虚假信息检测模型系统
CN112990357B (zh) 一种基于稀疏扰动的黑盒视频对抗样本生成方法
Pang et al. Federated Learning for Crowd Counting in Smart Surveillance Systems
CN111914617B (zh) 一种基于平衡栈式生成式对抗网络的人脸属性编辑方法
CN113011523A (zh) 一种基于分布对抗的无监督深度领域适应方法
Zhang et al. Black-box based limited query membership inference attack
Ren et al. Student behavior detection based on YOLOv4-Bi
Zhou et al. Network unknown‐threat detection based on a generative adversarial network and evolutionary algorithm
Yu et al. A multi-scale feature selection method for steganalytic feature GFR

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant