CN115100543A - 面向小样本遥感影像场景分类的自监督自蒸馏元学习方法 - Google Patents

面向小样本遥感影像场景分类的自监督自蒸馏元学习方法 Download PDF

Info

Publication number
CN115100543A
CN115100543A CN202210878368.2A CN202210878368A CN115100543A CN 115100543 A CN115100543 A CN 115100543A CN 202210878368 A CN202210878368 A CN 202210878368A CN 115100543 A CN115100543 A CN 115100543A
Authority
CN
China
Prior art keywords
self
remote sensing
sensing image
classification
learning
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210878368.2A
Other languages
English (en)
Inventor
夏鲁瑞
周浩楠
李森
陈雪旗
卢妍
张占月
王鹏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peoples Liberation Army Strategic Support Force Aerospace Engineering University
Original Assignee
Peoples Liberation Army Strategic Support Force Aerospace Engineering University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peoples Liberation Army Strategic Support Force Aerospace Engineering University filed Critical Peoples Liberation Army Strategic Support Force Aerospace Engineering University
Priority to CN202210878368.2A priority Critical patent/CN115100543A/zh
Publication of CN115100543A publication Critical patent/CN115100543A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • G06V20/10Terrestrial scenes
    • G06V20/13Satellite images
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/42Global feature extraction by analysis of the whole pattern, e.g. using frequency domain transformations or autocorrelation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/44Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Multimedia (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Astronomy & Astrophysics (AREA)
  • Remote Sensing (AREA)
  • Image Analysis (AREA)

Abstract

本发明具体公开了一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,所述方法包括以下步骤:S1、基于自监督学习并利用遥感图像数据对预设的深度神经网络进行训练,获取嵌入网络模型;S2、将嵌入网络模型嵌入到元学习框架中进行训练,得到小样本分类嵌入网络模型;S3、基于自蒸馏训练对小样本分类嵌入网络模型进行迭代优化,进而完成面向小样本遥感影像分类的自监督自蒸馏元学习,得到训练好的SMSR模型。本发明通过将自监督学习训练得到具有特征提取能力的嵌入网络模型应用到元学习框架中进行小样本遥感影像分类训练,再经过自蒸馏精简得到训练好的SMSR模型,进而实现面向小样本遥感影像分类的自监督自蒸馏元学习。

Description

面向小样本遥感影像场景分类的自监督自蒸馏元学习方法
技术领域
本发明涉及遥感影像处理技术领域,尤其涉及一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法。
背景技术
随着越来越多的遥感卫星被发射升空,遥感影像的获取变得越来越简单,对更快速更智能的遥感影像理解的需求也日益增加。随着遥感影像分辨率的提高,遥感影像中地物的细节越来越丰富,也产生了众多新的应用。其中,遥感影像场景分类是将每张遥感影像指定一个标签,这一标签表明了该影像中包含的地物所构成的场景,如商业区、学校或者公园。以往针对不同地物的特点手工设计特征来进行遥感影像理解的方法因为费时费力且准确性不高,已经逐渐被基于机器(深度)学习的方法所代替。
深度学习技术在遥感图像处理领域有着日渐广泛的应用。通过在大规模的有标签遥感影像数据集上进行有监督训练,可以得到适用于该数据集的深度神经网络。遥感影像场景分类是对遥感影像进行智能解译的重要环节,但是在某些特定场景下,获得的遥感影像分类数据往往十分有限且缺少标注,无法反映真实数据分布,用于直接训练深度学习网络,得到的网络的性能会出现明显的下降。
鉴于此,如何解决遥感影像场景分类中存在的小样本问题是本技术领域人员亟待解决的技术问题。
发明内容
本发明的目的是克服现有技术中的缺点和不足,提供一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,通过自监督学习训练得到具有特征提取能力的嵌入网络模型,将嵌入网络模型应用到元学习框架中进行小样本遥感影像分类训练,然后使用自蒸馏进行精简,从而得到训练好的SMSR模型,最后利用训练好的SMSR模型即可对实时获取的遥感影像进行分类。
为解决上述技术问题,本发明提供一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,所述方法包括以下步骤:
S1、基于自监督学习并利用遥感图像数据对预设的深度神经网络进行训练,进而得到能够提取遥感影像特征信息的嵌入网络模型;
S2、将步骤S1中所获取的嵌入网络模型嵌入到元学习的框架中进行训练,并使元学习训练后的嵌入网络模型能够利用已有先验知识对新样本进行分类,进而得到小样本分类嵌入网络模型;
S3、基于自蒸馏训练对小样本分类嵌入网络模型进行迭代优化,进而完成面向小样本遥感影像分类的自监督自蒸馏元学习,得到训练好的SMSR模型。
优选地,所述步骤S1具体实现方式为:
S11、将对比学习设置为自监督学习的辅助任务,并选取NCE损失为对比学习的衡量指标;
S12、利用不同数据增强方法对同一张遥感图像进行数据增强,并计算在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失;
S13、利用随机梯度下降法不断减少在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失,直至该NCE损失的值趋于稳定,进而得到能够提取遥感影像特征信息的嵌入网络模型。
优选地,所述步骤S12中不同数据增强方法包括缩放、随机裁剪、随机中心裁剪、仿射变换、颜色抖动、随机水平翻转和随机灰度变换。
优选地,所述对比学习通过选取来源于预设的深度神经网络中不同卷积层所生成的特征图进行对比。
优选地,所述步骤S12中在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失用公式表示为:
L(xm,xn)=L(fa(xm),f5(xn))+L(fa(xm),fe(xn))+L(f5(xm),f5(xn))+L(fa(xn),f5(xm))+L(fa(xn),fe(xm))+L(fe(xm),fe(xn)) (1)
式(1)中,xm和xn分别表示同一张遥感影像经过两种不同数据增强方法进行数据增强的处理结果,fa表示预设的深度神经网络提取的全局特征图,f5表示预设的深度神经网络中第五层卷积提取的局部特征图,fe表示预设的深度神经中最后一层卷积提取的网络最后一层输出特征图,其中,
Figure BDA0003763087360000031
Figure BDA0003763087360000032
Figure BDA0003763087360000033
Figure BDA0003763087360000034
Figure BDA0003763087360000035
Figure BDA0003763087360000036
N表示所有遥感影像,d表示欧几里得距离的平方。
优选地,所述步骤S2的具体实现方式为:
S21、将步骤S1所获取的嵌入网络模型嵌入到元学习的框架中并作为待训练嵌入网络模型;
S22、将由所有遥感影像构成的数据集中的样本分为若干个种类,从每一个种类中均随机抽取K个样本作为支持集,其余样本作为查询集,并将支持集中的样本输入待训练嵌入网络模型中进行度量,得到支持集中第i个样本的分类结果wi,用公式表示为:
Figure BDA0003763087360000037
式(2)中,f()表示待训练嵌入网络模型,S表示支持集,yi表示支持集中的第i个样本,zi表示样本yi的对应标签;
S23、将查询集中的样本输入待训练嵌入网络模型得到查询集中样本的分类结果,然后计算查询集中样本的分类结果与支持集中样本的分类结果之间的差异,进而判断出查询集中样本的种类;
S24、利用随机梯度下降法不断降低查询集中样本的分类结果与支持集中样本的分类结果之间的差异,以使元学习训练后的待训练嵌入网络模型获得利用已有先验知识对新样本进行分类的能力,进而得到小样本分类嵌入网络模型。
优选地,所述查询集中样本的分类结果与支持集中样本的分类结果之间的差异用公式表示为:
Lmeta=log∑nexp(-D(f(q),wi)) (3)
式(3)中,q表示查询集中样本,wi表示待训练嵌入网络模型对支持集中第i个样本的分类结果,
Figure BDA0003763087360000041
表示查询集中样本与支持集中样本的分类结果之间的差异。
优选地,所述步骤S3的具体实现方式为:首先计算自蒸馏训练过程中的损失函数,即计算自蒸馏训练中第k代预测值和第k代实际值之间的差异以及第k代预测值与第k-1代预测值之间的差异;然后利用诱导网络降低两种差异值,不断迭代直至两种差异值之和减小到最小值并在最小值附近趋于稳定,从而得到训练好的SMSR模型。
优选地,所述自蒸馏训练第k次迭代训练中损失函数取最小值时的模型参数化用公式表示为:
Figure BDA0003763087360000042
式(4)中,
Figure BDA0003763087360000043
表示自蒸馏训练第k次迭代得到的参数化模型,KL表示KullbackLeibler散度,α表示超参数,Q表示自蒸馏训练的数据集,Lce为交叉熵损失,
Figure BDA0003763087360000044
表示参数化模型,
Figure BDA0003763087360000051
表示参数化模型
Figure BDA0003763087360000052
对数据集Q的分类结果。
优选地,还包括:S4、将实时获取的遥感影像输入训练好的SMSR模型中进行分类,进而实现遥感影像的场景分类。
与现有技术比较,本发明利用自监督学习训练预设的深度神经网络得到具有特征提取能力的嵌入网络模型,即从少量无标签的数据中提取先验知识的能力,然后将该嵌入网络模型应用到元学习框架中进行小样本遥感影像分类训练后得到利用已知先验知识对新样本进行分类的小样本分类嵌入网络模型,再将小样本分类嵌入网络模型经过自蒸馏进行精简优化,从而得到训练好的SMSR模型,最后利用训练好的SMSR模型即可对实时获取的遥感影像进行分类,实现小样本遥感影像的场景分类,基于所述训练好的SMSR模型对遥感影像进行场景分类的分类结果更优于传统的有监督学习遥感影像场景分类方法的分类结果。
附图说明
图1是本发明一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法的流程图,
图2是本发明实施例中MASATI遥感影像数据集的部分示意图,
图3是本发明中遥感图像基于自监督学习网络输出特征图的流程图。
具体实施方式
为了使本技术领域的人员更好地理解本发明的技术方案,下面结合附图对本发明作进一步的详细说明。
如图1-图3所示,一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,所述方法包括以下步骤:
S1、基于自监督学习并利用遥感图像数据对预设的深度神经网络进行训练,进而获取一个能够提取遥感影像特征信息的嵌入网络模型;
S2、将步骤S1中所获取的嵌入网络模型嵌入到元学习的框架中进行训练,并使元学习训练后的嵌入网络模型能够利用已有先验知识对新样本进行分类,进而得到小样本分类嵌入网络模型;
S3、基于自蒸馏训练对小样本分类嵌入网络模型进行迭代优化,进而完成面向小样本遥感影像分类的自监督自蒸馏元学习,得到训练好的SMSR(Self-supervised Metalearning and Self-distillation Remote Sensing Scene Classification,面向小样本遥感影像场景分类的自监督自蒸馏元学习方法)模型;
S4、将实时获取的遥感影像输入训练好的SMSR模型中进行分类,从而实现遥感影像的场景分类。
本实施例中,首先,利用自监督学习训练预设的深度神经网络得到具有特征提取能力的嵌入网络模型;然后,将嵌入网络模型应用到元学习框架中进行小样本遥感影像分类训练后并使用自蒸馏进行精简优化,从而得到训练好的SMSR模型;最后利用训练好的SMSR模型对实时获取的遥感影像进行分类,实现小样本遥感影像场景分类。基于所述训练好的SMSR模型对遥感影像进行场景分类的分类结果更优于传统的有监督学习遥感影像场景分类方法的分类结果。
其中,所述步骤S1具体实现方式为:
S11、将对比学习设置为自监督学习的辅助任务,并选取NCE(Noise ContrastiveEstimation,噪音对比估计)损失为对比学习的衡量指标,NCE损失用于计量目标对象之间的差别程度和互信息的大小,所述对比学习通过选取来源于预设的深度神经网络中不同卷积层所生成的特征图进行对比,进而能够进一步扩大对比对象之间的差别,提升对比学习的复杂程度;
S12、利用不同数据增强方法对同一张遥感图像进行数据增强,并计算在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失,其中,不同数据增强方法包括缩放、随机裁剪、随机中心裁剪、仿射变换、颜色抖动、随机水平翻转和随机灰度变换;
S13、利用随机梯度下降法不断减少在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失,直至该NCE损失的值趋于稳定,进而得到能够提取遥感影像特征信息的嵌入网络模型。
其中,在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失用公式表示为:
L(xm,xn)=L(fa(xm),f5(xn))+L(fa(xm),fe(xn))+L(f5(xm),f5(xn))+L(fa(xn),f5(xm))+L(fa(xn),fe(xm))+L(fe(xm),fe(xn)) (1)
式(1)中,xm和xn分别表示同一张遥感影像经过两种不同数据增强方法进行数据增强的处理结果,fa表示预设的深度神经网络提取的全局特征图,f5表示预设的深度神经网络中第五层卷积提取的局部特征图,fe表示预设的深度神经中最后一层卷积提取的网络最后一层输出特征图,其中,
Figure BDA0003763087360000071
Figure BDA0003763087360000072
Figure BDA0003763087360000073
Figure BDA0003763087360000074
Figure BDA0003763087360000075
Figure BDA0003763087360000076
N表示所有遥感影像,d表示欧几里得距离的平方。
本实施例中,通过将对比学习作为自监督学习的辅助任务,当两个对比目标之间的差异越小,其互信息就越高,通过不断扩大两张遥感影像样本之间的互信息,促使深度神经网络学习到遥感影像样本中的特征信息,同时,为了扩大不同数据增强结果之间的差异,避免模型训练的坍缩,采用不同的数据增强方法对同一张遥感图像进行数据增强,然后计算在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失,最后通过不断减少NCE损失直至其值趋于稳定后,即得到能够提取遥感影像特征信息的嵌入网络模型。
其中,所述步骤S2的具体实现方式为:
S21、将步骤S1所获取的嵌入网络模型嵌入到元学习的框架中并作为待训练嵌入网络模型;
S22、将由所有遥感影像构成的数据集中的样本分为若干个种类,从每一个种类中均随机抽取K个样本作为支持集,其余样本作为查询集,并将支持集中的样本输入待训练嵌入网络模型中进行度量,得到支持集中第i个样本的分类结果wi,用公式表示为:
Figure BDA0003763087360000081
式(2)中,f()表示待训练嵌入网络模型,S表示支持集,yi表示支持集中的第i个样本,zi表示样本yi的对应标签,其中,所述查询集用于测试小样本分类嵌入网络模型的分类性能;
S23、将查询集中的样本输入待训练嵌入网络模型得到查询集中样本的分类结果,然后计算查询集中样本的分类结果与支持集中样本的分类结果之间的差异,进而判断出查询集中样本的种类,其中,查询集中样本的分类结果与支持集中样本的分类结果之间的差异最小的种类即为查询集中样本的种类,计算查询集中每一个样本的分类结果与计算支持集中第i个样本的分类结果相同,此处不再赘述;
S24、利用随机梯度下降法不断降低查询集中样本的分类结果与支持集中样本的分类结果之间的差异,以使元学习训练后的待训练嵌入网络模型获得利用已有先验知识对新样本进行分类的能力,进而得到小样本分类嵌入网络模型,其中,所述查询集中样本的分类结果与支持集中样本的分类结果之间的差异用公式表示为:
Lmeta=logΣnexp(-D(f(q),wi)) (3)
式(3)中,q表示查询集中样本,wi表示待训练嵌入网络模型对支持集中第i个样本的分类结果,
Figure BDA0003763087360000091
表示查询集中样本的分类结果与支持集中样本的分类结果之间的差异。
本实施例中,由于支持集和查询集中存在多种类的样本,对于多种类的样本综合考虑时,因此多种样本分类结果之间的差异需要指数求和再对数进行计算。其中,小样本分类嵌入网络模型的获取具体方式为:将自监督学习训练得到的嵌入网络模型,进而使得嵌入网络模型得到从少量无标签的数据中提取先验知识的能力,再将该嵌入网络模型应用到元学习的框架中进行训练,有效加深了网络学习进度,得到了利用已知先验知识对新样本进行分类的小样本分类嵌入网络模型。
需要说明的是,在本实施例中,对于查询集中样本的分类结果与支持集中样本的分类结果之间的差异,由于待分类样本的种类不止一种,故多种待分类样本综合考虑时需要使用公式(3)中的指数求和再对数进行计算,相比直接将查询集中样本的分类结果与支持集中样本的分类结果之间的差异相加,这样处理可以避免单个样本对整体影响过大,造成过拟合。
其中,所述步骤S3的具体实现方式为:首先计算自蒸馏训练过程中的损失函数,即计算自蒸馏训练中第k代预测值和第k代实际值之间的差异以及第k代预测值与第k-1代预测值之间的差异;然后利用诱导网络降低两种差异值,不断迭代直至两种差异值之和减小到最小值并在最小值附近趋于稳定,从而得到训练好的SMSR模型。本实施例中,利用自蒸馏对小样本分类嵌入网络模型进行优化,使得模型的结构更加精简,其模型的精度大大得到提升。
其中,所述自蒸馏训练第k次迭代训练中损失函数取最小值时的模型参数化用公式表示为:
Figure BDA0003763087360000092
式(4)中,
Figure BDA0003763087360000093
表示自蒸馏训练第k次迭代得到的参数化模型,KL表示KullbackLeibler散度,α表示超参数,Q表示自蒸馏训练的数据集,Lce为交叉熵损失,
Figure BDA0003763087360000101
表示参数化模型,
Figure BDA0003763087360000102
表示参数化模型
Figure BDA0003763087360000103
对数据集Q的分类结果。
为了进一步理解本发明的工作原理和技术效果,下面进行了一系列相关实验予以说明。
1、选取数据集
如图2所示,数据集中的遥感影像数据来源于MAritime SATellite Imagerydataset遥感影像数据集(记为MASATI),该数据集于2016年3月至2019年6月间从欧洲、非洲、亚洲、地中海以及大西洋和太平洋的不同区域汇编而来,提供可见光谱的光学遥感影像场景。数据集中含有陆地、海岸、海洋、船舶等七类样本,每张的分辨率为512x512,总计约7389张影像。
2、深度神经网络模型训练与结果
如图3所示,利用自监督学习对预设的深度神经网络进行训练,自监督学习网络对输入的遥感影像进行特征提取,生成对应的特征图进行对比学习,训练结束后得到的自监督网络即为嵌入网络模型。
表1为自监督学习卷积网络的结构参数,其中,每个卷积块由多个卷积层组成,从Conv3x到Conv7x中的每个卷积块都包含16层卷积(Conv表示卷积层),使用1x1卷积来减小通道数,降低模型复杂度,每个卷积块内使用相同的输出通道数,前向运算时将依次将每块的输入和输出在通道中相连接,构成自监督学习网络。
表1自监督学习网络结构参数
Figure BDA0003763087360000104
Figure BDA0003763087360000111
基于自监督的元学习训练时,选取的遥感影像小样本场景分类任务为1-shot 5-way。训练前将遥感影像图片预处理裁剪为256×256并输入网络中,优化器使用SGD(随机梯度下降法),learning rate(学习率)取0.0001,dropout(深度神经网络中的一个参数)取0.5,,batch size(批尺寸)取32,总共训练300个epoch(其中,一个epoch表示将所有样本训练一次的过程)。
对模型进行自蒸馏训练时,在同一数据集上使用基于自监督的元学习训练相得到的模型进行自蒸馏,超参数α取0.5,总共进行10个epoch的自蒸馏训练,经过自蒸馏得到的模型命名为SMSR模型,即基于自监督元学习与自蒸馏小样本遥感影像场景分类方法模型。
为了进行对比,本实验选取了几种经典的有监督小样本学习算法在1-shot 5-way(其中,5-Way就是5路或5类,1-Shot就是1次或1个)任务中进行训练,同时在数据集中进行测试,测试结果如表2所示:
表2 MASATI 1-shot 5-way准确率对比,置信度选取95%
算法名称 嵌入网络 1-shot5-way准确率
MAML Convnet 53.70±0.82%
RelationNet Convnet 54.21±0.14%
SNAIL ResNet50 56.25±0.29%
ProtoNet ResNet50 57.63±0.66%
Self-jig ResNet50 58.16±0.84%
SMSR 自监督学习网络 59.62±0.16%
由表2可知,本发明提出的SMSR模型分类方法在MASATI遥感影像数据集(深度学习中的遥感影像数据集)中进行的1-shot 5-way任务中,相比经典有监督小样本学习算法准确率有一定的优势,其中比Self-jig准确率提高了约1.5%。
3消融实验
为了验证本发明所提出的SMSR模型分类方法中所采用的自监督元学习和自蒸馏的有效性,还进一步做了一系列消融实验。
3.1自监督学习消融实验
为了验证自监督学习的有效性,使用有监督训练代替SMSR中的自监督训练,嵌入模型直接使用ResNet50(一种神经网络)代替自监督学习网络进行元学习,其他部分保持和SMSR模型相同,并将该算法命名为MSR模型,在和SMSR模型相同的训练环境下对MSR模型进行训练并进行测试,实验结果如表3所示。
表3 MASATI 1-shot 5-way准确率对比,置信度选取95%
Figure BDA0003763087360000121
Figure BDA0003763087360000131
由表3可知,有监督学习在仅使用少量有标签数据训练嵌入网络时,会因为样本太少发生过拟合,进一步导致元学习模型泛化能力不足,在1-shot 5-way任务中的性能不如使用自监督学习的SMSR模型,因此,通过自监督学习可以得到性能良好的嵌入网络模型,提高元学习模型的性能。
3.2自蒸馏消融实验
为了验证自蒸馏的有效性,去除SMSR模型中的自蒸馏部分,其他结构与SMSR模型保持一致,在与SMSR模型相同的训练环境下进行训练和测试,实验结果如表4所示。
表4 MASATI 1-shot 5-way准确率对比,置信度选取95%
算法名称 基础网络 1-shot5-way准确率
SMSR- 自监督学习网络 62.83±0.91%
SMSR 自监督学习网络 64.54±0.16%
其中,SMSR-模型代表不进行自蒸馏的SMSR模型,由表4可以看出,经过自蒸馏训练的SMSR模型比SMSR-准确率高了约1.7%,由此可见,自蒸馏训练能够使得模型更加精炼。
以上对本发明所提供的一种面向小样本遥感影像场景分类的自监督自蒸馏元学习方法进行了详细介绍。本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的核心思想。应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以对本发明进行若干改进和修饰,这些改进和修饰也落入本发明权利要求的保护范围内。

Claims (10)

1.面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述方法包括以下步骤:
S1、基于自监督学习并利用遥感图像数据对预设的深度神经网络进行训练,进而得到能够提取遥感影像特征信息的嵌入网络模型;
S2、将步骤S1中所获取的嵌入网络模型嵌入到元学习的框架中进行训练,并使元学习训练后的嵌入网络模型能够利用已有先验知识对新样本进行分类,进而得到小样本分类嵌入网络模型;
S3、基于自蒸馏训练对小样本分类嵌入网络模型进行迭代优化,进而完成面向小样本遥感影像分类的自监督自蒸馏元学习,得到训练好的SMSR模型。
2.如权利要求1所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述步骤S1具体实现方式为:
S11、将对比学习设置为自监督学习的辅助任务,并选取NCE损失为对比学习的衡量指标;
S12、利用不同数据增强方法对同一张遥感图像进行数据增强,并计算在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失;
S13、利用随机梯度下降法不断减少在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失,直至该NCE损失的值趋于稳定,进而得到能够提取遥感影像特征信息的嵌入网络模型。
3.如权利要求2所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述步骤S12中不同数据增强方法包括缩放、随机裁剪、随机中心裁剪、仿射变换、颜色抖动、随机水平翻转和随机灰度变换。
4.如权利要求3所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述对比学习通过选取来源于预设的深度神经网络中不同卷积层所生成的特征图进行对比。
5.如权利要求4所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述步骤S12中在不同数据增强方法下该遥感图像的数据增强处理结果之间的NCE损失用公式表示为:
L(xm,xn)=L(fa(xm),f5(xn))+L(fa(xm),fe(xn))+L(f5(xm),f5(xn))+L(fa(xn),f5(xm))+L(fa(xn),fe(xm))+L(fe(xm),fe(xn)) (1)
式(1)中,xm和xn分别表示同一张遥感影像经过两种不同数据增强方法进行数据增强的处理结果,fa表示预设的深度神经网络提取的全局特征图,f5表示预设的深度神经网络中第五层卷积提取的局部特征图,fe表示预设的深度神经中最后一层卷积提取的网络最后一层输出特征图,其中,
Figure FDA0003763087350000021
Figure FDA0003763087350000022
Figure FDA0003763087350000023
Figure FDA0003763087350000024
Figure FDA0003763087350000025
Figure FDA0003763087350000026
N表示所有遥感影像,d表示欧几里得距离的平方。
6.如权利要求5所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述步骤S2的具体实现方式为:
S21、将步骤S1所获取的嵌入网络模型嵌入到元学习的框架中并作为待训练嵌入网络模型;
S22、将由所有遥感影像构成的数据集中的样本分为若干个种类,从每一个种类中均随机抽取K个样本作为支持集,其余样本作为查询集,并将支持集中的样本输入待训练嵌入网络模型中进行度量,得到支持集中第i个样本的分类结果wi,用公式表示为:
Figure FDA0003763087350000031
式(2)中,f()表示待训练嵌入网络模型,S表示支持集,yi表示支持集中的第i个样本,zi表示样本yi的对应标签;
S23、将查询集中的样本输入待训练嵌入网络模型得到查询集中样本的分类结果,然后计算查询集中样本的分类结果与支持集中样本的分类结果之间的差异,进而判断出查询集中样本的种类;
S24、利用随机梯度下降法不断降低查询集中样本的分类结果与支持集中样本的分类结果之间的差异,以使元学习训练后的待训练嵌入网络模型获得利用已有先验知识对新样本进行分类的能力,进而得到小样本分类嵌入网络模型。
7.如权利要求6所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述查询集中样本的分类结果与支持集中样本的分类结果之间的差异用公式表示为:
Lmeta=log∑nexp(-D(f(q),wi)) (3)
式(3)中,q表示查询集中样本,wi表示待训练嵌入网络模型对支持集中第i个样本的分类结果,
Figure FDA0003763087350000032
表示查询集中样本与支持集中样本的分类结果之间的差异。
8.如权利要求7所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述步骤S3的具体实现方式为:首先计算自蒸馏训练过程中的损失函数,即计算自蒸馏训练中第k代预测值和第k代实际值之间的差异以及第k代预测值与第k-1代预测值之间的差异;然后利用诱导网络降低两种差异值,不断迭代直至两种差异值之和减小到最小值并在最小值附近趋于稳定,从而得到训练好的SMSR模型。
9.如权利要求8所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,所述自蒸馏训练第k次迭代训练中损失函数取最小值时的模型参数化用公式表示为:
Figure FDA0003763087350000041
式(4)中,
Figure FDA0003763087350000042
表示自蒸馏训练第k次迭代得到的参数化模型,KL表示为KullbackLeibler散度,α表示超参数,Q表示自蒸馏训练的数据集,Lce为交叉熵损失,
Figure FDA0003763087350000043
表示参数化模型,
Figure FDA0003763087350000044
表示参数化模型
Figure FDA0003763087350000045
对数据集Q的分类结果。
10.如权利要求1-9任一项所述的面向小样本遥感影像场景分类的自监督自蒸馏元学习方法,其特征在于,还包括:
S4、将实时获取的遥感影像输入训练好的SMSR模型中进行分类,进而实现遥感影像的场景分类。
CN202210878368.2A 2022-07-25 2022-07-25 面向小样本遥感影像场景分类的自监督自蒸馏元学习方法 Pending CN115100543A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210878368.2A CN115100543A (zh) 2022-07-25 2022-07-25 面向小样本遥感影像场景分类的自监督自蒸馏元学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210878368.2A CN115100543A (zh) 2022-07-25 2022-07-25 面向小样本遥感影像场景分类的自监督自蒸馏元学习方法

Publications (1)

Publication Number Publication Date
CN115100543A true CN115100543A (zh) 2022-09-23

Family

ID=83298644

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210878368.2A Pending CN115100543A (zh) 2022-07-25 2022-07-25 面向小样本遥感影像场景分类的自监督自蒸馏元学习方法

Country Status (1)

Country Link
CN (1) CN115100543A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116071609A (zh) * 2023-03-29 2023-05-05 中国科学技术大学 基于目标特征动态自适应提取的小样本图像分类方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116071609A (zh) * 2023-03-29 2023-05-05 中国科学技术大学 基于目标特征动态自适应提取的小样本图像分类方法
CN116071609B (zh) * 2023-03-29 2023-07-18 中国科学技术大学 基于目标特征动态自适应提取的小样本图像分类方法

Similar Documents

Publication Publication Date Title
CN106909924B (zh) 一种基于深度显著性的遥感影像快速检索方法
CN114037844B (zh) 基于滤波器特征图的全局秩感知神经网络模型压缩方法
CN110781924B (zh) 一种基于全卷积神经网络的侧扫声纳图像特征提取方法
CN110222767B (zh) 基于嵌套神经网络和栅格地图的三维点云分类方法
CN110334645B (zh) 一种基于深度学习的月球撞击坑识别方法
CN111753995B (zh) 一种基于梯度提升树的局部可解释方法
CN114266988A (zh) 基于对比学习的无监督视觉目标跟踪方法及系统
CN111178438A (zh) 一种基于ResNet101的天气类型识别方法
Hassam et al. A single stream modified mobilenet v2 and whale controlled entropy based optimization framework for citrus fruit diseases recognition
CN113269224A (zh) 一种场景图像分类方法、系统及存储介质
CN114898775B (zh) 一种基于跨层交叉融合的语音情绪识别方法及系统
CN115965862A (zh) 基于掩码网络融合图像特征的sar舰船目标检测方法
CN112597919A (zh) 基于YOLOv3剪枝网络和嵌入式开发板的实时药盒检测方法
CN114550014B (zh) 道路分割方法及计算机装置
CN115100543A (zh) 面向小样本遥感影像场景分类的自监督自蒸馏元学习方法
CN116883650A (zh) 一种基于注意力和局部拼接的图像级弱监督语义分割方法
CN111815526A (zh) 基于图像滤波和cnn的有雨图像雨条纹去除方法及系统
CN115147727A (zh) 一种遥感影像不透水面提取方法及系统
CN116912685B (zh) 一种水体识别方法、系统及电子设备
CN117409316A (zh) 一种基于TransUNet的地震数据喀斯特特征智能识别定位方法
CN115994242A (zh) 影像检索方法、装置、设备及存储介质
CN114299091A (zh) 一种基于DA-Net的杂草自动分割方法
CN113128362A (zh) 一种基于yolov3的无人机视角下小目标快速检测方法
CN116524358B (zh) 用于目标识别的sar数据集扩增方法
CN117593755B (zh) 一种基于骨架模型预训练的金文图像识别方法和系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination