CN113486987A - 基于特征解耦的多源域适应方法 - Google Patents

基于特征解耦的多源域适应方法 Download PDF

Info

Publication number
CN113486987A
CN113486987A CN202110890031.9A CN202110890031A CN113486987A CN 113486987 A CN113486987 A CN 113486987A CN 202110890031 A CN202110890031 A CN 202110890031A CN 113486987 A CN113486987 A CN 113486987A
Authority
CN
China
Prior art keywords
domain
source
feature
sharing
private
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110890031.9A
Other languages
English (en)
Inventor
徐行
张明
邵杰
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Electronic Science and Technology of China
Original Assignee
University of Electronic Science and Technology of China
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Electronic Science and Technology of China filed Critical University of Electronic Science and Technology of China
Priority to CN202110890031.9A priority Critical patent/CN113486987A/zh
Publication of CN113486987A publication Critical patent/CN113486987A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/16Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Health & Medical Sciences (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computational Mathematics (AREA)
  • Mathematical Analysis (AREA)
  • Mathematical Optimization (AREA)
  • Pure & Applied Mathematics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明属于领域自适应领域,提出一种基于特征解耦的多源域适应方法,包括:分别提取各个域图像的域共享特征和域私有特征,获取提取损失函数并进行正交解耦,使得域私有特征和域共享特征各个维度的表示互相独立;搭建并训练分类器,通过第一分类损失函数对源域的域共享特征正确分类,通过第二分类损失函数对源域的域私有特征无法正确分类;搭建并训域判别器,通过域判别器交叉熵损失函数无法正确区分域共享特征来自哪个域;搭建解码器,通过解码器的回归损失函数对域共享特征和域私有特征的组合进行解码,使域共享特征和域私有特征耦合,重构回原始表示;计算所有损失函数的总损失函数;对组成的模型进行整体训练并对目标域图像进行分类。

Description

基于特征解耦的多源域适应方法
技术领域
本发明涉及领域自适应领域,尤其涉及一种基于特征解耦的多源域适应方法。
背景技术
域适应是迁移学习的一个分支,特征是源域和目标域的数据分布不同,任务相同。当训练数据集和测试数据集分布不一致的情况下,通过在训练数据集上按经验误差最小准则训练得到的模型在测试数据集上性能不佳。为了在拥有不同分布的数据集上有较好的表现,引入域适应。目标是在带标签的源数据集上训练一个神经网络,并确保在显著不同于源数据集的无标签的目标数据集上也有良好的准确性。
区别于一般的域适应问题,多源域适应涉及多于一个的源域,同时将多个源域的知识迁移到目标域中辅助目标域的学习。由于多源域数据不但和目标域数据分布不同,彼此之间也不同,这种场景就更具挑战性。
现有方法的关键解决思路在于确定源域和目标域之间共享的不变表示是什么以及如何找到这些表示,目的是将每个源的特征对齐到一个统一的空间来获得领域共享的特征。典型方法包括直接特征对齐方法,该方法通过最小化深度前馈网络内的特征分布之间的差异来提取域共享表示,以及通过欺骗域判别器来提取表征的对抗性学习方法。最近,为了在考虑语义信息的情况下提取领域共享表示,研究者们还提出了细粒度的语义对齐方法。然而,它们中的大多数都需要为目标域样本生成伪标签,以尽量减少同一标签内跨域的差异。由于伪标签存在不确定性,容易导致误差累积。同时,由于域共享语义特征和域私有的特征紧密耦合在一起,以上方法很难在保留域共享语义特征的同时排除每个域私有特征的噪声影响,从而很难为所有的域提取到同样的域共享特征,进而会影响在目标域上的准确度表现。
发明内容
本发明的目的是提供一种基于特征解耦的多源域适应方法,通过假设域共享特征和域私有特征之间存在独立性,显式地将这两种特征在潜在空间解耦,让域共享特征覆盖尽量多的类别语义信息,进行目标域上的分类任务。
本发明解决其技术问题,采用的技术方案是:
本发明提出一种基于特征解耦的多源域适应方法,包括如下步骤:
步骤1.分别提取各个域图像的域共享特征和域私有特征,获取提取损失函数,通过正交解耦,使得域私有特征和域共享特征各个维度的表示互相独立,所述域包括一个目标域和多个源域;
步骤2.搭建并训练分类器,使分类器通过第一分类损失函数对源域的域共享特征正确分类,并使分类器通过第二分类损失函数对源域的域私有特征无法正确分类;
步骤3.搭建并训域判别器,使域判别器通过域判别器交叉熵损失函数无法正确区分域共享特征来自源域还是目标域;
步骤4.搭建解码器,通过解码器的回归损失函数对域共享特征和域私有特征的组合进行解码,使域共享特征和域私有特征耦合,重构回原始表示;
步骤5.计算所有损失函数的总损失函数;
步骤6.对步骤1-5组成的模型进行整体训练;
步骤7.利用所述进行整体训练好后的模型对不带标签的需要分类的目标域图像进行分类。
进一步的是,步骤1具体包括如下步骤:
步骤101.获取包含多个域的图像数据,将其中一个域作为目标域
Figure DEST_PATH_IMAGE001
,目标域样本记为
Figure 619287DEST_PATH_IMAGE002
,其余m个域作为源域
Figure DEST_PATH_IMAGE003
Figure 631849DEST_PATH_IMAGE004
,源域的样本记为
Figure DEST_PATH_IMAGE005
步骤102.所有域用一个公共的域共享编码器提取域共享特征
Figure 853883DEST_PATH_IMAGE006
,其中
Figure DEST_PATH_IMAGE007
,各个域用各自的一个域私有编码器提取域私有特征
Figure 831DEST_PATH_IMAGE008
,其中
Figure 28830DEST_PATH_IMAGE007
,图像特征经过特征提取后均以向量的形式存在;
步骤103.构建矩阵
Figure DEST_PATH_IMAGE009
Figure 397363DEST_PATH_IMAGE010
,矩阵
Figure 852615DEST_PATH_IMAGE009
Figure 752438DEST_PATH_IMAGE010
的行向量分别为源域的域共享特征
Figure DEST_PATH_IMAGE011
和目标域的域共享特征
Figure 256232DEST_PATH_IMAGE012
步骤104.构建矩阵
Figure DEST_PATH_IMAGE013
Figure 292321DEST_PATH_IMAGE014
Figure 918474DEST_PATH_IMAGE013
Figure 728430DEST_PATH_IMAGE014
的行向量分别为源域的域私有特征
Figure DEST_PATH_IMAGE015
和目标域的域私有特征
Figure 98231DEST_PATH_IMAGE016
步骤105.将
Figure 254406DEST_PATH_IMAGE009
矩阵的转置矩阵与
Figure 989144DEST_PATH_IMAGE013
矩阵相乘,
Figure 597980DEST_PATH_IMAGE010
矩阵的转置矩阵与
Figure 505893DEST_PATH_IMAGE014
矩阵相乘,两个乘积相加,获取提取损失函数,表示为:
Figure DEST_PATH_IMAGE017
其中,
Figure 703525DEST_PATH_IMAGE018
是Frobenius范数的平方;
步骤106.最小化提取损失函数,即促使域私有特征向量和域共享特征向量正交,以使得域私有特征和域共享特征向量的每一个维度都彼此独立,完成域私有特征和域共享特征的解耦。
进一步的是,步骤2中,搭建分类器C,并通过源域中的图像数据训练,所述分类器包括多个人工神经网络的全连接层,全连接层后接softmax激活函数,所述源域中的图像数据含有类别标签。
进一步的是,步骤2中,搭建好分类器后,训练分类器,使分类器通过第一分类损失函数对源域的域共享特征正确分类,具体是指:
输入源域的域共享特征
Figure DEST_PATH_IMAGE019
,通过最小化分类器输出
Figure 671481DEST_PATH_IMAGE020
和源域特征的类别标签
Figure DEST_PATH_IMAGE021
的交叉熵损失,使分类器获得对域共享特征的多分类能力,能够对样本的域共享特征进行正确分类;
其中,第一分类损失函数定义为一个batch内源域样本分类交叉熵损失的均值,表示为:
Figure 705296DEST_PATH_IMAGE022
其中,
Figure DEST_PATH_IMAGE023
为一个batch输入源域样本总的数量,
Figure 416900DEST_PATH_IMAGE024
为单个源域样本
Figure DEST_PATH_IMAGE025
的one-hot形式的类别标签,
Figure 547667DEST_PATH_IMAGE026
为经过softmax激活函数之后的分类器的对域共享特征类别预测结果,具体定义为:
Figure 372010DEST_PATH_IMAGE028
其中,
Figure DEST_PATH_IMAGE029
表示softmax激活函数。
进一步的是,步骤2中,搭建好分类器后,训练分类器,使分类器通过第二分类损失函数对源域的域私有特征无法正确分类,具体是指:
将源域私有特征输入经过源域共享特征预训练好的分类器,最小化分类器输出
Figure 689859DEST_PATH_IMAGE030
和源域特征的类别标签
Figure 205154DEST_PATH_IMAGE021
的交叉熵损失,通过梯度反转层,让梯度反向传播到域私有编码器的参数之前自动取反,使得域私有特征提取器朝着最大化交叉熵损失的方向优化;
其中,第二分类损失函数的定义如下:
Figure DEST_PATH_IMAGE031
其中,
Figure 128111DEST_PATH_IMAGE032
为经过softmax激活函数之后的分类器的对域私有特征的类别预测结果,
Figure 437869DEST_PATH_IMAGE032
表示为:
Figure 243014DEST_PATH_IMAGE034
其中,R代表梯度反转层,其数学表达为:
Figure 483371DEST_PATH_IMAGE036
Figure DEST_PATH_IMAGE037
其中,
Figure 323152DEST_PATH_IMAGE038
代表单位矩阵,在梯度反转层中,参数
Figure DEST_PATH_IMAGE039
是动态变化的,其变化表达式如下式所示:
Figure 741495DEST_PATH_IMAGE040
其中,
Figure DEST_PATH_IMAGE041
代表训练进程相对值,即当前迭代次数与总迭代次数的比值,
Figure 299515DEST_PATH_IMAGE042
为常数10。
进一步的是,步骤3中,搭建一个域判别器D,通过梯度反转层,梯度反向传播到域共享特征提取器的参数之前自动取反,使得域共享编码器朝着最大化域判别器交叉熵损失的方向优化,即让提取得到的源域和目标域共享特征在特征空间中混淆。
进一步的是,步骤3中,所述域判别器为一个二分类器,输出为0或1,表示输入样本来自源域或者目标域,域判别器交叉熵损失函数定义如下:
Figure DEST_PATH_IMAGE043
其中,
Figure 579449DEST_PATH_IMAGE044
为一个batch输入源域样本和目标域样本总的数量,
Figure DEST_PATH_IMAGE045
为单个样本
Figure 539314DEST_PATH_IMAGE025
来自哪个域的标签,0代表来自源域,1代表来自目标域,
Figure 190876DEST_PATH_IMAGE046
为经过softmax激活函数之后的域判别器的对域共享特征来自哪个域的预测结果,具体定义为:
Figure 908296DEST_PATH_IMAGE048
其中,
Figure 303505DEST_PATH_IMAGE029
表示softmax激活函数,R代表梯度反转层。
进一步的是,步骤4中,所述回归损失函数定义为:
Figure DEST_PATH_IMAGE049
其中,De为解码器。
进一步的是,步骤5具体为:
通过将所有损失函数分别乘上一个权重系数,加起来作为总损失函数,并利用随机梯度下降法对总损失函数进行优化,表示为:
Figure DEST_PATH_IMAGE051
其中,
Figure 117877DEST_PATH_IMAGE052
为各损失函数的权重系数值。
本发明的有益效果是,通过上述基于特征解耦的多源域适应方法,为每个域构建两个特征提取器,显式地解耦得到域共享特征和域私有特征,通过对抗学习训练类别分类器对域共享特征正确分类,对域私有特征无法正确分类,从而让域共享特征覆盖尽量多的类别语义信息,域私有特征则包含较少类别信息,训练域判别器无法正确区分共享特征来自哪个域以让各个域的共享特征在特征空间中尽量靠近,重构损失的约束使得重新耦合的特征仍能正确表征原始数据,尽量减少解耦过程中语义信息的损失。
附图说明
图1为本发明实施例基于特征解耦的多源域适应方法的流程图;
图2为本发明实施例中基于特征解耦的多源域适应方法的整体流程图;
图3为本发明实施例中特征解耦示意图;
图4为本发明实施例中类别分类器对抗学习示意图;
图5为本发明实施例中域判别器对抗学习示意图;
图6为本发明实施例中重构示意图。
具体实施方式
下面结合附图及实施例,详细描述本发明的技术方案。
实施例
本发明提出的一种基于特征解耦的多源域适应方法,其流程图见图1,其中,该方法包括如下步骤:
S1.分别提取各个域图像的域共享特征和域私有特征,获取提取损失函数,通过正交解耦,使得域私有特征和域共享特征各个维度的表示互相独立,所述域包括一个目标域和多个源域;
S2.搭建并训练分类器,使分类器通过第一分类损失函数对源域的域共享特征正确分类,并使分类器通过第二分类损失函数对源域的域私有特征无法正确分类;
S3.搭建并训域判别器,使域判别器通过域判别器交叉熵损失函数无法正确区分域共享特征来自源域还是目标域;
S4.搭建解码器,通过解码器的回归损失函数对域共享特征和域私有特征的组合进行解码,使域共享特征和域私有特征耦合,重构回原始表示;
S5.计算所有损失函数的总损失函数;
S6.对步骤1-5组成的模型进行整体训练;
S7.利用所述进行整体训练好后的模型对不带标签的需要分类的目标域图像进行分类。
通过本实施例提出的基于特征解耦的多源域适应方法,可以通过假设域共享特征和域私有特征之间存在独立性,显式地将这两种特征在潜在空间解耦,让域共享特征覆盖尽量多的类别语义信息,进行目标域上的分类任务。
为了进一步说明本方法,需要指出的是,本方法具体可以包括以下步骤:
步骤S1:提取各个域图像特征并解耦
这里,图像来自不同的域,如网页、照片、简笔画等,所有域的图像都分为K个类别。将其中一个域作为目标域
Figure 127291DEST_PATH_IMAGE001
,目标域样本记为
Figure 394324DEST_PATH_IMAGE002
,剩余的各个域当做源域
Figure 593224DEST_PATH_IMAGE003
Figure 262103DEST_PATH_IMAGE004
,源域的样本记为
Figure 193150DEST_PATH_IMAGE005
。利用卷积神经网络作为编码器
Figure DEST_PATH_IMAGE053
提取图像特征得到特征向量。如图2所示,所有域用一个公共的特征提取器
Figure 947479DEST_PATH_IMAGE054
来提取域共享特征
Figure 950070DEST_PATH_IMAGE006
Figure 158941DEST_PATH_IMAGE007
。同时各个域用各自的一个特征提取器提取域私有特征
Figure 323206DEST_PATH_IMAGE008
Figure 564832DEST_PATH_IMAGE007
。这些数据经过特征提取之后都以向量的形式存在。然后利用正交约束来促使域私有特征和域共享特征各个维度的表示互相独立。
具体而言,可以构建矩阵
Figure 105535DEST_PATH_IMAGE009
Figure 421109DEST_PATH_IMAGE010
,它们的行向量分别为源域的域共享特征
Figure DEST_PATH_IMAGE055
和目标域的域共享特征
Figure 21855DEST_PATH_IMAGE056
。同样地,构建矩阵
Figure 485197DEST_PATH_IMAGE013
Figure 16542DEST_PATH_IMAGE014
,它们的行向量分别为源域的域私有特征
Figure DEST_PATH_IMAGE057
和目标域的域私有特征
Figure 248940DEST_PATH_IMAGE058
。将
Figure 20587DEST_PATH_IMAGE009
矩阵的转置矩阵与
Figure 908908DEST_PATH_IMAGE013
矩阵相乘,
Figure 791414DEST_PATH_IMAGE010
矩阵的转置矩阵与
Figure 143898DEST_PATH_IMAGE014
矩阵相乘,两个乘积相加,提取损失函数的函数形式为:
Figure 774861DEST_PATH_IMAGE017
(1)
最小化此函数,即促使其正交,以使得向量的每一个维度都彼此独立。这样,提取的特征向量每一个维度就是对样本不同角度的表征,其中,部分维度是对域之间不变的语义特征的表征,部分维度是对各个域独有特征的表征,这样就完成了对特征的解耦。
步骤S2:训练分类器对源域共享特征正确分类
域适应的目标是对目标域的图像进行正确分类,因此首先需要训练一个正确的图像分类器。分类器由几个人工神经网络的全连接层组成,全连接层后接softmax激活函数。由于只有源域图像有类别标签,可以为分类器的训练提供指导,所以使用源域数据来训练分类器。如图3所示,本实施例输入源域的域共享特征
Figure 212796DEST_PATH_IMAGE019
,通过最小化分类器输出
Figure 898992DEST_PATH_IMAGE020
和源域特征的类别标签
Figure 105983DEST_PATH_IMAGE021
的交叉熵损失,使分类器获得对域共享特征的多分类能力,能够对样本的域共享特征进行正确分类。第一分类损失函数定义为一个batch内源域样本分类交叉熵损失的均值:
Figure DEST_PATH_IMAGE059
(2)
其中,
Figure 157115DEST_PATH_IMAGE023
为一个batch输入源域样本总的数量,
Figure 82346DEST_PATH_IMAGE024
为单个源域样本
Figure 306654DEST_PATH_IMAGE025
的one-hot形式的类别标签。
Figure 820681DEST_PATH_IMAGE026
为经过softmax激活函数之后的分类器的对域共享特征类别预测结果,具体定义为:
Figure 839452DEST_PATH_IMAGE028
Figure 251979DEST_PATH_IMAGE029
表示softmax激活函数。通过对域共享特征的正确分类,促使类别信息包含进域共享特征之中,即让域共享特征提取器获得在不同源域上提取类别语义信息的能力。
步骤S3:训练模型,让分类器对域私有特征无法进行正确分类
如图3所示,将源域私有特征输入经过源域共享特征预训练好的分类器,由于训练目的是让类别语义信息大多包含在域共享特征中,而域私有特征中包含尽量少的类别信息,因此分类器对域私有特征无法进行正确分类,应该最大化分类器输出
Figure 279978DEST_PATH_IMAGE030
和源域特征的类别标签
Figure 399244DEST_PATH_IMAGE021
的交叉熵损失,这样就和步骤S2中优化目标相反,构成了对抗训练。为了避免固定参数分阶段训练带来的麻烦和风险,引入梯度反转层(Gradient ReversalLayer)。梯度反转在反向传播过程中梯度方向自动取反,在前向传播过程中实现恒等变换。相关数学表示如下式所示:
Figure 854496DEST_PATH_IMAGE060
Figure 488740DEST_PATH_IMAGE037
其中,
Figure 54850DEST_PATH_IMAGE038
代表单位矩阵。
通过梯度反转层,让梯度反向传播到域私有特征提取器的参数之前自动取反,使得域私有特征提取器朝着最大化交叉熵损失的方向优化,即让分类器对域私有特征的分类结果尽量偏离标签值,让域私有特征尽量少包含类别信息。由于经过公式(1)使得域共享特征和域私有特征得到解耦,这样间接式地让类别信息尽可能多地包含在域共享特征中。第二分类损失函数的定义如下:
Figure DEST_PATH_IMAGE061
(3)
形式和(2)中一致,
Figure 42004DEST_PATH_IMAGE032
为经过softmax激活函数之后的分类器的对域私有特征的类别预测结果,定义与(2)中区别在于输入为域私有特征,且添加了梯度反转层:
Figure 668158DEST_PATH_IMAGE034
其中,R代表梯度反转层。在梯度反转层中,参数
Figure 789697DEST_PATH_IMAGE039
并不是固定值,而是动态变化的。其变化表达式如下式所示:
Figure 97182DEST_PATH_IMAGE040
其中,
Figure 253357DEST_PATH_IMAGE041
代表训练进程相对值,即当前迭代次数与总迭代次数的比值,
Figure 50412DEST_PATH_IMAGE042
为常数10。
步骤S4:训练模型,让域判别器无法正确区分共享特征来自哪个域
如图4所示,搭建一个域判别器
Figure 659247DEST_PATH_IMAGE062
,用来判断输入的共享特征是来自源域还是目标域。域判别器实质是一个二分类器,输出为0或1,表示输入样本来自源域或者目标域。由人工神经网络的几个全连接层组成,全连接层后接激活函数。损失函数为各个样本的域标签与域判别器的判断之间的交叉熵。同样通过梯度反转层,让梯度反向传播到域共享特征提取器的参数之前自动取反,使得域共享特征提取器朝着最大化域判别器交叉熵损失的方向优化。域判别器交叉熵损失函数定义如下:
Figure DEST_PATH_IMAGE063
(4)
其中,
Figure 754111DEST_PATH_IMAGE044
为一个batch输入源域样本和目标域样本总的数量,
Figure 764793DEST_PATH_IMAGE064
为单个样本
Figure 732749DEST_PATH_IMAGE025
来自哪个域的标签,0代表来自源域,1代表来自目标域。
Figure DEST_PATH_IMAGE065
为经过softmax激活函数之后的域判别器的对域共享特征来自哪个域的预测结果,具体定义为:
Figure DEST_PATH_IMAGE067
其中,
Figure 766564DEST_PATH_IMAGE029
表示softmax激活函数,R代表梯度反转层。
域共享特征提取器从各个不同的域中提取的特征让域分类器无法正确区分特征来自源域还是目标域,即让提取得到的源域和目标域共享特征在特征空间中混淆。由于公式(2)中的训练已经让域共享特征提取器获得了可以用来正确分类的源域共享特征提取能力,当前步骤的混淆让域共享特征提取器同样可以提取目标域中的含有类别信息的域共享特征。
步骤S5:特征耦合重构回原始表示
如图5所示,为了保证解耦之后的域共享特征和域私有特征没有损失样本特征的信息,搭建解码器
Figure 478168DEST_PATH_IMAGE068
,输入为域共享特征和域私有特征的组合,通过
Figure 31771DEST_PATH_IMAGE068
的解码,重构回原始特征。通过最小化重构特征与原始特征的回归损失函数施加约束,优化模型,让特征解耦过程中减少样本原始语义信息的损失。解码器的回归损失函数定义为:
Figure DEST_PATH_IMAGE069
(5)
其中,De为
Figure 170628DEST_PATH_IMAGE068
。即通过缩小解耦前的特征与重新耦合得到的特征之间的差异,促使网络在解耦过程中尽量少损失原始信息,以免由于类别语义信息的减少给分类器的分类带来困难。
步骤S6:计算总损失函数
将步骤S1~S5中所有损失函数分别乘上一个权重系数,加起来作为总损失函数,再利用随机梯度下降法对总损失函数进行优化,对整个网络进行训练:
Figure 754057DEST_PATH_IMAGE051
其中,
Figure 207035DEST_PATH_IMAGE052
为各损失函数的权重系数值。这个损失函数值反映得到结果和实际结果之间的差异,用来衡量模型好坏,损失函数值越小,模型的效果越好.
步骤S7:模型整体训练
重复步骤S1~S6,处理完所有训练数据,并在验证数据集上进行测试。重复训练、验证,在设定的轮次里将验证结果最好的一组训练参数保存下来作为最终的模型参数。
步骤S8:对目标域图像进行分类
将训练完成的模型参数加载到模型中,输入不带标签的需要分类的目标域图像,执行测试程序,分类器的输出即为分类结果,模型完成多分类任务。

Claims (10)

1.基于特征解耦的多源域适应方法,其特征在于,包括如下步骤:
步骤1.分别提取各个域图像的域共享特征和域私有特征,获取提取损失函数,通过正交解耦,使得域私有特征和域共享特征各个维度的表示互相独立,所述域包括一个目标域和多个源域;
步骤2.搭建并训练分类器,使分类器通过第一分类损失函数对源域的域共享特征正确分类,并使分类器通过第二分类损失函数对源域的域私有特征无法正确分类;
步骤3.搭建并训域判别器,使域判别器通过域判别器交叉熵损失函数无法正确区分域共享特征来自源域还是目标域;
步骤4.搭建解码器,通过解码器的回归损失函数对域共享特征和域私有特征的组合进行解码,使域共享特征和域私有特征耦合,重构回原始表示;
步骤5.计算所有损失函数的总损失函数;
步骤6.对步骤1-5组成的模型进行整体训练;
步骤7.利用所述进行整体训练好后的模型对不带标签的需要分类的目标域图像进行分类。
2.根据权利要求1所述的基于特征解耦的多源域适应方法,其特征在于,步骤1具体包括如下步骤:
步骤101.获取包含多个域的图像数据,将其中一个域作为目标域
Figure 709708DEST_PATH_IMAGE001
,目标域样本记为
Figure 260382DEST_PATH_IMAGE002
,其余m个域作为源域
Figure 399239DEST_PATH_IMAGE003
Figure 982668DEST_PATH_IMAGE004
,源域的样本记为
Figure 170066DEST_PATH_IMAGE005
步骤102.所有域用一个公共的域共享编码器提取域共享特征
Figure 155340DEST_PATH_IMAGE006
,其中
Figure 465099DEST_PATH_IMAGE007
,各个域用各自的一个域私有编码器提取域私有特征
Figure 722773DEST_PATH_IMAGE008
,其中
Figure 776180DEST_PATH_IMAGE007
,图像特征经过特征提取后均以向量的形式存在;
步骤103.构建矩阵
Figure 615960DEST_PATH_IMAGE009
Figure 96620DEST_PATH_IMAGE010
,矩阵
Figure 326744DEST_PATH_IMAGE009
Figure 183842DEST_PATH_IMAGE010
的行向量分别为源域的域共享特征
Figure 878128DEST_PATH_IMAGE011
和目标域的域共享特征
Figure 218105DEST_PATH_IMAGE012
步骤104.构建矩阵
Figure 263421DEST_PATH_IMAGE013
Figure 658631DEST_PATH_IMAGE014
Figure 207424DEST_PATH_IMAGE013
Figure 233148DEST_PATH_IMAGE014
的行向量分别为源域的域私有特征
Figure 500182DEST_PATH_IMAGE015
和目标域的域私有特征
Figure 699082DEST_PATH_IMAGE016
步骤105.将
Figure 289332DEST_PATH_IMAGE009
矩阵的转置矩阵与
Figure 548275DEST_PATH_IMAGE013
矩阵相乘,
Figure 302604DEST_PATH_IMAGE010
矩阵的转置矩阵与
Figure 39616DEST_PATH_IMAGE014
矩阵相乘,两个乘积相加,获取提取损失函数,表示为:
Figure 500685DEST_PATH_IMAGE017
其中,
Figure 930529DEST_PATH_IMAGE018
是Frobenius范数的平方;
步骤106.最小化提取损失函数,即促使域私有特征向量和域共享特征向量正交,以使得域私有特征和域共享特征向量的每一个维度都彼此独立,完成域私有特征和域共享特征的解耦。
3.根据权利要求1所述的基于特征解耦的多源域适应方法,其特征在于,步骤2中,搭建分类器C,并通过源域中的图像数据训练,所述分类器包括多个人工神经网络的全连接层,全连接层后接softmax激活函数,所述源域中的图像数据含有类别标签。
4.根据权利要求3所述的基于特征解耦的多源域适应方法,其特征在于,步骤2中,搭建好分类器后,训练分类器,使分类器通过第一分类损失函数对源域的域共享特征正确分类,具体是指:
输入源域的域共享特征
Figure 906575DEST_PATH_IMAGE019
,通过最小化分类器输出
Figure 447278DEST_PATH_IMAGE020
和源域特征的类别标签
Figure 776235DEST_PATH_IMAGE021
的交叉熵损失,使分类器获得对域共享特征的多分类能力,能够对样本的域共享特征进行正确分类;
其中,第一分类损失函数定义为一个batch内源域样本分类交叉熵损失的均值,表示为:
Figure 111401DEST_PATH_IMAGE022
其中,
Figure 840323DEST_PATH_IMAGE023
为一个batch输入源域样本总的数量,
Figure 856820DEST_PATH_IMAGE024
为单个源域样本
Figure 354798DEST_PATH_IMAGE025
的one-hot形式的类别标签,
Figure 860865DEST_PATH_IMAGE026
为经过softmax激活函数之后的分类器的对域共享特征类别预测结果,具体定义为:
Figure DEST_PATH_IMAGE027
其中,
Figure 264034DEST_PATH_IMAGE028
表示softmax激活函数。
5.根据权利要求4所述的基于特征解耦的多源域适应方法,其特征在于,步骤2中,搭建好分类器后,训练分类器,使分类器通过第二分类损失函数对源域的域私有特征无法正确分类,具体是指:
将源域私有特征输入经过源域共享特征预训练好的分类器,最小化分类器输出
Figure 146539DEST_PATH_IMAGE029
和源域特征的类别标签
Figure 499023DEST_PATH_IMAGE021
的交叉熵损失,通过梯度反转层,让梯度反向传播到域私有编码器的参数之前自动取反,使得域私有特征提取器朝着最大化交叉熵损失的方向优化;
其中,第二分类损失函数的定义如下:
Figure 441571DEST_PATH_IMAGE030
其中,
Figure 817189DEST_PATH_IMAGE031
为经过softmax激活函数之后的分类器的对域私有特征的类别预测结果,
Figure 237806DEST_PATH_IMAGE031
表示为:
Figure 710376DEST_PATH_IMAGE032
其中,R代表梯度反转层,其数学表达为:
Figure DEST_PATH_IMAGE033
Figure 246662DEST_PATH_IMAGE034
其中,
Figure DEST_PATH_IMAGE035
代表单位矩阵,在梯度反转层中,参数
Figure 171892DEST_PATH_IMAGE036
是动态变化的,其变化表达式如下式所示:
Figure 396200DEST_PATH_IMAGE037
其中,
Figure 660959DEST_PATH_IMAGE038
代表训练进程相对值,即当前迭代次数与总迭代次数的比值,
Figure 945310DEST_PATH_IMAGE039
为常数10。
6.根据权利要求1所述的基于特征解耦的多源域适应方法,其特征在于,步骤3中,搭建一个域判别器D,通过梯度反转层,梯度反向传播到域共享特征提取器的参数之前自动取反,使得域共享编码器朝着最大化域判别器交叉熵损失的方向优化,即让提取得到的源域和目标域共享特征在特征空间中混淆。
7.根据权利要求6所述的基于特征解耦的多源域适应方法,其特征在于,步骤3中,所述域判别器为一个二分类器,输出为0或1,表示输入样本来自源域或者目标域,域判别器交叉熵损失函数定义如下:
Figure 357837DEST_PATH_IMAGE040
其中,
Figure 120257DEST_PATH_IMAGE041
为一个batch输入源域样本和目标域样本总的数量,
Figure 488790DEST_PATH_IMAGE042
为单个样本
Figure 944042DEST_PATH_IMAGE025
来自哪个域的标签,0代表来自源域,1代表来自目标域,
Figure 843865DEST_PATH_IMAGE043
为经过softmax激活函数之后的域判别器的对域共享特征来自哪个域的预测结果,具体定义为:
Figure 409976DEST_PATH_IMAGE044
其中,
Figure 446065DEST_PATH_IMAGE028
表示softmax激活函数,R代表梯度反转层。
8.根据权利要求1所述的基于特征解耦的多源域适应方法,其特征在于,步骤4中,所述回归损失函数定义为:
Figure 9901DEST_PATH_IMAGE045
其中,De为解码器。
9.根据权利要求1-8任意一项所述的基于特征解耦的多源域适应方法,其特征在于,步骤5具体为:
通过将所有损失函数分别乘上一个权重系数,加起来作为总损失函数,并利用随机梯度下降法对总损失函数进行优化,表示为:
Figure 397020DEST_PATH_IMAGE046
其中,
Figure DEST_PATH_IMAGE047
为各损失函数的权重系数值。
10.根据权利要求9所述的基于特征解耦的多源域适应方法,其特征在于,步骤6-7具体是指:
重复步骤1-5,处理完所有训练数据,并在验证数据集上进行测试,重复训练及验证,在设定的轮次里将验证结果最好的一组训练参数保存下来作为最终的模型参数;
将所述最终的模型参数加载到模型中,输入不带标签的需要分类的目标域图像,执行测试程序,分类器的输出即为分类结果,模型完成多分类任务。
CN202110890031.9A 2021-08-04 2021-08-04 基于特征解耦的多源域适应方法 Pending CN113486987A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110890031.9A CN113486987A (zh) 2021-08-04 2021-08-04 基于特征解耦的多源域适应方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110890031.9A CN113486987A (zh) 2021-08-04 2021-08-04 基于特征解耦的多源域适应方法

Publications (1)

Publication Number Publication Date
CN113486987A true CN113486987A (zh) 2021-10-08

Family

ID=77945612

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110890031.9A Pending CN113486987A (zh) 2021-08-04 2021-08-04 基于特征解耦的多源域适应方法

Country Status (1)

Country Link
CN (1) CN113486987A (zh)

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114627196A (zh) * 2022-01-06 2022-06-14 福州大学 基于变分自动编码器的潜变量空间解耦方法
CN114694150A (zh) * 2022-05-31 2022-07-01 成都考拉悠然科技有限公司 一种提升数字图像分类模型泛化能力的方法及系统
CN114693972A (zh) * 2022-03-29 2022-07-01 电子科技大学 一种基于重建的中间域领域自适应方法
CN115050032A (zh) * 2022-05-02 2022-09-13 清华大学 一种基于特征对齐和熵正则化的域适应文本图像识别方法
CN115357710A (zh) * 2022-08-18 2022-11-18 百度在线网络技术(北京)有限公司 表格描述文本生成模型的训练方法、装置及电子设备
CN116912593A (zh) * 2023-07-31 2023-10-20 大连理工大学 域对抗的遥感图像目标分类方法
CN117664567A (zh) * 2024-01-30 2024-03-08 东北大学 一种面向多源域不平衡数据的滚动轴承跨域故障诊断方法

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111371830A (zh) * 2019-11-26 2020-07-03 航天科工网络信息发展有限公司 一种万网融合场景下基于数据驱动的智能协同云架构
US10839269B1 (en) * 2020-03-20 2020-11-17 King Abdulaziz University System for fast and accurate visual domain adaptation
CN112633071A (zh) * 2020-11-30 2021-04-09 之江实验室 基于数据风格解耦内容迁移的行人重识别数据域适应方法

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111371830A (zh) * 2019-11-26 2020-07-03 航天科工网络信息发展有限公司 一种万网融合场景下基于数据驱动的智能协同云架构
US10839269B1 (en) * 2020-03-20 2020-11-17 King Abdulaziz University System for fast and accurate visual domain adaptation
CN112633071A (zh) * 2020-11-30 2021-04-09 之江实验室 基于数据风格解耦内容迁移的行人重识别数据域适应方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
XIAOYI ZHANG 等: "Scene graph generation via multi-relation classification and cross-modal attention coordinator" *
盛一堃: "基于深度学习的迁移学习方法研究与应用" *

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114627196A (zh) * 2022-01-06 2022-06-14 福州大学 基于变分自动编码器的潜变量空间解耦方法
CN114693972A (zh) * 2022-03-29 2022-07-01 电子科技大学 一种基于重建的中间域领域自适应方法
CN114693972B (zh) * 2022-03-29 2023-08-29 电子科技大学 一种基于重建的中间域领域自适应方法
CN115050032A (zh) * 2022-05-02 2022-09-13 清华大学 一种基于特征对齐和熵正则化的域适应文本图像识别方法
CN114694150A (zh) * 2022-05-31 2022-07-01 成都考拉悠然科技有限公司 一种提升数字图像分类模型泛化能力的方法及系统
CN114694150B (zh) * 2022-05-31 2022-10-21 成都考拉悠然科技有限公司 一种提升数字图像分类模型泛化能力的方法及系统
CN115357710A (zh) * 2022-08-18 2022-11-18 百度在线网络技术(北京)有限公司 表格描述文本生成模型的训练方法、装置及电子设备
CN116912593A (zh) * 2023-07-31 2023-10-20 大连理工大学 域对抗的遥感图像目标分类方法
CN116912593B (zh) * 2023-07-31 2024-01-23 大连理工大学 域对抗的遥感图像目标分类方法
CN117664567A (zh) * 2024-01-30 2024-03-08 东北大学 一种面向多源域不平衡数据的滚动轴承跨域故障诊断方法
CN117664567B (zh) * 2024-01-30 2024-04-02 东北大学 一种面向多源域不平衡数据的滚动轴承跨域故障诊断方法

Similar Documents

Publication Publication Date Title
CN113486987A (zh) 基于特征解耦的多源域适应方法
CN109214452B (zh) 基于注意深度双向循环神经网络的hrrp目标识别方法
CN109086658B (zh) 一种基于生成对抗网络的传感器数据生成方法与系统
Audebert et al. Generative adversarial networks for realistic synthesis of hyperspectral samples
CN112784913B (zh) 一种基于图神经网络融合多视图信息的miRNA-疾病关联预测方法及装置
CN113657561B (zh) 一种基于多任务解耦学习的半监督夜间图像分类方法
CN112465120A (zh) 一种基于进化方法的快速注意力神经网络架构搜索方法
CN112765352A (zh) 基于具有自注意力机制的图卷积神经网络文本分类方法
Chen Model reprogramming: Resource-efficient cross-domain machine learning
CN111127146A (zh) 基于卷积神经网络与降噪自编码器的信息推荐方法及系统
CN114998602B (zh) 基于低置信度样本对比损失的域适应学习方法及系统
CN113822125B (zh) 唇语识别模型的处理方法、装置、计算机设备和存储介质
CN113806494A (zh) 一种基于预训练语言模型的命名实体识别方法
Lopes et al. Efficient guided evolution for neural architecture search
Batson et al. Molecular cross-validation for single-cell RNA-seq
CN111612133B (zh) 基于人脸图像多阶段关系学习的内脏器官特征编码方法
He et al. Exploring the gap between collapsed & whitened features in self-supervised learning
Caucheteux et al. Long-range and hierarchical language predictions in brains and algorithms
Lin et al. PS-mixer: A polar-vector and strength-vector mixer model for multimodal sentiment analysis
CN112786160A (zh) 基于图神经网络的多图片输入的多标签胃镜图片分类方法
CN112241741A (zh) 基于分类对抗网的自适应图像属性编辑模型和编辑方法
CN116206327A (zh) 一种基于在线知识蒸馏的图像分类方法
CN113469338B (zh) 模型训练方法、模型训练装置、终端设备及存储介质
Hu et al. Ranknas: Efficient neural architecture search by pairwise ranking
CN115032602A (zh) 一种基于多尺度卷积胶囊网络的雷达目标识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20211008

WD01 Invention patent application deemed withdrawn after publication