CN113011513B - 一种基于通用域自适应的图像大数据分类方法 - Google Patents

一种基于通用域自适应的图像大数据分类方法 Download PDF

Info

Publication number
CN113011513B
CN113011513B CN202110333791.XA CN202110333791A CN113011513B CN 113011513 B CN113011513 B CN 113011513B CN 202110333791 A CN202110333791 A CN 202110333791A CN 113011513 B CN113011513 B CN 113011513B
Authority
CN
China
Prior art keywords
domain
target domain
classification
target
image
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110333791.XA
Other languages
English (en)
Other versions
CN113011513A (zh
Inventor
罗荣华
周绍煌
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
South China University of Technology SCUT
Original Assignee
South China University of Technology SCUT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by South China University of Technology SCUT filed Critical South China University of Technology SCUT
Priority to CN202110333791.XA priority Critical patent/CN113011513B/zh
Publication of CN113011513A publication Critical patent/CN113011513A/zh
Application granted granted Critical
Publication of CN113011513B publication Critical patent/CN113011513B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • G06F18/254Fusion techniques of classification results, e.g. of results related to same input data
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于通用域自适应的图像大数据分类方法,本发明方法包括将获取的目标域图像数据Xt输入预先训练好的通用域自适应网络,通过特征提取器F将目标域图像数据Xt转化成目标域图像特征向量Zt,通过分类部分G对特征向量Zt进行分类输出得到分类结果Yt和余弦相似度Ct,通过域判别器D对特征向量Zt进行域判别输出得到目标域判别dt。将得到的余弦相似度Ct和域判别dt结合得到目标域权重Wt,目标域权重Wt与阈值相比较,大于阈值则输出分类结果Yt。本发明能够解决图像大数据域自适应的问题,使训练好的模型可以应用在任意图像大数据集上,极大地提高了模型的泛化能力和分类效果,减轻对模型源域数据集的依赖。

Description

一种基于通用域自适应的图像大数据分类方法
技术领域
本发明涉及图像大数据处理技术,具体涉及一种基于通用域自适应的图像大数据分类方法。
背景技术
随着神经网络的提出,在人工智能领域取得了重大发展突破,并成为深度学习等研究领域的主干网络。
随着对抗神经网络的提出,作为启发方法又产生了大量的新的学习方法,其中迁移学习根据神经网络以及对抗训练的方法提出了对抗性域自适应,并成为了域自适应的主流框架。
近年来大数据相关产业蓬勃发展,而大数据的使用也成为了一个重要的研究问题,以往的机器学习中都需要目标域标注数据,域自适应刚好是一个解决目标域无标签训练的较好方法。但由于域自适应要求源域与目标域之间的标签空间完全相同,在应用时虽然减少了目标域空间标签的工作量,在大数据的应用场景上远不能满足现实世界的要求,因作为源域的训练样本标签空间有限,且制作需要消耗大量代价,在实际应用中不能很好的满足大数据中众多类别的需求。因此本发明提出了通用域自适应研究方法,旨在已知源域而目标域类别未知的情况下训练网络,能将目标域中与源域共有的部分很好区分出来并进行任务输出。
Kaichao You等在《Universal domain adaptation.In The IEEE Conference onComputer Vision and Pattern Recognition(CVPR),June 2019》中,对于分类方法的实际应用并没有深入研究,在工业界中也未有实际应用,且分类的精确性较低。
发明内容
本发明要解决的技术问题,针对现有技术的上述问题,提出了一种基于通用域自适应的图像大数据分类方法,通过该方法,能将源域和目标域的数据映射到同一空间,并确保其共有类别在空间上分布一致,最终通过在源域和目标域伪标签训练的分类器上能有较好的分类效果,并且通过权重判断是否为共有类,解决传统域自适应的应用场景有限问题。
一种基于通用域自适应的图像大数据分类方法,步骤包括:
将目标域图像数据xt输入预先训练好的通用域自适应网络,所述通用域自适应网络包括特征提取器F、分类部分G和域判别器D,通过特征提取器F将目标域图像数据xt转化成目标域图像特征向量zt,分类部分G对特征向量zt进行分类输出得到分类结果yt和余弦相似度ct,通过域判别器D对特征向量zt进行域判别,输出得到目标域与源域的相似度dt
将得到的分类结果yt、余弦相似度ct和域判别dt结合得到目标域权重wt,目标域权重wt与阈值wα相比较,大于阈值wα则输出分类结果yt,小于阈值的目标域数据为目标域的特有类,将其视为一类输出,目标域的特有类加上源域和目标域共有的类别数做为最终的分类结果输出。
在所述将目标域图像数据xt输入训练好的通用域自适应网络之前还包括训练通用域自适应网络的步骤,所述训练通用域自适应网络的步骤如下:
其中特征提取器F由残差网络resnet-50组成,在image-net上预训练得到基础参数。将有标签的源域数据xs和无标签的目标域数据xt作为输入特征提取器F,从而将源域和目标域映射到同一空间中,得到对应输出源域图像特征向量zs和目标域图像特征向量zt,这些特征向量又作为输入,分别传给分类部分G和域判别器D进行训练。
其中所述分类部分G由两个参数不同的分类器组成,所述的两个分类器的网络结构相同,都是由2个全连接层组成且全连接层也尺寸相同,为确保两分类器的参数不同,使两个分类器的参数保持余弦距离损失函数;
所述公式如:
Figure BDA0002997358890000031
所述分类器的两层全连接层大小256,源域类别数。
所述域判别器D由3层全连接层组成,域判别器输出该数据是来自源域还是目标域,当输入为目标域时dt越大,为源域与目标域共有类的可能性越高。
所述域判别器D三层全连接层大小分别为1024,1024,1。
所述第一第二层全连接层后面加有激活函数ReLU和0.5的Dropout。
所述第三层后加有激活函数sigmoid。
所述域判别器D的训练损失函数如下所示:
Figure BDA0002997358890000032
所述Ladv(G,D)中w(x)为输入图像的权重,当输入为源域,w(x)权重越小,越有可能属于源域中共有类部分,因此乘以-1;
所述Ladv(G,D)中源域训练标签为1目标域训练标签为0;
为判别目标域输入是否为源域共有类,通过获取目标域数据于域判别器D和分类部分G的输出处理并得到权重wt,该权重wt与阈值wα进行对比,大于阈值的视为共有类,进行处理,否则小于阈值视为目标域特有类,记为unknown。
所述阈值wα在训练阶段根据不同批次,动态线性变化,当最后一批训练完成时值为w0
所述训练阶段动态阈值wα公式:
Figure BDA0002997358890000033
所述动态阈值wα公式中t为当前训练批次,T为总批次,阈值随着训练过程线性降低,w0=0.8为定值,当训练到最后一批t=T时,wα=w0
所述阈值在测试阶段大小固定为w0,所述w0与训练阶段w0相同。
所述判断为共有类进行处理包括:
在训练阶段将分批次中被判定为共有类的目标域数据打上伪标签,将拥有伪标签的目标域特征向量zt作为输入反馈训练分类部分G。
而在测试阶段则直接将分类部分G的输出作为结果输出。
所述伪标签是目标域数据在分类部分G中输出的单位向量yt的最大值所属类别作为伪标签,
Figure BDA0002997358890000043
所述权重是由域判别器D和分类部分G得到。
所述权重w(x)中的分类器部分,是通过协同训练,将两个分类器的输出向量计算相似度。
所述计算相似度是通过余弦距离公式,计算两个分类器的输出余弦 ct=cos(yt1,yt2),其中yt1和yt2分别为两个分类器的分类输出。
目标域的分类部分输出yt,yt=(yt1+yt2)/2。
权重w(x)中的域判别部分即域判别器的输出,当输入为目标域,w(x)权重越大,越有可能属于目标域中共有类部分,当输入为源域,w(x)权重越小,越有可能属于源域中共有类部分;
当输入为目标域数据时,目标域余弦相似度ct=cos(yt1,yt2),目标域的分类输出yt=(yt1+yt2)/2,其中yt1和yt2分别为两个分类器的分类输出,目标域权重 wt最终公式为:w(x)=d(x)+ct
当输入为源域数据时,源域余弦相似度cs=cos(ys1,ys2),源域的分类输出为 ys=(ys1+ys2)/2,其中ys1,ys2分别为两个分类器的源域分类输出,源域权重ws: ws=d(x)+cs
域判别器与特征提取器之间加入了梯度反转层λadv,λadv=-1。
训练的总损失函数为:
Figure BDA0002997358890000041
所述总损失函数L(G,D)中,Lseg(G)为分类部分G的损失函数,
Figure BDA0002997358890000042
为分类器参数损失函数,Ladv(G,D)为域判别器D计算损失函数,λweight为参数。
所述分类部分G训练阶段是分别通过有标签的源域数据和有伪标签的目标域数据进行训练参数,所述损失函数Lseg计算函数表达式如下:
Figure BDA0002997358890000051
所述损失函数Lseg中,x表示输入的图片,LCE为交叉熵损失函数,p为源域部分, q为目标域部分,y为源域的真实标签,
Figure BDA0002997358890000052
为源域或目标域图像的分类输出,
Figure BDA0002997358890000053
将目标域分类输出的最大值类别作为伪标签类别,
Figure BDA0002997358890000054
表示当权重大于阈值wα时为1,否则为0,
Figure BDA0002997358890000055
表示概率分布在p源域和q目标域。
与现有技术相比,本发明能够实现的有益效果如下:
(1)对于图像大数据来说,存在着不同的图像由于拍摄时间背景不同,因此即便是同一物体也存在着风格差异较大的问题,同时图像大数据还存在着图像的类别较多难以统计,在实际应用中如要对需要的类别进行提取,需要大量的人工筛选。因此本发明提出的一种基于通用域自适应的图像大数据分类方法方法,可以直接应用在图像大数据方面,通过域自适应消除由于背景原因对于分类结果的影响,并通过加入权重筛选判别出需要用到的类别图像,并取得较好的分类效果。
(2)域自适应方法就是在只有源域有训练标签而目标域没有训练标签的情况下,可以通过在源域上训练,迁移到目标域,因此增加模型的应用范围,而之前的域自适应方法都要求源域和目标的的标签空间(即类比相同),本申请的通用域自适应可以在源域和目标域的标签空间类别不完全相同的情况下进行训练,训练后的模型可对目标域中与源域共有的类别即所述权重W判别为共有类的图像数据进行分类输出,扩大了源域和目标域数据的应用场景。
(3)本发明能够解决图像大数据域自适应的问题,使训练好的模型可以应用在任意图像大数据集上,极大地提高了模型的泛化能力和分类效果,减轻对模型源域数据集的依赖。
(4)本申请的加入的协同训练的方法,可以提高分类器的分类精度准确性,并且,对于是否共有类这一判别标准,加入了两个分类器的余弦距离来增加不确定性的度量,即当两个分类器的输出余弦距离越小,说明输入图像的不确定性越大,判别为目标域特有类的可能性越高。
附图说明
图1为训练阶段基本框架流程示意图。
图2为测试阶段基本框架流程示意图。
图3为本发明实施例方法的基本流程示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。
如图2所示,本实施例一种基于通用域自适应的图像大数据分类方法的实施步骤包括:
步骤1:获取目标域图像数据xt
步骤2:将目标域图像数据xt输入预先训练好的通用域自适应网络,通过特征提取器F将目标域图像数据xt转化成目标域图像特征向量zt
步骤3:将特征向量zt通过分类部分G对进行分类输出得到分类结果yt和余弦相似度ct
步骤4:将特征向量zt通过域判别器D进行域判别输出得到目标域判别dt
步骤5:将得到的余弦相似度ct和域判别器dt结合得到目标域权重wt,目标域权重wt与阈值wα相比较,大于阈值则输出分类结果yt。小于阈值的目标域数据为目标域的特有类,将其视为一类输出,目标域的特有类加上源域和目标域共有的类别数做为最终的分类结果输出。
在本发明其中一个实施例中,通用域自适应网络结构包括特征提取器F,域判别器D和分类部分G三部分。
所述特征提取器F,用于对输入的目标域图像数据xt进行特征向量提取,得到目标域图像特征向量zt
所述分类部分G,将目标域图像特征向量zt输入,得到目标域分类结果yt、余弦相似度ct
所述域判别器D,将目标域图像特征向量zt输入,输出目标域判别dt
在本发明其中一个实施例中,如图1所示,将目标域图像数据Xt输入训练好的通用域自适应网络之前还包括训练通用域自适应网络的步骤,准备好有标签的源域数据xs和无标签将要应用的目标域数据xt,其中源域与目标域标签空间不同但有交集,且交集未知。
所述训练特征提取器F,域判别器D和分类部分G的步骤如下:
1)特征提取器F由残差网络resnet-50组成,在image-net上预训练得到网络基础参数;
2)将有标签的源域数据xs和无标签的目标域图像数据xt同时输入至特征提取器F,特征提取器F输出源域图像特征向量zs和目标域图像特征向量zt,通过同一个特征提取器F,使源域和目标域特征向量映射在同一特征空间中;
3)将得到的源域图像特征向量zs和目标域图像特征向量zt作为输入,分别传给分类部分G和域判别器D,分类部分G对输入的特征向量进行分类,输出源域分类结果ys、余弦相似度cs和目标域分类结果yt、余弦相似度ct,对有标签的源域数据分类结果ys计算分类标签的损失函数Lseg,域判别器D将输入的特征向量进行二分类,判别输入数据来自源域还是目标域,输出源域判别ds和目标域判别dt
4)按照源域和目标域将域判别器D输出的源域判别ds和目标域判别dt、分类部分G输出的源域余弦相似度cs和目标域余弦相似度ct结合起来,得到源域权重ws和目标域权重wt,将其中目标域权重wt与阈值wα进行对比,大于阈值wα的视为共有类,进行处理,小于阈值视为目标域特有类,记为unknown;
5)训练阶段所述阈值wα在训练阶段根据不同批次,动态线性变化,阈值wα公式:
Figure BDA0002997358890000081
其中t是当前训练批次,T是模型训练的总批次,w0是测试阶段的阈值,当训练到最后一批t=T时,wα=w0
在本发明其中一个实施例中,所述分类部分G训练阶段是分别通过有标签的源域数据和有伪标签的目标域数据进行训练参数,所述损失函数Lseg计算函数表达式如下:
Figure BDA0002997358890000082
所述损失函数Lseg中,x表示输入的图片,LCE为交叉熵损失函数,p为源域部分, q为目标域部分,y为源域的真实标签,
Figure BDA0002997358890000083
为源域或目标域图像的分类输出,
Figure BDA0002997358890000084
将目标域分类输出的最大值类别作为伪标签类别,
Figure BDA0002997358890000085
表示当权重大于阈值wα时为1,否则为0,
Figure BDA0002997358890000086
表示概率分布在p源域和q目标域。
所述目标域伪标签
Figure BDA0002997358890000087
是将目标域图像的权重大于阈值wα的部分视为共有类,并将其分类输出作为输入再训练分类部分G。而在测试阶段则直接将分类部分G输出向量最大值作为结果输出。
在本发明其中一个实施例中,所述分类部分G由两个参数不同的分类器组成,两个分类器的网络结构相同,都是由2个全连接层组成且全连接层尺寸相同,为确保两分类器的参数不同,使两个分类器的参数保持余弦距离损失函数:
Figure BDA0002997358890000088
分类器的两层全连接层大小256,源域类别数。
权重w(x)中的分类器部分,是通过协同训练,将两个分类器的分类输出计算余弦相似度,加上与判别器的输出d(x),输出值越大,该目标域图像为源域与目标域共有类的可能性越高。
在本发明其中一个实施例中,计算相似度是通过余弦距离公式,计算两个分类器的输出余弦距离ct=cos(yt1,yt2),其中yt1和yt2分别为两个分类器的分类输出。
当输入为目标域数据时,目标域余弦相似度ct=cos(yt1,yt2),目标域的分类输出yt=(yt1+yt2)/2,其中yt1和yt2分别为两个分类器的分类输出,目标域权重 wt最终公式为:w(x)=d(x)+ct
当输入为源域数据时,源域余弦相似度cs=cos(ys1,ys2),源域的分类输出为 ys=(ys1+ys2)/2,其中ys1,ys2分别为两个分类器的源域分类输出,源域权重ws, ws=d(x)+cs
在本发明其中一个实施例中,所述域判别器D由3层全连接层组成,域判别器输出该数据是来自源域还是目标域。
所述域判别器D三层全连接层大小分别为1024,1024,1。
所述第一第二层全连接层后面加有激活函数ReLU和0.5的Dropout。
所述第三层后加油激活函数sigmoid。
所述域判别器D的训练损失函数如下所示:
Figure BDA0002997358890000091
Ladv(G,D)中源域训练标签为1目标域训练标签为0;
Ladv(G,D)中w(x)为输入图像的权重,当输入为目标域时w(x)=wt,当输入为源域时w(x)=ws。当输入为源域,w(x)权重越小,越有可能属于源域中共有类部分,因此乘以-1。权重w(x)是由域判别器输出和余弦距离组合而成,对于与判别器输出来说,是通过设计源域为1目标域为0,因为共有类是源域和目标域重合的部分,因此输入属于目标域中共有类部分时,域判别器的输出比当输入属于目标域的特有部分大,同理源域,源域的共有部分因与目标域重合,因此源域的域判别器输出中共有类部分比特有类部分小。而余弦距离也是相同的理论,余弦距离是度量两个向量相似度的,因为分类器是在源域进行了分类训练,所以两个分类器在输入为源域时,输出的相似度大,目标域时小。而域判别器存在的目的就是是源域和目标域共有的数据在同一特征空间重合,因此乘以w(x)作为参数,更有利于训练数据中属于共有类部分。
综上所述方法训练的总损失函数为:
Figure BDA0002997358890000101
总损失函数L(G,D)中,Lseg(G)为分类部分G的损失函数,
Figure BDA0002997358890000102
为分类器参数损失函数,Ladv(G,D)为域判别器D计算损失函数,λweight为参数。在本发明其中一个实施例中,参数λweight=0.05。λweight控制
Figure BDA0002997358890000103
的大小,防止过分影响分类器参数而影响分类结果的准确性。
在本发明其中一个实施例中,域判别器与特征提取器之间加入了梯度反转层,与特征提取器进行对抗训练,因此域判别器损失函数乘以参数λadv,一般λadv=-1。
显然,本发明的上述实施例仅仅是为清楚地说明本发明所作的举例,而并非是对本发明的实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式的变化或变动。这里无需也无法对所有的实施方式予以穷举。凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明权利要求的保护范围之内。

Claims (9)

1.一种基于通用域自适应的图像大数据分类方法,其特征在于,包括以下步骤:
将目标域图像数据
Figure 480902DEST_PATH_IMAGE001
输入预先训练好的通用域自适应网络,所述通用域自适应网络包 括特征提取器F、分类部分G和域判别器D,通过特征提取器F将目标域图像数据
Figure 559716DEST_PATH_IMAGE001
转化成目 标域图像特征向量
Figure 995377DEST_PATH_IMAGE002
,分类部分G对特征向量
Figure 501445DEST_PATH_IMAGE002
进行分类输出得到分类结果
Figure 655345DEST_PATH_IMAGE003
和余弦相似 度
Figure 537851DEST_PATH_IMAGE004
,通过域判别器D对特征向量
Figure 326553DEST_PATH_IMAGE002
进行域判别,输出得到目标域判别
Figure 269101DEST_PATH_IMAGE005
将得到的分类结果
Figure 644719DEST_PATH_IMAGE003
余弦相似度
Figure 65336DEST_PATH_IMAGE004
和目标域判别
Figure 475589DEST_PATH_IMAGE005
结合得到目标域权重
Figure 323459DEST_PATH_IMAGE006
,目标域 权重
Figure 186373DEST_PATH_IMAGE006
与阈值
Figure 410681DEST_PATH_IMAGE007
相比较,大于阈值
Figure 173975DEST_PATH_IMAGE007
则输出分类结果
Figure 458326DEST_PATH_IMAGE003
,而小于阈值
Figure 808536DEST_PATH_IMAGE007
时对应的的目标 域数据为目标域的特有类,将其视为一类输出,目标域的特有类加上源域和目标域共有的 类别数作为最终的分类结果输出;
其中,在所述将目标域图像数据
Figure 508639DEST_PATH_IMAGE001
输入预先训练好的通用域自适应网络之前还包括训 练通用域自适应网络的步骤,所述训练通用域自适应网络的步骤如下:
特征提取器F包括残差网络resnet-50,在image-net上预训练得到网络基础参数;
将有标签的源域数据
Figure 690221DEST_PATH_IMAGE008
和无标签的目标域图像数据
Figure 83156DEST_PATH_IMAGE001
同时输入至特征提取器F,特征 提取器F输出源域图像特征向量
Figure 982979DEST_PATH_IMAGE009
和目标域图像特征向量,通过同一个特征提取器F使源域 和目标域的图像特征向量映射在同一特征空间中;
将得到的源域图像特征向量
Figure 985308DEST_PATH_IMAGE009
和目标域图像特征向量作为输入,分别传给分类部分G 和域判别器D,分类部分G对输入的特征向量进行分类,输出源域分类结果
Figure 21397DEST_PATH_IMAGE010
、源域余弦相似 度
Figure 585234DEST_PATH_IMAGE011
和目标域分类结果、目标域余弦相似度,针对有标签的源域数据分类结果
Figure 972353DEST_PATH_IMAGE010
计算分类 标签的损失函数
Figure 14258DEST_PATH_IMAGE012
,域判别器D将输入的特征向量进行二分类,判别输入数据来自源域还 是目标域,输出源域判别
Figure 170433DEST_PATH_IMAGE013
和目标域判别;
按照源域和目标域将域判别器D输出的源域判别
Figure 905171DEST_PATH_IMAGE013
和目标域判别、分类部分G输出的源 域余弦相似度
Figure 514007DEST_PATH_IMAGE011
和目标域余弦相似度结合起来,得到源域权重
Figure 123717DEST_PATH_IMAGE014
和目标域权重,将其中目 标域权重与阈值
Figure 134399DEST_PATH_IMAGE007
进行对比,大于阈值
Figure 40038DEST_PATH_IMAGE007
时对应的目标域数据视为共有类,小于阈值时 对应的目标域数据视为目标域特有类;
其中,所述阈值
Figure 136170DEST_PATH_IMAGE007
在训练阶段根据不同批次,动态线性变化,阈值
Figure 519878DEST_PATH_IMAGE007
公式:
Figure 650645DEST_PATH_IMAGE015
其中t是当前训练批次,T是模型训练的总批次,
Figure 727185DEST_PATH_IMAGE016
是测试阶段的阈值,当训练到最后一 批t=T时,
Figure 310613DEST_PATH_IMAGE017
2.根据权利要求1所述的基于通用域自适应的图像大数据分类方法,其特征在于:
所述特征提取器F,用于对输入的目标域图像数据
Figure 19985DEST_PATH_IMAGE001
进行特征向量提取,得到目标域图 像特征向量
Figure 5258DEST_PATH_IMAGE002
所述分类部分G,输入目标域图像特征向量
Figure 252700DEST_PATH_IMAGE002
后,得到目标域分类结果
Figure 323424DEST_PATH_IMAGE003
、余弦相似度
Figure 314514DEST_PATH_IMAGE004
所述域判别器D,输入目标域图像特征向量
Figure 154294DEST_PATH_IMAGE002
后,输出目标域判别
Figure 572637DEST_PATH_IMAGE005
3.根据权利要求1所述的基于通用域自适应的图像大数据分类方法,其特征在于,分类 部分G训练阶段时的损失函数
Figure 865078DEST_PATH_IMAGE012
计算函数表达式如下:
Figure 158394DEST_PATH_IMAGE018
其中,x表示输入的图片,
Figure 852680DEST_PATH_IMAGE019
为交叉熵损失函数,p为源域部分,q为目标域部分,y为源 域的真实标签,
Figure 441925DEST_PATH_IMAGE020
为源域或目标域图像的分类输出,
Figure 487241DEST_PATH_IMAGE021
将目标域分类输出 的最大值类别作为伪标签类别,
Figure 820133DEST_PATH_IMAGE022
表示当权重大于阈值
Figure 368926DEST_PATH_IMAGE007
时为1,否则为0,
Figure 394651DEST_PATH_IMAGE023
表示概率分布在p源域,
Figure 661685DEST_PATH_IMAGE024
表示概率分布在q目标域。
4.根据权利要求3所述的基于通用域自适应的图像大数据分类方法,其特征在于,
Figure 31224DEST_PATH_IMAGE021
将目标域图像的权重大于阈值
Figure 700103DEST_PATH_IMAGE007
的部分视为共有类,并将分类输出作为输 入再训练分类部分G。
5.根据权利要求1所述的基于通用域自适应的图像大数据分类方法,其特征在于,所述 分类部分G包括两个参数不同但结构相同的分类器,两个分类器的网络结构均包括2层全连 接层,为确保分类器的参数不同,使两个分类器的参数保持余弦距离的损失函数
Figure 896729DEST_PATH_IMAGE025
公 式如下:
Figure 385479DEST_PATH_IMAGE026
所述分类部分G的输出是将两个分类器的输出向量求和并进行归一化处理,其中
Figure 325753DEST_PATH_IMAGE027
分别为两个分类器的所有参数。
6.根据权利要求1所述的基于通用域自适应的图像大数据分类方法,其特征在于,所述域判别器D包括三层全连接层,域判别器输出数据是来自源域还是目标域,源域训练标签为1目标域为0,所述域判别器D的训练损失函数如下所示:
Figure 849138DEST_PATH_IMAGE028
其中,
Figure 951086DEST_PATH_IMAGE029
为输入图像的权重
Figure 192712DEST_PATH_IMAGE023
表示概率分布在p源域,
Figure 169633DEST_PATH_IMAGE024
表示概率 分布在q目标域,
Figure 547525DEST_PATH_IMAGE019
为交叉熵损失函数,p为源域部分,q为目标域部分,
Figure 85953DEST_PATH_IMAGE030
为判别器的输 出。
7.根据权利要求1所述的基于通用域自适应的图像大数据分类方法,其特征在于,所述 权重中的分类器部分,是通过将两个分类器的分类输出计算余弦相似度,加上与判别器的 输出
Figure 549296DEST_PATH_IMAGE030
当输入为目标域数据时,目标域余弦相似度
Figure 831372DEST_PATH_IMAGE031
,目标域的分类输出
Figure 329350DEST_PATH_IMAGE032
,其中
Figure 773101DEST_PATH_IMAGE033
Figure 989318DEST_PATH_IMAGE034
分别为两个分类器的分类输出,目标域权重
Figure 308042DEST_PATH_IMAGE006
最终公 式为:
Figure 660526DEST_PATH_IMAGE035
当输入为源域数据时,源域余弦相似度
Figure 275178DEST_PATH_IMAGE036
,源域的分类输出为
Figure 713113DEST_PATH_IMAGE037
,其中
Figure 336992DEST_PATH_IMAGE038
分别为两个分类器的源域分类输出,源域权重
Figure 543982DEST_PATH_IMAGE014
最终 公式为:
Figure 595115DEST_PATH_IMAGE039
8.根据权利要求1所述的基于通用域自适应的图像大数据分类方法,其特征在于,所述域判别器与特征提取器之间加入梯度反转层,与特征提取器进行对抗训练。
9.根据权利要求1-8任一所述的基于通用域自适应的图像大数据分类方法,其特征在于,训练的总损失函数的计算公式如下所示:
Figure 520346DEST_PATH_IMAGE040
其中,
Figure 180872DEST_PATH_IMAGE041
为分类部分G的损失函数,
Figure 507948DEST_PATH_IMAGE042
为分类器参数损失函数,
Figure 729982DEST_PATH_IMAGE043
为域判别器D计算损失函数,
Figure 876930DEST_PATH_IMAGE044
Figure 842612DEST_PATH_IMAGE045
为参数。
CN202110333791.XA 2021-03-29 2021-03-29 一种基于通用域自适应的图像大数据分类方法 Active CN113011513B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110333791.XA CN113011513B (zh) 2021-03-29 2021-03-29 一种基于通用域自适应的图像大数据分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110333791.XA CN113011513B (zh) 2021-03-29 2021-03-29 一种基于通用域自适应的图像大数据分类方法

Publications (2)

Publication Number Publication Date
CN113011513A CN113011513A (zh) 2021-06-22
CN113011513B true CN113011513B (zh) 2023-03-24

Family

ID=76408707

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110333791.XA Active CN113011513B (zh) 2021-03-29 2021-03-29 一种基于通用域自适应的图像大数据分类方法

Country Status (1)

Country Link
CN (1) CN113011513B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113392933B (zh) * 2021-07-06 2022-04-15 湖南大学 一种基于不确定性引导的自适应跨域目标检测方法
CN113537403B (zh) * 2021-08-14 2024-08-13 北京达佳互联信息技术有限公司 图像处理模型的训练方法和装置及预测方法和装置

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101794396A (zh) * 2010-03-25 2010-08-04 西安电子科技大学 基于迁移网络学习的遥感图像目标识别系统及方法
CN107451616A (zh) * 2017-08-01 2017-12-08 西安电子科技大学 基于深度半监督迁移学习的多光谱遥感图像地物分类方法
CN109993173A (zh) * 2019-03-28 2019-07-09 华南理工大学 一种基于种子生长及边界约束的弱监督图像语义分割方法
CN110135510A (zh) * 2019-05-22 2019-08-16 电子科技大学中山学院 一种动态领域自适应方法、设备及计算机可读存储介质
CN110163286A (zh) * 2019-05-24 2019-08-23 常熟理工学院 一种基于混合池化的领域自适应图像分类方法
CN111797703A (zh) * 2020-06-11 2020-10-20 武汉大学 基于鲁棒深度语义分割网络的多源遥感影像分类方法
CN112308158A (zh) * 2020-11-05 2021-02-02 电子科技大学 一种基于部分特征对齐的多源领域自适应模型及方法

Family Cites Families (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10354199B2 (en) * 2015-12-07 2019-07-16 Xerox Corporation Transductive adaptation of classifiers without source data
US10657424B2 (en) * 2016-12-07 2020-05-19 Samsung Electronics Co., Ltd. Target detection method and apparatus
US10289909B2 (en) * 2017-03-06 2019-05-14 Xerox Corporation Conditional adaptation network for image classification
CN107392242B (zh) * 2017-07-18 2020-06-19 广东工业大学 一种基于同态神经网络的跨领域图片分类方法
US20200125928A1 (en) * 2018-10-22 2020-04-23 Ca, Inc. Real-time supervised machine learning by models configured to classify offensiveness of computer-generated natural-language text
CN110378872A (zh) * 2019-06-10 2019-10-25 河海大学 一种面向裂缝图像检测的多源自适应平衡迁移学习方法
CN110781970B (zh) * 2019-10-30 2024-04-26 腾讯科技(深圳)有限公司 分类器的生成方法、装置、设备及存储介质
CN111259941B (zh) * 2020-01-10 2023-09-26 中国科学院计算技术研究所 基于细粒度领域自适应的跨领域图像分类方法及系统
CN111832605B (zh) * 2020-05-22 2023-12-08 北京嘀嘀无限科技发展有限公司 无监督图像分类模型的训练方法、装置和电子设备
CN111738315B (zh) * 2020-06-10 2022-08-12 西安电子科技大学 基于对抗融合多源迁移学习的图像分类方法
CN111860494B (zh) * 2020-06-16 2023-07-07 北京航空航天大学 图像目标检测的优化方法、装置、电子设备和存储介质

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101794396A (zh) * 2010-03-25 2010-08-04 西安电子科技大学 基于迁移网络学习的遥感图像目标识别系统及方法
CN107451616A (zh) * 2017-08-01 2017-12-08 西安电子科技大学 基于深度半监督迁移学习的多光谱遥感图像地物分类方法
CN109993173A (zh) * 2019-03-28 2019-07-09 华南理工大学 一种基于种子生长及边界约束的弱监督图像语义分割方法
CN110135510A (zh) * 2019-05-22 2019-08-16 电子科技大学中山学院 一种动态领域自适应方法、设备及计算机可读存储介质
CN110163286A (zh) * 2019-05-24 2019-08-23 常熟理工学院 一种基于混合池化的领域自适应图像分类方法
CN111797703A (zh) * 2020-06-11 2020-10-20 武汉大学 基于鲁棒深度语义分割网络的多源遥感影像分类方法
CN112308158A (zh) * 2020-11-05 2021-02-02 电子科技大学 一种基于部分特征对齐的多源领域自适应模型及方法

Also Published As

Publication number Publication date
CN113011513A (zh) 2021-06-22

Similar Documents

Publication Publication Date Title
CN109949317B (zh) 基于逐步对抗学习的半监督图像实例分割方法
Fang et al. A Method for Improving CNN-Based Image Recognition Using DCGAN.
CN111126482B (zh) 一种基于多分类器级联模型的遥感影像自动分类方法
CN113076994B (zh) 一种开集域自适应图像分类方法及系统
WO2020114378A1 (zh) 视频水印的识别方法、装置、设备及存储介质
CN109359541A (zh) 一种基于深度迁移学习的素描人脸识别方法
CN110633708A (zh) 一种基于全局模型和局部优化的深度网络显著性检测方法
CN113011513B (zh) 一种基于通用域自适应的图像大数据分类方法
CN110827265B (zh) 基于深度学习的图片异常检测方法
CN113870254B (zh) 目标对象的检测方法、装置、电子设备及存储介质
Jiang et al. A CNN model for semantic person part segmentation with capacity optimization
CN108520215A (zh) 基于多尺度联合特征编码器的单样本人脸识别方法
CN110633727A (zh) 基于选择性搜索的深度神经网络舰船目标细粒度识别方法
Aruleba et al. Deep learning for age estimation using EfficientNet
CN114492634A (zh) 一种细粒度装备图片分类识别方法及系统
Liu et al. Zero-Shot Object Detection by Semantics-Aware DETR with Adaptive Contrastive Loss
CN106447691A (zh) 基于加权多示例学习的加权极限学习机视频目标跟踪方法
Khashman Blood cell identification using a simple neural network
CN112750128B (zh) 图像语义分割方法、装置、终端及可读存储介质
Hsia et al. A fast face detection method for illumination variant condition
Salih et al. Deep learning for face expressions detection: Enhanced recurrent neural network with long short term memory
Pryor et al. Deepfake detection analyzing hybrid dataset utilizing CNN and SVM
CN115878896A (zh) 基于语义的真假性特征的多模态虚假新闻检测方法及装置
CN115457620A (zh) 用户表情识别方法、装置、计算机设备及存储介质
Liu et al. PGR-Net: A parallel network based on group and regression for age estimation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant