CN116227578A - 一种无源域数据的无监督领域适应方法 - Google Patents

一种无源域数据的无监督领域适应方法 Download PDF

Info

Publication number
CN116227578A
CN116227578A CN202211600631.8A CN202211600631A CN116227578A CN 116227578 A CN116227578 A CN 116227578A CN 202211600631 A CN202211600631 A CN 202211600631A CN 116227578 A CN116227578 A CN 116227578A
Authority
CN
China
Prior art keywords
domain
model
target domain
loss
target
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211600631.8A
Other languages
English (en)
Inventor
梅建萍
翁烨涛
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University of Technology ZJUT
Original Assignee
Zhejiang University of Technology ZJUT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University of Technology ZJUT filed Critical Zhejiang University of Technology ZJUT
Priority to CN202211600631.8A priority Critical patent/CN116227578A/zh
Publication of CN116227578A publication Critical patent/CN116227578A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及一种无源域数据的无监督领域适应方法,以有标签的源域样本训练模型,得到预训练好的源域模型;以源域模型初始化目标域模型;以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,最小化分布对齐损失,尽可能拉近源域和目标域特征分布空间;基于源域模型的分类器的预测对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失,对目标域样本计算信息最大化损失;以所有损失函数共同训练目标域模型,实现无源域数据的无监督领域适应,纠正部分最初分类器分错的目标域样本,提高分类准确度。

Description

一种无源域数据的无监督领域适应方法
技术领域
本发明涉及计算;推算或计数的技术领域,特别涉及一种机器学习领域的、基于BN层信息和软聚类的无源域数据的无监督领域适应方法。
背景技术
近年来,深度神经网络在视觉分类领域取得了非常不错的应用效果,被广泛地运用在各个行业。神经网络表现出卓越性能的一个前提是测试数据与训练数据服从独立同分布,然而,在现实世界中这个条件难以满足,理想情况下是希望模型能在标签丰富的数据集上获得的知识可以转移或者应用到其他未标记的数据上,但即使数据集之间的差异很小,深度网络也难以应用到未知的数据域中,在训练中,影响模型泛化能力的重要因素是来自不同领域数据之间的分布偏移。因此,领域适应就是针对这类问题进行的研究。
近年来,在该技术问题上取得了巨大的进展,尤其是无监督的领域适应。当我们可以直接访问源域数据集时,可以直接对齐源域和目标域的分布偏移,现有的许多领域适应方法即使对无标签的目标域数据都非常有效。然而,传统的领域适应都是基于源域数据及其标签可用的前提,在一些实际情况下,包括但不限于数据集过大存储困难、共享数据的挑战、数据隐私和其他数据集处理问题,使得源数据不容易获取,只能获取预训练好的模型,这让传统的无监督领域适应模型有了局限性,因此提出了无源领域适应。
无源领域适应与无监督领域适应的不同在于,无源即不能获取有标签的源域数据,只能用源域数据训练好的模型和无标签的目标域数据进行训练。目前无源领域适应常用的方法有两类:一类从预训练好的模型里挖掘包含源域特征的信息与目标域样本进行训练,微调预训练模型;另一类使用生成模型,利用目标域数据、预训练模型生成含有源域信息的生成样本,用生成样本和目标域样本进行领域适应;但因为没有源域数据,这些方法都没有对源域和目标域显式对齐,只是利用无监督方法进行微调或者生成的样本是类似目标域样本的伪源域样本。
发明内容
本发明解决了现有技术中存在的问题,提供了一种无源域数据的无监督领域适应方法。
本发明所采用的技术方案是,一种无源域数据的无监督领域适应方法,所述方法包括以下步骤:
步骤1:以有标签的源域样本训练模型,得到预训练好的源域模型;
步骤2:以源域模型初始化目标域模型,包括特征提取器和分类器;
步骤3:以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,计算分布对齐损失LBN
步骤4:基于目标域模型的分类器的预测,对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失Lclu
步骤5:对目标域样本计算信息最大化损失LIM,信息最大化损失包括最小化熵损失和最大化平均熵损失,使样本预测置信度更高,并且避免模型坍塌;
步骤6:以所述分布对齐损失LBN、交叉熵损失Lclu和信息最大化损失LIM共同训练目标域模型,实现无源域数据的无监督领域适应,提高对目标域样本的识别准确率。
优选地,所述步骤1中,为防止预训练模型在源域数据上过拟合,通过标签平滑后再计算交叉熵损失,以提升模型到目标域的泛化性能,目标函数为,
Figure BDA0003997270540000021
其中,fs表示预训练好的源域模型,包括特征提取器gs和分类器hs,满足给定输入x,fs(x)=hs(gs(x));K表示类别数目,k对应任一类别,Xs为源域样本集;给定qk为源域样本xs的标签,则
Figure BDA0003997270540000031
是对qk平滑后的标签,满足
Figure BDA0003997270540000032
α是平滑系数,0<α<1,一般取0.05≤α≤0.15;
σ(·)表示对某一给定向量的softmax归一化操作,假设给定向量a和温度参数T,用σk表示对某个向量σ(·)操作后得到的第k维的值,
Figure BDA0003997270540000033
ak表示向量a第k维的值,j指向量a第j维,式(1)中T为1。
优选地,所述目标域模型的分类器固定不变。所述步骤2中,目标域模型ft,包括特征提取器gt和分类器ht,分别初始化为源域模型中的特征提取器和分类器,给定任意输入x,满足ft(x)=ht(gt(x));通过损失函数优化目标域模型的特征提取器,分类器初始化以后冻结不更新。
优选地,所述步骤3中,BN层的统计信息包括该层每个通道的均值和方差,这些统计信息可以用来近似训练样本的全局特征分布,具体地说,每个BN层的每个通道的数据分布可以用一个高斯分布N(μ,σ2)表示,其中μ、σ2表示高斯分布的均值和方差;以源域模型的每个BN层中每个通道的均值、方差表示的高斯分布与目标域样本对应BN层的当前batch样本的每个通道的均值、方差表示的高斯分布,计算它们之间KL散度的平均值,作为衡量源域和目标域样本特征分布的距离。
优选地,所述分布对齐损失LBN为,
Figure BDA0003997270540000034
其中,M表示模型中BN层的总数,Cm表示第m个BN层的通道总数,
Figure BDA0003997270540000035
Figure BDA0003997270540000036
表示源域模型中第m个BN层第cm个通道存储的均值和方差,
Figure BDA0003997270540000041
Figure BDA0003997270540000042
表示当前batch经过目标域模型第m个BN层的第cm个通道的均值和方差;DKL为KL散度;
最小化损失函数LBN,通过最小化该损失函数,用BN层中的均值方差表示的高斯分布来近似因没有源域样本而无法获取的源域特征分布,实现与目标域特征的分布对齐。
优选地,所述步骤4包括以下步骤:
步骤4.1:因为目标域和源域的数据差异,固定的源域分类器对目标域样本的预测存在噪声,使得分类器对难以分辨的目标域样本没有纠正作用,为缓解这一问题,引入基于聚类的软标签对齐损失,软标签生成过程类似模糊聚类。先以目标域模型分类器输出的概率为权重,对提取的特征进行加权平均,初始化簇中心:
Figure BDA0003997270540000043
式(3)中,δk表示第k个类的簇中心,ft表示目标模型,包括特征提取器gt和分类器ht,满足给定输入x,ft(x)=ht(gt(x));xt表示目标域样本,Bt表示当前读入的目标域样本的batch,σ(·)表示对某一给定向量的softmax归一化操作,上标T表示向量的转置;
步骤4.2:虽然直接计算样本特征到簇中心距离能得到独热编码的伪标签,但仍有部分分错样本因距离决策边界过远,靠聚类难以纠正。为降低错误伪标签的影响,将独热编码的伪标签改成平滑的软标签,能减少对错误标签过于自信的影响,最终提高模型的泛化性。基于簇中心δk计算样本到每个簇中心的余弦距离,取倒数再softmax归一化得到样本的预测分布,同时加上温度参数T调节软标签的平滑程度,
Figure BDA0003997270540000044
其中,D表示余弦距离,
Figure BDA0003997270540000045
表示聚类得到的软标签在第k个类的概率或隶属度,此时温度参数满足0.6≤T≤1.2;
步骤4.3:用该软标签和模型分类器对目标域样本的输出概率分布计算交叉熵损失,
Figure BDA0003997270540000051
能一定程度纠正由于源域分类器在最初分类错误的目标域样本预测,优化模型。
优选地,所述步骤5中,信息最大化损失使得目标域样本的预测置信度更高,同时避免所有样本被分到少数几个类而产生坍塌解,具体地满足,
LIM=Lent+Ldiv (6)
其中,Lent为最小化熵损失,
Figure BDA0003997270540000052
Ldiv为最大化平均熵损失,
Figure BDA0003997270540000053
Figure BDA0003997270540000054
为第k个类的平均隶属度,
Figure BDA0003997270540000055
优选地,所述步骤6中,完整目标函数Lgt为,
Lgt=LIM+βLBN+γLclu (7)
其中,β和γ为对应的超参数,β,γ∈[0.6,1.0]。
本发明涉及一种无源域数据的无监督领域适应方法,以有标签的源域样本训练模型,得到预训练好的源域模型;以源域模型初始化目标域模型;以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,最小化分布对齐损失,尽可能拉近源域和目标域特征分布空间;基于源域模型的分类器的预测对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失,对目标域样本计算信息最大化损失;以所有损失函数共同训练目标域模型,实现无源域数据的无监督领域适应,纠正部分最初分类器分错的目标域样本,提高分类准确度。
本发明的有益效果在于:
(1)充分利用源域网络模型的参数中存储的统计信息即训练样本的均值方差来近似源域样本的特征分布,从而能显式地与目标域样本进行分布对齐,避免了因无法获取源域样本而无法进行分布对齐的问题;
(2)虽然直接聚类得到的伪标签相比直接模型分类器得的预测分类准确度有所提高,但仍有部分分错样本因距离决策边界过远、靠聚类难以纠正,计算样本和簇中心的余弦距离加上温度参数得到的平滑的软标签能包含更多目标域样本的信息;
(3)相比于传统的无监督领域适应方法,本发明具有更高的分类准确率,并且不需要源域数据只使用预训练模型和目标域样本,具有更广的模型适用性;
(4)有效性在SVHN、MINIST、USPS、Office-31与Office-Home数据集上得到验证,其目标域模型在Office-31数据集的三个子集两两迁移的六个迁移场景上的平均准确率能到达89.5%,在Office-Home数据集的四个子集的十二个迁移场景上的平均准确率到达72.4%。
附图说明
图1为本发明的流程图;
图2为本发明的无源域数据的无监督领域适应方法的示意图。
具体实施方式
下面结合实施例对本发明做进一步的详细描述,但本发明的保护范围并不限于此。
本发明涉及一种无源域数据的无监督领域适应方法,复用源域模型的分类器,通过源域模型中BN层存储的统计信息即该模型的训练样本的全局均值和方差来近似源域的特征分布,从而显式最小化源域和目标域之间的分布差异;由于源域分类器对目标域样本的预测存在噪声,本发明提出了基于分类器输出对目标域样本的特征进行软聚类得到平滑的标签,相较于独热编码的伪标签,软聚类得到的隶属度包含更多目标域样本信息,能一定程度上纠正源域分类器难以辨别的目标域样本;此外还采用信息最大化损失以提高样本预测的置信度、防止出现坍塌解,以进一步提高模型在目标域中的分类性能和鲁棒性。
本发明中,需要特别说明的是,下标s表示source即源域,下标t表示target即目标域。
所述方法包括以下步骤:
步骤1:选取数据集,确定好源域和目标域,以有标签的源域样本训练模型,得到预训练好的源域模型;
所述步骤1中,选取公开数据集,将数据集中的多个域两两随机组合成多个迁移场景。本发明中采用Office-31和Office-Home两个公开数据集作为实验的数据集,其中Office-31是一个小型包含办公环境中的31个类的数据集,有三个子集;Office-Home是一个中型数据集,包含65个类的照片,四个子集。我们用有标签的源域数据集训练源域模型时,在标准交叉熵损失中加入标签平滑,增加鲁棒性。
通过标签平滑后再计算交叉熵损失,目标函数为,
Figure BDA0003997270540000071
其中,fs表示预训练好的源域模型,包括特征提取器gs和分类器hs,满足给定输入x,fs(x)=hs(gs(x));K表示类别数目,k对应任一类别,Xs为源域样本集;给定qk为源域样本xs的标签,则
Figure BDA0003997270540000072
是对qk平滑后的标签,满足
Figure BDA0003997270540000073
α是平滑系数,0<α<1;
σ(·)表示对某一给定向量的softmax归一化操作,假设给定向量a和温度参数T,用σk表示对某个向量σ(·)操作后得到的第k维的值,
Figure BDA0003997270540000074
ak表示向量a第k维的值,j指向量a第j维,式(1)中T为1。
步骤2:用源域模型初始化目标域模型,包括特征提取器和分类器;目标域模型的特征提取器后续训练优化,分类器固定不变。
本发明中,在完成初始化后,gs=gt且hs=ht,其中,ht在之后不再更新。
步骤3:用源域模型中BN层里存储的统计信息近似源域样本的全局特征分布,用以和目标域样本的特征分布进行显式对齐,计算分布对齐损失LBN
所述步骤3中,用源域模型每个BN层中每个通道的均值、方差表示的高斯分布和目标域样本对应BN层当前batch样本的每个通道的均值、方差表示的高斯分布,计算它们相对熵(KL散度)的平均值作为衡量源域和目标域样本特征分布的距离,损失函数如下:
Figure BDA0003997270540000081
其中,M表示模型中BN层的总数,Cm表示第m个BN层的通道总数,
Figure BDA0003997270540000082
Figure BDA0003997270540000083
表示源域模型中第m个BN层第cm个通道存储的均值和方差,
Figure BDA0003997270540000084
Figure BDA0003997270540000085
表示当前batch经过目标域模型第m个BN层的第cm个通道的均值和方差;通过最小化该损失函数,用BN层中的均值方差表示的高斯分布来近似因没有源域样本而无法获取的源域特征分布,来和目标域的特征分布对齐。
步骤4:基于目标域模型的分类器的预测,对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失Lclu
所述步骤4包括以下步骤:
步骤4.1:为缓解源域分类器对目标域样本的噪声问题,本发明引入软标签损失;软标签生成过程类似模糊聚类。先以目标域模型分类器输出的概率为权重,对提取的特征进行加权平均,初始化簇中心:
Figure BDA0003997270540000086
式(3)中,δk表示第k个类的簇中心,ft表示目标模型,包括特征提取器gt和分类器ht,满足给定输入x,ft(x)=ht(gt(x));xt表示目标域样本,Bt表示当前读入的目标域样本的batch,σ(·)表示对某一给定向量的softmax归一化操作,上标T表示向量的转置。
步骤4.2:虽然直接计算样本特征到簇中心距离能得到独热编码的伪标签,但仍有部分分错样本因距离决策边界过远,靠聚类难以纠正。为降低错误伪标签的影响,将独热编码的伪标签改成平滑的软标签,能减少对错误标签过于自信的影响,最终提高模型的泛化性。基于簇中心δk计算样本到每个簇中心的余弦距离,取倒数再softmax归一化得到样本的预测分布,同时加上温度参数T(0.6≤T≤1.2)调节软标签的平滑程度,
Figure BDA0003997270540000091
其中,D表示余弦距离,
Figure BDA0003997270540000092
表示聚类得到的软标签在第k个类的概率或隶属度,此时温度参数满足0.6≤T≤1.2;
步骤4.3:用该软标签和模型分类器对目标域样本的输出概率分布计算交叉熵损失,损失函数为:
Figure BDA0003997270540000093
能一定程度纠正由于起初分类错误的目标域样本预测,从而优化模型。
步骤5:采用信息最大化损失LIM,包括最小化熵损失和最大化平均熵损失,使样本预测置信度更高,并且避免坍塌解。
所述步骤5中,最大化信息损失包含了最小化熵损失Lent和最大化平均熵损失Ldiv,使得目标域样本的预测置信度更高,同时避免所有样本被分到少数几个类而产生坍塌解,具体地,
LIM=Lent+Ldiv (6)
其中:Lent为最小化熵损失,
Figure BDA0003997270540000101
Ldiv为最大化平均熵损失,
Figure BDA0003997270540000102
Figure BDA0003997270540000103
为第k个类的平均隶属度,
Figure BDA0003997270540000104
步骤6:用以上三部分的损失:分布对齐损失LBN、交叉熵损失Lclu和信息最大化损失LIM来共同训练目标域模型,提高对目标域样本的识别准确率。
所述步骤6中,联合以上三个部分损失,优化的完整目标函数为,
Figure BDA0003997270540000105
其中,β和γ为两部分损失的超参数,β,γ∈[0.6,1.0]。
本发明中,给出一个具体实施例:
步骤一:选取Office-31数据集中的子集Amazon作为源域训练集,Webcam作为目标域。Amazon包含2817张图像背景单一的在线电商图片,Webcam包含795张有噪点的低分辨率图片,都是31个类。
步骤二:用Amazon训练源域模型,选取Resnet50网络作为骨干模型,将Resnet网络最后的全连接层替换成256维的适配层,添加一个BN层在适配层之后,最后是一层31类的分类器。我们用有标签的源域数据集训练源域模型时,在标准交叉熵损失中加入标签平滑,增加鲁棒性。其中,平滑参数α为0.1,批次大小为64。
可以通过使用pytorch等深度学习框架对数据集执行以上操作,将图片输入载入DataLoader中,遍历DataLoader中的数据输入编码器中,获取它们的模型输出,计算损失,使用sgd优化器优化模型;
步骤三:用源域模型初始化目标域模型。
步骤四:用源域模型每个BN层中每个通道的均值、方差表示的高斯分布N(μ,σ2)和目标域样本对应BN层当前batch样本的每个通道的均值、方差表示的高斯分布进行对齐,模型共54个BN层,计算它们相对熵(KL散度)的平均值LBN作为衡量源域和目标域样本特征分布的距离。
步骤五:将分类器输出作为模糊隶属度,对目标域样本特征进行软聚类,加上温度参数T=0.8,得到平滑的软标签,与分类器预测对齐计算损失Lclu
步骤六:采用信息最大化损失LIM。其中包含了最小化熵损失Lent和最大化平均熵损失Ldiv,使得目标域样本的预测置信度更高,同时避免所有样本被分到少数几个类,产生崩溃解。
步骤七:联合以上三部分损失,对部分损失加上权重
Figure BDA0003997270540000111
Figure BDA0003997270540000112
优化目标域的特征提取器,分类器则冻结不更新。
优化目标域模型采用SGD优化器训练,动量为0.9,权重衰减为10-3,批次大小为64。学习率动态变化lr=lr0(1+10p-0.75),其中lr0为初始值,除适配层和分类器设置为0.01外其余设置为0.001,p随着迭代数的增加从0变化到1。训练时,聚类得到的软标签每个epoch更新一次,超参β=0.3,γ=1.0,epoch设为20。
基于本发明训练目标域模型可以实现将源域模型知识转移或者应用到对未标记目标域数据的学习上,减少不同领域数据之间分布偏移的影响;基于此方法,可以实现计算机介质及程序、设备的开发。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
尽管已描述了本发明的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本发明范围的所有变更和修改。
显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。

Claims (8)

1.一种无源域数据的无监督领域适应方法,其特征在于:所述方法包括以下步骤:
步骤1:以有标签的源域样本训练模型,得到预训练好的源域模型;
步骤2:以源域模型初始化目标域模型,包括特征提取器和分类器;
步骤3:以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,计算分布对齐损失LBN
步骤4:基于目标域模型的分类器的预测,对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失Lclu
步骤5:对目标域样本计算信息最大化损失LIM,信息最大化损失包括最小化熵损失和最大化平均熵损失;
步骤6:以所述分布对齐损失LBN、交叉熵损失Lclu和信息最大化损失LIM共同训练目标域模型,实现无源域数据的无监督领域适应。
2.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤1中,通过标签平滑后再计算交叉熵损失,目标函数为,
Figure FDA0003997270530000011
其中,fs表示预训练好的源域模型,包括特征提取器gs和分类器hs,满足给定输入x,fs(x)=hs(gs(x));K表示类别数目,k对应任一类别,Xs为源域样本集;给定qk为源域样本xs的标签,则
Figure FDA0003997270530000012
是对qk平滑后的标签,满足
Figure FDA0003997270530000013
α是平滑系数,
0<α<1;
σ(·)表示对某一给定向量的softmax归一化操作,假设给定向量a和温度参数T,用σk表示对某个向量σ(·)操作后得到的第k维的值,
Figure FDA0003997270530000021
ak表示向量a第k维的值,j指向量a第j维,式(1)中T为1。
3.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述目标域模型的分类器固定不变。
4.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤3中,BN层的统计信息包括均值和方差;以源域模型的每个BN层中每个通道的均值、方差表示的高斯分布与目标域样本对应BN层的当前batch样本的每个通道的均值、方差表示的高斯分布,计算KL散度的平均值,作为衡量源域和目标域样本特征分布的距离。
5.根据权利要求4所述的一种无源域数据的无监督领域适应方法,其特征在于:所述分布对齐损失LBN为,
Figure FDA0003997270530000022
其中,M表示模型中BN层的总数,Cm表示第m个BN层的通道总数,
Figure FDA0003997270530000023
Figure FDA0003997270530000024
表示源域模型中第m个BN层第cm个通道存储的均值和方差,
Figure FDA0003997270530000025
Figure FDA0003997270530000026
表示当前batch经过目标域模型第m个BN层的第cm个通道的均值和方差;DKL为KL散度;
最小化损失函数LBN
6.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤4包括以下步骤:
步骤4.1:以目标域模型分类器输出的概率为权重,对提取的特征进行加权平均,初始化簇中心,
Figure FDA0003997270530000031
其中,Pk表示第k个类的簇中心,ft表示目标模型,包括特征提取器gt和分类器ht,满足给定输入x,ft(x)=ht(gt(x));xt表示目标域样本,Bt表示当前读入的目标域样本的batch,σ(·)表示对某一给定向量的softmax归一化操作,上标T表示向量的转置;
步骤4.2:根据簇中心δk计算样本到每个簇中心的余弦距离,取倒数再softmax归一化得到样本的预测分布,同时加上温度参数T调节软标签的平滑程度,
Figure FDA0003997270530000032
其中,D表示余弦距离,
Figure FDA0003997270530000033
表示聚类得到的软标签在第k个类的概率或隶属度,此时温度参数满足0.6≤T≤1.2;
步骤4.3:用该软标签和模型分类器对目标域样本的输出概率分布计算交叉熵损失,
Figure FDA0003997270530000034
纠正由于源域分类器分类错误的目标域样本预测。
7.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤5中,信息最大化损失满足,
LIM=Lent+Ldiv (6)其中,Lent为最小化熵损失,
Figure FDA0003997270530000041
Ldiv为最大化平均熵损失,
Figure FDA0003997270530000042
Figure FDA0003997270530000043
为第k个类的平均隶属度,
Figure FDA0003997270530000044
8.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤6中,完整目标函数Lgt为,
Lgt=LIM+βLBN+γLclu (7)其中,β和γ为对应的超参数,β,γ∈[0.6,1.0]。
CN202211600631.8A 2022-12-13 2022-12-13 一种无源域数据的无监督领域适应方法 Pending CN116227578A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211600631.8A CN116227578A (zh) 2022-12-13 2022-12-13 一种无源域数据的无监督领域适应方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211600631.8A CN116227578A (zh) 2022-12-13 2022-12-13 一种无源域数据的无监督领域适应方法

Publications (1)

Publication Number Publication Date
CN116227578A true CN116227578A (zh) 2023-06-06

Family

ID=86588110

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211600631.8A Pending CN116227578A (zh) 2022-12-13 2022-12-13 一种无源域数据的无监督领域适应方法

Country Status (1)

Country Link
CN (1) CN116227578A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116502644A (zh) * 2023-06-27 2023-07-28 浙江大学 一种基于无源领域自适应的商品实体匹配方法及装置
CN117152563A (zh) * 2023-10-16 2023-12-01 华南师范大学 混合目标域自适应模型的训练方法、装置及计算机设备
CN117892203A (zh) * 2024-03-14 2024-04-16 江南大学 一种缺陷齿轮分类方法、装置及计算机可读存储介质

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116502644A (zh) * 2023-06-27 2023-07-28 浙江大学 一种基于无源领域自适应的商品实体匹配方法及装置
CN116502644B (zh) * 2023-06-27 2023-09-22 浙江大学 一种基于无源领域自适应的商品实体匹配方法及装置
CN117152563A (zh) * 2023-10-16 2023-12-01 华南师范大学 混合目标域自适应模型的训练方法、装置及计算机设备
CN117152563B (zh) * 2023-10-16 2024-05-14 华南师范大学 混合目标域自适应模型的训练方法、装置及计算机设备
CN117892203A (zh) * 2024-03-14 2024-04-16 江南大学 一种缺陷齿轮分类方法、装置及计算机可读存储介质
CN117892203B (zh) * 2024-03-14 2024-06-07 江南大学 一种缺陷齿轮分类方法、装置及计算机可读存储介质

Similar Documents

Publication Publication Date Title
CN110321926B (zh) 一种基于深度残差修正网络的迁移方法及系统
CN116227578A (zh) 一种无源域数据的无监督领域适应方法
CN112446423B (zh) 一种基于迁移学习的快速混合高阶注意力域对抗网络的方法
CN111259979A (zh) 一种基于标签自适应策略的深度半监督图像聚类方法
CN113469186B (zh) 一种基于少量点标注的跨域迁移图像分割方法
CN114692741B (zh) 基于域不变特征的泛化人脸伪造检测方法
US20240054345A1 (en) Framework for Learning to Transfer Learn
CN113963165B (zh) 一种基于自监督学习的小样本图像分类方法及系统
CN110443273B (zh) 一种用于自然图像跨类识别的对抗零样本学习方法
CN116824216A (zh) 一种无源无监督域适应图像分类方法
CN114973350B (zh) 一种源域数据无关的跨域人脸表情识别方法
CN118052301A (zh) 一种迭代凝聚式簇估计联邦学习方法
CN114782742A (zh) 基于教师模型分类层权重的输出正则化方法
CN112364980B (zh) 一种弱监督场景下基于强化学习的深度神经网络训练方法
CN116958548B (zh) 基于类别统计驱动的伪标签自蒸馏语义分割方法
CN117671261A (zh) 面向遥感图像的无源域噪声感知域自适应分割方法
CN112836753A (zh) 用于域自适应学习的方法、装置、设备、介质和产品
CN116882480A (zh) 一种面向隐私保护的扩散模型驱动的无监督域泛化方法
CN116486150A (zh) 一种基于不确定性感知的图像分类模型回归误差消减方法
CN114792114B (zh) 一种基于黑盒多源域通用场景下的无监督域适应方法
CN116563602A (zh) 基于类别级软目标监督的细粒度图像分类模型训练方法
CN115578593A (zh) 一种使用残差注意力模块的域适应方法
CN112800959A (zh) 一种用于人脸识别中数据拟合估计的困难样本发掘方法
CN117456309B (zh) 基于中间域引导与度量学习约束的跨域目标识别方法
CN117593215B (zh) 一种生成模型增强的大规模视觉预训练方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination