CN115641480A - 一种基于样本筛选与标签矫正的噪声数据集训练方法 - Google Patents
一种基于样本筛选与标签矫正的噪声数据集训练方法 Download PDFInfo
- Publication number
- CN115641480A CN115641480A CN202211392565.XA CN202211392565A CN115641480A CN 115641480 A CN115641480 A CN 115641480A CN 202211392565 A CN202211392565 A CN 202211392565A CN 115641480 A CN115641480 A CN 115641480A
- Authority
- CN
- China
- Prior art keywords
- sample
- label
- samples
- training
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Image Analysis (AREA)
Abstract
本发明公开了一种基于样本筛选与标签矫正的噪声数据集训练方法,属于深度学习、卷积神经网络和噪声数据集训练领域,具体地说是一种基于样本筛选与标签矫正的噪声数据集训练方法。提出了一种新的样本筛选策略,其利用基于类级特征的聚类过程来识别在特征空间中靠近其对应类中心的干净样本。该方法还提出了采用一种新的标签矫正策略,将训练标签重建为当前训练标签、特征标签以及预测标签结果的加权组合,随着训练的进行,可以将噪声标签逐渐纠正为正确的标签,这增加了干净的可用于训练的样本数量,降低了训练集的噪声强度,提高了模型的精度与鲁棒性。
Description
技术领域
本发明的技术方案涉及深度学习、卷积神经网络和噪声数据集训练领域,具体地说是一种基于样本筛选与标签矫正的噪声数据集训练方法。
背景技术
随着最近大规模数据集的出现,深度神经网络(Deep Neural Network,DNN)在许多机器学习任务中显著的成功,如计算机视觉、信息检索和自然语言处理。然而,收集大量的高质量数据是相当昂贵和耗时的。为了缓解这一问题,通常会采用一些低成本的替代方案,如网页爬取(web-crawling)、在线查询(online queries)、众包(crowdsourcing)等。虽然这些方法显著降低了图像标注的获取成本,但其无法保证所获取标注的质量,导致数据集中含有噪声标签。许多研究已经表明,在存在噪声标签的情况下,深度神经网络的训练容易受到噪声标签的影响。因为大量的模型参数使得网络具有学习任何复杂函数的能力,使其可以很容易地用任何比例的损坏标签来学习拟合整个训练数据集,最终导致测试数据集的泛化性较差。因此,在带有标签噪声的数据集上训练鲁棒的深度神经网络是实际应用中一个具有挑战性的问题。
如今,常用的处理标签噪声数据集方法可以大致概括为三种类型:鲁棒损失函数、损失校正与样本筛选方法。鲁棒损失函数方法是指是设计一个不受噪声标签过度影响的损失函数,常用的鲁棒损失函数包括:平均绝对误差(Mean Absolute Error,MAE)、互补损失函数、主动-被动损失(Active Passive Loss)等。然而,噪声鲁棒性损失的条件非常严格,它们经常遇到欠拟合的问题。损失矫正是将真实标签视为观测变量,然后使用噪声转移矩阵来校正训练阶段的风险,与此相关的先前工作集中于从给定的噪声标签估计噪声转移矩阵。许多方法通过额外的网络重新加权训练数据来估计噪声转移矩阵。然而,训练附加网络通常需要具有干净标签的小训练集,这在许多应用中并不总是可行的。此外,如何准确估计噪声传递矩阵是一个重大挑战。上述方法可能错误地校正干净数据的损失,并在训练数据中引入新的噪声标签。因此,人们后续提出了样本筛选方法。样本筛选方法试图根据特定的划分标准将训练数据划分为“干净”和“噪声”子集。然后应用不同的策略分别在两个子集上训练模型。现有的划分准则主要基于训练损失,例如小损失准则。小损失准则通过设置阈值选择具有小损失的训练样本作为干净样本。
现有的噪声数据集训练算法存在着一些明显的不足主要有:
(1)大部分算法为了减轻标签噪声样本带来的负面影响而采用了样本筛选方法将训练数据划分为“干净”和“噪声”子集,只用干净样本训练网络,而将噪声样本丢弃。这种策略减少了训练样本的总数,不可避免地会降低网络性能。而且,虽然噪声样本的标签是错误的,但其包含的图像信息如果合理运用则也可以提高DNN的鲁棒性和泛化能力。
(2)大多数样本筛选方法采用小损失准则,然而小损失准则倾向于假设所有小批次的噪声比相同,而且往往需要根基估计的噪声率来缺点在每个小批次中选择样本的数量。但这种假设在现实情况下可能不成立,噪声率也难以准确估计,而且其忽略了不同批次数据中噪声比的波动。
发明内容
针对现有技术的不足,本发明所要解决的技术问题是:提供一种基于样本筛选与标签矫正的噪声数据集训练方法。该方法提出了一种新的样本筛选策略,其利用基于类级特征的聚类过程来识别在特征空间中靠近其对应类中心的干净样本。该方法还提出了采用一种新的标签矫正策略,将训练标签重建为当前训练标签、特征标签以及预测标签结果的加权组合,随着训练的进行,可以将噪声标签逐渐纠正为正确的标签,这增加了干净的可用于训练的样本数量,降低了训练集的噪声强度,提高了模型的精度与鲁棒性。
本发明解决该技术问题所采用的技术方案是:提供一种基于样本筛选与标签矫正的噪声数据集训练方法,该方法的步骤如下:
第一步:初始化网络模型,模型分为特征提取器与分类器两部分,并初始化伪标签为数据集原始噪声标签。用数据集中所有数据对网络进行训练一定轮数,使其获得一定的特征提取能力与分类能力,为后续步骤做好预热准备。
第二步:利用网络模型提取所有样本的特征,并根据类别对特征进行归一化处理,随后计算每个类的特征中心。
第三步:计算每个样本与其对应特征中心的余弦相似度,将高斯混合模型(Gaussian Mixture Model,GMM)应用于每类样本的相似度,并进行二值分类,根据分类结果将训练集分为干净子集与噪声子集。
第四步:计算每个样本与各类特征中心的余弦相似度,并取与其相似度最大的类特征中心对应的类别编号作为特征标签。
第五步:选取一个mini-batch的数据,将其特征输入分类器得到预测结果,根据预测结果与特征标签来更新伪标签,利用伪标签与预测结果的交叉熵更新整个网络的参数。重复该步骤直至数据集所有数据参与训练。
第六步:输出并保存训练好的网络模型。
上述第一步中,噪声数据集表示为其中N表示数据集样本总数量,xi表示第i个样本,对应的噪声标签,C表示数据集的类别数。网络模型分为两部分:以图像为输入并提取其特征的特征提取器F,以及基于由F提取的图像特征输出分类概率的分类器G。fi=F(xi)表示第i个样本经过提取器F提取的特征,pi=G(fi)表示第i个样本的预测标签。伪标签si初始化为数据集原始标签即因为模型参数是随机初始化的,其开始并不具备正确筛选样本以及矫正标签的能力,所以本专利首先用数据集中所有数据对网络进行训练Ew个轮数(epoch),使其获得一定的特征提取能力与分类能力,为后续步骤做好预热准备。Ew需要根据不同的数据集进行适当调整。
预热训练过程中采用的损失函数为传统交叉熵(CrossEntropy,CE)损失函数,因为使用的数据集原始标签,所以其可以参数化为:
随后根据si计算每个类的特征中心Oc:
上述第三步中,本发明接着计算每个样本与其对应特征中心的余弦相似度SIMi:
随后,本发明根据类别,将高斯混合模型(GMM)应用于每类样本的相似度以进行二值分类得到每个样本的置信分数Scorei:
紧接着,我们根据每个样本的置信分数Scorei将样本分为干净样本与噪声样本。具体表示为:如果Scorei≥Ti,则将样本xi视为干净样本,即标签正确的样本,如果Scorei<Ti,样本xi视为噪声样本,即标签错误的样本。
该步骤的目的样本筛选。对于每一类样本,其都有一个特征中心,干净样本的特征相比噪声样本距离特征中心更近些,所以其余弦相似度更大,经过GMM得到的分数也就更高,利用这个特性,本发明可以将每个类别中的样本分为干净样本以及噪声样本,进而为后续训练提供纯净度更高,质量更高的训练样本。
上述第四步中,本发明需要计算每个样本与各类特征中心的余弦相似度COSij:
其中j∈{1,2,3,…,C}表示对应类别编号,C表示类别总数,COSij表示第i个样本与第j类的特征中心Oj之间的余弦相似度。
随后我们根据每个样本与各类特征中心的余弦相似度确定样本的特征标签ui,规则定义为:
ui=argmax(COSi1,COSi2,COSi3,…,COSiC)
其中argmax表示返回其操作集合中最大值的序号索引。例如,如果COSi3的值最大,则返回值为3,即ui=3。
对于每个样本,本发明计算其与各类特征中心的余弦相似度,其与某中心的相似度越大,证明其越有可能属于该类样本,进而可以得到一个关于特征的标签用于后续的标签矫正工作。与第三步不同的是,第三步是针对每一类计算其类内所有样本的余弦相似度,目的是为了样本筛选,确定该类中哪些样本的特征余弦相似度较低,即哪些样本更倾向于是噪声样本,其筛选计算比较范围是该类中的所有样本。而在本步骤中,计算的是每个样本与各个特征中心的余弦相似度,其比较范围只是样本自身的一些数据特征,与其他样本未进行比较。
上述第五步中,本发明选取一小批次(mini-batch)的数据,将其特征f输入分类器得到预测结果p,随后根据预测结果与特征标签来更新伪标签si:
si=αsi+βpi+γui
其中α,β,γ为超参,且α+β+γ=1。这里进行伪标签的更新是为了将很多样本的噪声标签校正为正确的可以使用的噪声标签。
随后利用干净样本伪标签与预测结果的交叉熵更新整个网络的参数,其参数表示为:
其中nc表示该批次样本中干净样本的数量,第三步骤中判定为干净的样本正常参与网络的更新,而判定为噪声的样本,不参与网络的更新,目的是减少网络对噪声样本的学习,增强网络的性能。
重复该步骤直至数据集中所有数据参与训练,即为跑完一轮(一个epoch)。随后重复第二到五步直至跑完设定好的轮数。
上述第六步中,训练结束后,输出并保存训练好的网络模型参数,其保存格式为.pth文件。
与现有技术相比,本发明提出了一种新的样本选择策略,利用基于类级特征的聚类过程来识别在特征空间中接近其对应类中心的干净样本。该样本选择策略不需要数据集的相关先验知识(如噪声率),而且样本选择的数量不受某个比例或数值的约束,这使得它成为一种自然全局选择度量,可以根据不同小批次的噪声比,自适应的选择干净样本,缓解了不同小批次内噪声比不平衡引起的问题。此外本发明提出了的新的标签矫正策略,将训练标签重建为当前训练标签、特征标签以及预测标签结果的加权组合,随着训练的进行,可以将噪声标签逐渐纠正为正确的标签,这增加了干净的可用于训练的样本数量,降低了训练集的噪声强度。此外对于伪标签的更新本发明采用了移动平均方案,该方案使用模型预测逐步纠正有问题的标签,缓解了模型预测的不稳定性问题,使训练过程中更加平滑,并能够在必要时完全更改训练样本的标签。
附图说明
下面结合附图和实施例对本发明进一步说明。
图1为本发明所提样本筛选策略流程示意图。
图2为本发明特征标签产生流程示意图。
图3为本发明整体网络结构示意图。
具体实施方式
计算后生成图1中的标准化特征空间,其中篮色与红色小球分别表示干净样本的特征标准化特征与噪声样本的标准化特征,
随后根据si计算每个类的特征中心Oc:
其中c∈{1,2,3,…,C}表示对应类别编号,C表示类别总数,Nc表示si=c的样本总数,即伪标签类别为c的样本总数。表示对应类别样本中第i个样本的标准化特征。计算得到的特征中心由图1中的绿色小球表示。
接着计算每个样本与其对应特征中心的余弦相似度SIMi:
随后本发明根据类别,将高斯混合模型(Gaussian Mixture Model,GMM)应用于每类样本的相似度以进行二值分类得到每个样本的置信分数Scorei:
Scorei=GMMsi(SIMi)
其中GMMsi表示第i个样本的伪标签si对应的高斯混合模型。
最后根据每个样本的置信分数Scorei将将样本分为干净样本与噪声样本。具体表示为:如果Scorei≥Ti,则将样本xi视为干净样本,即标签正确的样本,如果Scorei<Ti,样本xi视为噪声样本,即标签错误的样本。
其中i∈{1,2,3,…,C}表示对应类别编号,C表示类别总数,COSi表示该样本与第i类的特征中心Oi之间的余弦相似度。随后根据该样本与各类特征中心的余弦相似度确定样本的特征标签ui,规则定义为:
ui=argmax(COS1,COS2,COS3,…,COSC)
其中argmax表示返回其操作集合中最大值的序号索引。
图3为本专利网络模型的整体结构示意图、模型的主干为特征提取器-分类器结构,特征提取器由一个九层的CNN网络构成,中间穿插着池化与Dropout等操作,每次卷积后也都有采用LReLu激活函数。输入的图像经过特征提取器F或会产生一个1*128的特征f,最后特征f经过由全连接层构成的分类器G会产生一个分类结果p。p既是图像分类的结果,又是后面进行伪标签更新的一个参考分量。
实施例1
本实施例采用基于样本筛选与标签矫正的噪声数据集训练方法。在这里,本实施例对含噪声的CIFAR10数据集进行分类训练。
CIFAR10数据集是图像分类的常用数据集。它一共包含10个类别的RGB彩色图片:飞机、汽车、鸟类、猫、鹿、狗、蛙类、马、船和卡车,其有50000张彩色图像用于训练,10000张图像用于测试,分辨率为32×32。
第一步中,首先用含噪声的CIFAR10数据集数据集中所有数据对网络进行训练12轮(epoch),使其获得一定的特征提取能力与分类能力,为后续步骤做好预热准备。
训练过程中采用的损失函数为传统交叉熵损失函数,因为使用的数据集原始标签,所以其可以参数化为:
其中n表示batchsize大小(本次实施n=128),即每次参与训练的样本数,C表示类别数(本次实施C=10),表示第i个样本的标签在第j类上的分量,pij表示第i个样本的预测标签pi在第j类上的分量。
其中c∈{1,2,3,…,10}表示对应类别编号,Nc表示si=c的样本总数,即伪标签类别为c的样本总数。表示对应类别样本中第i个样本的标准化特征。本次实施中Nc=5000,即每类样本有5000个,会产生10个特征中心,特征中心Oc形状的为128*1。
第三步,计算每个样本与其对应特征中心的余弦相似度SIMi:
其中,表示样本的伪标签si对应的特征中心。本次实施实际运算将样本分为了10组,每组的5000个标准化特征合并为一个5000*128的张量,随后与其组对应的128*1的类特征进行矩阵乘法运算得到一个5000*1的张量,对应5000个样本与其对应特征中心的余弦相似度SIM。随后,应用高斯混合模型(GMM)对每组样本的相似度以进行二值分类得到每个样本的置信分数Scorei:
最后根据每个样本的置信分数Scorei将将样本分为干净样本与噪声样本。具体表示为:如果Scorei≥Ti,则将样本xi视为干净样本,即标签正确的样本,如果Scorei<Ti,样本xi视为噪声样本,即标签错误的样本。至此将数据集划分为了两个子集:干净子集与噪声子集。本次实施Ti=0.3。
第四步中,计算每个样本与各类特征中心的余弦相似度COSij:
其中j∈{1,2,3,…,C}表示对应类别编号,C表示类别总数,COSij表示第i个样本与第j类的特征中心Oj之间的余弦相似度。本次实施中是将所有50000个标准化特征进行拼接得到50000*128的张量,再将所有特征中心拼接为一个128*10的张量,矩阵运算的到50000*10的张量结果,每个1*10都表示一个与各类特征中心的余弦相似度结果。随后根据每个样本与各类特征中心的余弦相似度确定样本的特征标签ui,规则定义为:
ui=argmax(COSi1,COSi2,COSi3,…,COSiC)
其中argmax表示返回其操作集合中最大值的序号索引。
第五步中,每次从数据集中选取128条数据,将其特征f输入分类器得到预测结果p,随后根据预测结果与特征标签来更新伪标签si:
si=αsi+βpi+γui
本次实施α=0.8,β=0.1,γ=0.1,随后只利用干净样本伪标签与预测结果的交叉熵来更新整个网络的参数,其参数表示为:
其中nc表示该批次样本中干净样本的数量,之前步骤中判定为干净的样本正常参与网络的更新,而判定为噪声的样本,不参与网络的更新。重复该步骤直至数据集中所有数据参与训练。随后重复第二到五步直至跑完设定好的轮数。
第六步,训练完成,输出并保存训练好的网络模型参数,其保存格式为.pth文件。
本方法模型训练环境基于Python3.6和Pytorch1.8.0搭建,训练集大小为50000,测试集大小为10000,验证实验采用了三种有代表性的噪声类型:(1)对称噪声(Symmetric);(2)成对翻转噪声(Pairflip);(3)三对角噪声(Tridiagonal)以及两种常用噪声比率:20%与40%。网络模型采用一个随机参数初始化的九层卷积神经网络作为特征提取器,Adam(momentum=0.9)作为优化器,初始学习率(Learning Rate)为0.001,每个批次的大小设置为128,一共训练200轮(epoch),预热轮数Ew=12,伪标签更新的超参α,β,γ分别设置为0.8,0.1,0.1。模型训练所使用的硬件环境为Ubuntu18.04操作系统、Intel6140CPU、Nvidia RTX3090 GPU。
本方法采用测试准确率作为评价指标,即:
为评估所提出方法的有效性,本方法选取部分杰出算法并在相同的实验条件下与本方法中的模型性能进行对比,为测试模型实验结果的稳定性,每次实验在相同条件下进行三次,并计算标准差,
如表1所示,本文的方法在各项指标下均取得了最好的结果。
表1不同算法对比实验结果
本发明所涉及的样本筛选方法与伪标签矫正方法均为基于现有方法的改进。
需要说明的是,本发明并不局限于上述具体实施方式中。在不脱离本发明原理的情况下,凡是本领域技术人员在本发明的启示下获得的其它实施方式,均视为在本发明的保护之内。
Claims (4)
1.一种基于样本筛选与标签矫正的噪声数据集训练方法,其特征在于:包括以下步骤:
第一步:初始化网络模型,模型分为特征提取器与分类器两部分,并初始化伪标签为数据集原始噪声标签,用数据集中所有数据对网络进行训练一定轮数,使其获得一定的特征提取能力与分类能力,为后续步骤做好预热准备;
第二步:利用网络模型提取所有样本的特征,并根据类别对特征进行归一化处理,随后计算每个类的特征中心;
第三步∶计算每个样本与其对应特征中心的余弦相似度,将高斯混合模型(GaussianMixture Model,GMM)应用于每类样本的相似度,并进行二值分类,根据分类结果将训练集分为干净子集与噪声子集;
第四步:计算每个样本与各类特征中心的余弦相似度,并取与其相似度最大的类特征中心对应的类别编号作为特征标签;
第五步:选取一个mini-batch的数据,将其特征输入分类器得到预测结果,根据预测结果与特征标签来更新伪标签,利用伪标签与预测结果的交叉熵来更新整个网络的参数,重复该步骤直至数据集所有数据参与训练,
第六步∶输出并保存训练好的网络模型。
2.根据权利要求1所述的一种基于样本筛选与标签矫正的噪声数据集训练方法,其特征在于:第三步为计算每个样本与其对应特征中心的余弦相似度,随后应用高斯混合模型进行二值分类,根据分类结果将训练集分为干净子集与噪声子集,
其中每个样本与其对应特征中心的余弦相似度SIMi计算公式为:
随后将高斯混合模型(GMM)应用于每类样本的相似度以进行二值分类得到每个样本的置信分数Scorei:
紧接着,我们根据每个样本的置信分数Scorei将样本分为干净样本与噪声样本,具体表示为:如果Scorei≥Ti,则将样本xi视为干净样本,即标签正确的样本,如果Scorei<Ti,样本xi视为噪声样本,即标签错误的样本。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211392565.XA CN115641480A (zh) | 2022-11-08 | 2022-11-08 | 一种基于样本筛选与标签矫正的噪声数据集训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211392565.XA CN115641480A (zh) | 2022-11-08 | 2022-11-08 | 一种基于样本筛选与标签矫正的噪声数据集训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115641480A true CN115641480A (zh) | 2023-01-24 |
Family
ID=84949685
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211392565.XA Pending CN115641480A (zh) | 2022-11-08 | 2022-11-08 | 一种基于样本筛选与标签矫正的噪声数据集训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115641480A (zh) |
-
2022
- 2022-11-08 CN CN202211392565.XA patent/CN115641480A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113378632B (zh) | 一种基于伪标签优化的无监督域适应行人重识别方法 | |
CN107122809B (zh) | 基于图像自编码的神经网络特征学习方法 | |
US20180341862A1 (en) | Integrating a memory layer in a neural network for one-shot learning | |
CN109063719B (zh) | 一种联合结构相似性和类信息的图像分类方法 | |
CN110866530A (zh) | 一种字符图像识别方法、装置及电子设备 | |
CN113326731A (zh) | 一种基于动量网络指导的跨域行人重识别算法 | |
CN110619059B (zh) | 一种基于迁移学习的建筑物标定方法 | |
CN112733866A (zh) | 一种提高可控图像文本描述正确性的网络构建方法 | |
CN110347857B (zh) | 基于强化学习的遥感影像的语义标注方法 | |
CN107491729B (zh) | 基于余弦相似度激活的卷积神经网络的手写数字识别方法 | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN110598022B (zh) | 一种基于鲁棒深度哈希网络的图像检索系统与方法 | |
CN108052959A (zh) | 一种提高深度学习图片识别算法鲁棒性的方法 | |
CN113076927A (zh) | 基于多源域迁移的指静脉识别方法及系统 | |
Chen et al. | Military image scene recognition based on CNN and semantic information | |
CN113065409A (zh) | 一种基于摄像分头布差异对齐约束的无监督行人重识别方法 | |
CN113723572B (zh) | 船只目标识别方法、计算机系统及程序产品、存储介质 | |
CN111310820A (zh) | 基于交叉验证深度cnn特征集成的地基气象云图分类方法 | |
CN113095229B (zh) | 一种无监督域自适应行人重识别系统及方法 | |
CN114267060A (zh) | 基于不确定抑制网络模型的人脸年龄识别方法及系统 | |
CN111461229B (zh) | 一种基于目标传递和线搜索的深层神经网络优化及图像分类方法 | |
CN111144469B (zh) | 基于多维关联时序分类神经网络的端到端多序列文本识别方法 | |
CN111783688A (zh) | 一种基于卷积神经网络的遥感图像场景分类方法 | |
CN115641480A (zh) | 一种基于样本筛选与标签矫正的噪声数据集训练方法 | |
CN113963235A (zh) | 一种跨类别图像识别模型重用方法和系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |