CN114693972B - 一种基于重建的中间域领域自适应方法 - Google Patents

一种基于重建的中间域领域自适应方法 Download PDF

Info

Publication number
CN114693972B
CN114693972B CN202210324083.4A CN202210324083A CN114693972B CN 114693972 B CN114693972 B CN 114693972B CN 202210324083 A CN202210324083 A CN 202210324083A CN 114693972 B CN114693972 B CN 114693972B
Authority
CN
China
Prior art keywords
domain
module
feature
source
target
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210324083.4A
Other languages
English (en)
Other versions
CN114693972A (zh
Inventor
张可
吴圣娜
刘杰彦
黄乐天
贾宇明
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Electronic Science and Technology of China
Yangtze River Delta Research Institute of UESTC Huzhou
Original Assignee
University of Electronic Science and Technology of China
Yangtze River Delta Research Institute of UESTC Huzhou
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Electronic Science and Technology of China, Yangtze River Delta Research Institute of UESTC Huzhou filed Critical University of Electronic Science and Technology of China
Priority to CN202210324083.4A priority Critical patent/CN114693972B/zh
Publication of CN114693972A publication Critical patent/CN114693972A/zh
Application granted granted Critical
Publication of CN114693972B publication Critical patent/CN114693972B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

本发明公开了一种基于重建的中间域领域自适应方法,属于计算机视觉、智能频谱数据分析等领域自适应技术领域,具体涉及一种基于重建的中间域领域自适应方法。本发明针对现有领域自适应方法领域特征对齐困难等不足之处,提出一种基于重建的中间域领域自适应方法,并且能够实现更好的分类性能。本发明使用重建的方法对源域数据和目标域数据的特征进行提取,这样提取到特征将包含更多的数据信息,具有更强的可辨别性。同时,针对实际场景中两域之间直接对域差异最小化实现困难的问题,本发明通过在中间域对两域特征进行对齐,从而达到减轻特征对齐难度的目的,最终实现目标域数据的有效分类。

Description

一种基于重建的中间域领域自适应方法
技术领域
本发明属于计算机视觉、智能频谱数据分析等领域自适应技术领域,具体涉及一种基于重建的中间域领域自适应方法。
背景技术
在计算机视觉等领域广泛使用的深度学习模型要求训练集和测试集的数据分布相同,但是在军事测控,遥感成像,医疗健康等数据差异大,视觉技术要求高的领域使用的训练集和测试集的数据分布经常存在较大偏差,这对于深度学习模型的训练和更新是一个巨大的挑战。领域自适应问题是迁移学习的研究内容之一,它侧重于解决特征空间一致、类别空间一致,仅特征分布不一致的问题,其目的就是利用有标签的领域知识来辅助目标领域的知识获取和学习。
领域自适应的基本方法可以分为数据分布自适应法,特征选择法和子空间学习法三大类,而随着深度学习方法的广泛应用,出现了更多基于深度神经网络进行领域自适应的方法,比如通过增加自适应层,选择不同的度量准则实现领域对齐,或者通过深度对抗,在博弈对抗中实现领域对齐。
目前大量领域自适应方法通过对域差异进行直接最小化实现领域对齐,使用增加自适应层或者深度对抗的方法,提取两域中的域不变特征,将两域特征进行直接对齐,这往往忽略了在实际场景中两域之间的差异可能过大,进而导致域差异最小化的实现很困难,另外深度神经网络提取的领域不变特征中可能仍然存在残留的领域特有特征,这将对特征对齐带来影响。
因此,针对领域自适应问题,有必要提出一种能够有效提取领域不变特征,实现领域特征对齐的领域自适应方法,利用源域知识在目标域获得更好的识别结果。
发明内容
本发明要解决的技术问题是:针对现有领域自适应方法领域特征对齐困难等不足之处,提出一种基于重建的中间域领域自适应方法,并且能够实现更好的数据分类性能。
本发明提供的一种基于重建的中间域领域自适应方法,包括下列步骤:
步骤S1:获取有标签的源域数据集Ds和无标签的目标域数据集Dt,其中,有标签源域数据集Ds的数据数量为n,每个数据定义为数据/>的类别标签定义为/>目标域数据集Dt的数据数量为m,且n、m为正整数;
步骤S2:构建深度网络模型,所述深度网络模型包括重建特征提取模块,中间域特征提取模块F,中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D;
所述重建特征提取模块包括源域重建特征提取模块和目标域重建特征提取模块,且源域重建特征提取模块和目标域重建特征提取模块的损失分别为源域重建特征损失Ls-recon和目标域重建特征损失Lt-recon
所述源域重建特征提取模块包括源域特征编码器Es和源域特征解码器Ds,其输入数据为源域数据;
所述目标域重建特征提取模块包括目标域特征编码器Et和目标域特征解码器Dt,其输入数据为目标域数据;
其中,源域特征编码器Es包括多个交替的卷积层与最大池化层,且源域特征解码器Ds的网络结构与源域特征编码器Es镜像对称,即网络结构设置完全相反;标域特征编码器Et与源域特征编码器Es的网络结构相同,目标域特征解码器Dt与源域特征解码器Ds的网络结构相同;
且所述源域特征编码器Es和目标域特征编码器Et的输入还输入中间域特征对齐模块D;
所述中间域特征提取模块F的输入为源域数据和目标域数据,中间域特征提取模块F用于提取两域数据的数据特征,得到源域特征和目标域特征,并将两域特征同时输入中间域对抗模块AD和中间域特征对齐模块D,以及将源域特征输入中间域分类模块C,通过与中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D的配合,完成中间域特征对齐;
所述中间域分类模块C根据源域数据标签对输入的源域特征进行分类处理,中间域分类模块C的损失为中间域分类损失LC
所述中间域对抗模块AD,用于对中间域特征提取模块F混淆的两域特征(源域特征和目标域特征)进行辨别,且所述中间域对抗模块AD的训练目的为:中间域对抗模块AD不能区分两域特征,并在反向传播时对梯度进行翻转,反向更新中间域对抗模块AD的网络参数;其中,中间域对抗模块AD的损失为中间域域对抗损失LAD
所述中间域特征对齐模块D为一个域分类器,用于对输入数据(中间域特征提取模块F输出的源域特征和目标域特征,以及源域特征编码器Es和目标域特征编码器Et的输出)进行域分类,所包括的域类别有:源域,中间域和目标域;从而实现源域特征和目标域特征在中间域进行对齐;其中,中间域特征对齐模块D的损失为中间域域判别损失LD
步骤S3:将源域数据和目标域数据分别输入源域重建特征提取模块和目标域重建特征提取模块,以及将源域数据和目标域数据同时输入中间域特征提取模块F;
并通过迭代训练使得源域重建特征损失Ls-recon,目标域重建特征损失Lt-recon收敛,中间域分类损失LC,中间域域对抗损失LAD和中间域域判别损失LD收敛,得到训练好的深度网络模型;
步骤S4:基于训练好的深度网络模型的中间域特征提取模块F和中间域分类模块C组成分类网络;将目标域的待分类数据输入所述分类网络,基于其前向传播的输出得到分类结果。
进一步的,所述源域特征编码器Es的依次包括:卷积层1、最大池化层1、卷积层2、最大池化层2、卷积层3和最大池化层3。
进一步的,所述中间域特征提取模块F包括至少三层卷积层,且每层卷积层后依次设置有批归一化层与最大池化层,每层卷积层采用非线性激活函数,并在第二层卷积层之后采用dropout防止过拟合。
进一步的,所述中间域分类模块C包括多层全连接层,在倒数第二层的全连接层后加入批归一化层,并通过dropout防止过拟合,最后一层全连接层采用Softmax函数进行分类输出。
进一步的,所述中间域对抗模块AD包括两层全连接层,在每一层全连接层后加入批归一化层,采用非线性激活函数作为激活函数,最后一层全连接层采用Softmax函数进行判别输出。
进一步的,所述中间域特征对齐模块D的网络结构与中间域分类模块C的网络结构相同。
本发明提供的技术方案至少带来如下有益效果:
(1)本发明提出了一种基于重建的中间域领域自适应方法,使用重建的方法对源域数据和目标域数据的特征进行提取,这样提取到的特征将包含更多的数据信息,具有更强的可辨别性。
(2)针对实际场景中两域之间直接对域差异最小化实现困难的问题,本发明通过在中间域对两域特征进行对齐,从而达到减轻特征对齐难度的目的,最终实现目标域数据的有效分类。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1是本发明实施例提供的一种基于重建的中间域领域自适应方法的实现流程图。
图2是本发明实施例提供的一种基于重建的中间域领域自适应方法采用的深度网络模型的结构示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明实施方式作进一步地详细描述。
本发明实施例通过重建和对抗的方法将源域数据和目标域数据在中间域进行特征对齐,实现分类器对两域数据的准确分类。
参见图1,以处理图像分类为例,本发明实施例提供的一种基于重建的中间域领域自适应方法的具体处理过程包括:
步骤1:获得n个有标签源域数据集和m个无标签目标域数据集
即获取n个有标签的源域图像数据,m个无标签目标域数据,其中,定义表示源域的第i个数据(图像),/>表示/>的图像分类标签,/>表示第j个目标域图像数据。
步骤2:构建深度网络模型,所述深度网络模型包括重建特征提取模块,中间域特征提取模块F,中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D,如图2所示;
步骤3:将源域数据和目标域数据分别输入源域重建特征提取模块和目标域重建特征提取模块,提取到源域数据和目标域数据的可辨别特征作为基准特征(即基准图像特征),并且将源域数据和目标域数据同时输入中间域特征提取模块F,与中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D配合,提取到中间域特征(即中间域图像特征);
步骤4:通过迭代训练使得源域重建特征损失Ls-recon,目标域重建特征损失Lt-recon收敛,中间域分类损失LC,中间域域对抗损失LAD和中间域域判别损失LD收敛,从而在能对源域数据进行准确分类(图像分类)的前提下在中间域对齐领域特征;
步骤5:经过训练之后,将目标域数据输入由中间域特征提取模块F和中间域分类模块C组成的分类网络得到分类结果,以验证方法的有效性。
即,基于训练后的深度网络模型的中间域特征提取模块F和中间域分类模块C组成分类网络,对目标域的待分类的图像(无标签的图像数据),将其输入到所述分类网络,基于其前向传播的输出得到分类结果。
本发明实施例中,所述深度网络模型包括重建特征提取模块,中间域特征提取模块F,中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D具有对应的损失函数,其中重建特征提取模块包括源域重建特征损失Ls-recon和目标域重建特征损失Lt-recon,通过重建特征提取模块可以提取到源域数据和目标域数据的可辨别信息,中间域分类模块C具有中间域分类损失LC,通过中间域分类模块C可以对转移到中间域的源域数据进行分类,中间域对抗模块AD具有中间域域对抗损失LAD,通过梯度翻转层与中间域特征提取模块F进行对抗训练,对两域数据进行混淆,而中间域特征对齐模块D具有中间域域判别损失LD,通过中间域特征对齐模块D对源域数据,中间域数据和目标域数据进行判别实现中间域特征的对齐。而对于获取的源域Ds和目标域Dt,签源域数据目标域数据/>源域和目标域具有相同的特征空间,即Xs=Xt,并且具有相同的类别空间,即Ys=Yt,但是两域的边缘分布不同,即Ps(xs)≠Pt(xt),方法训练分类器将利用源域数据知识对目标域数据进行有效分类。其中,Xs、Xt分别表示源域和目标域的图像特征空间,Ys、Yt分别表示源域和目标域的图像类别空间,Ps(xs)、Pt(xt)分别表示源域和目标域的边缘分布。
作为一种可能的实现方式,本发明实施例的深度网络模型的各部分具体设置为:
(1)重建特征提取模块。重建特征提取模块包括源域重建特征提取模块和目标域重建特征提取模块,而二者分别由源域特征编码器Es,源域特征解码器Ds和目标域特征编码器Et,目标域特征解码器Dt组成。
例如对应于手写数字数据集(图像类别取决于对应的数字,比如0~9分别作为一个识别类别),网络结构设置选择使用浅层网络,当数据集更加复杂,可以选择使用深层网络。以源域特征编码器Es为例具体网络结构为:卷积层1-最大池化层1-卷积层2-最大池化层2-卷积层3-最大池化层3,而源域特征解码器Ds与源域特征编码器Es的网络结构设置完全相反,目标域部分则与源域部分完全相同。源域重建特征提取模块和目标域重建特征提取模块将维护源域重建特征损失Ls-recon和目标域重建特征损失Lt-recon,通过降低重建损失,即提取特征能尽可能重现输入数据特征,进而提取到源域数据和目标域数据的可辨别特征,本发明实施例将提取特征作为源域和目标域的基准特征。源域重建特征损失Ls-recon和目标域重建特征损失Lt-recon表达式分别为公式(1)和公式(2)。
其中,xs、xt分别表示源域数据和目标域数据,分别表示源域特征编码器Es和源域特征解码器Ds的网络参数,/>分别表示目标域特征编码器Et和目标域特征解码器Dt的网络参数,lr()代表平方损失函数,fs()和ft()分别代表源域特征解码器Ds和目标域特征解码器Dt的输出。
(2)中间域特征提取模块F。中间域特征提取模块F负责将源域数据特征和目标域数据特征进行提取,得到源域特征F(Xs)和目标域特征F(Xt),由于源域数据分布和目标域数据分布可能差别很大,直接对齐两域特征难度较大,本发明实施例选择将两域数据特征在中间域进行对齐,减小对齐的难度。将源域数据和目标域数据输入中间域特征提取模块F,模块在提取两域特征的同时将对两域特征进行混淆,与中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D配合,进而达到中间域特征对齐的目的。中间域特征提取模块F的深度网络结构采用三层卷积,三层卷积的卷积核尺寸分别为5×5,3×3,3×3,通道数分别为64,64,50,同时为了加快训练速度,在每层卷积之后使用Batch Normalization与最大池化层MaxPool进行归一化和特征提取,这样也能有效的减少参数量,并且在第二层卷积之后使用dropout以防止过拟合,使用ReLU函数作为激活函数进行非线性激活。
(3)中间域分类模块C。将中间域特征提取模块F提取到转移到中间域的源域特征输入中间域分类模块C,根据源域数据标签对转移到中间域的源域数据进行分类。中间域分类模块C的深度网络结构使用了三层全连接层,每层的通道数分别为100,100,10,即最后一层全连接层的通道数与预置的图像类别数对应,同样的在前两层全连接层之后加入BatchNormalization和dropout加快训练速度和防止过拟合,使用ReLU函数作为激活函数,最后使用Softmax函数进行分类,在使用Pytorch编程时直接使用CrossEntropyLoss损失,则可以不用显式调用Softmax函数。中间域分类损失LC表达式为公式(3)。
其中,θF、θC分别表示中间域特征提取模块F和中间域分类模块C的网络参数,lc()代表交叉熵损失函数,fc()代表中间域分类模块C的输出,即属于各类别的概率,本实施中,具体指分类器Softmax输出。
(4)中间域对抗模块AD。中间域对抗模块AD与中间域特征提取模块F相配合,对中间域特征提取模块F混淆的两域特征进行辨别,而训练的目的就是使得中间域对抗模块AD不能区分两域特征,通过梯度翻转层,可以在反向传播时对(深度学习的)梯度进行翻转,进而反向更新对抗参数。中间域对抗模块AD的深度网络结构使用了两层全连接层,每层的通道数分别为100和2,并且在每层全连接层之后加入Batch Normalization,使用ReLU函数作为激活函数,同样使用交叉熵损失函数(CrossEntropyLoss),最后使用Softmax函数进行判别。中间域域对抗损失LAD表达式为公式(4)。
其中,x表示中间域对抗模块AD的输入数据,θF、θAD分别表示中间域特征提取模块F和中间域对抗模块AD的网络参数,fad()代表中间域对抗模块AD域判别器Softmax输出,和/>分别代表两域数据的领域标签,即源域和目标域。
(5)中间域特征对齐模块D。中间域特征对齐模块D负责把源域数据和目标域数据在中间域进行对齐,其本质上是一个域分类器。将源域重建特征Es(xs),转移到中间域的源域特征F(xs),转移到中间域的目标域特征F(xt)和目标域重建特征Et(xt)四类特征输入中间域特征对齐模块D,而该模块会将四类特征分为三个领域,即源域,中间域和目标域,通过中间域特征提取模块F对F(xs)和F(xt)进行混淆,中间域特征对齐模块D对领域进行判别,从而实现源域数据特征和目标域数据特征在中间域进行对齐。中间域特征对齐模块D的深度网络结构和中间域分类模块C的结构相似,只需把最后一层的输出种类改为3即可。中间域域判别损失LD表达式为公式(5)。
其中,x表示中间域特征对齐模块D的输入数据,θF、θD分别表示中间域特征提取模块F和中间域特征对齐模块D的网络参数,fd()代表中间域特征对齐模块D分类器Softmax输出,表示第个i数据xi(图像)的数据标签,一共是k类,即三类,即用于输出当前数据的领域标签。
总体损失LRMDAN表达式为公式(6):
LRMDAN=Ls-recon+Lt-recon+LC-LAD+LD (6)
作为一种可能的实现方式,本发明实施例将训练的迭代次数设为200,学习率μ=1e3,学习率的衰减系数设为0.90,每两个epoch衰减一次,使用Adam优化器对模型进行更新训练。当满足预置的收敛条件时,得到训练好的深度网络模型,再基于此时的中间域特征提取模块F和中间域分类模块C组成分类网络,用于目标图像的图像分类处理。
本发明实施例提供的一种基于重建的中间域领域自适应方法,能用于计算机视觉、智能频谱数据分析等领域自适应领域。即本发明实施例提供的深度网络模型的输入数据除了图像数据,还可以是频谱数据,可以实现对频谱数据的分类处理。其处理过程与图像分类的处理过程相同,源域数据与目标数据均为采集的频谱数据,训练时,将对应的源域数据和频谱数据输入即可,再基于训练后的中间域特征提取模块F和中间域分类模块C构成频谱数据分类的分类网络,以便于将目标域的频谱数据输入该分类网络获取分类结果。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
以上所述的仅是本发明的一些实施方式。对于本领域的普通技术人员来说,在不脱离本发明创造构思的前提下,还可以做出若干变形和改进,这些都属于本发明的保护范围。

Claims (10)

1.一种处理图像分类的基于重建的中间域领域自适应方法,其特征在于,包括下列步骤:
步骤S1:获取有标签的源域图像数据集Ds和无标签的目标域图像数据集Dt,其中,有标签源域图像数据集Ds的数据数量为n,每个图像数据定义为图像数据/>的图像类别标签定义为/>目标域数据集Dt的数据数量为m,且n、m为正整数;
步骤S2:构建深度网络模型,所述深度网络模型包括重建特征提取模块,中间域特征提取模块F,中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D;
所述重建特征提取模块包括源域重建特征提取模块和目标域重建特征提取模块,且源域重建特征提取模块和目标域重建特征提取模块的损失分别为源域重建特征损失Ls-recon和目标域重建特征损失Lt-recon
所述源域重建特征提取模块包括源域特征编码器Es和源域特征解码器Ds,其输入数据为源域数据;
所述目标域重建特征提取模块包括目标域特征编码器Et和目标域特征解码器Dt,其输入数据为目标域图像数据;
其中,源域特征编码器Es包括多个交替的卷积层与最大池化层,且源域特征解码器Ds的网络结构与源域特征编码器Es镜像对称;标域特征编码器Et与源域特征编码器Es的网络结构相同,目标域特征解码器Dt与源域特征解码器Ds的网络结构相同;
且所述源域特征编码器Es和目标域特征编码器Et的输入还输入中间域特征对齐模块D;
所述中间域特征提取模块F的输入为源域图像数据和目标域图像数据,中间域特征提取模块F用于提取两域图像数据的数据特征,得到源域特征和目标域特征,并将两域特征同时输入中间域对抗模块AD和中间域特征对齐模块D,以及将源域特征输入中间域分类模块C,通过与中间域分类模块C,中间域对抗模块AD和中间域特征对齐模块D的配合,完成中间域特征对齐;
所述中间域分类模块C根据源域图像数据标签对输入的源域特征进行分类处理,中间域分类模块C的损失为中间域分类损失LC
所述中间域对抗模块AD,用于对中间域特征提取模块F混淆的两域特征进行辨别,且所述中间域对抗模块AD的训练目的为:中间域对抗模块AD不能区分两域特征,并在反向传播时对梯度进行翻转,反向更新中间域对抗模块AD的网络参数;其中,中间域对抗模块AD的损失为中间域域对抗损失LAD
所述中间域特征对齐模块D为一个域分类器,用于对输入数据进行域分类,中间域特征对齐模块D所包括的领域类别有:源域,中间域和目标域;其中,中间域特征对齐模块D的损失为中间域域判别损失LD
步骤S3:将源域图像数据和目标域图像数据分别输入源域重建特征提取模块和目标域重建特征提取模块,以及将源域图像数据和目标域图像数据同时输入中间域特征提取模块F;
并通过迭代训练使得源域重建特征损失Ls-recon,目标域重建特征损失Lt-recon收敛,中间域分类损失LC,中间域域对抗损失LAD和中间域域判别损失LD收敛,得到训练好的深度网络模型;
步骤S4:基于训练好的深度网络模型的中间域特征提取模块F和中间域分类模块C组成分类网络,用于目标图像的图像分类处理;将目标域的待分类数据输入所述分类网络,基于其前向传播的输出得到分类结果。
2.如权利要求1所述的方法,其特征在于,所述源域重建特征损失Ls-recon和目标域重建特征损失Lt-recon的表达式分别为:
其中,xs、xt分别表示源域和目标域图像数据,分别表示源域特征编码器Es和源域特征解码器Ds的网络参数,/>分别表示目标域特征编码器Et和目标域特征解码器Dt的网络参数,lr()表示平方损失函数,fs()和ft()分别表示源域特征解码器Ds和目标域特征解码器Dt的输出。
3.如权利要求1或2所述的方法,其特征在于,所述源域特征编码器Es的依次包括:卷积层1、最大池化层1、卷积层2、最大池化层2、卷积层3和最大池化层3。
4.如权利要求1所述的方法,其特征在于,所述中间域特征提取模块F包括至少三层卷积层,且每层卷积层后依次设置有批归一化层与最大池化层,每层卷积层采用非线性激活函数,并在第二层卷积层之后采用dropout防止过拟合。
5.如权利要求1所述的方法,其特征在于,所述中间域分类损失LC的表达式为:
其中,xs表示源域图像数据,θF、θC分别表示中间域特征提取模块F和中间域分类模块C的网络参数,lc()表示交叉熵损失函数,fc()表示中间域分类模块C的输出。
6.如权利要求1或5所述的方法,其特征在于,所述中间域分类模块C包括多层全连接层,在倒数第二层的全连接层后加入批归一化层,并通过dropout防止过拟合,最后一层全连接层采用Softmax函数进行分类输出。
7.如权利要求1所述的方法,其特征在于,所述中间域域对抗损失LAD的表达式为:
其中,x表示中间域对抗模块AD的输入数据,θF、θAD分别表示中间域特征提取模块F和中间域对抗模块AD的网络参数,lc()表示交叉熵损失函数,fad()表示中间域对抗模块AD的输出,和/>分别源域和目标域图像数据的领域标签。
8.如权利要求1或7所述的方法,其特征在于,所述中间域对抗模块AD包括两层全连接层,在每一层全连接层后加入批归一化层,采用非线性激活函数作为激活函数,最后一层全连接层采用Softmax函数进行判别输出。
9.如权利要求1所述的方法,其特征在于,所述中间域域判别损失LD的表达式为:
其中,x表示中间域特征对齐模块D的输入数据,θF、θD分别表示中间域特征提取模块F和中间域特征对齐模块D的网络参数,lc()表示交叉熵损失函数,fd()表示中间域特征对齐模块D的输出,表示第i个图像数据xi的领域标签。
10.如权利要求1或9所述的方法,其特征在于,所述中间域特征对齐模块D的网络结构与中间域分类模块C的网络结构相同。
CN202210324083.4A 2022-03-29 2022-03-29 一种基于重建的中间域领域自适应方法 Active CN114693972B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210324083.4A CN114693972B (zh) 2022-03-29 2022-03-29 一种基于重建的中间域领域自适应方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210324083.4A CN114693972B (zh) 2022-03-29 2022-03-29 一种基于重建的中间域领域自适应方法

Publications (2)

Publication Number Publication Date
CN114693972A CN114693972A (zh) 2022-07-01
CN114693972B true CN114693972B (zh) 2023-08-29

Family

ID=82140149

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210324083.4A Active CN114693972B (zh) 2022-03-29 2022-03-29 一种基于重建的中间域领域自适应方法

Country Status (1)

Country Link
CN (1) CN114693972B (zh)

Citations (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN104537362A (zh) * 2015-01-16 2015-04-22 中国科学院自动化研究所 一种基于域自适应的英文场景文字识别方法
CN111126386A (zh) * 2019-12-20 2020-05-08 复旦大学 场景文本识别中基于对抗学习的序列领域适应方法
CN111932444A (zh) * 2020-07-16 2020-11-13 中国石油大学(华东) 基于生成对抗网络的人脸属性编辑方法及信息处理终端
CN112308158A (zh) * 2020-11-05 2021-02-02 电子科技大学 一种基于部分特征对齐的多源领域自适应模型及方法
CN112990359A (zh) * 2021-04-19 2021-06-18 深圳市深光粟科技有限公司 一种影像数据处理方法、装置、计算机及存储介质
CN113283444A (zh) * 2021-03-30 2021-08-20 电子科技大学 一种基于生成对抗网络的异源图像迁移方法
CN113378904A (zh) * 2021-06-01 2021-09-10 电子科技大学 一种基于对抗域自适应网络的图像分类方法
CN113469273A (zh) * 2021-07-20 2021-10-01 南京信息工程大学 基于双向生成及中间域对齐的无监督域适应图像分类方法
CN113486987A (zh) * 2021-08-04 2021-10-08 电子科技大学 基于特征解耦的多源域适应方法
CN113807371A (zh) * 2021-10-08 2021-12-17 中国人民解放军国防科技大学 一种类条件下的有益特征对齐的无监督域自适应方法
WO2022001489A1 (zh) * 2020-06-28 2022-01-06 北京交通大学 一种无监督域适应的目标重识别方法
CN113936318A (zh) * 2021-10-20 2022-01-14 成都信息工程大学 一种基于gan人脸先验信息预测和融合的人脸图像修复方法
CN114065861A (zh) * 2021-11-17 2022-02-18 北京工业大学 基于对比对抗学习的领域自适应方法及装置
CN114187263A (zh) * 2021-12-10 2022-03-15 西安交通大学 一种融合先验引导和域适应的磨损表面朗伯反射分离方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
EP3616130B1 (en) * 2017-09-20 2024-04-10 Google LLC Using simulation and domain adaptation for robotic control
US20190147854A1 (en) * 2017-11-16 2019-05-16 Microsoft Technology Licensing, Llc Speech Recognition Source to Target Domain Adaptation

Patent Citations (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN104537362A (zh) * 2015-01-16 2015-04-22 中国科学院自动化研究所 一种基于域自适应的英文场景文字识别方法
CN111126386A (zh) * 2019-12-20 2020-05-08 复旦大学 场景文本识别中基于对抗学习的序列领域适应方法
WO2022001489A1 (zh) * 2020-06-28 2022-01-06 北京交通大学 一种无监督域适应的目标重识别方法
CN111932444A (zh) * 2020-07-16 2020-11-13 中国石油大学(华东) 基于生成对抗网络的人脸属性编辑方法及信息处理终端
CN112308158A (zh) * 2020-11-05 2021-02-02 电子科技大学 一种基于部分特征对齐的多源领域自适应模型及方法
CN113283444A (zh) * 2021-03-30 2021-08-20 电子科技大学 一种基于生成对抗网络的异源图像迁移方法
CN112990359A (zh) * 2021-04-19 2021-06-18 深圳市深光粟科技有限公司 一种影像数据处理方法、装置、计算机及存储介质
CN113378904A (zh) * 2021-06-01 2021-09-10 电子科技大学 一种基于对抗域自适应网络的图像分类方法
CN113469273A (zh) * 2021-07-20 2021-10-01 南京信息工程大学 基于双向生成及中间域对齐的无监督域适应图像分类方法
CN113486987A (zh) * 2021-08-04 2021-10-08 电子科技大学 基于特征解耦的多源域适应方法
CN113807371A (zh) * 2021-10-08 2021-12-17 中国人民解放军国防科技大学 一种类条件下的有益特征对齐的无监督域自适应方法
CN113936318A (zh) * 2021-10-20 2022-01-14 成都信息工程大学 一种基于gan人脸先验信息预测和融合的人脸图像修复方法
CN114065861A (zh) * 2021-11-17 2022-02-18 北京工业大学 基于对比对抗学习的领域自适应方法及装置
CN114187263A (zh) * 2021-12-10 2022-03-15 西安交通大学 一种融合先验引导和域适应的磨损表面朗伯反射分离方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
"领域自适应 研究综述";李晶晶等;《计算机工程》;第47卷(第6期);第1-13页 *

Also Published As

Publication number Publication date
CN114693972A (zh) 2022-07-01

Similar Documents

Publication Publication Date Title
Mascarenhas et al. A comparison between VGG16, VGG19 and ResNet50 architecture frameworks for Image Classification
Ghosal et al. Rice leaf diseases classification using CNN with transfer learning
CN111368896B (zh) 基于密集残差三维卷积神经网络的高光谱遥感图像分类方法
CN111126386B (zh) 场景文本识别中基于对抗学习的序列领域适应方法
CN112446423B (zh) 一种基于迁移学习的快速混合高阶注意力域对抗网络的方法
CN110555060B (zh) 基于成对样本匹配的迁移学习方法
Liu et al. Global pixel transformers for virtual staining of microscopy images
Bani-Hani et al. Classification of leucocytes using convolutional neural network optimized through genetic algorithm
CN111178120A (zh) 一种基于作物识别级联技术的害虫图像检测方法
CN110942825A (zh) 基于卷积神经网络和循环神经网络结合的心电诊断方法
Zhang et al. Classification of canker on small datasets using improved deep convolutional generative adversarial networks
Ma et al. Maize leaf disease identification using deep transfer convolutional neural networks
Singh et al. Performance Analysis of CNN Models with Data Augmentation in Rice Diseases
Nga et al. Combining binary particle swarm optimization with support vector machine for enhancing rice varieties classification accuracy
CN113011436A (zh) 一种基于卷积神经网络的中医舌色苔色协同分类方法
CN113221913A (zh) 一种基于高斯概率决策级融合的农林病虫害细粒度识别方法及装置
CN114693972B (zh) 一种基于重建的中间域领域自适应方法
Zheng et al. Fruit tree disease recognition based on convolutional neural networks
Hu et al. Learning salient features for flower classification using convolutional neural network
CN108960275A (zh) 一种基于深度玻尔兹曼机的图像识别方法及系统
CN114998973A (zh) 一种基于域自适应的微表情识别方法
Jayaram et al. A brief study on rice diseases recognition and image classification: fusion deep belief network and S-particle swarm optimization algorithm
Adaïmé et al. Deep learning approaches to the phylogenetic placement of extinct pollen morphotypes
Guzzi et al. Distillation of a CNN for a high accuracy mobile face recognition system
Sahu et al. Adaptive fusion of K-means region growing with optimized deep features for enhanced LSTM-based multi-disease classification of plant leaves

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant