CN114492574A - 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 - Google Patents
基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 Download PDFInfo
- Publication number
- CN114492574A CN114492574A CN202111579071.8A CN202111579071A CN114492574A CN 114492574 A CN114492574 A CN 114492574A CN 202111579071 A CN202111579071 A CN 202111579071A CN 114492574 A CN114492574 A CN 114492574A
- Authority
- CN
- China
- Prior art keywords
- domain
- target
- sample
- feature extractor
- classifier
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,通过迁移学习或领域适应的方法来利用相关源域的大量可用的注释数据跨域迁移知识到相关目标域获得带标签的目标数据;提出的域适应方法融合了高斯均匀混合模型检测离群值与深度神经网络进行图像分类,利用高斯均匀混合模型对每一类的目标样本特征到类均值的余弦距离进行建模,得到目标样本后验概率作为估计目标样本伪标签的重要程度;基于训练过程产生的目标样本伪标签提出的辅助伪标签损失加入神经网络的训练;同时最小化条件熵损失使得学习到的特征远离决策边界;综合大量实验证明了所提方法能够提高深度网络模型的图片分类准确度。
Description
技术领域
本发明涉及深度学习图片分类技术领域,主要涉及一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法。
背景技术
目前现有技术在图像分类或动作识别的大规模数据集上取得了令人印象深刻的结果。然而,获取如此大的带注释数据集是非常昂贵的,并且需要将知识从现有的带注释数据集转移到与特定的未带标签数据。如果有标签的和没有标签的数据具有不同的特征,那么它们是从两个不同的域采样的。特别是从互联网上收集的数据集,例如从共享视频或图像的平台上收集的数据集,与应用程序需要处理的数据有很大的不同。由于计算机视觉中常见的域转移,源域和目标域之间的数据分布可能会有很大的差异。因此,用足够多的标记图像训练的深度神经网络可能无法很好地适应目标域。为了解决带标签数据集(源域)与目标域未带标签数据之间的域转移问题,提出了各种无监督域适应方法。如果来自目标源的数据被部分标记,这个问题被称为半监督域适应。在这项工作中,解决了无监督领域适应的背景下的图像分类。
对抗适应是解决视觉域适应问题的领域适应方法中的重要组成部分。它的目的是通过一个领域鉴别器的两个对抗性目标来减少跨领域分布的差异。这些方法的思想主要来自生成对抗网络,它由两个神经网络组成:一个生成器和一个判别器。前者的目的是生成新的图像,从而使后者混淆,从而试图将它们与真实图像区分开来。现有的许多视觉域自适应方法都利用了这一思想,以保证网络无法区分源域和目标域的图像。
域适应主要试图减少源域和目标域之间的域转移。以往的浅域自适应方法或者学习源域数据的不变特征表示,或者通过源域数据的监督式学习目标域的预测模型,忽略了目标域的伪标签,目标域数据的伪标签已被证明对域自适应学习有提升。由于伪标记不可避免地含有噪声,因此在使用伪标记指导域自适应任务时,如何正确选择标记数据至关重要。
目前的域自适应方法虽然在实际应用中表现出了良好的性能,但并不能有效利用伪标签提升图片分类精准度,仍然面临着巨大的挑战。
发明内容
发明目的:针对上述背景技术中存在的问题,本发明提出了一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,以分步训练学习网络模型来利用训练中产生的伪标签。利用高斯均匀混合模型对每一类的目标样本特征到类均值的余弦距离进行建模,得到目标样本后验概率作为估计目标样本伪标签的重要程度。利用目标样本的后验概率筛选出置信度高的样本伪标签加入深度神经网络的学习。基于训练过程产生的目标样本伪标签提出的辅助伪标签损失加入神经网络的训练。最小化条件熵损失使得学习到的特征远离决策边界。
技术方案:为实现上述目的,本发明采用的技术方案为:
一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,包括以下步骤:
步骤S1、分别获取源域和目标域的图像,其中源域为已知域,图像带有标签,目标域为待分类域,图像无标签;基于多操作拼接对获取的图像进行预处理和增强;
步骤S2、搭建预训练的深度神经网络;基于步骤S1中源域监督式的分类损失和对抗损失学习预训练神经网络,获得预训练后的特征提取器和分类器;
步骤S3、建立高斯均匀混合模型;基于步骤S2获取的预训练后的特征提取器和分类器,首先将目标域的样本图像输入至预训练后的特征提取器进行特征提取;将提取到的特征输入分类器中获取目标样本伪标签;所述特征提取器提取的特征为目标域的样本特征;计算每一类目标样本到目标域的类特征均值的余弦距离;对每一类的目标样本计算的余弦距离进行建模,并通过期望最大化的方法更新高斯均匀混合模型参数,获取目标样本的后验概率,作为衡量伪标签置信度的指标;
步骤S4、搭建对抗域适应网络并进行训练;搭建特征提取器、分类器、判别器网络,特征提取器为在常规ImageNet数据集上预训练的特征提取器网络;输入源域和目标域图像,通过对抗域适应网络的前向传播分别获取源域和目标域图像分类器和判别器的输出;对于有标签的源域采用标准的监督式交叉熵损失训练特征提取器和分类器;基于源域和目标域的域标签提出域对抗损失,学习特征提取器和判别器;基于步骤S3获取的目标样本的后验概率,筛选步骤S3的目标样本后验概率高的伪标签结合对抗域适应网络分类器的目标样本预测计算平均绝对误差作为伪标签损失,通过伪标签损失学习对抗域适应网络;采用最小化条件熵损失学习特征远离决策边界;将最终学习好的特征提取器和分类器作为步骤S3的预训练特征提取器和分类器;
步骤S5、重复步骤S3-S4,通过迭代分别训练对抗域适应网络和高斯均匀混合模型;
在初次训练高斯均匀混合模型时,使用初始化的神经网络为步骤S2预训练的特征提取器、分类器;固定预训练的特征提取器、分类器参数,训练高斯均匀混合模型,得到目标样本的后验概率;在初次训练对抗域适应网络时,使用特征提取器为在常规 ImageNet数据集上预训练的特征提取器网络,利用目标样本的后验概率的伪标签损失对域适应任务进行学习;在训练一轮对抗域适应网络后,将对抗域适应学习到的特征提取器、分类器保存,并应用到下一步的高斯均匀混合模型学习当中;
步骤S6、基于训练好的对抗域适应网络,输入目标域图像样本,由特征提取器提取特征后,将目标样本特征输入到分类器,并利用softmax方法计算分类器输出的目标样本所属的概率最大的维度,即为目标样本图像的预测标签类别。
进一步地,所述步骤S1还包括对源域和目标域图像的预处理和增强步骤;具体地,
调整获取源域和目标域图像的尺寸,在固定长宽比的条件下,将图像调整至256×256 个像素值大小;将调整后的图像随机裁剪转换为224×224个像素值大小;以翻转概率0.5 对裁剪后的图像进行随机水平翻转;
将图像由[0,255]取值范围的PIL.image转化为Tensor,并将取值范围归一化到[0,1] 区间;对归一化处理后的图像Tensor依次进行标准化,减去均值,除以标准差,以拼接操作将各操作拼接起来。
进一步地,所述步骤S2中获得预训练后的特征提取器和分类器具体步骤包括:
步骤S2.1、搭建预训练的深度神经网络,包括特征提取器G、分类器C和判别器 D;预训练的特征提取器G为ResNet残差网络模型和256个神经元的全连接神经网络;对特征提取器G输出的256维特征进行标准化处理如下:
其中f为输入的图像经过特征提取器得到的样本特征,r为系数,||f||2为求l2范数;
所述分类器D为一层的全连接神经网络,神经元数量与图像类别数量相等;判别器D为两层的全连接神经网络,最后一层含有一个神经元,为输出层;在特征提取器F和判别器D之间还包括梯度反转层,在网络的前向传播中实现恒等变换;在网络的反向传播中,当梯度反向经过判别器之后传播到特征提取器时,梯度反转层通过对梯度取反乘系数的操作之后,继续反向传播到特征提取器;
步骤S2.3、学习预训练的特征提取器G、分类器C、判别器D;对于由卷积和全连接层组成的深度神经网络特征提取器G和分类器C,将经过如步骤S1所述的数据预处理和增强步骤的源域图像样本和目标域图像样本分别输入至特征提取器G,得到256维的源域样本特征和目标域样本特征,将源域样本特征和目标域样本特征进行步骤S2.1中的标准化处理;将标准化后的源域样本特征输入到分类器;源域样本特征的分类损失最小化基于如下标准监督式分类损失实现:
其中G表示特征提取器,C表示分类器,nS为源域中的源样本数量,为源域中的第i个样本,为源域中第i个样本标签,为源域中的第i个样本的特征,为源域中的第i个样本分类器的预测输出,J(·,·)是交叉熵损失;通过交叉熵损失学习特征提取器G和分类器C,最小化源分类损失;
将标准化处理后的源域样本特征和目标域样本特征分别输入到领域判别器D,领域判别器D的输出通过Sigmod激活函数,值域为(0,1),函数定义如下:
域对抗损失最小化基于如下的标准监督式二分类损失实现的:
其中G表示特征提取器,D表示领域判别器,nS为源域中的源样本数量,nt为目标域中的目标样本数量,xm为所有样本的第m个样本,dm为所有样本中第m个样本标签,G(xm)为所有样本中的第m个样本的特征,D(G(xm))为所有样本中的第m个样本领域判别器的预测输出,BCE(·,·)是二分类交叉熵损失;通过二分类交叉熵损失学习特征提取器G和领域判别器D,最小化源分类损失;
在前向传播时计算二分类交叉熵损失,反向传播时当梯度经过领域判别器之后的梯度反转层时,对梯度取负同时对梯度乘系数,然后继续反向传播到特征提取器。
步骤S2.4、对于搭建的预训练深度神经网络模型进行500轮学习后,保存学习到的特征提取器G和分类器C作为学习高斯均匀混合模型的预训练模型。
进一步地,所述步骤S3中高斯均匀混合模型建立具体包括:
步骤S3.1、采用如步骤S1所述的预处理和增强的目标域图像;
步骤S3.2、建立高斯均匀混合模型,其中特征提取器G为步骤S2中获取的预训练后的ResNet残差网络模型和256个神经元的全连接神经网络,分类器C为步骤S2中获取的预训练后的层的全连接神经网络;在训练高斯均匀混合模型时预训练的特征提取器 G、分类器C只有前向传播没有反向传播,即神经网络部分参数是固定的;将步骤S3.1 中获取的目标域图像样本输入到特征提取器G,得到256维的目标样本特征,对于目标样本的256维特征进行标准化处理如下:
基于得到的目标域图像样本的伪标签,将目标域分类为C类,计算C类中的每一类目标样本到类中心的余弦距离,余弦距离定义如下:
其中表示由目标样本输入特征提取器G得到的目标域的第j个目标样本的特征到目标类别均值的余弦距离,表示目标域第k类的所有的目标样本特征的均值,k 为目标域样本的类别数,||·||表示模,表示目标样本输入特征提取器G得到的目标域的第j个目标样本的特征;
其中πk表示内点的先验概率,σk表示高斯分布的标准差,δk表示均匀分布参数;表示高斯均匀混合模型的所有参数,对每个伪标签的目标样本引入一个随机变量zj∈{0,1},当目标样本的伪标签正确标记则记为1,否则记为0;当目标样本的伪标签被正确标记时,概率表示定义如下:
高斯分量对正确标记的目标数据建模,均匀分量对错误标记的数据建模,采用最大期望算法更新混合模型参数;首先固定特征提取器G、分类器C参数,初始化模型参数θ的初始值,设定迭代次数,首先开始E步,计算联合分布的条件概率期望,第j个目标数据正确标注的后验概率为:
其中l表示EM算法的迭代次数,E步利用对隐变量的现有估计值,计算每一个样本来自高斯分量的概率;M步最大化在E步上求得的极大似然值计算参数的值如下:
迭代进行E步和M步,当极大似然值达到最大时退出循环,固定高斯均匀混合模型参数,得到每一个目标数据正确标注的后验概率,作为目标样本筛选的依据。
进一步地,所述步骤S4中搭建对抗域适应网络并进行训练的具体步骤包括:
步骤S4.1、采用步骤S1所述的图像预处理和增强的目标域图像和源域图像;
步骤S4.2、训练对抗域适应网络;无监督对抗域适应学习包括一个带标签的源域和不带标签的目标域,其中带标签的源域和不带标签的目标域的图像都作为训练集使用,不带标签的目标域作为测试集;所述对抗域适应网络中包括特征提取器G、分类器C、判别器D,其中特征提取器G为常规ImageNet数据集上预训练的残差网络构件;将经过预处理后的源域图像样本和目标域图像样本分别输入至特征提取器G,重复步骤S2.3;
步骤S4.3、利用目标样本的后验概率的伪标签损失对域适应任务进行学习;设定辅助的伪标签损失如下:
H(·)是熵损失;基于上述损失和步骤S2.3的源分类损失和域对抗损失训练对抗域适应网络。
附图说明
图1是本发明提供的伪标签损失对抗域适应方法网络结构图。
具体实施方式
下面结合附图对本发明作更进一步的说明。显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明提供了一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,根本动机在于待分类的目标域没有可用于训练的图片。具体步骤如下:
步骤S1、分别获取源域和目标域的图像,其中源域为已知域,图像带有标签,目标域为待分类域,图像无标签;基于多操作拼接对获取的图像进行预处理和增强。具体地,
调整获取源域和目标域图像的尺寸,在固定长宽比的条件下,将图像调整至256×256 个像素值大小;将调整后的图像随机裁剪转换为224×224个像素值大小;以翻转概率0.5 对裁剪后的图像进行随机水平翻转;
将图像由[0,255]取值范围的PIL.image转化为Tensor,并将取值范围归一化到[0,1] 区间;对归一化处理后的图像Tensor依次进行标准化,减去均值,除以标准差,以拼接操作将各操作拼接起来。
步骤S2、搭建预训练的深度神经网络;基于步骤S1中源域监督式的分类损失和对抗损失学习预训练神经网络,获得预训练后的特征提取器和分类器。具体地,
步骤S2.1、搭建预训练的深度神经网络,包括特征提取器G、分类器C和判别器 D;预训练的特征提取器G为ResNet残差网络模型和256个神经元的全连接神经网络;对特征提取器G输出的256维特征进行标准化处理如下:
其中f为输入的图像经过特征提取器得到的样本特征,r为系数,||f||2为求l2范数;
所述分类器D为一层的全连接神经网络,神经元数量与图像类别数量相等;判别器D为两层的全连接神经网络,最后一层含有一个神经元,为输出层;在特征提取器F和判别器D之间还包括梯度反转层,在网络的前向传播中实现恒等变换;在网络的反向传播中,当梯度反向经过判别器之后传播到特征提取器时,梯度反转层通过对梯度取反乘系数的操作之后,继续反向传播到特征提取器;
步骤S2.2、学习预训练的特征提取器G、分类器C、判别器D;对于由卷积和全连接层组成的深度神经网络特征提取器G和分类器C,将经过如步骤S1所述的数据预处理和增强步骤的源域图像样本和目标域图像样本分别输入至特征提取器G,得到256维的源域样本特征和目标域样本特征,将源域样本特征和目标域样本特征进行步骤S2.1中的标准化处理;将标准化后的源域样本特征输入到分类器;源域样本特征的分类损失最小化基于如下标准监督式分类损失实现:
其中G表示特征提取器,C表示分类器,nS为源域中的源样本数量,为源域中的第i个样本,为源域中第i个样本标签,为源域中的第i个样本的特征,为源域中的第i个样本分类器的预测输出,J(·,·)是交叉熵损失;通过交叉熵损失学习特征提取器G和分类器C,最小化源分类损失;
将标准化处理后的源域样本特征和目标域样本特征分别输入到领域判别器D,领域判别器D的输出通过Sigmod激活函数,值域为(0,1),函数定义如下:
域对抗损失最小化基于如下的标准监督式二分类损失实现的:
其中G表示特征提取器,D表示领域判别器,nS为源域中的源样本数量,nt为目标域中的目标样本数量,xm为所有样本的第m个样本,dm为所有样本中第m个样本标签,G(xm)为所有样本中的第m个样本的特征,D(G(xm))为所有样本中的第m个样本领域判别器的预测输出,BCE(·,·)是二分类交叉熵损失;通过二分类交叉熵损失学习特征提取器G和领域判别器D,最小化源分类损失;
在前向传播时计算二分类交叉熵损失,反向传播时当梯度经过领域判别器之后的梯度反转层时,对梯度取负同时对梯度乘系数,然后继续反向传播到特征提取器。
步骤S2.3、对于搭建的预训练深度神经网络模型进行500轮学习后,保存学习到的特征提取器G和分类器C作为学习高斯均匀混合模型的预训练模型。
步骤S3、建立高斯均匀混合模型;具体地,
步骤S3.1、采用步骤S1所述的图像预处理和增强的目标域图像;
步骤S3.2、建立高斯均匀混合模型,其中特征提取器G为步骤S2中获取的预训练后的ResNet残差网络模型和256个神经元的全连接神经网络,分类器C为步骤S2中获取的预训练后的层的全连接神经网络;在训练高斯均匀混合模型时预训练的特征提取器 G、分类器C只有前向传播没有反向传播,即神经网络部分参数是固定的;将步骤S3.1 中获取的目标域图像样本输入到特征提取器G,得到256维的目标样本特征,对于目标样本的256维特征进行标准化处理如下:
基于得到的目标域图像样本的伪标签,将目标域分类为C类,计算C类中的每一类目标样本到类中心的余弦距离,余弦距离定义如下:
其中表示由目标样本输入特征提取器G得到的目标域的第j个目标样本的特征到目标类别均值的余弦距离,表示目标域第k类的所有的目标样本特征的均值,k 为目标域样本的类别数,||·||表示模,表示目标样本输入特征提取器G得到的目标域的第j个目标样本的特征;
目标是训练一个模型来检测离群值并降低它们在网络输出预测中的作用,由于没有关于离群值的百分比和分布的先验信息,不妨假设内点服从高斯分布,而离群点服从均匀分布。对于在目标域上计算第k类目标样本的进行高斯均匀混合模型建模:
其中πk表示内点的先验概率,σk表示高斯分布的标准差,δk表示均匀分布参数;表示高斯均匀混合模型的所有参数,对每个伪标签的目标样本引入一个随机变量zj∈{0,1},当目标样本的伪标签正确标记则记为1,否则记为0;当目标样本的伪标签被正确标记时,概率表示定义如下:
高斯分量对正确标记的目标数据建模,均匀分量对错误标记的数据建模,采用最大期望算法更新混合模型参数;首先固定特征提取器G、分类器C参数,初始化模型参数θ的初始值,设定迭代次数,首先开始E步,计算联合分布的条件概率期望,第j个目标数据正确标注的后验概率为:
其中l表示EM算法的迭代次数,E步利用对隐变量的现有估计值,计算每一个样本来自高斯分量的概率;M步最大化在E步上求得的极大似然值计算参数的值如下:
迭代进行E步和M步,当极大似然值达到最大时退出循环,固定高斯均匀混合模型参数,得到每一个目标数据正确标注的后验概率,作为目标样本筛选的依据。
步骤S4、搭建对抗域适应网络并进行训练;
步骤S4.1、采用如步骤S1所述的预处理和增强的目标域和源域图像;
步骤S4.2、训练对抗域适应网络;无监督对抗域适应学习包括一个带标签的源域和不带标签的目标域,其中带标签的源域和不带标签的目标域的图像都作为训练集使用,不带标签的目标域作为测试集;所述对抗域适应网络中包括特征提取器G、分类器C、判别器D,其中特征提取器G为常规ImageNet数据集上预训练的残差网络构件;将经过预处理后的源域图像样本和目标域图像样本分别输入至特征提取器G,重复步骤S2.3;
步骤S4.3、利用目标样本的后验概率的伪标签损失对域适应任务进行学习;设定辅助的伪标签损失如下:
H(·)是熵损失;基于上述损失和步骤2.3中的源分类损失和域对抗损失训练对抗域适应网络。
步骤S5、重复步骤S3-S4,通过迭代分别训练对抗域适应网络和高斯均匀混合模型;
在初次训练高斯均匀混合模型时,使用初始化的神经网络为步骤S2预训练的特征提取器、分类器;固定预训练的特征提取器、分类器参数,训练高斯均匀混合模型,得到目标样本的后验概率;在初次训练对抗域适应网络时,使用特征提取器为在常规 ImageNet数据集上预训练的特征提取器网络,利用目标样本的后验概率的伪标签损失对域适应任务进行学习;在训练一轮对抗域适应网络后,将对抗域适应学习到的特征提取器、分类器保存,并应用到下一步的高斯均匀混合模型学习当中;
步骤S6、基于训练好的对抗域适应网络,保存对抗域适应学习到的特征提取器、分类器。对目标域数据采用步骤S1中的预处理和增强步骤,输入目标域图像样本,由特征提取器提取特征后,将目标样本特征输入到分类器,并利用softmax方法计算分类器输出的目标样本所属的概率最大的维度,即为目标样本图像的预测标签类别。
下面根据实验结果进一步说明本发明提供的伪标签损失对抗域适应方法。
为了验证本发明的有效性,分别在Office31,ImageCLEF-DA,Office-Caltech10上做了实验。
在Office31数据集上,含有三个域:Amazon、Webcam、Dslr。选择其中的一个域作为带标签的源域,一个域作为待分类的目标域。如下表1所示,本方法在多个域和平均精度上均有较大的提升。
表1 office31数据集上的识别精度
在ImageCLEF-DA数据集上,含有三个域:I、P、C,如下表2所示,可以明显看出,相比于其他方法本专利提出的方法有较大提升。
表2 ImageCLEF-DA数据集上的识别精度
在Office-Caltech10数据集上含有四个不同的域,分别以A、C、D、W代替,如下表3所示,本文所提方法在平均精度上达到了最优。
表3 Office-Caltech10数据集上的识别精度
综上所述,本发明的基于高斯均匀混合模型的伪标签损失对抗域适应方法,一方面提出了的一种高斯均匀混合模型无监督离群值检测和对抗域适应无监督图像分类的迭代方法,以分步训练学习网络模型来利用训练中产生的伪标签,有效的利用训练中目标域伪标签。另一方面基于训练过程产生的目标样本伪标签提出的辅助伪标签损失加入神经网络的训练。最小化条件熵损失使得学习到的特征远离决策边界,提高分类准确度。方法提高了目标域伪标签的利用率,更好的辅助域适应任务和学习分类网络。
以上所述仅是本发明的优选实施方式,应当指出:对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。
Claims (5)
1.一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,其特征在于,包括以下步骤:
步骤S1、分别获取源域和目标域的图像,其中源域为已知域,图像带有标签,目标域为待分类域,图像无标签;基于多操作拼接对获取的图像进行预处理和增强;
步骤S2、搭建预训练的深度神经网络;基于步骤S1中源域监督式的分类损失和对抗损失学习预训练神经网络,获得预训练后的特征提取器和分类器;
步骤S3、建立高斯均匀混合模型;基于步骤S2获取的预训练后的特征提取器和分类器,首先将目标域的样本图像输入至预训练后的特征提取器进行特征提取;将提取到的特征输入至预训练后的分类器中获取目标样本伪标签;所述特征提取器提取的特征为目标域的样本特征;计算每一类目标样本到目标域的类特征均值的余弦距离;对每一类的目标样本计算的余弦距离进行建模,固定预训练的特征提取器和分类器参数,并通过期望最大化的方法更新高斯均匀混合模型参数检测离群值,获取目标样本的后验概率,作为衡量伪标签置信度的指标;
步骤S4、搭建对抗域适应网络并进行训练;搭建特征提取器、分类器、判别器网络,特征提取器为在常规ImageNet数据集上预训练的特征提取器网络;输入源域和目标域图像,通过对抗域适应网络的前向传播分别获取源域和目标域图像分类器和判别器的输出;对于有标签的源域采用标准的监督式交叉熵损失训练特征提取器和分类器;基于源域和目标域的域标签提出域对抗损失,学习特征提取器和判别器;基于步骤S3获取的目标样本的后验概率,筛选步骤S3的目标样本后验概率高的伪标签结合对抗域适应网络分类器的目标样本预测计算平均绝对误差作为伪标签损失,通过伪标签损失学习对抗域适应网络;采用最小化条件熵损失学习特征远离决策边界;将最终学习好的特征提取器和分类器作为步骤S3的预训练特征提取器和分类器;
步骤S5、重复步骤S3-S4,通过迭代分别训练对抗域适应网络和高斯均匀混合模型;
在初次训练高斯均匀混合模型时,使用初始化的神经网络为步骤S2预训练的特征提取器、分类器;固定预训练的特征提取器、分类器参数,训练高斯均匀混合模型,得到目标样本的后验概率;在初次训练对抗域适应网络时,使用特征提取器为在常规ImageNet数据集上预训练的特征提取器网络,利用目标样本的后验概率的伪标签损失对域适应任务进行学习;在训练一轮对抗域适应网络后,将对抗域适应学习到的特征提取器、分类器保存,并应用到下一步的高斯均匀混合模型学习当中;
步骤S6、基于训练好的对抗域适应网络,输入目标域图像样本,由特征提取器提取特征后,将目标样本特征输入到分类器,并利用softmax方法计算分类器输出的目标样本所属的概率最大的维度,即为目标样本图像的预测标签类别。
2.根据权利要求1所述的一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,其特征在于,所述步骤S1还包括对源域和目标域图像的预处理和增强步骤;具体地,
调整获取源域和目标域图像的尺寸,在固定长宽比的条件下,将图像调整至256×256个像素值大小;将调整后的图像随机裁剪转换为224×224个像素值大小;以翻转概率0.5对裁剪后的图像进行随机水平翻转;
将图像由[0,255]取值范围的PIL.image转化为Tensor,并将取值范围归一化到[0,1]区间;对归一化处理后的图像Tensor依次进行标准化,减去均值,除以标准差,以拼接操作将各操作拼接起来。
3.根据权利要求2所述的一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,其特征在于,所述步骤S2中获得预训练后的特征提取器和分类器具体步骤包括:
步骤S2.1、搭建预训练的深度神经网络,包括特征提取器G、分类器C和判别器D;预训练的特征提取器G为ResNet残差网络模型和256个神经元的全连接神经网络;对特征提取器G输出的256维特征进行标准化处理如下:
其中f为输入的图像经过特征提取器得到的样本特征,r为系数,||f||2为求l2范数;
所述分类器D为一层的全连接神经网络,神经元数量与图像类别数量相等;判别器D为两层的全连接神经网络,最后一层含有一个神经元,为输出层;在特征提取器F和判别器D之间还包括梯度反转层,在网络的前向传播中实现恒等变换;在网络的反向传播中,当梯度反向经过判别器之后传播到特征提取器时,梯度反转层通过对梯度取反乘系数的操作之后,继续反向传播到特征提取器;
步骤S2.2、学习预训练的特征提取器G、分类器C、判别器D;对于由卷积和全连接层组成的深度神经网络特征提取器G和分类器C,将经过如步骤S1所述的数据预处理和增强步骤的源域图像样本和目标域图像样本分别输入至特征提取器G,得到256维的源域样本特征和目标域样本特征,将源域样本特征和目标域样本特征进行步骤S2.1中的标准化处理;将标准化后的源域样本特征输入到分类器;源域样本特征的分类损失最小化基于如下标准监督式分类损失实现:
其中G表示特征提取器,C表示分类器,nS为源域中的源样本数量,为源域中的第i个样本,为源域中第i个样本标签,为源域中的第i个样本的特征,为源域中的第i个样本分类器的预测输出,J(·,·)是交叉熵损失;通过交叉熵损失学习特征提取器G和分类器C,最小化源分类损失;
将标准化处理后的源域样本特征和目标域样本特征分别输入到领域判别器D,领域判别器D的输出通过Sigmod激活函数,值域为(0,1),函数定义如下:
域对抗损失最小化基于如下的标准监督式二分类损失实现的:
其中G表示特征提取器,D表示领域判别器,nS为源域中的源样本数量,nt为目标域中的目标样本数量,xm为所有样本的第m个样本,dm为所有样本中第m个样本标签,G(xm)为所有样本中的第m个样本的特征,D(G(xm))为所有样本中的第m个样本领域判别器的预测输出,BCE(·,·)是二分类交叉熵损失;通过二分类交叉熵损失学习特征提取器G和领域判别器D,最小化源分类损失;
在前向传播时计算二分类交叉熵损失,反向传播时当梯度经过领域判别器之后的梯度反转层时,对梯度取负同时对梯度乘系数,然后继续反向传播到特征提取器;
步骤S2.3、对于搭建的预训练深度神经网络模型进行500轮学习后,保存学习到的特征提取器G和分类器C作为学习高斯均匀混合模型的预训练模型。
4.根据权利要求3所述的一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,其特征在于,所述步骤S3中高斯均匀混合模型建立具体包括:
步骤S3.1、采用步骤S1所述的图像预处理和增强的目标域图像;
步骤S3.2、建立高斯均匀混合模型,其中特征提取器G为步骤S2中获取的预训练后的ResNet残差网络模型和256个神经元的全连接神经网络,分类器C为步骤S2中获取的预训练后的层的全连接神经网络;在训练高斯均匀混合模型时预训练的特征提取器G、分类器C只有前向传播没有反向传播,即神经网络全部参数是固定的;将步骤S3.1中获取的目标域图像样本输入到特征提取器G,得到256维的目标样本特征,对于目标样本的256维特征进行标准化处理如下:
基于得到的目标域图像样本的伪标签,将目标域分类为C类,计算C类中的每一类目标样本到类中心的余弦距离,余弦距离定义如下:
其中表示由目标样本输入特征提取器G得到的目标域的第j个目标样本的特征到目标类别均值的余弦距离,表示目标域第k类的所有的目标样本特征的均值,k为目标域样本的类别数,||·||表示模,表示目标样本输入特征提取器G得到的目标域的第j个目标样本的特征;
其中πk表示内点的先验概率,σk表示高斯分布的标准差,δk表示均匀分布参数;表示高斯均匀混合模型的所有参数,对每个伪标签的目标样本引入一个随机变量zj∈{0,1},当目标样本的伪标签正确标记则记为1,否则记为0;当目标样本的伪标签被正确标记时,概率表示定义如下:
高斯分量对正确标记的目标数据建模,均匀分量对错误标记的数据建模,采用最大期望算法更新混合模型参数;首先固定特征提取器G、分类器C参数,初始化模型参数θ的初始值,设定迭代次数,首先开始E步,计算联合分布的条件概率期望,第j个目标数据正确标注的后验概率为:
其中l表示EM算法的迭代次数,E步利用对隐变量的现有估计值,计算每一个样本来自高斯分量的概率;M步最大化在E步上求得的极大似然值计算参数的值如下:
迭代进行E步和M步,当极大似然值达到最大时退出循环,固定高斯均匀混合模型参数,得到每一个目标数据正确标注的后验概率,作为目标样本筛选的依据。
5.根据权利要求4所述的一种基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法,其特征在于,所述步骤S4中搭建对抗域适应网络并进行训练的具体步骤包括:
步骤S4.1、采用步骤S1所述的图像预处理和增强的目标域图像和源域图像;
步骤S4.2、训练对抗域适应网络;无监督对抗域适应学习包括一个带标签的源域和不带标签的目标域,其中带标签的源域和不带标签的目标域的图像都作为训练集使用,不带标签的目标域作为测试集;所述对抗域适应网络中包括特征提取器G、分类器C、判别器D,其中特征提取器G为常规ImageNet数据集上预训练的残差网络构件;将经过预处理后的源域图像样本和目标域图像样本分别输入至特征提取器G,重复步骤S2.3;
步骤S4.3、利用目标样本的后验概率的伪标签损失对域适应任务进行学习;设定辅助的伪标签损失如下:
H(·)是熵损失;基于上述损失与步骤S2.3的源分类损失和域对抗损失训练对抗域适应网络。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111579071.8A CN114492574A (zh) | 2021-12-22 | 2021-12-22 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111579071.8A CN114492574A (zh) | 2021-12-22 | 2021-12-22 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114492574A true CN114492574A (zh) | 2022-05-13 |
Family
ID=81493379
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111579071.8A Pending CN114492574A (zh) | 2021-12-22 | 2021-12-22 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114492574A (zh) |
Cited By (15)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114821200A (zh) * | 2022-06-28 | 2022-07-29 | 苏州立创致恒电子科技有限公司 | 一种应用于工业视觉检测领域的图像检测模型及方法 |
CN114821282A (zh) * | 2022-06-28 | 2022-07-29 | 苏州立创致恒电子科技有限公司 | 一种基于域对抗神经网络的图像检测模型及方法 |
CN114943879A (zh) * | 2022-07-22 | 2022-08-26 | 中国科学院空天信息创新研究院 | 基于域适应半监督学习的sar目标识别方法 |
CN114998602A (zh) * | 2022-08-08 | 2022-09-02 | 中国科学技术大学 | 基于低置信度样本对比损失的域适应学习方法及系统 |
CN115082725A (zh) * | 2022-05-17 | 2022-09-20 | 西北工业大学 | 基于可靠样本选择和双分支动态网络的多源域自适应方法 |
CN115115892A (zh) * | 2022-07-20 | 2022-09-27 | 南通大学 | 一种基于生成模型与判别分类模型的图像半监督分类方法 |
CN115564960A (zh) * | 2022-11-10 | 2023-01-03 | 南京码极客科技有限公司 | 一种样本选择与标签校正结合的网络图像标签去噪方法 |
CN116070796A (zh) * | 2023-03-29 | 2023-05-05 | 中国科学技术大学 | 柴油车排放等级评估方法及系统 |
CN116128047A (zh) * | 2022-12-08 | 2023-05-16 | 西南民族大学 | 一种基于对抗网络的迁移学习方法 |
CN116152229A (zh) * | 2023-04-14 | 2023-05-23 | 吉林大学 | 一种糖尿病视网膜病变诊断模型的构建方法及诊断模型 |
CN116563612A (zh) * | 2023-04-13 | 2023-08-08 | 广东技术师范大学 | 一种基于联合高斯过程的多源图像自适应识别技术 |
TWI815762B (zh) * | 2022-12-16 | 2023-09-11 | 大陸商環旭電子股份有限公司 | 影像識別深度學習模型的訓練方法 |
CN117036869A (zh) * | 2023-10-08 | 2023-11-10 | 之江实验室 | 一种基于多样性和随机策略的模型训练方法及装置 |
CN117152563A (zh) * | 2023-10-16 | 2023-12-01 | 华南师范大学 | 混合目标域自适应模型的训练方法、装置及计算机设备 |
CN117456312A (zh) * | 2023-12-22 | 2024-01-26 | 华侨大学 | 一种面向无监督图像检索的模拟抗污伪标签增强方法 |
-
2021
- 2021-12-22 CN CN202111579071.8A patent/CN114492574A/zh active Pending
Cited By (24)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115082725A (zh) * | 2022-05-17 | 2022-09-20 | 西北工业大学 | 基于可靠样本选择和双分支动态网络的多源域自适应方法 |
CN115082725B (zh) * | 2022-05-17 | 2024-02-23 | 西北工业大学 | 基于可靠样本选择和双分支动态网络的多源域自适应方法 |
CN114821200A (zh) * | 2022-06-28 | 2022-07-29 | 苏州立创致恒电子科技有限公司 | 一种应用于工业视觉检测领域的图像检测模型及方法 |
CN114821200B (zh) * | 2022-06-28 | 2022-09-13 | 苏州立创致恒电子科技有限公司 | 一种应用于工业视觉检测领域的图像检测模型及方法 |
CN114821282A (zh) * | 2022-06-28 | 2022-07-29 | 苏州立创致恒电子科技有限公司 | 一种基于域对抗神经网络的图像检测模型及方法 |
CN115115892B (zh) * | 2022-07-20 | 2024-09-20 | 南通大学 | 一种基于生成模型与判别分类模型的图像半监督分类方法 |
CN115115892A (zh) * | 2022-07-20 | 2022-09-27 | 南通大学 | 一种基于生成模型与判别分类模型的图像半监督分类方法 |
CN114943879A (zh) * | 2022-07-22 | 2022-08-26 | 中国科学院空天信息创新研究院 | 基于域适应半监督学习的sar目标识别方法 |
CN114943879B (zh) * | 2022-07-22 | 2022-10-04 | 中国科学院空天信息创新研究院 | 基于域适应半监督学习的sar目标识别方法 |
CN114998602A (zh) * | 2022-08-08 | 2022-09-02 | 中国科学技术大学 | 基于低置信度样本对比损失的域适应学习方法及系统 |
CN115564960A (zh) * | 2022-11-10 | 2023-01-03 | 南京码极客科技有限公司 | 一种样本选择与标签校正结合的网络图像标签去噪方法 |
CN116128047B (zh) * | 2022-12-08 | 2023-11-14 | 西南民族大学 | 一种基于对抗网络的迁移学习方法 |
CN116128047A (zh) * | 2022-12-08 | 2023-05-16 | 西南民族大学 | 一种基于对抗网络的迁移学习方法 |
TWI815762B (zh) * | 2022-12-16 | 2023-09-11 | 大陸商環旭電子股份有限公司 | 影像識別深度學習模型的訓練方法 |
CN116070796A (zh) * | 2023-03-29 | 2023-05-05 | 中国科学技术大学 | 柴油车排放等级评估方法及系统 |
CN116563612A (zh) * | 2023-04-13 | 2023-08-08 | 广东技术师范大学 | 一种基于联合高斯过程的多源图像自适应识别技术 |
CN116152229B (zh) * | 2023-04-14 | 2023-07-11 | 吉林大学 | 一种糖尿病视网膜病变诊断模型的构建方法及诊断模型 |
CN116152229A (zh) * | 2023-04-14 | 2023-05-23 | 吉林大学 | 一种糖尿病视网膜病变诊断模型的构建方法及诊断模型 |
CN117036869B (zh) * | 2023-10-08 | 2024-01-09 | 之江实验室 | 一种基于多样性和随机策略的模型训练方法及装置 |
CN117036869A (zh) * | 2023-10-08 | 2023-11-10 | 之江实验室 | 一种基于多样性和随机策略的模型训练方法及装置 |
CN117152563A (zh) * | 2023-10-16 | 2023-12-01 | 华南师范大学 | 混合目标域自适应模型的训练方法、装置及计算机设备 |
CN117152563B (zh) * | 2023-10-16 | 2024-05-14 | 华南师范大学 | 混合目标域自适应模型的训练方法、装置及计算机设备 |
CN117456312A (zh) * | 2023-12-22 | 2024-01-26 | 华侨大学 | 一种面向无监督图像检索的模拟抗污伪标签增强方法 |
CN117456312B (zh) * | 2023-12-22 | 2024-03-12 | 华侨大学 | 一种面向无监督图像检索的模拟抗污伪标签增强方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114492574A (zh) | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 | |
CN111583263B (zh) | 一种基于联合动态图卷积的点云分割方法 | |
CN111079847B (zh) | 一种基于深度学习的遥感影像自动标注方法 | |
CN111127364B (zh) | 图像数据增强策略选择方法及人脸识别图像数据增强方法 | |
CN111709311A (zh) | 一种基于多尺度卷积特征融合的行人重识别方法 | |
CN110728694B (zh) | 一种基于持续学习的长时视觉目标跟踪方法 | |
CN113326731A (zh) | 一种基于动量网络指导的跨域行人重识别算法 | |
CN112085055A (zh) | 一种基于迁移模型雅克比阵特征向量扰动的黑盒攻击方法 | |
CN112733965A (zh) | 一种基于小样本学习的无标签图像分类方法 | |
Xu et al. | Learning representations that support robust transfer of predictors | |
CN114882534B (zh) | 基于反事实注意力学习的行人再识别方法、系统、介质 | |
CN114972904B (zh) | 一种基于对抗三元组损失的零样本知识蒸馏方法及系统 | |
CN116910571A (zh) | 一种基于原型对比学习的开集域适应方法及系统 | |
CN115147864A (zh) | 一种基于协同异质深度学习网络的红外人体行为识别方法 | |
CN105787045B (zh) | 一种用于可视媒体语义索引的精度增强方法 | |
CN116935125A (zh) | 通过弱监督实现的噪声数据集目标检测方法 | |
Yang et al. | NAM net: meta-network with normalization-based attention for few-shot learning | |
CN113032612B (zh) | 一种多目标图像检索模型的构建方法及检索方法和装置 | |
CN115100694A (zh) | 一种基于自监督神经网络的指纹快速检索方法 | |
Pang et al. | Target tracking based on siamese convolution neural networks | |
Joshi et al. | Video object segmentation with self-supervised framework for an autonomous vehicle | |
CN112232398A (zh) | 一种半监督的多类别Boosting分类方法 | |
CN118799645A (zh) | 一种基于知识蒸馏的深度融合多跨域少样本分类方法 | |
Wei et al. | Entropy-minimization Mean Teacher for Source-Free Domain Adaptive Object Detection | |
Mai et al. | An Uncertainty-Adaptive Consistency Training Algorithm for Semi-Supervised Object Counting Networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |