CN115482418B - 基于伪负标签的半监督模型训练方法、系统及应用 - Google Patents

基于伪负标签的半监督模型训练方法、系统及应用 Download PDF

Info

Publication number
CN115482418B
CN115482418B CN202211232414.8A CN202211232414A CN115482418B CN 115482418 B CN115482418 B CN 115482418B CN 202211232414 A CN202211232414 A CN 202211232414A CN 115482418 B CN115482418 B CN 115482418B
Authority
CN
China
Prior art keywords
pseudo
label
classification
result
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211232414.8A
Other languages
English (en)
Other versions
CN115482418A (zh
Inventor
徐昊
彭成斌
陈传梓
邱晓杰
肖辉
严迪群
董理
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Centran Technology Co ltd
Original Assignee
Centran Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Centran Technology Co ltd filed Critical Centran Technology Co ltd
Priority to CN202211232414.8A priority Critical patent/CN115482418B/zh
Publication of CN115482418A publication Critical patent/CN115482418A/zh
Application granted granted Critical
Publication of CN115482418B publication Critical patent/CN115482418B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • Databases & Information Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种针对半监督图像分类任务下基于伪负标签的模型训练方法、系统及应用。所述训练方法包括:获取结构相同的两个基础模型及训练集;计算有监督损失值;将无标签数据分别进行弱增强和强增强操作;分别输入基础模型获得伪标签和预测结果;基于伪标签生成伪负标签,负标签代表图片不属于的类别,并基于模型预测和对方提供的伪负标签计算无监督损失值;基于有监督和无监督损失值,迭代更新参数。本发明所提供的训练方法通过生成伪负标签,避免了对伪标签进行筛选,有效提高了无标签数据利用率,并降低两个基础模型参数之间的耦合程度;通过伪负标签选择生成更高效的伪负标签供对方基础模型学习,从而显著提高了训练效率和模型分类准确性。

Description

基于伪负标签的半监督模型训练方法、系统及应用
技术领域
本发明属于计算机技术领域,具体涉及计算机视觉和机器学习技术领域,尤其涉及一种基于伪负标签的半监督模型训练方法、系统及应用。
背景技术
图像分类是计算机视觉领域中最为重要的一项任务。随着深度学习的发展,该项任务取得了重大的突破。
然而对于一般的全监督学习方法,模型的性能很大程度上依靠数据的规模。对于在现实任务场景中,数据的采集十分便利,但是数据的标注往往伴随着巨大了人力成本和时间成本。因此,对半监督学习方法的探索尤为重要。
半监督学习主要是通过将少量的有标签数据和大量的无标签数据结合去训练一个具有强泛化能力的AI模型。通常情况下,半监督图像分类方法通过对无标记的数据生成伪标签,利用伪标签进行熵最小化从而进一步指导模型训练。但是由于生成的伪标签往往带有噪声,这会导致训练后期模型会对不同的噪声产生过拟合现象。对此一些方法采用较高的阈值去对伪标签进行筛选,保留较高置信度的伪标签。这种操作虽然滤除了大部分噪声标签但是导致了无标签数据利用率大大降低,所以这些方法也伴随着局限性。
此外,在半监督图像分类方法中,基于多模型的相互学习方法通过多个模型彼此提供伪标签作为训练的指导目标,从而达到彼此促进收敛的效果。但是在这个学习过程,模型可能会将错误的目标提供给对方学习,这将导致对方模型性能降低,从而导致整个训练框架崩坏。其次,在训练后期由于模型趋于收敛,从而使得彼此传递的学习目标趋于一致,进而导致相互学习框架退化为自学习框架。
发明内容
针对现有技术的不足,本发明的目的在于针对半监督图像分类任务,提供一种基于伪负标签的模型训练方法、系统及应用。基于伪负标签的模型相互学习有效提高了无标签数据利用率,并且降低了两个基础模型参数之间的耦合程度。通过伪负标签选择模块帮助生成更高效的伪负标签供对方学习,从而提高了训练效率。
为实现前述发明目的,本发明采用的技术方案包括:
第一方面,本发明提供一种基于伪负标签的图像分类模型训练方法,包括:
1)获取结构相同的第一基础模型和第二基础模型以及训练集,所述训练集包括有标签数据及其对应的真值标签,以及无标签数据;
2)将任一所述有标签数据分别进行两次不同的弱增强操作,获得有标签第一结果和有标签第二结果,基于所述有标签第一结果和有标签第二结果及其对应的真值标签分别计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值;
3)将任一所述无标签数据分别进行弱增强操作和强增强操作,获得无标签弱增强结果和无标签强增强结果;
4)将所述无标签弱增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一伪标签和第二伪标签;将所述无标签强增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一分类结果和第二分类结果;
5)基于所述第一伪标签和第二伪标签生成第一伪负标签和第二伪负标签,其中,伪负标签代表不同于对应的伪标签的其余标签类别,并基于所述第一分类结果和第二伪负标签计算第一无监督损失值,基于所述第二分类结果和第一伪负标签计算第二无监督损失值;
6)基于所述第一有监督损失值和第一无监督损失值,迭代更新所述第一基础模型的参数,基于所述第二有监督损失值和第二无监督损失值,迭代更新所述第二基础模型的参数。
第二方面,本发明还提供一种基于伪负标签的半监督图像分类模型训练系统,包括:
模型数据模块,用于获取结构相同的第一基础模型和第二基础模型以及训练集,所述训练集包括有标签数据及其对应的真值标签,以及无标签数据;
有标签增强模块,用于将任一所述有标签数据分别进行两次不同的弱增强操作,获得有标签第一结果和有标签第二结果;
有监督损失模块,用于基于所述有标签第一结果和有标签第二结果及其对应的真值标签分别计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值;
无标签增强模块,用于将任一所述无标签数据分别进行弱增强操作和强增强操作,获得无标签弱增强结果和无标签强增强结果;
无标签分类模块,用于将所述无标签弱增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一伪标签和第二伪标签;将所述无标签强增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一分类结果和第二分类结果;
伪负标签模块,用于基于所述第一伪标签和第二伪标签生成第一伪负标签和第二伪负标签,其中,伪负标签代表不同于对应的伪标签的其余类别标签,并基于所述第一分类结果和第二伪负标签计算第一无监督损失值,基于所述第二分类结果和第一伪负标签计算第二无监督损失值;
迭代更新模块,用于基于所述第一有监督损失值和第一无监督损失值,迭代更新所述第一基础模型的参数,基于所述第二有监督损失值和第二无监督损失值,迭代更新所述第二基础模型的参数。
第三方面,本发明还提供上述训练方法训练获得的图像分类模型。
第四方面,本发明还提供一种电子设备,包括存储器和处理器,所述存储器中存储有计算机程序,所述计算机程序被所述处理器运行时,执行上述模型训练方法中的步骤或运行上述图像分类模型。
第五方面,本发明还提供一种可读存储介质,所述刻度存储介质中存储有计算机程序,所述计算机程序被运行时,执行上述模型训练方法中的步骤或运行上述图像分类模型。
基于上述技术方案,与现有技术相比,本发明的有益效果包括:
本发明所提供的训练方法通过生成伪负标签,避免了对伪标签进行筛选,有效提高了无标签数据的利用率,并且降低了两个基础模型参数之间的耦合程度;通过伪负标签选择生成更高效的伪负标签供对方基础模型学习,从而显著提高了训练效率和模型分类准确性。
上述说明仅是本发明技术方案的概述,为了能够使本领域技术人员能够更清楚地了解本申请的技术手段,并可依照说明书的内容予以实施,以下以本发明的较佳实施例并配合详细附图说明如后。
附图说明
图1是本发明一典型实施案例提供的图像分类模型的训练方法的流程示意图;
图2是本发明一典型实施案例提供的图像分类模型的训练方法的系统结构与过程示意图。
具体实施方式
鉴于现有技术中的不足,本案发明人经长期研究和大量实践,得以提出本发明的技术方案。
如下将对该技术方案、其实施过程及原理等作进一步的解释说明。在下面的描述中阐述了很多具体细节以便于充分理解本发明,但是,本发明还可以采用其他不同于在此描述的方式来实施,因此,本发明的保护范围并不受下面公开的具体实施例的限制。
而且,诸如“第一”和“第二”等之类的关系术语仅仅用来将一个与另一个具有相同名称的部件或方法步骤区分开来,而不一定要求或者暗示这些部件或方法步骤之间存在任何这种实际的关系或者顺序。
参见图1-图2,本发明实施例提供一种基于伪负标签的图像分类模型训练方法,包括如下的步骤:
1)获取结构相同的第一基础模型和第二基础模型以及训练集,所述训练集包括有标签数据及其对应的真值标签,以及无标签数据。
2)将任一所述有标签数据分别进行两次不同的弱增强操作,获得有标签第一结果和有标签第二结果,基于所述有标签第一结果和有标签第二结果及其对应的真值标签分别计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值。
3)将任一所述无标签数据分别进行弱增强操作和强增强操作,获得无标签弱增强结果和无标签强增强结果。
4)将所述无标签弱增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一伪标签和第二伪标签;将所述无标签强增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一分类结果和第二分类结果。
5)基于所述第一伪标签和第二伪标签生成第一伪负标签和第二伪负标签,其中,伪负标签代表不同于对应的伪标签的误分类结果,并基于所述第一分类结果和第二伪负标签计算第一无监督损失值,基于所述第二分类结果和第一伪负标签计算第二无监督损失值。
6)基于所述第一有监督损失值和第一无监督损失值,迭代更新所述第一基础模型的参数,基于所述第二有监督损失值和第二无监督损失值,迭代更新所述第二基础模型的参数。
其中,所述第一和第二基础模型优选可以采用所述有标签数据以及真值标签进行预训练获得,也可以直接获取在其他方法或程序中初步训练得到的图像分类模型。所述真值标签亦称地面真值标签。并且,为方便展示,本发明实施例所提供的图2进行了向左旋转。
基于上述技术方案,作为一些典型的应用实例,上述训练方法可以采用如下的步骤得以实施:
S1、创建并独立初始化两个结构一致的基础模型。
S2、对有标签数据进行弱增强操作,对无标签数据分别进行弱增强和强增强操作。
S3、将弱增强后的有标签数据分别输入两个模型中,通过对模型预测和对应标签计算交叉熵得到各自的有监督损失值。
S4、对于无标签数据,将弱增强处理的数据输入两个模型去生成伪标签,再将强增强处理的数据输入两个模型去生成预测分类。
S5、基于以上产生的伪标签通过伪负标签选择机制去为对方模型生成伪负标签,作为其对强增强数据预测的学习目标并计算无监督损失值。
S6、使用梯度下降更新两个模型参数,重复S2-S5直至收敛。
其过程可以概括为:用有标签数据训练两个基础网络;利用两个模型为弱增强处理之后无标签数据生成伪标签;基于伪标签,通过伪负标签选择模块生成高效的伪负标签作为对方模型的学习目标;在有标签数据下,分别计算两个模型对各个类别的误分类分数并利用指数移动平均进行更新;更新两个模型参数。
在一些实施方案中,步骤2)具体可以包括:
将任一所述有标签数据分别进行两次不同的弱增强操作,获得两个有标签弱增强结果并分别输入所述第一基础模型和第二基础模型,获得两个有标签分类结果,并结合对应的所述真值标签计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值。
在一些实施方案中,步骤5)中,所述伪负标签的生成方法具体可以包括:
基于所述有标签分类结果,统计所述真值标签所指示的类别之外的其余各个类别的分类概率。
基于所述分类概率,从所述伪标签所指示的类别之外的其余各个类别中随机抽取若干类别作为所述伪负标签。
在一些实施方案中,通过使用有标签数据计算基础模型误分类到其余各个类别的分类分数,所述分类分数代表所述分类概率。
所述分类分数的计算公式为:
Rk=Soft max(Prk)
其中,Prk[j]代表所述基础模型对第k类别数据误分到第j类上的概率分数;Nk代表模型对于本次迭代中的多个有标签数据中第k类数据误分类的总个数;pi代表所述基础模型预测的概率向量;Rk代表标准化过后的所述概率分数,并作为所述分类分数使用。
且在迭代过程中,采用指数移动平均的方法更新Prk[j]。
其中,Rk Prk代表-个向量。描述向量中每个元素操作。Rk=Softmax(Prk)描述向量整体操作。
在一些实施方案中,步骤4)中,所述伪标签的计算公式可以为:
其中,Y(p)代表所述伪标签;OneHot代表热编码操作;代表所述无标签弱增强结果;/>代表基础模型对所述无标签弱增强结果进行分类获得的分类概率分布。
在一些实施方案中,步骤5)中,所述伪负标签的计算公式可以为:
Y(c)∈z(Y(p),m)
其中,Y(c)代表所述伪负标签,从集合z(Y(p),m)中随机选择;K代表总类别个数;m为一个大于等于1小于K的正整数,代表随机选择的伪负标签的个数;v代表一个包含K个特征的一维向量。每个特征取值为0或1,其中值为1代表该索引对应的类别选作为伪负标签。
在一些实施方案中,所述第一有监督损失值和第二有监督损失值的损失函数分别可以为:
上述公式中,
其中,代表所述第一有监督损失值;/>代表所述第二有监督损失值;/>代表两次不同的弱增强操作获得的两个有标签弱增强结果;fθ代表所述第一基础模型,/>代表所述第二基础模型;Y(t)代表所述真值标签编码之后的热向量,Y(1)代表所述有标签分类结果。
所述第一无监督损失值和第二无监督损失值的损失函数分别可以为:
上述公式中,
其中,代表所述第一无监督损失值,/>代表所述第二无监督损失值;/>代表所述无标签弱增强结果;Y(c)代表所述伪负标签编码后的热向量;Y(2)代表所述第一分类结果或第二分类结果。
在一些实施方案中,可以将有监督损失值与无监督损失值的线型加和值作为总损失值对相应的基础模型的参数进行更新。
所述总损失值的计算公式分别可以为:
其中,λ代表有监督损失和无监督损失之间的平衡系数;l(1)和l(2)分别代表所述第一基础模型和第二基础模型对应的总损失值。
在一些实施方案中,所述训练方法中的迭代可以分批次进行,即将训练集分为多个批次,每一批次对应一次迭代。所述批次中的数据数量可以是大于一的若干个,例如2-256个等等,也可以仅仅只有一个,此时相当于不分批进行。
基于上述技术方案,在实际应用中,上述训练方法的执行步骤可以是:
(1)创建并独立初始化两个基础模型fθθ和/>是相应的模型参数,为了使得双方各方面能力相同,因此设定两个模型的网络结构一致。其次该训练集由少量的有标签数据和大量的无标签数据组成。将有标签数据和无标签数据划分成特定大小的批次,依次输入两个模型进行训练。
(2)对于小批次的有标签数据,对当中每个图像进行两次不同的弱增强操作,其中弱增强操作包含了随机翻转和水平方向平移等操作,对应公式可以如下:
其中Xi表示该批次中第i张图像数据,和/>表示不同的弱增强操作。/>表示不同增强处理之后的图像数据。
(3)经过上述操作处理之后,将得到的两个有标签数据批次分别做为两个基础模型的输入。之后对模型预测和相应地面真值标签计算有监督损失值。使用交叉熵函数作为有监督训练的损失函数,公式如下:
其中Y(1)表示所述基础模型预测的概率向量,Y(t)表示数据地面真值标签编码之后的热向量。
因此,两个模型的有监督损失可以如下所示:
和/>分别表示两个模型的有监督损失值。
(4)对于小批次的无标签数据,对当中每个图像分别进行弱增强和强增强操作,其中弱增强操作包含了随机翻转和水平方向平移等操作,强增强操作包含了色彩抖动等操作。对应公式如下:
其中Xj表示该无标签数据批次中第j张图像数据,Aw和As表示弱增强和强增强操作。
和/>表示不同增强处理之后的图像数据。
(5)将输入两个基础模型当中生成各自的伪标签。之后基于伪标签为对方模型产生伪负标签,供对方模型学习。当中任意一模型产伪标签的方式如下:
OneHot表示将模型预测的概率分布进行热编码操作。Y(p)表示模型对该无标签数据生成的伪标签。其次基于上述操作产生的伪标签,进一步去生成伪负标签,方式如下:
Y(c)∈z(Y(p),m)
伪负标签Y(c)从集合z(Y(p),m)中随机选择。K表示总类别个数,m为一个正整数其大小为大于等于1小于K,表示选择伪负标签的个数。
通过以上操作,两个模型分别为对方生成伪负标签Y(c1)和Y(c2)
之后计算两个模型对强增强数据产生的预测和相应的伪负标签之间的无标签损失。无标签损失函数如下:
其中Y(2)表示基础模型预测的无标签数据的概率向量,Y(c)表示所述伪负标签编码后的热向量。
因此,两个模型的无标签损失如下:
和/>分别表示两个模型的无监督损失值。
最终,两个模型各自的总损失值可以如下所示:
其中λ是有监督损失和无监督损失之间的平衡系数,其取值范围为0.5~1之间,当然并不仅限于上述范围,本领域技术人员可以对λ的取值进行适应性调整。
(6)对每一个模型,通过使用有标签数据计算模型对各个类别误分类到其余类别的分数。分数计算公式如下:
Prk表示基础模型对第k类别数据误分到其余类别的分数向量。Prk[j]表示基础模型对第k类别数据误分到第j类上的概率分数。Nk表示模型对于该批次中第k类数据误分类的总个数。pij表示基础模型对第i个样本第j类的预测概率。在迭代过程中,采用指数移动平均去更新Prk
在使用之前对其进行标准化操作,如下:
Rk=Softmax(Prk)
当模型生成的伪标签类别为k时,R中的各个概率分数将作为对应的其余类别被选中的概率,从而使得在生成伪负标签时,对方模型容易误分的类别被选中的可能性更大。因此,这种依据有标签数据得出基础模型对某一类别容易误判为其他类别的概率并基于该概率来优选伪负标签的方式,使得伪负标签更具有针对性,进而非常有利于训练效率和训练得到的模型的准确性。
(7)使用梯度下降更新两个模型参数,重复步骤(2)-步骤(6)直至两个基础模型收敛,从而得到能够用于图像分类的图像分类模型。
具体的,所述步骤(1)是作为通常的初始化步骤,具有一对基础模型,并且双方模型具有相同的网络结构。在每次迭代过程中,学生网络参数使用梯度下降进行更新。所述步骤(2)中对于小批次的有标签数据当中每个图像进行两次不同的弱增强操作,对小批次的无标签数据图像进行弱增强和强增强操作。所述步骤(3)中,利用交叉熵函数计算两个模型的有监督损失值。所述步骤(4)中,利用模型对弱增强处理图像的预测生成伪标签。
更加具体的,步骤(5)中所述的伪负标签学习,其过程具体包括:
(i)对于每一个模型使用有标签数据计算其对各个类别误分类的分数值并通过指数移动平均进行更新。
(ii)通过上一步骤产生伪标签传入伪负标签选择模块。其伪负标签代表该图像不属于某一类别或者某些类别,即所述伪负标签代表了误判的类别。对于伪标签为类别k来说,根据上述Rk中各个其余类别被选择的概率大小进行随机抽取,抽取的类别数大于等于1,小于总类别数。最后将生成的伪负标签作为对方模型的负面学习目标。
(iii)对两个模型将有监督损失和无监督损失累加得到总体的损失。
以上具体示例了高效、准确训练图像分类模型的具体实现方法,继续参见图2,本发明实施例还提供一种基于伪负标签的图像分类模型训练系统,该训练系统正是应用上述训练方法的,其包括:
模型数据模块,用于获取结构相同的第一基础模型和第二基础模型以及训练集,所述训练集包括有标签数据及其对应的真值标签,以及无标签数据。
有标签增强模块,用于将任一所述有标签数据分别进行两次不同的弱增强操作,获得有标签弱增强第一结果和有标签弱增强第二结果。
有监督损失模块,用于基于所述有标签数据第一结果和第二结果及其对应的真值标签分别计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值。
无标签增强模块,用于将任一所述无标签数据分别进行弱增强操作和强增强操作,获得无标签弱增强结果和无标签强增强结果。
无标签分类模块,用于将所述无标签弱增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一伪标签和第二伪标签;将所述无标签强增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一分类结果和第二分类结果。
伪负标签模块,用于基于所述第一伪标签和第二伪标签生成第一伪负标签和第二伪负标签,其中,伪负标签代表不同于对应的伪标签的其余类别标签,并基于所述第一分类结果和第二伪负标签计算第一无监督损失值,基于所述第二分类结果和第一伪负标签计算第二无监督损失值。
迭代更新模块,用于基于所述第一有监督损失值和第一无监督损失值,迭代更新所述第一基础模型的参数,基于所述第二有监督损失值和第二无监督损失值,迭代更新所述第二基础模型的参数。
同理,本发明实施例还提供上述训练方法训练获得的图像分类模型。
作为上述技术方案的一些典型的应用实例,本发明实施例还应用上述训练方法和系统,实际进行了多个模型的训练,且与现有技术中的多种训练方法和系统进行了比较,尤其是他们的选取有标签数据的数量与识别准确率。
具体的,利用已知的网络数据对该基于伪负标签的半监督图像分类方法的可靠性进行了进行验证,同时和现有的半监督图像分类方法进行了比较,如下所示:
实施例1
本实施例示例本发明所提供的训练方法与现有的半监督图像分类模型训练方法的比较:
本实施例中分别使用WideResNet-28-2构建了两个基础模型以及使用CNN-13构建了两个基础模型,一次构建的两个模型具有相同的结构,且有标签数据和无标签数据划分的批次大小为N=256,损失平衡系数λ=0.5,指数移动平均中α=0.99,伪负标签抽取类别数m=3。本实例在所有的实验中都使用了数据扩充的方法,包括平移、翻转和色彩抖动。所有网络都用随机梯度下降(SGD)优化器进行训练。初始学习率为0.03,动量为0.9,并且使用多项式衰减策略,即1-(iter/max_iter)0.9,用于调整学习率。
下表1和表2所示为在两种不同的数据集上,采用不同结构网络做为基础模型以及不同的有标签数据划分下,本实例所述半监督图像分类方法和现有的半监督图像分类方法的比较结果。
表1 CIFAR-10数据集和CNN-13网络结构下不同训练方法的模型预测结果
表2 SVHN数据集和WideResNet-28-2网络结构下不同训练方法的模型预测结果
其中CIFAR-10和SVHN数据集的训练集分别有50000和73257张图片。其中,表中第一栏数字代表了随机选择有标签数据的数量,表中数值为模型在测试数据集中的识别准确率(%);DNLL为本实施例所提供的训练方法。表1使用的数据集和网络分别为CIFAR-10和CNN-13,表二使用的数据集和网络分别为SVHN和WideResNet-28-2。由表1和表2可知,本实施例所述识别方法有效提高了半监督分类模型的训练效率和图像分类模型的分类准确率。
实施例2
本实施例示例本发明所提供的训练方法与现有的基于多模型相互学习的半监督图像分类方法比较:
下表3所示为在CIFAR-10数据集上,采用不同结构网络做为基础模型以及不同的有标签数据划分下,本实例所提供半监督图像分类训练算法和现有的基于多模型相互学习的半监督图像分类训练方法的比较结果,其中的DNLL均为本实施例所提供的训练算法,其余方法为现有的训练方法。由表3可知,本实施例所述方法在两种不同的网络结构下均有效提高了半监督分类模型的性能。
表3 CIFAR-10数据集上不同训练方法的模型预测结果
实施例3
本实施例示例本发明所提供的训练方法与现有的基于单模型的自学习框架比较:
下表4所示为在CIFAR-10数据集上,在特定的有标签数据划分下,本实例所述半监督图像分类算法在基于双模型相互学习框架和基于单模型的自学习框架下的结果比较。其中,ML为本实施例基于双模型相互学习的实现,SL为基于单模型自学习的实现。由表4可知,本实施例所述识别方法在双模型相互学习框架下的性能更佳。
表4 CIFAR-10数据集上不同训练方法的模型预测结果
基于上述各实施例,可以明确,本发明实施例所提供的训练方法,通过生成伪负标签,避免了对伪标签进行筛选,有效提高了无标签数据的利用率,并且降低了两个基础模型参数之间的耦合程度;并通过伪负标签选择模块生成更高效的伪负标签供对方基础模型学习,从而显著提高了训练效率和模型分类准确性。
本实施例还提供了一种电子设备,包括:一个或多个处理器;以及存储器,其中存储器用于存储一个或多个可执行指令;一个或多个处理器被配置为经由执行一个或多个可执行指令以本实施例所述训练的步骤或运行本发明实施例所提供的图像分类模型。
本发明的具体实施手段可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。计算机可读存储介质可以是保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以包括但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、静态随机存取存储器(SRAM)、便携式压缩盘只读存储器(CD-ROM)、数字多功能盘(DVD)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。
需要说明的是,虽然上文按照特定顺序描述了各个步骤,但是并不意味着必须按照上述特定顺序来执行各个步骤,实际上,这些步骤中的一些可以并发执行,甚至改变顺序,只要能够实现所需要的功能即可。
应当理解,以上所描述的实施例是本发明一部分实施例,而不是全部的实施例。本发明的实施例的详细描述并非旨在限制要求保护的本发明的范围,而是仅仅表示本发明的选定实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。

Claims (8)

1.一种基于伪负标签的图像分类半监督模型训练方法,其特征在于,包括:
1)获取结构相同的第一基础模型和第二基础模型以及训练集,所述训练集包括有标签数据及其对应的真值标签,以及无标签数据;
2)将任一所述有标签数据分别进行两次不同的弱增强操作,获得有标签第一结果和有标签第二结果,并分别输入所述第一基础模型和第二基础模型,获得两个有标签分类结果,并结合对应的所述真值标签计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值;
3)将任一所述无标签数据分别进行弱增强操作和强增强操作,获得无标签弱增强结果和无标签强增强结果;
4)将所述无标签弱增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一伪标签和第二伪标签;将所述无标签强增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一分类结果和第二分类结果;
5)基于所述第一伪标签和第二伪标签生成第一伪负标签和第二伪负标签,其中,伪负标签代表不同于对应的伪标签的其余类别标签,并基于所述第一分类结果和第二伪负标签计算第一无监督损失值,基于所述第二分类结果和第一伪负标签计算第二无监督损失值;
所述伪负标签的生成方法具体包括:
基于所述有标签分类结果,统计所述真值标签所指示的类别之外的其余各个类别的分类概率;
基于所述分类概率,从所述伪标签所指示的类别之外的其余各个类别中随机抽取若干类别作为所述伪负标签;
6)基于所述第一有监督损失值和第一无监督损失值,迭代更新所述第一基础模型的参数,基于所述第二有监督损失值和第二无监督损失值,迭代更新所述第二基础模型的参数,获得用于图像分类的图像分类模型。
2.根据权利要求1所述的训练方法,其特征在于,通过使用有标签数据计算基础模型误分类到其余各个类别的分类分数,所述分类分数代表所述分类概率;
所述分类分数的计算公式为:
Rk=Softmax(Prk)
其中,Prk表示基础模型对第k类别数据误分到其余类别的分数向量,Prk[j]代表所述基础模型对第k类别数据误分到第j类上的概率分数;Nk代表模型对于本次迭代中的多个有标签数据中第k类数据误分类的总个数;pij表示基础模型对第i个样本第j类的预测概率;Rk代表标准化过后的所述概率分数,并作为所述分类分数使用;
且在迭代过程中,采用指数移动平均的方法更新Prk[j]。
3.根据权利要求1所述的训练方法,其特征在于,步骤4)中,所述伪标签的计算公式为:
其中,Y(p)代表所述伪标签;OneHot代表热编码操作;代表所述无标签弱增强结果;代表基础模型对所述无标签弱增强结果进行分类获得的分类概率分布。
4.根据权利要求3所述的训练方法,其特征在于,步骤5)中,所述伪负标签的计算公式为:
Y(c)∈z(Y(p),m)
其中,Y(c)代表所述伪负标签,从集合z(Y(p),m)中随机选择;K代表总类别个数;m为一个大于等于1小于K的正整数,代表随机选择的伪负标签的个数;v代表一个包含K个特征的一维向量。
5.根据权利要求1所述的训练方法,其特征在于,所述第一有监督损失值和第二有监督损失值的损失函数分别为:
上述公式中,
其中,代表所述第一有监督损失值;/>代表所述第二有监督损失值;/>和/>代表两次不同的弱增强操作获得的两个有标签弱增强结果;fθ代表所述第一基础模型,/>代表所述第二基础模型;Y(t)代表所述真值标签编码之后的热向量,Y(1)代表所述有标签分类结果;
所述第一无监督损失值和第二无监督损失值的损失函数分别为:
上述公式中,
其中,代表所述第一无监督损失值,/>代表所述第二无监督损失值;/>代表所述无标签弱增强结果;Y(c)代表所述伪负标签编码后的热向量;Y(2)代表所述第一分类结果或第二分类结果。
6.根据权利要求5所述的训练方法,其特征在于,将有监督损失值与无监督损失值的线型加和值作为总损失值对相应的基础模型的参数进行更新;
所述总损失值的计算公式分别为:
其中,λ代表有监督损失和无监督损失之间的平衡系数,其取值范围为0.5~1之间;l(1)和l(2)分别代表所述第一基础模型和第二基础模型对应的总损失值。
7.一种基于伪负标签的图像分类模型训练系统,其特征在于,包括:
模型数据模块,用于获取结构相同的第一基础模型和第二基础模型以及训练集,所述训练集包括有标签数据及其对应的真值标签,以及无标签数据;
有标签增强模块,用于将任一所述有标签数据分别进行两次不同的弱增强操作,获得有标签第一结果和有标签第二结果;
有监督损失模块,用于基于所述有标签第一结果和有标签第二结果及其对应的真值标签分别计算所述第一基础模型的第一有监督损失值和第二基础模型的第二有监督损失值;
无标签增强模块,用于将任一所述无标签数据分别进行弱增强操作和强增强操作,获得无标签弱增强结果和无标签强增强结果;
无标签分类模块,用于将所述无标签弱增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一伪标签和第二伪标签;将所述无标签强增强结果分别输入所述第一基础模型和第二基础模型中进行分类,获得第一分类结果和第二分类结果;
伪负标签模块,用于基于所述第一伪标签和第二伪标签生成第一伪负标签和第二伪负标签,其中,伪负标签代表不同于对应的伪标签的误分类结果,并基于所述第一分类结果和第二伪负标签计算第一无监督损失值,基于所述第二分类结果和第一伪负标签计算第二无监督损失值;
迭代更新模块,用于基于所述第一有监督损失值和第一无监督损失值,迭代更新所述第一基础模型的参数,基于所述第二有监督损失值和第二无监督损失值,迭代更新所述第二基础模型的参数。
8.权利要求1-6中任意一项所述的训练方法训练获得的用于图像分类的图像分类模型。
CN202211232414.8A 2022-10-09 2022-10-09 基于伪负标签的半监督模型训练方法、系统及应用 Active CN115482418B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211232414.8A CN115482418B (zh) 2022-10-09 2022-10-09 基于伪负标签的半监督模型训练方法、系统及应用

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211232414.8A CN115482418B (zh) 2022-10-09 2022-10-09 基于伪负标签的半监督模型训练方法、系统及应用

Publications (2)

Publication Number Publication Date
CN115482418A CN115482418A (zh) 2022-12-16
CN115482418B true CN115482418B (zh) 2024-06-07

Family

ID=84393563

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211232414.8A Active CN115482418B (zh) 2022-10-09 2022-10-09 基于伪负标签的半监督模型训练方法、系统及应用

Country Status (1)

Country Link
CN (1) CN115482418B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115690100B (zh) * 2022-12-28 2023-04-07 珠海横琴圣澳云智科技有限公司 半监督信号点检测模型训练方法、信号点检测方法和装置
CN116778239B (zh) * 2023-06-16 2024-06-11 酷哇科技有限公司 面向实例分割模型的半监督训练方法及设备
CN117197474B (zh) * 2023-09-28 2024-08-02 江苏开放大学(江苏城市职业学院) 一种基于类别均衡及交叉合并策略的噪声标签学习方法
CN118334352B (zh) * 2024-06-13 2024-08-13 宁波大学 一种点云语义分割模型的训练方法、系统、介质及设备

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111078911A (zh) * 2019-12-13 2020-04-28 宁波大学 一种基于自编码器的无监督哈希方法
KR20200046173A (ko) * 2018-10-18 2020-05-07 부산대학교 산학협력단 언레이블 데이터를 사용한 나이브 반지도 심층 학습의 제공 방법 및 그 시스템
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN114037876A (zh) * 2021-12-16 2022-02-11 马上消费金融股份有限公司 一种模型优化方法和装置
WO2022042002A1 (zh) * 2020-08-31 2022-03-03 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN114648779A (zh) * 2022-03-14 2022-06-21 宁波大学 基于自标签精炼深度学习模型的无监督行人重识别方法
CN114743109A (zh) * 2022-04-28 2022-07-12 湖南大学 多模型协同优化高分遥感图像半监督变化检测方法及系统
CN114881149A (zh) * 2022-05-10 2022-08-09 杭州海康威视数字技术股份有限公司 一种模型训练方法及装置、目标检测方法及装置
CN114943879A (zh) * 2022-07-22 2022-08-26 中国科学院空天信息创新研究院 基于域适应半监督学习的sar目标识别方法

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11694093B2 (en) * 2018-03-14 2023-07-04 Adobe Inc. Generation of training data to train a classifier to identify distinct physical user devices in a cross-device context
US11416772B2 (en) * 2019-12-02 2022-08-16 International Business Machines Corporation Integrated bottom-up segmentation for semi-supervised image segmentation
KR20210149530A (ko) * 2020-06-02 2021-12-09 삼성에스디에스 주식회사 이미지 분류 모델 학습 방법 및 이를 수행하기 위한 장치

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20200046173A (ko) * 2018-10-18 2020-05-07 부산대학교 산학협력단 언레이블 데이터를 사용한 나이브 반지도 심층 학습의 제공 방법 및 그 시스템
CN111078911A (zh) * 2019-12-13 2020-04-28 宁波大学 一种基于自编码器的无监督哈希方法
WO2022042002A1 (zh) * 2020-08-31 2022-03-03 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN114037876A (zh) * 2021-12-16 2022-02-11 马上消费金融股份有限公司 一种模型优化方法和装置
CN114648779A (zh) * 2022-03-14 2022-06-21 宁波大学 基于自标签精炼深度学习模型的无监督行人重识别方法
CN114743109A (zh) * 2022-04-28 2022-07-12 湖南大学 多模型协同优化高分遥感图像半监督变化检测方法及系统
CN114881149A (zh) * 2022-05-10 2022-08-09 杭州海康威视数字技术股份有限公司 一种模型训练方法及装置、目标检测方法及装置
CN114943879A (zh) * 2022-07-22 2022-08-26 中国科学院空天信息创新研究院 基于域适应半监督学习的sar目标识别方法

Non-Patent Citations (5)

* Cited by examiner, † Cited by third party
Title
Negative Pseudo Labeling using Class Proportion for Semantic Segmentation in Pathology;Hiroki Tokunaga 等;《arXiv:2007.08044v1》;全文 *
一种结合GAN和伪标签的深度半监督模型研究;杨灿;;中国科技信息(17);全文 *
基于ResNet 的音频场景声替换造假的检测算法;严迪群 等;《计算机应用》;全文 *
基于伪标签半监督核局部Fisher判别分析轴承故障诊断;陶新民;任超;徐朗;何庆;刘锐;邹俊荣;;振动与冲击(17);全文 *
基于可信标签扩展传递的跨领域倾向性分析;侯秀艳;刘培玉;孟凡龙;;计算机应用研究(05);全文 *

Also Published As

Publication number Publication date
CN115482418A (zh) 2022-12-16

Similar Documents

Publication Publication Date Title
CN115482418B (zh) 基于伪负标签的半监督模型训练方法、系统及应用
CN110347839A (zh) 一种基于生成式多任务学习模型的文本分类方法
US11288324B2 (en) Chart question answering
CN112668579A (zh) 基于自适应亲和力和类别分配的弱监督语义分割方法
CN111738169B (zh) 一种基于端对端网络模型的手写公式识别方法
CN112148831B (zh) 图文混合检索方法、装置、存储介质、计算机设备
CN113378938B (zh) 一种基于边Transformer图神经网络的小样本图像分类方法及系统
CN112507912B (zh) 一种识别违规图片的方法及装置
CN112597324A (zh) 一种基于相关滤波的图像哈希索引构建方法、系统及设备
CN114612767B (zh) 一种基于场景图的图像理解与表达方法、系统与存储介质
CN112364747B (zh) 一种有限样本下的目标检测方法
CN111325237A (zh) 一种基于注意力交互机制的图像识别方法
CN111582506A (zh) 基于全局和局部标记关系的偏多标记学习方法
CN116486419A (zh) 一种基于孪生卷积神经网络的书法字识别方法
CN115393631A (zh) 基于贝叶斯层图卷积神经网络的高光谱图像分类方法
CN114998647B (zh) 基于注意力多实例学习的乳腺癌全尺寸病理图像分类方法
CN118468061B (zh) 一种算法自动匹配及参数优化方法及系统
CN115457332A (zh) 基于图卷积神经网络和类激活映射的图像多标签分类方法
CN111709442A (zh) 一种面向图像分类任务的多层字典学习方法
CN113240033B (zh) 一种基于场景图高阶语义结构的视觉关系检测方法及装置
Qin Application of efficient recognition algorithm based on deep neural network in English teaching scene
CN117975464A (zh) 基于U-Net的电气二次图纸文字信息的识别方法及系统
CN117788917A (zh) 基于伪标签的半监督酒店设施识别模型的训练方法、系统
CN113535928A (zh) 基于注意力机制下长短期记忆网络的服务发现方法及系统
CN105787045A (zh) 一种用于可视媒体语义索引的精度增强方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
TA01 Transfer of patent application right
TA01 Transfer of patent application right

Effective date of registration: 20240205

Address after: 518000 1002, Building A, Zhiyun Industrial Park, No. 13, Huaxing Road, Henglang Community, Longhua District, Shenzhen, Guangdong Province

Applicant after: Shenzhen Wanzhida Technology Co.,Ltd.

Country or region after: China

Address before: 315000 Fenghua Road, Jiangbei District, Ningbo, Zhejiang Province, No. 818

Applicant before: Ningbo University

Country or region before: China

TA01 Transfer of patent application right
TA01 Transfer of patent application right

Effective date of registration: 20240520

Address after: 219, 2nd Floor, Teaching Building Section II (Science and Technology Park Building A), West Campus of Beijing University of Chemical Technology, No. 98 Zizhuyuan Road, Haidian District, Beijing, 100000

Applicant after: CENTRAN TECHNOLOGY Co.,Ltd.

Country or region after: China

Address before: 518000 1002, Building A, Zhiyun Industrial Park, No. 13, Huaxing Road, Henglang Community, Longhua District, Shenzhen, Guangdong Province

Applicant before: Shenzhen Wanzhida Technology Co.,Ltd.

Country or region before: China

GR01 Patent grant