CN117611932B - 基于双重伪标签细化和样本重加权的图像分类方法及系统 - Google Patents
基于双重伪标签细化和样本重加权的图像分类方法及系统 Download PDFInfo
- Publication number
- CN117611932B CN117611932B CN202410094841.7A CN202410094841A CN117611932B CN 117611932 B CN117611932 B CN 117611932B CN 202410094841 A CN202410094841 A CN 202410094841A CN 117611932 B CN117611932 B CN 117611932B
- Authority
- CN
- China
- Prior art keywords
- sample
- data
- probability
- class
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 60
- 238000012549 training Methods 0.000 claims abstract description 29
- 238000002372 labelling Methods 0.000 claims abstract description 5
- 238000004364 calculation method Methods 0.000 claims description 19
- 238000007781 pre-processing Methods 0.000 claims description 12
- 230000009977 dual effect Effects 0.000 claims description 11
- 230000004927 fusion Effects 0.000 claims description 8
- 230000006870 function Effects 0.000 claims description 7
- 238000000605 extraction Methods 0.000 claims description 5
- 238000011176 pooling Methods 0.000 claims description 5
- 238000011478 gradient descent method Methods 0.000 claims description 3
- 238000013507 mapping Methods 0.000 claims description 3
- 230000001131 transforming effect Effects 0.000 claims 1
- 238000013145 classification model Methods 0.000 abstract description 5
- 230000008447 perception Effects 0.000 description 16
- 238000013135 deep learning Methods 0.000 description 6
- 238000005457 optimization Methods 0.000 description 6
- 230000008569 process Effects 0.000 description 4
- 230000003044 adaptive effect Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000005259 measurement Methods 0.000 description 2
- 238000005065 mining Methods 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000007635 classification algorithm Methods 0.000 description 1
- 238000012790 confirmation Methods 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000002950 deficient Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000010191 image analysis Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000004807 localization Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/762—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using clustering, e.g. of similar faces in social networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Databases & Information Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
本发明提出基于双重伪标签细化和样本重加权的图像分类方法及系统,涉及图像处理技术领域。包括样本数据集依次输入至学生网络和教师网络;将学生网络提取的样本特征输入至样本难度概率预测器,获得样本学习难度概率;教师网络中未标记数据分别通过聚类和预测得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,利用样本学习难度概率实现对各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记;更新网络参数,迭代直到达到收敛条件,完成模型的训练;利用训练好的模型对待分类图像进行分类。本发明提升了图像分类模型的准确性和鲁棒性。
Description
技术领域
本发明属于图像分类技术领域,尤其涉及基于双重伪标签细化和样本重加权的图像分类方法及系统。
背景技术
图像分类是计算机视觉领域中一个基础且重要的研究方向,旨在将输入图像分配到预定义的类别中,同时也是如目标检测、目标定位及语义分割等其他计算机视觉任务的基础。目前,已有众多学者对图像分类方法进行了深入广泛的研究,并在图像识别、自动驾驶、医学图像分析等领域取得了广泛的应用。随着深度学习的兴起,尤其是深度神经网络的成功应用,图像分类任务取得了革命性的进展。深度学习充分展现了其强大的特征学习能力,目前已成为图像分类的主流方法。
尽管基于深度学习的图像分类方法已经取得了显著的进展,但也面临着一系列的挑战和难点。基于深度学习的方法通常需要大量的标注数据来训练模型以保持良好的性能,这限制了它们在实际应用领域中的广泛适用性。例如在医学图像领域,由于医学专业知识的要求以及图像标注的复杂性,因此标注数据的数量相对较少且存在一定的局限性。这对于现有的基于大规模数据驱动的深度学习方法来说,无疑是严峻的挑战。
在面对标注数据稀缺的挑战时,半监督学习方法成为了一种有力的解决方案。它能够通过同时利用有限的标记数据和大量的未标记数据,更充分地利用可用信息,并实现与部分监督学习方法比肩的学习效果。然而,发明人发现,半监督学习方法中,仍然存在以下技术问题:
(1)数据集中通常存在一些难以学习或分类的样本,即难样本,它们通常位于不同类别相似特征交汇的边界区域,位于决策边界附近,这使得模型难以对其进行准确分类。简单来说,难样本为模型在分类时容易分错的样本,而简单样本是模型容易进行正确分类的样本。
难样本中包含更有价值的信息,而现有技术中对难样本的利用并不充分,导致模型无法学习到更准确的决策边界,模型分类性能的精度不够。
(2)当数据集中存在类别不平衡问题时,模型生成的伪标签会偏向于多数类,而远离少数类,严重影响多数类伪标签的质量以及少数类伪标签的数量,加剧模型的确认偏差,限制了图像分类性能的提升。
发明内容
为克服上述现有技术的不足,本发明提供了基于双重伪标签细化和样本重加权的图像分类方法及系统,引入难度感知自适应加权模块、双重约束伪标签细化模块和样本感知置信度阈值调整模块,提升了图像分类模型的准确性和鲁棒性,有效解决了图像数据集中存在的难样本难以识别、类别不平衡、标记数据稀缺等问题。
为实现上述目的,本发明的一个或多个实施例提供了如下技术方案:
本发明第一方面提供了基于双重伪标签细化和样本重加权的图像分类方法。
基于双重伪标签细化和样本重加权的图像分类方法,包括以下步骤:
获取包含标记数据和未标记数据的样本数据集,进行预处理;
将预处理后的数据集随机加噪并依次输入至学生网络和教师网络中,分别提取样本特征;
将学生网络提取的样本特征输入至样本难度概率预测器中,获得样本学习难度概率;
将教师网络中的未标记数据分别通过聚类和预测得到聚类伪标签属于每一类的概率、预测伪标签属于每一类的概率,两者相融合得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,并利用前述的样本学习难度概率实现对每个样本各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记,并加入到标记数据中,用于下一轮的模型训练;
计算模型的总损失;
更新学生网络、教师网络参数,迭代直到达到收敛条件,保存最小损失值时的网络模型,完成模型的训练;
利用训练好的模型对待分类图像进行分类。
本发明第二方面提供了基于双重伪标签细化和样本重加权的图像分类系统。
基于双重伪标签细化和样本重加权的图像分类系统,包括:
预处理模块,被配置为:获取包含标记数据和未标记数据的样本数据集,进行预处理;
特征提取模块,被配置为:将预处理后的数据集随机加噪并依次输入至学生网络和教师网络中,分别提取样本特征;
样本学习难度概率计算模块,被配置为:将学生网络提取的样本特征输入至样本难度概率预测器中,获得样本学习难度概率;
未标记样本标记模块,被配置为:将教师网络中的未标记数据分别通过聚类和预测得到聚类伪标签属于每一类的概率、预测伪标签属于每一类的概率,两者相融合得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,并利用前述的样本学习难度概率实现对每个样本各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记,并加入到标记数据中,用于下一轮的模型训练;
总损失计算模块,被配置为:计算模型的总损失;
迭代模块,被配置为:更新学生网络、教师网络参数,迭代直到达到收敛条件,保存最小损失值时的网络模型,完成模型的训练;
分类模块,被配置为:利用训练好的模型对待分类图像进行分类。
以上一个或多个技术方案存在以下有益效果:
本发明提出一种基于双重伪标签细化和样本重加权的图像分类方法及系统,对比之前类似的方法,在图像分类效果上显示出更优异的表现。一方面,本发明提出了一种难度感知自适应加权模块来挖掘难样本知识,并作为样本权重融合在损失中,帮助模型从难样本中学习具有判别性的特征,从而提升模型分类性能;另一方面,本发明引入双重约束伪标签细化模块实现在数据的局部结构和全局任务相关信息之间的平衡。
引入样本感知置信度阈值调整模块对每个样本在类级阈值基础上生成更加个性化的特定样本阈值,从而充分挖掘各类样本中的难样本知识并生成更可靠的伪标签,有效提升图像分类性能。
本发明附加方面的优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
构成本发明的一部分的说明书附图用来提供对本发明的进一步理解,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。
图1为第一个实施例的方法流程图。
图2为第二个实施例的系统结构图。
具体实施方式
应该指出,以下详细说明都是示例性的,旨在对本发明提供进一步的说明。除非另有指明,本文使用的所有技术和科学术语具有与本发明所属技术领域的普通技术人员通常理解的相同含义。
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本发明的示例性实施方式。
在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。
本发明提出的总体思路:
为了有效地解决图像数据集中存在的难样本难以识别、类别不平衡、标记数据稀缺等问题,本发明提出了一种基于双重伪标签细化和样本重加权的图像分类方法与系统,来提升图像分类模型的准确性和鲁棒性。
本发明引入难度感知自适应加权模块,该模块通过对样本学习难度进行显式建模来自适应分配样本权重,并最终作用于样本损失中,以提升模型对不确定性知识的学习能力。
引入双重约束伪标签细化模块,其中的聚类伪标签能够关注数据本身的局部结构信息,而预测伪标签主要关注于与全局任务相关的信息,通过综合利用聚类的灵活性和预测的准确性,提升伪标签的质量,从而增强图像分类模型的性能。
引入样本感知置信度阈值调整模块,通过动态调整每个样本中各类别的置信度阈值,来充分挖掘各类别难样本知识并生成更可靠的伪标签,进一步提升图像分类模型的准确性和鲁棒性。
数据集预处理:首先将现有数据集中的图像尺度变换为统一大小,然后对所有标记和未标记数据利用随机数据增强方法进行扩充,主要的数据增强有水平或垂直翻转、随机裁剪、缩放和平移、高斯噪声、随机擦除等。最后随机将每类标记数据按照预定的比例进行分配,构建为训练集和测试集;
提取样本特征:首先将训练集中的标记和未标记数据进行随机加噪,随后依次将标记和未标记数据输入到学生网络和教师网络中,经过多次卷积操作后,然后进入平均池化层,来提取样本特征,并为之后挖掘样本学习难度知识、生成聚类伪标签及损失函数的输入信息做准备工作;
难度感知自适应加权:该方法主要用于对样本的学习难度进行显式建模。在本发明中,引入了样本难度概率预测器,由带有一个隐藏层的多层感知器网络构成,其输入为在学生网络中提取的样本特征,输出为样本的学习难度概率,并将其作为样本权重与计算的分类损失及一致性损失进行融合。对于样本难度概率预测器设计了专门的优化目标,并以回归的方式对其进行优化;
双重约束伪标签细化:该方法主要目的是实现对预测伪标签和聚类伪标签的细化。其中聚类伪标签首先根据标记样本知识计算各类原型,然后将教师网络中提取的未标记样本特征映射到构建的嵌入空间,计算未标记样本特征与各类原型的相似度,并将其转换为未标记样本属于每一类的概率。预测伪标签为教师网络对未标记样本的预测概率,然后将预测伪标签与聚类伪标签根据设定的融合比例进行融合,作为样本的最终预测概率;
样本感知置信度阈值调整:该方法用于调整每个样本生成伪标签的置信度阈值。它可以分为两阶段,首先根据模型对每个类的学习能力强弱以及上一轮生成伪标签的类分布进行类级阈值的动态调整,然后利用在难度感知自适应加权方法中获得的样本学习难度概率实现对每个样本各类阈值的自适应调整;
生成伪标签数据:对于未标记数据,如果教师网络对未标记样本的最大置信度高于在样本感知置信度阈值调整方法设置的该样本对应类的阈值,则对其进行标记,并加入到标记数据中,用于下一轮的模型训练;
计算交叉熵损失:将学生网络对标记数据的预测标签与真实标签类别进行交叉熵损失的计算;
计算一致性损失:将学生网络与教师网络对未标记数据的预测标签进行均方差损失的计算;
计算难易预测损失:利用难样本的分类不一致性和分类不确定性两个特性,设计了针对样本难度概率预测器的优化目标。以标记样本为例,将学生网络对标记数据的预测标签与真实标签分布进行Jensen-Shannon散度的计算衡量分类不一致性,采用学生网络对样本预测概率分布计算的信息熵衡量分类的不确定性,最终结合以上计算的两个指标与样本难度概率预测器输出的样本学习难度概率计算平均绝对误差,实现对样本难度概率预测器输出结果的约束。其中未标记样本与标记样本在计算难易预测损失的区别在于没有真实标签的强约束信息,因此采用学生网络和教师网络对未标记样本预测概率分布的Jensen-Shannon散度来衡量分类不一致性;
计算网络的总损失:将交叉熵损失与一致性损失相加结果与样本的学习难度权重进行融合,并与难易预测损失相加作为网络的总损失。在进行样本学习难度权重与交叉熵损失、一致性损失融合时,采用两阶段策略,在模型训练初期,采用课程学习思想,为学习难度小的样本施加更大的学习权重,帮助模型快速收敛并具备初步的学习能力;在模型训练中后期,为难以学习的样本施加更大的学习权重,帮助模型学习其中的不确定性知识,进一步提升模型性能;
网络训练:通过随机梯度下降法更新学生网络的参数。在更新完学生网络参数后,利用学生网络的参数通过指数滑动平均(EMA)方式进行教师网络参数更新。迭代上述过程,直到达到收敛条件,保存其最小损失值时的网络模型;
预测阶段:利用训练好的网络模型对输入图像进行预测得到对应类别概率得分,然后选取概率最大的类别作为最终该图像的预测结果。
实施例一
本实施例公开了基于双重伪标签细化和样本重加权的图像分类方法。
如图1所示,基于双重伪标签细化和样本重加权的图像分类方法,包括以下步骤:
获取包含标记数据和未标记数据的样本数据集,进行预处理;
将预处理后的数据集随机加噪并依次输入至学生网络和教师网络中,分别提取样本特征;
将学生网络提取的样本特征输入至样本难度概率预测器中,获得样本学习难度概率;
将教师网络中的未标记数据分别通过聚类和预测得到聚类伪标签属于每一类的概率、预测伪标签属于每一类的概率,两者相融合得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,并利用前述的样本学习难度概率实现对每个样本各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记,并加入到标记数据中,用于下一轮的模型训练;
计算模型的总损失;
更新学生网络、教师网络参数,迭代直到达到收敛条件,保存最小损失值时的网络模型,完成模型的训练;
利用训练好的模型对待分类图像进行分类。
具体包括如下步骤:
S1: 数据集预处理
在数据收集过程中,原始数据样本的图像尺寸可能存在不一致的情况,为了促进深度网络模型更有效的学习,因此需要对原始数据样本集进行图像尺寸的统一调整。具体来说,利用Pytorch深度学习框架中的transforms类对现有数据集进行统一的尺度变换。其次,由于数据集中标记数据匮乏,因此采用随机数据增强方法对原始数据集进行扩充。最后将数据集中每类标记数据按照3:1的比例划分为训练集和测试集。特别注意,扩充后的标记数据集和未标记数据集标签不变。
S2:提取样本特征
首先在训练集中的标记和未标记数据在输入到学生网络与教师网络之前需要经过随机加噪处理,其中加噪的质量在很大程度上决定了分类算法的性能。在本发明中,根据图像数据集的特点设计了合理的加噪方式,其中加噪方式包括随机翻转、颜色抖动、噪声添加三种方式的随机组合,并且这些加噪方式的改变值全部采用一定范围内的随机数。
然后将处理后的标记数据(其中/>和/>分别表示第/>个数据及其标签,/>表示标记数据数量)和未标记数据/>(其中/>表示第/>个未标记数据,/>表示未标记数据数量)依次输入到学生网络/>和教师网络中(其中/>是输入数据,/>和/>表示模型参数),经过多次卷积操作后,然后进入平均池化层,来提取样本特征,并获取对应的特征图。为之后挖掘样本学习难度知识、生成聚类伪标签及损失函数的输入信息做准备工作。
S3:难度感知自适应加权
首先引入样本难度概率预测器(其中/>表示模型参数),它由带有一个隐藏层的多层感知器网络组成,该网络的输入为经过学生网络特征提取器提取的样本特征(其中/>表示样本特征的维度),输出为样本的学习难度概率,并将其作为样本权重与后面计算的交叉熵损失及一致性损失进行融合。
S4:双重约束伪标签细化
首先将标记数据经过教师网络提取样本特征,然后采用K-means等聚类算法将提取的样本特征聚类为/>簇(其中/>表示类别数量),获取各数据所属的簇标签以及/>个聚类中心(原型)/>,然后采用投票机制计算每一簇中数据的真实标签数量,并选择获得最多票数的标签作为该簇的标签。在获取各类原型后,将教师网络提取的未标记数据的样本特征映射到相同的嵌入空间中,以第/>个未标记数据为例,采用余弦相似度等度量方式计算其与各类原型之间的相似度,其中经教师网络提取的未标记数据样本特征/>与第/>类的原型/>计算相似度可表示为/>,并通过Softmax函数将与各类原型的相似度得分压缩到[0,1]之间,作为该未标记数据聚类伪标签属于每一类的概率。
其次,以第个未标记数据为例,将教师网络提取的样本特征经过平均池化层、全连接层和Softmax函数处理,获得该未标记数据预测伪标签属于每一类的概率。以多次实验调参得到的参数/>作为聚类伪标签与预测伪标签融合的比例系数,最终实现伪标签的细化/>。
S5:样本感知置信度阈值调整
首先根据教师网络的学习状态进行类级阈值的动态调整,实现降低少数类的阈值从而保留预测概率较低的少数类样本,提升多数类样本阈值从而滤除存在噪声的数据,只保留高质量的多数类样本。其中教师网络的学习状态可以通过对每个类的学习能力强弱以及上一轮生成伪标签的类分布两个角度进行估计。
前述的个聚类中心,即表示该分类任务的总类别数量;其中包含第/>类、第/>类等。
接下来以预测为类的数据为例,教师网络对/>类的学习能力强弱可以通过计算所有预测为/>类数据数量反映/>,其中学习能力强弱综合考虑标记和未标记数据,即/>表示标记或未标记数据在教师网络中的预测结果,/>。然后对/>执行归一化/>,使其取值范围在0到1之间。教师网络上一轮生成伪标签的类分布表示为/>,其中用于计算上一轮生成的伪标签为/>类别的数量,/>表示第/>个未标记数据在教师网络中的预测结果,在开始训练时,将伪标签分布赋值为类别数量的倒数。因此类级阈值可表示为/>,其中/>表示预先设定的最大置信度阈值。
然后利用在难度感知自适应加权方法中获得的样本学习难度概率实现对每个未标记样本各类阈值的自适应调整,以第个未标记样本的/>类为例可表示为,其中/>用于约束伪标签阈值不为负值。
S6:生成伪标签数据
对于未标记数据,如果教师网络对未标记样本的分到某类的概率高于在样本感知置信度阈值调整方法设置的该样本对应类的阈值,则对其进行标记,并加入到标记数据中,用于下一轮的模型训练。
S7:计算交叉熵损失
将学生网络对标记数据的预测标签与真实标签进行交叉熵损失计算。
S8: 计算一致性损失
将学生网络和教师网络对未标记数据的预测标签进行均方差损失计算;
S9:计算难易预测损失
针对难度感知自适应加权方法中引入的样本难度概率预测器,在本发明中设计了专门的优化目标,并以回归的方式对其进行优化。首先,根据样本学习难度将样本划分为简单样本和难样本。难样本一般具有以下特性:1)分类不一致性。模型对难样本的不同加噪组合难以保持一致的预测结果。2)分类不确定性。模型对难样本的预测通常是不确定的,通常会给出较低的置信度分数。
该优化目标可以分别从标记数据和未标记数据两个角度表示,对于第个标记数据,将学生网络对标记数据的预测标签与真实标签分布进行Jensen-Shannon散度计算来衡量分类不一致性/>,其中/>表示KL散度计算。采用学生网络对标记数据预测概率分布计算的信息熵衡量分类的不确定性,其中/>表示学生网络对第/>个标记数据第/>类的预测概率。
最终结合以上计算的两个指标与样本难度概率预测器输出的样本学习难度概率计算平均绝对误差/>,实现对样本难度概率预测器输出结果的约束。其中未标记样本与标记样本在计算难易预测损失的区别在于没有真实标签的强约束信息,因此采用学生网络和教师网络对未标记样本预测概率分布的Jensen-Shannon散度来衡量分类不一致性。
S10: 计算网络的总损失
在进行样本学习难度权重与交叉熵损失/>、一致性损失/>融合时,采用两阶段策略,在模型训练初期,采用课程学习思想,为学习难度小的样本施加更大的学习权重,帮助模型快速收敛并具备初步的学习能力;在模型训练中后期,为难以学习的样本施加更大的学习权重/>,帮助模型学习其中的不确定性知识,进一步提升模型性能。
样本学习的难易通过步骤3中样本难度概率预测器生成的样本学习难度概率进行估计。
因此,分别将交叉熵损失与一致性损失/>与根据两阶段策略调整的样本的学习难度权重/>进行融合,然后与难易预测损失/>相加作为网络的总损失。
S11: 网络训练
通过随机梯度下降法更新学生网络的参数。在更新完学生网络参数后,利用学生网络的参数通过指数滑动平均(EMA)方式进行教师网络参数更新。迭代上述过程,直到达到收敛条件,保存其最小损失值时的网络模型;
S12: 预测阶段
输入待测图像到训练好的网络模型中进行预测,最终网络模型输出计算的最大值所对应的类别做为最终的预测结果。
实施例二
本实施例公开了基于双重伪标签细化和样本重加权的图像分类系统。
如图2所示,基于双重伪标签细化和样本重加权的图像分类系统,包括:
预处理模块,被配置为:获取包含标记数据和未标记数据的样本数据集,进行预处理;
特征提取模块,被配置为:将预处理后的数据集随机加噪并依次输入至学生网络和教师网络中,分别提取样本特征;
样本学习难度概率计算模块,被配置为:将学生网络提取的样本特征输入至样本难度概率预测器中,获得样本学习难度概率;
未标记样本标记模块,被配置为:将教师网络中的未标记数据分别通过聚类和预测得到聚类伪标签属于每一类的概率、预测伪标签属于每一类的概率,两者相融合得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,并利用前述的样本学习难度概率实现对每个样本各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记,并加入到标记数据中,用于下一轮的模型训练;
总损失计算模块,被配置为:计算模型的总损失;
迭代模块,被配置为:更新学生网络、教师网络参数,迭代直到达到收敛条件,保存最小损失值时的网络模型,完成模型的训练;
分类模块,被配置为:利用训练好的模型对待分类图像进行分类。
如图2所示,对整个图像分类系统模型框架做解释说明。
图2中所对应的虚线框内系统为主要执行分类功能的系统模块,其中特征向量提取模块利用S2中所述教师网络模型提取的特征图,然后计算预测类别与用户进行交互。其中采用的教师网络模型为经过训练后确定的最优模型。
用户输入待测试图像数据进入分类系统,分类系统内部自动进行特征向量提取和计算预测类别两个过程,最后输出预测类别与用户进行交互。
本领域技术人员应该明白,上述本发明的各模块或各步骤可以用通用的计算机装置来实现,可选地,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。本发明不限制于任何特定的硬件和软件的结合。
上述虽然结合附图对本发明的具体实施方式进行了描述,但并非对本发明保护范围的限制,所属领域技术人员应该明白,在本发明的技术方案的基础上,本领域技术人员不需要付出创造性劳动即可做出的各种修改或变形仍在本发明的保护范围以内。
Claims (8)
1.基于双重伪标签细化和样本重加权的图像分类方法,其特征在于,包括以下步骤:
获取包含标记数据和未标记数据的样本数据集,进行预处理;
将预处理后的数据集随机加噪并依次输入至学生网络和教师网络中,分别提取样本特征;
将学生网络提取的样本特征输入至样本难度概率预测器中,获得样本学习难度概率;
将教师网络中的未标记数据分别通过聚类和预测得到聚类伪标签属于每一类的概率、预测伪标签属于每一类的概率,两者相融合得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,并利用前述的样本学习难度概率实现对每个样本各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记,并加入到标记数据中,用于下一轮的模型训练;
最终预测概率的计算方法具体为:
将教师网络提取的标记数据的样本特征聚类到多个簇,获取多个簇的簇标签及原型,将教师网络提取的未标记数据的样本特征映射到与标记数据相同的嵌入空间中,计算未标记样本特征与各类原型的相似度,得到未标记样本属于每一类的概率;
将教师网络提取的样本特征经过平均池化层、全连接层和Softmax函数处理,获得该未标记数据预测伪标签属于每一类的概率;
通过以下公式得到最终预测概率:
;
为聚类伪标签与预测伪标签融合的比例系数,/>为第/>个未标记数据;
通过对每个类的学习能力强弱以及上一轮生成伪标签的类分布两个角度对教师网络的学习状态进行估计:
教师网络学习能力强弱表示为:
,
其中为所有预测为/>类数据的数量;/>表示标记或未标记数据在教师网络中的预测结果;/>;/>表示标记数据数量;/>表示未标记数据数量;/>为第/>个标记或未标记数据;
教师网络上一轮生成伪标签的类分布表示为:
其中用于计算上一轮生成的伪标签为/>类别的数量;/>为/>类中的第/>类;
类级阈值表示为:
,
其中表示预先设定的最大置信度阈值;/>为C类的类级阈值;
样本学习难度概率实现对每个未标记样本各类阈值的自适应调整:
,
其中用于约束伪标签阈值不为负值;/>为第/>个未标记数据的学习难度概率;
计算模型的总损失;
更新学生网络、教师网络参数,迭代直到达到收敛条件,保存最小损失值时的网络模型,完成模型的训练;
利用训练好的模型对待分类图像进行分类。
2.如权利要求1所述的基于双重伪标签细化和样本重加权的图像分类方法,其特征在于:
数据预处理包括:将数据集中的图像尺度变换为统一大小,对所有标记和未标记数据利用随机数据增强方法进行扩充;
数据集的随机加噪方式包括随机翻转、颜色抖动、噪声添加三种方式的随机组合。
3.如权利要求1所述的基于双重伪标签细化和样本重加权的图像分类方法,其特征在于,样本难度概率预测器由带有一个隐藏层的多层感知器网络组成,其中/>表示模型参数,输入为经过学生网络提取的样本特征/>,输出为样本的学习难度概率。
4.如权利要求1所述的基于双重伪标签细化和样本重加权的图像分类方法,其特征在于:
计算交叉熵损失,具体为:
将学生网络对标记数据的预测标签与真实标签进行交叉熵损失计算:
;
计算一致性损失,具体为:
将学生网络和教师网络对未标记数据的预测标签进行均方差损失计算:
;
其中,学生网络,/>是输入数据,/>表示模型参数;/>为真实标签;教师网络,/>表示模型参数;标记数据/>;/>和/>分别表示第/>个数据及其标签,/>表示标记数据数量;未标记数据/>,其中/>表示第/>个未标记数据,/>表示未标记数据数量。
5.如权利要求4所述的基于双重伪标签细化和样本重加权的图像分类方法,其特征在于,计算难易预测损失,具体为:
对于第个标记数据,将学生网络对标记数据的预测标签与真实标签分布进行Jensen-Shannon散度计算,来衡量分类不一致性:
,
其中表示KL散度计算;/>为学生网络对第/>个标记数据的预测概率;
采用学生网络对标记数据预测概率分布计算的信息熵,衡量分类的不确定性:
,
其中表示学生网络对第/>个标记数据第/>类的预测概率;
结合以上计算的两个指标:
;
将与样本学习难度概率计算平均绝对误差,实现对样本难度概率预测器输出结果的约束:
。
6.如权利要求5所述的基于双重伪标签细化和样本重加权的图像分类方法,其特征在于:
分别将交叉熵损失与一致性损失/>与样本的学习难度权重/>进行融合,然后与难易预测损失/>相加作为网络的总损失:
;
为标记数据权重;/>为未标记数据权重;
其中,在进行样本学习难度权重与交叉熵损失、一致性损失融合时,采用两阶段策略:
第一阶段即模型训练初期,更注重简单样本学习,对简单样本施加更大的学习权重;第二阶段即模型训练中后期,更加注重难样本,对难样本施加更大的学习权重,/>为根据两阶段策略调整的样本的学习难度权重。
7.如权利要求1所述的基于双重伪标签细化和样本重加权的图像分类方法,其特征在于:
通过随机梯度下降法更新学生网络的参数;
在更新完学生网络参数后,利用学生网络的参数通过指数滑动平均方式进行教师网络参数更新。
8.基于双重伪标签细化和样本重加权的图像分类系统,其特征在于:包括:
预处理模块,被配置为:获取包含标记数据和未标记数据的样本数据集,进行预处理;
特征提取模块,被配置为:将预处理后的数据集随机加噪并依次输入至学生网络和教师网络中,分别提取样本特征;
样本学习难度概率计算模块,被配置为:将学生网络提取的样本特征输入至样本难度概率预测器中,获得样本学习难度概率;
未标记样本标记模块,被配置为:将教师网络中的未标记数据分别通过聚类和预测得到聚类伪标签属于每一类的概率、预测伪标签属于每一类的概率,两者相融合得到最终预测概率;基于教师网络的学习状态进行类级阈值的动态调整,并利用前述的样本学习难度概率实现对每个样本各类阈值的自适应调整,得到对应类的阈值;若未标记样本的最终预测概率高于对应类的阈值,则用该类对未标记样本进行标记,并加入到标记数据中,用于下一轮的模型训练;
最终预测概率的计算方法具体为:
将教师网络提取的标记数据的样本特征聚类到多个簇,获取多个簇的簇标签及原型,将教师网络提取的未标记数据的样本特征映射到与标记数据相同的嵌入空间中,计算未标记样本特征与各类原型的相似度,得到未标记样本属于每一类的概率;
将教师网络提取的样本特征经过平均池化层、全连接层和Softmax函数处理,获得该未标记数据预测伪标签属于每一类的概率;
通过以下公式得到最终预测概率:
;
为聚类伪标签与预测伪标签融合的比例系数,/>为第/>个未标记数据;
通过对每个类的学习能力强弱以及上一轮生成伪标签的类分布两个角度对教师网络的学习状态进行估计:
教师网络学习能力强弱表示为:
,
其中为所有预测为/>类数据的数量;/>表示标记或未标记数据在教师网络中的预测结果;/>;/>表示标记数据数量;/>表示未标记数据数量;/>为第/>个标记或未标记数据;
教师网络上一轮生成伪标签的类分布表示为:
其中用于计算上一轮生成的伪标签为/>类别的数量;/>为/>类中的第/>类;
类级阈值表示为:
,
其中表示预先设定的最大置信度阈值;/>为C类的类级阈值;
样本学习难度概率实现对每个未标记样本各类阈值的自适应调整:
,
其中用于约束伪标签阈值不为负值;/>为第/>个未标记数据的学习难度概率;
总损失计算模块,被配置为:计算模型的总损失;
迭代模块,被配置为:更新学生网络、教师网络参数,迭代直到达到收敛条件,保存最小损失值时的网络模型,完成模型的训练;
分类模块,被配置为:利用训练好的模型对待分类图像进行分类。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410094841.7A CN117611932B (zh) | 2024-01-24 | 2024-01-24 | 基于双重伪标签细化和样本重加权的图像分类方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410094841.7A CN117611932B (zh) | 2024-01-24 | 2024-01-24 | 基于双重伪标签细化和样本重加权的图像分类方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117611932A CN117611932A (zh) | 2024-02-27 |
CN117611932B true CN117611932B (zh) | 2024-04-26 |
Family
ID=89958362
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410094841.7A Active CN117611932B (zh) | 2024-01-24 | 2024-01-24 | 基于双重伪标签细化和样本重加权的图像分类方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117611932B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118038185B (zh) * | 2024-03-26 | 2024-08-09 | 山东建筑大学 | 一种基于多样化自适应知识蒸馏的图像分类方法及系统 |
CN118262181B (zh) * | 2024-05-29 | 2024-08-13 | 山东鲁能控制工程有限公司 | 一种基于大数据的自动化数据处理系统 |
CN118279700B (zh) * | 2024-05-30 | 2024-08-09 | 广东工业大学 | 一种工业质检网络训练方法及装置 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111340105A (zh) * | 2020-02-25 | 2020-06-26 | 腾讯科技(深圳)有限公司 | 一种图像分类模型训练方法、图像分类方法、装置及计算设备 |
EP3975062A1 (en) * | 2020-09-24 | 2022-03-30 | Toyota Jidosha Kabushiki Kaisha | Method and system for selecting data to train a model |
CN114419363A (zh) * | 2021-12-23 | 2022-04-29 | 北京三快在线科技有限公司 | 基于无标注样本数据的目标分类模型训练方法及装置 |
CN114821204A (zh) * | 2022-06-30 | 2022-07-29 | 山东建筑大学 | 一种基于元学习嵌入半监督学习图像分类方法与系统 |
CN115660101A (zh) * | 2022-09-27 | 2023-01-31 | 上海淇玥信息技术有限公司 | 一种基于业务节点信息的数据服务提供方法及装置 |
CN116310475A (zh) * | 2022-11-21 | 2023-06-23 | 天津大学 | 一种基于半监督学习的垃圾图像分类算法 |
-
2024
- 2024-01-24 CN CN202410094841.7A patent/CN117611932B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111340105A (zh) * | 2020-02-25 | 2020-06-26 | 腾讯科技(深圳)有限公司 | 一种图像分类模型训练方法、图像分类方法、装置及计算设备 |
EP3975062A1 (en) * | 2020-09-24 | 2022-03-30 | Toyota Jidosha Kabushiki Kaisha | Method and system for selecting data to train a model |
CN114419363A (zh) * | 2021-12-23 | 2022-04-29 | 北京三快在线科技有限公司 | 基于无标注样本数据的目标分类模型训练方法及装置 |
CN114821204A (zh) * | 2022-06-30 | 2022-07-29 | 山东建筑大学 | 一种基于元学习嵌入半监督学习图像分类方法与系统 |
CN115660101A (zh) * | 2022-09-27 | 2023-01-31 | 上海淇玥信息技术有限公司 | 一种基于业务节点信息的数据服务提供方法及装置 |
CN116310475A (zh) * | 2022-11-21 | 2023-06-23 | 天津大学 | 一种基于半监督学习的垃圾图像分类算法 |
Non-Patent Citations (3)
Title |
---|
Home Medical Image Computing and Computer Assisted Intervention – MICCAI 2022 Conference paper Semi-supervised Medical Image Classification with Temporal Knowledge-Aware Regularization;QIUSHI YANG;《Medical Image Computing and Computer Assisted Intervention – MICCAI 2022》;20220916;全文 * |
半监督图像分类的伪标签预测方法研究;熊巍钰;《中国优秀硕士学位论文全文数据库》;20240115;全文 * |
基于层次化双重注意力网络的乳腺多模态图像分类;杨霄;《山东大学学报》;20220630;全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN117611932A (zh) | 2024-02-27 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN117611932B (zh) | 基于双重伪标签细化和样本重加权的图像分类方法及系统 | |
CN109583501B (zh) | 图片分类、分类识别模型的生成方法、装置、设备及介质 | |
CN109993102B (zh) | 相似人脸检索方法、装置及存储介质 | |
CN102314614B (zh) | 一种基于类共享多核学习的图像语义分类方法 | |
CN110796199B (zh) | 一种图像处理方法、装置以及电子医疗设备 | |
JP2022141931A (ja) | 生体検出モデルのトレーニング方法及び装置、生体検出の方法及び装置、電子機器、記憶媒体、並びにコンピュータプログラム | |
CN106951825A (zh) | 一种人脸图像质量评估系统以及实现方法 | |
CN108563624A (zh) | 一种基于深度学习的自然语言生成方法 | |
CN111723674A (zh) | 基于马尔科夫链蒙特卡洛与变分推断的半贝叶斯深度学习的遥感图像场景分类方法 | |
CN108537168B (zh) | 基于迁移学习技术的面部表情识别方法 | |
CN109492750B (zh) | 基于卷积神经网络和因素空间的零样本图像分类方法 | |
CN114494718A (zh) | 一种图像分类方法、装置、存储介质及终端 | |
CN112800876A (zh) | 一种用于重识别的超球面特征嵌入方法及系统 | |
CN114692732B (zh) | 一种在线标签更新的方法、系统、装置及存储介质 | |
CN112364791B (zh) | 一种基于生成对抗网络的行人重识别方法和系统 | |
JPWO2018203555A1 (ja) | 信号検索装置、方法、及びプログラム | |
CN116910571B (zh) | 一种基于原型对比学习的开集域适应方法及系统 | |
CN111639540A (zh) | 基于相机风格和人体姿态适应的半监督人物重识别方法 | |
CN114998220A (zh) | 一种基于改进的Tiny-YOLO v4自然环境下舌像检测定位方法 | |
CN113222149A (zh) | 模型训练方法、装置、设备和存储介质 | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN117152459B (zh) | 图像检测方法、装置、计算机可读介质及电子设备 | |
CN107330448A (zh) | 一种基于标记协方差和多标记分类的联合学习方法 | |
CN114842238A (zh) | 一种嵌入式乳腺超声影像的识别方法 | |
CN117671261A (zh) | 面向遥感图像的无源域噪声感知域自适应分割方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |