CN116563602A - 基于类别级软目标监督的细粒度图像分类模型训练方法 - Google Patents
基于类别级软目标监督的细粒度图像分类模型训练方法 Download PDFInfo
- Publication number
- CN116563602A CN116563602A CN202310352190.2A CN202310352190A CN116563602A CN 116563602 A CN116563602 A CN 116563602A CN 202310352190 A CN202310352190 A CN 202310352190A CN 116563602 A CN116563602 A CN 116563602A
- Authority
- CN
- China
- Prior art keywords
- model
- ema
- target
- category
- fine
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 63
- 238000000034 method Methods 0.000 title claims abstract description 59
- 238000013145 classification model Methods 0.000 title claims abstract description 22
- 239000011159 matrix material Substances 0.000 claims abstract description 17
- 230000008569 process Effects 0.000 claims abstract description 8
- 238000010276 construction Methods 0.000 claims abstract description 5
- 238000009826 distribution Methods 0.000 claims description 14
- 238000010606 normalization Methods 0.000 claims description 13
- 238000012545 processing Methods 0.000 claims description 8
- 230000005540 biological transmission Effects 0.000 claims description 2
- 230000000694 effects Effects 0.000 abstract description 7
- 230000006870 function Effects 0.000 description 10
- 238000010586 diagram Methods 0.000 description 8
- 238000004590 computer program Methods 0.000 description 7
- 241000282472 Canis lupus familiaris Species 0.000 description 6
- 238000012986 modification Methods 0.000 description 4
- 230000004048 modification Effects 0.000 description 4
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000005457 optimization Methods 0.000 description 3
- 238000011160 research Methods 0.000 description 3
- 238000003860 storage Methods 0.000 description 3
- 230000004075 alteration Effects 0.000 description 2
- 238000004519 manufacturing process Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
- 239000002699 waste material Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Multimedia (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及基于类别级软目标监督的细粒度图像分类模型训练方法,以有标签的数据预训练一目标模型;以目标模型的参数初始化EMA模型,根据EMA模型中全连接层的参数计算一相似度矩阵,基于相似度矩阵获得类别级软标签,与图像进行关联;输入图像,基于目标模型和EMA模型构建模型训练的总损失更新目标模型;以新的目标模型更新EMA模型,并用新EMA模型计算出新的类别级软标签;重复并最小化总损失,实现细粒度图像分类模型的训练。本发明能在面对细粒度图像分类的问题上取得良好的效果,既保留了类别之间的关系,也不需要额外的空间储存预训练模型,不需要复杂的聚类过程,也不需要额外的预训练模型来获取软标签;准确率高。
Description
技术领域
本发明涉及计算;推算或计数的技术领域,特别涉及一种基于类别级软目标监督的细粒度图像分类模型训练方法。
背景技术
图像分类是计算机视觉分类领域的一个经典问题,其目标是将不同的图像划分到不同的类别。近年来,深度神经网络在视觉分类领域取得了非常不错的应用效果,已经成为解决计算机视觉领域众多机器学习任务的首选建模工具,尤其在监督模式下训练的大规模神经网络在图像分类任务中取得了明显优于其他传统模型的泛化能力。在过去几年中,深度神经网络推动图像分类取得了很大的进步,但是常见的图像分类集中的类别的粒度仍然较粗,比如,狗这个类别下,还可以细分为拉布拉多犬、金毛寻回犬、边境牧羊犬等细分类别,这就导致了在一些网络中对于这些图像的分类效果不佳。粗粒度的分类已经越来越无法满足实际生产生活的需要,而细粒度图像的分类就是针对这类问题继续的研究。
近年来,细粒度图像分类无论在工业界还是学术界都有着广泛的研究需求与应用场景。对比普通的图像分类问题,细粒度分类面对的图像数据具有更加相似的外观特性。由于分类的粒度很小,细粒度图像分类非常困难,在某些类别上甚至专家都难以区分。主要原因有三点:1.子类之间差异细微:只在某个局部上有细微差异;2.子类内部差异巨大;3.受视角、背景、遮挡等因素影响较大。这些困难令细粒度图像分类成为一项极具挑战的研究任务。在实际生活中,识别不同的子类别又存在着巨大的应用需求。例如,在生态保护中,有效识别不同种类的生物是进行生态研究的重要前提,如果能够借助于计算机视觉的技术实现低成本的细粒度图像识别,那么无论对于学术界还是工业界而言,都有着非常重要的意义。
针对以上问题需要对类别之间的关系进行建模,用于模型训练,而由于硬标签的one-hot的特性,不适用于细粒度图像之间的区分,因此提出了软标签的概念。硬标签和软标签的不同在于,硬标签对于分类结果不是1就是0,而软标签对于分类结果是根据每个类别对应的概率给出一个不那么确定的标签,这就让类别之间有了更多的关联和信息,可以让模型学习到更多的知识。Label Smoothing是一种获得软标签的方法,但是有效性不够,由于其只是单纯的添加随机噪声,也无法反映标签之间的关系,因此对模型的提升有限,甚至有欠拟合的风险。另一种是通过额外的预训练模型来获取软标签,但是会需要额外的空间来储存预训练模型,这就造成了空间的浪费,所以其有效性也不足。
发明内容
本发明解决了现有技术中存在的问题,提供了一种基于类别级软目标监督的细粒度图像分类模型训练方法。
本发明所采用的技术方案是,一种基于类别级软目标监督的细粒度图像分类模型训练方法,所述方法以有标签的数据预训练一目标模型;以目标模型的参数初始化EMA模型,根据EMA模型中全连接层的参数计算一相似度矩阵,基于相似度矩阵获得类别级软标签tClu,与图像进行关联;
输入图像,基于目标模型和EMA模型构建模型训练的总损失更新目标模型;以新的目标模型更新EMA模型,并用新EMA模型计算出新的类别级软标签;重复并最小化总损失,实现细粒度图像分类模型的训练。
本发明中,一般以有标签的样本训练40个epoch,获得目标模型;训练目标模型和EMA模型的过程中,在150个epoch前,每过1个epoch,根据目标函数的参数更新EMA模型,在150个epoch后,每过3个epoch根据目标函数的参数更新EMA模型,在更新EMA模型后更新类级别软标签tCla;通过最小化总损失来训练网络参数,完成细粒度图像分类学习,使其具备较高的细粒度图像分类准确率。
优选地,为训练模型使用交叉熵损失,所述目标函数为,
gc=f(xc;φ)
其中,f表示目标模型,C表示类别总量,c对应每个类别,yc是样本xc的标签,pc是经过归一化处理后模型预测的到第c个类别的概率,T是温度系数,0<T<5,gc是样本xc经过目标模型f后得到的归一化之前的输出,φ是目标模型的参数。
优选地,所述目标模型包括特征提取器和分类器,目标模型的特征提取器和分类器在每一轮(每个epoch)基于总损失的梯度反传进行更新;所述EMA模型包括特征提取器和分类器,EMA模型的特征提取器和分类器初始化为目标模型中的特征提取器和分类器,训练过程中EMA模型的特征提取器和分类器基于当前目标模型通过指数滑动平均进行更新。
优选地,EMA模型的分类器的全连接层VK×C包括每个类别的权重,K为维度,这个信息可以用来近似表示类别之间的关系,得到一个相似度矩阵S,S=VTV,S的每一行每一列都包含一个类别对所有类别的相似度信息,对S的每一列以softmax归一化处理,得到类别级软标签tCla。
优选地,输入图像,基于目标模型和EMA模型构建模型训练的总损失L包括:
基于目标模型预测的概率分布,计算其与样本真实硬标签之间的交叉熵损失LHard;因为目标模型对于类别的分类不一定准确,容易造成最终的分类准确率较低,为缓解这一问题,引入基于模型得到的预测和真实标签的交叉熵损失来优化模型;
基于目标模型预测的概率分布,计算其与类别级软标签之间的KL散度LCla;虽然真实标签能帮助模型优化,但是由于真实标签是one-hot的数据,这就导致了其中包含的信息很少,对于细粒度图像无法获取足够的信息,靠真实标签难以让模型分类细粒度图像,为了让模型学习到更多的类别之间的关系,引入类别级软标签tCla,能让模型学习到更多的知识,最终提高模型对细粒度图像分类的准确率,基于类别级软标签tCla和样本经过目标模型得到的输出软标签算出一个KL散度LCla;
基于目标模型预测的概率分布,计算其与EMA模型输出概率分布之间的KL散度LEMA;虽然真实标签和类级别软标签能帮助模型优化,但是目标模型也可能会在训练的过程中往错误的方向训练,这就需要依靠EMA模型来限制目标模型,让其和之前的训练相差不大并向EMA模型靠近,基于样本经过EMA模型得到的输出软标签和样本经过目标模型得到的输出软标签算出一个KL散度LEMA。
优选地,L=LHard+λ1LCla+λ2LEMA,其中,λ1和λ2为对应损失的权重系数,λ1>0,λ2>0;一般来说,λ1和λ2均取1。
优选地,KL散度LCla满足,
Lcla=KL(tCla,p)
其中,KL散度为KL(t,p),C表示类别总量,c对应每个类别,tCla为类级别的软标签,p是经过归一化处理后模型预测的概率分布。
优选地,KL散度LEMA满足,
LEMA=KL(tEMA,p)
其中,KL()为KL散度,tEMA为样本经过EMA模型得到的输出,p是样本经过目标模型得到的概率分布。
优选地,由于EMA是冻结的,不参与训练,所以需要根据目标模型的参数来更新EMA模型;更新目标模型,以一定比例的EMA模型的参数和剩下比例的目标模型参数及预设的更新频率规则更新EMA模型,根据新的EMA模型计算出新的类级别软标签,
φ'←αφ'+(1-α)φ
其中,φ′是EMA模型参数,φ是当前目标模型参数,α是对应的权重系数,α∈(0,1);一般来说,α=0.95。
优选地,给定一张需要进行分类的图片,以训练后的目标模型利用其训练时学习的类别之间的相关性,输出图片对应的类别。
本发明涉及一种基于类别级软目标监督的细粒度图像分类模型训练方法,以有标签的数据预训练一目标模型;以目标模型的参数初始化EMA模型,根据EMA模型中全连接层的参数计算一相似度矩阵,基于相似度矩阵获得类别级软标签tClu,与图像进行关联;输入图像,基于目标模型和EMA模型构建模型训练的总损失更新目标模型;以新的目标模型更新EMA模型,并用新EMA模型计算出新的类别级软标签;重复并最小化总损失,实现细粒度图像分类模型的训练。
本发明的有益效果在于:
(1)相较于其他分类方法存在不注重细粒度图像的分类,且虽然能在一些粗粒度图像的分类上取得良好的效果,但是对于细粒度图像的分类效果就明显不足的问题,本发明能将相似类之间的关系学习出来,能在面对细粒度图像分类的问题上取得良好的效果;
(2)由于相比于基于硬标签的方法,软标签中包含了更多的知识,考虑了类别之间的关系,使得模型能学习到更多的知识,对于分类准确率的提升有着巨大的帮助,本发明根据分类器的全连接层的权重得出各个类别之间的关系,由此计算出获取软标签的矩阵,既保留了类别之间的关系,也不需要额外的空间储存预训练模型;
(3)相比于其他基于软标签的方法,本发明不需要复杂的聚类过程,也不需要额外的预训练模型来获取软标签;
(4)有效性在CUB-200-2011、Stanford Dogs与MIT67数据集上得到验证,其模型在CUB-200-2011数据集上的准确率能到达71.2%,在Stanford Dogs数据集上的准确率到达69.3%,在MIT67数据集上的准确率能到达70%。
附图说明
图1为本发明的流程图;
图2为本发明的模型示意图;
图3为本发明的类别级软标签示意图。
具体实施方式
下面结合实施例对本发明做进一步的详细描述,但本发明的保护范围并不限于此。
本发明涉及一种基于类别级软目标监督的细粒度图像分类模型训练方法,用带标签的数据预训练一个目标模型;通过目标模型的参数初始化EMA模型;根据EMA模型中全连接层的参数计算出一个相似度矩阵;对相似度矩阵的每一列进行softmax归一化获得类别级软标签,将其中一列作为一个类别图像的软标签;计算目标模型对输入的预测及其真实标签之间的交叉熵损失,计算目标模型对输入的预测及其类别级软标签之间的KL散度,计算目标模型对输入的预测及EMA对输入的预测之间的KL散度;把交叉熵损失和两个KL散度相加作为用于模型训练的总损失;更新了目标模型后根据新的目标模型更新EMA模型,并用新EMA模型计算出新的类别级软标签。对总损失的最小化实现细粒度间的一致性和去相关性,最终实现进行有效的细粒度图像分类。
结合实施例,所述方法包括以下步骤:
步骤1:选取数据集,以有标签的样本预训练目标模型,得到训练40个epoch的目标模型;此处选取公开数据集,如采用CUB-200-2011这个公开数据集作为实验的数据集,其中CUB-200-2011是包含200类鸟类子类的数据集,共有11788张鸟类图像,其中训练数据集有5994张图像,测试集有5794张图像,每张图像均提供了图像类标记信息;
用CUB-200-2011训练目标模型,选取Resnet18网络作为骨干模型,最后是一层200类的分类器,以有标签的训练集训练目标模型时,使用交叉熵损失来保持模型的效率,先使用40个epoch将目标模型训练出来。其中,温度参数T为3,批次大小为64;可以通过使用pytorch等深度学习框架对数据集执行以上操作,将图片输入载入DataLoader中,遍历DataLoader中的数据,输入编码器中,获取它们的模型输出,计算损失,使用sgd优化器优化模型。
用有标签的数据集训练目标模型时,使用标准交叉熵损失增加鲁棒性。通过计算交叉熵损失,目标损失为,
gc=f(xc;φ)
其中,f表示目标模型,C表示类别总量,c对应每个类别,yc是样本xc的真实标签,pc是经过softmax归一化处理后模型预测的到第c个类别的概率,T是温度参数,0<T<5,gc是某个样本xc经过目标模型f后得到的归一化之前的输出,φ是目标模型的参数。
步骤2:所述目标模型包括特征提取器和分类器,目标模型的特征提取器和分类器在每一轮基于总损失的梯度反传进行更新;所述EMA模型包括特征提取器和分类器,EMA模型的特征提取器和分类器初始化为目标模型中的特征提取器和分类器,训练过程中EMA模型的特征提取器和分类器基于当前目标模型通过指数滑动平均进行更新。
在第40个epoch时用目标模型初始化EMA模型,特征提取器Ft固定,分类器Ct固定。
步骤3:根据EMA模型中全连接分类层的参数计算出一个相似度矩阵,对相似度矩阵的每一列进行softmax归一化获得类别级软标签,将其中一列作为一个类别图像的软标签;EMA的分类层的全连接层VK×C包括每个类别的权重,K表示输入向量的维度,C表示类别的总个数;这个信息可以用来近似表示类别之间的关系,具体地说,这些权重V可以通过S=VTV得到一个相似度矩阵S,S的每一行每一列都包含一个类别对所有类别的相似度信息。对S的每一列做softmax归一化,得到类别级软标签tCla。
实施例中,将初始化后的EMA模型的分类层的全连接层V512×200中的200个类别的权重V*提取出来,这些权重V可以通过S=VtV得到一个200×200的相似度矩阵S,S的每一行每一列都包含一个类别对所有类别的相似度信息,对S的每一列做softmax归一化,得到类别级软标签tCla。
步骤4:基于目标模型的分类器的预测,得到样本的输出软标签,计算样本真实硬标签与输出软标签的交叉熵损失。基于目标模型的分类器的预测,得到样本的输出软标签,基于真实标签获得类别软标签,计算类别级软标签与输出软标签的KL散度。基于EMA模型的分类器的预测,得到样本的输出软标签,计算类EMA模型得出的输出软标签与目标模型得出的输出软标签的KL散度。
所述步骤4包括以下步骤:
步骤4.1:因为目标模型对于类别的分类不一定准确,容易造成最终的分类准确率较低,为缓解这一问题,引入基于模型得到的预测和真实标签的交叉熵损失,见公式(1),来优化模型;
步骤4.2:虽然真实标签能帮助模型优化,但是由于真实标签是one-hot的数据,这就导致了其中包含的信息很少,对于细粒度图像无法获取足够的信息,靠真实标签难以让模型分类细粒度图像。为了让模型学习到更多的类别之间的关系,引入类别级软标签tCla,能让模型学习到更多的知识,最终提高模型对细粒度图像分类的准确率。基于类别级软标签tCla和样本经过目标模型得到的输出软标签算出一个KL散度,
LCla=KL(tCla,p) (3)
其中,KL表示KL散度,c对应每个类别,C表示类别总个数,tc表示第c个类别对应的类别级软标签,pc是对某一给定向量的softmax归一化操作后得到的结果,tCla表示所有类的类级别的软标签,p是经过归一化处理后模型预测的概率分布;
步骤4.3:虽然真实标签和类级别软标签能帮助模型优化,但是目标模型也可能会在训练的过程中往错误的方向训练。这就需要依靠EMA模型来将目标模型拉住,让其和之前的训练相差不大,向EMA模型靠近。基于样本经过EMA模型得到的输出软标签和样本经过目标模型得到的输出软标签算出一个KL散度,
LEMA=KL(tEMA,p) (4)
其中,KL表示KL散度,tEMA表示样本经过EMA模型得到的输出,p是样本经过目标模型得到的概率分布。
步骤5:在目标模型更新后,以一定权重和预设的更新频率规则进行EMA模型的更新,并根据新的EMA模型得到新的类级别软标签;
所述步骤5中,由于EMA是冻结的,不参与训练,所以需要根据目标模型的参数来更新EMA模型。以一定比例的EMA模型的参数加上剩下比例的目标模型参数,得到新的EMA的模型:
φ′←αφ′+(1-α)φ (5)
其中,φ′是EMA模型参数,φ是当前目标模型参数,α是对应的权重系数,α=0.95。
在更新了EMA模型后,根据新的EMA模型计算出新的类级别软标签。
步骤6:将上述损失相加得到总损失L,通过最小化所有损失函数共同训练目标模型,完成细粒度图像分类学习,使其具备较高的细粒度图像分类准确率。
所述步骤6中,完整目标函数L为,
L=LHard+λ1LCla+λ2LEMA (7)
其中,λ1和λ2为对应的超参数,λ1=1,λ2=1。
基于完整目标函数L优化目标模型;优化目标域模型采用SGD优化器训练,动量为0.9,权重衰减为5*10-4,批次大小为64,学习率为0.1;训练时,EMA模型在第40-150个epoch间每一个epoch更新一次,在第150个epoch后每三个epoch更新一次,超参λ1=1.0,λ2=1.0,epoch设为200。
基于本发明训练目标模型可以实现对细粒度图像的分类,减少细粒度图像对图像分类效果的影响;给定一张需要进行分类的图片,以训练后的目标模型利用其训练时学习的类别之间的相关性,输出图片对应的类别。
基于此方法,可以实现计算机介质及程序、设备的开发。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
尽管已描述了本发明的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本发明范围的所有变更和修改。
显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。
Claims (10)
1.一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:所述方法以有标签的数据预训练一目标模型;以目标模型的参数初始化EMA模型,根据EMA模型中全连接层的参数计算一相似度矩阵,基于相似度矩阵获得类别级软标签tCla,与图像进行关联;
输入图像,基于目标模型和EMA模型构建模型训练的总损失更新目标模型;以新的目标模型更新EMA模型,并用新EMA模型计算出新的类别级软标签;重复并最小化总损失,实现细粒度图像分类模型的训练。
2.根据权利要求1所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:所述目标函数为,
gc=f(xc;φ)
其中,f表示目标模型,C表示类别总量,c对应每个类别,yc是样本xc的标签,pc是经过归一化处理后模型预测的到第c个类别的概率,T是温度系数,0<T<5,gc是样本xc经过目标模型f后得到的归一化之前的输出,φ是目标模型的参数。
3.根据权利要求1所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:所述目标模型包括特征提取器和分类器,目标模型的特征提取器和分类器在每一轮基于总损失的梯度反传进行更新;所述EMA模型包括特征提取器和分类器,EMA模型的特征提取器和分类器初始化为目标模型中的特征提取器和分类器,训练过程中EMA模型的特征提取器和分类器基于当前目标模型通过指数滑动平均进行更新。
4.根据权利要求1所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:EMA模型的分类器的全连接层VK×C包括每个类别的权重,K为维度,得到一个相似度矩阵S,
S=VTV,对S的每一列归一化处理,得到类别级软标签tcla。
5.根据权利要求1所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:输入图像,基于目标模型和EMA模型构建模型训练的总损失L包括:
基于目标模型预测的概率分布,计算其与样本真实硬标签之间的交叉熵损失LHard;
基于目标模型预测的概率分布,计算其与类别级软标签之间的KL散度LCla;
基于目标模型预测的概率分布,计算其与EMA模型输出概率分布之间的KL散度LEMA。
6.根据权利要求5所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:L=LHard+λ1LCla+λ2LEMA,其中,λ1和λ2为对应损失的权重系数,λ1>0,λ2>0。
7.根据权利要求5所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:KL散度LCla满足,
LCla=KL(tCla,p)
其中,KL散度为KL(t,p),C表示类别总量,c对应每个类别,tCla为类级别的软标签,p是经过归一化处理后模型预测的概率分布。
8.根据权利要求5所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:KL散度LeMA满足,
LeMA=KL(teMA,p)
其中,KL()为KL散度,tEMA为样本经过EMA模型得到的输出,p是样本经过目标模型得到的概率分布。
9.根据权利要求1所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:目标模型更新后,以一定权重和预设的更新频率规则进行EMA模型的更新,
φ′←αφ′+(1-α)φ
其中,φ′是EMA模型参数,φ是当前目标模型参数,α是对应的权重系数,α∈(0,1)。
10.根据权利要求1所述的一种基于类别级软目标监督的细粒度图像分类模型训练方法,其特征在于:给定一张需要进行分类的图片,以训练后的目标模型利用其训练时学习的类别之间的相关性,输出图片对应的类别。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310352190.2A CN116563602A (zh) | 2023-04-04 | 2023-04-04 | 基于类别级软目标监督的细粒度图像分类模型训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310352190.2A CN116563602A (zh) | 2023-04-04 | 2023-04-04 | 基于类别级软目标监督的细粒度图像分类模型训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116563602A true CN116563602A (zh) | 2023-08-08 |
Family
ID=87497278
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310352190.2A Pending CN116563602A (zh) | 2023-04-04 | 2023-04-04 | 基于类别级软目标监督的细粒度图像分类模型训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116563602A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117036829A (zh) * | 2023-10-07 | 2023-11-10 | 之江实验室 | 一种基于原型学习实现标签增强的叶片细粒度识别方法和系统 |
-
2023
- 2023-04-04 CN CN202310352190.2A patent/CN116563602A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117036829A (zh) * | 2023-10-07 | 2023-11-10 | 之江实验室 | 一种基于原型学习实现标签增强的叶片细粒度识别方法和系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Kan | Machine learning applications in cell image analysis | |
CN111814854B (zh) | 一种无监督域适应的目标重识别方法 | |
Angermueller et al. | Deep learning for computational biology | |
CN110490239B (zh) | 图像质控网络的训练方法、质量分类方法、装置及设备 | |
Myronenko et al. | Accounting for dependencies in deep learning based multiple instance learning for whole slide imaging | |
CN113469186B (zh) | 一种基于少量点标注的跨域迁移图像分割方法 | |
CN109447096B (zh) | 一种基于机器学习的扫视路径预测方法和装置 | |
CN109063743B (zh) | 基于半监督多任务学习的医疗数据分类模型的构建方法 | |
CN113821670B (zh) | 图像检索方法、装置、设备及计算机可读存储介质 | |
CN113761261A (zh) | 图像检索方法、装置、计算机可读介质及电子设备 | |
CN113657561A (zh) | 一种基于多任务解耦学习的半监督夜间图像分类方法 | |
Mougeot et al. | A deep learning approach for dog face verification and recognition | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
CN111126464A (zh) | 一种基于无监督域对抗领域适应的图像分类方法 | |
CN114913923A (zh) | 针对单细胞染色质开放性测序数据的细胞类型识别方法 | |
CN111126155B (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
CN116563602A (zh) | 基于类别级软目标监督的细粒度图像分类模型训练方法 | |
CN113947133A (zh) | 小样本图像识别的任务重要性感知元学习方法 | |
CN109271957A (zh) | 人脸性别识别方法以及装置 | |
CN114819091B (zh) | 基于自适应任务权重的多任务网络模型训练方法及系统 | |
CN116306793A (zh) | 一种基于对比孪生网络的具有目标任务指向性的自监督学习方法 | |
CN114330514A (zh) | 一种基于深度特征与梯度信息的数据重建方法及系统 | |
CN112115996B (zh) | 图像数据的处理方法、装置、设备及存储介质 | |
CN116704208A (zh) | 基于特征关系的局部可解释方法 | |
CN114821248B (zh) | 面向点云理解的数据主动筛选标注方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |