CN112308158B - 一种基于部分特征对齐的多源领域自适应模型及方法 - Google Patents
一种基于部分特征对齐的多源领域自适应模型及方法 Download PDFInfo
- Publication number
- CN112308158B CN112308158B CN202011223578.5A CN202011223578A CN112308158B CN 112308158 B CN112308158 B CN 112308158B CN 202011223578 A CN202011223578 A CN 202011223578A CN 112308158 B CN112308158 B CN 112308158B
- Authority
- CN
- China
- Prior art keywords
- feature
- domain
- partial
- alignment
- loss function
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 27
- 230000006870 function Effects 0.000 claims abstract description 79
- 238000000605 extraction Methods 0.000 claims abstract description 72
- 239000013598 vector Substances 0.000 claims abstract description 29
- 238000013527 convolutional neural network Methods 0.000 claims abstract description 19
- 238000013528 artificial neural network Methods 0.000 claims abstract description 8
- 238000012549 training Methods 0.000 claims description 38
- 230000003044 adaptive effect Effects 0.000 claims description 17
- 238000010586 diagram Methods 0.000 claims description 17
- 238000012360 testing method Methods 0.000 claims description 8
- 238000000746 purification Methods 0.000 claims description 7
- 238000004364 calculation method Methods 0.000 claims description 5
- 239000006185 dispersion Substances 0.000 claims description 5
- 238000007781 pre-processing Methods 0.000 claims description 5
- 230000008569 process Effects 0.000 claims description 5
- 230000004913 activation Effects 0.000 claims description 2
- 238000002474 experimental method Methods 0.000 claims description 2
- 238000005520 cutting process Methods 0.000 claims 1
- 230000000694 effects Effects 0.000 abstract description 4
- 230000000875 corresponding effect Effects 0.000 description 16
- 230000002776 aggregation Effects 0.000 description 3
- 238000004220 aggregation Methods 0.000 description 3
- 238000009826 distribution Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000006978 adaptation Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000002596 correlated effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000004821 distillation Methods 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000004576 sand Substances 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
- G06V10/443—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components by matching or filtering
- G06V10/449—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters
- G06V10/451—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters with interaction between the filter responses, e.g. cortical complex cells
- G06V10/454—Integrating the filters into a hierarchical structure, e.g. convolutional neural networks [CNN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/771—Feature selection, e.g. selecting representative features from a multi-dimensional feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/7715—Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
Abstract
本发明公开了一种基于部分特征对齐的多源领域自适应模型及方法,其中部分特征提取的特征选择模块在常规卷积神经网络或残差神经网络特征提取器的基础上,根据源域与目标域各个特征维度的相似性,生成特征层面的选择向量,该选择向量作用于初始的特征图后,可以筛选出源域中与目标域高度相关的部分特征。在此基础上,本发明进一步提出了三种分别针对类别内、域间和类别间的部分特征对齐损失函数,使得提纯后的特征图对于分类器的可分辨性更好,源域与目标域相关的部分特征被凸显出来。本发明用于多源领域自适应分类数据集,与现有多源领域自适应模型相比,其分类正确率更高,特征选择的效果更好。
Description
技术领域
本发明属于计算机视觉与迁移学习中的多源领域自适应分支,具体涉及一种通过高阶矩匹配对部分特征进行对齐的多源领域自适应模型结构,以及在特征图上设计对齐损失函数以达到特征对分类器可分辨的目的。
背景技术
在机器学习中,利用深度神经网络的监督和半监督学习已经取得了较为显著的成果,依赖于众多公开数据集,监督和半监督学习在多种任务如图像分类、人脸识别、语义分析上具有广泛的应用。但是,现实世界的数据标签收集十分困难,往往需要大量人力来完成,由于域间偏移的存在,在其它数据集上训练的模型并不能直接应用在实际生产中,因此提出了无监督领域自适应的技术。
传统的无监督领域自适应技术使用一个带标签的源域作为信息来源,将其与无标签的目标域共同训练,通过减小源域与目标域数据之间的分布差异来实现领域自适应,从而在目标域上完成所期望的任务,而单源域的领域自适应所获取的信息对于目标任务来说是有限的,因此引入了多源领域自适应的方法,即使用多个源域数据集作为模型学习的依据,将多个源域的数据与目标域样本进行对齐,从而达到更好的性能提升。
在现有的多源领域自适应模型中,研究者们提出了多种前沿技术基于深度学习提升模型在目标域上的表现,例如:基于对抗学习、基于知识传递与聚集、基于数据蒸馏、基于数据动态选择、基于特征提取与对齐、基于子空间学习等。其中基于特征提取与对齐的技术使用较为广泛,可以使得源域与目标域的数据在特征空间上相互对齐,在目标域的任务上取得较好的结果。
通常在现有的基于特征提取与对齐的多源域领域自适应策略中,使用一个共享权重的特征提取器提取多个源域以及目标域的图像特征,在所获得的特征图上设计关注于不同维度的对齐和匹配损失函数以减小源域特征与目标域特征之间的分布差异。
尽管研究者们提出的诸多特征对齐方法已经在大量公开数据集中取得了较高的正确率,但现有的特征提取与对齐方法中,还存在一些不足。例如,现有的一些方法在进行特征对齐时,使用的是通过特征提取器获得的全部图像特征,这些特征中不仅包含了源域和目标域之间相关的特征,还包含了只在某些源域中出现的域特有的特征,如果这些特征也参与到特征对齐中,将造成对齐效果的下降。
综上,在多源域领域自适应中,在所有源域特征的并集包含目标域的所有特征的基本假设上,有必要在特征图上进一步讨论各个特征维度对于目标域任务的重要性,以及源域和目标域在各个特征维度上的相似性,以此提出一种针对特征图中的一部分与目标域相关的特征的对齐策略,以在目标域任务上取得更好的结果。
发明内容
本发明的发明目的在于:提供一种在图像特征层面上利用高阶矩差异进行部分特征的对齐的策略,以获得多源域之间与目标域紧密相关的部分特征,并在部分特征上应用三种不同维度的部分特征对齐损失函数,实现更好的分类性能。
本发明是一种基于部分特征对齐的多源领域自适应模型,包括特征提取模块、部分特征提取的特征选择模块及其对应的损失函数、三种部分特征对齐损失函数和两个对抗训练的分类器,其中,三种部分特征对齐损失函数包括类别内部分特征对齐损失、域间部分特征对齐损失和类别间部分特征对齐损失,经过特征选择模块提取出的部分特征再通过三种不同的部分特征对齐损失以达到同类别聚集,不同类别分散的目的;
其中,所述特征提取模块使用常规的卷积神经网络CNN或预训练的残差神经网络ResN et-101,分别针对简单图像和复杂图像,用于提取图像的初始特征图,多个源域和目标域的初始特征图之间分别计算L1距离后作为部分特征提取的特征选择模块的输入;
所述卷积神经网络CNN使用三层卷积层和两层全连接层,最终得到特征维度为2048维,所述残差神经网络ResNet-101去掉最后一层全连接层后,获得特征维度为2048维;
所述部分特征提取的特征选择模块通过两层全连接层,使用如上所述的L1距离作为输入,将该距离中数值较小的维度看作是源域与目标域相关的特征维度,特征选择模块的输出为初始特征图的特征选择向量。将特征选择向量以点乘的方式作用于特征提取模块提取出的初始特征图,获得提纯后的特征图;
进一步地,本发明在提纯后的特征图上设计部分特征对齐的特征选择模块的损失函数以及三种部分特征对齐的损失函数,在提纯的特征图上分别计算部分特征提取的特征选择模块对应的损失、类别内的部分特征对齐损失、域间的部分特征对齐损失和类别间的部分特征对齐损失,取上述所有损失函数的加权和作为除分类损失以外的所有损失函数;
所述部分特征提取的特征选择模块的损失函数如下:
其中分别表示提纯后源域和目标域的特征图,k表示矩的阶数,λreg是规范化的权重参数,N为一次批量训练中的样本数量,vi表示源域i的特征选择向量,为各个源域特征选择向量的平均值,表示某个源域的一批样本,表示目标域的一批样本,表示求期望值,G是常规的特征提取器,是提取出来得到的源域i的初始特征图,是提取出来得到的目标域的初始特征图;
所述三种部分特征对齐损失建立在提纯特征图的类别中心点的基础上,其定义如下:
其中fc表示某一个类别的中心点,F表示上述部分特征提取的过程,即经过F获得提纯后的特征图,n为一次批量训练中对应类别的样本数量;
为了保留所述基于部分特征对齐的多源领域自适应模型在前面训练所获得的信息,使用指数级流动平均值在每一次批训练中更新所有中心点:
所述类别内部分特征对齐损失如下:
其中fc表示某个域中某一类的中心点,fs表示对应域的对应类的部分特征样本点,k为矩的阶数。
所述域间部分特征对齐损失如下:
所述类别间部分特征对齐损失如下:
所述基于部分特征对齐的多源领域自适应模型整体的损失函数为:
L=Ls+λpLp+λcLc+λdomLdom+λdiscLdisc
其中Ls为源域的交叉熵分类损失,由两个分类器的交叉熵损失之和得到,λp,λc,λdom,λdisc分别为Lp,Lc,Ldom,Ldisc的权重参数。
一种基于部分特征对齐的多源领域自适应方法,具体包括以下步骤:
步骤1:数据预处理,得到预处理后的数据;
步骤2:对预处理后的数据使用特征提取模块提取图像基本特征;
步骤3:将源域与目标域特征之间的绝对值差异输入到部分特征对齐的特征选择模块,得到特征选择向量,将其作用于初始特征图,获得提纯后特征图,在提纯后的特征图上计算上述损失函数L并更新各个模块的参数,包括特征提取模块、部分特征提取的特征选择模块以及两个分类器;
步骤4:重复步骤2-3,获得提纯特征图和两个分类器的分类概率,计算两个分类器的源域分类交叉熵损失之和Ls以及两个分类器在目标域上的分类概率的绝对值差异Ldis。固定除分类器以外的其他模块的参数,用分类损失减去绝对值差异Ls-Ldis作为损失函数去更新两个分类器的参数,再固定除特征提取模块以外的其它模块的参数,重新计算两分类器在目标域上的分类概率的绝对值差异Ldis作为损失函数去更新特征提取模块的参数,进行特征提取模块与两个分类器的对抗训练;
步骤5:训练完成后,在测试数据集上只经过特征提取和分类步骤,获得分类预测结果,验证模型有效性。
与现有技术相比,本发明具有如下优点和有益效果:
(1)本发明提出了一个适用于特征提纯的部分特征提取和选择策略,在常规特征提取模块的基础上,使用该策略可以获得源域中与目标域高度相关的特征,同时过滤掉源域数据中与目标域不相关或关系较弱的特征,从而避免了不相关特征对特征对齐造成的负面影响。
(2)在提纯特征的基础上,本发明关注于类别内的特征对齐问题,提出类别内的特征对齐损失函数,使得源域和目标域中属于同一类别的图像在提纯后的特征图上相互聚集,有益于分类器进行分类预测。
(3)在领域自适应方面,本发明关注于不同域的图像之间的相关性,提出域间的特征对齐损失函数,将源域与目标域中对应类别的中心点联系起来,降低源域与目标域的分布差异,实现领域自适应,使得在源域上训练获得的知识可以迁移到目标域中。
(4)针对与提纯特征中不同类别的可分辨性,本发明提出了类别间的特征对齐损失函数,使得源域和目标域中,不同类别的数据中心点相互分离,在提纯后的特征图上体现为多个相互分离的数据聚集区域,提升数据点对于分类器的可分辨性。
附图说明
图1为本发明进行部分特征对齐的实现流程图;
图2为本发明的部分特征提取框架示意图;
图3为本发明的部分特征对齐损失示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面结合实施方式和附图,对本发明作进一步地详细描述,以便相关领域的技术人员能更好地理解本发明。需要特别注意的是,所描述的实施例是本发明一部分实施例,而不是全部的实施例,也非旨在限制要求保护的本发明的范围。本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
考虑到现有的多源域领域自适应模型在进行特征提取和对齐时,往往将所有的特征一起进行对齐,忽略了某些特征属于源域所特有的特征而不在目标域中表现这一事实,本发明提出了一种基于部分特征对齐的多源领域自适应方法。本发明在图像的特征层面上对源域和目标域的特征进行筛选,使得剩下的特征与目标域任务高度相关,并提出三种关注于不同层面的部分特征对齐的损失函数,提升了在目标域上图像分类的效果。下面结合具体实例,对本发明进行详细完整的说明。
如图2所示,本发明提出的一种基于部分特征对齐的多源领域自适应模型包括特征提取模块、部分特征提取的特征选择模块及其对应的损失函数、三种部分特征对齐损失函数和两个对抗训练的分类器,其中,三种部分特征对齐损失函数包括类别内部分特征对齐损失、域间部分特征对齐损失和类别间部分特征对齐损失;
其中,所述特征提取模块使用常规的卷积神经网络CNN或预训练的残差神经网络ResNet-101,分别针对简单图像和复杂图像,用于提取图像的初始特征图,多个源域和目标域的初始特征图之间分别计算L1距离后作为部分特征提取的特征选择模块的输入;所述卷积神经网络CNN使用三层卷积层和两层全连接层,最终得到特征维度为2048维,所述残差神经网络ResNet-101去掉最后一层全连接层后,获得特征维度为2048维;
所述部分特征提取的特征选择模块通过两层全连接层,使用所述L1距离作为输入,将该距离中数值较小的维度看作是源域与目标域相关的特征维度,特征选择模块的输出为初始特征图的特征选择向量,将特征选择向量以点乘的方式作用于特征提取模块提取出的初始特征图,获得提纯后的特征图;
在提纯后的特征图上设计部分特征提取的特征选择模块的损失函数以及三种部分特征对齐损失函数,在提纯的特征图上分别计算部分特征提取的特征选择模块对应的损失、类别内部分特征对齐损失、域间部分特征对齐损失和类别间部分特征对齐损失,取上述所有损失函数的加权和作为除分类损失以外的所有损失函数;
所述部分特征提取的特征选择模块的损失函数如下:
其中分别表示提纯后源域和目标域的特征图,k表示矩的阶数,λreg是规范化的权重参数,N为一次批量训练中的样本数量,vi表示源域i的特征选择向量,为各个源域特征选择向量的平均值,表示某个源域的一批样本,表示目标域的一批样本,表示求期望值,G是常规的特征提取器,是提取出来得到的源域i的初始特征图,是提取出来得到的目标域的初始特征图;
所述三种部分特征对齐损失建立在提纯特征图的类别中心点的基础上,其定义如下:
其中fc表示某一个类别的中心点,F表示上述部分特征提取的过程,即经过F获得提纯后的特征图,n为一次批量训练中对应类别的样本数量;
为了保留所述基于部分特征对齐的多源领域自适应模型在前面训练所获得的信息,使用指数级流动平均值在每一次批训练中更新所有中心点:
如图3所示,所述类别内部分特征对齐损失如下:
其中fc表示某个域中某一类的中心点,fs表示对应域的对应类的部分特征样本点,k为矩的阶数;
所述域间部分特征对齐损失如下:
所述类别间部分特征对齐损失如下:
所述基于部分特征对齐的多源领域自适应模型整体的损失函数为:
L=Ls+λpLp+λcLc+λdomLdom+λdiscLdisc
其中Ls为源域的交叉熵分类损失,由所述两个对抗训练的分类器的交叉熵损失之和得到,λp,λc,λdom,λdisc分别为Lp,Lc,Ldom,Ldisc的预设权重参数。
如图1所示,本发明中基于部分特征对齐的多源领域自适应方法包含如下步骤:
步骤1:数据预处理。本实例选取三个本领域使用较多的公开数据集进行实验,包括Digit-Five、Office31和DomainNet。
Digit-Five中收集了五个不同种类的手写数字识别数据集的子集,分别为MNIST-M、MNIST、USPS、SVHN和Synthetic Digits,其中USPS含有9298张图片,其余数据集均含有25000张训练图片和9000张测试图片。
Office31是一个传统的多源领域自适应数据集,包含4652张图片,31个类别。图像收集于办公室的环境,展现在三个域中:Amazon、Webcam和DSLR。
DomainNet是近年才提出的新数据集,并且是目前为止数量最大、最具挑战性的多源领域自适应数据集,总共包含六个域的数据:clipart、infograph、painting、quickdraw、real和sketch,其中每个域都包含有345个类别的图像。
选择以上三种数据集用于验证模型在不同类型和数量的数据集上的鲁棒性,三个数据集中类别的个数依次递增,域间差异也越来越大,挑战性也随之增加,可以比较好地反映模型的性能。数据的预处理过程包含简单的图像缩放以及随机翻转和裁剪等操作。
步骤2:使用特征提取模块提取图像基本特征。
对于Digit-Five数据集,图像大小缩放为32×32,特征提取模块选用三层卷积层和两层全连接层的卷积神经网络,卷积核大小均为5,全连接层的输出为2048维的特征向量;对于Office31数据集,图像大小为252×252,特征提取模块选用预训练的AlexNet,输出为4096维的特征向量;对于DomainNet数据集,图像大小为224×224,特征提取模块选用预训练的ResNet-101,输出为2048维的特征向量,用fT分别表示第i个源域和目标域的特征图。
在训练中,除Digit-Five的一批数据量为128张图片以外,其余数据集均使用一批16张图片的方式训练,因此训练时Digit-Five的特征图维度为128×2048,Office31的特征图维度为16×4096,DomainNet的特征图维度为16×2048。训练时,Digit-Five和Office31均训练100个epoch,DomainNet由于数据量非常大,只训练20个epoch。
步骤3:特征选择以及部分特征对齐损失函数计算。
计算源域特征图与目标域特征图的绝对值差异将其作为特征选择模块的输入,其输出为特征选择向量vi,然后计算源域提纯的特征图对于目标域,可以选择使用所有源域选择向量的均值作为选择向量,或者不进行目标域特征提纯,因为目标域的特征与自身已经是高度相关的,最终得到目标域的特征图用FT表示。
为计算部分特征对齐的损失函数,需要维护提纯特征每个域中每个类别的中心点,如前所述,中心点的维护方式如下:
在此基础上,计算所述类别内、域间和类别间的部分特征对齐损失函数。
如图3所示,类别内部分特征对齐损失具体形式如下:
其中fc表示某个域中某一类的中心点,fs表示对应域的对应类的部分特征样本点,在源域和目标域上均进行计算,目标域的标签使用当前分类器预测的伪标签来替代。
域间部分特征对齐损失具体形式如下:
类别间部分特征对齐损失具体形式如下:
将提纯后的特征图输入给两个分类器,获得分类概率,计算所有源域上的交叉熵损失Ls。
将前述所有损失函数按折衷参数求合,获得模型的整体损失函数:
L=Ls+λpLp+λcLc+λdomLdom+λdiscLdisc
其中Ls的计算使用两个分类器交叉熵损失之和。按此损失函数对整个模型所有模块参数进行更新,包括特征提取模块、部分特征提取的特征选择模块以及两个分类器的参数。
步骤4:特征提取模块与两个分类器对抗训练。
重复步骤2-3,获得源域和目标域提纯后的特征而不计算相关特征对齐损失函数,在此特征图上,使用两个分类器分别生成目标域的预测概率,计算两者经过softmax激活函数后的绝对值差异Ldis以及两个分类器在源域上的交叉熵损失Ls。此时固定特征提取模块的参数,使用Ls-Ldis作为损失函数去更新两个分类器,增大目标域分类概率差异,然后固定分类器参数,重新计算Ldis作为损失函数去更新特征提取模块,减小分类差异,实现特征提取器与两个分类器的对抗训练。其中,对于特征提取模块,损失函数Ldis的计算和参数的更新可以重复1~4次,该重复次数用于在特征提取器和分类器之间进行折衷。
步骤5:预测测试数据集的分类结果。
本发明提出的模型在上述步骤1中所提到的三种数据集上均进行了如步骤2-4所述的训练,并进行测试集的测试。实验结果显示,本发明提出的基于部分特征对齐的多源领域自适应模型在Digit-Five、Office31和DomainNet上的平均分类正确率分别为92.7%、84.6%和48%,其中Digit-Five和DomainNet的结果优于已有的多源领域自适应方法,Office31的结果也达到了前沿水平,说明本发明提出的模型能够有效地在原始特征图上进一步提取出源域中与目标域高度相关的部分特征,并通过类别内、域间和类别间的部分特征对齐损失实现部分特征图中相同类别的数据点相互聚集、源域与目标域对应同类别的数据中心点相互靠近、目标域中不同类别的数据点相互分散的目的。
为进一步验证本发明提出各个模块以及损失函数的有效性,在Digit-Five数据集上进行了剔除部分模块的实验,其中去除部分特征提取的特征选择模块后,分类正确率为90.9%;去除类别内部分特征对齐损失后,分类正确率为90.8%;去除域间部分特征对齐损失后,分类正确率为89.4%;去除类别间部分特征对齐损失后,分类正确率为90.8%;即去掉本发明提出的模型的不同模块后,正确率均有不同程度的下降。由此说明,本发明提出的部分特征提取的特征选择模块以及三种部分对齐损失函数在多个层面上对当前的针对全部特征对齐的多源领域自适应方法的改进是有效的。
以上所述,仅为本发明的具体实施方式,本说明书中所公开的任一特征,除非特别叙述,均可被其他等效或具有类似目的的替代特征加以替换;所公开的所有特征、或所有方法或过程中的步骤,除了互相排斥的特征和/或步骤以外,均可以任何方式组合。
Claims (2)
1.一种基于部分特征对齐的多源领域自适应模型,其特征在于,该多源领域自适应模型包括特征提取模块、部分特征提取的特征选择模块及其对应的损失函数、三种部分特征对齐损失函数和两个对抗训练的分类器,其中,三种部分特征对齐损失函数包括类别内部分特征对齐损失函数、域间部分特征对齐损失函数和类别间部分特征对齐损失函数;
其中,所述特征提取模块使用常规的卷积神经网络CNN或预训练的残差神经网络ResNet-101,分别针对简单图像和复杂图像,用于提取图像的初始特征图,多个源域和目标域的初始特征图之间分别计算L1距离后作为部分特征提取的特征选择模块的输入;所述卷积神经网络CNN使用三层卷积层和两层全连接层,最终得到特征维度为2048维,所述残差神经网络ResNet-101去掉最后一层全连接层后,获得特征维度为2048维;
所述部分特征提取的特征选择模块通过两层全连接层,使用所述L1距离作为输入,根据该距离分析源域与目标域相关的特征维度,所述部分特征提取的特征选择模块的输出为初始特征图的特征选择向量,将特征选择向量以点乘的方式作用于特征提取模块提取出的初始特征图,获得提纯后的特征图;
在提纯后的特征图上设计部分特征提取的特征选择模块的损失函数以及三种部分特征对齐损失函数,即在提纯后的特征图上分别计算部分特征提取的特征选择模块对应的损失函数、类别内部分特征对齐损失函数、域间部分特征对齐损失函数和类别间部分特征对齐损失函数,取上述所有损失函数的加权和作为除分类损失以外的所有损失函数;
所述部分特征提取的特征选择模块的损失函数如下:
其中分别表示提纯后源域和目标域的特征图,k表示矩的阶数,λreg是规范化的权重参数,N为一次批量训练中的样本数量,vi表示源域i的特征选择向量,为各个源域特征选择向量的平均值,表示某个源域的一批样本,表示目标域的一批样本,表示求期望值,G是常规的特征提取器,是提取出来得到的源域i的初始特征图,是提取出来得到的目标域的初始特征图;
所述三种部分特征对齐损失函数建立在提纯特征图的类别中心点的基础上,其定义如下:
其中fc表示某一个类别的中心点,F表示部分特征提取的过程,即经过F获得提纯后的特征图,n为一次批量训练中对应类别的样本数量;
为了保留所述基于部分特征对齐的多源领域自适应模型在前面训练所获得的信息,使用指数级流动平均值在每一次批训练中更新所有中心点:
所述类别内部分特征对齐损失函数如下:
其中fc表示某个域中某一类的中心点,fs表示对应域的对应类的部分特征样本点,k为矩的阶数;
所述域间部分特征对齐损失函数如下:
所述类别间部分特征对齐损失函数如下:
所述基于部分特征对齐的多源领域自适应模型整体的损失函数为:
L=Ls+λpLp+λcLc+λdomLdom+λdiscLdisc
其中Ls为源域的交叉熵分类损失,由所述两个对抗训练的分类器的交叉熵损失之和得到,λp,λc,λdom,λdisc分别为Lp,Lc,Ldom,Ldisc的预设权重参数。
2.一种基于部分特征对齐的多源领域自适应方法,采用如权利要求1所述的基于部分特征对齐的多源领域自适应模型实现,其特征在于,所述多源领域自适应方法包括如下步骤:
步骤1:数据预处理,选取三个公开数据集进行实验,包括Digit-Five、Office31和DomainNet;
Digit-Five中收集了五个不同种类的手写数字识别数据集的子集,分别为MNIST-M、MNIST、USPS、SVHN和Synthetic Digits,其中USPS含有9298张图片,其余数据集均含有25000张训练图片和9000张测试图片;
Office31是一个传统的多源领域自适应数据集,包含4652张图片,31个类别,图像收集于办公室的环境,展现在三个域中:Amazon、Webcam和DSLR;
DomainNet是近年才提出的新数据集,并且是目前为止数量最大、最具挑战性的多源领域自适应数据集,总共包含六个域的数据:clipart、infograph、painting、quickdraw、real和sketch,其中每个域都包含有345个类别的图像;
数据的预处理过程包含简单的图像缩放以及随机翻转和裁剪的操作;
步骤2:对预处理后的数据使用特征提取模块提取图像基本特征,对于Digit-Five数据集,图像大小缩放为32×32,特征提取模块选用三层卷积层和两层全连接层的卷积神经网络,卷积核大小均为5,全连接层的输出为2048维的特征向量;对于Office31数据集,图像大小为252×252,特征提取模块选用预训练的AlexNet,输出为4096维的特征向量;对于DomainNet数据集,图像大小为224×224,特征提取模块选用预训练的ResNet-101,输出为2048维的特征向量,用fT分别表示第i个源域和目标域的特征图;
在训练中,除Digit-Five的一批数据量为128张图片以外,其余数据集均使用一批16张图片的方式训练,因此训练时Digit-Five的特征图维度为128×2048,Office31的特征图维度为16×4096,DomainNet的特征图维度为16×2048;训练时,Digit-Five和Office31均训练100个epoch,DomainNet由于数据量非常大,只训练20个epoch;
步骤3:特征选择以及部分特征对齐损失函数计算,计算源域特征图与目标域特征图的绝对值差异将其作为部分特征提取的特征选择模块的输入,其输出为特征选择向量vi,然后计算源域提纯的特征图对于目标域,使用所有源域选择向量的均值作为选择向量,最终得到目标域的特征图用FT表示;
为计算部分特征对齐的损失函数,需要维护提纯特征每个域中每个类别的中心点,中心点的维护方式如下:
在此基础上,计算所述类别内、域间和类别间部分特征对齐损失函数,
类别内部分特征对齐损失函数具体形式如下:
其中fc表示某个域中某一类的中心点,fs表示对应域的对应类的部分特征样本点,在源域和目标域上均进行计算,目标域的标签使用当前分类器预测的伪标签来替代;
域间部分特征对齐损失函数具体形式如下:
类别间部分特征对齐损失函数具体形式如下:
将提纯后的特征图输入给所述两个对抗训练的分类器,获得分类概率,计算所有源域上的交叉熵损失Ls,
将前述所有损失函数按折衷参数求合,获得模型的整体损失函数:
L=Ls+λpLp+λcLc+λdomLdom+λdiscLdisc
其中Ls的计算使用所述两个对抗训练的分类器交叉熵损失之和,λp,λc,λdom,λdisc分别为Lp,Lc,Ldom,Ldisc的预设权重参数,按损失函数L对整个基于部分特征对齐的多源领域自适应模型所有模块参数进行更新,包括特征提取模块、部分特征提取的特征选择模块以及两个对抗训练的分类器的参数;
步骤4:特征提取模块与两个分类器对抗训练,重复步骤2-步骤3,获得源域和目标域提纯后的特征而不计算相关特征对齐损失函数,在此特征图上,使用两个分类器分别生成目标域的预测概率,计算两者经过softmax激活函数后的绝对值差异Ldis以及两个分类器在源域上的交叉熵损失Ls,此时固定特征提取模块的参数,使用Ls-Ldis作为损失函数去更新两个分类器,增大目标域分类概率差异,然后固定两个分类器参数,重新计算Ldis作为损失函数去更新特征提取模块,减小分类差异,实现特征提取器与两个分类器的对抗训练,其中,对于特征提取模块,损失函数Ldis的计算和参数的更新重复4次,该重复次数用于在特征提取器和分类器之间进行折衷;
步骤5:预测测试数据集的分类结果,使用所述基于部分特征对齐的多源领域自适应模型在所述步骤1中所提到的三种数据集上均进行了如步骤2-4所述的训练,并进行测试集的测试。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011223578.5A CN112308158B (zh) | 2020-11-05 | 2020-11-05 | 一种基于部分特征对齐的多源领域自适应模型及方法 |
US17/519,604 US11960568B2 (en) | 2020-11-05 | 2021-11-05 | Model and method for multi-source domain adaptation by aligning partial features |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011223578.5A CN112308158B (zh) | 2020-11-05 | 2020-11-05 | 一种基于部分特征对齐的多源领域自适应模型及方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112308158A CN112308158A (zh) | 2021-02-02 |
CN112308158B true CN112308158B (zh) | 2021-09-24 |
Family
ID=74325116
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011223578.5A Active CN112308158B (zh) | 2020-11-05 | 2020-11-05 | 一种基于部分特征对齐的多源领域自适应模型及方法 |
Country Status (2)
Country | Link |
---|---|
US (1) | US11960568B2 (zh) |
CN (1) | CN112308158B (zh) |
Families Citing this family (37)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113011513B (zh) * | 2021-03-29 | 2023-03-24 | 华南理工大学 | 一种基于通用域自适应的图像大数据分类方法 |
CN113420775B (zh) * | 2021-03-31 | 2024-03-29 | 中国矿业大学 | 基于非线性度自适应子域领域适应的极少量训练样本下图片分类方法 |
CN113436197B (zh) * | 2021-06-07 | 2022-10-04 | 华东师范大学 | 基于生成对抗和类特征分布的域适应无监督图像分割方法 |
CN113221848B (zh) * | 2021-06-09 | 2022-07-19 | 中国人民解放军国防科技大学 | 基于多分类器域对抗网络的高光谱开放集领域自适应方法 |
CN113313255A (zh) * | 2021-06-18 | 2021-08-27 | 东南大学 | 一种基于神经网络架构搜索的无监督领域自适应方法 |
CN113537292B (zh) * | 2021-06-18 | 2024-02-09 | 杭州电子科技大学 | 一种基于张量化高阶互注意力机制的多源域适应方法 |
CN113610105A (zh) * | 2021-07-01 | 2021-11-05 | 南京信息工程大学 | 基于动态加权学习和元学习的无监督域适应图像分类方法 |
CN113688867B (zh) * | 2021-07-20 | 2023-04-28 | 广东工业大学 | 一种跨域图像分类方法 |
CN114065852B (zh) * | 2021-11-11 | 2024-04-16 | 合肥工业大学 | 基于动态权重的多源联合自适应和内聚性特征提取方法 |
CN114693972B (zh) * | 2022-03-29 | 2023-08-29 | 电子科技大学 | 一种基于重建的中间域领域自适应方法 |
CN114926877B (zh) * | 2022-05-10 | 2024-02-20 | 西北工业大学 | 一种基于对比域差异的跨域人脸表情识别方法 |
CN115082725B (zh) * | 2022-05-17 | 2024-02-23 | 西北工业大学 | 基于可靠样本选择和双分支动态网络的多源域自适应方法 |
CN114882220B (zh) * | 2022-05-20 | 2023-02-28 | 山东力聚机器人科技股份有限公司 | 基于域自适应先验知识引导gan的图像生成方法及系统 |
CN114972920B (zh) * | 2022-05-30 | 2024-03-12 | 西北工业大学 | 一种多层次无监督领域自适应目标检测识别方法 |
CN114694150B (zh) * | 2022-05-31 | 2022-10-21 | 成都考拉悠然科技有限公司 | 一种提升数字图像分类模型泛化能力的方法及系统 |
CN115081480B (zh) * | 2022-06-23 | 2024-03-29 | 中国科学技术大学 | 一种多源共迁移跨用户的肌电模式分类方法 |
CN114818996B (zh) * | 2022-06-28 | 2022-10-11 | 山东大学 | 基于联邦域泛化的机械故障诊断方法及系统 |
CN115329853B (zh) * | 2022-08-04 | 2023-04-21 | 西南交通大学 | 一种基于多源域迁移的装备参数预测与知识转移方法 |
CN115131590B (zh) * | 2022-09-01 | 2022-12-06 | 浙江大华技术股份有限公司 | 目标检测模型的训练方法、目标检测方法及相关设备 |
CN115186773B (zh) * | 2022-09-13 | 2022-12-09 | 杭州涿溪脑与智能研究所 | 一种无源的主动领域自适应模型训练方法及装置 |
CN115410088B (zh) * | 2022-10-10 | 2023-10-31 | 中国矿业大学 | 一种基于虚拟分类器的高光谱图像领域自适应方法 |
CN115310727B (zh) * | 2022-10-11 | 2023-02-03 | 山东建筑大学 | 一种基于迁移学习的建筑冷热电负荷预测方法及系统 |
CN115601560A (zh) * | 2022-10-28 | 2023-01-13 | 广东石油化工学院(Cn) | 一种基于自适应网络的参数更新方法 |
CN115791174B (zh) * | 2022-12-29 | 2023-11-21 | 南京航空航天大学 | 一种滚动轴承异常诊断方法、系统、电子设备及存储介质 |
CN116312860B (zh) * | 2023-03-24 | 2023-09-12 | 江南大学 | 基于监督迁移学习的农产品可溶性固形物预测方法 |
CN116229080B (zh) * | 2023-05-08 | 2023-08-29 | 中国科学技术大学 | 半监督域适应图像语义分割方法、系统、设备及存储介质 |
CN116340833B (zh) * | 2023-05-25 | 2023-10-13 | 中国人民解放军海军工程大学 | 基于改进领域对抗式迁移网络的故障诊断方法 |
CN116340852B (zh) * | 2023-05-30 | 2023-09-15 | 支付宝(杭州)信息技术有限公司 | 一种模型训练、业务风控的方法及装置 |
CN116484905B (zh) * | 2023-06-20 | 2023-08-29 | 合肥高维数据技术有限公司 | 针对非对齐样本的深度神经网络模型训练方法 |
CN116502644B (zh) * | 2023-06-27 | 2023-09-22 | 浙江大学 | 一种基于无源领域自适应的商品实体匹配方法及装置 |
CN116883735B (zh) * | 2023-07-05 | 2024-03-08 | 江南大学 | 基于公有特征和私有特征的域自适应小麦种子分类方法 |
CN116543269B (zh) * | 2023-07-07 | 2023-09-05 | 江西师范大学 | 基于自监督的跨域小样本细粒度图像识别方法及其模型 |
CN116883681B (zh) * | 2023-08-09 | 2024-01-30 | 北京航空航天大学 | 一种基于对抗生成网络的域泛化目标检测方法 |
CN117132841B (zh) * | 2023-10-26 | 2024-03-29 | 之江实验室 | 一种保守渐进的领域自适应图像分类方法和装置 |
CN117253097B (zh) * | 2023-11-20 | 2024-02-23 | 中国科学技术大学 | 半监督域适应图像分类方法、系统、设备及存储介质 |
CN117435916B (zh) * | 2023-12-18 | 2024-03-12 | 四川云实信息技术有限公司 | 航片ai解译中的自适应迁移学习方法 |
CN117576765B (zh) * | 2024-01-15 | 2024-03-29 | 华中科技大学 | 一种基于分层特征对齐的面部动作单元检测模型构建方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110135579A (zh) * | 2019-04-08 | 2019-08-16 | 上海交通大学 | 基于对抗学习的无监督领域适应方法、系统及介质 |
US20200151457A1 (en) * | 2018-11-13 | 2020-05-14 | Nec Laboratories America, Inc. | Attention and warping based domain adaptation for videos |
CN111340021A (zh) * | 2020-02-20 | 2020-06-26 | 中国科学技术大学 | 基于中心对齐和关系显著性的无监督域适应目标检测方法 |
-
2020
- 2020-11-05 CN CN202011223578.5A patent/CN112308158B/zh active Active
-
2021
- 2021-11-05 US US17/519,604 patent/US11960568B2/en active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200151457A1 (en) * | 2018-11-13 | 2020-05-14 | Nec Laboratories America, Inc. | Attention and warping based domain adaptation for videos |
CN110135579A (zh) * | 2019-04-08 | 2019-08-16 | 上海交通大学 | 基于对抗学习的无监督领域适应方法、系统及介质 |
CN111340021A (zh) * | 2020-02-20 | 2020-06-26 | 中国科学技术大学 | 基于中心对齐和关系显著性的无监督域适应目标检测方法 |
Non-Patent Citations (3)
Title |
---|
aligning domain-specific distribution and classifier for cross-domain classification from multiple sources;yongchun zhu et al;《the thirty-third AAAI conference on artificial intelligence》;20191231;5989-5996 * |
learning domain-invariant and discriminative features for homogeneous unsupervised domain adaptation;zhang yun et al;《Chinese Journal of Electronics》;20201101;第29卷(第6期);1119-1125 * |
具有特征选择的多源自适应分类框架;黄学雨 等;《计算机应用》;20200910;第40卷(第9期);2499-2506 * |
Also Published As
Publication number | Publication date |
---|---|
US11960568B2 (en) | 2024-04-16 |
US20220138495A1 (en) | 2022-05-05 |
CN112308158A (zh) | 2021-02-02 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112308158B (zh) | 一种基于部分特征对齐的多源领域自适应模型及方法 | |
CN109993100B (zh) | 基于深层特征聚类的人脸表情识别的实现方法 | |
CN111126386B (zh) | 场景文本识别中基于对抗学习的序列领域适应方法 | |
CN106682694A (zh) | 一种基于深度学习的敏感图像识别方法 | |
CN111696101A (zh) | 一种基于SE-Inception的轻量级茄科病害识别方法 | |
CN109344759A (zh) | 一种基于角度损失神经网络的亲属识别方法 | |
CN111414461A (zh) | 一种融合知识库与用户建模的智能问答方法及系统 | |
CN109741341A (zh) | 一种基于超像素和长短时记忆网络的图像分割方法 | |
CN109886161A (zh) | 一种基于可能性聚类和卷积神经网络的道路交通标识识别方法 | |
CN110414587A (zh) | 基于渐进学习的深度卷积神经网络训练方法与系统 | |
CN112784921A (zh) | 任务注意力引导的小样本图像互补学习分类算法 | |
CN112784929A (zh) | 一种基于双元组扩充的小样本图像分类方法及装置 | |
CN111126155B (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
CN110991554B (zh) | 一种基于改进pca的深度网络图像分类方法 | |
CN116152554A (zh) | 基于知识引导的小样本图像识别系统 | |
CN112084895A (zh) | 一种基于深度学习的行人重识别方法 | |
CN111310820A (zh) | 基于交叉验证深度cnn特征集成的地基气象云图分类方法 | |
Hou et al. | A face detection algorithm based on two information flow block and retinal receptive field block | |
López-Cifuentes et al. | Attention-based knowledge distillation in scene recognition: the impact of a dct-driven loss | |
Sun et al. | Deep learning based pedestrian detection | |
Rahman et al. | A CNN Model-based ensemble approach for Fruit identification using seed | |
CN109583406B (zh) | 基于特征关注机制的人脸表情识别方法 | |
CN115439791A (zh) | 跨域视频动作识别方法、装置、设备和计算机可存储介质 | |
CN115546474A (zh) | 一种基于学习者集成策略的少样本语义分割方法 | |
CN115100694A (zh) | 一种基于自监督神经网络的指纹快速检索方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |