CN115880529A - 基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统 - Google Patents

基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统 Download PDF

Info

Publication number
CN115880529A
CN115880529A CN202211534488.7A CN202211534488A CN115880529A CN 115880529 A CN115880529 A CN 115880529A CN 202211534488 A CN202211534488 A CN 202211534488A CN 115880529 A CN115880529 A CN 115880529A
Authority
CN
China
Prior art keywords
model
target
image
attention
bird
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211534488.7A
Other languages
English (en)
Inventor
陈志泊
杨锋
张颖
王康
陈伊鑫
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Forestry University
Original Assignee
Beijing Forestry University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Forestry University filed Critical Beijing Forestry University
Priority to CN202211534488.7A priority Critical patent/CN115880529A/zh
Publication of CN115880529A publication Critical patent/CN115880529A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统,属于计算机视觉技术领域,包括以下步骤:(1)获取鸟类数据集;(2)基于注意力引导实现数据增强,并训练教师模型;(3)基于解耦知识蒸馏压缩鸟类分类模型,实现教师模型和学生模型同时数据增强;(4)基于目标定位再识别的思想,预测阶段将目标图像输入最终的轻量级分类模型,获得最终鸟类细粒度分类结果。本发明应用于鸟类细粒度分类中,基于注意力引导实现数据增强,弥补了鸟类数据集不充足的问题;基于解耦知识蒸馏实现了鸟类分类模型的高效压缩,并在此基础上实现教师模型和学生模型同时数据增强的方法,再次提升学生模型的预测精度,获得高准确率的轻量级鸟类分类模型。

Description

基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统
技术领域
本发明涉及计算机视觉技术领域,尤其涉及知识蒸馏和鸟类细粒度图像分类技术,具体涉及一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统。
背景技术
鸟类对于维持生态系统平衡至关重要,其群落组成和种类分布成为检测大自然环境变化的重要指标。环境学家经常利用鸟类对环境变化的敏感来监测生态系统,且大多数生态应用都依赖于鸟类,例如环境污染检测、生物多样性检测、气候变化检测和濒危鸟类救援等。近年来,对鸟类的识别主要包括专家识别、雷达识别和声音识别。专家识别虽然保证了识别的精确度,但人力、时间成本较高;雷达识别通过自动识别降低了人工成本,但并不能保证较高的精准度;声音识别精准度较高但容易受到识别区域和周围噪音的影响。目前图像采集技术已经日趋成熟,深度学习技术也在不断发展,鸟类图像分类模型的研究实现了自动化监测,不仅在保证监测准确度的前提下降低了人力成本,并且为生态环境监测领域提供强有力的技术支撑。
鸟类识别属于细粒度分类,是对鸟类的子类进行分类,即精确到“种”的识别。由于同一类别的鸟类通常呈现不同的姿势和视角,不同类别之间存在细微差别,这使分类任务极具挑战,仅依靠当前先进的粗粒度卷积神经网络(CNN),如VGG、ResNet和Inception很难获得准确的分类结果。细粒度图像分类的关键步骤是提取目标中更具鉴别性的局部精细特征。在早期研究中,通常依赖于目标位置或属性的注释来关注图像的局部特征,属于强监督方法。这类方法在进行模型训练时,不仅需要图像的类别标签,还需要额外的目标重要区域标注。强监督方法虽然识别效果更加准确,但识别效率低,且前期对图像标注开销大。基于弱监督的细粒度图像分类方法成为了近几年基于深度学习的研究趋势。近年来,循环注意卷积网络模型RA-CNN和精细特征提取模型NTS-Net被提出。其中,RA-CNN模型是依靠循环预测一个注意区域的位置并提取相应的特征;而后者则将特征金字塔网络FPN引入细粒度分类任务,使模型对目标的三个区域进行定位。以上方法均对目标部分区域进行定位,限制了模型对目标区域全部精细特征的学习。
知识蒸馏最早提出是通过最小化教师和学生logit之间的差异来传递知识,是一种新兴的压缩模型的方法。但目前的logit蒸馏直接使用KL散度函数计算教师和学生logit之间的差异大小,由于KL散度函数高度耦合,抑制了非目标类之间的差异对总差异值的贡献,极大的限制了logit蒸馏效果。
综上,开发高准确率的轻量级鸟类分类模型,实现自动化鸟类监测,具有重要的研究价值。
发明内容
针对现有方法的不足之处,本发明提供一种基于注意力和解耦知识蒸馏的鸟类细粒度分类的方法及系统。该方法提出一种基于注意力引导的数据增强方法,利用图像注意力图获取目标和局部区域图像,提高训练的数据集质量,并在此基础上实现了基于区域定位再识别的细粒度分类方法;基于解耦知识蒸馏技术,实现了鸟类分类模型的高效压缩,训练出既能满足预测精度也能嵌入移动端使用的学生模型。除此之外,在知识蒸馏过程中本方法还实现了教师模型和学生模型同时实现数据增强的方法,在知识蒸馏的基础上再次提升学生模型的预测精度,最终得以快速获得鸟类细粒度分类。
为了达到上述技术目的,本发明采用的技术方案为:
一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,包括:
步骤1根据预设方式获取鸟类数据集;
步骤2将步骤1中获取的数据集输入教师模型,使用DenseNet121深度卷积网络作为特征提取器,基于注意力引导实现数据增强,获得目标图像和局部区域图像,输入教师模型并将损失值最小的教师模型作为训练好的教师模型;
步骤3将步骤1中获得的数据集和步骤2教师模型输出的目标图像和局部区域图像输入学生模型,所述学生模型采用轻量级的卷积网络ShuffleNetV2作为特征提取器,基于解耦logit蒸馏,实现鸟类细粒度分类模型压缩,实现教师模型和学生模型同时数据增强的方法;
步骤4基于目标定位再识别的思想,预测阶段将目标图像输入最终的轻量级分类模型,获得最终鸟类细粒度分类结果。
进一步的,步骤2采用的注意力引导增强数据的方法,包括两种,一种是裁剪目标图像进行增强,一种是裁剪局部区域图像数据增强。
步骤2.1.1通过通道叠加原图特征图得到原图注意力图A,注意力图A的计算公式为:
Figure SMS_1
式中使用F∈RC×H×W表示卷积神经网络模型的最后一个卷积层输出的具有C个信道和空间大小为H×W的特征图集,fi是特征图集的第i个特征图,A为所有特征图每个通道对应位置相加得到的通道为1、大小为H×W的注意力图。
步骤2.1.2根据步骤2.1.1中得到的原图注意力图中能直观观察到关键部分的所在区域,下一步需要计算目标像素值大小的阈值
Figure SMS_2
阈值/>
Figure SMS_3
的计算公式为:
Figure SMS_4
步骤2.1.3根据步骤2.1.2中得到的阈值
Figure SMS_5
判断注意力图A每个像素点是否是目标的一部分,使用该方式定位到目标的全部区域生成大小为H×W的目标位置掩膜,掩膜计算公式为:
Figure SMS_6
根据上式得到掩膜图,受图像背景和噪音的影响,该图中存在多个连通面积,取最大连通面积的最小边界框作为目标的定位框,对应至原图中作为目标图像,并缩放至合适大小。
步骤2.1.4进行局部区域图像的裁剪,利用卷积输出特征图的特性,计算目标图像的注意力图A2
Figure SMS_7
该式中f为目标图像生成的特征图。
步骤2.1.5根据步骤2.1.4中获得的多个目标重要区域,使用滑动窗口的方式框选出多个目标重要局部区域,计算窗口所有像素点的注意力值的平均值,窗口注意力平均值计算公式为:
Figure SMS_8
式中,Hw和Ww为窗口的高度和宽度,Aw为注意力图中窗口区域。其中,
Figure SMS_9
的大小与区域的信息量成正比,/>
Figure SMS_10
越大,代表这部分区域的信息量越大。把窗口对应至目标图像中,裁剪出局部区域图像,实现了数据增强。
进一步的,步骤2中教师模型的训练使用交叉熵损失函数计算预测损失,具体分为以下三步:
步骤2.2.1使用原始图像对模型进行训练,原始图像经过特征提取器得到原始图像特征图,原始图像特征图经过全局池化输入全连接层(fc),计算原始图像预测损失Lram
步骤2.2.2基于原始图像特征图,实现裁剪目标图像数据增强方法得到目标图像,目标图像缩放至合适尺寸输入模型,得到目标图像特征图和目标图像预测损失Lobject
步骤2.2.3再次通过裁剪局部区域图像数据增强方法,以滑动窗口的方式获取多个局部区域图像,局部区域图像缩放后输入模型,计算出局部区域图像预测损失Lparts,;
以上各损失函数计算公式如下:
Lraw=-log(Pr(c))
Lobject=-log(Po(c))
Figure SMS_11
其中,c是图像的真实标签,pr是原始图像输出类别标签,po是目标图像输出类别概率,pp(n)是局部图像输出类别概率,其中n是局部区域图像的数量。总的损失值Ltotal计算公式如下:
Ltotal=Lraw+Lobject+Lparts
进一步的,学生模型训练阶段将图像分别输入教师模型和学生模型,对学生模型的预测输出分别使用交叉熵损失计算预测损失值Lhard和KL散度函数计算学生模型与老师模型输出的差距值Lsoft,计算公式如下:
Lhard=-log(P(c))
Figure SMS_12
/>
其中,c是图像的真是标签,P是模型输出i类别概率,T和S分别表示教师与学生,B为目标类的二元输概率,
Figure SMS_13
为非目标类的多类别输出概率,α为NCKD新的权重。总损失计算公式如下:
Ltotal=Lhard+Lsoft
进一步的,步骤3中采用解耦知识蒸馏方法,获取教师模型输出的鸟类细粒度属于某一类别的概率,具体包括:
步骤3.1引入超参数温度T之后通过softmax计算第i个类别的概率pi,softmax计算公式如下:
Figure SMS_14
T是超参数温度,模型输出记为Z=[z1,z2,...,zt,...,zc]∈R1×C,其中zi是第i类输出值,C是任务分类个数,pi表示第i个类别预测概率,记模型输出为P=[p1,p2,...,pt,...,pc]∈R1×C
步骤3.2使用softmax公式计算目标类(pt)和所有其他非目标类(p\t)的预测概率,公式如下:
Figure SMS_15
Figure SMS_16
记B=[pt,p\t]∈R1×2表示模型目标类和非目标类预测概率;目标类知识蒸馏TCKD的定义为:
TCKD=KL(BT‖BS)
其中,S和T分别代表老师和学生;记
Figure SMS_17
Figure SMS_18
为第i个非目标类预测概率,公式入下:
Figure SMS_19
NCKD的定义为:
Figure SMS_20
步骤3.3对KL损失函数进行拆解,首先把目标类分类概率从叠加运算中抽离出来:
Figure SMS_21
知识蒸馏损失可写为:
Figure SMS_22
根据上式可知,NCKD的权重与
Figure SMS_23
相耦合,限制了非目标类知识传递。为改善上述情况,本方法给NCKD赋予一个新的权重,定义为解耦知识蒸馏(DKD),DKD的损失函数定义如下:
DKD=TCKD+αNCKDDKD通过优化非目标类知识蒸馏权重,消除目标类预测概率对非目标类知识传递的抑制。
进一步的,步骤4基于目标定位再识别的思想,预测阶段通过基于视觉注意力的目标区域定位方法,定位目标区域获取目标图像,将目标图像输入最终的轻量级分类模型,获得最终鸟类细粒度分类结果。
本发明还提供了基于知识蒸馏的鸟类细粒度分析系统,采用上述方法进行鸟类细粒度分类工作,包括:
数据处理模块:用于对已有鸟类图像数据集进行目标和关键区域的定位,实现数据增强,对数据进行预处理;
模型训练模块:用于对处理完毕的数据集在特定条件下,DesNet121作为教师模型特征提取器,ShuffleNetV2作为学生模型的特征提取器;
知识蒸馏模块:用于提出的解耦知识蒸馏损失函数,调整参数权重,训练获得轻量级的学生模型;
目标检测模块:用于利用最终训练好的学生模型,基于目标再定位的方法对鸟类数据进行最终的细粒度分类;
控制处理模块:用于向其他模块发出指令,按序完成分类步骤。
进一步的,数据处理模块包括:
图像获取模块,用于通过预设方式获取鸟类图像,建立鸟类数据集;
图像增强模块,用于将图像进行特征提取后进行通道叠加得到注意力图,根据注意力图中的注意力值的分布完成目标区域和重要局部区域的裁剪,实现数据增强;
进一步的,模型训练模块包括:
教师模型训练模块,通过输入原始图像、目标图像和局部区域图像对DesNet121特征提取模型进行训练;
学生模型训练模块,通过输入原始图像即教师模型和学生模型得到的目标图像和局部区域图像对轻量级卷积网络ShuffleNetV2进行训练;
进一步的,知识蒸馏模块包括:
预测损失模块,通过对目标类预测结果和非目标预测结果分别计算KL损失值;
NCKD权重控制模块,通过对NCKD赋予合适的权重,消除目标类预测概率对非目标类知识传递的抑制。
本发明所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统,其显著优点是:提出基于注意力引导的数据增强方法,利用图像特征图定位目标和关键区域,对定位区域实现数据增强,提高训练集的数据质量;采用复杂模型DesNet121作为教师模型,采用轻量级网络ShufflenetV2作为学生模型的特征提取器;基于解耦知识蒸馏方法,优化了非目标类知识蒸馏权重,消除目标类预测概率对非目标类知识传递的抑制,显著提升了知识蒸馏效果,实现了鸟类分类模型的高效压缩,在满足精度的前提下训练出参数量和计算量更小的学生模型。
附图说明
图1所示为本发明所述方法流程图;
图2所示为教师模型的训练结构图;
图3所示为目标区域定位过程示意图;
图4所示为目标局部区域定位过程示意图;
图5所示为学生模型的训练及预测结构图;
图6所示为本发明所述方法中设计的解耦知识蒸馏结构图;
具体实施方式
下面结合附图对本发明的技术方案进一步说明,本实施例在本发明技术方案为前提下进行实施,给出详细的实施步骤和具体操作流程。
实施例一
如图1所示,本发明所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,具体包括以下步骤:
步骤1根据预设方式获取鸟类数据集;
具体的,本文所采用的鸟类数据来自加利福尼亚理工学院提供的鸟类数据库,包含200种常见的鸟类,如CommonYellowthroat、Rock Wren、Marsh Wren等。
步骤2教师模型的训练结构如图2所示,将步骤1中获取的数据集输入教师模型,使用DenseNet121深度卷积网络作为特征提取器,基于注意力引导实现数据增强,获得目标图像和局部区域图像,输入教师模型并将损失值最小的教师模型作为训练好的教师模型。
具体的,开发鸟类分类模型,构建注意力引导数据增强方法,本方法提出的注意力引导数据增强包括两部分,分别获取目标图像和局部区域图像的获取。
具体的,目标区域的定位过程如图3所示,首先通过通道叠加得到原图的注意力图A,其计算公式如下:
Figure SMS_24
使用F∈RC×H×W表示卷积神经网络模型的最后一个卷积层输出的具有C个信道和空间大小为H×W的特征图集,fi是特征图集的第i个特征图,A为所有特征图每个通道对应位置相加得到的通道为1、大小为H×W的注意力图。在注意力图A中,像素值较高的区域作为关键部分所在区域,通过计算目标像素值大小的阈值来判断是否为目标区域,阈值
Figure SMS_25
的计算公式为:
Figure SMS_26
使用上式定位生成H×W的目标位置掩膜,掩膜计算公式如下:
Figure SMS_27
受到图像背景和噪音的影响,在掩模图中可能会存在多个连通面积,使用最大连通的最小边界框作为目标的定位框,并将目标图像放缩至合适大小。
具体的,目标局部区域图像的获取过程如图4所示,局部区域图像基于目标图像获得,首先,通过目标图像的特征图,计算目标图像的注意力图A2,注意力图A2计算公式如下:
Figure SMS_28
该式中f为目标图像生成的特征图。使用不同大小的滑动窗口在注意力图A2上滑动,并计算出窗口在滑动过程中每个位置窗口的注意力平均值,窗口注意力平均值计算公式如下:
Figure SMS_29
式中,Hw和WW为窗口的高度和宽度,Aw为注意力图中窗口区域。其中,
Figure SMS_30
的大小与区域的信息量成正比,/>
Figure SMS_31
越大,代表这部分区域的信息量越大。把窗口对应至目标图像中,裁剪出局部区域图像,实现了数据增强。在窗口选取过程中,为了避免同一区域的多次选择,通常排除与已选区域交并比过大的窗口。裁剪出窗口区域作为局部区域图像,并且在输入模型之前放缩至合适大小。使用目标图像训练模型,提高了模型对目标的识别能力,能消除部分背景和噪音的预测的影响;使用局部区域图像训练模型,实现了数据增强,提高了模型对目标精细特征的提取能力,提升了鸟类细粒度分类模型的分类效果。
具体的,教师模型的训练阶段将原始图像、经过数据增强获得的目标图像和局部区域图像同时进行训练。
进一步的,首先使用原始图像对模型进行训练,原始图像经过特征提取器得到原始图像特征图,原始图像特征图经过全局池化输入全连接层(fc),使用交叉熵损失函数计算原始图像的预测损失Lraw,并基于原始图像特征图,实现裁剪目标图像特征图以及目标图像预测损失Lobject,通过裁剪局部区域图像数据增强方法,以滑动窗口的方式获取多个局部区域图像,进行放缩后输入模型,计算局部图像预测损失Lparts。各损失函数计算公式如下:
Lraw=-log(Pr(c))
Lobject=-log(Po(c))
Figure SMS_32
其中,c是图像的真实标签,Pr是原始图像输出类别的概率,Po是目标图像输出类别概率,Pp(n)是局部区域图像输出类别概率,其中n是局部区域图像的数量。总的损失值为上述三类损失至之和,计算公式如下:
Ltotal=Lraw+Lobject+Lparts
使用上式计算的总损失反向传播对鸟类细粒度分类模型参数进行优化。使用原始图像、目标图像和局部区域图像共同训练模型,提升模型对目标区别性区域的识别能力,实现模型细粒度分类效果。在测试阶段删除裁剪局部区域图像数据增强过程和局部区域图像预测过程,使用目标图像预测作为最终的输出。
步骤3学生模型的训练结构如图5所示,将步骤1中获得的数据集和步骤2教师模型输出的目标图像和局部区域图像输入学生模型,所述学生模型采用轻量级的卷积网络ShuffleNetV2作为特征提取器,基于解耦知识蒸馏,实现鸟类细粒度分类模型压缩,实现教师模型和学生模型同时数据增强的方法。
具体的,学生模型的训练数据包含原始图像、基于教师模型获取的目标图像和局部区域图像、基于学生模型获取的目标图像和局部区域图像5类图像,每张图像的识别过程都受到教师模型的指导。
进一步的,将图像分别输入教师模型和学生模型,对学生模型的预测输出分别使用交叉熵损失计算公式计算预测损失值Lhard和KL散度函数计算学生与教师模型输出的差距值Lsoft,计算公式如下:
Lhard=-log(P(c))
Figure SMS_33
其中,c是图像的真实标签,P是模型输出类别概率,T和S分别是教师和学生,B为目标类的二元输出概率,
Figure SMS_34
为非目标类的多类别输出概率,α为NCKD新的权重。总损失计算公式如下:
Ltotal=Lhard+Lsoft
使用上式计算得到的总损失反向传播对学生模型进行优化。
具体的,如图6所示,构建解耦知识蒸馏方法,通过对目标类和非目标类分别使用了logit蒸馏,实现知识蒸馏的解耦,并提出解耦知识蒸馏(DKD)的损失函数,公式如下:
DKD=TCKD+αNCKD
其中,TCKD和NCKD分别是目标类知识蒸馏和非目标类知识蒸馏,α为本方法赋予NCKD的权重值。
具体的,首先通过softmax公式计算第i个类别的预测概率pi,公式如下:
Figure SMS_35
其中,T为超温度参数,模型输出记为Z=[z1,z2,...,zt,...,zc]∈R1×C,其中zi是第i类输出值,C是任务分类个数,记模型输出为P=[p1,p2,...,pt,...,pc]∈R1×C
进一步的,引入超温度参数T能够显示更多非目标类和目标类之间的相似知识,使得老师模型和学生模型的输出更加平滑,指导学生模型达到更高精度。一般的,使用KL散度作为知识蒸馏(KD)的损失函数,计算公式如下:
Figure SMS_36
其中,S和T分表表示老师和学生。通过softmax方式计算出目标类(pt)和非目标类(p\t)的预测概率,如下所示:
Figure SMS_37
Figure SMS_38
记B=[pt,p\t]∈R1×2表示模型目标类和非目标类预测概率。目标类知识蒸馏(TCKD)的定义如下:
TCKD=KL(BT‖BS)
进一步的,对非目标类的预测概率进行计算,记为
Figure SMS_39
其中,/>
Figure SMS_40
表示第i个非目标类预测概率,计算方法如下:
Figure SMS_41
具体的,NCKD的定义为:
Figure SMS_42
进一步的,对KL损失函数进行拆解,首先把目标类分类概率从叠加运算中抽离出来:
Figure SMS_43
知识蒸馏损失可写为:
Figure SMS_44
/>
根据上式可知,NCKD的权重与
Figure SMS_45
相耦合,在模型预测的过程中目标概率出现接近1的情况下,限制了非目标类知识传递。因此,本方法给NCKD赋予一个新的权重,定义为解耦知识蒸馏(DKD),DKD的损失函数定义如下:
DKD=TCKD+αNCKDDKD通过优化非目标类知识蒸馏权重,消除目标类预测概率对非目标类知识传递的抑制。在鸟类分类任务中,DKD显著提升了知识蒸馏效果。
步骤4基于目标定位再识别的思想,预测阶段将目标图像输入最终的轻量级分类模型,获得最终鸟类细粒度分类结果。
具体的,首先基于注意力图定位得到目标图像,在预测阶段,将目标图像输入最终的轻量级分类模型,得到分类结果。
实施例二
一种基于注意力和解耦知识蒸馏的鸟类细粒度分类系统,包括:
数据处理模块:用于对已有鸟类图像数据集进行目标和关键区域的定位,实现数据增强,对数据进行预处理;
模型训练模块:用于对处理完毕的数据集在特定条件下,DesNet121作为教师模型特征提取器,ShuffleNetV2作为学生模型的特征提取器;
知识蒸馏模块:用于提出的解耦的知识蒸馏损失函数,调整参数权重,训练获得轻量级的学生模型;
目标检测模块:用于利用最终训练好的学生模型,基于目标再定位的方法对鸟类数据进行最终的细粒度分类;
控制处理模块:用于向其他模块发出指令,按序完成分类步骤。
进一步的,数据处理模块包括:
图像获取模块,用于通过预设方式获取鸟类图像,建立鸟类数据集;
图像增强模块,用于将图像进行特征提取后进行通道叠加得到注意力图,根据注意力图中的注意力值的分布完成目标区域和重要局部区域的裁剪,实现数据增强;
进一步的,模型训练模块包括:
教师模型训练模块,通过输入原始图像、目标图像和局部区域图像对DesNet121特征提取模型进行训练;
学生模型训练模块,通过输入原始图像即教师模型和学生模型得到的目标图像和局部区域图像对轻量级卷积网络ShuffleNetV2进行训练;
进一步的,知识蒸馏模块包括:
预测损失模块,通过对目标类预测结果和非目标预测结果分别计算KL损失值;
NCKD权重控制模块,通过对NCKD赋予合适的权重,消除目标类预测概率对非目标类知识传递的抑制。
本系统使用DesNet121和ShuffleNetV2特征提取模型分别对数据集进行训练,通过动态调整NCKD的权重,实现解耦知识蒸馏方法,提升学生模型对鸟类精细特征的学习效果,获得高准确率、低成本的鸟类分类系统。
本文中所描述的具体实施例仅仅是对本发明精神作举例说明。本发明所属技术领域的技术人员可以对所描述的具体实施例做各类修改或补充或采用类似的方式替代,但并不会偏离本发明的精神或超越所附权利要求书定义的范围。

Claims (7)

1.一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,其特征在于,包括以下步骤:
(1)根据预设方式获取鸟类数据集;
(2)将步骤(1)中获取的数据集输入教师模型,使用DenseNet121深度卷积网络作为特征提取器,基于注意力引导实现数据增强,获得目标图像和局部区域图像,输入教师模型并将损失值最小的教师模型作为训练好的教师模型;
(3)将步骤(1)中获得的数据集和步骤(2)教师模型输出的目标图像和局部区域图像输入学生模型,所述学生模型采用轻量级的卷积网络ShuffleNetV2作为特征提取器,基于解耦知识蒸馏,实现鸟类细粒度分类模型压缩,实现教师模型和学生模型同时数据增强的方法;
(4)基于目标定位再识别的思想,预测阶段将目标图像输入最终的轻量级分类模型,获得最终鸟类细粒度分类结果。
2.根据权利要求1所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,其特征在于,所述步骤(2)中注意力引导的数据增强方法,具体包括:
(2.1.1)通过提取原图特征图中的目标位置信息生成注意力图A,注意力图A的计算公式为:
Figure QLYQS_1
(2.1.2)将注意力图A中像素值较高的区域作为目标区域,通过计算计算目标像素值大小的阈值
Figure QLYQS_2
阈值/>
Figure QLYQS_3
的计算公式为:
Figure QLYQS_4
(2.1.3)使用阈值
Figure QLYQS_5
判断注意力图A中每个像素点是否为目标的一部分,并生成H×W的目标位置掩膜,掩膜计算公式如下:
Figure QLYQS_6
在掩膜图中使用最大连通面积的最小边界作为目标图像;
(2.1.4)通过目标图像的特征图,计算目标图像的注意力图A2,注意力图A2计算公式如下:
Figure QLYQS_7
该式中f为目标图像生成的特征图;
(2.1.5)使用滑动窗口的方式在注意力图A2框选目标重要区域,计算窗口中所有像素低点的平均值,窗口注意力平均值计算公式如下:
Figure QLYQS_8
式中,Hw和Ww为窗口的高度和宽度,Aw为注意力图中窗口区域。其中,
Figure QLYQS_9
的大小与区域的信息量成正比,/>
Figure QLYQS_10
越大,代表这部分区域的信息量越大。把窗口对应至目标图像中,裁剪出局部区域图像,实现了数据增强。
3.根据权利要求1所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,其特征在于,所述步骤(2)中教师模型的训练具体包括:
(2.2.1)使用原始图像经过DenseNet121特征提取,得到原始图像特征图,原始图像特征图经过全局池化输入全连接层(fc),计算原始图像预测损失Lraw
(2.2.2)基于原始图像特征图实现数据增强得到目标图像,并将目标图像放缩至合适大小输入教师模型,得到目标图像特征图和目标图像预测损失Lobject
(2.2.3)基于目标图像进行裁剪并通过滑动窗口的方式获得多个局部区域图像,并输入模型,计算局部区域图像预测损失Lparts
以上预测损失计算公式如下:
Lraw=-log(Pr(c))
Lobject=-log(Po(c))
Figure QLYQS_11
其中,c是图像的真实标签,Pr是原始图像输出类别概率,Po是目标图像输出类别概率,Pp(n)是局部区域图像输出类别概率,其中n是局部区域图像的数量。总的损失值为三个损失值相加,总损失计算方法为:
Ltotal=Lraw+Lobject+Lparts
4.根据权利要求1所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,其特征在于,所述步骤(3)中学生模型的训练中,将图像分别输入教师模型和学生模型,对学生模型的预测输出分别使用交叉熵损失计算预测损失值Lhard和解耦知识蒸馏函数(DKD)计算学生与教师模型输出的差距值Lsoft,计算公式如下:
Lhard=-log(P(c))
Lsoft=DKD(PT,PS)
其中,c是图像的真是标签,P是模型输出i类别概率,T和S分别表示教师与学生,总损失计算公式如下:
Ltotal=Lhard+Lsoft
5.根据权利要求1所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,其特征在于,所述步骤(3)中的解耦知识蒸馏方法,具体包括:
(3.1)采用解耦知识蒸馏方法,获取教师模型输出的鸟类细粒度属于某一类别的概率;具体包括:引入超参数温度T之后通过softmax计算第i个类别的概率pi,softmax计算公式如下:
Figure QLYQS_12
T是超参数温度,模型输出记为Z=[z1,z2,...,zt,...,zc]∈R1×C,其中zi是第i类输出值,C是任务分类个数,pi表示第i个类别预测概率,记模型输出为P=[p1,p2,...,pt,...,pc]∈R1×C
(3.2)使用softmax公式计算目标类(pt)和所有其他非目标类(p\t)的预测概率,公式如下:
Figure QLYQS_13
Figure QLYQS_14
记B=[pt,p\t]∈R1×2表示模型目标类和非目标类预测概率;目标类知识蒸馏TCKD的定义为:
TCKD=KL(BT‖BS)
其中,S和T分别代表老师和学生;记
Figure QLYQS_15
为第i个非目标类预测概率,公式入下:
Figure QLYQS_16
NCKD的定义为:
Figure QLYQS_17
(3.3)对KL损失函数进行拆解,首先把目标类分类概率从叠加运算中抽离出来:
Figure QLYQS_18
知识蒸馏损失可写为:
Figure QLYQS_19
根据上式可知,NCKD的权重与
Figure QLYQS_20
相耦合,限制了非目标类知识传递。为改善上述情况,本方法给NCKD赋予一个新的权重,定义为解耦知识蒸馏(DKD),DKD的损失函数定义如下:
DKD=TCKD+αNCKD
DKD通过优化非目标类知识蒸馏权重,消除目标类预测概率对非目标类知识传递的抑制。
6.根据根据权利要求1所述的一种基于注意力和解耦知识蒸馏的鸟类细粒度分类方法,其特征在于,所述步骤(4)中的教师模型和学生模型同时数据增强的方法,具体包括:
在学生模型的训练阶段同时使用原始图像、基于教师模型获取的目标图像和局部区域图像、基于学生模型获取的目标图像和局部区域图像5类图像作为数据增强后的数据训练学生模型,每张图像都会接受教师模型的指导。
7.基于注意力和解耦知识蒸馏的鸟类细粒度分类系统,其特征在于,采用如权利要求1~5中任一项所述的方法进行鸟类细粒度分类的工作,包括:
数据处理模块:用于对已有鸟类图像数据集进行目标和关键区域的定位,实现数据增强,对数据进行预处理;
模型训练模块:用于对处理完毕的数据集在特定条件下,DesNet121作为教师模型特征提取器,ShuffleNetV2作为学生模型的特征提取器;
知识蒸馏模块:用于提出的解耦知识蒸馏损失函数,调整参数权重,训练获得轻量级的学生模型;
目标检测模块:用于利用最终训练好的学生模型,基于目标再定位的方法对鸟类数据进行最终的细粒度分类;
控制处理模块:用于向其他模块发出指令,按序完成分类步骤。
CN202211534488.7A 2022-12-02 2022-12-02 基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统 Pending CN115880529A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211534488.7A CN115880529A (zh) 2022-12-02 2022-12-02 基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211534488.7A CN115880529A (zh) 2022-12-02 2022-12-02 基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统

Publications (1)

Publication Number Publication Date
CN115880529A true CN115880529A (zh) 2023-03-31

Family

ID=85765462

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211534488.7A Pending CN115880529A (zh) 2022-12-02 2022-12-02 基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统

Country Status (1)

Country Link
CN (1) CN115880529A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116883745A (zh) * 2023-07-13 2023-10-13 南京恩博科技有限公司 一种基于深度学习的动物定位模型及方法
CN117036698A (zh) * 2023-07-27 2023-11-10 中国矿业大学 一种基于双重特征知识蒸馏的语义分割方法

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116883745A (zh) * 2023-07-13 2023-10-13 南京恩博科技有限公司 一种基于深度学习的动物定位模型及方法
CN116883745B (zh) * 2023-07-13 2024-02-27 南京恩博科技有限公司 一种基于深度学习的动物定位模型及方法
CN117036698A (zh) * 2023-07-27 2023-11-10 中国矿业大学 一种基于双重特征知识蒸馏的语义分割方法
CN117036698B (zh) * 2023-07-27 2024-06-18 中国矿业大学 一种基于双重特征知识蒸馏的语义分割方法

Similar Documents

Publication Publication Date Title
Dong et al. PGA-Net: Pyramid feature fusion and global context attention network for automated surface defect detection
CN108961235B (zh) 一种基于YOLOv3网络和粒子滤波算法的缺陷绝缘子识别方法
CN110532900B (zh) 基于U-Net和LS-CNN的人脸表情识别方法
CN107563372B (zh) 一种基于深度学习ssd框架的车牌定位方法
CN104992223B (zh) 基于深度学习的密集人数估计方法
CN112241762B (zh) 一种用于病虫害图像分类的细粒度识别方法
CN109165623B (zh) 基于深度学习的水稻病斑检测方法及系统
CN115880529A (zh) 基于注意力和解耦知识蒸馏的鸟类细粒度分类方法及系统
EP3690741A2 (en) Method for automatically evaluating labeling reliability of training images for use in deep learning network to analyze images, and reliability-evaluating device using the same
CN112464911A (zh) 基于改进YOLOv3-tiny的交通标志检测与识别方法
CN110569843B (zh) 一种矿井目标智能检测与识别方法
CN103049763A (zh) 一种基于上下文约束的目标识别方法
CN112084930A (zh) 一种全视野数字病理切片的病灶区域分类方法及其系统
CN112861970B (zh) 一种基于特征融合的细粒度图像分类方法
Huang et al. Qualitynet: Segmentation quality evaluation with deep convolutional networks
CN112528058B (zh) 基于图像属性主动学习的细粒度图像分类方法
CN110334584A (zh) 一种基于区域全卷积网络的手势识别方法
CN113205026A (zh) 一种基于Faster RCNN深度学习网络改进的车型识别方法
CN115410258A (zh) 基于注意力图像的人脸表情识别方法
CN114972759A (zh) 基于分级轮廓代价函数的遥感图像语义分割方法
CN116665148A (zh) 基于合成孔径雷达数据的海上船舶检测方法
CN114743109A (zh) 多模型协同优化高分遥感图像半监督变化检测方法及系统
CN114626476A (zh) 基于Transformer与部件特征融合的鸟类细粒度图像识别方法及装置
CN115719475A (zh) 一种基于深度学习的三阶段轨旁设备故障自动检测方法
CN111242028A (zh) 基于U-Net的遥感图像地物分割方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination