CN108647736A - 一种基于感知损失和匹配注意力机制的图像分类方法 - Google Patents
一种基于感知损失和匹配注意力机制的图像分类方法 Download PDFInfo
- Publication number
- CN108647736A CN108647736A CN201810468906.4A CN201810468906A CN108647736A CN 108647736 A CN108647736 A CN 108647736A CN 201810468906 A CN201810468906 A CN 201810468906A CN 108647736 A CN108647736 A CN 108647736A
- Authority
- CN
- China
- Prior art keywords
- feature
- network
- target
- targetcnn
- source data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 230000008447 perception Effects 0.000 title claims abstract description 35
- 230000007246 mechanism Effects 0.000 title claims abstract description 26
- 238000000034 method Methods 0.000 title claims abstract description 18
- 238000000605 extraction Methods 0.000 claims abstract description 67
- 230000006870 function Effects 0.000 claims abstract description 37
- 238000012549 training Methods 0.000 claims abstract description 26
- 238000009826 distribution Methods 0.000 claims abstract description 23
- 238000005457 optimization Methods 0.000 claims abstract description 19
- 239000000284 extract Substances 0.000 claims abstract description 7
- 238000010276 construction Methods 0.000 claims description 19
- 239000013598 vector Substances 0.000 claims description 8
- 239000012141 concentrate Substances 0.000 claims description 7
- 238000004364 calculation method Methods 0.000 claims description 6
- 239000004576 sand Substances 0.000 claims description 6
- 238000005303 weighing Methods 0.000 abstract description 4
- 239000010410 layer Substances 0.000 description 34
- 238000013528 artificial neural network Methods 0.000 description 5
- 239000002356 single layer Substances 0.000 description 4
- 238000013527 convolutional neural network Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 230000006872 improvement Effects 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000001537 neural effect Effects 0.000 description 2
- 238000010606 normalization Methods 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000004069 differentiation Effects 0.000 description 1
- 230000001617 migratory effect Effects 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 230000007704 transition Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2413—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Abstract
本发明公开了一种基于感知损失和匹配注意力机制的图像分类方法,包括:步骤1,针对目标数据集和源数据集分别设计一个基于卷积核的特征提取网络;步骤2,在特征提取网络中间层之间增加匹配注意力机制。步骤3,设计一个针对提取出的特征的分类网络。步骤4利用带类标的源数据集训练步骤1~步骤3构建的整个网络,利用训练出的参数作为网络下一阶段训练的初始化。步骤5,利用训练好的分类器构造一个感知损失函数,感知函数用于衡量特征提取网络提取出特征分布的距离。步骤6,再设计判别器网络用于区分特征提取网络提取出特征分布的不同。步骤7,优化步骤1~步骤3构建的网络;步骤8,得到图像的分类结果。
Description
技术领域
本发明属于计算机视觉领域,尤其涉及一种基于感知损失和匹配注意力机制的图像分类方法。
背景技术
领域迁移方法就是尝试着去解决在目标集数据集标签是很少的情况下,缓解模型由于数据偏移造成的负面影响。具体说来,最近的迁移学习的方法期望学习出一个能将不同分布的数据集映射到同一个特征空间的深度神经网络。对抗迁移方法也被广泛用于处理两个特征分布不一致的问题。在目标数据集上特征提取器相当于生成器,旨在提取出和源数据集上分布一样的特征,判别器旨在将源数据集和目标数据集上提取的特征区分开来。判别器无法判断出一个数据是来自真实分布还是生成分布。但是通过反向传播算法优化很难达到Nash均衡。另外普遍存在的现象是生成器容易出现数据分布局部或者整体崩塌,也就是生成的数据太过单一化。因此,仅仅利用判别器思想判别两个分布是否相似会显得不稳定。
在网络结构方面,普遍认为卷积神经网络的低层提取到的是图像的低层信息,然而在中层以上能学习出抽象的特征例如物体局部复杂的纹理。所以,图像的中层特征与高层更抽象的分类特征更加信息相关,因此,利用中层的更多的细节信息,会有利于在高层判断两个特征是否相似,匹配门机制和注意力机制能够动态地选取特征关键的部分,而不是对特征中每一个维度同等看待。有利于促进相似的局部特征传递到网络高层,从而使得网络网络中层特征能产生与任务更相关的特征
发明内容
发明目的:深度学习中大部分的模型是在庞大的数据集上面训练的,无法泛化到新的数据集和新的任务中,领域迁移方法就是尝试着去解决在目标集数据集标签是很少的情况下,缓解模型由于数据偏移造成的负面影响。本文结合匹配门函数和注意力机制,提出了匹配注意力机制,有利于学习到与任务更相关的特征,增强了特征提取网络的表达能力。另外我们在损失函数设计方面,提取了利用感知损失的思想来衡量两个特征分布情况缓解了仅仅基于判别器网络判别特征分布的不稳定行为。
本发明具体包括如下步骤:
步骤1,针对源数据集和目标数据集,分别构造两个基于卷积核的特征提取器,针对源数据集XS、源数据集中的图像对应的类别标签信息集合YS构造源数据集特征提取网络SourceCNN,针对目标数据集XT构造目标特征提取网络TargetCNN,SourceCNN和TargetCNN可以为任意卷积神经网络,比如有3层卷积层后接2层全连接层组成的,SourceCNN用于提取源图像集中图像的可判别的特征,TargetCNN用于提取目标数据集中图像的可判别的特征,,用xs表示源数据集XS中第s张图像,ys表示xs对应的类别标签信息(如数字1,数字2等),ys∈YS;用xt表示目标数据集XT中第t张图像,其不包含类别标签信息;
步骤2,根据源数据集,目标数据集和提取的特征构造分类器网络Classifier(分类器网络的输入为SourceCNN和TargetCNN的输出,分类器网络的输出为源数据集的类别数目。比如可以根据输入输出维度构造一个单层神经网络),分类器网络Classifier用于对源数据集图像特征和目标数据集图像特征进行分类;
步骤3,在源数据集特征提取网络SourceCNN和目标特征提取网络TargetCNN中间层增加匹配注意力机制Matching Attention;
步骤4,利用源数据集训练分类器网络Classifier,得到训练好的分类器网络Classifier;
步骤5,根据步骤4训练得到的分类器网络Classifier,构造感知损失函数,得到感知损失;感知损失函数用于判断两个特征分布之间的距离,将源数据特征和目标数据特征分布输入到分类器网络得到一个针对类别的平均概率,两个平均概率分布之间的KL-散度即为感知损失;
步骤6,构造判别器网络,判别器网络用于判断输入特征来自源数据集特征提取网络SourceCNN的输出还是目标特征提取网络TargetCNN的输出,对于源数据集特征提取网络SourceCNN的输出,优化判别器网络让其输出概率1;对于目标特征提取网络TargetCNN的输出,优化判别器网络让其输出概率0;
步骤7,利用感知损失和判别器网络优化源数据集特征提取网络SourceCNN、目标特征提取网络TargetCNN和匹配注意力机制Matching Attention,并同时利用源数据集特征提取网络SourceCNN和目标特征提取网络TargetCNN的输出特征训练判别器网络;
步骤8,在步骤7的训练阶段完成之后,对无类别标签信息目标数据集进行分类,构造输入对,目标特征提取网络TargetCNN的输出即为对目标数据集提取特征,分类器网络Classifier的输出即为预测的图像分类类别。
匹配注意力机制旨在通过在源数据集特征提取网络SourceCNN和目标特征提取网络TargetCNN特征层中间添加了一个桥,即目标特征提取网络TargetCNN中间层的提取到的特征可以根据源数据集特征提取网络SourceCNN和目标特征提取网络TargetCNN特征差值调整所学到特征的权重,步骤3包括:
步骤3-1:记源数据集特征提取网络SouceCNN隐藏层(SourceCNN是一个多层的卷积神经网络,中间层称之为隐藏层)j-1的特征为目标特征提取网络TargetCNN隐藏层j-1的特征为对于和通过同一个特征概括网络fsumm(比如单层神经网络)对源数据集特征和目标数据集特征重新编码表示,分别得到源数据集新的特征和目标数据集新的特征
步骤3-2:根据如下公式计算步骤3-1得出的新的特征之间的相似度gj-1:
步骤3-3:对后续隐藏层j层特征不同维度,根据步骤3-2的相似度学习出不同的权重:
其中,c为特征和的维度,表示源数据集特征每一维度计算注意力的函数,所占的比重表示目标数据集特征每一维度计算注意力的函数,代表针对源数据集特征每一维度计算出的权重,代表针对目标数据集特征每一维度计算出的权重,表示根据权重重新计算出的特征,表示根据权重重新计算出的特征。
步骤4中,分类器网络包含四个组件:SourceCNN,TargetCNN,MatchingAttention,Classifier,利用源数据集XS随机取样构造训练对(源数据集图像xs,源数据集图像x's),xs和x's为源数据集中随机取样的图像,它们对应的类标为(ys,y's),训练分类器网络,最小化分类误差,损失函数定义为:
其中,数据C(fs)表示分类器网络对fs分类的结果,K为类别标签信息集合中类别标签的总数,k取值为1~K,C(f's)表示分类器网络对f's分类的结果,fs表示步骤1中源数据集特征提取网络SourceCNN对xs的提取出的特征,源数据集特征提取网络SourceCNN对xs的提取出的所有特征的集合记为FS;f's表示步骤1中目标特征提取网络TargetCNN对x's提取出的特征,目标特征提取网络TargetCNN提取出的所有特征的集合记为FT,log为对数函数,xs的存在是为了训练SourceCNN和Classifier,x's的存在是为了为TargetCNN和MatchingAttention找到比较好的初始化值。
步骤5中,对于源数据集特征提取网络SourceCNN提取出的特征集FS和目标特征提取网络TargetCNN提取出的特征集FT,利用步骤4训练好的分类器网络Classifier构造感知损失函数衡量FS与FT之间的距离,感知损失函数公式如下:
其中,Constant为一个常数,pi表示对于源数据集特征提取网络SouceCNN所提取到的所有特征经过分类器网络Classifier之后,属于类别i的平均概率;qi表示对于目标特征提取网络TargetCNN所提取到的所有特征经过分类网络Classifier之后,属于类别i的平均概率,两个平均概率分布之间的KL-散度即为感知损失。
pi、qi计算公式如下:
步骤7包括:联合步骤5提出的感知损失函数和步骤6提出的判别器网络D,对于源数据集特征提取网络SourceCNN提取到的源特征集FS,判别器网络D期望将其和目标特征提取网络TargetCNN提取到的目标特征集FT区分开,再次优化步骤1~步骤3提出的四个组件SourceCNN,TargetCNN,Matching Attention,Classifier,判别器网络D和四个组件之间的优化是交替进行,优化公式如下:
其中,k取值为1~K,表示当ys=k时,返回1,其余返回0,表示对于源特征集FS分类误差,表示判别器网络D对源特征集FS和目标特征集FT的区分度,表示目标特征集FT生成的特征欺骗判别器网络D的程度,α,β,γ分别表示优化过程中所占的比重。
步骤8包括:构造输入对(xs,xt),得到目标特征提取网络TargetCNN的输出ft,再经过分类器网络Classifier得到xt分类结果。
本发明将匹配门函数和注意力机制合成起来提出了匹配注意力机制,有利于深度神经网络学习到任务更相关的特征,同时将用于衡量两张图像之间距离的感知损失函数扩展到衡量两个复杂、不可描述的特征分布之间距离,并第一次将它们用在了迁移学习领域,在无类标目标数据集上取得到了更高的准确度。
有益效果:本发明解决了现有技术中的问题:如何进一步提高网络表达能力和缓解对抗损失不稳定的问题。本文通过将匹配门函数和注意力机制结合起来在网络的中间层增加匹配注意力机制,从卷积神经网络的中层开始,增加一个匹配门门函数用于比较两张图像中提取出来特征的相似情况。匹配门门函数的存在有利于促进相似的局部特征传递到网络高层,从而使得网络网络中层特征能产生与任务更相关的特征。同时对于对抗损失训练不稳定的问题,我们额外增加了特征级的感知损失函数用于衡量两个特征分布之间的距离。利用我们提出的两点改进,我们在领域迁移方向,分类的准确度取得了明显提升。
附图说明
下面结合附图和具体实施方式对本发明做更进一步的具体说明,本发明的上述或其他方面的优点将会变得更加清楚。
图1为TargetCNN和SourceCNN中可选的网络结构图,其中FC表示全连接层,CONV表示卷积层,Residual Attention model表示残差注意力机制。
图2为整个方法的流程图。
具体实施方式
下面结合附图及实施例对本发明做进一步说明。
本发明适用与应对带类标的源数据集和目标数据集分布存在差异时,如何利用好在源数据集上训练的分类模型,来提高带少量类标或者缺失类标的目标数据集上的分类准确度。本发明提出了新的领域迁移网络和新的优化函数。1)在源数据集特征提取网络(SourceCNN)和目标特征提取网络(TargetCNN)中间层增加匹配注意力机制(MatchingAttention)。用于比较从两张图像中提取出来特征的相似情况。匹配门门函数的存在有利于促进相似的局部特征传递到网络高层,从而使得网络网络中层特征能产生与任务更相关的特征。在深度学习中提取出的特征的好坏直接影响到后面的分类结果。2)将用于判断两个图像之间的感知损失思想,扩展到了利用感知思想判断在源数据集特征提取网络(SourceCNN)提取出的特征集fs∈FS和目标特征提取网络(TargetCNN)提取出的特征集ft∈FT之间感知损失。提出了损失函数3)在预测目标数据集图像分类准确度时候,我们从源数据集中任意选择一张源图像,联合目标图像构造输入对(xs,xt),对于一张输入图像,存在多种预测结果,集成多种预测结果有利于提高分类准确度。
如图2所示,本发明包括如下步骤:
包括如下步骤:
步骤1,针对源数据集xs∈XS和目标数据集xt∈XT,分别构造两个基于卷积核的特征提取器(SourceCNN和TargetCNN)。对于源图像集中图像,SourceCNN能提取出可判别的特征,对于目标数据集中图像,TargetCNN能提取出可判别的特征。图1为TargetCNN和SourceCNN中可选的网络结构图,其中一共提供了3个可选择的网络结构:层数比较浅的模型(base model),层数比较深的模型(deeper model),带注意力机制的模型(attentionmodel),FC表示全连接层,CONV表示卷积层,Residual Attention model表示残差注意力机制,maxpool表示最大池化层,flatten表示将多维输入一维化,用于卷积层与全连接层之间的过渡。
步骤2,根据步骤1中构造的特征提取器提取到的特征维数和源数据集,目标数据集和的维数构造分类器网络(Classifier)。分类器的目的是对源数据集图像特征和目标数据集图像特征进行准确分类。
步骤3,在源特征提取器(SourceCNN)和目标特征提起器(TargetCNN)中间层增加匹配注意力机制(Matching Attention),匹配注意力模块旨根据源数据集特征和目标数据集特征之间的差异,调整中间层特征值的权重。
步骤4,利用带标签的源数据集训练步骤1-3构造的网络,即最大化网络在源数据集上的分类准确度,在训练的过程中需要根据源数据集构造训练对(源数据集1,源数据集2),源数据集1的存在是为了训练SourceCNN和Classifier,源数据集2的存在是为了为TargetCNN和Matching Attention找到比较好的初始化值。
步骤5,根据步骤4训练得到的分类器网络,构造感知损失函数,感知损失函数用于判断两个特征分布之间的距离。将源和目标特征向量分布输入到分类器网络得到一个针对类别的平均概率。两个多元概率分布之间的KL-散度即为感知损失。
步骤6,根据生成对抗网络的思想,我们也增加了判别器网络用于判断输入特征来自SourceCNN的输出还是TargetCNN的输出。对于SourceCNN的输出,我们优化判别器让其输出概率1,然而对于TargetCNN的输出,我们优化判别器输出概率0。
步骤7,联合步骤5和步骤6,以(源数据集,目标数据集)为训练对,输入SourceCNN和TargetCNN,利用感知损失和判别器网络优化特征提取网络(TargetCNN和MatchingAttention),并同时利用SourceCNN和TargetCNN的输出特征训练判别器网络。
步骤8,在步骤7的训练阶段完成之后,模型就可以对无类别标签信息目标数据集进行分类,构造输入对(源数据集,目标数据集),TargetCNN的输出即为对目标数据集提取特征,通过Classifier的输出即为预测的图像分类类别。
匹配注意力机制旨在通过在源特征提取器网络(SourceCNN)和目标特征提取器网络(TargetCNN)特征层中间添加了一个桥,即TargetCNN中间层的提取到的特征可以根据SourceCNN和TargetCNN特征差值调整所学到特征的权重。记SouceCNN和TargetCNN隐藏层j-1的特征为和步骤3包括:
步骤3-1:特征概括,即对于和通过同一个特征概括网络(比如单层神经网络)对源数据集特征和目标数据集特征重新编码表示。
步骤3-2:计算步骤3-1得出的新的特征之间的相似度(匹配阶段)。
步骤3-3:对后续隐藏层j层特征不同维度,根据步骤3-2的相似度学习出不同的权重(注意力阶段)。后续两个步骤如下:
上述公式中,gj-1为相似度计算(也可以换成高斯核函数),表示根据相似度注意力函数,对所学到的注意力权重归一化,表示根据学到的权重对特征重新赋值。
整个网络包含四个组件(SourceCNN,TargetCNN,Matching Attention,Classifier),利用源数据集构造训练对(源数据集图像xs,源数据集图像x's),它们对应的类标为(ys,y's),训练整个网络,最小化分类误差。损失函数定义为:
C表示分类器对fs和ft分类的结果,log为对数函数。xs的存在是为了训练SourceCNN和Classifier,x's的存在是为了为TargetCNN和Matching Attention找到比较好的初始化值。
步骤5,对于两个SourceCNN和TargetCNN提取出的特征集fs∈FS和ft∈FT,为了衡量FS与FT之间的距离,我们提出了利用步骤3训练好的分类器网络构造感知损失函数公式如下:
上述式子中,Constant为一个常数,pi与qi分别代表着对于SouceCNN和TargetCNN所提取到的所有特征向量经过分类网络Classifier之后,属于类别i的平均概率。计算公式为:
c表示分类数目。
联合步骤5提出的感知损失函数和步骤6提出的判别器网络,判别器网络为D,对于SourceCNN提取到的特征集FS,D期望将其和TargetCNN提取到的特征集FT区分开,再次微调步骤1到步骤3提出的三个组件(SourceCNN,TargetCNN,Matching Attention),根据生成对抗网络的思想,判别器网络和四个组件之间的优化是交替进行,优化公式如下:
k取值为1~K,表示当ys=k时,返回1,其余返回0,表示对于源特征集分类误差,表示判别器对源特征集和目标特征集特征的区分度。表示目标特征集生成的特征欺骗判别器D的程度,α,β,γ分别表示不同损失之间优化权重比。
在步骤7微调完(SourceCNN,TargetCNN,Matching Attention和Classifier)之后,在预测目标数据集xs∈XS每张图像的分类结果方面。我们需要构造输入对(xs,xt),得到TargetCNN的输出ft,进一步经过分类器Classifier得到分类结果。第二阶段的训练如图2所示。
实施例:
本发明采用上述方案,实现了将SVHN(Street View House Numbers,门牌号数据集)训练出的分类模型迁移到MNIST(小型手写体数据集)上的领域迁移任务。
具体实现如下:首先按照步骤1到步骤3构造整个网络,然后利用带标签的SVHN数据集训练分类任务。当SourceCNN、TargetCNN、Matching Attention和Classifier训练完成之后,保持Classifier的参数不变,构造感知损失函数。联合感知损失和判别器网络在SVHN和MNIST数据集上微调TargetCNN和Matching Attention模块,最后,利用再对目标数据集上图像分类。
本实施例包括如下步骤:
步骤1,针对源数据集xs∈XS和目标数据集xt∈XT,分别构造两个基于卷积核的特征提取器(SourceCNN和TargetCNN)。对于源图像集中图像,SourceCNN能提取出可判别的特征,对于目标数据集中图像,TargetCNN能提取出可判别的特征。
步骤2,根据步骤1中构造的特征提取器提取到的特征维数和源数据集,目标数据集和的维数构造分类器网络(Classifier)。分类器的目的是对源数据集图像特征和目标数据集图像特征进行准确分类。
步骤3,在源特征提取器(SourceCNN)和目标特征提起器(TargetCNN)中间层增加匹配注意力机制(Matching Attention),匹配注意力模块旨根据源数据集特征和目标数据集特征之间的差异,调整中间层特征值的权重。
步骤4,利用带标签的源数据集训练步骤1-3构造的网络,即最大化网络在源数据集上的分类准确度,在训练的过程中需要根据源数据集构造训练对(源数据集1,源数据集2),源数据集1的存在是为了训练SourceCNN和Classifier,源数据集2的存在是为了为TargetCNN和Matching Attention找到比较好的初始化值。
步骤5,根据步骤4训练得到的分类器网络,构造感知损失函数,感知损失函数用于判断两个特征分布之间的距离。将源和目标特征向量分布输入到分类器网络得到一个针对类别的平均概率。两个多元概率分布之间的KL-散度即为感知损失。
步骤6,根据生成对抗网络的思想,我们也增加了判别器网络用于判断输入特征来自SourceCNN的输出还是TargetCNN的输出。对于SourceCNN的输出,我们优化判别器让其输出概率1,然而对于TargetCNN的输出,我们优化判别器输出概率0。
步骤7,联合步骤5和步骤6,以(源数据集,目标数据集)为训练对,输入SourceCNN和TargetCNN,利用感知损失和判别器网络优化特征提取网络(TargetCNN和MatchingAttention),并同时利用SourceCNN和TargetCNN的输出特征训练判别器网络。
步骤8,在步骤7的训练阶段完成之后,模型就可以对无类别标签信息目标数据集进行分类,构造输入对(源数据集,目标数据集),TargetCNN的输出即为对目标数据集提取特征,通过Classifier的输出即为预测的图像分类类别。
匹配注意力机制旨在通过在源特征提取器网络(SourceCNN)和目标特征提取器网络(TargetCNN)特征层中间添加了一个桥,即TargetCNN中间层的提取到的特征可以根据SourceCNN和TargetCNN特征差值调整所学到特征的权重。记SouceCNN和TargetCNN隐藏层j-1的特征为和步骤3包括:
步骤3-1:特征概括,即对于和通过同一个特征概括网络(比如单层神经网络)对源数据集特征和目标数据集特征重新编码表示。
步骤3-2:计算步骤3-1得出的新的特征之间的相似度(匹配阶段)。
步骤3-3:对后续隐藏层j层特征不同维度,根据步骤3-2的相似度学习出不同的权重(注意力阶段)。后续两个步骤如下:
上述公式中,gj-1为相似度计算(也可以换成高斯核函数),表示根据相似度注意力函数,对所学到的注意力权重归一化,表示根据学到的权重对特征重新赋值。
整个网络包含四个组件(SourceCNN,TargetCNN,Matching Attention,Classifier),我们利用源数据集构造训练对(源数据集图像xs,源数据集图像x's),,它们对应的类标为(ys,y's),训练整个网络,最小化分类误差。损失函数定义为:
C表示分类器对fs和ft分类的结果,log为对数函数。xs的存在是为了训练SourceCNN和Classifier,x's的存在是为了为TargetCNN和Matching Attention找到比较好的初始化值。
步骤5,对于两个SourceCNN和TargetCNN提取出的特征集fs∈FS和ft∈FT,为了衡量FS与FT之间的距离,我们提出了利用步骤3训练好的分类器网络构造感知损失函数公式如下:
上述式子中,Constant为一个常数,pi与qi分别代表着对于SouceCNN和TargetCNN所提取到的所有特征向量经过分类网络Classifier之后,属于类别i的平均概率。计算公式为:
c表示分类数目。
联合步骤5提出的感知损失函数和步骤6提出的判别器网络,判别器网络为D,对于SourceCNN提取到的特征集FS,D期望将其和TargetCNN提取到的特征集FT区分开,再次微调步骤1到步骤3提出的三个组件(SourceCNN,TargetCNN,Matching Attention),根据生成对抗网络的思想,判别器网络和四个组件之间的优化是交替进行,优化公式如下:
表示对于源特征集分类误差,表示判别器对源特征集和目标特征集特征的区分度。表示目标特征集生成的特征欺骗判别器D的程度,α,β,γ分别表示不同损失之间优化权重比。
在步骤7微调完(SourceCNN,TargetCNN,Matching Attention和Classifier)之后,在预测目标数据集xs∈XS每张图像的分类结果方面。需要构造输入对(xs,xt),得到TargetCNN的输出ft,进一步经过分类器Classifier得到分类结果。在本实施例中设置他们分别为0.04,10,1。
本发明提供了一种基于感知损失和匹配注意力机制的图像分类方法,具体实现该技术方案的方法和途径很多,以上所述仅是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。本实施例中未明确的各组成部分均可用现有技术加以实现。
Claims (7)
1.一种基于感知损失和匹配注意力机制的图像分类方法,其特征在于,包括如下步骤:
步骤1,针对源数据集XS、源数据集中的图像对应的类别标签信息集合YS构造源数据集特征提取网络SourceCNN,针对目标数据集XT构造目标特征提取网络TargetCNN,SourceCNN用于提取源图像集中图像的可判别的特征,TargetCNN用于提取目标数据集中图像的可判别的特征,用xs表示源数据集XS中第s张图像,ys表示xs对应的类别标签信息,ys∈YS;用xt表示目标数据集XT中第t张图像,其不包含类别标签信息;
步骤2,构造用于对源数据集图像特征和目标数据集图像特征进行分类的分类器网络Classifier;
步骤3,在源数据集特征提取网络SourceCNN和目标特征提取网络TargetCNN中间层增加匹配注意力机制;
步骤4,利用源数据集训练分类器网络Classifier,得到训练好的分类器网络Classifier;
步骤5,根据步骤4训练得到的分类器网络Classifier,构造感知损失函数,得到感知损失;
步骤6,构造判别器网络,判别器网络用于判断输入特征来自源数据集特征提取网络SourceCNN的输出还是目标特征提取网络TargetCNN的输出,对于源数据集特征提取网络SourceCNN的输出,优化判别器网络让其输出概率1;对于目标特征提取网络TargetCNN的输出,优化判别器网络让其输出概率0;
步骤7,利用感知损失和判别器网络优化源数据集特征提取网络SourceCNN、目标特征提取网络TargetCNN、分类器网络Classifier和匹配注意力机制,并同时利用源数据集特征提取网络SourceCNN和目标特征提取网络TargetCNN的输出特征训练判别器网络;
步骤8,在步骤7的训练阶段完成之后,对无类别标签信息目标数据集进行分类,构造输入对,目标特征提取网络TargetCNN的输出即为对目标数据集提取特征,分类器网络Classifier的输出即为预测的图像分类类别。
2.根据权利要求1所述的方法,其特征在于,步骤3包括:
步骤3-1:记源数据集特征提取网络SouceCNN隐藏层j-1的特征为目标特征提取网络TargetCNN隐藏层j-1的特征为对于和通过同一个特征概括网络fsumm对源数据集特征和目标数据集特征重新编码表示,分别得到源数据集新的特征和目标数据集新的特征
步骤3-2:根据如下公式计算步骤3-1得出的新的特征之间的相似度gj-1:
步骤3-3:对后续隐藏层j层特征不同维度,根据步骤3-2的相似度学习出不同的权重:
其中,c为特征和的维度,表示源数据集特征每一维度计算注意力的函数,所占的比重表示目标数据集特征每一维度计算注意力的函数,代表针对源数据集特征每一维度计算出的权重,代表针对目标数据集特征每一维度计算出的权重,表示根据权重重新计算出的特征,表示根据权重重新计算出的特征。
3.根据权利要求2所述的方法,其特征在于,步骤4中,分类器网络包含四个组件:SourceCNN,TargetCNN,Matching Attention,Classifier,利用源数据集XS随机取样构造训练对(xs,x's),xs和x's,为源数据集中随机取样的图像,它们对应的类标为(ys,y's),训练分类器网络,最小化分类误差,损失函数定义为:
其中,数据C(fs)表示分类器网络对fs分类的结果,K为类别标签信息集合中类别标签的总数,k取值为1~K,C(f's)表示分类器网络对f′s分类的结果,fs表示步骤1中源数据集特征提取网络SourceCNN对xs的提取出的特征,源数据集特征提取网络SourceCNN对xs的提取出的所有特征的集合记为FS;f′s表示步骤1中目标特征提取网络TargetCNN对x's提取出的特征,目标特征提取网络TargetCNN提取出的所有特征的集合记为FT,log为对数函数。
4.根据权利要求3所述的方法,其特征在于,步骤5中,对于源数据集特征提取网络SourceCNN提取出的特征集FS和目标特征提取网络TargetCNN提取出的特征集FT,利用步骤4训练好的分类器网络Classifier构造感知损失函数衡量FS与FT之间的距离,感知损失函数公式如下:
其中,Constant为一个常数,pi表示对于源数据集特征提取网络SouceCNN所提取到的所有特征经过分类器网络Classifier之后,属于类别i的平均概率;qi表示对于目标特征提取网络TargetCNN所提取到的所有特征经过分类网络Classifier之后,属于类别i的平均概率,i取值为1~K,两个平均概率分布之间的KL-散度即为感知损失。
5.根据权利要求4所述的方法,其特征在于,pi、qi计算公式如下:
6.根据权利要求5所述的方法,其特征在于,步骤7包括:联合步骤5提出的感知损失函数和步骤6提出的判别器网络D,对于源数据集特征提取网络SourceCNN提取到的源特征集FS,判别器网络D期望将其和目标特征提取网络TargetCNN提取到的目标特征集FT区分开,再次优化步骤1~步骤3提出的四个组件SourceCNN,TargetCNN,Matching Attention,Classifier,判别器网络D和四个组件之间的优化是交替进行,优化公式如下:
其中,k取值为1~K,表示当ys=k时,返回1,其余返回0,表示对于源特征集FS分类误差,表示判别器网络D对源特征集FS和目标特征集FT的区分度,表示目标特征集FT生成的特征欺骗判别器网络D的程度,α,β,γ分别分别表示优化过程中所占的比重。
7.根据权利要求6所述的方法,其特征在于,步骤8包括:构造输入对(xs,xt),得到目标特征提取网络TargetCNN的输出ft,再经过分类器网络Classifier得到xt分类结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201810468906.4A CN108647736B (zh) | 2018-05-16 | 2018-05-16 | 一种基于感知损失和匹配注意力机制的图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201810468906.4A CN108647736B (zh) | 2018-05-16 | 2018-05-16 | 一种基于感知损失和匹配注意力机制的图像分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN108647736A true CN108647736A (zh) | 2018-10-12 |
CN108647736B CN108647736B (zh) | 2021-10-12 |
Family
ID=63756126
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201810468906.4A Active CN108647736B (zh) | 2018-05-16 | 2018-05-16 | 一种基于感知损失和匹配注意力机制的图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN108647736B (zh) |
Cited By (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109376804A (zh) * | 2018-12-19 | 2019-02-22 | 中国地质大学(武汉) | 基于注意力机制和卷积神经网络高光谱遥感图像分类方法 |
CN109902602A (zh) * | 2019-02-16 | 2019-06-18 | 北京工业大学 | 一种基于对抗神经网络数据增强的机场跑道异物材料识别方法 |
CN110047076A (zh) * | 2019-03-29 | 2019-07-23 | 腾讯科技(深圳)有限公司 | 一种图像信息的处理方法、装置及存储介质 |
CN110084119A (zh) * | 2019-03-26 | 2019-08-02 | 安徽艾睿思智能科技有限公司 | 基于深度学习的低分辨率人脸图像识别方法 |
CN110110742A (zh) * | 2019-03-26 | 2019-08-09 | 北京达佳互联信息技术有限公司 | 多特征融合方法、装置、电子设备及存储介质 |
CN110599443A (zh) * | 2019-07-02 | 2019-12-20 | 山东工商学院 | 一种使用双向长短期记忆网络的视觉显著性检测方法 |
CN110704665A (zh) * | 2019-08-30 | 2020-01-17 | 北京大学 | 一种基于视觉注意力机制的图像特征表达方法及系统 |
CN111523663A (zh) * | 2020-04-22 | 2020-08-11 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置以及电子设备 |
CN111832404A (zh) * | 2020-06-04 | 2020-10-27 | 中国科学院空天信息创新研究院 | 一种基于特征生成网络的小样本遥感地物分类方法及系统 |
CN112102326A (zh) * | 2020-10-26 | 2020-12-18 | 北京航星机器制造有限公司 | 一种安检ct图像目标物的提取和分割方法 |
CN113033174A (zh) * | 2021-03-23 | 2021-06-25 | 哈尔滨工业大学 | 一种基于输出型相似门的案件罪名判定方法、装置及存储介质 |
CN113392938A (zh) * | 2021-07-30 | 2021-09-14 | 广东工业大学 | 一种分类模型训练方法、阿尔茨海默病分类方法及装置 |
CN116363536A (zh) * | 2023-05-31 | 2023-06-30 | 国网湖北省电力有限公司经济技术研究院 | 一种基于无人机巡查数据的电网基建设备缺陷归档方法 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107679525A (zh) * | 2017-11-01 | 2018-02-09 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置及计算机可读存储介质 |
CN107918782A (zh) * | 2016-12-29 | 2018-04-17 | 中国科学院计算技术研究所 | 一种生成描述图像内容的自然语言的方法与系统 |
-
2018
- 2018-05-16 CN CN201810468906.4A patent/CN108647736B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107918782A (zh) * | 2016-12-29 | 2018-04-17 | 中国科学院计算技术研究所 | 一种生成描述图像内容的自然语言的方法与系统 |
CN107679525A (zh) * | 2017-11-01 | 2018-02-09 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置及计算机可读存储介质 |
Non-Patent Citations (1)
Title |
---|
刘路飞: "深度学习感知损失函数的设计与应用", 《中国优秀硕士学位论文全文数据库(信息科技)》 * |
Cited By (20)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109376804B (zh) * | 2018-12-19 | 2020-10-30 | 中国地质大学(武汉) | 基于注意力机制和卷积神经网络高光谱遥感图像分类方法 |
CN109376804A (zh) * | 2018-12-19 | 2019-02-22 | 中国地质大学(武汉) | 基于注意力机制和卷积神经网络高光谱遥感图像分类方法 |
CN109902602A (zh) * | 2019-02-16 | 2019-06-18 | 北京工业大学 | 一种基于对抗神经网络数据增强的机场跑道异物材料识别方法 |
CN109902602B (zh) * | 2019-02-16 | 2021-04-30 | 北京工业大学 | 一种基于对抗神经网络数据增强的机场跑道异物材料识别方法 |
CN110084119A (zh) * | 2019-03-26 | 2019-08-02 | 安徽艾睿思智能科技有限公司 | 基于深度学习的低分辨率人脸图像识别方法 |
CN110110742A (zh) * | 2019-03-26 | 2019-08-09 | 北京达佳互联信息技术有限公司 | 多特征融合方法、装置、电子设备及存储介质 |
CN110047076B (zh) * | 2019-03-29 | 2021-03-23 | 腾讯科技(深圳)有限公司 | 一种图像信息的处理方法、装置及存储介质 |
CN110047076A (zh) * | 2019-03-29 | 2019-07-23 | 腾讯科技(深圳)有限公司 | 一种图像信息的处理方法、装置及存储介质 |
CN110599443A (zh) * | 2019-07-02 | 2019-12-20 | 山东工商学院 | 一种使用双向长短期记忆网络的视觉显著性检测方法 |
CN110704665A (zh) * | 2019-08-30 | 2020-01-17 | 北京大学 | 一种基于视觉注意力机制的图像特征表达方法及系统 |
CN111523663A (zh) * | 2020-04-22 | 2020-08-11 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置以及电子设备 |
CN111523663B (zh) * | 2020-04-22 | 2023-06-23 | 北京百度网讯科技有限公司 | 一种目标神经网络模型训练方法、装置以及电子设备 |
CN111832404A (zh) * | 2020-06-04 | 2020-10-27 | 中国科学院空天信息创新研究院 | 一种基于特征生成网络的小样本遥感地物分类方法及系统 |
CN111832404B (zh) * | 2020-06-04 | 2021-05-18 | 中国科学院空天信息创新研究院 | 一种基于特征生成网络的小样本遥感地物分类方法及系统 |
CN112102326A (zh) * | 2020-10-26 | 2020-12-18 | 北京航星机器制造有限公司 | 一种安检ct图像目标物的提取和分割方法 |
CN112102326B (zh) * | 2020-10-26 | 2023-11-07 | 北京航星机器制造有限公司 | 一种安检ct图像目标物的提取和分割方法 |
CN113033174A (zh) * | 2021-03-23 | 2021-06-25 | 哈尔滨工业大学 | 一种基于输出型相似门的案件罪名判定方法、装置及存储介质 |
CN113392938A (zh) * | 2021-07-30 | 2021-09-14 | 广东工业大学 | 一种分类模型训练方法、阿尔茨海默病分类方法及装置 |
CN116363536A (zh) * | 2023-05-31 | 2023-06-30 | 国网湖北省电力有限公司经济技术研究院 | 一种基于无人机巡查数据的电网基建设备缺陷归档方法 |
CN116363536B (zh) * | 2023-05-31 | 2023-08-11 | 国网湖北省电力有限公司经济技术研究院 | 一种基于无人机巡查数据的电网基建设备缺陷归档方法 |
Also Published As
Publication number | Publication date |
---|---|
CN108647736B (zh) | 2021-10-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108647736A (zh) | 一种基于感知损失和匹配注意力机制的图像分类方法 | |
CN108717568B (zh) | 一种基于三维卷积神经网络的图像特征提取与训练方法 | |
CN108564129B (zh) | 一种基于生成对抗网络的轨迹数据分类方法 | |
CN109815801A (zh) | 基于深度学习的人脸识别方法及装置 | |
Chacko et al. | Handwritten character recognition using wavelet energy and extreme learning machine | |
CN104239858B (zh) | 一种人脸特征验证的方法和装置 | |
CN107766418A (zh) | 一种基于融合模型的信用评估方法、电子设备和存储介质 | |
CN103262118B (zh) | 属性值估计装置和属性值估计方法 | |
EP2908268B1 (en) | Face detector training method, face detection method, and apparatus | |
CN107871100A (zh) | 人脸模型的训练方法和装置、人脸认证方法和装置 | |
CN102324038B (zh) | 一种基于数字图像的植物种类识别方法 | |
CN107609459A (zh) | 一种基于深度学习的人脸识别方法及装置 | |
CN107239514A (zh) | 一种基于卷积神经网络的植物识别方法及系统 | |
CN104866829A (zh) | 一种基于特征学习的跨年龄人脸验证方法 | |
CN106919951A (zh) | 一种基于点击与视觉融合的弱监督双线性深度学习方法 | |
CN103984953A (zh) | 基于多特征融合与Boosting决策森林的街景图像的语义分割方法 | |
CN104680179B (zh) | 基于邻域相似度的数据降维方法 | |
CN106295694A (zh) | 一种迭代重约束组稀疏表示分类的人脸识别方法 | |
CN109993100A (zh) | 基于深层特征聚类的人脸表情识别的实现方法 | |
CN104376308B (zh) | 一种基于多任务学习的人体动作识别方法 | |
CN109726918A (zh) | 基于生成式对抗网络和半监督学习的个人信用确定方法 | |
CN107239777A (zh) | 一种基于多视角图模型的餐具检测和识别方法 | |
CN109034281A (zh) | 加速基于卷积神经网络的中文手写体识别的方法 | |
CN108198324B (zh) | 一种基于图像指纹的多国纸币币种识别方法 | |
CN109815814A (zh) | 一种基于卷积神经网络的人脸检测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |