CN114155436B - 长尾分布的遥感图像目标识别逐步蒸馏学习方法 - Google Patents

长尾分布的遥感图像目标识别逐步蒸馏学习方法 Download PDF

Info

Publication number
CN114155436B
CN114155436B CN202111471933.5A CN202111471933A CN114155436B CN 114155436 B CN114155436 B CN 114155436B CN 202111471933 A CN202111471933 A CN 202111471933A CN 114155436 B CN114155436 B CN 114155436B
Authority
CN
China
Prior art keywords
model
teacher
training
teacher model
head
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202111471933.5A
Other languages
English (en)
Other versions
CN114155436A (zh
Inventor
赵文达
刘佳妮
刘瑜
卢湖川
何友
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Dalian University of Technology
Original Assignee
Dalian University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Dalian University of Technology filed Critical Dalian University of Technology
Priority to CN202111471933.5A priority Critical patent/CN114155436B/zh
Publication of CN114155436A publication Critical patent/CN114155436A/zh
Application granted granted Critical
Publication of CN114155436B publication Critical patent/CN114155436B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明属于图像信息处理技术领域,提出了一种逐步蒸馏学习的长尾分布的遥感图像目标识别方法,具体为一种利用头尾数据之间的联系,并结合知识蒸馏完成遥感图像分类的方法。我们使用结构相同的三个教师模型与一个学生模型。提出了渐进式教师模型的学习以及自校正采样算法,在学生模型训练过程中可以很好的解决长尾问题,使最终的分类准确度得到提升。本发明利用蒸馏的方法以及提出的渐进式教师学习和自校正采样学习算法,增强了网络特征提取能力,目前存在的解决长尾问题的各种方法仍然存在各种弊端,比如不能充分利用头部数据的优势、对超参数敏感等等,本发明的逐步蒸馏学习方法方法有效的解决了这些问题,本发明方法能够提升分类网络的准确度。

Description

长尾分布的遥感图像目标识别逐步蒸馏学习方法
技术领域
本发明属于图像信息处理技术领域,特别是涉及遥感图像目标识别的方法。
背景技术
目前,与本专利相关的方法包括两方面:第一是基于深度学习的长尾分布图像目标识别算法;第二是基于特征表示的蒸馏学习算法。
基于深度学习的长尾分布图像目标识别算法主要分为三类:一类是对长尾分布数据进行重采样的方法,在训练集上实现样本平衡,包括对头部样本欠采样以及对尾部样本过采样。Ren等人在文献《Ensemble based adaptive over-sampling method forimbalanced data learning in computer aided detection of microaneurysm》中提出了一种基于集成的自适应过采样算法,减少了不平衡数据引入的归纳偏差,克服了假阳性减少中的类不平衡问题。一类是给不同类别的损失设置不同权重的方法,通常会对损失函数中的尾类分配较大的权重,对头类的权重相对较小,使损失函数更关注尾类,加强尾类的优化。Cui等人在文献《Class-balanced loss based on effective number of samples》中引入了一种新的理论框架,通过与每个样本的一个小的邻近区域关联来测量数据重叠,并设计了一个重新加权方案,使用每个类的有效样本数来重新平衡损失。第三类是最近提出的多专家网络,训练一个多专家网络,然后设计不同的方法来结合不同专家网络的学习结果。Wang等人在文献《Long-tailed recognition by routing diverse distribution-aware experts》中提出了一种新的共享早期层和独立的通道减少的后期层的多专家模型,通过分布感知多样性损失减少了模型偏差,通过动态专家路由模型降低了计算成本。
基于特征表示的蒸馏学习算法的研究也有多种方式,例如,He等人在文献《Distilling virtual examples for longtailed recognition》中从知识精馏的角度解决了长尾视觉识别问题,提出了一种虚拟实例的提取方法。Ju等人在文献《Relationalsubsets knowledge distillation for long-tailed retinal diseases recognition》中提出了根据先验知识将长尾数据划分为多个类子集并分别进行学习,强制模型集中学习特定于子集的知识。Zhang等人在文献《Balanced knowledge distillation for long-tailed learning》中通过最小化实例平衡分类损失和类平衡蒸馏损失的组合来训练学生模型,解决了修改分类损失以增加对尾类的学习重点但却牺牲了头类的性能的问题。
现实生活中的数据通常呈现极端不平衡现象,使得真实数据的分布通常呈现出“长尾”分布的形态。另外,针对于遥感图像的目标识别问题,同样可以使用自然域的图像的目标识别方法来完成,但是遥感图像和自然图像在域上具有较大的差别,直接使用准确度必然下降。因此由于遥感数据集呈现极端的长尾分布状态,导致使用传统分类器对遥感图像分类时性能恶劣。我们的方法将整个过程分为两个阶段,第一阶段将呈长尾分布的数据集根据每个类别的数量划分成三个子集分别训练三个教师模型,由于大量数据训练的模型具有良好的特征提取能力,为了充分利用这一优势,提出教师模型之间渐进式学习,第二个阶段可以利用已经训练好的教师模型辅助学生模型进行学习,这个过程中我们还提出自校正采样学习方法,有效的针对每个训练迭代过程中学生模型的学习结果动态更新采样权重,增加学生模型的识别准确度。
发明内容
针对提升遥感网络识别准确度的问题,提出了一种利用头尾数据之间的联系,并结合知识蒸馏完成遥感图像分类的方法。我们使用结构相同的三个教师模型与一个学生模型。提出了渐进式教师模型的学习以及自校正采样算法,在学生模型训练过程中可以很好的解决长尾问题,使最终的分类准确度得到提升。
本发明的技术方案:
一种逐步蒸馏学习的长尾分布的遥感图像目标识别方法,步骤如下:
整个训练过程主要分为教师模型的训练阶段和学生模型的训练阶段;
(1)教师模型的训练阶段
构建一个基础的分类网络用于教师模型的训练:Resnet50的前四个模块作为网络的特征提取主干的卷积模型,Resnet50的第四个模块输出的特征作为辅助其他教师模型进行训练的特征表示;
首先将呈长尾分布的训练集划分成三个子集,分别为头部子集、中部子集和尾部子集;教师模型包括头部教师模型、中部教师模型和尾部教师模型;将头部子集输入到对应的头部教师模型中,训练出一个具有良好特征提取能力的头部教师模型,然后将中部子集输入到对应的中部教师模型中,在进行训练时,该中部子集还将输入到头部教师模型中,此时头部教师模型冻结,将头部教师模型第四个模型输出的特征与中部教师模型对应位置对应图片的特征进行比较,对比的MSE损失为:
其中,下角标M表示中部教师模型,FM'和FM分别为头部和中部教师模型的特征,n为batch size;
同样,训练尾部子集对应的尾部教师模型时,利用头部教师模型和中部教师模型进行辅助训练,对比的MSE损失为:
其中,下角标T表示尾部教师模型,F″T、FT'和FT分别为中部、头部和尾部教师模型的特征;
通过式(1)和(2)的约束可训练得到三个特征提取能力以及分类能力都比较好的教师模型TH、TM和TT
(2)学生模型的训练阶段
在第二个阶段进行学生模型的训练时,将TH、TM和TT学到的知识蒸馏到学生模型中,学生模型的网络结构与教师模型的网络结构完全相同;由于学生模型在学习开始时对任何类都具有相同的特征提取能力,因此在每次训练开始时采用均匀采样,然后再使用提出的自校正采样学习;具体来说,教师模型第四个模块输出的特征与学生模型对应位置对应图片的特征进行比较,该MSE损失为:
其中,F'和F分别为教师模型和学生模型的特征;然后,根据损失LMSE来评估学生模型的学习质量,设计了一个公式来根据LMSE得到每个类的权重w:
w=α×log(LMSE+1) (4)
其中,α为超参数;最后将获得的权重w应用到采样器中,学生模型对应某一类学习质量越好,损失越小,w越小,所以该类在下一个batch中的采样概率越小,反之亦然;
另外,三个学生模型和教师模型的分类损失函数都是相同的,每个模型的分类损失为:
其中,c为数据集的类别,fi为模型分类的概率,yi为真值;
最终教师模型学习到的知识全部蒸馏到学生模型中,并且利用自校正采样算法让学生模型对于某个学习效果不好的类进行再次学习。
本发明的有益效果:本发明的长尾遥感图像目标识别的逐步蒸馏学习方法,利用了蒸馏的方法以及提出的渐进式教师学习和自校正采样学习算法,增强了网络特征提取能力,目前存在的解决长尾问题的各种方法仍然存在各种弊端,比如:不能充分利用头部数据的优势、对超参数敏感等等,本发明的逐步蒸馏学习方法方法有效的解决了这些问题,本发明方法能够提升分类网络的准确度。
附图说明
图1为网络整体训练流程图。
图2为有监督阶段结构示意图。
图3为半监督阶段结构示意图。
具体实施方式
以下结合附图和技术方案,进一步说明本发明的具体实施方式。
图1为网络整体训练流程图,第一步,通过划分好的子集训练得到三个具有较好特征提取能力的教师模型TH、TM和TT,第二步,将第一步训练好的三个教师模型学习到的知识,蒸馏到学生模型S中,同时使用自校正采样实现对模型S的训练。图2的教师模型训练的具体流程包括,首先用划分好的子集对教师模型进行训练,将图片输入到卷积层中实现特征提取,图中的长方体为提取到的特征,然后再对分类器进行训练,由于数据量大所以训练得到的模型的特征提取能力越好,为了充分利用这一优势,对教师模型进行渐进式训练,即利用已经训练好的教师模型进行辅助训练,图中用虚线表示模型之间的辅助训练,具体来说,首先用传统方法对头部教师模型进行训练,得到特征提取能力很好的头部教师模型;然后,中部子集不仅输入到对应的中部教师模型中,还会输入到训练好的头部教师模型中,此时,头部教师模型参数固定,使用公式(1)、(2)对比对应位置对应图片的特征表示,获得特征提取能力较好的中部教师模型;尾部教师模型类似,将尾部子集输入到头部教师模型和中部教师模型中,固定头部教师模型和中部教师模型的参数,对比对应特征。利用这种方法可以提高数据量较少的教师模型的特征提取能力。图3主要是学生模型训练的展示,输入为完整的长尾分布的训练集,固定三个教师模型的参数,将三个教师模型学习到的知识通过特征比较蒸馏到学生模型中,根据特征对比,得到关于每个类的特征对比损失,由此对学生模型的学习效果进行评估,损失越大说明学习效果越差,再次对该类采样的概率也就越大,反之亦然,通过不断对学生模型学习效果的评估,逐渐改善模型学习能力,这样整个学生模型的分类能力也会得到显著提升。

Claims (1)

1.一种逐步蒸馏学习的长尾分布的遥感图像目标识别方法,其特征在于,步骤如下:
整个训练过程主要分为教师模型的训练阶段和学生模型的训练阶段;
(1)教师模型的训练阶段
构建一个基础的分类网络用于教师模型的训练:Resnet50的前四个模块作为网络的特征提取主干的卷积模型,Resnet50的第四个模块输出的特征作为辅助其他教师模型进行训练的特征表示;
首先将呈长尾分布的训练集划分成三个子集,分别为头部子集、中部子集和尾部子集;教师模型包括头部教师模型、中部教师模型和尾部教师模型;将头部子集输入到对应的头部教师模型中,训练出一个具有良好特征提取能力的头部教师模型,然后将中部子集输入到对应的中部教师模型中,在进行训练时,该中部子集还将输入到头部教师模型中,此时头部教师模型冻结,将头部教师模型第四个模型输出的特征与中部教师模型对应位置对应图片的特征进行比较,对比的MSE损失为:
其中,下角标M表示中部教师模型,FM'和FM分别为头部和中部教师模型的特征,n为batch size;
同样,训练尾部子集对应的尾部教师模型时,利用头部教师模型和中部教师模型进行辅助训练,对比的MSE损失为:
其中,下角标T表示尾部教师模型,F″T、FT'和FT分别为中部、头部和尾部教师模型的特征;
通过式(1)和(2)的约束可训练得到三个特征提取能力以及分类能力都比较好的教师模型TH、TM和TT
(2)学生模型的训练阶段
在第二个阶段进行学生模型的训练时,将TH、TM和TT学到的知识蒸馏到学生模型中,学生模型的网络结构与教师模型的网络结构完全相同;由于学生模型在学习开始时对任何类都具有相同的特征提取能力,因此在每次训练开始时采用均匀采样,然后再使用提出的自校正采样学习;具体来说,教师模型第四个模块输出的特征与学生模型对应位置对应图片的特征进行比较,该MSE损失为:
其中,F'和F分别为教师模型和学生模型的特征;然后,根据损失LMSE来评估学生模型的学习质量,设计了一个公式来根据LMSE得到每个类的权重w:
w=α×log(LMSE+1) (4)
其中,α为超参数;最后将获得的权重w应用到采样器中,学生模型对应某一类学习质量越好,损失越小,w越小,所以该类在下一个batch中的采样概率越小,反之亦然;
另外,三个学生模型和教师模型的分类损失函数都是相同的,每个模型的分类损失为:
其中,c为数据集的类别,fi为模型分类的概率,yi为真值;
最终教师模型学习到的知识全部蒸馏到学生模型中,并且利用自校正采样算法让学生模型对于某个学习效果不好的类进行再次学习。
CN202111471933.5A 2021-12-06 2021-12-06 长尾分布的遥感图像目标识别逐步蒸馏学习方法 Active CN114155436B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111471933.5A CN114155436B (zh) 2021-12-06 2021-12-06 长尾分布的遥感图像目标识别逐步蒸馏学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111471933.5A CN114155436B (zh) 2021-12-06 2021-12-06 长尾分布的遥感图像目标识别逐步蒸馏学习方法

Publications (2)

Publication Number Publication Date
CN114155436A CN114155436A (zh) 2022-03-08
CN114155436B true CN114155436B (zh) 2024-05-24

Family

ID=80452731

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111471933.5A Active CN114155436B (zh) 2021-12-06 2021-12-06 长尾分布的遥感图像目标识别逐步蒸馏学习方法

Country Status (1)

Country Link
CN (1) CN114155436B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114511887B (zh) * 2022-03-31 2022-07-05 北京字节跳动网络技术有限公司 组织图像的识别方法、装置、可读介质和电子设备
CN115019123B (zh) * 2022-05-20 2023-04-18 中南大学 一种遥感图像场景分类的自蒸馏对比学习方法
CN115272881B (zh) * 2022-08-02 2023-03-21 大连理工大学 动态关系蒸馏的长尾遥感图像目标识别方法
CN116758391B (zh) * 2023-04-21 2023-11-21 大连理工大学 噪声抑制蒸馏的多域遥感目标泛化性识别方法

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20200121206A (ko) * 2019-04-15 2020-10-23 계명대학교 산학협력단 심층 네트워크와 랜덤 포레스트가 결합된 앙상블 분류기의 경량화를 위한 교사-학생 프레임워크 및 이를 기반으로 하는 분류 방법
CN112199535A (zh) * 2020-09-30 2021-01-08 浙江大学 一种基于集成知识蒸馏的图像分类方法
CN112529178A (zh) * 2020-12-09 2021-03-19 中国科学院国家空间科学中心 一种适用于无预选框检测模型的知识蒸馏方法及系统
CN112766087A (zh) * 2021-01-04 2021-05-07 武汉大学 一种基于知识蒸馏的光学遥感图像舰船检测方法
CN113255822A (zh) * 2021-06-15 2021-08-13 中国人民解放军国防科技大学 一种用于图像检索的双重知识蒸馏方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11966670B2 (en) * 2018-09-06 2024-04-23 Terrafuse, Inc. Method and system for predicting wildfire hazard and spread at multiple time scales
US11720727B2 (en) * 2018-09-06 2023-08-08 Terrafuse, Inc. Method and system for increasing the resolution of physical gridded data

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20200121206A (ko) * 2019-04-15 2020-10-23 계명대학교 산학협력단 심층 네트워크와 랜덤 포레스트가 결합된 앙상블 분류기의 경량화를 위한 교사-학생 프레임워크 및 이를 기반으로 하는 분류 방법
CN112199535A (zh) * 2020-09-30 2021-01-08 浙江大学 一种基于集成知识蒸馏的图像分类方法
CN112529178A (zh) * 2020-12-09 2021-03-19 中国科学院国家空间科学中心 一种适用于无预选框检测模型的知识蒸馏方法及系统
CN112766087A (zh) * 2021-01-04 2021-05-07 武汉大学 一种基于知识蒸馏的光学遥感图像舰船检测方法
CN113255822A (zh) * 2021-06-15 2021-08-13 中国人民解放军国防科技大学 一种用于图像检索的双重知识蒸馏方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
基于可见光遥感图像的船只目标检测识别方法;陈亮;王志茹;韩仲;王冠群;周浩天;师皓;胡程;龙腾;;科技导报;20171028(第20期);全文 *
基于对抗学习和知识蒸馏的神经网络压缩算法;刘金金;李清宝;李晓楠;计算机工程与应用;20210618;第57卷(第021期);全文 *

Also Published As

Publication number Publication date
CN114155436A (zh) 2022-03-08

Similar Documents

Publication Publication Date Title
CN114155436B (zh) 长尾分布的遥感图像目标识别逐步蒸馏学习方法
CN107392919B (zh) 基于自适应遗传算法的灰度阈值获取方法、图像分割方法
CN103729678B (zh) 一种基于改进dbn模型的水军检测方法及系统
CN109117793B (zh) 基于深度迁移学习的直推式雷达高分辨距离像识别方法
CN114841257B (zh) 一种基于自监督对比约束下的小样本目标检测方法
CN112633406A (zh) 一种基于知识蒸馏的少样本目标检测方法
CN111738303A (zh) 一种基于层次学习的长尾分布图像识别方法
CN112819063B (zh) 一种基于改进的Focal损失函数的图像识别方法
CN115272881B (zh) 动态关系蒸馏的长尾遥感图像目标识别方法
CN112784872B (zh) 一种基于开放集联合迁移学习的跨工况故障诊断方法
CN110598018A (zh) 一种基于协同注意力的草图图像检索方法
CN111191685A (zh) 一种损失函数动态加权的方法
CN112527993A (zh) 一种跨媒体层次化深度视频问答推理框架
CN116318928A (zh) 一种基于数据增强和特征融合的恶意流量识别方法及系统
Meng et al. QoE-based big data analysis with deep learning in pervasive edge environment
CN113901448A (zh) 基于卷积神经网络和轻量级梯度提升机的入侵检测方法
CN110569761B (zh) 一种基于对抗学习的手绘草图检索遥感图像的方法
CN115984213A (zh) 基于深度聚类的工业产品外观缺陷检测方法
CN114612747A (zh) 基于无监督加权哈希的遥感图像检索方法
CN114780879A (zh) 一种用于知识超图的可解释性链接预测方法
CN114169504B (zh) 基于自适应滤波的图卷积神经网络池化方法
CN116756391A (zh) 一种基于图数据增强的不平衡图节点神经网络分类方法
CN116821905A (zh) 一种基于知识搜索的恶意软件检测方法及系统
CN114973350B (zh) 一种源域数据无关的跨域人脸表情识别方法
CN114859317A (zh) 雷达目标自适应反向截断智能识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant