CN115482461A - 基于自监督学习和最近邻网络的小样本sar目标分类方法 - Google Patents
基于自监督学习和最近邻网络的小样本sar目标分类方法 Download PDFInfo
- Publication number
- CN115482461A CN115482461A CN202211032994.6A CN202211032994A CN115482461A CN 115482461 A CN115482461 A CN 115482461A CN 202211032994 A CN202211032994 A CN 202211032994A CN 115482461 A CN115482461 A CN 115482461A
- Authority
- CN
- China
- Prior art keywords
- training
- feature extraction
- network model
- self
- training stage
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/10—Terrestrial scenes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
- G06V10/806—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Multimedia (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种基于自监督学习和最近邻网络的小样本SAR目标分类方法,包括:从若干合成孔径雷达图像中获取训练任务集和测试任务集;构建预训练阶段网络模型;利用训练任务集对预训练阶段网络模型进行迭代训练;构建自监督训练阶段网络模型;将训练好的预训练阶段网络模型的参数加载到自监督训练阶段网络模型中,并利用训练任务集对加载后的自监督训练阶段网络模型进行迭代训练;将测试任务集输入训练好的第三特征提取模块进行特征提取,得到测试特征向量组集合;计算测试特征向量组集合中测试查询样本对应的测试特征向量和每一个测试支撑样本对应的测试特征向量的相似度,得到目标分类结果。该分类方法有效地提高了小样本目标的分类准确率。
Description
技术领域
本发明属于雷达图像处理技术领域,具体涉及一种基于自监督学习和最近邻网络的小样本SAR目标分类方法。
背景技术
合成孔径雷达(Synthetic Aperture Radar,SAR)是一种主动式的对地观测系统,可安装在飞机、卫星、宇宙飞船等飞行平台上,并具有全天时、全天候对地实施观测的优势,对地表有一定的穿透能力。因此,SAR系统在灾害监测、环境监测、海洋监测、资源勘查、测绘和军事等领域的应用上具有独特的优势,因此越来越受到广泛的关注。
SAR目标分类是一种根据不同类别的目标各自在图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法。当前主流的SAR目标分类方法可分为基于传统的分类方法和基于深度学习的分类方法。基于传统的SAR目标分类方法是基于色彩、纹理、形状、空间关系等图像特征对图像目标进行分类,通常采用人工选取特征、设计分类器,然而这往往需要依据大量的经验和较强的专业知识来针对特定目标设定特定算法,耗时长且难以推广,同时无法获得较好的分类精度。近年来随着深度学习的发展,涌现出了许多VGG、ResNet、DenseNet等一系列强大的卷积神经网络,可对SAR目标自动提取更加稳健的特征,从而取得较好的SAR目标分类结果,得到了研究人员的广泛应用。
基于深度学习的SAR目标分类方法通常需要大量的训练样本来训练模型以在测试样本上获得高的分类准确率,但是,现实中能获取的SAR图像数量较少。对于小样本SAR目标,基于深度学习的SAR目标分类方法会由于训练样本的不足出现过拟合现象,进而导致SAR目标分类准确率低。
为了解决这个问题,现有技术通过改进模型结构,设计对样本数量要求较低的特殊模型来提高小样本SAR目标的分类准确率。例如现有技术公开了一种基于混合推理网络的小样本SAR自动目标识别方法,该方法提出的基于混合推理网络的小样本学习方法克服了传统小样本学习中单独采用归纳推理或转导推理的不足,提高了识别性能;采用增强混合损失来约束嵌入网络进行学习,将样本映射到该嵌入空间,在该空间中,归纳推理和转导推理都能很好地执行;采用混合推断网络,在目标分类问题中只需要少量训练样本就能获得和传统SAR目标识别方法接近的识别正确率,在SAR图像样本有限的情况下,有效地提高了识别精度。但是,由于标记样本有限,该方法无法高度依赖嵌入网络获得更好的特征表示;模型收敛较慢,模型泛化能力较差;采用的欧几里得距离度量方式存在一定的偏差,在样本较少时,并不是很高效的度量方法。因此,该方法对SAR目标分类的精度依然较低。
发明内容
为了解决现有技术中存在的上述问题,本发明提供了一种基于自监督学习和最近邻网络的小样本SAR目标分类方法。本发明要解决的技术问题通过以下技术方案实现:
本发明实施例提供了一种基于自监督学习和最近邻网络的小样本SAR目标分类方法,包括步骤:
S1、从若干合成孔径雷达图像中获取训练任务集和测试任务集,其中,所述训练任务集中每个训练任务包括训练支撑样本和训练查询样本,所述测试任务集中每个测试任务包测试支撑样本集和测试查询样本;
S2、构建预训练阶段网络模型,其中,所述预训练阶段网络模型包括依次级联的第一特征提取模块、多层感知机模块和相似度计算模块;
S3、利用所述训练任务集对预训练阶段网络模型进行迭代训练,并利用交叉熵损失函数对所述预训练阶段网络模型进行更新,得到训练好的预训练阶段网络模型;
S4、构建自监督训练阶段网络模型,其中,所述自监督训练阶段网络模型包括并联的第二特征提取模块和第三特征提取模块,所述第一特征提取模块、所述第二特征提取模块和所述第三特征提取模块的结构相同;
S5、将所述训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述自监督训练阶段网络模型中,并利用所述训练任务集对加载后的自监督训练阶段网络模型进行迭代训练,在训练过程中,利用负余弦相似度损失函数和自监督损失函数对所述第三特征提取模块进行更新,得到训练好的第三特征提取模块;
S6、将所述测试任务集输入到所述训练好的第三特征提取模块进行特征提取,得到测试特征向量组集合;
S7、计算所述测试特征向量组集合中所述测试查询样本对应的测试特征向量和每一个所述测试支撑样本对应的测试特征向量的相似度,得到目标分类结果。
在本发明的一个实施例中,所述第一特征提取模块包括依次级联的第一特征提取子模块和第二特征提取子模块,其中,
所述第一特征提取子模块包括依次级联的第一卷积层、第一批量归一化层、第一激活函数层和最大池化层;
所述第二特征提取子模块包括依次级联的第二卷积层、第二批量归一化层和第二激活函数层;
所述多层感知机模块包括全连接层。
在本发明的一个实施例中,所述第一卷积层和所述第二卷积层中卷积核的个数均为64,卷积核的大小均为3×3,步长和填充均为1;所述第一激活函数层和所述第二激活函数层均采用Leaky Relu,其参数均为0.2;所述最大池化层的池化核大小为2×2,步长为2。
在本发明的一个实施例中,步骤S3包括:
S31、初始化设置预训练阶段迭代次数、预训练阶段最大迭代次数和第n次迭代的预训练阶段网络模型;
S32、利用所述第一特征提取模块对所述每个训练任务进行特征提取,得到第一特征向量组集合;
S33、利用所述多层感知机模块将所述特征向量组集合中的每个特征映射到样本标记空间,得到第二特征向量组集合;
S34、利用所述相似度计算模块计算所述第二特征向量组集合中所述训练查询样本对应的特征向量与每个所述训练支撑样本对应的特征向量的相似度,得到相似度得分集;
S35、采用交叉熵损失函数计算每次迭代过程中所述预训练阶段网络模型的第一损失值,并采用梯度下降法将所述第一损失值对第一权值参数的偏导在所述预训练阶段网络模型中进行反向传播,以对所述第一权值参数进行更新,得到更新后的第一权值参数;
S36、判断迭代是否完成,若是,则得到所述训练好的预训练阶段网络模型;若否,则继续进行迭代训练。
在本发明的一个实施例中,利用所述相似度计算模块计算相似度的公式为:
其中,表示与之间的相似度得分,表示支撑样本依次经过第一特征提取模块D和多层感知机模块E得到的特征向量,表示查询样本经过特征提取模块D和多层感知机模块E得到的特征向量,m表示支撑特征向量被划分的局部描述子的个数,z表示查询特征向量被划分的局部描述子的个数,xi表示支撑特征向量中第i个局部描述子,表示查询特征向量中第j个局部描述子。
在本发明的一个实施例中,所述第一损失值为:
所述更新后的第一权值参数为:
在本发明的一个实施例中,所述第二特征提取模块包括依次级联的第三卷积层、第三批量归一化层、第三激活函数层和第二最大池化层;
所述第三特征提取模块包括依次级联的第四卷积层、第四批量归一化层、第四激活函数层和第三最大池化层。
在本发明的一个实施例中,步骤S5包括:
S51、初始化设置自监督训练阶段迭代次数、自监督训练阶段最大迭代次数和第n次迭代的自监督训练阶段网络模型;
S52、将所述训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述第二特征提取模块中,利用加载后的所述第二特征提取模块对所述每个训练任务进行特征提取,得到第三特征向量组集合;将训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述第三特征提取模块中,利用加载后的所述第二特征提取模块对所述每个训练任务进行特征提取,得到第四特征向量组集合;
S53、采用负余弦相似度损失函数和自监督损失函数之和作为训练损失函数来计算每次迭代过程中所述自监督训练阶段网络模型的第二损失值,然后再利用所述第二损失值对第二权值参数的偏导在所述自监督训练阶段网络模型中进行反向传播以对所述第二权值参数进行更新,得到更新后的第二权值参数;
S54、判断迭代是否完成,若是,则得到所述训练好的自监督训练阶段网络模型;若否,则继续进行迭代训练。
在本发明的一个实施例中,所述训练损失函数为:
L=Lcos+Lself-supervised
其中,Lcos为负余弦相似度损失函数,Lself-supervised为自监督损失函数;
所述负余弦相似度损失函数为:
其中,D表示第二特征提取模块,G表示第三特征提取模块,||·||2表示l2正则化;
所述自监督损失函数为:
Lself-supervised(x)=U(P1,Z2)+U(P2,Z1)
其中,P1表示X1经过一个h模块得到的特征向量组,P2表示X2经过一个h模块得到的特征向量组,X1和X2分别表示训练任务集经过数据增强后得到的两种效果图,Z1表示X1经过一个f模块得到的特征向量组,Z2表示X2经过一个f模块得到的特征向量组X2。
与现有技术相比,本发明的有益效果:
本发明的分类方法引入了预训练阶段网络模型,通过将训练好的预训练阶段网络模型的参数加载到自监督训练阶段网络模型中,可以加快自监督训练阶段网络模型的收敛并提高模型的泛化能力;在对自监督训练阶段网络模型进行训练时,引入了自监督损失函数和负余弦相似度损失函数可获得更好的特征表示,并且引入高效的最近邻度量方式以计算支撑样本和查询样本之间的相似度,从而,该方法避免了现有技术模型收敛速度慢、模型泛化能力不强、无法获得更好的特征表示和度量方式不高效导致对SAR目标分类精度的影响,有效地提高了小样本目标的分类准确率。
附图说明
图1为本发明实施例提供的一种基于自监督学习和最近邻网络的小样本SAR目标分类方法的流程示意图;
图2为本发明实施例提供的一种对预训练阶段网络模型进行迭代训练的实现流程图;
图3为本发明实施例提供的一种对自监督训练阶段网络模型进行迭代训练的实现流程图;
图4为本发明实施例提供的一种获取小样本SAR图像的目标分类结果的实现流程图。
具体实施方式
下面结合具体实施例对本发明做进一步详细的描述,但本发明的实施方式不限于此。
实施例一
请参见图1,图1为本发明实施例提供的一种基于自监督学习和最近邻网络的小样本SAR目标分类方法的流程示意图,该小样本SAR目标分类方法包括步骤:
S1、从若干合成孔径雷达图像中获取训练任务集和测试任务集;其中,所述训练任务集中每个训练任务包括训练支撑样本和训练查询样本,所述测试任务集中每个测试任务包测试支撑样本集和测试查询样本。具体包括步骤:
S11、获取U幅合成孔径雷达SAR图像,每幅合成孔径雷达SAR图像的大小为h×h,U幅合成孔径雷达SAR图像包含C个不同目标类别,每个目标类别对应M幅SAR图像,其中C≥10,M≥200,64≤h≤2560,U≥2000。
S12、对每幅SAR图像中的目标类别进行标记,并随机选取Ctrain个目标类别对应的Ctrain×M幅SAR图像及其标签组成训练数据集将其余的Ctest个目标类别对应的Ctest×M幅SAR图像及其标签组成测试数据集其中Ctrain+Ctest=C,∩表示交集,Ctrain>C/2,Ctrain×M+Ctest×M=U。
S13、从训练数据集中随机选取的包含Ctest个目标类别的z=Ctest×M幅SAR图像,对选取的每幅SAR图像的标签进行one-hot编码,得到对应的标签向量集合,并选取该标签向量集合中z1=Ctest×K个标签向量及其对应的SAR图像组成训练支撑样本集将其余z2=Ctest(M-K)个标签向量及其对应的SAR图像组成训练查询样本集然后将与每个训练查询样本进行组合,得到训练任务集其中,表示由SAR图像及其对应的标签向量组成的第a个训练支撑样本,表示由SAR图像及其对应的标签向量组成的第b个训练查询样本,1≤K≤10,表示第b个训练任务,
S14、从测试数据集中随机选取的包含Ctest个目标类别的z=Ctest×M幅SAR图像,对选取的每幅SAR图像的标签进行one-hot编码,得到对应的标签向量集合,并选取该标签向量集合中z1=Ctest×K个标签向量及其对应的SAR图像组成测试支撑样本集将其余z2=Ctest(M-K)个标签向量及其对应的SAR图像组成测试查询样本集然后将与每个测试查询样本进行组合,得到测试任务集其中,表示由SAR图像及其对应的标签向量组成的第e测试支撑样本,表示由SAR图像及其对应的标签向量组成的第g个测试查询样本,表示第g个测试任务,
S2、构建预训练阶段网络模型。
具体的,构建的预训练阶段网络模型包括依次级联的第一特征提取模块Da、多层感知机模块E和相似度计算模块R。
在一个具体实施例中,所述第一特征提取模块Da包括依次级联的第一特征提取子模块D1和第二特征提取子模块D2。其中,所述第一特征提取子模块D1包括依次级联的第一卷积层、第一批量归一化层、第一激活函数层和最大池化层。所述第二特征提取子模块D2包括依次级联的第二卷积层、第二批量归一化层和第二激活函数层。所述多层感知机模块E包括一个全连接层。
在一个具体实施例中,第一特征提取子模块D1和第二特征提取子模块D2中,所述第一卷积层和所述第二卷积层中卷积核的个数均为64,卷积核的大小均为3×3,步长和填充均为1;所述第一激活函数层和所述第二激活函数层均采用Leaky Relu,其参数均为0.2;所述最大池化层的池化核大小为2×2,步长为2。
S3、利用所述训练任务集对预训练阶段网络模型进行迭代训练,并利用交叉熵损失函数对所述预训练阶段网络模型进行更新,得到训练好的预训练阶段网络模型。
请参见图2,图2为本发明实施例提供的一种对预训练阶段网络模型进行迭代训练的实现流程图。步骤S3具体包括步骤:
S31、初始化设置预训练阶段迭代次数、预训练阶段最大迭代次数和第n次迭代的预训练阶段网络模型。
S32、利用所述第一特征提取模块对所述每个训练任务进行特征提取,得到第一特征向量组集合。
S33、利用所述多层感知机模块将所述特征向量组集合中的每个特征映射到样本标记空间,得到第二特征向量组集合。
S34、利用所述相似度计算模块计算所述第二特征向量组集合中所述训练查询样本对应的特征向量与所述训练支撑样本对应的每个特征向量的相似度,得到相似度得分集。
其中,表示与之间的相似度得分,表示支撑样本依次经过第一特征提取模块D和多层感知机模块E得到的特征向量,表示查询样本经过特征提取模块D和多层感知机模块E得到的特征向量,m表示支撑特征向量被划分的局部描述子的个数,z表示查询特征向量被划分的局部描述子的个数,xi表示支撑特征向量中第i个局部描述子,表示查询特征向量中第j个局部描述子。
S35、采用交叉熵损失函数计算每次迭代过程中所述预训练阶段网络模型的第一损失值,并采用梯度下降法将所述第一损失值对第一权值参数的偏导在所述预训练阶段网络模型中进行反向传播,以对所述第一权值参数进行更新,得到更新后的第一权值参数。
具体的,采用交叉熵损失函数,并通过每个预测标签ym1和其对应的真实标签计算每次迭代过程中所述预训练阶段网络模型的第一损失值Ls1,然后求取第一损失值Ls1对第一权值参数ωs1的偏导再采用梯度下降法,通过将在预训练阶段网络模型中进行反向传播的方式对第一权值参数ωs1进行更新,得到更新后的第一权值参数。
第一损失值Ls1的计算公式为:
更新后的第一权值参数为:
S36、判断迭代是否完成,若是,则得到所述训练好的预训练阶段网络模型;若否,则继续进行迭代训练。
具体的,判断n≥N是否成立,若是,得到训练好的预训练阶段网络模型H1',否则,令n=n+1,并执行步骤S32。
S4、构建自监督训练阶段网络模型。
具体的,构建的自监督训练阶段网络模型包括并联的第二特征提取模块Db和第三特征提取模块G,所述第一特征提取模块Da、所述第二特征提取模块Db和所述第三特征提取模块G的结构相同。即,第二特征提取模块Db和所述第三特征提取模块G均包括次级联的第一特征提取子模块D1和第二特征提取子模块D2。其中,所述第一特征提取子模块D1包括依次级联的第一卷积层、第一批量归一化层、第一激活函数层和最大池化层。所述第二特征提取子模块D2包括依次级联的第二卷积层、第二批量归一化层和第二激活函数层。所述多层感知机模块E包括一个全连接层。第一特征提取子模块D1和第二特征提取子模块D2中,所述第一卷积层和所述第二卷积层中卷积核的个数均为64,卷积核的大小均为3×3,步长和填充均为1;所述第一激活函数层和所述第二激活函数层均采用Leaky Relu,其参数均为0.2;所述最大池化层的池化核大小为2×2,步长为2。
S5、将所述训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述自监督训练阶段网络模型中,并利用所述训练任务集对加载后的自监督训练阶段网络模型进行迭代训练,在训练过程中,利用负余弦相似度损失函数和自监督损失函数对所述第三特征提取模块进行更新,得到训练好的第三特征提取模块。
请参见图3,图3为本发明实施例提供的一种对自监督训练阶段网络模型进行迭代训练的实现流程图,步骤S5具体包括步骤:
S51、初始化设置自监督训练阶段迭代次数、自监督训练阶段最大迭代次数和第n次迭代的自监督训练阶段网络模型。
S52、将所述训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述第二特征提取模块中,利用加载后的所述第二特征提取模块对所述每个训练任务进行特征提取,得到第三特征向量组集合;将训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述第三特征提取模块中,利用加载后的所述第二特征提取模块对所述每个训练任务进行特征提取,得到第四特征向量组集合。
具体的,将训练任务集作为自监督训练阶段网络模型的输入,第二特征提取模块Db加载训练好的预训练阶段网络模型H1'中第一特征提取模块Da的参数,然后对每个训练任务进行特征提取,得到特征向量组集合其中,第二特征提取模块Db在训练过程中不进行参数更新;第三特征提取模块G加载训练好的预训练阶段网络模型H1'中第一特征提取模块Da的参数,然后对每个训练任务进行特征提取,得到特征向量组集合第三特征提取模块G在训练过程中需要进行参数更新。
S53、采用负余弦相似度损失函数和自监督损失函数之和作为训练损失函数来计算每次迭代过程中所述自监督训练阶段网络模型的第二损失值,然后再利用所述第二损失值对第二权值参数的偏导在所述自监督训练阶段网络模型中进行反向传播以对所述第二权值参数进行更新,得到更新后的第二权值参数。
具体的,采用负余弦相似度损失函数Lcos和自监督损失函数Lself-supervised之和作为训练损失函数L来计算自监督训练阶段网络模型的第二损失值Ls2,然后求取第二损失值Ls2对第二权值参数ωs2的偏导再采用梯度下降法,通过将在自监督训练阶段网络模型中进行反向传播的方式以对第二权值参数ωs2进行更新。
具体的,训练损失函数为:
L=Lcos+Lself-supervised
其中,Lcos为负余弦相似度损失函数,Lself-supervised为自监督损失函数。
所述负余弦相似度损失函数为:
其中,D表示第二特征提取模块,G表示第三特征提取模块,||·||2表示l2正则化。
自监督损失函数的计算过程为:训练任务集进行复制后得到两个任务集,对这两个任务集分别通过随机裁剪、随机水平翻转和高斯滤波三种数据增强方式后得到两种不同的增强效果和X1和X2经过一个f模块后分别得到和X1和X2经过一个h模块后分别得到和其中,f模块包含顺次级联的特征提取模块G、特征映射模块O1;h模块仅包含一个特征映射模块O2;特征映射模块O1包含顺次级联的全连接层、归一化层、激活函数层、全连接层、归一化层、激活函数层、全连接层、归一化层;特征映射模块O2包含顺次级联的全连接层、归一化层、激活函数层、全连接层。从而,自监督损失函数Lself-supervised的计算公式如下:
Lself-supervised(x)=U(P1,Z2)+U(P2,Z1)
其中,P1表示X1经过一个h模块得到的特征向量组,P2表示X2经过一个h模块得到的特征向量组,X1和X2分别表示训练任务集经过数据增强后得到的两种效果图,Z1表示X1经过一个f模块得到的特征向量组,Z2表示X2经过一个f模块得到的特征向量组X2。
S54、判断迭代是否完成,若是,则得到所述训练好的自监督训练阶段网络模型;若否,则继续进行迭代训练。
具体的,判断n≥N是否成立,若是,得到训练好的自监督训练阶段网络模型H'2,否则,令n=n+1,并执行步骤S52。
S6、将所述测试任务集输入到所述训练好的第三特征提取模块进行特征提取,得到测试特征向量组集合。
请参见图4,图4为本发明实施例提供的一种获取小样本SAR图像的目标分类结果的实现流程图。
具体的,将测试任务集作为训练好的自监督训练阶段网络模型H'2的输入进行前向传播,得到所有测试任务集的预测标签。可以理解的是,利用训练好的第三特征提取模块G对测试任务集中的每个测试任务包含的每幅SAR图像进行特征提取,得到测试特征向量组集合
S7、计算所述测试特征向量组集合中所述测试查询样本对应的测试特征向量和每一个所述测试支撑样本对应的测试特征向量的相似度,得到目标分类结果。
具体的,将测试查询样本对应的测试特征向量与中每一个元素通过相似度计算模块R来计算相似度得分,通过此方法,最终得到测试预测结果向量集合其中,z2=Ctest(M-K),每个测试预测结果向量中最大值对应的维数号即为对应的测试查询样本包括的SAR图像中目标的预测类别,从而得到目标分类结果。
进一步的,本实施例结合仿真实验,对基于自监督学习和最近邻网络的小样本SAR目标分类方法的技术效果作进一步的说明:
1.仿真实验条件和内容:
仿真实验的硬件平台为:GPU为NVIDIA GeForce RTX 1650,软件平台为:操作系统为Ubuntu18.04。仿真实验的数据集为公开的AID数据集,其中,C=30,类别为BareLand、Beach、Desert、Meadow、Mountain、Parking、Port、RailwayStation、School、StorageTanks、Airport、BaseballField、Bridge、Center、Church、Commercial、DenseResidential、Farmland、Forest、Industrial、MediumResidential、Park、Playground、Pond、Resort、River、SparseResidential、Square、Stadium、Viaduct。每类目标的SAR图像为290幅,即M=290。
为了和现有的基于混合推理网络的小样本SAR自动目标识别方法对比小样本SAR目标分类准确率,从AID数据集中选取个20目标类别的总共5800幅SAR图像及每幅SAR图像的标签作为训练样本集,即Ctrain=20,类别分别为,Airport、BaseballField、Bridge、Center、Church、Commercial、DenseResidential、Farmland、Forest、Industrial、MediumResidential、Park、Playground、Pond、Resort、River、SparseResidential、Square、Stadium、Viaduct;选取剩余10个目标类别的总共2900幅SAR图像及每幅SAR图像的标签作为测试样本集,Ctest=10,类别分别为,BareLand、Beach、Desert、Meadow、Mountain、Parking、Port、RailwayStation、School、StorageTanks。同时,每个训练/测试任务中每个目标类别采样的训练/测试支撑样本数量K=10,训练/测试查询样本数量M-K=280。
对本实施例方法和现有的基于混合推理网络的小样本SAR自动目标识别方法(即对比算法)在5-way 1-shot、5-way 5-shot两种模式下的平均准确率进行仿真对比,其结果如表1所示:
表1
N-way K-shot | 5-way 1-shot | 5-way 5-shot |
对比算法 | 60.62%±0.35 | 68.85%±0.28 |
本发明 | 67.94%±0.29 | 75.36%±0.21 |
从表1中可以看出,本实施例方法与对比算法而言,在5-way 1-shot和5-way 5-shot两种模式下平均准确率分别提高了7.32%、6.51%。
综上,本实施例的分类方法引入了预训练阶段网络模型,通过将训练好的预训练阶段网络模型的参数加载到自监督训练阶段网络模型中,可以加快自监督训练阶段网络模型的收敛并提高模型的泛化能力;在对自监督训练阶段网络模型进行训练时,引入了自监督损失函数和负余弦相似度损失函数可获得更好的特征表示,减少特征表征偏差,并且引入高效的最近邻度量方式以计算支撑样本和查询样本之间的相似度,从而,该方法避免了现有技术模型收敛速度慢、模型泛化能力不强、无法获得更好的特征表示和度量方式不高效导致对SAR目标分类精度的影响,有效地提高了小样本目标的分类准确率。
以上内容是结合具体的优选实施方式对本发明所作的进一步详细说明,不能认定本发明的具体实施只局限于这些说明。对于本发明所属技术领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干简单推演或替换,都应当视为属于本发明的保护范围。
Claims (9)
1.一种基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,包括步骤:
S1、从若干合成孔径雷达图像中获取训练任务集和测试任务集,其中,所述训练任务集中每个训练任务包括训练支撑样本和训练查询样本,所述测试任务集中每个测试任务包测试支撑样本集和测试查询样本;
S2、构建预训练阶段网络模型,其中,所述预训练阶段网络模型包括依次级联的第一特征提取模块、多层感知机模块和相似度计算模块;
S3、利用所述训练任务集对预训练阶段网络模型进行迭代训练,并利用交叉熵损失函数对所述预训练阶段网络模型进行更新,得到训练好的预训练阶段网络模型;
S4、构建自监督训练阶段网络模型,其中,所述自监督训练阶段网络模型包括并联的第二特征提取模块和第三特征提取模块,所述第一特征提取模块、所述第二特征提取模块和所述第三特征提取模块的结构相同;
S5、将所述训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述自监督训练阶段网络模型中,并利用所述训练任务集对加载后的自监督训练阶段网络模型进行迭代训练,在训练过程中,利用负余弦相似度损失函数和自监督损失函数对所述第三特征提取模块进行更新,得到训练好的第三特征提取模块;
S6、将所述测试任务集输入到所述训练好的第三特征提取模块进行特征提取,得到测试特征向量组集合;
S7、计算所述测试特征向量组集合中所述测试查询样本对应的测试特征向量和每一个所述测试支撑样本对应的测试特征向量的相似度,得到目标分类结果。
2.根据权利要求1所述的基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,所述第一特征提取模块包括依次级联的第一特征提取子模块和第二特征提取子模块,其中,
所述第一特征提取子模块包括依次级联的第一卷积层、第一批量归一化层、第一激活函数层和最大池化层;
所述第二特征提取子模块包括依次级联的第二卷积层、第二批量归一化层和第二激活函数层;
所述多层感知机模块包括全连接层。
3.根据权利要求2所述的基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,所述第一卷积层和所述第二卷积层中卷积核的个数均为64,卷积核的大小均为3×3,步长和填充均为1;所述第一激活函数层和所述第二激活函数层均采用Leaky Relu,其参数均为0.2;所述最大池化层的池化核大小为2×2,步长为2。
4.根据权利要求2所述的基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,步骤S3包括:
S31、初始化设置预训练阶段迭代次数、预训练阶段最大迭代次数和第n次迭代的预训练阶段网络模型;
S32、利用所述第一特征提取模块对所述每个训练任务进行特征提取,得到第一特征向量组集合;
S33、利用所述多层感知机模块将所述特征向量组集合中的每个特征映射到样本标记空间,得到第二特征向量组集合;
S34、利用所述相似度计算模块计算所述第二特征向量组集合中所述训练查询样本对应的特征向量与每个所述训练支撑样本对应的特征向量的相似度,得到相似度得分集;
S35、采用交叉熵损失函数计算每次迭代过程中所述预训练阶段网络模型的第一损失值,并采用梯度下降法将所述第一损失值对第一权值参数的偏导在所述预训练阶段网络模型中进行反向传播,以对所述第一权值参数进行更新,得到更新后的第一权值参数;
S36、判断迭代是否完成,若是,则得到所述训练好的预训练阶段网络模型;若否,则继续进行迭代训练。
7.根据权利要求2所述的基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,
所述第二特征提取模块包括依次级联的第三卷积层、第三批量归一化层、第三激活函数层和第二最大池化层;
所述第三特征提取模块包括依次级联的第四卷积层、第四批量归一化层、第四激活函数层和第三最大池化层。
8.根据权利要求7所述的基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,步骤S5包括:
S51、初始化设置自监督训练阶段迭代次数、自监督训练阶段最大迭代次数和第n次迭代的自监督训练阶段网络模型;
S52、将所述训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述第二特征提取模块中,利用加载后的所述第二特征提取模块对所述每个训练任务进行特征提取,得到第三特征向量组集合;将训练好的预训练阶段网络模型中所述第一特征提取模块的参数加载到所述第三特征提取模块中,利用加载后的所述第二特征提取模块对所述每个训练任务进行特征提取,得到第四特征向量组集合;
S53、采用负余弦相似度损失函数和自监督损失函数之和作为训练损失函数来计算每次迭代过程中所述自监督训练阶段网络模型的第二损失值,然后再利用所述第二损失值对第二权值参数的偏导在所述自监督训练阶段网络模型中进行反向传播以对所述第二权值参数进行更新,得到更新后的第二权值参数;
S54、判断迭代是否完成,若是,则得到所述训练好的自监督训练阶段网络模型;若否,则继续进行迭代训练。
9.根据权利要求8所述的基于自监督学习和最近邻网络的小样本SAR目标分类方法,其特征在于,所述训练损失函数为:
L=Lcos+Lself-supervised
其中,Lcos为负余弦相似度损失函数,Lself-supervised为自监督损失函数;
所述负余弦相似度损失函数为:
其中,D表示第二特征提取模块,G表示第三特征提取模块,||·||2表示l2正则化;
所述自监督损失函数为:
Lself-supervised(x)=U(P1,Z2)+U(P2,Z1)
其中,P1表示X1经过一个h模块得到的特征向量组,P2表示X2经过一个h模块得到的特征向量组,X1和X2分别表示训练任务集经过数据增强后得到的两种效果图,Z1表示X1经过一个f模块得到的特征向量组,Z2表示X2经过一个f模块得到的特征向量组X2。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211032994.6A CN115482461A (zh) | 2022-08-26 | 2022-08-26 | 基于自监督学习和最近邻网络的小样本sar目标分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211032994.6A CN115482461A (zh) | 2022-08-26 | 2022-08-26 | 基于自监督学习和最近邻网络的小样本sar目标分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115482461A true CN115482461A (zh) | 2022-12-16 |
Family
ID=84423031
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211032994.6A Pending CN115482461A (zh) | 2022-08-26 | 2022-08-26 | 基于自监督学习和最近邻网络的小样本sar目标分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115482461A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116597419A (zh) * | 2023-05-22 | 2023-08-15 | 宁波弗浪科技有限公司 | 一种基于参数化互近邻的车辆限高场景识别方法 |
-
2022
- 2022-08-26 CN CN202211032994.6A patent/CN115482461A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116597419A (zh) * | 2023-05-22 | 2023-08-15 | 宁波弗浪科技有限公司 | 一种基于参数化互近邻的车辆限高场景识别方法 |
CN116597419B (zh) * | 2023-05-22 | 2024-02-02 | 宁波弗浪科技有限公司 | 一种基于参数化互近邻的车辆限高场景识别方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110136154B (zh) | 基于全卷积网络与形态学处理的遥感图像语义分割方法 | |
CN110533631B (zh) | 基于金字塔池化孪生网络的sar图像变化检测方法 | |
CN108805200B (zh) | 基于深度孪生残差网络的光学遥感场景分类方法及装置 | |
CN112699966B (zh) | 基于深度迁移学习的雷达hrrp小样本目标识别预训练及微调方法 | |
CN113486981B (zh) | 基于多尺度特征注意力融合网络的rgb图像分类方法 | |
CN113095409B (zh) | 基于注意力机制和权值共享的高光谱图像分类方法 | |
CN113655479B (zh) | 基于可变形卷积和双注意力的小样本sar目标分类方法 | |
CN111916144B (zh) | 基于自注意力神经网络和粗化算法的蛋白质分类方法 | |
CN110263644B (zh) | 基于三胞胎网络的遥感图像分类方法、系统、设备及介质 | |
CN113095416B (zh) | 基于混合损失与图注意力的小样本sar目标分类方法 | |
CN112685504A (zh) | 一种面向生产过程的分布式迁移图学习方法 | |
CN110555461A (zh) | 基于多结构卷积神经网络特征融合的场景分类方法及系统 | |
CN111832580B (zh) | 结合少样本学习与目标属性特征的sar目标识别方法 | |
CN115482461A (zh) | 基于自监督学习和最近邻网络的小样本sar目标分类方法 | |
CN115082790A (zh) | 一种基于连续学习的遥感图像场景分类方法 | |
CN111222545B (zh) | 基于线性规划增量学习的图像分类方法 | |
CN113420593B (zh) | 基于混合推理网络的小样本sar自动目标识别方法 | |
CN109063750B (zh) | 基于cnn和svm决策融合的sar目标分类方法 | |
Prasetiyo et al. | Differential augmentation data for vehicle classification using convolutional neural network | |
CN110197213A (zh) | 基于神经网络的图像匹配方法、装置和设备 | |
CN112085001A (zh) | 一种基于多尺度边缘特征检测的隧道识别模型及方法 | |
CN116343016A (zh) | 一种基于轻量型卷积网络的多角度声呐图像目标分类方法 | |
CN116089886A (zh) | 信息处理方法、装置、设备及存储介质 | |
Liang et al. | Efficient recurrent attention network for remote sensing scene classification | |
CN113673629A (zh) | 基于多图卷积网络的开集域适应遥感图像小样本分类方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |