CN111368886B - 一种基于样本筛选的无标注车辆图片分类方法 - Google Patents

一种基于样本筛选的无标注车辆图片分类方法 Download PDF

Info

Publication number
CN111368886B
CN111368886B CN202010114792.0A CN202010114792A CN111368886B CN 111368886 B CN111368886 B CN 111368886B CN 202010114792 A CN202010114792 A CN 202010114792A CN 111368886 B CN111368886 B CN 111368886B
Authority
CN
China
Prior art keywords
network
domain
data set
model
feature
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010114792.0A
Other languages
English (en)
Other versions
CN111368886A (zh
Inventor
贺海
徐雪妙
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
South China University of Technology SCUT
Original Assignee
South China University of Technology SCUT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by South China University of Technology SCUT filed Critical South China University of Technology SCUT
Priority to CN202010114792.0A priority Critical patent/CN111368886B/zh
Publication of CN111368886A publication Critical patent/CN111368886A/zh
Application granted granted Critical
Publication of CN111368886B publication Critical patent/CN111368886B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V2201/00Indexing scheme relating to image or video recognition or understanding
    • G06V2201/08Detecting or categorising vehicles
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于样本筛选的无标注车辆图片分类方法,包括步骤:1)数据获取;2)数据处理;3)模型构建;4)定义损失函数;5)模型训练;6)模型验证;7)模型应用。本发明减缓了现有车辆图片分类技术数据匮乏的缺点,通过结合特征提取网络提取图像高层语义信息的能力,对抗领域自适应网络对齐拉近两个域数据分布的能力,样本筛选损失函数从特征级别和标注级别筛选重要样本和异常样本并选择性增强的能力,以及通用分类器网络的精准分类能力,更准确高效地完成无标注车辆图片分类任务。

Description

一种基于样本筛选的无标注车辆图片分类方法
技术领域
本发明涉及计算机图像处理的技术领域,尤其是指一种基于样本筛选的无标注车辆图片分类方法。
背景技术
随着现代化进程的发展和国民消费水平的不断提高,交通车辆的数目日益增长,对于车辆的实时监控和管理仅靠人力过于繁琐,而借助计算机和深度学习完成图片分类和分析为智能交通管理带来了新的发展。
在计算机图像分类领域,一个性能优异的深度学习分类模型往往是海量数据驱动的。但在某些特定场景下,比如城市交通车辆数据或者高速公路车辆数据,是需要通过专门的相关部门获取并且需要有经验的人士标注的;同时交通路况错综复杂,在一个例如城市道路场景下标注的数据集训练好了一个模型,应用到高速公路上做车辆分类性能却大打折扣。为了节约目标数据集(目标域,比如高速公路)的标注成本,常见的做法是,借助一个不同但相关的有标注的数据集(源域,比如城市道路),通过一个深度神经网络分类模型统一地拉近源域和目标域的数据分布来抽取这两个域的域不变特征,从而在目标域没有标注的情况下,将从源域学到的知识迁移到目标域。这种方法称为领域自适应分类。但这种基于度量学习的方法,其核心在于估计域的真实分布。不幸的是,一些非预期的噪声样本可能会严重影响源域和目标域数据分布的估计,比如糟糕的成像条件和错误标注。如果对所有样本平等对待,这些异常样本在作全局统计数据分布时造成的影响,很容易使得源域和目标域数据分布在拉近过程中造成错位或者负迁移,影响分类模型的泛化能力。
发明内容
本发明的目的在于克服现有车辆图片分类技术数据匮乏的缺点与不足,提出了一种基于样本筛选的无标注车辆图片分类方法,该方法可以很好的区分数据集中的异常样本和主流样本,减缓了之前方法统计样本分布时候造成的偏差,通过拉近被正确估计的数据分布而学习到更有效的域不变特征。
为实现上述目的,本发明所提供的技术方案为:一种基于样本筛选的无标注车辆图片分类方法,包括以下步骤:
1)数据获取
鉴于测试目标域数据集的标注成本昂贵,即高速公路的目标域数据集的标注成本昂贵,需要借助一个不同但相关的有标注的源域数据集,即城市道路的源域数据集,其中,高速公路的目标域数据集必须和城市道路的源域数据集中有相同类别的待分类的车辆,但是两个域车辆的角度和拍摄环境有区别;然后划分目标域数据集为训练数据集、验证数据集和测试数据集,源域数据集全为训练数据集;
2)数据处理
将源域数据集的图像、域标注和类别标注数据及目标域数据集的图像和域标注,通过预处理转化为训练领域自适应分类模型所需要的格式,然后成对地输入到车辆图片分类网络模型中;
3)模型构建
根据训练目标以及模型的输入输出形式,构造一个能够学习域不变特征的对抗深度神经网络模型,其由特征提取网络、对抗领域自适应网络和通用分类网络组成;
4)定义损失函数
根据训练目标以及模型的架构,除了必需的分类和域对抗度量损失函数,额外提出了特征层面和标注层面两个按样本重要性增强损失函数;
5)模型训练
初始化模型各网络层的参数,不断迭代输入成对的源域和目标域训练样本,根据损失函数计算得到模型各网络层的损失值,再通过反向传播计算出各网络层参数的梯度,通过随机梯度下降法对各网络层的参数进行更新;
6)模型验证
使用目标域数据集中的验证数据集对训练得到的模型进行验证,测试模型的泛化性能,调整超参数;
7)模型应用
使用目标域数据集中的测试数据集测试训练得到的模型,并应用到目标域车辆图片的分类任务中。
所述步骤2)包括以下步骤:
2.1)将源域和目标域数据集中的图像缩放到长和宽为256×256像素大小;
2.2)在缩放后的图像上,随机裁剪得到224×224像素大小的矩形图像;
2.3)以0.5的概率随机水平翻转裁剪后的图像;
2.4)将随机翻转后的图像从[0,255]转换到[-1,1]的范围内;
2.5)将源域数据集中的类别标注数据转换为One-Hot向量,源域和目标域的域标注分别设为1和0。
所述步骤3)包括以下步骤:
3.1)构造特征提取网络
特征提取网络相当于一个编码器,将步骤2)处理后的图像输入网络,能够提取其高层的域不变的语义信息并输出为一个低维的特征向量;特征提取网络是由一系列的残差模块、全连接层、批量归一化层、非线性激活层以及随机失活层级联而成,残差模块是由卷积层、批量归一化层、非线性激活层、池化层级联而成;残差模块能够防止梯度消失,提高网络学习能力;全连接层能够统筹全局信息;批量归一化层能够归一化特征,加速网络收敛;非线性激活层能够带来更多的非线性;随机失活层能够稀疏网络,防止过拟合,其随机失活概率为0.5;
3.2)构造对抗领域自适应网络
对抗领域自适应网络主要负责拉近源域数据高维特征分布和目标域数据高维特征分布之间的距离,迫使特征提取网络学习到两个域共有的域不变特征;对抗领域自适应网络主要由三层神经元块级联而成,而该神经元块由全连接层、批量归一化层、非线性激活层以及随机失活层构成;
对抗领域自适应网络的输入为特征提取网络所编码的低维向量,输出为预测出的域标注;如果特征向量来自源域数据集,则预期预测为1;如果特征向量来自目标域数据集,则预期预测为0;但是当网络在进行反向转播的时候,梯度通过对抗领域自适应网络而准备传播到特征提取网络,中间会经过一个梯度翻转层,该层能够将通过它的梯度取反,这样特征提取网络误认为自己抽取到了错误的特征,实现两个网络对抗更新,对抗学习迫使源域数据特征和目标域数据特征映射至同一隐空间,使得对抗领域自适应网络无法分辨出特征来自哪个域;
3.3)构造通用分类网络
通用分类网络主要负责对特征提取网络输出的低维向量作类别分类,其由一层全连接层构成,将低维向量经过矩阵运算得到一个类别长度的向量,该向量数值最大的位置所对应的类别,即为预测的类别。
所述步骤4)包括以下步骤:
4.1)定义特征级调控损失函数
基于模型压缩领域的特征模长越小则信息量越小的原则,能够推测出特征提取网络的输出特征模长,反映了样本的重要程度;对于重要的样本,在对抗领域自适应网络中要被相应地增强,即特征提取网络中特征模越大,在对抗领域自适应网络中特征模长也相应越大,反之对于异常值在特征提取网络中特征模长偏小,这样在作对抗训练的时候其对抗领域自适应网络特征模长也会相应越小,因此,为了实现样本筛选和对抗特征选择性增强目的,特征级调控损失函数能够定义为对抗特征模长和提取特征模长的最小二乘损失(Least Square Loss),公式如下所示:
Figure BDA0002391144950000051
式中,n表示样本总数,α表示缩减因子,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,fd(xi)表示对抗领域自适应网络输出的特征向量,fg(xi)表示特征提取网络输出的特征向量;
4.2)定义标注级调控损失函数
除了在特征级筛选样本,标注级的样本筛选依然值得考虑,熵最小化原则表明,分类器类别预测结果的熵能够有效帮助低密度类别间的分离,基于此能够推测出,对于重要的样本都集中在输出向量的熵小的地方,而异常值的熵比较大,即难以被分类器区分;为了实现对抗性域对齐过程中真实数据分布的预测,定义标注级调控损失函数来减弱异常值的影响,即通用分类器预测结果的熵作为筛选指标,通用分类器输出向量的熵越小,越不可能是异常值,对抗领域自适应网络在预测源域和目标域输出向量的熵也应该越小,基于此,标注级调控损失函数定义为通用分类网络输出向量的熵作为权重,加权到对抗领域自适应网络输出向量的熵的最小化过程中,公式如下所示:
Figure BDA0002391144950000061
式中:n表示样本总数,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,H(g)=-∑glog(g)表示预测类别向量g的熵,即H(pd(xi))表示对抗领域自适应网络输出向量的熵,H(pg(xi))表示通用分类网络输出向量的熵;
4.3)定义对抗领域自适应损失函数
领域自适应的目的是希望将源域和目标域数据映射到同一个高维语义空间中,通过在该空间中将两个域数据分布对齐拉近;对抗领域自适应使用对抗学习的方式,设置一个对抗领域自适应网络和一个特征提取网络,对抗领域自适应网络尽可能分辨样本的输入是否来自源域,而特征提取网络尽可能去欺骗对抗领域自适应网络,通过两者的博弈来增强特征提取网络抽取域不变特征的能力,这样的对抗方式已被证明为最小化源域和目标域特征之间的相对熵距离,基于此,定义对抗领域自适应损失函数为:
Figure BDA0002391144950000062
式中,ns和nt分别表示源域和目标域样本数,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,pd(xi)表示对抗领域自适应网络的输出向量;
4.4)定义通用分类网络的损失函数
通用分类网络将源域数据经过特征提取网络提取到特征作为输入,输出类别数量的一维向量,该向量用于与该输入对应的One-Hot标注作交叉熵,其分类任务损失函数定义为交叉熵损失,公式如下:
Figure BDA0002391144950000071
式中,ns表示源域样本数,Ds表示源域数据集,xi表示输入的来自源域的图片数据,pg(xi)表示通用分类网络的输出向量;yi表示标注的One-Hot向量;
4.5)定义总损失函数
步骤4.1)和步骤4.2)中的两个样本筛选调控损失函数搭配对抗领域自适应损失函数,能够实现源域和目标域真实分布之间的拉近对齐,之后搭配上交叉熵分类损失函数使得网络具有类别鉴别能力,总损失函数定义为:
Figure BDA0002391144950000072
式中,w1和w2分别为用来权衡特征级调控损失和标注级调控损失的参数。
所述步骤5)包括以下步骤:
5.1)初始化模型各层参数
各层参数的初始化采用的是深度卷积神经网络中使用到的方法,具体是:对于特征提取网络参数采用在ImageNet数据集上预训练好的ResNet-50网络模型参数作为初始值;对于通用分类器和对抗领域自适应网络中的全连接层采用均值为0,标准差为0.02的高斯分布进行初始化;对所有的批量归一化层参数采用均值为1,标准差为0.02的高斯分布进行初始化;
5.2)训练模型
随机将步骤2)处理过的成对图像处理,经过特征提取网络得到相应的低维特征向量,该部分特征向量划分出源域部分经过通用分类网络,再通过计算源域数据的分类损失值,同时该部分的源域和目标域特征向量也会经过对抗领域自适应网络预测对应的特征向量来自源域还是目标域,并通过分别计算相应的特征级调控损失值和标注级调控损失值,通过反向传播该误差值,计算各个网络的各层参数的梯度,再通过随机梯度下降算法根据梯度对各层参数进行优化,实现每轮网络的训练;
5.3)重复步骤5.2)直至模型能够对目标域数据集中的测试集数据鲁棒地分类。
在步骤6)中,随机从目标域数据集中的验证数据集中取出一些原始图像,经过步骤2)处理后,输入到步骤5)训练好的网络模型,让该网络模型去预测其类别,通过输出的结果与对应的标注数据进行比对,从而判断该训练好的网络模型对目标域数据的泛化能力,并对网络超参数进行调整。
在步骤7)中,随机从目标域数据集中的测试数据集取出一些原始图像,经过步骤2)处理后,输入到步骤5)训练好的网络模型,让该网络模型去预测其类别,而后再应用到目标域车辆图片的分类任务中。
本发明与现有技术相比,具有如下优点与有益效果:
1、提出了一种按重要性样本筛选机制,在特征提取网络和通用分类网络的指导下优化对抗领域自适应网络中每个样本的训练梯度,为此,引入了特征级调控损失和标注级调控损失来按样本重要程度选择性增强网络学习能力。
2、研究了数据采样和异常值处理问题,并证明了特征模长和输出向量的熵在指示数据重要性方面的有效性,即特征模长越长,或者输出向量的熵越小,越不可能是异常值。
3、本发明方法不仅在标准领域自适应条件下达到最优,而且在部分领域自适应上都优于最新结果,此外,该方法简单易行,只需几行代码即可实现,且所提出的按重要性采样机制不会引入任何其它参数。
附图说明
图1为本发明方法流程图。
图2为本发明整体网络示意图。
图3为对抗领域自适应网络示意图。
图4为通用分类网络示意图。
具体实施方式
下面结合具体实施例对本发明作进一步说明。
如图1所示,本实施例所提供的基于样本筛选的无标注车辆图片分类方法方法,其具体情况如下:
步骤1,从两个不同场景下获取同类别集合的两个图片数据集,对其中一个数据量大的源域数据集进行人工标注,另外一个目标域数据集不标注,由于测试目标域数据集高昂的标注成本,因此需要借助一个不同但相关的有标注的大规模的源域数据集,比如高速公路目标域数据集必须和城市道路源域数据集中有相同类别的待分类的车辆,但是两个域车辆的角度和拍摄环境有较大区别。然后划分目标域数据集为训练数据集、验证数据集和测试数据集,源域数据集全为训练数据集。
步骤2,将两个域的图像数据集的图像和标注数据通过预处理转化为训练深度对抗领域自适应分类网络所需要的格式,包括以下步骤:
步骤2.1,将源域和目标域数据集中的图像缩放到长和宽为256×256像素大小;
步骤2.2,在缩放后的图像上,随机裁剪得到224×224像素大小的矩形图像;
步骤2.3,以0.5的概率随机水平翻转裁剪后的图像;
步骤2.4,将随机翻转后的图像从[0,255]转换到[-1,1]的范围内;
步骤2.5,将源域数据集中的类别标注数据转换为One-Hot向量,源域和目标域的域标注分别设为1和0。
步骤3,根据训练目标以及模型的输入输出形式,构造一个可学习域不变特征的对抗深度神经网络模型,如图2所示,由特征提取网络、对抗领域自适应网络和通用分类网络组成,包括以下步骤:
步骤3.1,构造特征提取网络。特征提取网络相当于一个编码器,将步骤2)处理后的图像输入网络,能够提取其高层的域不变的语义信息并输出为一个低维的特征向量。特征提取网络的输入为3×224×224的图像,输出为一系列低维编码特征向量(1024×1×1)。该网络由一系列的残差模块、全连接层、批量归一化层、非线性激活层以及随机失活层级联而成,其结构与Resnet-50网络层一致。输入图像首先经过Resnet-50的16个残差模块得到2048×1×1特征向量,然后输入到一层全连接模块(全连接层、批量归一化层、非线性激活层、随机失活层)编码得到1024×1×1大小的特征向量。残差模块是由卷积层、批量归一化层、非线性激活层、池化层级联而成,残差模块能够防止梯度消失,提高网络学习能力;全连接层能够统筹图像全局信息;批量归一化层能够归一化特征,加速网络收敛;非线性激活层能够带来更多的非线性;随机失活层能够稀疏网络,防止过拟合,其随机失活概率为0.5;
步骤3.2,构造对抗领域自适应网络。对抗领域自适应网络主要负责拉近源域数据高维特征分布和目标域数据高维特征分布之间的距离,迫使特征提取网络学习到两个域共有的域不变特征;对抗领域自适应网络的输入为特征提取网络输出的1024×1×1的特征向量,输出为2×1×1的领域二分类预测向量。该网络包括3个串联的全连接模块(全连接层、批量归一化层、非线性激活层、随机失活层),如图3所示。如果特征向量来自源域数据集,则预期预测为1;如果特征向量来自目标域数据集,则预期预测为0;通过二元交叉熵损失函数计算所得到的初始梯度,在进行反向转播的时候,经过对抗领域自适应网络是正常反向更新对抗领域自适应网络参数。而准备传播到特征提取网络时,中间会经过一个梯度翻转层,该层能够将通过它的梯度取反,取反率设置为-0.3,将乘到梯度上。这样特征提取网络误认为自己抽取到了错误的特征,实现两个网络对抗更新,对抗学习迫使源域数据特征和目标域数据特征映射至同一隐空间,使得对抗领域自适应网络无法分辨出特征来自哪个域;
步骤3.3,构造通用分类网络。通用分类网络主要负责对特征提取网络输出的低维向量作类别分类。通用分类网络的输入同样为特征提取网络输出的1024×1×1的向量,输出为数据集类别数长度的一维向量。该网络只包含一层全连接层,如图4所示。该向量数值最大的位置所对应的类别,即为预测的类别;
步骤4,定义对抗领域自适应网络和通用分类网络的损失函数,包括以下步骤:
步骤4.1,定义特征级调控损失函数从特征模长层面筛选重要样本和异常值,选择性增强各个样本使得对抗领域自适应网络可以更好地学习数据的真实分布并拉近。基于模型压缩领域的特征模长越小则信息量越小的原则,能够推测出特征提取网络的输出特征模长,反映了样本的重要程度;对于重要的样本,在对抗领域自适应网络中要被相应地增强,即特征提取网络中特征模越大,在对抗领域自适应网络中特征模长也相应越大,反之对于异常值在特征提取网络中特征模长偏小,这样在作对抗训练的时候其对抗领域自适应网络特征模长也会相应越小,因此,为了实现样本筛选和对抗特征选择性增强目的,特征级调控损失函数能够定义为对抗特征模长和提取特征模长的最小二乘损失(Least SquareLoss),公式如下所示:
Figure BDA0002391144950000121
式中,n表示样本总数,α表示缩减因子,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,fd(xi)表示对抗领域自适应网络特征,fg(xi)表示特征提取网络特征;
步骤4.2,定义标注级调控损失函数从输出向量的熵层面筛选重要样本和异常值,选择性增强各个样本使得对抗领域自适应网络可以更好地学习数据的真实分布并拉近。熵最小化原则表明,分类器类别预测结果的熵能够有效帮助低密度类别间的分离,基于此能够推测出,对于重要的样本都集中在输出向量的熵小的地方,而异常值的熵比较大,即难以被分类器区分;为了实现对抗性域对齐过程中真实数据分布的预测,定义标注级调控损失函数来减弱异常值的影响,即通用分类器预测结果的熵作为筛选指标,通用分类器输出向量的熵越小,越不可能是异常值,对抗领域自适应网络在预测源域和目标域输出向量的熵也应该越小,基于此,标注级调控损失函数定义为通用分类网络输出向量的熵作为权重,加权到对抗领域自适应网络输出向量的熵的最小化过程中,公式如下所示:
Figure BDA0002391144950000122
式中:n表示样本总数,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,H(g)=-∑glog(g)表示预测类别向量g的熵,即H(pd(xi))表示对抗领域自适应网络输出向量的熵,H(pg(xi))表示通用分类网络输出向量的熵;
步骤4.3,定义对抗领域自适应网络的损失函数。定义对抗损失函数使对抗领域自适应网络可以尽可能地预测出输入特征向量来自源域还是目标域,使得特征提取网络尽可能提取到源域和目标域的域不变特征迷惑对抗领域自适应网络。对抗领域自适应使用对抗学习的方式,设置一个对抗领域自适应网络和一个特征提取网络,对抗领域自适应网络尽可能分辨样本的输入是否来自源域,而特征提取网络尽可能去欺骗对抗领域自适应网络,通过两者的博弈来增强特征提取网络抽取域不变特征的能力,这样的对抗方式已被证明为最小化源域和目标域特征之间的相对熵距离,基于此,定义对抗领域自适应损失函数为:
Figure BDA0002391144950000131
式中,ns和nt分别表示源域和目标域样本数,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,pd(xi)表示对抗领域自适应网络的输出向量;
步骤4.4,定义通用分类网络的损失函数。定义损失函数使输出的向量所预测的类别分数尽可能的与标注数据接近,类别数与数据集类别数一致。具体实现为该输出向量用于与该输入对应的One-Hot标注作交叉熵,其分类任务损失函数定义为交叉熵损失,公式如下:
Figure BDA0002391144950000132
式中,ns表示源域样本数,Ds表示源域数据集,xi表示输入的来自源域的图片数据,pg(xi)表示通用分类网络的输出向量;yi表示标注的One-Hot向量;
步骤4.5,定义总损失函数。步骤4.1)和步骤4.2)中的两个样本筛选调控损失函数搭配对抗领域自适应损失函数,能够实现源域和目标域真实分布之间的拉近对齐,之后搭配上交叉熵分类损失函数使得网络具有类别鉴别能力。对以上4个损失进行加权求和。用公式表示如下:
Figure BDA0002391144950000141
其中,L为总损失值,其中,w1和w2分别用来权衡标注级调控损失和标注级调控损失;
步骤5,训练网络模型,包括以下步骤:
步骤5.1,各层参数的初始化采用的是传统的深度卷积神经网络中使用到的方法,对于特征提取网络参数采用在ImageNet数据集上预训练好的ResNet-50网络模型参数作为初始值;对于通用分类器和对抗领域自适应网络中的全连接层采用均值为0,标准差为0.02的高斯分布进行初始化;对所有的批量归一化层参数采用均值为1,标准差为0.02的高斯分布进行初始化;
步骤5.2,随机将步骤2处理过的成对图像处理,经过步骤3.1的特征提取网络得到相应的低维特征向量,该部分特征向量划分出源域部分经过步骤3.3的通用分类网络,再通过步骤4.4计算源域数据的分类损失值;同时该部分的源域和目标域特征向量也会同时经过步骤3.2的对抗领域自适应网络预测对应的特征向量来自源域还是目标域,并通过步骤4.1分别计算相应的特征级调控损失值和步骤4.2分别计算相应的标注级调控损失值。通过反向传播该误差值,计算各个网络的各层参数的梯度,再通过随机梯度下降算法根据梯度对各层参数进行优化,实现每轮网络模型的训练;
步骤5.3,重复步骤5.2直到网络能够对目标域测试集数据鲁棒地分类;
步骤6,使用目标域验证数据集对训练得到的模型进行验证,调整网络超参数。
具体做法是随机从目标域验证数据集中取出一些原始图像,经过步骤2处理后,输入到步骤5训练好的网络模型,让该网络模型去预测其类别,通过输出的结果与对应的标注数据进行比对,从而判断该训练好的网络模型的对目标域数据的泛化能力,并对网络超参数进行调整。
步骤7,使用目标域测试数据集对训练得到的模型进行测试,其具体做法是随机从目标域测试数据集取出一些原始图像,经过步骤2)处理后,输入到步骤5)训练好的网络模型,让该网络模型去预测其类别,而后再应用到目标域车辆图片的分类任务中。
以上所述实施例只为本发明之较佳实施例,并非以此限制本发明的实施范围,故凡依本发明之形状、原理所作的变化,均应涵盖在本发明的保护范围内。

Claims (6)

1.一种基于样本筛选的无标注车辆图片分类方法,其特征在于,包括以下步骤:
1)数据获取
鉴于测试目标域数据集的标注成本昂贵,即高速公路的目标域数据集的标注成本昂贵,需要借助一个不同但相关的有标注的源域数据集,即城市道路的源域数据集,其中,高速公路的目标域数据集必须和城市道路的源域数据集中有相同类别的待分类的车辆,但是两个域车辆的角度和拍摄环境有区别;然后划分目标域数据集为训练数据集、验证数据集和测试数据集,源域数据集全为训练数据集;
2)数据处理
将源域数据集的图像、域标注和类别标注数据及目标域数据集的图像和域标注,通过预处理转化为训练车辆图片分类网络模型所需要的格式,然后成对地输入到车辆图片分类网络模型中;
3)模型构建
根据训练目标以及模型的输入输出形式,构造一个能够学习域不变特征的对抗深度神经网络模型,其由特征提取网络、对抗领域自适应网络和通用分类网络组成;
4)定义损失函数
根据训练目标以及模型的架构,除了必需的分类和域对抗度量损失函数,额外提出了特征层面和标注层面两个按样本重要性增强型损失函数,包括以下步骤:
4.1)定义特征级调控损失函数
基于模型压缩领域的特征模长越小则信息量越小的原则,能够推测出特征提取网络的输出特征模长,反映了样本的重要程度;对于重要的样本,在对抗领域自适应网络中要被相应地增强,即特征提取网络中特征模越大,在对抗领域自适应网络中特征模长也相应越大,反之对于异常值在特征提取网络中特征模长偏小,这样在作对抗训练的时候其对抗领域自适应网络特征模长也会相应越小,因此,为了实现样本筛选和对抗特征选择性增强目的,特征级调控损失函数能够定义为对抗特征模长和提取特征模长的最小二乘损失,公式如下所示:
Figure FDA0004000189840000021
式中,n表示样本总数,α表示缩减因子,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,fd(xi)表示对抗领域自适应网络输出特征向量,fg(xi)表示特征提取网络输出特征向量;
4.2)定义标注级调控损失函数
除了在特征级筛选样本,标注级的样本筛选依然值得考虑,熵最小化原则表明,分类器类别预测结果的熵能够有效帮助低密度类别间的分离,基于此能够推测出,对于重要的样本都集中在输出向量的熵小的地方,而异常值的熵大,即难以被分类器区分;为了实现对抗性域对齐过程中真实数据分布的预测,定义标注级调控损失函数来减弱异常值的影响,即通用分类器输出向量的熵作为筛选指标,通用分类器输出向量的熵越小,越不是异常值,对抗领域自适应网络在预测源域和目标域的输出向量的熵也应该越小,基于此,标注级调控损失函数定义为通用分类网络输出向量的熵作为权重,加权到对抗领域自适应网络输出的熵的最小化过程中,公式如下所示:
Figure FDA0004000189840000031
式中:n表示样本总数,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,H(g)=-∑glog(g)表示预测类别向量g的熵,即H(pd(xi))表示对抗领域自适应网络输出向量的熵,H(pg(xi))表示通用分类网络输出向量的熵;
4.3)定义对抗领域自适应损失函数
领域自适应的目的是将源域和目标域数据映射到同一个高维语义空间中,通过在该空间中将两个域数据分布对齐拉近;对抗领域自适应使用对抗学习的方式,设置一个对抗领域自适应网络和一个特征提取网络,对抗领域自适应网络分辨样本的输入是否来自源域,而特征提取网络去欺骗对抗领域自适应网络,通过两者的博弈来增强特征提取网络抽取域不变特征的能力,这样的对抗方式已被证明为最小化源域和目标域特征之间的相对熵距离,基于此,定义对抗领域自适应损失函数为:
Figure FDA0004000189840000032
式中,ns和nt分别表示源域和目标域样本数,Ds和Dt分别表示源域数据集和目标域数据集,xi表示输入的来自源域或目标域数据集的图片数据,pd(xi)表示对抗领域自适应网络的输出向量;
4.4)定义通用分类网络的损失函数
通用分类网络将源域数据经过特征提取网络提取到特征作为输入,输出类别数量的一维向量,该向量用于与该输入对应的One-Hot标注作交叉熵,其分类任务损失函数定义为交叉熵损失,公式如下:
Figure FDA0004000189840000041
式中,ns表示源域样本数,Ds表示源域数据集,xi表示输入的来自源域的图片数据,pg(xi)表示通用分类网络的输出向量;yi表示标注的One-Hot向量;
4.5)定义总损失函数
步骤4.1)和步骤4.2)中的两个样本筛选调控损失函数搭配对抗领域自适应损失函数,能够实现源域和目标域真实分布之间的拉近对齐,之后搭配上交叉熵分类损失函数使得网络具有类别鉴别能力,总损失函数定义为:
Figure FDA0004000189840000042
式中,w1和w2分别为用来权衡特征级调控损失和标注级调控损失的参数;
5)模型训练
初始化模型各网络层的参数,不断迭代输入成对的源域和目标域训练样本,根据损失函数计算得到模型各网络层的损失值,再通过反向传播计算出各网络层参数的梯度,通过随机梯度下降法对各网络层的参数进行更新;
6)模型验证
使用目标域数据集中的验证数据集对训练得到的模型进行验证,测试模型的泛化性能,调整超参数;
7)模型应用
使用目标域数据集中的测试数据集测试训练得到的模型,并应用到目标域车辆图片的分类任务中。
2.根据权利要求1所述的一种基于样本筛选的无标注车辆图片分类方法,其特征在于,所述步骤2)包括以下步骤:
2.1)将源域和目标域数据集中的图像缩放到长和宽为256×256像素大小;
2.2)在缩放后的图像上,随机裁剪得到224×224像素大小的矩形图像;
2.3)以0.5的概率随机水平翻转裁剪后的图像;
2.4)将随机翻转后的图像从[0,255]转换到[-1,1]的范围内;
2.5)将源域数据集中的类别标注数据转换为One-Hot向量,源域和目标域的域标注分别设为1和0。
3.根据权利要求1所述的一种基于样本筛选的无标注车辆图片分类方法,其特征在于,所述步骤3)包括以下步骤:
3.1)构造特征提取网络
特征提取网络相当于一个编码器,将步骤2)处理后的图像输入网络,能够提取其高层的域不变的语义信息并输出为一个低维的特征向量;特征提取网络是由一系列的残差模块、全连接层、批量归一化层、非线性激活层以及随机失活层级联而成,残差模块是由卷积层、批量归一化层、非线性激活层、池化层级联而成;残差模块能够防止梯度消失,提高网络学习能力;全连接层能够统筹全局信息;批量归一化层能够归一化特征,加速网络收敛;非线性激活层能够带来更多的非线性;随机失活层能够稀疏网络,防止过拟合,其随机失活概率为0.5;
3.2)构造对抗领域自适应网络
对抗领域自适应网络主要负责拉近源域数据高维特征分布和目标域数据高维特征分布之间的距离,迫使特征提取网络学习到两个域共有的域不变特征;对抗领域自适应网络由三层神经元块级联而成,而该神经元块由全连接层、批量归一化层、非线性激活层以及随机失活层构成;
对抗领域自适应网络的输入为特征提取网络所编码的低维向量,输出为预测出的域标注;如果特征向量来自源域数据集,则预期预测为1;如果特征向量来自目标域数据集,则预期预测为0;但是当网络在进行反向转播的时候,梯度通过对抗领域自适应网络而准备传播到特征提取网络,中间会经过一个梯度翻转层,该层能够将通过它的梯度取反,这样特征提取网络误认为自己抽取到了错误的特征,实现两个网络对抗更新,对抗学习迫使源域数据特征和目标域数据特征映射至同一隐空间,使得对抗领域自适应网络无法分辨出特征来自哪个域;
3.3)构造通用分类网络
通用分类网络主要负责对特征提取网络输出的低维向量作类别分类,其由一层全连接层构成,将低维向量经过矩阵运算得到一个类别长度的向量,该向量数值最大的位置所对应的类别,即为预测的类别。
4.根据权利要求1所述的一种基于样本筛选的无标注车辆图片分类方法,其特征在于,所述步骤5)包括以下步骤:
5.1)初始化模型各层参数
各层参数的初始化采用的是深度卷积神经网络中使用到的方法,具体是:对于特征提取网络参数采用在ImageNet数据集上预训练好的ResNet-50网络模型参数作为初始值;对于通用分类器和对抗领域自适应网络中的全连接层采用均值为0,标准差为0.02的高斯分布进行初始化;对所有的批量归一化层参数采用均值为1,标准差为0.02的高斯分布进行初始化;
5.2)训练模型
随机将步骤2)处理过的成对图像处理,经过特征提取网络得到相应的低维特征向量,该低维特征向量划分出源域部分经过通用分类网络,再通过计算源域数据的分类损失值,同时该部分的源域和目标域特征向量也会经过对抗领域自适应网络预测对应的特征向量来自源域还是目标域,并通过分别计算相应的特征级调控损失值和标注级调控损失值,通过反向传播误差值,计算各个网络的各层参数的梯度,再通过随机梯度下降算法根据梯度对各层参数进行优化,实现每轮网络的训练;
5.3)重复步骤5.2)直至模型能够对目标域数据集中的测试集数据鲁棒地分类。
5.根据权利要求1所述的一种基于样本筛选的无标注车辆图片分类方法,其特征在于:在步骤6)中,随机从目标域数据集中的验证数据集中取出一些原始图像,经过步骤2)处理后,输入到步骤5)训练好的网络模型,让该网络模型去预测其类别,通过输出的结果与对应的标注数据进行比对,从而判断该训练好的网络模型对目标域数据的泛化能力,并对网络超参数进行调整。
6.根据权利要求1所述的一种基于样本筛选的无标注车辆图片分类方法,其特征在于:在步骤7)中,随机从目标域数据集中的测试数据集取出一些原始图像,经过步骤2)处理后,输入到步骤5)训练好的网络模型,让该网络模型去预测其类别,而后再应用到目标域车辆图片的分类任务中。
CN202010114792.0A 2020-02-25 2020-02-25 一种基于样本筛选的无标注车辆图片分类方法 Active CN111368886B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010114792.0A CN111368886B (zh) 2020-02-25 2020-02-25 一种基于样本筛选的无标注车辆图片分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010114792.0A CN111368886B (zh) 2020-02-25 2020-02-25 一种基于样本筛选的无标注车辆图片分类方法

Publications (2)

Publication Number Publication Date
CN111368886A CN111368886A (zh) 2020-07-03
CN111368886B true CN111368886B (zh) 2023-03-21

Family

ID=71212088

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010114792.0A Active CN111368886B (zh) 2020-02-25 2020-02-25 一种基于样本筛选的无标注车辆图片分类方法

Country Status (1)

Country Link
CN (1) CN111368886B (zh)

Families Citing this family (18)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111860670B (zh) * 2020-07-28 2022-05-17 平安科技(深圳)有限公司 域自适应模型训练、图像检测方法、装置、设备及介质
CN112183663A (zh) * 2020-10-26 2021-01-05 北京达佳互联信息技术有限公司 一种图像分类方法、装置、电子设备及存储介质
CN113762005A (zh) * 2020-11-09 2021-12-07 北京沃东天骏信息技术有限公司 特征选择模型的训练、对象分类方法、装置、设备及介质
CN112561080B (zh) * 2020-12-18 2023-03-03 Oppo(重庆)智能科技有限公司 样本筛选方法、样本筛选装置及终端设备
CN113128565B (zh) * 2021-03-25 2022-05-06 之江实验室 面向预训练标注数据不可知的图像自动标注系统和装置
CN113096080B (zh) * 2021-03-30 2024-01-16 四川大学华西第二医院 图像分析方法及系统
CN113378904B (zh) * 2021-06-01 2022-06-14 电子科技大学 一种基于对抗域自适应网络的图像分类方法
CN113449781B (zh) * 2021-06-17 2023-04-07 上海深至信息科技有限公司 一种甲状腺结节分类模型的生成方法及系统
CN113420824A (zh) * 2021-07-03 2021-09-21 上海理想信息产业(集团)有限公司 针对工业视觉应用的预训练数据筛选及训练方法、系统
CN113688867B (zh) * 2021-07-20 2023-04-28 广东工业大学 一种跨域图像分类方法
CN113537403A (zh) * 2021-08-14 2021-10-22 北京达佳互联信息技术有限公司 图像处理模型的训练方法和装置及预测方法和装置
CN113780468B (zh) * 2021-09-28 2022-08-09 中国人民解放军国防科技大学 一种基于少量神经元连接的健壮图像分类模型训练方法
CN113989627B (zh) * 2021-12-29 2022-05-27 深圳市万物云科技有限公司 一种基于异步联邦学习的城市防控图像检测方法和系统
CN114610933B (zh) * 2022-03-17 2024-02-13 西安理工大学 基于零样本域适应的图像分类方法
CN115578593B (zh) * 2022-10-19 2023-07-18 北京建筑大学 一种使用残差注意力模块的域适应方法
CN116778376B (zh) * 2023-05-11 2024-03-22 中国科学院自动化研究所 内容安全检测模型训练方法、检测方法和装置
CN117372416A (zh) * 2023-11-13 2024-01-09 北京透彻未来科技有限公司 一种对抗训练的高鲁棒性数字病理切片诊断系统及方法
CN117593594B (zh) * 2024-01-18 2024-04-23 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 基于一致性对齐的脑部mri图像分类方法、设备和介质

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109003253A (zh) * 2017-05-24 2018-12-14 通用电气公司 神经网络点云生成系统
CN110533066A (zh) * 2019-07-19 2019-12-03 浙江工业大学 一种基于深度神经网络的图像数据集自动构建方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11127157B2 (en) * 2017-10-24 2021-09-21 Nike, Inc. Image recognition system
US11386328B2 (en) * 2018-05-30 2022-07-12 Robert Bosch Gmbh Method, apparatus and computer program for generating robust automated learning systems and testing trained automated learning systems

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109003253A (zh) * 2017-05-24 2018-12-14 通用电气公司 神经网络点云生成系统
CN110533066A (zh) * 2019-07-19 2019-12-03 浙江工业大学 一种基于深度神经网络的图像数据集自动构建方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
Synthetic aperture radar ship discrimination, generation and latent variable extraction using information maximizing generative adversarial networks;C. P. Schwegmann 等;《2017 IEEE International Geoscience and Remote Sensing Symposium (IGARSS)》;20171231;第2263-2266页 *

Also Published As

Publication number Publication date
CN111368886A (zh) 2020-07-03

Similar Documents

Publication Publication Date Title
CN111368886B (zh) 一种基于样本筛选的无标注车辆图片分类方法
CN109949317B (zh) 基于逐步对抗学习的半监督图像实例分割方法
CN108830188B (zh) 基于深度学习的车辆检测方法
CN109558823B (zh) 一种以图搜图的车辆识别方法及系统
CN111259930A (zh) 自适应注意力指导机制的一般性目标检测方法
EP3690741A2 (en) Method for automatically evaluating labeling reliability of training images for use in deep learning network to analyze images, and reliability-evaluating device using the same
CN103984959A (zh) 一种基于数据与任务驱动的图像分类方法
CN107943856A (zh) 一种基于扩充标记样本的文本分类方法及系统
CN111079847B (zh) 一种基于深度学习的遥感影像自动标注方法
CN113011357A (zh) 基于时空融合的深度伪造人脸视频定位方法
CN111178451A (zh) 一种基于YOLOv3网络的车牌检测方法
CN111597340A (zh) 一种文本分类方法及装置、可读存储介质
CN113486886B (zh) 一种自然场景下的车牌识别方法和装置
CN112712052A (zh) 一种机场全景视频中微弱目标的检测识别方法
CN111860106A (zh) 一种无监督的桥梁裂缝识别方法
CN110852358A (zh) 一种基于深度学习的车辆类型判别方法
CN114926693A (zh) 基于加权距离的sar图像小样本识别方法及装置
CN115546196A (zh) 一种基于知识蒸馏的轻量级遥感影像变化检测方法
CN115131313A (zh) 基于Transformer的高光谱图像变化检测方法及装置
CN115546553A (zh) 一种基于动态特征抽取和属性修正的零样本分类方法
CN117237733A (zh) 一种结合自监督和弱监督学习的乳腺癌全切片图像分类方法
CN111242028A (zh) 基于U-Net的遥感图像地物分割方法
CN110751005B (zh) 融合深度感知特征和核极限学习机的行人检测方法
CN110909645B (zh) 一种基于半监督流形嵌入的人群计数方法
CN111832463A (zh) 一种基于深度学习的交通标志检测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant