CN115995018A - 基于样本感知蒸馏的长尾分布视觉分类方法 - Google Patents
基于样本感知蒸馏的长尾分布视觉分类方法 Download PDFInfo
- Publication number
- CN115995018A CN115995018A CN202211579446.5A CN202211579446A CN115995018A CN 115995018 A CN115995018 A CN 115995018A CN 202211579446 A CN202211579446 A CN 202211579446A CN 115995018 A CN115995018 A CN 115995018A
- Authority
- CN
- China
- Prior art keywords
- teacher
- feature
- student
- distillation
- loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000009826 distribution Methods 0.000 title claims abstract description 60
- 238000000034 method Methods 0.000 title claims abstract description 58
- 238000004821 distillation Methods 0.000 title claims abstract description 39
- 230000000007 visual effect Effects 0.000 title claims abstract description 16
- 230000008447 perception Effects 0.000 title claims abstract description 9
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 51
- 238000012549 training Methods 0.000 claims abstract description 50
- 238000007781 pre-processing Methods 0.000 claims abstract description 5
- 239000013598 vector Substances 0.000 claims description 43
- 230000006870 function Effects 0.000 claims description 18
- 239000011159 matrix material Substances 0.000 claims description 18
- 230000008569 process Effects 0.000 claims description 13
- 238000004364 calculation method Methods 0.000 claims description 11
- 238000005457 optimization Methods 0.000 claims description 9
- 238000005070 sampling Methods 0.000 claims description 7
- 230000004927 fusion Effects 0.000 claims description 6
- 230000006641 stabilisation Effects 0.000 claims description 4
- 238000011105 stabilization Methods 0.000 claims description 4
- 238000005259 measurement Methods 0.000 claims description 3
- 239000000758 substrate Substances 0.000 claims 1
- 238000001514 detection method Methods 0.000 abstract description 4
- 238000012545 processing Methods 0.000 abstract description 2
- 238000002474 experimental method Methods 0.000 description 7
- 238000012360 testing method Methods 0.000 description 6
- 230000002708 enhancing effect Effects 0.000 description 3
- 238000013508 migration Methods 0.000 description 3
- 230000005012 migration Effects 0.000 description 3
- 238000012935 Averaging Methods 0.000 description 2
- 238000012952 Resampling Methods 0.000 description 2
- 238000003745 diagnosis Methods 0.000 description 2
- 201000010099 disease Diseases 0.000 description 2
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 238000002679 ablation Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000036541 health Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000001568 sexual effect Effects 0.000 description 1
Images
Landscapes
- Image Analysis (AREA)
Abstract
基于样本感知蒸馏的长尾分布视觉分类方法,涉及图像处理、目标检测领域。建立长尾分布数据集,对输入图片预处理,训练教师网络模型和学生网络模型,计算交叉熵损失和特征中心稳定学习损失,利用选择性知识蒸馏模块计算选择性知识蒸馏损失;三种损失对学生网络模型优化训练。提出特征中心稳定学习模块:计算得到全局性类间特征中心,利用全局特征中心的类间关系将局部特征中心优化,丰富尾部类数据特征的丰富性和表达能力,利用优化得到的局部特征中心对样本分类;提出选择性知识蒸馏模块:根据教师模型与学生模型的知识正确性及置信度有侧重性地将教师模型的知识蒸馏给学生。可用于长尾图像分类、目标检测等。
Description
技术领域
本发明涉及图像处理、目标检测领域,尤其是涉及安全或健康的关键应用,例如自动驾驶和医疗/疾病诊断等数据本质上是严重失调的、存在长尾效应现实问题的一种基于样本感知蒸馏的长尾分布视觉分类方法。
背景技术
长尾分布问题因为其应用前景以及实际应用价值,近年来受到广泛的关注,并涌现出许多优秀的算法。这些算法大致分为三大类:基于重加权/重采样的长尾分布学习算法,基于解耦表示和分类器的长尾分布学习算法,基于知识迁移的长尾分布学习算法。相比其他两类算法,基于知识迁移的长尾分布学习算法在目前受到的关注度较高。该方法通过将头部的知识或者学习得足够充分后的教师知识迁移至学生模型达到在尾部类数据不充足的情况下也能取得较好的性能。这类方法通常也会与其他额外模块叠加使用,对尾部类的样本的特征表达能力进行增强。
知识蒸馏被广泛的用于模型压缩和迁移学习当中,其中自蒸馏和互蒸馏是知识蒸馏领域中两个很重要的分支。自蒸馏是一种模仿模型自身在不同训练阶段输出的一种学习策略,而互蒸馏是采用多个网络(2个或更多)同时训练,每个网络在训练过程中不仅接受来自真值标记的监督,还参考同伴网络的学习经验提升泛化能力。在整个过程中,两个网络之间不断分享学习经验,实现互相学习共同进步。
现有的长尾分布视觉分类方法通常利用对尾部类增加权重或重采样,但这类方法往往过多关注尾部类,重加权和重采样的方法泛化性能差。
发明内容
本发明的目的在于针对现有技术存在的重加权和重采样的方法泛化性能差等问题,提供可有效提升长尾分布问题的分类性能的一种基于样本感知蒸馏的长尾分布视觉分类方法,利用深度学习模型提升尾部类样本的特征表达能力,有侧重性地进行知识蒸馏提升长尾分布图像的分类性能。
本发明包括以下步骤:
1)建立长尾分布数据集,对数据集采样作为输入图片并预处理;
2)将步骤1)预处理后的图像输入教师网络,训练教师网络模型后,将训练集所有样本输预训练教师模型中,得到特征向量的均值求全局类别特征中心,取出预测置信度中最高置信度的标签,得到教师预测结果;
3)将步骤1)预处理后的图像输入学生网络中训练学生网络模型,样本经特征编码器得到特征向量,特征向量进入两个分支,分别计算交叉熵损失和特征中心稳定学习损失;取出预测置信度中最高置信度所对应的类别标签,得到学生预测结果;
4)学生预测结果、教师预测结果及真实标签三者利用选择性知识蒸馏模块计算得到选择性知识蒸馏损失;
5)联合交叉熵损失、特征中心稳定学习损失和选择性知识蒸馏损失三种损失对学生网络模型优化训练。
在步骤1)中,所述长尾分布数据集其中N,C分别表示图像样本总数和总类别数;对该数据集而言,不同类别对应的样本数量是不均衡的,对第c个类别,若对应的样本数量为nc,则nmin<...<nc<...<nmax;每个批次随机选取设定的batch_size大小的图片数量作为网络的输入;
所述预处理包括对输入图片进行归一化,随机裁剪至固定大小(p*p),随机翻转进行数据增强。
在步骤2)中,所述训练教师网络模型的具体步骤可为:将预处理后的图像输入教师网络(Teacher Network),图像经过网络的特征编码器得到64维度的特征向量ft和网络逻辑预测输出zt,该逻辑预测输出zt经过功能函数归一化后得到预测置信度pt,该置信度pt与真实标签y构成交叉熵损失LCE对模型进行约束;
训练结束后,将训练集所有样本输入到预训练教师模型中,得到所有样本的特征向量f′t、逻辑预测输出z′t及预测置信度p′t;利用特征向量f′t对每个类别求特征向量的均值得到全局类别特征中心Qg;取出预测置信度p′t中最高置信度的标签,得到教师预测结果yt。
在步骤3)中,所述训练学生网络模型的具体步骤可为:
将图像输入到学生网络(Student Network)中,其中学生网络的模型与教师网络模型一致,且学生网络与教师网络不共享参数;每个批次(batch)的图像经过学生网络的特征编码器可以得到64维度的特征向量fs,该特征向量fs进入两个分支;第一个分支进一步得到网络逻辑预测输出zs,并经过功能函数归一化后得到预测置信度ps,该预测置信度ps和真实标签y计算得到交叉熵损失LCE;第二个分支通过对该特征向量求类别特征均值得到局部类别特征中心Ql,并和全局类别特征中心Qg一起输入到特征中心稳定学习模块,计算得到特征中心稳定学习损失LSFCL;取出预测置信度ps中最高置信度所对应的类别标签,得到学生预测结果ys。
在步骤3)中,所述特征中心稳定学习模块用于提升尾部类样本的特征表达能力,具体步骤如下:
(3.1)对于训练样本,训练集图像在教师网络输出的特征向量ft,其经过逻辑预测得到zt,如式(1)所示;zt经过softmax函数,得到预测置信度pt,如式(2)所示;预测结果yt,是所有预测置信度的最大值所对应的类别标号,如式(3)所示;将一个batch的图像集输入学生网络对应的输出为特征向量fs,逻辑预测输出zs以及预测结果ys,计算过程如下:
zt=logits(ft),zs=logits(fs) (1)
pt=softmax(zt),ps=softmax(zs) (2)
yt=argmax(pt),ys=argmax(ps) (3)
其中,argmax函数表示取出预测置信度中最高置信度对应的索引,即类别标签;
(3.2)利用特征中心稳定学习模块用于提升尾部类样本的特征表达能力;主要包括以下步骤:
ⅰ.利用教师模型得到的全局类别特征中心Qg作为特征中心稳定学习模块的输入,并利用Qg计算得到表示类间关系的亲和度矩阵进一步利用功能函数softmax对亲和度矩阵A在去掉对角线后归一化得到归一化亲和度矩阵
ⅲ.利用归一化亲和度矩阵对当前批次的局部类别特征中心Ql加权优化,即使用移动加权平均(EMA)的方法,利用优化后的局部类别特征中心和全局类别特征中心Qg在每个批次中更新当前批次的局部类别特征中心,得到优化后的局部类别特征中心Qr,计算过程公式如下:
ⅳ.利用产生的Qr拉近每个样本的特征与其对应的特征中心的距离,使其向其对应的特征中心靠齐,采用曼哈顿距离(L1-norm,距离的绝对值之和)作为距离度量方式,得到特征中心稳定学习模块的损失函数,如下式:
LSFCL=||θ(σ(Qr),σ(fbatch),y)||1 (6)
其中,σ(a)=a/||a||表示的是欧式距离(L2-norm,距离平方和),θ(b,c)表示的是b和c之间的余弦相似性,fbatch表示当前batch的样本特征;
在步骤4)中,所述选择性知识蒸馏模块用于提升长尾分布问题的分类性能,具体步骤如下:
(4.1)知识蒸馏通常使用KL散度(Kullback-Leibler divergence)衡量学生模型和教师模型对相同批次样本预测后得到的分布的差异;Kullback-Leibler divergence(KL)的具体计算公式如下:
其中,zs,zt分别为学生模型、教师模型的预测分布;
(4.2)选择性知识蒸馏模块在知识蒸馏的基础上对知识进行挑选;每个批次随机采样batch_size个样本输入到两个模型,得到学生预测结果ys和教师预测结果yt,学生预测结果ys、教师预测结果yt和真实标签y三者作为选择性知识蒸馏模块的输入并计算选择性知识蒸馏损失;该模块中的蒸馏权重由三种情况组成,公式为分别是:
ⅰ.教师预测结果与真实标签不同(教师预测错误),蒸馏权重置0;
ⅱ.教师预测结果与学生预测结果相同,蒸馏权重置为pr,p为1-所预测类别置信度,r取2;
ⅲ.教师预测结果与真实标签一致(教师预测正确),学生预测结果与真实标签不一致,蒸馏权重置1;
总的选择性知识蒸馏损失计算如下式:
通过优化选择性知识蒸馏损失,可以有选择性地有效减小两个模型间的分布差异,此外,学生模型还有效剔除教师模型错误的知识,提升该学生模型对长尾分布问题的分类性能;
总的损失函数如下式:
Ltotal=LCE+LSKD+α·LSFCL (9)
其中,α为损失的平衡系数;重复上述步骤迭代设定的次数,直至训练结束。
在步骤5)中,所述三种损失进行网络优化训练,对于学生网络模型的整个训练过程中,联合交叉熵损失LCE、特征中心稳定学习损失LSFCL和选择性知识蒸馏损失LSKD三种损失进行网络优化训练,交叉熵损失LCE用于对模型进行约束,特征中心稳定学习损失LSFCL用于帮助增强尾部类样本的特征表达能力,选择性知识蒸馏损失LSKD用于帮助将知识有侧重性地蒸馏给学生网络,以提高长尾分布问题的分类性能;经过训练,学生网络模型在尾部类样本上的特征表达能力有所提升,并学习到教师网络模型正确的知识。
本发明根据互蒸馏的核心思想,通过学习模仿教师模型的预测分布,在充分学习教师分布的同时,有选择地剔除教师模型错误的知识,使得学生模型学习到的知识更可靠更准确。利用互蒸馏的思想,同一批次的样本经过两个不同的模型,通过有选择地互相拟合彼此的特征分布,使得二者分布差异尽量拟合的情况下,有效保留最正确的信息。
与现有技术相比,本发明具有以下突出优点:
1.本发明首先提出基于样本感知蒸馏的长尾分布视觉分类方法。考虑到长尾分布问题存在的数据严重不平衡问题,首先通过特征中心稳定学习模块提升尾部类样本的特征表达能力,使得尾部类样本具有更泛化的表达能力;接着利用选择性知识蒸馏模块根据教师模型与学生模型的知识正确性及置信度有侧重性地将教师模型的知识蒸馏给学生,在有效减小两个模型间的分布差异之外还剔除教师模型错误的知识,进一步保证分类结果的准确性。
2.巧妙使用数据增强的思想,通过利用全局的类间关系(类间亲和度矩阵)对batch中的样本特征进行特征融合,达到对尾部类数据增强的效果。这个操作是与常规数据增强不同,不需要提前对数据集操作,而是在训练过程中自发完成。
3.巧妙利用选择性知识蒸馏,使得来自两个模型的所有样本在互相学习彼此的特征分布的基础上,有侧重性地选择教师模型中较为准确的知识。
附图说明
图1是本发明的基于样本感知蒸馏的长尾分布视觉分类方法框架。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下实施例将结合附图对本发明进行作进一步的说明。应当理解,此处所描述的具体实施例仅仅用于解释本发明,并不用于限定本发明。相反,本发明涵盖任何由权利要求定义的在本发明的精髓和范围上做的替代、修改、等效方法以及方案。进一步,为了使公众对本发明有更好的了解,以下对本发明的细节描述中,详尽描述了一些特定的细节部分。对本领域技术人员来说没有这些细节部分的描述也可以完全理解本发明。
本发明实施例包括以下步骤:
(1)长尾分布数据集其中N,C分别表示图像样本总数和总类别数。对该数据集而言,不同类别对应的样本数量是不均衡的,对第c个类别,假设对应的样本数量为nc,那么nmin<...<nc<...<nmax。每个批次随机选取设定的batch_size大小的图片数量作为网络的输入;
(2)对输入图片归一化,并随机裁剪至固定大小(p*p),并随机翻转进行数据增强;
(3)预先训练教师网络模型:将图像输入到教师网络(Teacher Network),图像经过网络的特征编码器可以得到64维度的特征向量ft,进一步可以得到网络逻辑预测输出zt,该逻辑预测输出zt经过功能函数归一化后得到预测置信度pt,该置信度pt与真实标签y构成交叉熵损失LCE对模型进行约束。训练结束后,将训练集所有样本输入到预训练教师模型中,得到所有样本的特征向量f′t、逻辑预测输出z′t及预测置信度p′t。利用特征向量f′t对每个类别求特征向量的均值得到全局类别特征中心Qg。取出预测置信度p′t中最高置信度的标签,得到教师预测结果yt。
(4)训练学生网络模型:将图像输入到学生网络(Student Network)中,其中学生网络的模型与教师网络模型一致,且学生网络与教师网络不共享参数。每个批次(batch)的图像经过学生网络的特征编码器可以得到64维度的特征向量fs,该特征向量fs进入两个分支。第一个分支进一步得到网络逻辑预测输出zs,并经过功能函数归一化后得到预测置信度ps,该预测置信度ps和真实标签y计算得到交叉熵损失LCE。第二个分支通过对该特征向量求类别特征均值得到局部类别特征中心Ql,并和全局类别特征中心Qg一起输入到特征中心稳定学习模块,计算得到特征中心稳定学习损失LSFCL。取出预测置信度ps中最高置信度所对应的类别标签,得到学生预测结果ys。学生预测结果ys、教师预测结果yt及真实标签y三者利用选择性知识蒸馏模块计算得到选择性知识蒸馏损失LSKD。对于学生网络模型的整个训练过程中,联合交叉熵损失LCE、特征中心稳定学习损失LSFCL和选择性知识蒸馏损失LSKD三种损失进行网络优化训练。特征中心稳定学习损失LSFCL帮助增强尾部类样本的特征表达能力,选择性知识蒸馏损失LSKD帮助将知识有侧重性地蒸馏给学生网络,从而提高长尾分布问题的分类性能。
(4.1)对于训练样本,经过要求1中的步骤(3),训练集图像在教师网络输出的特征向量ft,其经过逻辑预测得到zt,如式(1)所示。zt经过softmax函数,得到预测置信度pt,如式(2)所示。预测结果yt,是所有预测置信度的最大值所对应的类别标号,如式(3)所示。经过要求1中的步骤(4),将一个batch的图像集输入学生网络对应的输出为特征向量fs,逻辑预测输出zs以及预测结果ys,计算过程如下式(1)~(3):
zt=logits(ft),zs=logits(fs) (1)
pt=softmax(zt),ps=softmax(zs) (2)
yt=argmax(pt),ys=argmax(ps) (3)
其中,argmax函数表示取出预测置信度中最高置信度对应的索引,即类别标签。
(4.2)利用特征中心稳定学习模块用于提升尾部类样本的特征表达能力。主要包括以下步骤:
ⅰ.利用教师模型得到的全局类别特征中心Qg作为特征中心稳定学习模块的输入,并利用Qg计算得到表示类间关系的亲和度矩阵进一步利用功能函数softmax对亲和度矩阵A在去掉对角线后归一化得到归一化亲和度矩阵
ⅲ.利用归一化亲和度矩阵对当前批次的局部类别特征中心Ql加权优化,即并使用移动加权平均(EMA)的方法,利用优化后的局部类别特征中心和全局类别特征中心Qg在每个批次中更新当前批次的局部类别特征中心,得到优化后的局部类别特征中心Qr,计算过程公式如式(4)和(5):
ⅳ.利用产生的Qr拉近每个样本的特征与其对应的特征中心的距离,使其向其对应的特征中心靠齐,在此采用曼哈顿距离(L1-norm,距离的绝对值之和)作为距离度量方式,最终得到特征中心稳定学习模块的损失函数,如式(6):
LSFCL=||θ(σ(Qr),σ(fbatch),y||1 (6)
其中σ(a)=a/||a||表示的是欧式距离(L2-norm,距离平方和),θ(b,c)表示的是b和c之间的余弦相似性,fbatch表示当前batch的样本特征。
所述利用选择性知识蒸馏模块进一步提升长尾分布问题的分类性能:
(4.3)知识蒸馏通常使用KL散度(Kullback-Leibler divergence)衡量学生模型和教师模型对相同批次样本预测后得到的分布的差异。Kullback-Leibler divergence(KL)的具体计算公式如式(7):
其中zs,zt分别为学生模型、教师模型的预测分布。
(4.4)选择性知识蒸馏模块在知识蒸馏的基础上对知识进行挑选。每个批次随机采样batch_size个样本输入到两个模型,得到学生预测结果ys和教师预测结果yt。学生预测结果ys、教师预测结果yt和真实标签y三者作为选择性知识蒸馏模块的输入并计算选择性知识蒸馏损失。该模块中的蒸馏权重由三种情况组成,公式为分别是:
ⅰ.教师预测结果与真实标签不同(教师预测错误),蒸馏权重置0;
ⅱ.教师预测结果与学生预测结果相同,蒸馏权重置为pr,p为1-所预测类别置信度,r取2;
ⅲ.教师预测结果与真实标签一致(教师预测正确),学生预测结果与真实标签不一致,蒸馏权重置1。
总的选择性知识蒸馏损失计算如式(8):
通过优化选择性知识蒸馏损失,可以有选择性地有效减小两个模型间的分布差异,此外,学生模型还有效剔除教师模型错误的知识,因此进一步提升该学生模型对长尾分布问题的分类性能。
总的损失函数如式(9)计算,其中α为损失的平衡系数。重复上述步骤迭代设定的次数,直至训练结束。
Ltotal=LCE+LSKD+α·LSFCL (9)
(5)经过上述的训练后,学生模型在尾部类样本上的特征表达能力有所提升,并学习到教师模型正确的知识。在测试阶段,利用该学生模型对测试数据集进行类别预测,计算样本的分类情况;
(6)根据分类情况及分类的评价指标算得Top-K(K=1)、每个类别的分类精度及整体的平均精度(mAP)。
本发明设计特征中心稳定学习模块:首先计算得到全局性类间特征中心,其次利用全局特征中心的类间关系将局部特征中心优化(refine),从而丰富尾部类数据特征的丰富性和表达能力,最后利用优化得到的局部特征中心对样本分类;本发明提出选择性知识蒸馏模块:根据教师模型与学生模型的知识正确性及置信度有侧重性地将教师模型的知识蒸馏给学生。本发明可用于长尾图像分类(例如,医疗/疾病诊断),目标检测(例如,自动驾驶)等。
参见图1,本发明的框架为:
步骤1,获取模型的输入图像。
对数据集采样,每个批次随机选取设定的batch_size大小的图片数量作为网络的输入。
对输入图片归一化,CIFAR数据集随机裁剪至指定大小(32*32),ImageNet数据集随机裁剪至指定大小(224*224),并随机裁剪至固定大小(p*p),并随机翻转进行数据增强;
步骤2,得到学生模型和教师模型的特征向量(ys,ft)、逻辑预测输出(zs,zt)、预测置信度(ps,pt)以及预测结果(ys,yt)
(2a)将步骤1数据增强后的图像输入教师网络(Teacher Network)中,训练过程只使用交叉熵损失LCE对模型约束。
(2b)将步骤1数据增强后的图像输入学生网络(Student Network)中,对于每一张图像,得到对应的64维度的特征向量(fs),进一步得到逻辑预测输出(zs)并得到学生预测结果(ys)。在训练学生网络的过程,利用训练好的教师模型在参数不回传的情况下计算得到教师模型对应的特征向量(ft)、逻辑预测输出(zt)和预测结果(yt)。
步骤3,利用三个损失项对学生网络模型优化训练。
(3a)使用常用的交叉熵损失LCE、特征中心稳定学习损失LSFCL和选择性知识蒸馏损失LSKD的总合对模型进行优化训练;
(3b)特征中心稳定学习模块主要包括以下步骤:ⅰ.利用教师模型迭代完整数据集计算得到的特征向量(ft)对每个类别求特征向量的均值得到全局性类别特征中心Qg,并利用Qg计算得到表示类间关系的亲和度矩阵A;ⅱ.在当前batch中,利用学生模型计算当前batch的特征向量并对每个类别求特征向量的均值得到局部(batch)类别特征中心Ql;ⅲ.利用亲和度矩阵A首先对当前batch中的样本进行特征融合,从而提升尾部类样本的特征表达能力,其次对局部类别特征中心Ql进行加权优化(refine),从而更新局部类别特征中心;ⅳ.在当前batch中,拉近每个样本与其对应的特征中心的距离。
(3c)选择性知识蒸馏模块主要利用挑选知识进行蒸馏的思想:学生模型在每个batch中,根据自身预测的标签(ys)、教师预测的标签(yt)及真实标签(y)三者的关系,侧重性地蒸馏教师模型正确的知识,过滤教师模型错误的知识,从而提高学生模型知识的正确性和可靠性。
实验结果以及结果分析:
实验一,用本发明在CIFAR-10/CIFAR-100两个数据集上图像分类。
为验证算法的有效性,在CIFAR-10/CIFAR-100的测试集上,消融实验,表1为实验结果。其中,‘CE’表示交叉熵损失LCE,‘SKD’表示互蒸馏损失LSKD,‘SFCL’表示互蒸馏损失LSFCL‘√’和‘×’表示分别表示使用和不使用对应项的损失,‘CIFAR-10-Top-1’,‘CIFAR-100-Top-1’分别表示两个CIFAR数据集在不平衡因子为100的情况下的平均精度。实验结果表明,本发明所提出的两种损失分别对长尾分布问题分类任务上,都有较大程度的性能提升,验证本发明方法的有效性。
实验二,用本发明在ImageNet2012-LT数据集上图像分类。
为验证算法的有效性,在ImageNet2012-LT的数据集上测试。表2为实验结果,从结果可以发现,本发明提出的基于样本感知蒸馏的长尾分布视觉分类方法在ImageNet2012-LT数据集上同样获得卓越的性能提升。
结合实验一和实验二,本发明在现有的三个长尾分布数据集上都有显著的性能优势,超越当前学术领域的最高水平,验证本发明提出的方法有效提高尾部类样本的特征表达能力并成功选择性蒸馏教师模型的有效知识。
表1.本发明在CIFAR-10/CIFAR-100测试集上的消融实验
表2.本发明在ImageNet2012-LT数据集上的测试结果
ImageNet-Top-1 |
42.81 |
本发明根据互蒸馏的核心思想,通过学习模仿教师模型的预测分布,在充分学习教师分布的同时,有选择地剔除教师模型错误的知识,使得学生模型学习到的知识更可靠更准确。利用互蒸馏的思想,同一批次的样本经过两个不同的模型,通过有选择地互相拟合彼此的特征分布,使得二者分布差异尽量拟合的情况下,有效保留最正确的信息。
Claims (8)
1.基于样本感知蒸馏的长尾分布视觉分类方法,其特征在于包括以下步骤:
1)建立长尾分布数据集,对数据集采样作为输入图片并预处理;
2)将预处理后的图像输入教师网络,训练教师网络模型后,将训练集所有样本输预训练教师模型中,得到特征向量的均值求全局类别特征中心,取出预测置信度中最高置信度的标签,得到教师预测结果;
3)将预处理后的图像输入学生网络中训练学生网络模型,样本经特征编码器得到特征向量,特征向量进入两个分支,分别计算交叉熵损失和特征中心稳定学习损失;取出预测置信度中最高置信度所对应的类别标签,得到学生预测结果;
4)学生预测结果、教师预测结果及真实标签三者利用选择性知识蒸馏模块计算得到选择性知识蒸馏损失;
5)联合交叉熵损失、特征中心稳定学习损失和选择性知识蒸馏损失三种损失对学生网络模型优化训练。
3.如权利要求1所述基于样本感知蒸馏的长尾分布视觉分类方法,其特征在于在步骤1)中,所述预处理包括对输入图片归一化,随机裁剪至固定大小,随机翻转进行数据增强。
5.如权利要求1所述基于样本感知蒸馏的长尾分布视觉分类方法,其特征在于在步骤3)中,所述训练学生网络模型的具体步骤为:
将图像输入到学生网络中,其中学生网络的模型与教师网络模型一致,且学生网络与教师网络不共享参数;每个批次的图像经过学生网络的特征编码器得到64维度的特征向量fs,该特征向量fs进入两个分支;第一个分支进一步得到网络逻辑预测输出zs,并经过功能函数归一化后得到预测置信度ps,该预测置信度ps和真实标签y计算得到交叉熵损失LCE;第二个分支通过对该特征向量求类别特征均值得到局部类别特征中心Ql,并和全局类别特征中心Qg一起输入到特征中心稳定学习模块,计算得到特征中心稳定学习损失LSFCL;取出预测置信度ps中最高置信度所对应的类别标签,得到学生预测结果ys;学生预测结果ys、教师预测结果yt及真实标签y三者利用选择性知识蒸馏模块计算得到选择性知识蒸馏损失LSKD。
6.如权利要求5所述基于样本感知蒸馏的长尾分布视觉分类方法,其特征在于在步骤3)中,所述特征中心稳定学习模块用于提升尾部类样本的特征表达能力,具体步骤如下:
(3.1)对于训练样本,训练集图像在教师网络输出的特征向量ft,其经过逻辑预测得到zt,如式(1)所示;zt经过softmax函数,得到预测置信度pt,如式(2)所示;预测结果yt,是所有预测置信度的最大值所对应的类别标号,如式(3)所示;将一个batch的图像集输入学生网络对应的输出为特征向量fs,逻辑预测输出zs以及预测结果ys,计算过程如下:
zt=logits(ft),zs=logits(fs) (1)
pt=softmax(zt),ps=softmax(zs) (2)
yt=argmax(pt),ys=argmax(ps) (3)
其中,argmax函数表示取出预测置信度中最高置信度对应的索引,即类别标签;
(3.2)利用特征中心稳定学习模块用于提升尾部类样本的特征表达能力;包括以下步骤:
i.利用教师模型得到的全局类别特征中心Qg作为特征中心稳定学习模块的输入,并利用Qg计算得到表示类间关系的亲和度矩阵进一步利用功能函数softmax对亲和度矩阵A在去掉对角线后归一化得到归一化亲和度矩阵
iii.利用归一化亲和度矩阵对当前批次的局部类别特征中心Ql加权优化,即使用移动加权平均EMA的方法,利用优化后的局部类别特征中心和全局类别特征中心Qg在每个批次中更新当前批次的局部类别特征中心,得到优化后的局部类别特征中心Qr,计算过程公式如下:
iv.利用产生的Qr拉近每个样本的特征与其对应的特征中心的距离,使其向其对应的特征中心靠齐,采用曼哈顿距离作为距离度量方式,L1-norm,距离的绝对值之和,得到特征中心稳定学习模块的损失函数,如下式:
LSFCL=||θ(σ(Qr),σ(fbatch),y)||1 (6)
其中,σ(a)=a/||a||表示的是欧式距离,L2-norm,距离平方和,θ(b,c)表示的是b和c之间的余弦相似性,fbatch表示当前batch的样本特征。
7.如权利要求1所述基于样本感知蒸馏的长尾分布视觉分类方法,其特征在于在步骤4)中,所述选择性知识蒸馏模块用于提升长尾分布问题的分类性能,具体步骤如下:
(4.1)知识蒸馏通常使用KL散度衡量学生模型和教师模型对相同批次样本预测后得到的分布的差异;KL的具体计算公式如下:
其中,zs,zt分别为学生模型、教师模型的预测分布;
(4.2)选择性知识蒸馏模块在知识蒸馏的基础上对知识进行挑选;每个批次随机采样batch_size个样本输入到两个模型,得到学生预测结果ys和教师预测结果yt,学生预测结果ys、教师预测结果yt和真实标签y三者作为选择性知识蒸馏模块的输入并计算选择性知识蒸馏损失;该模块中的蒸馏权重由三种情况组成,公式为分别是:
i.教师预测结果与真实标签不同,教师预测错误,蒸馏权重置0;
ii.教师预测结果与学生预测结果相同,蒸馏权重置为pr,p为1-所预测类别置信度,r取2;
iii.教师预测结果与真实标签一致,教师预测正确,学生预测结果与真实标签不一致,蒸馏权重置1;
总的选择性知识蒸馏损失计算如下式:
通过优化选择性知识蒸馏损失,有选择性地减小两个模型间的分布差异,学生模型有效剔除教师模型错误的知识,提升该学生模型对长尾分布问题的分类性能;
总的损失函数如下式:
Ltotal=LCE+LSKD+α·LSFCL (9)
其中,α为损失的平衡系数;重复迭代设定的次数,直至训练结束。
8.如权利要求1所述基于样本感知蒸馏的长尾分布视觉分类方法,其特征在于在步骤5)中,三种损失网络优化训练,对于学生网络模型的整个训练过程中,联合交叉熵损失LCE、特征中心稳定学习损失LSFCL和选择性知识蒸馏损失LSKD三种损失进行网络优化训练,交叉熵损失LCE用于对模型约束,特征中心稳定学习损失LSFCL用于帮助增强尾部类样本的特征表达能力,选择性知识蒸馏损失LSKD用于帮助将知识有侧重性地蒸馏给学生网络,以提高长尾分布问题的分类性能;经过训练,学生网络模型在尾部类样本上的特征表达能力有所提升,并学习到教师网络模型正确的知识。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211579446.5A CN115995018A (zh) | 2022-12-09 | 2022-12-09 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211579446.5A CN115995018A (zh) | 2022-12-09 | 2022-12-09 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115995018A true CN115995018A (zh) | 2023-04-21 |
Family
ID=85994666
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211579446.5A Pending CN115995018A (zh) | 2022-12-09 | 2022-12-09 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115995018A (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116205290A (zh) * | 2023-05-06 | 2023-06-02 | 之江实验室 | 一种基于中间特征知识融合的知识蒸馏方法和装置 |
CN116415005A (zh) * | 2023-06-12 | 2023-07-11 | 中南大学 | 一种面向学者学术网络构建的关系抽取方法 |
CN116502621A (zh) * | 2023-06-26 | 2023-07-28 | 北京航空航天大学 | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 |
CN117333757A (zh) * | 2023-11-16 | 2024-01-02 | 中国科学院空天信息创新研究院 | 图像处理方法、装置、设备及存储介质 |
CN117372785A (zh) * | 2023-12-04 | 2024-01-09 | 吉林大学 | 一种基于特征簇中心压缩的图像分类方法 |
CN117474037A (zh) * | 2023-12-25 | 2024-01-30 | 深圳须弥云图空间科技有限公司 | 基于空间距离对齐的知识蒸馏方法及装置 |
CN117892841A (zh) * | 2024-03-14 | 2024-04-16 | 青岛理工大学 | 基于渐进式联想学习的自蒸馏方法及系统 |
-
2022
- 2022-12-09 CN CN202211579446.5A patent/CN115995018A/zh active Pending
Cited By (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116205290A (zh) * | 2023-05-06 | 2023-06-02 | 之江实验室 | 一种基于中间特征知识融合的知识蒸馏方法和装置 |
CN116205290B (zh) * | 2023-05-06 | 2023-09-15 | 之江实验室 | 一种基于中间特征知识融合的知识蒸馏方法和装置 |
CN116415005A (zh) * | 2023-06-12 | 2023-07-11 | 中南大学 | 一种面向学者学术网络构建的关系抽取方法 |
CN116415005B (zh) * | 2023-06-12 | 2023-08-18 | 中南大学 | 一种面向学者学术网络构建的关系抽取方法 |
CN116502621A (zh) * | 2023-06-26 | 2023-07-28 | 北京航空航天大学 | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 |
CN116502621B (zh) * | 2023-06-26 | 2023-10-17 | 北京航空航天大学 | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 |
CN117333757A (zh) * | 2023-11-16 | 2024-01-02 | 中国科学院空天信息创新研究院 | 图像处理方法、装置、设备及存储介质 |
CN117372785A (zh) * | 2023-12-04 | 2024-01-09 | 吉林大学 | 一种基于特征簇中心压缩的图像分类方法 |
CN117372785B (zh) * | 2023-12-04 | 2024-03-26 | 吉林大学 | 一种基于特征簇中心压缩的图像分类方法 |
CN117474037A (zh) * | 2023-12-25 | 2024-01-30 | 深圳须弥云图空间科技有限公司 | 基于空间距离对齐的知识蒸馏方法及装置 |
CN117474037B (zh) * | 2023-12-25 | 2024-05-10 | 深圳须弥云图空间科技有限公司 | 基于空间距离对齐的知识蒸馏方法及装置 |
CN117892841A (zh) * | 2024-03-14 | 2024-04-16 | 青岛理工大学 | 基于渐进式联想学习的自蒸馏方法及系统 |
CN117892841B (zh) * | 2024-03-14 | 2024-05-31 | 青岛理工大学 | 基于渐进式联想学习的自蒸馏方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN115995018A (zh) | 基于样本感知蒸馏的长尾分布视觉分类方法 | |
CN109948425B (zh) | 一种结构感知自注意和在线实例聚合匹配的行人搜索方法及装置 | |
WO2019015246A1 (zh) | 图像特征获取 | |
CN112507901B (zh) | 一种基于伪标签自纠正的无监督行人重识别方法 | |
CN110503000B (zh) | 一种基于人脸识别技术的教学抬头率测量方法 | |
CN111325115A (zh) | 带有三重约束损失的对抗跨模态行人重识别方法和系统 | |
CN114298122B (zh) | 数据分类方法、装置、设备、存储介质及计算机程序产品 | |
CN111652293A (zh) | 一种多任务联合判别学习的车辆重识别方法 | |
CN114241273A (zh) | 基于Transformer网络和超球空间学习的多模态图像处理方法及系统 | |
CN110348516B (zh) | 数据处理方法、装置、存储介质及电子设备 | |
CN115830531A (zh) | 一种基于残差多通道注意力多特征融合的行人重识别方法 | |
CN116543269B (zh) | 基于自监督的跨域小样本细粒度图像识别方法及其模型 | |
Zeng et al. | Geo-localization via ground-to-satellite cross-view image retrieval | |
Kordopatis-Zilos et al. | Geotagging social media content with a refined language modelling approach | |
CN114882267A (zh) | 一种基于相关区域的小样本图像分类方法及系统 | |
CN115546553A (zh) | 一种基于动态特征抽取和属性修正的零样本分类方法 | |
CN115761408A (zh) | 一种基于知识蒸馏的联邦域适应方法及系统 | |
CN115170874A (zh) | 一种基于解耦蒸馏损失的自蒸馏实现方法 | |
CN114972506A (zh) | 一种基于深度学习和街景图像的图像定位方法 | |
CN113361928A (zh) | 一种基于异构图注意力网络的众包任务推荐方法 | |
CN110472088A (zh) | 一种基于草图的图像检索方法 | |
CN114723998A (zh) | 基于大边界贝叶斯原型学习的小样本图像分类方法及装置 | |
CN114708637A (zh) | 一种基于元学习的人脸动作单元检测方法 | |
Biswas et al. | Large scale image clustering with active pairwise constraints | |
CN109871835B (zh) | 一种基于互斥正则化技术的人脸识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |