CN111738315B - 基于对抗融合多源迁移学习的图像分类方法 - Google Patents

基于对抗融合多源迁移学习的图像分类方法 Download PDF

Info

Publication number
CN111738315B
CN111738315B CN202010521228.0A CN202010521228A CN111738315B CN 111738315 B CN111738315 B CN 111738315B CN 202010521228 A CN202010521228 A CN 202010521228A CN 111738315 B CN111738315 B CN 111738315B
Authority
CN
China
Prior art keywords
domain
sample
network
classifier
source
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010521228.0A
Other languages
English (en)
Other versions
CN111738315A (zh
Inventor
方敏
徐筱
杜辉
胡心钰
李海翔
郭龙飞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xidian University
Original Assignee
Xidian University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xidian University filed Critical Xidian University
Priority to CN202010521228.0A priority Critical patent/CN111738315B/zh
Publication of CN111738315A publication Critical patent/CN111738315A/zh
Application granted granted Critical
Publication of CN111738315B publication Critical patent/CN111738315B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于对抗融合多源迁移学习的图像分类方法,主要解决现有技术图像分类准确率低的问题。其实现方案是:1)建立特征提取网络,从原始图像文件中提取图像特征;2)将图像特征输入特定的域判别器及分类器,计算得到域判别损失及目标域数据的伪标记、源域数据的分类损失;3)利用目标域样本伪标记与源域样本标记,计算得到源域与目标域中所有类别的MMD距离之和;4)利用域判别损失、分类损失及MMD距离之和对特征提取网络、域判别器及分类器进行训练;5)将待测样本依次输入到训练后的特征提取网络、域判别器及分类器,输出待测样本的类别标记。本发明能有效提高各类图像的分类准确率,可用于训练数据标记缺失下的图像分类。

Description

基于对抗融合多源迁移学习的图像分类方法
技术领域
本发明属于图像识别领域,特别涉及一种图像分类方法,可用于训练数据标记缺失下的图像分类。
背景技术
迁移学习是把在一个领域中学习到的知识、经验“迁移”到另外一个不同但相关的领域,以提高模型的学习效率,而不用重新开始学。一般把待分类或待预测的领域称为“目标域”;把有大量标记数据的辅助域称为“源域”,二者是存在域差异的。利用迁移学习研究图像分类问题在国内外已取得了显著的成效。现有的迁移学习方法可分为基于样本、基于特征和基于模型的方法。
受到博弈论中二人零和博弈的启发,有学者提出生成式对抗网络GAN,其包含一对互相对抗的模块,分别是生成式模型和判别式模型,可简称为生成器和判别器。生成器可以生成数据,其原始输入是随机噪声数据,目的是尽可能逼近真实数据;判别器的目的是尽可能的区分出生成数据和真实数据。
受GAN中对抗思想的启发,有研究人员提出基于对抗思想进行迁移学习。在基于对抗思想的迁移学习方法中,生成器与GAN中生成样本这一目标不同,其不再真正生成数据,而是对原始数据进行特征提取,使得判别器无法对两个领域进行分辨,此时生成器可以称为特征提取器。基于对抗的迁移学习核心思想是训练两个神经网络:一个试图区分源域和目标域特征的判别网络,一个试图迷惑判别网络使其无法区分源域和目标域特征的特征提取网络,基于这样的领域对抗思想,特征提取器最终能够提取到判别器无法区分的域不变特征,即可迁移特征,因此基于这种可迁移特征训练的分类器可以直接用来分类目标域的数据。
多源迁移方法,是利用基分类器加权来进行多个源域的迁移,利用多个源域训练得到的多个分类器对目标域数据进行预测,结果加权得到目标域数据的最终标记。如A-SVM方法提出了一个自适应的支持向量机模型,其利用多个源域分类器集成得到一个目标域的支持向量机分类模型,但是该方法认为所有源域分类器对目标域的贡献相同,并未考虑到不同源域之间的差异。MultiSourceTrAdaBoost方法利用样本加权的方式来对多个源域进行迁移,该方法在每个源域和目标域组合上学习一个分类器,然后计算多个弱分类器的分类误差并进行权值更新。MultiSourceTrAdaBoost方法中源域样本的权值更新策略类似于TrAdaBoost算法,目标域样本的权值更新策略类似于AdaBoost算法,最终利用训练好的加权分类器对目标域样本预测。Sun等人提出两级多源迁移学习方法,分别基于边缘分布和条件分布进行加权,使得源域和目标域分布更相近。
随着深度神经网络和生成对抗网络的发展,基于深度网络的多源域适应方法近两年也得到了国内外学者们的关注,如多源域对抗网络MDAN,该方法使用对抗策略学习多个源域的具有域不变性和任务判别性的特征;Xu R等人提出把所有的源域和目标域数据映射到一个公共的特征空间,以学习可迁移的域不变特征。
上述方法虽能实现目标域数据标记缺失下的图像分类工作,但其由于使用同一网络对源域数据进行特征提取,导致源域数据丧失部分有效特征,影响最终分类效果。
发明内容
本发明的目的在于针对上述现有技术的不足,提出一种基于对抗融合多源迁移学习的图像分类方法,以提高训练数据标记缺失下的图像分类准确率。
为实现上述目的,本发明的技术方案包括如下步骤:
(1)建立由域共享子网络F与域特定子网络Fj构成的特征提取网络;
(2)使用特征提取网络从原始图像文件中提取图像特征:
2a)对于来自源域j的第i个训练样本
Figure BDA0002532171300000021
经过域共享子网络F,得到初步特征
Figure BDA0002532171300000022
其中θF表示F的网络参数,j=1...N,
Figure BDA0002532171300000023
N表示源域个数,
Figure BDA0002532171300000024
表示源域j中样本的数目;
对于来自目标域的第t个样本
Figure BDA0002532171300000025
经过域共享子网络F,得到初步特征
Figure BDA0002532171300000026
其中t=1...nT,nT表示目标域中样本的数目;
2b)将2a)中得到的初步特征输入到第j个源域特有的域特定子网络Fj中,得到原始图像的最终特征Fj(F(xq;θF);θFj),其中θFj表示Fj的网络参数,xq表示输入域特定子网络的第q个样本,
Figure BDA0002532171300000027
(3)将(2)中得到的最终特征输入到域判别器Dj中,得到输出Dj(Fj(F(xi;θF);θFj);θDj),利用该输出计算得到Dj的域判别损失LjDFFjDj),其中θDj表示Dj的网络参数;
(4)将(2)中得到的最终特征输入到分类器Cj中,得到不同的输出:
对于来自源域j的图像,只有源域分类器Cj被激活,输出
Figure BDA0002532171300000031
利用其输出计算得到Cj的分类损失LjCFFjCj),其中θCj表示Cj的网络参数;
对于来自目标域的图像,所有的分类器都被激活,输出N个P维预测向量,取每个P维向量中最大元素对应的类别标记,即可得到N个伪标记,其中,P表示目标域数据的类别总数;
(5)利用目标域样本的伪标记与源域j中的样本标记,计算源域j与目标域中同类别数据的最大均值差异MMD距离,并对所有类别的MMD距离求和得到
Figure BDA0002532171300000032
(6)根据域判别损失
Figure BDA0002532171300000033
分类损失
Figure BDA0002532171300000034
及所有类别的MMD距离之和
Figure BDA0002532171300000035
对特征提取网络、域判别器及分类器进行训练,得到训练后的特征提取网络、域判别器及分类器;
(7)将待测样本输入到训练后的特征提取网络、域判别器及分类器中,通过特征提取网络从待测样本中提取图像特征,并将该特征作为域判别器及分类器的输入进行域判别及分类,最终得到该待测样本的N个P维预测向量;
(8)计算每个P维预测向量的熵,并利用该熵值计算得到目标域样本的最终类别标记。
本发明与现有方法相比具有如下优点:
第一,本发明建立了由域共享子网络F及域特定子网络Fj构成的特征提取网络,通过域共享子网络提取各个域共有的数据特征,通过域特定子网络提取各个域特有的数据特征,使得提取到的最终特征保留了各个源域的特性。
第二,本发明通过最小化MMD距离使源域数据与目标域数据在整体分布对齐的基础上实现了条件分布对齐,提高了目标域数据的分类准确率。
第三,本发明通过熵值融合各个源域分类器的分类结果,提高了目标域数据的分类正确率。
附图说明
图1为本发明的实现流程图;
图2为本发明中训练及测试实验时使用的Office-31数据集部分示例图;
图3为本发明中训练及测试实验时使用的Office-Caltech10数据集部分示例图;
图4为本发明中训练及测试实验时使用的Office-Home数据集部分示例图。
具体实施方式
下面结合附图,对本发明的实施例和效果做进一步的详细描述。
参照图1,其中,S1...SN表示N个源域,T表示目标域,F及F1...FN分别表示域共享子网络及N个域特定子网络,D1...DN表示N个域判别器,C1...CN表示N个分类器,具体实现步骤如下:
步骤1,建立由域共享子网络F与域特定子网络Fj构成的特征提取网络。
域共享子网络F是由何恺明等人提出的残差神经网络ResNet50,该网络由卷积层后接4个残差块构成,旨在提取所有域共享的底层特征;
域特定子网络共有N个,每个子网络是由卷积层、批标准化层及relu激活函数构成的多层神经网络,该网络旨在提取与特定域相关的高层特征;
将域特定子网络Fj与域共享子网络F相连接,构成特征提取网络。
步骤2,使用特征提取网络从原始图像文件中提取图像特征。
所述原始图像文件,分别来自目标域及N个源域,其图像特征提取如下:
2.1)原始图像文件,首先经过域共享子网络F,提取得到图像初步特征,即:
对于来自源域j的第i个训练样本
Figure BDA0002532171300000041
经过域共享子网络F,得到初步特征
Figure BDA0002532171300000042
其中θF表示F的网络参数,j=1...N,
Figure BDA0002532171300000043
N表示源域个数,
Figure BDA0002532171300000044
表示源域j中样本的数目;
对于来自目标域的第t个样本
Figure BDA0002532171300000045
经过域共享子网络F,得到初步特征
Figure BDA0002532171300000046
其中t=1...nT,nT表示目标域中样本的数目;
2.2)将2.1)中得到的两类初步特征均输入到第j个源域特有的域特定子网络Fj中,得到原始图像的最终特征Fj(F(xq;θF);θFj),其中θFj表示Fj的网络参数,xq表示输入域特定子网络的第q个样本,
Figure BDA0002532171300000051
步骤3,利用最终特征获得域判别损失
Figure BDA0002532171300000052
3.1)将步骤2中得到的最终特征输入到域判别器Dj中,得到输出Dj(Fj(F(xi;θF);θFj);θDj);
所述域判别器共有N个,每个域判别器均由全连接层构成,其中判别器Dj用于区分样本来自源域j还是目标域,设定源域数据的域标签为0,目标域数据的域标签为1,对于源域数据,输出Dj(Fj(F(xi;θF);θFj);θDj)为0,对于目标域数据,输出Dj(Fj(F(xi;θF);θFj);θDj)为1;
3.2)利用输出Dj(Fj(F(xi;θF);θFj);θDj)计算得到Dj的域判别损失
Figure BDA0002532171300000053
Figure BDA0002532171300000054
其中,
Figure BDA0002532171300000055
代表第j个源域的样本数量,nT代表目标域的样本数量,dq表示样本xq的域标签;
Figure BDA0002532171300000059
是一个指示函数,当dq=l时,该指示函数取值为1,否则,为0。
步骤4,利用最终特征获取源域样本分类损失及目标域样本的伪标记。
4.1)将步骤2中得到的最终特征输入到分类器Cj中,得到不同的输出:
所述分类器共有N个,每个分类器均由全连接层后接softmax函数构成;
对于源域j中的图像,对于来自源域j的图像,只有源域分类器Cj被激活,输出P维预测向量
Figure BDA0002532171300000056
其中,θCj表示Cj的网络参数,P表示目标域数据的类别总数;
对于来自目标域的图像,所有的分类器都被激活,输出N个P维预测向量,取每个P维向量中最大元素对应的类别标记,即可得到N个伪标记;
4.2)利用输出
Figure BDA0002532171300000057
计算得到Cj的分类损失
Figure BDA0002532171300000058
Figure BDA0002532171300000061
其中,
Figure BDA0002532171300000062
代表第j个源域的样本数量,k为样本类别标签,
Figure BDA0002532171300000063
表示源域样本
Figure BDA0002532171300000064
的类别标签,P是类别总数;
Figure BDA00025321713000000617
是一个指示函数,当
Figure BDA0002532171300000065
时该指示函数值取1,否则取0。
步骤5,利用目标域样本的伪标记与源域j中的样本标记,计算源域j与目标域中所有类别的MMD距离之和。
5.1)计算源域j与目标域中同类别数据的最大均值差异MMD距离:
Figure BDA0002532171300000066
其中,
Figure BDA0002532171300000067
为源域j中标记为k的样本数据与目标域中伪标记为k的样本数据的最大均值差异,
Figure BDA0002532171300000068
是源域j中类别为k的第i个样本,
Figure BDA0002532171300000069
为该源域j中类别为k的样本数目,
Figure BDA00025321713000000610
为目标域中类别为k的第m个样本,
Figure BDA00025321713000000611
为目标域中类别伪标记为k的样本数目;
5.2)计算得到所有类别的MMD距离之和
Figure BDA00025321713000000612
Figure BDA00025321713000000613
其中,P表示类别总数。
步骤6,对特征提取网络、域判别器及分类器进行训练。
根据域判别损失
Figure BDA00025321713000000614
分类损失
Figure BDA00025321713000000615
及所有类别的MMD距离之和
Figure BDA00025321713000000616
对网络参数进行更新以完成对特征提取网络、域判别器及分类器的训练,实现如下:
6.1)将域判别损失
Figure BDA0002532171300000071
反传到特征提取网络及域判别器中,通过最大化
Figure BDA0002532171300000072
更新特征提取网络中的域共享子网络参数θF和域特定子网络参数θFj,同时通过最小化
Figure BDA0002532171300000073
更新域判别器Dj的网络参数θDj,使得特征提取网络与域判别器产生对抗,此时特征提取网络能够提取到具有域不变特性的样本数据特征;
6.2)将分类损失
Figure BDA0002532171300000074
反传到特征提取网络及分类器中,通过最小化
Figure BDA0002532171300000075
更新特征提取网络中的域共享子网络参数θF、域特定子网络参数θFj及分类器的网络参数θCj
6.3)将所有类别的MMD距离之和
Figure BDA0002532171300000076
反传到特征提取网络中,通过最小化
Figure BDA0002532171300000077
更新特征提取网络中的域共享子网络参数θF和域特定子网络参数θFj,以对齐源域j中数据与目标域数据的条件分布。
步骤7,通过训练后的网络获得待测样本的N个P维预测向量。
7.1)将待测样本输入到训练后的特征提取网络、域判别器及分类器中,通过特征提取网络从待测样本中提取出图像特征;
7.2)将7.1)中所得图像特征作为域判别器及分类器的输入进行域判别及分类,最终得到该待测样本的N个P维预测向量。
步骤8,通过待测样本的N个P维预测向量获得待测样本的最终类别标记。
8.1)计算每个P维预测向量的熵,公式如下:
Figure BDA0002532171300000078
其中,Hj为测试样本经过分类器Cj所得预测向量的熵,
Figure BDA0002532171300000079
为测试样本
Figure BDA00025321713000000710
在分类器Cj上的输出标记,
Figure BDA00025321713000000711
是分类器Cj对第i个测试样本预测结果的第k个分量;
8.2)利用熵值计算得到待测样本的最终预测向量如下:
Figure BDA00025321713000000712
其中,
Figure BDA00025321713000000713
表示测试样本
Figure BDA00025321713000000714
的最终预测向量,θF *为更新后的域共享子网络参数,θFj *为更新后的域特定子网络参数,θCj *为更新后分类器的网络参数,ωj为测试样本经过分类器Cj所得预测结果的权重,ωj的计算方式为:
Figure BDA0002532171300000081
8.3)取最终预测向量
Figure BDA0002532171300000082
中最大元素对应的类别标记,该标记即为待测样本的最终类别标记,完成对待测样本的分类。
本发明的效果可通过以下实验做进一步说明。
一.实验条件
实验环境:本实验在集成Python环境的Anaconda下基于Pytorch搭建,算法逻辑与神经网络利用Python实现。
参数设置:初始域共享子网络F参数θF为ResNet50中的参数值,域特定子网络Fj的网络参数θFj,域判别器Dj的网络参数θDj,分类器Cj的网络参数θCj通过随机初始化得到。
实验数据选取及设置:本实验在Office-31、Office-Caltech10和Office-Home三个不同规模的公开数据集上评估本发明方法的分类性能。
所述Office-31是物体识别数据集,包括3个子集,分别为Amazon、Webcam和Dslr,这三个子集分布存在差异,其区别在于,Amazon数据集中的图像是直接从因特网中下载的,分辨率中等;Webcam是使用网络摄像头采集的图像,是低分辨率的;Dslr则是在实际环境中用数字相机采集的高分辨率图像,存在有噪声。这三个子集均包括31类图像,其中,Amazon数据集包括2817幅图像,Webcam数据集包括795幅图像,Dslr数据集包括498幅图像。该数据集部分示例如图2所示,其中前两列是Amazon数据集中的部分图像示例,中间两列是Dslr数据集中的部分示例,最后两列是Webcam中的部分图像。
所述Office-Caltech10数据集是由Office-31数据集和Caltech-256数据集中10个公共类别图像组成的数据集,包括四个子集,分别是Amazon、Webcam、Dslr和Caltech,可分别缩写为Ama、Web、Dsl、Cal,其中,Ama包含958幅图像,Web包含295幅图像,Dsl包含157幅图像,Cal包含1123幅图像。Office-Caltech10数据集中的部分图像如图3所示,图3中每一行代表不同的子集,从上往下依次是Ama,Calt,Dsl和Web。
所述Office-Home数据集包含4个不同子集,分别是Artistic images,Clip Art,Product images和Real-World images,可分别缩写为Art,Cli,Pro,Rea。每个子集均包含65个类别的图像,其中Art包含2427幅图像,Cli包含4365幅图像,Pro包含4439幅图像,Rea包含4357幅图像。Office-Home数据集中的部分图像示例如图4所示,图4中每一行代表不同的子集,从上往下依次是Art子集,Cli子集,Pro子集和Rea子集。
实验开始前首先进行简单的数据预处理,原始数据集里的图像数据大小不一,本实验中将所有图像的尺寸标准化为256*256,然后随机剪裁为224*224的图像块,特征提取网络的输入为224*224*3。
实验方法设置:实验中将本发明方法与已有方法进行比较,体现本发明方法的分类性能,已有方法包括以下6种:
1.深度域混淆方法DDC,
2.深度域适应方法DAN,
3.对抗域适应方法RevGrad,
4.多特征空间适应方法MFSAN,
5.矩匹配多源域适应方法M3SDA,
6.深度混合域适应方法DCTN。
二.实验内容
实验1:在Office-31数据集上采用本发明和现有的DDC、DAN、RevGrad、DCTN、MFSAN方法分别进行三个迁移任务上的图像分类实验,结果如表1所示:
表1在Office-31数据集上的实验结果
Figure BDA0002532171300000091
表1中Amazon、Dslr→Webcam表示以Amazon、Dslr为源域,以Webcam为目标域进行实验,Average表示各分类方法在三个任务上的平均分类准确率。
实验2:在Office-Caltech10数据集上采用本发明方法和现有的DDC、DAN、DCTN、M3SDA方法分别进行四个迁移任务上的图像分类实验,结果如表2所示:
表2在Office-Caltech 10数据集上的实验结果
Figure BDA0002532171300000101
表2中Ama、Web、Dsl→Cal表示以Ama、Web、Dsl为源域,以Cal为目标域进行实验,Average表示各分类方法在四个任务上的平均分类准确率。
实验3:在Office-Home数据集上采用本发明和现有DDC、DAN、RevGrad、M3SDA方法分别在四个迁移任务上进行图像分类实验,结果如表3所示:
表3在Office-Home数据集上的实验结果
Figure BDA0002532171300000102
表3中Art、Cli、Pro→Rea表示以Art、Cli、Pro为源域,以Rea为目标域进行实验,Average表示各分类方法在四个任务上的平均分类准确率。
上述实验1,实验2,实验3的结果表明,本发明方法在Office-31、Office-Caltech10和Office-Home这三个数据集中各个迁移任务上的分类准确率均优于现有方法。验证了本发明方法在进行图像分类时由于考虑到各个域特有的数据特征、源域数据和目标域数据的条件分布及利用熵值融合各个源域分类器的分类结果,因而有助于提高目标域数据的分类准确率。

Claims (10)

1.一种基于对抗融合多源迁移学习的图像分类方法,其特征在于,包括如下:
(1)建立由域共享子网络F与域特定子网络Fj构成的特征提取网络;
(2)使用特征提取网络从原始图像文件中提取图像特征:
2a)对于来自源域j的第i个训练样本
Figure FDA0002532171290000011
经过域共享子网络F,得到初步特征
Figure FDA0002532171290000012
其中θF表示F的网络参数,j=1...N,
Figure FDA0002532171290000013
N表示源域个数,
Figure FDA0002532171290000014
表示源域j中样本的数目;
对于来自目标域的第t个样本
Figure FDA0002532171290000015
经过域共享子网络F,得到初步特征
Figure FDA0002532171290000016
其中t=1...nT,nT表示目标域中样本的数目;
2b)将2a)中得到的初步特征输入到第j个源域特有的域特定子网络Fj中,得到原始图像的最终特征Fj(F(xq;θF);θFj),其中θFj表示Fj的网络参数,xq表示输入域特定子网络的第q个样本,
Figure FDA0002532171290000017
(3)将(2)中得到的最终特征输入到域判别器Dj中,得到输出Dj(Fj(F(xi;θF);θFj);θDj),利用该输出计算得到Dj的域判别损失
Figure FDA0002532171290000018
其中θDj表示Dj的网络参数;
(4)将(2)中得到的最终特征输入到分类器Cj中,得到不同的输出:
对于来自源域j的图像,只有源域分类器Cj被激活,输出
Figure FDA0002532171290000019
利用其输出计算得到Cj的分类损失
Figure FDA00025321712900000110
其中θCj表示Cj的网络参数;
对于来自目标域的图像,所有的分类器都被激活,输出N个P维预测向量,取每个P维向量中最大元素对应的类别标记,即可得到N个伪标记,其中,P表示目标域数据的类别总数;
(5)利用目标域样本的伪标记与源域j中的样本标记,计算源域j与目标域中同类别数据的最大均值差异MMD距离,并对所有类别的MMD距离求和得到
Figure FDA0002532171290000021
(6)根据域判别损失
Figure FDA0002532171290000022
分类损失
Figure FDA0002532171290000023
及所有类别的MMD距离之和
Figure FDA0002532171290000024
对特征提取网络、域判别器及分类器进行训练,得到训练后的特征提取网络、域判别器及分类器;
(7)将待测样本输入到训练后的特征提取网络、域判别器及分类器中,通过特征提取网络从待测样本中提取图像特征,并将该特征作为域判别器及分类器的输入进行域判别及分类,最终得到该待测样本的N个P维预测向量;
(8)计算每个P维预测向量的熵,并利用该熵值计算得到目标域样本的最终类别标记。
2.根据权利要求1所述的方法,其特征在于:(1)中的域共享子网络F是由卷积层后接4个残差块构成的残差神经网络。
3.根据权利要求1所述的方法,其特征在于:(1)中的域特定子网络共有N个,每个子网络是由卷积层、批标准化层及relu激活函数构成的多层神经网络。
4.根据权利要求1所述的方法,其特征在于:(3)中域判别器共有N个,每个域判别器均由全连接层构成。
5.根据权利要求1所述的方法,其特征在于:(3)中的域判别损失函数
Figure FDA0002532171290000025
表示如下:
Figure FDA0002532171290000031
其中,
Figure FDA0002532171290000032
代表第j个源域的样本数量,nT代表目标域的样本数量,dq表示样本xq的域标签;
Figure FDA00025321712900000311
是一个指示函数,当dq=l时,该指示函数取值为1,否则,为0。
6.根据权利要求1所述的方法,其特征在于:(4)中分类器共有N个,每个分类器均由全连接层后接softmax函数构成,其输出为P维预测向量。
7.根据权利要求1所述的方法,其特征在于:(4)中的分类损失函数
Figure FDA0002532171290000033
表示如下:
Figure FDA0002532171290000034
其中,
Figure FDA0002532171290000035
代表第j个源域的样本数量,k为样本类别标签,
Figure FDA0002532171290000036
表示源域样本
Figure FDA0002532171290000037
的类别标签,P是类别总数;
Figure FDA00025321712900000312
是一个指示函数,当
Figure FDA0002532171290000038
时该指示函数值取1,否则取0。
8.根据权利要求1所述的方法,其特征在于:(5)中所有类别的MMD距离之和
Figure FDA0002532171290000039
表示如下:
Figure FDA00025321712900000310
其中,P是类别总数,
Figure FDA0002532171290000041
是源域j中类别为k的第i个样本,
Figure FDA0002532171290000042
为该源域j中类别为k的样本数目;同理,
Figure FDA0002532171290000043
为目标域中类别为k的第m个样本,
Figure FDA0002532171290000044
为目标域中类别伪标记为k的样本数目。
9.根据权利要求1所述的方法,其特征在于:(6)中对特征提取网络、域判别器及分类器训练,实现如下:
6a)将域判别损失
Figure FDA0002532171290000045
反传到特征提取网络及域判别器中,通过最大化
Figure FDA0002532171290000046
更新特征提取网络中的域共享子网络参数θF和域特定子网络参数θFj,同时通过最小化
Figure FDA0002532171290000047
更新域判别器Dj的网络参数θDj
6b)将分类损失
Figure FDA0002532171290000048
反传到特征提取网络及分类器中,通过最小化
Figure FDA0002532171290000049
更新特征提取网络中的域共享子网络参数θF、域特定子网络参数θFj及分类器的网络参数θCj
6c)将所有类别的MMD距离之和
Figure FDA00025321712900000410
反传到特征提取网络中,通过最小化
Figure FDA00025321712900000411
更新特征提取网络中的域共享子网络参数θF和域特定子网络参数θFj
10.根据权利要求1所述的方法,其特征在于:(8)中通过计算各个P维预测向量的熵,并利用该熵值计算得到待测样本的最终类别标记,实现如下:
8a)P维预测向量的熵计算如下:
Figure FDA00025321712900000412
其中,Hj为待测样本经过分类器Cj所得预测向量的熵,
Figure FDA00025321712900000413
为测试样本
Figure FDA00025321712900000414
在分类器Cj上的输出标记,
Figure FDA00025321712900000415
是分类器Cj对第t个测试样本预测结果的第k个分量;
8b)利用熵值计算得到待测样本的最终预测向量如下:
Figure FDA0002532171290000051
其中,
Figure FDA0002532171290000052
表示测试样本
Figure FDA0002532171290000053
的最终预测向量,θF *为更新后的域共享子网络参数,θFj *为更新后的域特定子网络参数,θCj *为更新后分类器的网络参数,ωj为测试样本经过分类器Cj所得预测结果的权重,ωj的计算公式为:
Figure FDA0002532171290000054
8c)取最终预测向量
Figure FDA0002532171290000055
中最大元素对应的类别标记,该标记即为待测样本的最终类别标记。
CN202010521228.0A 2020-06-10 2020-06-10 基于对抗融合多源迁移学习的图像分类方法 Active CN111738315B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010521228.0A CN111738315B (zh) 2020-06-10 2020-06-10 基于对抗融合多源迁移学习的图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010521228.0A CN111738315B (zh) 2020-06-10 2020-06-10 基于对抗融合多源迁移学习的图像分类方法

Publications (2)

Publication Number Publication Date
CN111738315A CN111738315A (zh) 2020-10-02
CN111738315B true CN111738315B (zh) 2022-08-12

Family

ID=72648514

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010521228.0A Active CN111738315B (zh) 2020-06-10 2020-06-10 基于对抗融合多源迁移学习的图像分类方法

Country Status (1)

Country Link
CN (1) CN111738315B (zh)

Families Citing this family (19)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112330625B (zh) * 2020-11-03 2023-03-24 杭州迪英加科技有限公司 免疫组化核染色切片细胞定位多域共适应训练方法
CN112836795B (zh) * 2021-01-27 2023-08-18 西安理工大学 一种多源非均衡域自适应方法
CN113011487B (zh) * 2021-03-16 2022-11-18 华南理工大学 一种基于联合学习与知识迁移的开放集图像分类方法
CN113011513B (zh) * 2021-03-29 2023-03-24 华南理工大学 一种基于通用域自适应的图像大数据分类方法
CN113157678B (zh) * 2021-04-19 2022-03-15 中国人民解放军91977部队 一种多源异构数据关联方法
CN113076927B (zh) * 2021-04-25 2023-02-14 华南理工大学 基于多源域迁移的指静脉识别方法及系统
CN113361566B (zh) * 2021-05-17 2022-11-15 长春工业大学 用对抗性学习和判别性学习来迁移生成式对抗网络的方法
CN113378904B (zh) * 2021-06-01 2022-06-14 电子科技大学 一种基于对抗域自适应网络的图像分类方法
CN113378981B (zh) * 2021-07-02 2022-05-13 湖南大学 基于域适应的噪音场景图像分类方法及系统
CN113591736A (zh) * 2021-08-03 2021-11-02 北京百度网讯科技有限公司 特征提取网络、活体检测模型的训练方法和活体检测方法
CN113538413B (zh) * 2021-08-12 2023-11-24 泰康保险集团股份有限公司 图像检测方法及装置、电子设备和存储介质
CN113873024B (zh) * 2021-09-23 2022-09-23 中国科学院上海微系统与信息技术研究所 一种边缘雾网络中数据差分化下载方法
CN114020879B (zh) * 2022-01-04 2022-04-01 深圳佑驾创新科技有限公司 多源跨领域的文本情感分类网络的训练方法
CN114511737B (zh) * 2022-01-24 2022-09-09 北京建筑大学 图像识别域泛化模型的训练方法
CN114783072B (zh) * 2022-03-17 2022-12-30 哈尔滨工业大学(威海) 一种基于远域迁移学习的图像识别方法
CN114694150B (zh) * 2022-05-31 2022-10-21 成都考拉悠然科技有限公司 一种提升数字图像分类模型泛化能力的方法及系统
CN115578248B (zh) * 2022-11-28 2023-03-21 南京理工大学 一种基于风格引导的泛化增强图像分类算法
CN116758353B (zh) * 2023-06-20 2024-01-23 大连理工大学 基于域特定信息滤除的遥感图像目标分类方法
CN117152563B (zh) * 2023-10-16 2024-05-14 华南师范大学 混合目标域自适应模型的训练方法、装置及计算机设备

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10832166B2 (en) * 2016-12-20 2020-11-10 Conduent Business Services, Llc Method and system for text classification based on learning of transferable feature representations from a source domain
CN109753992B (zh) * 2018-12-10 2020-09-01 南京师范大学 基于条件生成对抗网络的无监督域适应图像分类方法
CN110135579A (zh) * 2019-04-08 2019-08-16 上海交通大学 基于对抗学习的无监督领域适应方法、系统及介质
CN110837850B (zh) * 2019-10-23 2022-06-21 浙江大学 一种基于对抗学习损失函数的无监督域适应方法

Also Published As

Publication number Publication date
CN111738315A (zh) 2020-10-02

Similar Documents

Publication Publication Date Title
CN111738315B (zh) 基于对抗融合多源迁移学习的图像分类方法
CN111368896B (zh) 基于密集残差三维卷积神经网络的高光谱遥感图像分类方法
CN109949317B (zh) 基于逐步对抗学习的半监督图像实例分割方法
Hao et al. An end-to-end architecture for class-incremental object detection with knowledge distillation
Carrara et al. Adversarial examples detection in features distance spaces
CN109583322B (zh) 一种人脸识别深度网络训练方法和系统
CN108021947B (zh) 一种基于视觉的分层极限学习机目标识别方法
CN113076994B (zh) 一种开集域自适应图像分类方法及系统
CN109063649B (zh) 基于孪生行人对齐残差网络的行人重识别方法
CN108846413B (zh) 一种基于全局语义一致网络的零样本学习方法
CN109344856B (zh) 一种基于多层判别式特征学习的脱机签名鉴别方法
CN109993201A (zh) 一种图像处理方法、装置和可读存储介质
Chen et al. Automated design of neural network architectures with reinforcement learning for detection of global manipulations
CN110598759A (zh) 一种基于多模态融合的生成对抗网络的零样本分类方法
CN108052959A (zh) 一种提高深度学习图片识别算法鲁棒性的方法
CN112232395B (zh) 一种基于联合训练生成对抗网络的半监督图像分类方法
CN110569780A (zh) 一种基于深度迁移学习的高精度人脸识别方法
CN114387473A (zh) 一种基于基类样本特征合成的小样本图像分类方法
CN115690541A (zh) 提高小样本、小目标识别准确率的深度学习训练方法
Xue et al. Region comparison network for interpretable few-shot image classification
CN110414626A (zh) 一种猪只品种识别方法、装置和计算机可读存储介质
Li et al. Adversarial domain adaptation via category transfer
CN112084897A (zh) 一种gs-ssd的交通大场景车辆目标快速检测方法
CN116452862A (zh) 基于领域泛化学习的图像分类方法
CN105512675A (zh) 一种基于记忆性多点交叉引力搜索的特征选择方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant