CN116151366A - 一种基于在线蒸馏的噪声标签鲁棒性学习方法 - Google Patents

一种基于在线蒸馏的噪声标签鲁棒性学习方法 Download PDF

Info

Publication number
CN116151366A
CN116151366A CN202310158386.8A CN202310158386A CN116151366A CN 116151366 A CN116151366 A CN 116151366A CN 202310158386 A CN202310158386 A CN 202310158386A CN 116151366 A CN116151366 A CN 116151366A
Authority
CN
China
Prior art keywords
model
noise
loss
tag
lcn
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310158386.8A
Other languages
English (en)
Inventor
黄贻望
黄雨鑫
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tongren University
Original Assignee
Tongren University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tongren University filed Critical Tongren University
Priority to CN202310158386.8A priority Critical patent/CN116151366A/zh
Publication of CN116151366A publication Critical patent/CN116151366A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T90/00Enabling technologies or technologies with a potential or indirect contribution to GHG emissions mitigation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明涉及标签校正与知识蒸馏技术领域,尤其涉及了一种基于在线蒸馏的噪声标签鲁棒性学习方法,将含噪声数据集输入教师模型,主模型与LCN分别生成伪标签与元伪标签;计算主模型输出对数与元伪标签之间的交叉熵损失;将噪声数据输入学生模型,并使用教师模型生成的伪标签训练,计算损失;根据损失计算梯度并更新学生模型;将干净数据输入更新后的学生模型,计算其损失并反馈给主模型,并更新主模型;将干净数据输入主模型并计算其损失,并更新LCN。该基于在线蒸馏的噪声标签鲁棒性学习方法,在噪声数据下,将知识蒸馏技术与标签校正技术相结合,并采用LCN生成的元伪标签作为真实标签,增强了模型训练的可行性。

Description

一种基于在线蒸馏的噪声标签鲁棒性学习方法
技术领域
本发明涉及标签校正与知识蒸馏技术领域,具体为一种基于在线蒸馏的噪声标签鲁棒性学习方法。
背景技术
如今大规模干净数据集是深度神经网络取得成功的一个关键因素,但大规模干净数据集的成本相当昂贵且相当稀缺,而在实际应用中通过搜索引擎与自动标签软件等方式很轻易的获取大量的非专业化标注数据集,这些数据集往往会含有一定程度的噪声标签,而这些噪声标签通常会导致深度神经网络出现过拟合现象,从而降低模型的泛化能力。因此,能够高效的在噪声标签数据下进行深度神经网络的鲁棒性训练成为当下一个重要的研究热点。
标签噪声学习在数据层面上通常利用噪声转移矩阵作为辅助信息构建满足统计一致性的学习算法,或利用噪声转移矩阵对噪声数据进行清洗,即在获取噪声转移矩阵后,就可以对数据进行校正。校正的方法又可以分为前向校正与后向校正。但噪声转移矩阵的获取依赖训练数据中的“锚点”(特定类别的数据点),而噪声转移矩阵估计的准确度又直接影响目标分类器的性能,故此类技术潜在的缺点在于有些训练数据并不存在“锚点”。而从目标层面上,通过正则化技术可以有效防止神经网络过拟合,限制模型对噪声的拟合。或者使用重加权技术减轻误标注数据对优化目标的贡献,理想情况下,误标注数据权重趋近于0;而标注正确数据的重趋近于1。算法在实际情况下通常无法给出完全准确的判断,因此样本权重通常为一定范围内的实数(可能为负,可能大于1)。另外,标签校正技术通过修正噪声数据中的错误标签,并使用修正后的标签训练模型,能够有效抑制标签噪声对模型的过拟合现象。但这种方法对于不同的数据集需要设置不同的超参数,自适应能力不足,难以普遍应用于现实生活中。目前,可通过元学习技术来改善这一缺陷,利用元学习的双层优化框架对标签校正方法中的分类器网络与标签校正网络进行联合优化,使其能够相互促进学习,从而提升模型的自适应能力。在模型层面上,越复杂的模型越能适应含有标签噪声的数据集,其迭代训练中模型测试精度会呈现一个“双层下降”的过程。但预训练大模型需要运用于不同的下游任务,现实情况下将大模型对一个下游任务进行重新训练需要耗费较大的资源,且训练后的大模型难以部署到边缘设备上。
为了噪声数据下使轻量化模型也能够达到与复杂模型相近甚至超越复杂模型的性能,提出了一种基于在线蒸馏的噪声标签鲁棒性学习方法,该方法利用知识蒸馏框架构建师生模型解决小模型在噪声标签中泛化性能不佳的问题,其中教师模型采用标签校正模型,同时使用元学习来连接轻量化模型与教师模型,在训练高性能轻量化模型的同时还能提升教师模型生成伪标签的质量与其分类性能。
发明内容
本发明的目的在于提供一种基于在线蒸馏的噪声标签鲁棒性学习方法,以解决上述背景技术中提出的问题。
为实现上述目的,本发明提供如下技术方案:一种基于在线蒸馏的噪声标签鲁棒性学习方法,包括以下步骤:
S1:将含噪声数据集输入教师模型(MLC),主模型与LCN分别生成伪标签与元伪标签;
S2:计算主模型输出对数与元伪标签之间的交叉熵损失;
S3:将噪声数据输入学生模型(轻量化神经网络),并使用教师模型生成的伪标签训练,计算损失;
S4:根据S2中的损失计算梯度并更新学生模型;
S5:将干净数据输入更新后的学生模型,计算其损失并反馈给主模型,并更新主模型;
S6:将干净数据输入主模型并计算其损失,并更新LCN。
优选的,所述步骤S1中,需要生成相应的伪标签,其具体步骤为:
S1-1:将噪声数据输入主模型网络,输出相应数据的对数,对数通过softmax层得到伪标签;
S1-2:提取噪声数据在主模型网络的特征输出并提供给LCN;
S1-3:LCN接收S1-2中的特征输出与相应噪声数据的标签,得到元伪标签。
优选的,所述S1中的MLC模型是由一个主模型(深度神经网络)与标签校正网络(LCN,即多层感知机)构成,主模型是一个参数为的Resnet网络,LCN是一个参数为的MLP,主模型与LCN之间通过元学习框架连接。
优选的,所述在线蒸馏的噪声标签鲁棒性学习方法主体框架由一个MLC模块与知识蒸馏模块构成,在知识蒸馏模块中,学生模型分别通过主模型与LCN生成的伪标签进行训练。
优选的,所述步骤S3中,通过知识蒸馏技术训练学生模型(MLC模型为教师模型),具体步骤如下:
S3-1:计算教师模型中主模型网络输出伪标签与学生模型输出之间的KL散度;
S3-2:LCN生成的元伪标签质量更高,将其视为真实标签,计算其与学生模型之间的交叉熵损失;
S3-3:计算S3-1与S3-2损失的梯度,并更新学生模型参数。
优选的,所述步骤S5中,具体步骤如下:
S5-1:计算更新后的学生模型在干净数据集上的损失;
S5-2:采用策略梯度计算S4-1中的损失关于主模型参数的梯度;
S5-3:计算S2中交叉熵损失关于主模型参数的梯度;
S5-4:根据S4-2与S4-3更新主模型。
优选的,所述步骤S6中,具体步骤如下:
S6-1:计算更新后的主模型在干净数据集上的损失;
S6-2:计算S6-1中的损失关于LCN参数的梯度;
S6-3:根据S6-2更新LCN。
与现有技术相比,本发明的有益效果是:
1.该基于在线蒸馏的噪声标签鲁棒性学习方法,在噪声数据下,将知识蒸馏技术与标签校正技术相结合,并采用LCN生成的元伪标签作为真实标签(知识蒸馏技术极其依赖真实标签),增强了模型训练的可行性。
2.该基于在线蒸馏的噪声标签鲁棒性学习方法,通过元学习技术将学生模型在干净数据上的损失反馈给教师模型,达到约束教师模型的效果,使教师模型生成更高质量的伪标签以供学生模型训练。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本发明的MLC模型架构图;
图2为本发明的知识蒸馏模型架构图;
图3为本发明的在两种噪声类型的不同噪声水平下的模型精度对比图;
图4为本发明的在UNIF噪声类型的不同噪声水平下的模型测试精度对比图;
图5为本发明的基于在线蒸馏的噪声标签鲁棒性学习方法模型框架图;
图6为本发明基于在线蒸馏的噪声标签鲁棒性学习方法执行流程框图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
请参阅图1-图6,本发明提供一种技术方案:一种基于在线蒸馏的噪声标签鲁棒性学习方法,如图5所示,该方法包括以下步骤:
S1:将含噪声数据集输入教师模型(MLC),主模型与LCN分别生成伪标签与元伪标签。
S2:计算主模型输出对数与元伪标签之间的交叉熵损失。
S3:将噪声数据输入学生模型(轻量化神经网络),并使用教师模型生成的伪标签训练,计算损失。
S4:根据S2中的损失计算梯度并更新学生模型。
S5:将干净数据输入更新后的学生模型,计算其损失并反馈给主模型,并更新主模型。
S6:将干净数据输入主模型并计算其损失,并更新LCN。
在步骤S1中,需要生成相应的伪标签,其具体步骤为:
S1-1:将噪声数据输入主模型网络,输出相应数据的对数,对数通过softmax层得到伪标签。
S1-2:提取噪声数据在主模型网络的特征输出并提供给LCN。
S1-3:LCN接收S1-2中的特征输出与相应噪声数据的标签,得到元伪标签。
在步骤S3中,通过知识蒸馏技术训练学生模型(MLC模型为教师模型),具体步骤如下:
S3-1:计算教师模型中主模型网络输出伪标签与学生模型输出之间的KL散度。
S3-2:LCN生成的元伪标签质量更高,将其视为真实标签,计算其与学生模型之间的交叉熵损失。
S3-3:计算S3-1与S3-2损失的梯度,并更新学生模型参数。
在步骤S5中,具体步骤如下:
S4-1:计算更新后的学生模型在干净数据集上的损失。
S4-2:采用策略梯度计算S4-1中的损失关于主模型参数的梯度。
S4-3:计算S2中交叉熵损失关于主模型参数的梯度。
S4-4:根据S4-2与S4-3更新主模型。
在步骤S6中,具体步骤如下:
S6-1:计算更新后的主模型在干净数据集上的损失。
S6-2:计算S6-1中的损失关于LCN参数的梯度。
S6-3:根据S6-2更新LCN。
元标签校正模型MLC
对于公共干净数据集,可以将其拆分为一个大的噪声数据集Du={xl,yl}n与一个小的干净数据集D={xu,yu}m,m n,可以采用UNIF与FLIP两种方法将数据集转化为对称标签噪声数据与非对称噪声数据。对于真实噪声数据集无需此操作。
MLC模型是由一个主模型(深度神经网络)与标签校正网络(LCN,即多层感知机)构成,如图1所示:
主模型是一个参数为θt的Resnet网络,LCN是一个参数为ω的MLP,主模型与LCN之间通过元学习框架连接。主模型参数更新公式如下:
Figure BDA0004093339820000061
其中ηt为主模型学习率,且
Figure BDA0004093339820000062
Figure BDA0004093339820000064
分别代表主模型与LCN,/>
Figure BDA0004093339820000063
表示噪声数据在主模型上的特征输出。若LCN生成了高质量的伪标签,则更新后的主模型在干净数据集上应实现较低的损失,MLC的总体形式可以表现为下列函数:
Figure BDA0004093339820000071
MLC算法具体流程如下:
1.将噪声数据提供给LCN生成相应的伪标签;
2.将噪声数据提供给主模型并计算预测的对数;
3.计算对数与伪标签之间的损失,并计算损失的梯度,即主模型的参数;
4.更新主模型参数;
5.将干净数据提供给更新后的主模型,并计算器损失;
6.计算(5)中损失的梯度,并更新LCN。
知识蒸馏是一种模型压缩技术,同时带有正则化效果。传统的知识蒸馏通过预训练的复杂教师模型所提炼出的监督信息来训练轻量化学生模型,使学生模型拥有更好的性能。其主要思想是通过引入温度系数t来改造softmax函数:
Figure BDA0004093339820000072
原始softmax函数的t=1,当温度系数t越高时,softmax函数输出每个值的概率分布更均匀,即增加了对负标签的关注程度,而负标签中都包含有一定的信息,尤其是一些值显著高于平均值的负标签对提升模型性能具有巨大价值。知识蒸馏具体流程如图2所示,知识蒸馏则是通过真实标签与教师模型生成的软标签对学生模型进行训练:
1.将数据输入预训练的复杂教师模型与轻量化的学生模型,并输出相应的软标签;
2.计算教师模型与学生模型所输出的软标签之间的损失(KL散度,Lsoft);
3.计算学生模型输出软标签与真实标签之间的交叉熵损失(Lhard);
4.根据(2)与(3)的损失之和计算其梯度,并更新学生模型参数。
基于在线蒸馏的噪声标签鲁棒性学习方法的具体计算过程与实施细节:
本发明模型的主体框架如图3所示,由一个MLC模块与知识蒸馏模块构成。在知识蒸馏模块中,学生模型分别通过主模型与LCN生成的伪标签进行训练。知识蒸馏对真实标签有较高的依赖,所以我们将LCN生成的元伪标签
Figure BDA00040933398200000811
视为真实标签,形成一个完整的基于响应知识的蒸馏过程:
将噪声数据Du={xu,yu}n输入主模型,从而得到相关噪声数据的特征输出
Figure BDA0004093339820000081
与对数输出/>
Figure BDA0004093339820000082
(其中θt为主模型参数),以及伪标签yt
将特征输出
Figure BDA0004093339820000083
与错误标签yu输入LCN中得到元伪标签/>
Figure BDA0004093339820000084
其中gω为LCN(其中ω为LCN参数);
将噪声数据输入学生模型得到其对数输出
Figure BDA0004093339820000085
(其中θs为学生模型参数)与软标签ys,并将元伪标签/>
Figure BDA0004093339820000086
作为真实标签,计算/>
Figure BDA0004093339820000087
与/>
Figure BDA0004093339820000088
之间的交叉熵损失:
Figure BDA0004093339820000089
计算ys与yt之间的KL散度:
Figure BDA00040933398200000810
根据知识蒸馏计算其蒸馏损失(α,β为超参数):
LKDst,ω)=αLreal+βLpseudo (6)
根据蒸馏损失更新学生模型(其中ηs为学生模型学习率):
Figure BDA0004093339820000091
学生模型依赖于教师模型生成的伪标签进行训练更新,即
Figure BDA0004093339820000092
所以教师模型生成伪标签的质量决定了学生模型的性能,为了使教师模型生成更高质量的伪标签,可以通过借鉴MPL模型的学生模型反馈机制(元学习双层优化策略),将学生模型在小型干净数据集D={xl,yl}m上的训练损失反馈给主模型,从而提升主模型性能,使它能够生成更高质量的伪标签:
Figure BDA0004093339820000093
其中
Figure BDA0004093339820000094
更新主模型参数:
Figure BDA0004093339820000095
计算更新后的主模型在干净数据集D={xl,yl}n上的对数输出
Figure BDA0004093339820000096
计算
Figure BDA0004093339820000097
与真实标签yl之间的交叉熵损失/>
Figure BDA0004093339820000098
更新LCN(ηω为LCN学习率):
Figure BDA0004093339820000099
具体算法流程如下表所示(ω为LCN参数,γ为噪声水平,lr为学习率):
Figure BDA00040933398200000910
/>
Figure BDA0004093339820000101
在CIFAR10数据集下,我们采用Resnet34作为教师模型的主模型,在CIFAR100数据集下我们采用Resnet32作为主模型,学生模型采用Resnet18网络。教师模型与学生模型采用不同的学习率优化方法,教师模型使用MultiStepLR方法分别在80次与100次迭代中改变学习率大小,学生模型将warmup与余弦法相结合作为学习率优化策略。蒸馏温度在我们的实验上影响较大,所以我们在CIFAR10数据集上教师模型与学生模型之间温度系数设置为3,在CIFAR100数据集上主模型与学生模型之间温度系数设置为1。
图3显示了MLC与我们的方法在CIFAR10数据集下不同噪声水平的测试精度,其中NT与NS代表我们算法的教师模型与学生模型.我们采用的学习率lr=0.03,蒸馏温度T=3,在UNIF噪声类型中,当噪声水平大于50%时与基础算法MLC相比有明显的提升,证明我们的方法在高噪声水平下拥有更优的泛化性能。而在FLIP噪声类型下,由于一个类中的任何实例只能以固定概率翻转到另一个类中,在这种噪声下,噪声水平应低于50%,否则错误标记数据占大多数将给模型的泛化性能带来较大的影响。在噪声类别较小的情况下,我们的算法在所有噪声水平下对比MLC算法均有明显的提升,尤其是学生模型的测试精度均高于MLC模型。
图4是在UNIF噪声类型的不同噪声水平下的测试精度对比。左图与右图分别代表噪声水平为0.4与0.9时在120次迭代中的精度对比。我们的教师模型与MLC模型采用相同的学习率优化策略,在80次迭代时均有较高的精度提升,而学生模型采用的学习率优化策略是将warmup与余弦法相结合,且模型容量较小,故在噪声水平较低时所精度曲线并不平缓。在高噪声水平下,MLC模型在80次迭代后明显出现了过拟合现象,而我们的教师模型仍然呈现平稳上升的趋势并未发生过拟合现象,证明我们的方法在高噪声水平下拥有更好的泛化性能与正则化效果,同时学生模型相对于低噪声水平时测试精度也呈现较为平缓的上升趋势。且在不同噪声水平下我们的方法性能均高于MLC算法。
对于上述公式(2),学生模型的反馈损失关于教师模型参数θt的求导过程,即式(9)中
Figure BDA0004093339820000111
在一批噪声数据中主模型生成的伪标签为yt Du(xu,yu),标签校正网络生成的元伪标签为/>
Figure BDA0004093339820000112
为了简化公式,将学生模型与教师模型分别表示为S与T,且LCN用Y表示。通过更新主模型参数,以最小化学生模型在小型干净数据集D(xl,yl)上的期望。
Figure BDA0004093339820000113
为了简化符号,作如下定义:
Figure BDA0004093339820000121
从而依据链式法则,将公式(11)转化为:
Figure BDA0004093339820000122
公式(13)中的第一项可由反向传播直接计算得出,故主要关注于第二项的推导过程:
Figure BDA0004093339820000123
故为了简化符号,作如下定义:
Figure BDA0004093339820000124
Figure BDA0004093339820000125
从而公式(14)可转化为:
Figure BDA0004093339820000126
yt
Figure BDA0004093339820000127
通过教师模型生成,所以与θt存在依赖关系,但由公式(6)与(7)可知,gt(yt)与/>
Figure BDA0004093339820000128
与θt无关,所以可以根据贝尔曼期望与策略梯度可做如下转化:
对公式(15)进行假设:
Figure BDA0004093339820000129
/>
从而可通过策略梯度将公式转化为:
Figure BDA0004093339820000131
上式中可以根据交叉熵损失函数的定义可对
Figure BDA0004093339820000132
进行替换,同时可对公式(17)进行相同的假设:
Figure BDA0004093339820000133
根据公式(19)可得:
Figure BDA0004093339820000134
故最后可将公式(13)推导为:
Figure BDA0004093339820000135
强化学习基本概念:
对于公式(18)的假设应用到了策略梯度方法,策略梯度是强化学习中的一种算法,以下是强化学习基本符号的定义:
状态s,动作a;
策略函数π(a|s),它是一个概率密度函数π(a|s)=P(A=a|S=s),即给定状态s做出动作a的概率;
状态转移p(s'|s,a),它是一个条件概率密度函数p(s'|s,a)=P(S'=s'|S=s,A=a);
回报R,即在状态s下做出动作a所能得到的奖励大小;
Return Ut=Rt+γRt+12Rt+2+…+γn-tRn
动作价值函数Qπ(st,at)=Ε[Ut|St=st,At=at]
状态价值函数Vπ(st)=ΕA[Qπ(st,A)]=∑aπ(a|st)·Qπ(st,a)
状态价值函数将A作为随机变量然后关于A求期望,且其期望仅与π,s有关,即Aπ(·|st),故根据期望的定义可将其写为连加的形式,即Vπ(st)=ΕA[Qπ(st,A)]=Σaπ(a|st)·Qπ(st,a)。
策略梯度:
(1)使用策略网络π(a|s;θ)近似π(a|s);
(2)定义状态价值函数V(s;θ)=∑aπ(a|s;θ)·Qπ(s,a);
(3)关于θ对函数V(s;θ)求导,即
Figure BDA0004093339820000141
我们所提算法将噪声数据集Du={xu,yu}n中的数据输入xu近似为状态s,类别yu近似为动作a。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个......”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
尽管已经示出和描述了本发明的实施例,对于本领域的普通技术人员而言,可以理解在不脱离本发明的原理和精神的情况下可以对这些实施例进行多种变化、修改、替换和变型,本发明的范围由所附权利要求及其等同物限定。

Claims (7)

1.一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:包括以下步骤:
S1:将含噪声数据集输入教师模型(MLC),主模型与LCN分别生成伪标签与元伪标签;
S2:计算主模型输出对数与元伪标签之间的交叉熵损失;
S3:将噪声数据输入学生模型(轻量化神经网络),并使用教师模型生成的伪标签训练,计算损失;
S4:根据S2中的损失计算梯度并更新学生模型;
S5:将干净数据输入更新后的学生模型,计算其损失并反馈给主模型,并更新主模型;
S6:将干净数据输入主模型并计算其损失,并更新LCN。
2.根据权利要求1所述的一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:所述S1中的MLC模型是由一个主模型(深度神经网络)与标签校正网络(LCN,即多层感知机)构成,主模型是一个参数为的Resnet网络,LCN是一个参数为的MLP,主模型与LCN之间通过元学习框架连接。
3.根据权利要求1所述的一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:所述在线蒸馏的噪声标签鲁棒性学习方法主体框架由一个MLC模块与知识蒸馏模块构成,在知识蒸馏模块中,学生模型分别通过主模型与LCN生成的伪标签进行训练。
4.根据权利要求1所述的一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:所述步骤S1中,需要生成相应的伪标签,其具体步骤为:
S1-1:将噪声数据输入主模型网络,输出相应数据的对数,对数通过softmax层得到伪标签;
S1-2:提取噪声数据在主模型网络的特征输出并提供给LCN;
S1-3:LCN接收S1-2中的特征输出与相应噪声数据的标签,得到元伪标签。
5.根据权利要求1所述的一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:所述步骤S3中,通过知识蒸馏技术训练学生模型(MLC模型为教师模型),具体步骤如下:
S3-1:计算教师模型中主模型网络输出伪标签与学生模型输出之间的KL散度;
S3-2:LCN生成的元伪标签质量更高,将其视为真实标签,计算其与学生模型之间的交叉熵损失;
S3-3:计算S3-1与S3-2损失的梯度,并更新学生模型参数。
6.根据权利要求1所述的一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:所述步骤S5中,具体步骤如下:
S5-1:计算更新后的学生模型在干净数据集上的损失;
S5-2:采用策略梯度计算S4-1中的损失关于主模型参数的梯度;
S5-3:计算S2中交叉熵损失关于主模型参数的梯度;
S5-4:根据S4-2与S4-3更新主模型。
7.根据权利要求1所述的一种基于在线蒸馏的噪声标签鲁棒性学习方法,其特征在于:所述步骤S6中,具体步骤如下:
S6-1:计算更新后的主模型在干净数据集上的损失;
S6-2:计算S6-1中的损失关于LCN参数的梯度;
S6-3:根据S6-2更新LCN。
CN202310158386.8A 2023-02-23 2023-02-23 一种基于在线蒸馏的噪声标签鲁棒性学习方法 Pending CN116151366A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310158386.8A CN116151366A (zh) 2023-02-23 2023-02-23 一种基于在线蒸馏的噪声标签鲁棒性学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310158386.8A CN116151366A (zh) 2023-02-23 2023-02-23 一种基于在线蒸馏的噪声标签鲁棒性学习方法

Publications (1)

Publication Number Publication Date
CN116151366A true CN116151366A (zh) 2023-05-23

Family

ID=86354114

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310158386.8A Pending CN116151366A (zh) 2023-02-23 2023-02-23 一种基于在线蒸馏的噪声标签鲁棒性学习方法

Country Status (1)

Country Link
CN (1) CN116151366A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117237720A (zh) * 2023-09-18 2023-12-15 大连理工大学 基于强化学习的标签噪声矫正图像分类方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117237720A (zh) * 2023-09-18 2023-12-15 大连理工大学 基于强化学习的标签噪声矫正图像分类方法
CN117237720B (zh) * 2023-09-18 2024-04-12 大连理工大学 基于强化学习的标签噪声矫正图像分类方法

Similar Documents

Publication Publication Date Title
Zheng et al. Layer-wise learning based stochastic gradient descent method for the optimization of deep convolutional neural network
CN112905891B (zh) 基于图神经网络的科研知识图谱人才推荐方法及装置
CN109635204A (zh) 基于协同过滤和长短记忆网络的在线推荐系统
CN112699247A (zh) 一种基于多类交叉熵对比补全编码的知识表示学习框架
CN113065649B (zh) 一种复杂网络拓扑图表示学习方法、预测方法及服务器
CN114154643A (zh) 基于联邦蒸馏的联邦学习模型的训练方法、系统和介质
CN112967088A (zh) 基于知识蒸馏的营销活动预测模型结构和预测方法
CN114329232A (zh) 一种基于科研网络的用户画像构建方法和系统
CN116644755B (zh) 基于多任务学习的少样本命名实体识别方法、装置及介质
CN116151366A (zh) 一种基于在线蒸馏的噪声标签鲁棒性学习方法
Liu et al. An improved Adam optimization algorithm combining adaptive coefficients and composite gradients based on randomized block coordinate descent
CN117009545A (zh) 一种持续多模态知识图谱的构建方法
CN115525771A (zh) 基于上下文数据增强的少样本知识图谱表示学习方法及系统
Lv et al. Intelligent model update strategy for sequential recommendation
CN114969078A (zh) 一种联邦学习的专家研究兴趣实时在线预测更新方法
Wang et al. Knowledge-enhanced semi-supervised federated learning for aggregating heterogeneous lightweight clients in iot
CN113326884A (zh) 大规模异构图节点表示的高效学习方法及装置
CN110083676B (zh) 一种基于短文本的领域动态跟踪方法
Fan et al. Convergence analysis for sparse Pi-sigma neural network model with entropy error function
CN115131605A (zh) 一种基于自适应子图的结构感知图对比学习方法
CN114218365B (zh) 一种机器阅读理解方法、系统、计算机及存储介质
Guo et al. Collaborative Extreme Noise Classifier: A Multi-Network Approach Based Extreme Noise Classification
CN115936115B (zh) 基于图卷积对比学习和XLNet的知识图谱嵌入方法
ZHANG et al. Multilingual Knowledge Graph Completion Based on Structure Features of the Dual-Branch
US20240119291A1 (en) Dynamic neural network model sparsification

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination