CN116757261A - 基于带有闭集噪声和开集噪声标签的鲁棒学习方法 - Google Patents

基于带有闭集噪声和开集噪声标签的鲁棒学习方法 Download PDF

Info

Publication number
CN116757261A
CN116757261A CN202311031130.7A CN202311031130A CN116757261A CN 116757261 A CN116757261 A CN 116757261A CN 202311031130 A CN202311031130 A CN 202311031130A CN 116757261 A CN116757261 A CN 116757261A
Authority
CN
China
Prior art keywords
sample
training
samples
clean
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311031130.7A
Other languages
English (en)
Inventor
李绍园
万文海
陈松灿
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing University of Aeronautics and Astronautics
Original Assignee
Nanjing University of Aeronautics and Astronautics
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing University of Aeronautics and Astronautics filed Critical Nanjing University of Aeronautics and Astronautics
Priority to CN202311031130.7A priority Critical patent/CN116757261A/zh
Publication of CN116757261A publication Critical patent/CN116757261A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Multimedia (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于带有闭集噪声和开集噪声标签的鲁棒学习方法,该方法旨在利用有用的开集示例,同时最大限度地减少闭集错误标记示例的负面影响。本发明分为两个阶段,第一阶段中,利用干净样本选择策略做训练初始化,并记录下来样本修正标记以及标记修正记录供第二阶段优化;在第二阶段中,利用Class Expansion的思想,将部分开集样本融入已知类进行训练,将剩余的具有判别性的开集样本进一步帮助模型提升其判别性。本发明方法针对数据集中存在闭集噪声和开集噪声的问题,使用了类扩展的思想,接纳了一部分开集类别样本,并且充分利用了剩余的开集样本,进一步提升了深度学习模型的准确率。

Description

基于带有闭集噪声和开集噪声标签的鲁棒学习方法
技术领域
本发明涉及一种基于带有闭集噪声和开集噪声标签的鲁棒学习方法。
背景技术
深度神经网络(DNN)在各种任务中取得了显著的成功,例如图像分类、物体检测、语音识别和机器翻译。但需要注意的是,这样的成功主要归因于大量高质量注释的数据,而在实践中收集这些数据是昂贵甚至不可行的。事实上,现有的大部分基准数据集都是从搜索引擎或网络爬虫中收集的,这不可避免地涉及到噪声标记。
鉴于DNN的强大学习能力,模型最终将过度拟合噪声标记,导致泛化性能差。为了缓解这个问题,开发能够学习噪声标记的强大模型具有重要意义,而在存在闭集噪声的同时也存在开集噪声,因此,在这个问题中对开集噪声的处理至关重要。
曾有研究表明模型在训练过程中面对开集样本的行为,并观察到一些开集类别与多个闭集类别集成在一起,称之为Class Expansion。具体来说,对带标记的闭集样本进行训练,并对开集示例生成伪标记以促进学习,这种方法不会损害模型学习,甚至一些特定的开放集类别的示例得到了很好的分类。此外,其他开集示例均匀分布在多个封闭集类别中。另外,在训练过程中添加适当的开集示例甚至可以提高模型的性能。
发明内容
本发明的目的在于提出一种基于带有闭集噪声和开集噪声标签的鲁棒学习方法,该方法针对数据集中存在闭集噪声和开集噪声的问题,旨在最大限度地减少闭集错误标记示例的负面影响的同时,最大化有用的开集示例对模型学习过程中带来的收益。
本发明为了实现上述目的,采用如下技术方案:
基于带有闭集噪声和开集噪声标签的鲁棒学习方法,包括以下步骤:
步骤1. 获取带有闭集噪声和开集噪声的数据集;其中,D表示由图像x i 以及对应的噪声标记y i 组成的数据集,ND中的样本总数,i=1,…,N
步骤2. 开始预训练,并初始化当前预训练次数t 0和预训练总次数T 0;确定干净样本挑选方法,并初始化干净样本筛选率参数ρλ
步骤3. 搭建深度学习模型M 0以及损失函数L ce
步骤4. 将数据集D中图像x i 和对应的标记y i 输入到模型M 0进行预训练,训练T 0轮,在预训练阶段样本输入模型后获得对应的输出,并结合标记计算出样本的交叉熵损失;
步骤5. 判断当前预训练次数t 0是否达到预训练总次数T 0
若当前预训练次数t 0未达到预训练总次数T 0,返回步骤4继续训练;否则进行如下处理:
对得到的数据集D上所有样本的交叉熵损失从小到大进行升序排列,然后根据干净样本挑选方法选择干净样本参与训练,并为非干净样本打上伪标记;
将干净样本作为有监督数据集D clean ,将非干净样本作为无监督数据集D dirty
步骤6. 将干净样本标记与非干净样本的伪标记结合,获得数据集D上的更为精确的标记以及修正记录T,用于下述步骤8中的优化训练阶段;
步骤7. 重新搭建深度学习模型M 1以及分类损失L cls 、对比损失L cont ;其中,分类损失L cls 用于帮助模型分类,对比损失L cont 用于帮助模型获得更优的表征学习能力;
步骤8. 开始优化训练,并初始化当前训练次数t 1和训练总次数T 1,利用干净样本生成类别原型;并根据原型做开集决策;并维持一个动量队列;
步骤9. 判断当前训练次数t 1是否达到训练总次数T 1;若当前训练次数t 1未达到训练总次数T 1,则返回步骤8继续训练;否则转到步骤10;
步骤10. 模型训练完成后,得到能够在数据集上执行分类预测任务的深度学习模型M 1,利用该训练好的深度学习模型M 1对输入图像进行类别预测。
本发明具有如下优点:
如上所述,本发明述及了一种基于带有闭集噪声和开集噪声标签的鲁棒学习方法,该方法提出了一个两步学习框架来解决开集噪声标记学习问题,旨在利用有用的开集示例,同时最大限度地减少闭集错误标记示例的负面影响。在第一步中,本发明采用成熟的方法来处理噪声标记,并保持闭合集类别的基本概念,为了进一步提高模型的预测准确性,本发明采用改良过的对比学习方案,在第二步训练过程中包括选择的开集示例,此外,本发明使用其余被忽略的开集示例作为分界点,以增强模型的表示学习能力。
附图说明
图1是本发明实施例中基于带有闭集噪声和开集噪声标签的鲁棒学习方法的流程图。
图2是本发明实施例中整体模型的结构示意图。
图3是本发明实施例中筛选干净样本的流程示意图。
具体实施方式
下面结合附图以及具体实施方式对本发明作进一步详细说明:
本实施例述及了一种基于带有闭集噪声和开集噪声标签的鲁棒学习方法,该方法旨在利用有用的开集示例,同时最大限度地减少闭集错误标记示例的负面影响。本发明分为两个阶段,第一阶段中,利用干净样本选择策略做训练初始化,并记录下来样本修正标记以及标记修正记录供第二阶段优化,在第二阶段中,利用Class Expansion的思想,将部分开集样本融入已知类进行训练,将剩余的具有判别性的开集样本进一步帮助模型提升其判别性。
如图1所示,基于带有闭集噪声和开集噪声标签的鲁棒学习方法,包括以下步骤:
步骤1. 获取带有闭集噪声和开集噪声的数据集;其中N为数据集D中的样本总数;D代表由原始图像x i 以及对应的噪声标记y i 组成的数据集,i=1,…,N
步骤2. 开始预训练,初始化各项参数,包括当前预训练次数t 0和预训练总次数T 0,确定干净样本挑选方法,并初始化干净样本筛选率参数ρλ
其中,参数ρ控制每类挑选样本数量均衡,λ用于挑选样本的置信度阈值。本实施例中确定的干净样本挑选方法包括CSS以及MHCS样本选择方法。
下面对CSS以及MHCS两种样本选择方法作如下说明:
Class-wise Small-loss Selection (CSS)的具体策略为:根据模型预测将整个训练数据分成C个集合S j ={(x i ,y i )|j=argmax jSj f j (x i )}。
其中,x i y i 表示第i个样本及其标记,f j (x i )表示对样本x i 关于第j类的预测概率,argmax jSj f j (x i )表示对样本x i 预测概率最高的类别。
对于第j个集合,计算每个示例的交叉熵损失值l i ,并选择k=min(ρn/C,|S j |)个具有最小l i 的示例作为干净示例,其中n为数据集D中的样本总数。
与原始的小损失选择方法相比,本发明将k与平均每个类别n/C的示例数量相关联,这样可以产生大致平衡的小损失集合。
Matched High-Confidence Selection (MHCS) 的具体策略为:
为每个示例计算置信度分数。也就是说,选择那些具有高置信度e i λ的示例,同时它们的预测结果也应该与给定的标记匹配。
在实践中,本发明设置一个高阈值,以使所选样本的干净概率更高。
步骤3. 搭建深度学习模型M 0以及损失函数L ce
步骤4. 将数据集D中图像x i 和对应的标记y i 输入到模型M 0进行预训练,训练T 0轮,在预训练阶段样本输入模型后获得对应的输出,并结合标记计算出样本的交叉熵损失。
深度学习模型M 0包括特征提取器F以及分类器G
对深度学习模型M 0进行预训练的过程如下:
步骤4.1. 将数据集D中的图像x i 输入特征提取器F中得到高维特征f i
其中,f i =F(x i )。
步骤4.2. 将高维特征f i 输入分类器G中,得到类别预测结果p i =G(f i ),再利用交叉熵函数L ce 计算类别预测结果p i 和与图像x i 对应的噪声标记y i 之间的交叉熵损失l i
其中,p i 是深度学习模型M 0对图像x i 的类别预测结果;是步骤6中获得的标记/>中对应的标记,p i c 是深度学习模型M 0对图像x i 在类别c上的预测概率值,/>是步骤6中获得的标记/>关于图像x i 的独热编码在类别c上的值,C表示数据集D中的类别总数,c∈[1,C]。
步骤5. 判断当前预训练次数t 0是否达到预训练总次数T 0
若当前预训练次数t 0未达到预训练总次数T 0,返回步骤4继续训练;否则进行如下处理:
如图3所示,对得到的数据集D上所有样本的交叉熵损失从小到大进行升序排列,然后根据干净样本挑选方法选择干净样本参与训练,并为非干净样本打上伪标记。
将干净样本作为有监督数据集D clean ,将非干净样本作为无监督数据集D dirty
步骤6. 将干净样本标记与非干净样本的伪标记结合,获得完整数据集D上的更为精确的标记(对于每一个样本,若被视为了干净,则它的标记不变;反之则将其标记更改为模型预测的标记),以及修正记录T,用于下述步骤8中的优化训练阶段。
修正记录T用于表示数据集中每个样本标记是否被修改,其表达式如下:
其中,表示数据集中样本x i 的修正记录;/>的取值为0或1;当=0时,表明样本x i 被视为了干净样本,其标记没被修改;当/>=1时,表明样本x i 标记被模型修改过。
步骤7. 重新搭建深度学习模型M 1以及分类损失L cls 、对比损失L cont ;其中,分类损失L cls 用于帮助模型分类,对比损失L cont 用于帮助模型获得更优的表征学习能力。
尽管在步骤6中获得了相对干净的标记,但没有对开集样本进行处理并挖掘有用的信息,因此,在步骤7中设计了监督对比损失项来最大化开集样本所带来的收益。
如图2所示,搭建深度学习模型M 1,对于每个输入深度学习模型M 1的样本标记对(x,),通过随机数据增强函数分别生成两个视图,即query视图a1(x)和key视图a2(x)。
其中,x表示输入模型的样本,表示步骤6中获取的样本对应的标记。
再将query视图a1(x)和key视图a2(x)分别送入backbone网络gbackbone网络g'中;其中,网络g后面分别接了一个多层感知机qn和一个分类头ch,网络g'后面接了一个多层感知机kn,从而产生一对L 2规范化的嵌入向量q=qn(g(a1(x)))和k=kn(g'(a2(x)))。
其中,g'是由g动量更新而来,knqn由动量更新而来。
qn·g称为query分支,将kn·g'称为key分支。
使用动量更新方法来更新key分支网络;维护一个动量队列queue,按时间顺序存储最近的键嵌入向量,并不断地在训练过程中以先入先出的方式更新该动量队列queue
结合当前训练batch中的嵌入向量以及动量队列queue中维护的嵌入向量,得到对比嵌入向量池:A=B q B k queue
其中,B q B k 分别表示对当前训练batchquery视图的嵌入向量和key视图的嵌入向量,A表示B q B k 以及维护的动量队列的并集。
query视图的嵌入向量简称query嵌入向量,将key视图的嵌入向量简称key嵌入向量。
对于样本x,每个样本x输入深度学习模型M 1所获得的对比损失由将其query嵌入向量与对比嵌入向量池A进行对比,以获得监督对比损失。
步骤8. 开始优化训练,初始化参数,包括当前训练次数t 1和训练总次数T 1,利用干净样本生成类别原型并根据原型做开集决策。
其中开集决策阈值为φ,并维持一个动量队列,动量队列长度ι
在优化训练阶段,根据干净样本生成干净的类别原型,然后利用原型去引导开集决策,对于决策为开集的样本,由于其显著地不同于任何已知类,它通常处于类间,本发明便可利用它的这种“中间性”来增强模型学习在已知类之间的判别性。
步骤8.1. 将有监督数据集D clean 中的图像x i 的弱增强版本a1(x i )和强增强版本a2(x i )分别输入query网络和key网络,分别得到q i =qn(g(a1(x i )))和k i =kn(g'(a2(x i )))。
获得q i k i 后,会再进行一次L 2规范化,获得最终的嵌入向量,即q i =L2(q i ),k i =L2(k i )。
其中,图像x i D clean 中的图像,a1(x i )是图像x i 通过缩放、旋转操作后得到的新图像,a2(x i )是图像x i 通过不同程度变化以及扰动得到的严重失真的新图像。
步骤8.2. 首先进行warm up训练。
对于D clean 中的图像即样本,根据步骤8.1为每一个样本提取到特征q i 后,根据类别为q i 分组,然后再对每一个样本进行标准化,并据此为每一个类生成一个原型Q c
其中,n c 表示第c类的样本个数,表示第c类的第j个样本的嵌入向量,j∈[1,n c ]。在训练过程中不断地用对应类别样本的特征以动量更新的方式更新原型Q c
Q c =Normalize(γQ c +(1-γ)q)。
其中,q表示输入样本的嵌入向量,其类别为cQ c 表示类别c的原型;Normalize(·)表示对向量的标准化操作,γ是一个动量移动参数。
定义g(a1(x i ))表示a1(x i )经过backbone网络g的输出特征,则将特征g(a1(x i ))输入分类头ch中,得到类别预测结果p i =ch(g(a1(x i )))。
利用交叉熵函数L ce 分别计算类别预测结果p i 与图像x i 对应的标记之间的交叉熵损失之和以优化query分支的g和多层感知机qn;/>是步骤6中获得的/>中对应的标记。
步骤8.3. 若样本x i 对应的修正记录=1,则证明标记被修正过,修正之后的标记记为/>,按步骤8.1中提取到的x i 对应的特征q i 与对应的原型/>做距离度量。
经过计算若距离小于φ,则该样本x i Class Expansion的形式引入已知类中;否则,将样本x i 视为具有判别性的开集样本,并进一步帮助模型学习类别之间的判别性。
若样本x i 对应的修正记录=0,则样本x i 为干净样本,正常参与模型优化训练。
其中,φ为开集决策阈值,表示类别标记/>对应的原型。
对于一个样本标记对(x i ,),该样本标记对经过query分支得到的嵌入向量为:q i =qn(g(a1(x i ))),那么正常参与训练的样本定义为:
F x =I((=0)或者 (/>=1且Distance(q i ,/>)<φ) )。
其中,F x 表示正常参与训练的样本;I(·)表示指示函数,是一个随机变量,当事件发生时指示函数取值为 1,当事件不发生时指示函数取值为 0。
表示样本x i 在步骤6中获得的修正标记。
将所有样本以及其对应的类别标记存入动量队列中,以构造对比学习正例对以及利用具有判别性的开集样本帮助模型深度学习模型M 1的学习。
下面将为D clean D dirty 两部分分别构造正例集合以进行监督对比学习。
D clean 部分的正例集合P clean (x)表示为:
P clean (x)={k|kA(x),y=,T x =0}。
其中,A(x)={A\(q)};A=B q B k queueq表示qn(g(a1(x))),A\(q)表示A去掉q之后的集合;T x 表示样本x的修正标记,T x =0说明样本x为干净样本。
表示图像x在步骤6中获得的标记/>中对应的标记。
D dirty 部分的正例集合P dirty (x)表示为:
P dirty (x)={k|kA(x),y=,T x =1,Distance(k,Q y )<φ}。
其中,T x =1说明样本x标记在预训练阶段被模型M 0修正过,Distance(k,Q y )<φ表示嵌入向量k与类别y的原型的距离是否小于开集决策阈值φ
正例集合P(x)表示为:P(x)=P clean (x) ∪P dirty (x)。
步骤8.4. 根据步骤8.1中得到的q i k i ,计算样本x i 的正负样本对在低维特征空间中的距离,利用监督对比损失L cont 优化模型,具体的形式为:
其中,k +表示正例集合P(x i )中的所有样本。
τ是温度参数,k'表示A(x i )中的所有样本,即k'∈A(x i )。
另外,分类损失L cls 的公式如下:
其中,是步骤8.3中用来区分开集样本是否融入已知类的决策记录;当/>=1时,I(/>=1)返回1,否则返回0;N是样本个数,C是类别总数,i∈[1,N],j∈[1,C]。
步骤8.5. 结合步骤8.4,构造出深度学习模型M 1优化的总体Loss
Loss=L cls +βL cont
其中,β是调节对比Loss权重的参数。
步骤8.6. 根据步骤8.5的Loss更新query分支以及分类头ch后,再以动量更新的方式去更新key分支的backbone网络g'和多层感知机kn
本发明根据得到的标记以及修正记录,利用基于原型的开集决策方式区分出作为类扩充的开集样本和可辨别的具有明显区分度的开集样本,将类扩充的开集样本视为正常已知类样本参与训练,将可辨别的开集样本参与对比学习以帮助模型获取更加具有判别性的表征。
步骤9. 判断当前训练次数t 1是否达到训练总次数T 1;若当前训练次数t 1未达到训练总次数T 1,则返回步骤8继续训练;否则转到步骤10。
步骤10. 模型训练完成后,得到能够在数据集上执行分类预测任务的深度学习模型M 1,利用该训练好的深度学习模型M 1对输入图像进行类别预测。
本发明方法针对数据集中存在闭集噪声和开集噪声的问题,使用了类扩展的思想,接纳了一部分开集类别样本,并且充分利用了剩余的开集样本,提升了模型的预测准确率。
当然,以上说明仅仅为本发明的较佳实施例,本发明并不限于列举上述实施例,应当说明的是,任何熟悉本领域的技术人员在本说明书的教导下,所做出的所有等同替代、明显变形形式,均落在本说明书的实质范围之内,理应受到本发明的保护。

Claims (7)

1.基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,包括以下步骤:
步骤1. 获取带有闭集噪声和开集噪声的数据集 ;其中,D表示由图像x i 以及对应的噪声标记y i 组成的数据集,ND中的样本总数,i=1,…,N
步骤2. 开始预训练,并初始化当前预训练次数t 0和预训练总次数T 0;确定干净样本挑选方法,并初始化干净样本筛选率参数ρλ
步骤3. 搭建深度学习模型M 0以及损失函数L ce
步骤4. 将数据集D中图像x i 和对应的标记y i 输入到模型M 0进行预训练,训练T 0轮,在预训练阶段样本输入模型后获得对应的输出,并结合标记计算出样本的交叉熵损失;
步骤5. 判断当前预训练次数t 0是否达到预训练总次数T 0
若当前预训练次数t 0未达到预训练总次数T 0,返回步骤4继续训练;否则进行如下处理:
对得到的数据集D上所有样本的交叉熵损失从小到大进行升序排列,然后根据干净样本挑选方法选择干净样本参与训练,并为非干净样本打上伪标记;
将干净样本作为有监督数据集D clean ,将非干净样本作为无监督数据集D dirty
步骤6. 将干净样本标记与非干净样本的伪标记结合,获得数据集D上的更为精确的标记以及修正记录T,用于下述步骤8中的优化训练阶段;
步骤7. 重新搭建深度学习模型M 1以及分类损失L cls 、对比损失L cont ;其中,分类损失L cls 用于帮助模型分类,对比损失L cont 用于帮助模型获得更优的表征学习能力;
步骤8. 开始优化训练,并初始化当前训练次数t 1和训练总次数T 1,利用干净样本生成类别原型;并根据原型做开集决策;并维持一个动量队列;
步骤9. 判断当前训练次数t 1是否达到训练总次数T 1;若当前训练次数t 1未达到训练总次数T 1,则返回步骤8继续训练;否则转到步骤10;
步骤10. 模型训练完成后,得到能够在数据集上执行分类预测任务的深度学习模型M 1,利用该训练好的深度学习模型M 1对输入图像进行类别预测。
2.根据权利要求1所述的基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,所述步骤2中,干净样本挑选方法包括CSS以及MHCS样本选择方法。
3.根据权利要求1所述的基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,所述步骤2中,ρ用于控制每类挑选样本数量均衡,λ为挑选样本的置信度阈值。
4.根据权利要求1所述的基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,所述步骤4中,深度学习模型M 0包括特征提取器F以及分类器G
对深度学习模型M 0进行预训练的过程如下:
步骤4.1. 将数据集D中的图像x i 输入特征提取器F中得到高维特征f i
其中,f i =F(x i );
步骤4.2. 将高维特征f i 输入分类器G中,得到类别预测结果p i =G(f i ),再利用交叉熵函数L ce 计算类别预测结果p i 和与图像x i 对应的噪声标记y i 之间的交叉熵损失l i
其中,p i 是深度学习模型M 0对图像x i 的类别预测结果;是步骤6中获得的标记/>中对应的标记,p i c 是深度学习模型M 0对图像x i 在类别c上的预测概率值,/>是步骤6中获得的标记关于图像x i 的独热编码在类别c上的值,C表示数据集D中的类别总数,c∈[1,C]。
5.根据权利要求1所述的基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,所述步骤6具体为:
利用干净样本训练模型,再对非干净样本打上已知类伪标记,将干净样本的标记以及非干净样本的伪标记结合,记录为数据集D的标记;同时记录样本的标记是否被更改,得到修正记录T,用于表示数据集中每个样本标记是否被修改,其表达式如下:
其中,表示数据集中样本x i 的修正记录;/>的取值为0或1;当/>=0时,表明样本x i 被视为了干净样本,其标记没被修改;当/>=1时,表明样本x i 标记被模型修改过。
6.根据权利要求1所述的基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,所述步骤7具体为:
搭建深度学习模型M 1,对于每个输入深度学习模型M 1的样本标记对(x, ),通过随机数据增强函数分别生成两个视图,即query视图a1(x)和key视图a2(x);
其中,x表示输入模型的样本,表示步骤6中获取的样本对应的标记;
再将query视图a1(x)和key视图a2(x)分别送入backbone网络gbackbone网络g'中;其中,网络g后面分别接了一个多层感知机qn和一个分类头ch,网络g'后面接了一个多层感知机kn,从而产生一对L 2规范化的嵌入向量q=qn(g(a1(x)))和k=kn(g'(a2(x)));
其中,g'是由g动量更新而来,knqn由动量更新而来;
qn·g称为query分支,将kn·g'称为key分支;
使用动量更新方法来更新key分支网络;维护一个动量队列queue,按时间顺序存储最近的键嵌入向量,并不断地在训练过程中以先入先出的方式更新该动量队列queue
结合深度学习模型M 1当前训练batch中的嵌入向量以及动量队列queue中维护的嵌入向量,得到对比嵌入向量池:A=B q B k queue
其中,B q B k 分别表示对当前训练batchquery视图的嵌入向量和key视图的嵌入向量,A表示B q B k 以及维护的动量队列的并集;
query视图的嵌入向量简称query嵌入向量,将key视图的嵌入向量简称key嵌入向量;
对于样本x,每个样本x输入深度学习模型M 1所获得的对比损失由将其query 嵌入向量与对比嵌入向量池A进行对比,以获得监督对比损失。
7.根据权利要求6所述的基于带有闭集噪声和开集噪声标签的鲁棒学习方法,其特征在于,所述步骤8具体为:
步骤8.1. 将有监督数据集D clean 中的图像x i 的弱增强版本a1(x i )和强增强版本a2 (x i )分别输入query网络和key网络,分别得到q i =qn(g(a1 (x i )))和k i =kn(g'(a2 (x i )));
获得q i k i 后,会再进行一次L 2规范化,获得最终的嵌入向量,即q i =L2(q i ),k i =L2(k i );
其中,图像x i D clean 中的图像,a1(x i )是图像x i 通过缩放、旋转操作后得到的新图像,a2(x i ) 是图像x i 通过不同程度变化以及扰动得到的严重失真的新图像;
步骤8.2. 首先进行warm up训练;
对于D clean 中的图像即样本,根据步骤8.1为每一个样本提取到特征q i 后,根据类别为q i 分组,然后再对每一个样本进行标准化,并据此为每一个类生成一个原型Q c
其中,n c 表示第c类的样本个数,q cj 表示第c类的第j个样本的嵌入向量,j∈[1, n c ];
在训练过程中不断地用对应类别样本的特征以动量更新的方式更新原型Q c
Q c =Normalize(γQ c +(1-γ)q);
其中,q表示输入样本的嵌入向量,其类别为cQ c 表示类别c的原型;Normalize(·)表示对向量的标准化操作,γ是一个动量移动参数;
定义g(a1(x i ))表示a1(x i )经过backbone网络g的输出特征,则将特征g(a1(x i ))输入分类头ch中,得到类别预测结果p i =ch(g(a1 (x i )));
利用交叉熵函数L ce 分别计算类别预测结果p i 与图像x i 对应的标记之间的交叉熵损失之和以优化query分支的g和多层感知机qn,/>是步骤6中获得的/>中对应的标记;
步骤8.3. 若样本x i 对应的修正记录=1,则证明标记被修正过,修正之后的标记记为,按步骤8.1中提取到的x i 对应的特征q i 与对应的原型/>做距离度量;
经过计算若距离小于φ,则该样本x i Class Expansion的形式引入已知类中;否则,将样本x i 视为具有判别性的开集样本,并进一步帮助模型学习类别之间的判别性;
若样本x i 对应的修正记录=0,则样本x i 为干净样本,正常参与模型优化训练;
其中,φ为开集决策阈值,表示类别标记/>对应的原型;
对于一个样本标记对(x i , ),该样本标记对经过query分支得到的嵌入向量为:q i =qn(g(a1(x i ))),那么正常参与训练的样本定义为:
F x =I((=0)或者 (/>=1且Distance(q i , />)<φ) );
其中,F x 表示正常参与训练的样本;I(·)表示指示函数,I(·)是一个随机变量,当事件发生时指示函数取值为 1,当事件不发生时指示函数取值为 0;
表示样本x i 在步骤6中获得的修正标记;
将所有样本以及其对应的类别标记存入动量队列中,以构造对比学习正例对以及利用具有判别性的开集样本帮助模型深度学习模型M 1的学习;
下面将为D clean D dirty 两部分分别构造正例集合以进行监督对比学习;
D clean 部分的正例集合P clean (x)表示为:
P clean (x)={k| kA(x),y=,T x =0};
其中,A(x)={A(q)};A=B q B k queueq表示qn(g(a1 (x))),A(q)表示A去掉q之后的集合;T x 表示样本x的修正标记,T x =0说明样本x为干净样本;
表示图像x在步骤6中获得的标记/>中对应的标记;
D dirty 部分的正例集合P dirty (x)表示为:
P dirty (x)={k| kA(x),y=,T x =1,Distance(k,Q y )< φ};
其中,T x =1说明样本x标记在预训练阶段被模型M 0修正过,Distance(k,Q y )< φ表示嵌入向量k与类别y的原型的距离是否小于开集决策阈值φ
正例集合P (x)表示为:P (x)=P clean (x) ∪P dirty (x);
步骤8.4. 根据步骤8.1中得到的q i k i ,计算样本x i 的正负样本对在低维特征空间中的距离,利用监督对比损失L cont 优化模型,具体的形式为:
其中,k +表示正例集合P (x i )中的所有样本;
τ是温度参数,k'表示A(x i )中的所有样本,即k'∈A(x i );
另外,分类损失L cls 的公式如下:
其中,是步骤8.3中用来区分开集样本是否融入已知类的决策记录;当/>=1时,I(=1)返回1,否则返回0;N是样本个数,C是类别总数,i∈[1,N],j∈[1,C];
步骤8.5. 结合步骤8.4,构造出深度学习模型M 1优化的总体Loss
Loss=L cls +βL cont
其中,β是调节对比Loss权重的参数;
步骤8.6. 根据步骤8.5的Loss更新query分支以及分类头ch后,再以动量更新的方式去更新key分支的backbone网络g'和多层感知机kn
CN202311031130.7A 2023-08-16 2023-08-16 基于带有闭集噪声和开集噪声标签的鲁棒学习方法 Pending CN116757261A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311031130.7A CN116757261A (zh) 2023-08-16 2023-08-16 基于带有闭集噪声和开集噪声标签的鲁棒学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311031130.7A CN116757261A (zh) 2023-08-16 2023-08-16 基于带有闭集噪声和开集噪声标签的鲁棒学习方法

Publications (1)

Publication Number Publication Date
CN116757261A true CN116757261A (zh) 2023-09-15

Family

ID=87959413

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311031130.7A Pending CN116757261A (zh) 2023-08-16 2023-08-16 基于带有闭集噪声和开集噪声标签的鲁棒学习方法

Country Status (1)

Country Link
CN (1) CN116757261A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118097319A (zh) * 2024-04-29 2024-05-28 南京航空航天大学 在线流数据中带有未见类和噪声标签的图像分类方法

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114358117A (zh) * 2021-11-24 2022-04-15 珠海亿智电子科技有限公司 基于网络数据的模型训练方法、装置、电子设备及介质
CN115331088A (zh) * 2022-10-13 2022-11-11 南京航空航天大学 基于带有噪声和不平衡的类标签的鲁棒学习方法
CN116089883A (zh) * 2023-01-30 2023-05-09 北京邮电大学 用于提高已有类别增量学习新旧类别区分度的训练方法

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114358117A (zh) * 2021-11-24 2022-04-15 珠海亿智电子科技有限公司 基于网络数据的模型训练方法、装置、电子设备及介质
CN115331088A (zh) * 2022-10-13 2022-11-11 南京航空航天大学 基于带有噪声和不平衡的类标签的鲁棒学习方法
CN116089883A (zh) * 2023-01-30 2023-05-09 北京邮电大学 用于提高已有类别增量学习新旧类别区分度的训练方法

Non-Patent Citations (5)

* Cited by examiner, † Cited by third party
Title
RAGAV SACHDEVA等: "EvidentialMix: Learning with Combined Open-set and Closed-set Noisy Labels", 《ARXIV》, pages 1 - 9 *
SHAO‑YUAN LI等: "Improving deep label noise learning with dual active label correction", 《MACHINE LEARNING》, no. 2022, pages 1103 *
WENHAI WAN等: "Unlocking the Power of Open Set : A New Perspective for Open-set Noisy Label Learning", 《ARXIV》, pages 1 - 10 *
诸建超: "基于重心学习的图像识别及应用平台", 《中国优秀硕士学位论文全文数据库 信息科技辑》, no. 12, pages 138 - 365 *
陈蕾 等: "矩阵补全模型及其算法研究综述", 《软件学报》, vol. 2, no. 6, pages 1547 - 1564 *

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118097319A (zh) * 2024-04-29 2024-05-28 南京航空航天大学 在线流数据中带有未见类和噪声标签的图像分类方法

Similar Documents

Publication Publication Date Title
WO2022037233A1 (zh) 一种基于自监督知识迁移的小样本视觉目标识别方法
CN110609891B (zh) 一种基于上下文感知图神经网络的视觉对话生成方法
Panda et al. Contemplating visual emotions: Understanding and overcoming dataset bias
CN112507901B (zh) 一种基于伪标签自纠正的无监督行人重识别方法
CN110135461B (zh) 基于分层注意感知深度度量学习的情感图像检索的方法
CN109871885A (zh) 一种基于深度学习和植物分类学的植物识别方法
CN110575663B (zh) 一种基于人工智能的体育辅助训练方法
US20210319215A1 (en) Method and system for person re-identification
CN113326731A (zh) 一种基于动量网络指导的跨域行人重识别算法
CN115641613A (zh) 一种基于聚类和多尺度学习的无监督跨域行人重识别方法
CN116757261A (zh) 基于带有闭集噪声和开集噪声标签的鲁棒学习方法
CN114880478B (zh) 基于主题信息增强的弱监督方面类别检测方法
CN116311483B (zh) 基于局部面部区域重构和记忆对比学习的微表情识别方法
Das et al. NAS-SGAN: a semi-supervised generative adversarial network model for atypia scoring of breast cancer histopathological images
CN115331065B (zh) 基于解码器迭代筛选的鲁棒噪声多标签图像学习方法
CN115270752A (zh) 一种基于多层次对比学习的模板句评估方法
CN115511012A (zh) 一种最大熵约束的类别软标签识别训练方法
CN115630649A (zh) 一种基于生成模型的医学中文命名实体识别方法
CN114882534A (zh) 基于反事实注意力学习的行人再识别方法、系统、介质
CN113764034B (zh) 基因组序列中潜在bgc的预测方法、装置、设备及介质
CN110867225A (zh) 字符级临床概念提取命名实体识别方法及系统
CN117313709B (zh) 一种基于统计信息和预训练语言模型的生成文本检测方法
CN110378384B (zh) 一种结合特权信息和排序支持向量机的图像分类方法
CN112883930A (zh) 基于全连接网络的实时真假运动判断方法
CN117274657A (zh) 基于课程知识蒸馏的耐噪声木薯叶病害分类方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination