CN113837238A - 一种基于自监督和自蒸馏的长尾图像识别方法 - Google Patents

一种基于自监督和自蒸馏的长尾图像识别方法 Download PDF

Info

Publication number
CN113837238A
CN113837238A CN202111026141.7A CN202111026141A CN113837238A CN 113837238 A CN113837238 A CN 113837238A CN 202111026141 A CN202111026141 A CN 202111026141A CN 113837238 A CN113837238 A CN 113837238A
Authority
CN
China
Prior art keywords
self
network
training
stage
supervision
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202111026141.7A
Other languages
English (en)
Other versions
CN113837238B (zh
Inventor
王利民
李天昊
武港山
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing University
Original Assignee
Nanjing University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing University filed Critical Nanjing University
Priority to CN202111026141.7A priority Critical patent/CN113837238B/zh
Publication of CN113837238A publication Critical patent/CN113837238A/zh
Application granted granted Critical
Publication of CN113837238B publication Critical patent/CN113837238B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computational Linguistics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

一种基于自监督和自蒸馏的长尾图像识别方法,构建多阶段的训练框架训练特征提取网络,第一阶段在长尾分布采样下利用自监督训练特征提取网络,第二阶段在保留第一阶段特征提取网络权重的情况下,在类别平衡采样下微调特征提取网络的分类器,生成用于自蒸馏的软标签,第三阶段丢弃之前的权重,在长尾分布采用下利用软标签作为监督对特征提取网络进行自蒸馏联合训练,得到的特征提取网络用于长尾分布下的图像识别分类。本发明针对长尾数据的特征提取网络提出一种利用自监督和自蒸馏的多阶段训练方法,利用自监督方法对尾部类别得到充分的表征,同时利用自蒸馏的方法将头部类别的知识有效迁移到尾部类别中。

Description

一种基于自监督和自蒸馏的长尾图像识别方法
技术领域
本发明属于计算机软件技术领域,涉及图像分类技术,具体为一种基于自监督和自蒸馏的长尾图像识别方法。
背景技术
最近,通过在大规模类别平衡和经过细致挑选标注的数据集,例如ImageNet和Kinetics上训练强大的神经网络,深度学习在图像和视频领域的视觉识别方面取得了显着进展。与这些人为平衡的数据集不同,现实世界的数据总是遵循长尾分布,这使得收集平衡数据集更具挑战性,对于天然样本数量比较少的类别,收集大量的训练样本的成本非常高,几乎难以实现。然而,由于数据分布极不平衡,直接从长尾数据中学习又会导致性能大幅下降。
缓解长尾训练数据带来的性能下降的常见方法是基于类别重新平衡的策略,包括在训练中重新平衡的训练数据采样策略和设计根据类别重新设置权重的损失函数。这些方法可以有效地减少训练过程中头部类的支配地位,从而产生更精确的分类决策边界。然而,由于数据分布被人为扭曲造成失真,过度参数化的深度网络很容易拟合这个合成分布,因此它们经常面临过拟合尾部类别的风险。为了解决这些问题,Bingyi等人将表征学习和分类器训练的任务分离,设计了一种两阶段的训练方案(Kang B,Xie S,Rohrbach M,etal.Decoupling representation and classifier for long-tailed recognition[J].arXiv preprint arXiv:1910.09217,2019.)。这种两阶段训练方案首先学习原始数据分布下的视觉表示,然后在类别平衡采样下在冻结的特征上训练线性分类器。事实证明,这个简单的两阶段训练方案能够处理过拟合问题,并在常见的长尾基准上取得了当时最好的效果。然而,这种两阶段训练方案未能很好地处理不平衡标签分布问题,特别是在表征学习阶段,使得特征不能很好得表示尾部类别得样本。
基于上面的分析,本发明的目标是为长尾视觉识别设计一种新的学习范式,希望能够融合两种长尾识别方法的优点,即对过拟合问题的鲁棒性,并有效地处理不平衡标签问题。
发明内容
本发明要解决的问题是:自然界中物体是按照长尾的特点进行分布的,直接对长尾分布的数据进行学习会导致模型只关注头部类别和忽略尾部类别,传统解决长尾识别的方法存在过拟合尾部样本,欠拟合头部样本的问题,最近提出的方法能够解决过拟合的问题,但是在特征训练阶段对尾部标签的建模不够充分,本发明要解决的问题就是针对长尾分布下的图像,如何设计出一个既不过拟合尾部类别,又能有效对不平衡标签进行建模的长尾视觉识别方法。
本发明的技术方案为:一种基于自监督和自蒸馏的长尾图像识别方法,构建一个多阶段的训练框架用于训练深度神经网络中的特征提取网络和分类器,第一阶段在长尾分布采样下利用自监督任务训练特征提取网络,第二阶段在保留第一阶段特征提取网络权重的情况下,在类别平衡采样下微调分类器,生成用于自蒸馏的软标签,第三阶段重新训练一个同样结构的深度神经网络,在长尾分布采用下利用第二阶段的软标签作为监督,对深度神经网络进行自蒸馏联合训练,得到的深度神经网络用于长尾分布下的图像识别分类。
进一步的,本发明包括以下步骤:
1)准备阶段:准备训练所用的长尾分布的图片数据集以及深度神经网络,深度神经网络由特征提取网络和分类器组成,随机初始化深度神经网络的参数;
2)自监督引导下的特征训练阶段:在长尾分布的数据下,同时利用监督任务和自监督任务对特征提取网络进行训练;
3)软标签生成阶段:在类别平衡的方式下进行数据采样,固定步骤2)中训练得到的特征提取网络的权重参数,对分类器进行微调,微调完成后作为教师网络输出训练样本的预测结果作为软标签,供步骤4)使用;
4)自蒸馏阶段:在原始的长尾分布的数据下,重新训练一个深度神经网络,该深度神经网络具有与步骤1)同样网络结构的特征提取网络,利用步骤3)得到的软标签和真实标签同时监督进行训练;
5)分类器微调阶段:在类别平衡的方式下进行数据采样,固定步骤4)训练得到的特征提取网络参数不变,对分类器进行微调,得到最终的深度神经网络;
6)测试阶段:在类别平衡的数据集上进行测试,检测深度神经网络的图片识别能力。
作为优选实施方式,多阶段的训练具体为:
自监督特征训练阶段:准备训练所用的深度神经网络D,深度神经网络D包括特征提取网络F和分类器Gsup,对长尾分布的图片数据集进行采样得到训练图片,将训练图片送入特征提取网络获得图片的特征f,将特征f送入分类器Gsup得到类别的预测,根据真实标签计算分类的损失函数;随机初始化一个用于自监督任务的网络模块,将特征f送入自监督任务的网络模块得到输出,根据输出计算自监督的损失函数;将所计算的分类和自监督的两个损失函数相加作为最终的损失函数,利用随机梯度下降对特征提取网络进行优化,不断迭代上述过程,直至达到迭代次数;
软标签生成阶段:采取类别平衡的方法对训练数据进行采样,重新训练自监督特征训练阶段得到的特征提取网络F和分类器Gsup,训练任务为分类任务,损失函数为交叉熵损失函数,训练的方法是固定特征提取网络F的权重,通过多个可学习的参数微调分类器中每个类别的权重,不断迭代直到达到迭代次数,称这个阶段训练完成的深度神经网络为网络R;
自蒸馏阶段:初始化一个新的深度神经网络S,由特征提取网络FS和两个线性分类器Hhard和Hsoft组成,特征提取网络FS与特征提取网络F网络结构一致,对长尾分布的图片数据集进行采样得到训练图片,每张训练图片先送入网络R中,输出网络的预测结果,这个预测结果即为软标签,再将训练图片送入深度神经网络S,深度神经网络S的两个分类器分别输出两个分类结果,利用软标签和图片的原始标签分别对两个分类结果进行监督,损失函数分别为KL散度和交叉熵损失函数,不断迭代训练,直到达到迭代次数。
进一步的,还配置有分类器微调阶段:将自蒸馏阶段训练得到的深度神经网络S中硬标签监督的分类器在类别平衡的数据采样下做微调,得到最终的分类结果。
本发明提出了一个概念上简单但特别有效的多阶段训练方案,该方案由两部分组成。首先,本发明引入了一个用于长尾识别的自蒸馏框架,它可以自动挖掘标签关系;其次,提出了一种由自我监督引导的新蒸馏标签生成模块,蒸馏标签整合了来自标签和数据域的信息,可以有效地对长尾分布进行建模。
本发明与现有技术相比有如下优点
本发明提出了一种利用自监督和自蒸馏的多阶段长尾图像识别训练方法,它能够利用自监督方法对尾部类别得到充分的表征,同时利用自蒸馏的方法将头部类别的知识有效迁移到尾部类别中。
与现有技术手动设计类别平衡训练策略的方法相比,本发明中,特征提取网络是在长尾分布下进行训练,没有人为破坏原本得分布,因此能够避免对于尾部类别的过拟合和对于头部类别的欠拟合。与现有的两阶段的长尾识别训练方法相比,本发明在自蒸馏阶段中加入网络R的软标签结果,实现在特征训练阶段引入类别平衡的建模,可以获得更鲁棒的表征。
本发明在公开长尾图像识别数据集上均取得显著优于现有技术的结果。
附图说明
图1是本发明所使用的系统框架图。
具体实施方式
深度学习在大规模平衡数据集的视觉识别方面取得了显着进展,但在现实世界的长尾数据上仍然表现不佳。现有技术通常采用类别重新平衡的训练策略来有效缓解不平衡问题,但可能存在过度拟合尾部类别的风险。最近有研究人员提出的解耦方法通过使用多阶段训练方案克服了过拟合问题,但仍然无法在特征学习阶段获取尾部类别信息。在本发明中,展示了软标签可以作为一种效果优异的解决方案,将标签相关性纳入多阶段训练方案以进行长尾识别,软标签所体现的类之间的内在关系通过将知识从头到尾类转移,从而有助于长尾识别。
如图1所示,本发明构建一个多阶段的训练框架用于训练深度神经网络中的特征提取网络和分类器,第一阶段在长尾分布采样下利用自监督任务训练特征提取网络,第二阶段在保留第一阶段特征提取网络权重的情况下,在类别平衡采样下微调分类器,生成用于自蒸馏的软标签,第三阶段重新训练一个同样结构的深度神经网络,在长尾分布采用下利用第二阶段的软标签作为监督,对深度神经网络进行自蒸馏联合训练,得到的深度神经网络用于长尾分布下的图像识别分类。包括以下步骤:
包括以下步骤:
1)准备阶段:准备训练所用的长尾分布的图片数据集以及深度神经网络,深度神经网络由特征提取网络和分类器组成,随机初始化深度神经网络的参数;
2)自监督引导下的特征训练阶段:在长尾分布的数据下,同时利用监督任务和自监督任务对特征提取网络进行训练;
3)软标签生成阶段:在类别平衡的方式下进行数据采样,固定步骤2)中训练得到的特征提取网络的权重参数,对分类器进行微调,微调完成后作为教师网络输出训练样本的预测结果作为软标签,供步骤4)使用;
4)自蒸馏阶段:在原始的长尾分布的数据下,重新训练一个深度神经网络,该深度神经网络具有与步骤1)同样网络结构的特征提取网络,利用步骤3)得到的软标签和真实标签同时监督进行训练;
5)分类器微调阶段:在类别平衡的方式下进行数据采样,固定步骤4)训练得到的特征提取网络参数不变,对分类器进行微调,得到最终的深度神经网络。
下面进行具体说明。
1)准备阶段:准备训练所用的数据集,数据集分为训练集和测试集,训练集的样本是按照长尾分布的,少数类别的样本数量较多,多数类别的样本数量较少,测试集每个类别拥有相同数量的样本,每个类别的样本数量一般较少。准备训练所用的深度神经网络,记作D,深度神经网络由特征提取网络和分类器组成,特征提取网络可选择为常见的深层基础网络,如ResNet、ResNeXt、VGGNet等,分类器为全连接层,记特征提取网络为变换F,分类器为变换Gsup,随机初始化深度神经网络的参数。
2)自监督引导下的特征训练阶段:随机初始化一个用于自监督任务的网络模块,记为Gself,此模块的具体形式根据自监督任务的不同有所不同。在数据本身的长尾分布下进行采样得到训练图片x。将图片送入特征提取网络获得图片的特征f=F(x),将特征f送入分类器得到类别的预测z=Gsup(f)∈R1×c,c为类别的数量,设图片的真实类别为y,根据真实标签计算分类的损失函数:
Figure BDA0003243410610000051
将特征送入自监督任务的网络模块得到输出u=Gself(f),根据输出计算自监督的损失函数Lself,两个损失函数加权相加作为最终的损失函数L,权重分别为α1和α2
L=α1Lsup2Lself
利用随机梯度下降对网络进行优化,不断迭代上述过程,直至达到迭代次数。
自监督任务可以选择预测图片的旋转角度或实例判别。具体如下:
2.1)旋转角度预测任务:对于一张图片x,随机旋转{0°,90°,180°,270°}中的一个角度,获得旋转后的图片x′,通过网络预测旋转的角度。此任务中自监督网络模块Gself为一个全连接层实现的线性分类器,输出为u∈R1×4,设图片的旋转角度为四个角度中的第r个角度,则自监督损失函数为:
Figure BDA0003243410610000052
2.2)实例判别任务:将自监督网络模块Gself实现为一个多层感知机模型。克隆当前的深度神经网络的结构和权重,生成一个新的深度神经网络M,训练过程中M的参数依据网络D的权重进行动量更新,记动量为m,则更新的公式为:
M=m·M+(1-m)·D
对于第i张图片,经过变换T1得到简单数据增强后的输入图片xi,经过变换T2得到复杂数据增强后的输入图片x′i,将xi送入网络F和Gself,并对输出结果利用l-2范数进行归一化,得到输入图片的特征vi,将x′i送入网络F和Gself并对输出结果利用l-2范数进行归一化,得到输入图片的特征v′i
Figure BDA0003243410610000061
Figure BDA0003243410610000062
则自监督的损失函数为:
Figure BDA0003243410610000063
其中v′k为其他图片经过网络M输出的特征,称为负样本,K为负样本的数量,τ为控制温度的超参数。
3)软标签生成阶段:采取类别平衡的方法对训练数据进行采样,类别平衡采样的具体步骤为一个两阶段采样,先通过均匀分布的采样器随机选择一个类别,再在属于选中类别的样本中通过均匀分布的采样器随机选择一个样本。重新训练步骤2)得到的深度神经网络中的特征提取网络F和分类器Gsup,训练时固定特征提取网络的权重不变,引入调节分类器中权重的尺度的参数si,对于分类器中每个类别的原本的权重wi,调整尺度后的权重为:
Figure BDA0003243410610000064
训练中保持原本的权重wi不变,根据梯度优化更新参数si的值。对于图片x,利用调整后的分类器得到类别的预测值:
f=F(x)
Figure BDA0003243410610000065
其中c为类别数量,损失函数为同样为交叉熵损失函数,设正确类别为y,则损失函数为:
Figure BDA0003243410610000066
利用随机梯度下降对网络进行优化,不断迭代上述过程,直至达到迭代次数。称这个阶段训练完成的深度神经网络为网络R。
4)自蒸馏阶段:重新随机初始化一个新的深度神经网络S,由特征提取网络FS以及两个线性分类器Hhard和Hsoft,特征提取网络FS与步骤1)中准备的特征提取网络结构一致。在数据原始的长尾数据下进行采样得到训练图片x,经过图像变换进行数据增强得到x′,图片的原始标签为y。首先通过步骤3)训练完成的网络R得到图片x的伪标签
Figure BDA0003243410610000067
Figure BDA0003243410610000071
Figure BDA0003243410610000072
Figure BDA0003243410610000073
其中
Figure BDA0003243410610000074
为网络R提取的特征,
Figure BDA0003243410610000075
为网络R得到的类别预测,
Figure BDA0003243410610000076
为网络R分类器的权重,
Figure BDA0003243410610000077
为第n个类别预测,
Figure BDA0003243410610000078
代表伪标签中的第n个元素,T代表温度参数,是一个超参数,通过手工设置T=2,用来调前上面公式的分布,T越大分布越平缓,c代表类别数量。
将同样经过数据增强的图片x′送入本阶段重新初始化的深度神经网络S,得到两个分类器的预测输出zhard和zsoft
f=FS(x′)
zhard=Hhard(f)
zsoft=Hsoft(f)
利用zsoft和软标签
Figure BDA0003243410610000079
计算自蒸馏部分的损失函数:
Figure BDA00032434106100000710
其中T为控制分布平滑程度的温度参数,
Figure BDA00032434106100000711
为分类器Hsoft输出的第n、k个类别预测。
利用zhard和原始标签y计算普通分类的损失函数:
Figure BDA00032434106100000712
将上述两个损失函数按照权重λ1和λ2进行加权融合,得到本阶段最后的损失函数:
L=λ1Lkd2Lce
利用随机梯度下降对网络进行优化,不断迭代上述过程,直至达到迭代次数。
5)分类器微调阶段:将自蒸馏阶段训练得到的深度神经网络S中硬标签监督的分类器Hhard在类别平衡的数据采样下做微调,硬标签即原始标签。与软标签一样采取类别平衡的方法对训练图片进行采样,重新训练步骤4)得到的深度神经网络中的特征提取网络FS和分类器Hhard,训练时固定特征提取网络FS的权重不变,引入调节分类器中权重的尺度的参数si,对于分类器中每个类别的原本的权重hi,调整尺度后的权重为:
Figure BDA00032434106100000713
训练中保持原本的权重wi不变,根据梯度优化更新参数si的值。对于图片x,利用调整后的分类器得到类别的预测值:
f=F(x)
Figure BDA0003243410610000081
其中c为类别数量,损失函数为同样为交叉熵损失函数,设正确类别为y,则损失函数为:
Figure BDA0003243410610000082
利用随机梯度下降对网络进行优化,不断迭代上述过程,直至达到迭代次数。
6)测试阶段:测试时使用步骤1)构造的测试集,测试集每个类别拥有相同数量的图片,即测试集类别平衡的数据集,将测试集图片分别送入步骤5)得到的网络中进行预测,通过跟这些图片的正确类别进行比对,得到预测的准确率,检测深度神经网络的图片识别能力是否达到准确性的要求。
下面通过具体实施例说明本发明的实施。
利用ImageNet-LT数据集中的图片进行训练,具体使用Python3编程语言,Pytorch1.6深度学习框架实施。
图1是本发明所使用的系统框架图,对应的具体实施如下:
1)准备阶段,构建训练和测试使用的数据集ImageNet-LT,该数据集共有1000个类别,类别分布符合帕累托分布,系数为6。训练集包含12万张图片,每个类别的图片数量从1280到5不等;测试集包含5万张图片,每个类别所包含的数量一样,均为50张。准备训练所需要的神经网络,特征提取网络选择ResNeXt-50,输出特征维度为2048维,分类器采用全连接层作为分类器,输入特征维度为2048,输出特征维度为1000,随机初始化神经网络的参数。
2)自监督引导下的特征训练阶段,具体实施是采用实例判别作为具体任务,自监督任务的网络模块为一个多层感知机,即全连接层-ReLU非线性激活层-全连接层,输入特征维度和隐层特征维度均为2048维,输出特征维度为128维。构建一个结构和参数与当前网络完全相同的网络,改网络在训练过程中使用动量更新,动量更新的参数m为0.999,温度参数τ为0.2。对于一张图片分别经过变换T1和T2得到两张数据增强后的图片,其中变换T1为随机变换图片大小-随机裁剪-随机进行水平翻转-归一化,变换T2为随机变换图片大小-随机裁剪-随机颜色变换-随机灰度化-随机高斯模糊-随机水平翻转-归一化。将变换T1变换后的图片送入原始网络中,得到1000维的分类预测向量和128维的特征向量,变换T2得到的图片送入动量更新的网络中,得到128维的特征向量。利用分类预测向量计算分类损失函数,特征向量计算自监督损失函数,两个损失函数按照1:1比例进行融合得到最总损失函数。利用随机梯度下降算法进行训练,使用8块TITIAN Xp进行训练,批大小为256,训练轮数为135轮,学习率为0.1,采用余弦函数对学习率进行衰减。
3)软标签生成阶段,抛弃步骤2)中的自监督模块,保留特征提取模块和分类器。采取类别平衡的方法对训练数据进行采样,即先通过均匀分布的采样器随机选择一个类别,再在属于选中类别的样本中通过均匀分布的采样器随机选择一个样本。重新训练步骤2)得到的深度神经网络中的特征提取网络和分类器,对于采样得到的一样图片,利用变换随机变换图片大小-随机裁剪-随机进行水平翻转-归一化得到数据增强后的图片,相继送入特征提取网络和经过系数调整后的分类器,得到类别预测结果,根据真实类别计算损失函数。采用随机梯度下降算法进行训练,使用8块TITIAN Xp进行训练,批大小为512,训练轮数为5轮,对分类器进行调整的系数的学习率为0.2,其余部分(特征提取网络和分类器原始参数)的学习率为0,采用余弦函数对学习率进行衰减。
4)自蒸馏阶段,重新随机初始化一个新的深度神经网络,由特征提取网络以及两个线性分类器,特征提取网络仍选择ResNeXt-50,两个分类器均为输入维度为2048,输出维度为1000的全连接层。在数据原始的长尾数据下进行采样得到训练图片,经过图像变换随机变换图片大小-随机裁剪-随机进行水平翻转-归一化进行数据增强,将图片送入步骤3)训练得到的网络,得到预测结果,利用温度T=2进行调制并利用softmax进行归一化后,得到伪标签;将图片送入特征提取网络得到中间特征,将中间特征分别送入两个分类器得到两个预测结果,利用第一个分类结果和伪标签计算自蒸馏的损失函数,利用第二个分类结果和真实标签计算得到分类的损失函数,将两个损失函数进行1:1融合得到最终的损失函数。利用随机梯度下降算法进行训练,使用8块TITIAN Xp进行训练,批大小为256,训练轮数为135轮,学习率为0.1,采用余弦函数对学习率进行衰减。
5)分类器微调阶段,采取类别平衡的方法对训练数据进行采样,即先通过均匀分布的采样器随机选择一个类别,再在属于选中类别的样本中通过均匀分布的采样器随机选择一个样本。重新训练步骤4)得到的深度神经网络中的特征提取网络和分类器,对于采样得到的一样图片,利用变换随机变换图片大小-随机裁剪-随机进行水平翻转-归一化得到数据增强后的图片,相继送入特征提取网络和经过系数调整后的分类器,得到类别预测结果,根据真实类别计算损失函数。采用随机梯度下降算法进行训练,使用8块TITIAN Xp进行训练,批大小为512,训练轮数为5轮,对分类器进行调整的系数的学习率为0.2,其余部分(特征提取网络和分类器原始参数)的学习率为0,采用余弦函数对学习率进行衰减。
6)测试阶段,使用步骤1)构造的测试集,测试集每个类别拥有相同数量的图片,将图片分别送入步骤5)得到的网络中进行预测,与正确类别进行比对得到预测的准确率。整个测试集的准确率为56.0%,其中训练集中样本数量较多的类别准确率为66.8%,出现次数中等的准确率为53.1%,出现次数较少的准确率为35.4%。与基线方法相比,准确率分别提高3.9%,3.4%,4.5%和3.1%。

Claims (6)

1.一种基于自监督和自蒸馏的长尾图像识别方法,其特征是构建一个多阶段的训练框架用于训练深度神经网络中的特征提取网络和分类器,第一阶段在长尾分布采样下利用自监督任务训练特征提取网络,第二阶段在保留第一阶段特征提取网络权重的情况下,在类别平衡采样下微调分类器,生成用于自蒸馏的软标签,第三阶段重新训练一个同样结构的深度神经网络,在长尾分布采用下利用第二阶段的软标签作为监督,对深度神经网络进行自蒸馏联合训练,得到的深度神经网络用于长尾分布下的图像识别分类。
2.根据权利要求1所述的一种基于自监督和自蒸馏的长尾图像识别方法,其特征是包括以下步骤:
1)准备阶段:准备训练所用的长尾分布的图片数据集以及深度神经网络,深度神经网络由特征提取网络和分类器组成,随机初始化深度神经网络的参数;
2)自监督引导下的特征训练阶段:在长尾分布的数据下,同时利用监督任务和自监督任务对特征提取网络进行训练;
3)软标签生成阶段:在类别平衡的方式下进行数据采样,固定步骤2)中训练得到的特征提取网络的权重参数,对分类器进行微调,微调完成后作为教师网络输出训练样本的预测结果作为软标签,供步骤4)使用;
4)自蒸馏阶段:在原始的长尾分布的数据下,重新训练一个深度神经网络,该深度神经网络具有与步骤1)同样网络结构的特征提取网络,利用步骤3)得到的软标签和真实标签同时监督进行训练;
5)分类器微调阶段:在类别平衡的方式下进行数据采样,固定步骤4)训练得到的特征提取网络参数不变,对分类器进行微调,得到最终的深度神经网络;
6)测试阶段:在类别平衡的数据集上进行测试,检测深度神经网络的图片识别能力是否符合要求。
3.根据权利要求1或2所述的一种基于自监督和自蒸馏的长尾图像识别方法,其特征是多阶段的训练具体为:
自监督特征训练阶段:准备训练所用的深度神经网络D,深度神经网络D包括特征提取网络F和分类器Gsup,对长尾分布的图片数据集进行采样得到训练图片,将训练图片送入特征提取网络获得图片的特征f,将特征f送入分类器Gsup得到类别的预测,根据真实标签计算分类的损失函数;随机初始化一个用于自监督任务的网络模块,将特征f送入自监督任务的网络模块得到输出,根据输出计算自监督的损失函数;将所计算的分类和自监督的两个损失函数相加作为最终的损失函数,利用随机梯度下降对特征提取网络进行优化,不断迭代上述过程,直至达到迭代次数;
软标签生成阶段:采取类别平衡的方法对训练图片进行采样,重新训练自监督特征训练阶段得到的特征提取网络F和分类器Gsup,训练任务为分类任务,损失函数为交叉熵损失函数,训练的方法是固定特征提取网络F的权重,通过多个可学习的参数微调分类器Gsup中每个类别的权重,不断迭代直到达到迭代次数,称这个阶段训练完成的深度神经网络为网络R;
自蒸馏阶段:初始化一个新的深度神经网络S,由特征提取网络FS和两个线性分类器Hhard和Hsoft组成,特征提取网络FS与特征提取网络F网络结构一致,对长尾分布的图片数据集进行采样得到训练图片,每张训练图片先送入网络R中,输出网络的预测结果,这个预测结果即为软标签,再将训练图片送入深度神经网络S,深度神经网络S的两个分类器分别输出两个分类结果,利用软标签和图片的原始标签分别对两个分类结果进行监督,损失函数分别为KL散度和交叉熵损失函数,不断迭代训练,直到达到迭代次数。
4.根据权利要求3所述的一种基于自监督和自蒸馏的长尾图像识别方法,其特征是还配置有分类器微调阶段:将自蒸馏阶段训练得到的深度神经网络S中硬标签监督的分类器在类别平衡的数据采样下做微调,得到最终的分类结果。
5.根据权利要求3所述的一种基于自监督和自蒸馏的长尾图像识别方法,其特征是自监督特征训练阶段中,自监督任务包括预测图片的旋转角度和实例判别:
旋转角度预测任务:对于一张图片x,随机旋转{0°,90°,180°,270°}中的一个角度,获得旋转后的图片x′,通过网络预测旋转的角度,此任务中自监督网络模块Gself为一个全连接层实现的线性分类器,输出为u∈R1×4,设图片的旋转角度为四个角度中的第r个角度,则自监督损失函数为:
Figure FDA0003243410600000021
实例判别任务:将自监督网络模块Gself实现为一个多层感知机模型,克隆当前的深度神经网络的结构和权重,生成一个新的深度神经网络M,训练过程中网络M的参数依据网络D的权重进行动量更新,记动量为m,则更新的公式为:
M=m·M+(1-m)·D
对于第i张图片,经过变换T1得到简单数据增强后的输入图片xi,经过变换T2得到复杂数据增强后的输入图片x′i,将xi送入网络F和Gself,并对输出结果利用/-2范数进行归一化,得到输入图片的特征vi,将x′i送入网络F和Gself并对输出结果利用/-2范数进行归一化,得到输入图片的特征v′i
Figure FDA0003243410600000031
Figure FDA0003243410600000032
自监督的损失函数为:
Figure FDA0003243410600000033
其中v′k为其他图片经过网络M输出的特征,称为负样本,K为负样本的数量,τ为控制温度的超参数。
6.根据权利要求3所述的一种基于自监督和自蒸馏的长尾图像识别方法,其特征是自蒸馏阶段使用双头自蒸馏算法:
对长尾分布的图片数据集进行采样得到训练图片x,经过图像变换进行数据增强得到图片x′,设图片的原始标签为y,首先通过网络R得到图片x的伪标签
Figure FDA0003243410600000034
Figure FDA0003243410600000035
Figure FDA0003243410600000036
Figure FDA0003243410600000037
其中
Figure FDA0003243410600000038
为网络R提取的特征,
Figure FDA0003243410600000039
为网络R得到的类别预测,
Figure FDA00032434106000000310
为网络R分类器的权重,
Figure FDA00032434106000000311
为第n个类别预测,
Figure FDA00032434106000000312
代表伪标签中的第n个元素,T代表控制分布平滑程度的温度参数,c代表类别数量;
将同样经过数据增强的图片x′送入本阶段重新初始化的深度神经网络S,得到两个分类器的预测输出zhard和zsoft
f=FS(x′)
zhard=Hhard(f)
zsoft=Hsoft(f)
利用zsoft和软标签
Figure FDA00032434106000000313
计算自蒸馏部分的损失函数:
Figure FDA00032434106000000314
其中T为控制分布平滑程度的温度参数,利用zhard和原始标签y计算普通分类的损失函数:
Figure FDA0003243410600000041
将上述两个损失函数按照权重λ1和λ2进行加权融合,得到本阶段最终的损失函数:
L=λ1Lkd2Lce
利用随机梯度下降对网络进行优化,不断迭代上述过程,直至达到迭代次数。
CN202111026141.7A 2021-09-02 2021-09-02 一种基于自监督和自蒸馏的长尾图像识别方法 Active CN113837238B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111026141.7A CN113837238B (zh) 2021-09-02 2021-09-02 一种基于自监督和自蒸馏的长尾图像识别方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111026141.7A CN113837238B (zh) 2021-09-02 2021-09-02 一种基于自监督和自蒸馏的长尾图像识别方法

Publications (2)

Publication Number Publication Date
CN113837238A true CN113837238A (zh) 2021-12-24
CN113837238B CN113837238B (zh) 2023-09-01

Family

ID=78962069

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111026141.7A Active CN113837238B (zh) 2021-09-02 2021-09-02 一种基于自监督和自蒸馏的长尾图像识别方法

Country Status (1)

Country Link
CN (1) CN113837238B (zh)

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114549904A (zh) * 2022-02-25 2022-05-27 北京百度网讯科技有限公司 视觉处理及模型训练方法、设备、存储介质及程序产品
CN114595780A (zh) * 2022-03-15 2022-06-07 百度在线网络技术(北京)有限公司 图文处理模型训练及图文处理方法、装置、设备及介质
CN114627348A (zh) * 2022-03-22 2022-06-14 厦门大学 多主体任务中基于意图的图片识别方法
CN114863193A (zh) * 2022-07-07 2022-08-05 之江实验室 基于混合批归一化的长尾学习图像分类、训练方法及装置
CN114863248A (zh) * 2022-03-02 2022-08-05 武汉大学 一种基于深监督自蒸馏的图像目标检测方法
CN115272881A (zh) * 2022-08-02 2022-11-01 大连理工大学 动态关系蒸馏的长尾遥感图像目标识别方法
CN116578913A (zh) * 2023-03-31 2023-08-11 中国人民解放军陆军工程大学 一种面向复杂电磁环境的可靠无人机检测识别方法
CN116578913B (zh) * 2023-03-31 2024-05-24 中国人民解放军陆军工程大学 一种面向复杂电磁环境的可靠无人机检测识别方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106203530A (zh) * 2016-07-21 2016-12-07 长安大学 面向k近邻算法用于不平衡分布数据的特征权重确定方法
US20200364542A1 (en) * 2019-05-16 2020-11-19 Salesforce.Com, Inc. Private deep learning
CN112348792A (zh) * 2020-11-04 2021-02-09 广东工业大学 一种基于小样本学习和自监督学习的x光胸片图像分类方法
CN112381116A (zh) * 2020-10-21 2021-02-19 福州大学 基于对比学习的自监督图像分类方法
US20210182618A1 (en) * 2018-10-29 2021-06-17 Hrl Laboratories, Llc Process to learn new image classes without labels
CN113177612A (zh) * 2021-05-24 2021-07-27 同济大学 一种基于cnn少样本的农业病虫害图像识别方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106203530A (zh) * 2016-07-21 2016-12-07 长安大学 面向k近邻算法用于不平衡分布数据的特征权重确定方法
US20210182618A1 (en) * 2018-10-29 2021-06-17 Hrl Laboratories, Llc Process to learn new image classes without labels
US20200364542A1 (en) * 2019-05-16 2020-11-19 Salesforce.Com, Inc. Private deep learning
CN112381116A (zh) * 2020-10-21 2021-02-19 福州大学 基于对比学习的自监督图像分类方法
CN112348792A (zh) * 2020-11-04 2021-02-09 广东工业大学 一种基于小样本学习和自监督学习的x光胸片图像分类方法
CN113177612A (zh) * 2021-05-24 2021-07-27 同济大学 一种基于cnn少样本的农业病虫害图像识别方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
T. LI: ""Self Supervision to Distillation for Long-Tailed Visual Recognition"", 《2021 IEEE/CVF INTERNATIONAL CONFERENCE ON COMPUTER VISION (ICCV)》, pages 610 - 619 *
李徵: ""基于地质知识蒸馏学习的油气储集层识别方法"", 《中国科学:信息科学》, vol. 51, no. 1, pages 40 - 55 *
秦晓明: ""基于深度学习的含噪声标签图像的分类研究"", 《中国优秀硕士学位论文全文数据库 信息科技辑》, no. 8, pages 138 - 562 *

Cited By (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114549904A (zh) * 2022-02-25 2022-05-27 北京百度网讯科技有限公司 视觉处理及模型训练方法、设备、存储介质及程序产品
CN114863248A (zh) * 2022-03-02 2022-08-05 武汉大学 一种基于深监督自蒸馏的图像目标检测方法
CN114863248B (zh) * 2022-03-02 2024-04-26 武汉大学 一种基于深监督自蒸馏的图像目标检测方法
CN114595780A (zh) * 2022-03-15 2022-06-07 百度在线网络技术(北京)有限公司 图文处理模型训练及图文处理方法、装置、设备及介质
CN114627348A (zh) * 2022-03-22 2022-06-14 厦门大学 多主体任务中基于意图的图片识别方法
CN114627348B (zh) * 2022-03-22 2024-05-31 厦门大学 多主体任务中基于意图的图片识别方法
CN114863193A (zh) * 2022-07-07 2022-08-05 之江实验室 基于混合批归一化的长尾学习图像分类、训练方法及装置
CN115272881A (zh) * 2022-08-02 2022-11-01 大连理工大学 动态关系蒸馏的长尾遥感图像目标识别方法
CN116578913A (zh) * 2023-03-31 2023-08-11 中国人民解放军陆军工程大学 一种面向复杂电磁环境的可靠无人机检测识别方法
CN116578913B (zh) * 2023-03-31 2024-05-24 中国人民解放军陆军工程大学 一种面向复杂电磁环境的可靠无人机检测识别方法

Also Published As

Publication number Publication date
CN113837238B (zh) 2023-09-01

Similar Documents

Publication Publication Date Title
CN113837238A (zh) 一种基于自监督和自蒸馏的长尾图像识别方法
CN109214452B (zh) 基于注意深度双向循环神经网络的hrrp目标识别方法
CN111583263B (zh) 一种基于联合动态图卷积的点云分割方法
CN108846413B (zh) 一种基于全局语义一致网络的零样本学习方法
CN109598711B (zh) 一种基于特征挖掘和神经网络的热图像缺陷提取方法
CN109239670B (zh) 基于结构嵌入和深度神经网络的雷达hrrp识别方法
CN110188827A (zh) 一种基于卷积神经网络和递归自动编码器模型的场景识别方法
CN111239137B (zh) 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法
CN109145685B (zh) 基于集成学习的果蔬高光谱品质检测方法
CN114861705A (zh) 一种基于多特征异构融合的电磁目标智能感知识别方法
CN113011487B (zh) 一种基于联合学习与知识迁移的开放集图像分类方法
CN113065520A (zh) 一种面向多模态数据的遥感图像分类方法
CN109872319B (zh) 一种基于特征挖掘和神经网络的热图像缺陷提取方法
CN116523711A (zh) 基于人工智能的教育监管系统及其方法
CN113688867B (zh) 一种跨域图像分类方法
CN115512272A (zh) 一种针对多事件实例视频的时序事件检测方法
CN115098681A (zh) 一种基于有监督对比学习的开放服务意图检测方法
CN112784927B (zh) 一种基于在线学习的半自动图像标注方法
CN113449751B (zh) 基于对称性和群论的物体-属性组合图像识别方法
CN114220145A (zh) 人脸检测模型生成方法和装置、伪造人脸检测方法和装置
CN116310463B (zh) 一种无监督学习的遥感目标分类方法
CN111652265A (zh) 一种基于自调整图的鲁棒半监督稀疏特征选择方法
CN112446432A (zh) 基于量子自学习自训练网络的手写体图片分类方法
CN114037866B (zh) 一种基于可辨伪特征合成的广义零样本图像分类方法
CN117274724B (zh) 基于可变类别温度蒸馏的焊缝缺陷分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant