CN114943879B - 基于域适应半监督学习的sar目标识别方法 - Google Patents
基于域适应半监督学习的sar目标识别方法 Download PDFInfo
- Publication number
- CN114943879B CN114943879B CN202210860624.5A CN202210860624A CN114943879B CN 114943879 B CN114943879 B CN 114943879B CN 202210860624 A CN202210860624 A CN 202210860624A CN 114943879 B CN114943879 B CN 114943879B
- Authority
- CN
- China
- Prior art keywords
- sample
- loss
- samples
- domain
- enhanced
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
- G06V10/7753—Incorporation of unlabelled data, e.g. multiple instance learning [MIL]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/10—Terrestrial scenes
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Data Mining & Analysis (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Mathematical Physics (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Medical Informatics (AREA)
- Biomedical Technology (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明提供一种基于域适应半监督学习的SAR目标识别方法,涉及SAR目标识别技术领域,用以解决无标记样本对初始模型的继续优化的作用较为有限、模型的准确度不高的技术问题。该方法将无标记样本做强弱两种方式的增强,将更具有多样性、识别难度更大的强增强样本用于模型的训练,并将强增强样本所对应弱增强样本的伪标签作为其伪标签以保证伪标签的正确性,通过对强弱增强样本的使用使得模型能够得到更加有效的训练;通过构建域适应损失减少标记和无标记样本间的域差异,有效地减少了错误伪标签的数量;通过在无标记样本的分类损失中加入Top‑k损失降低错误伪标签对模型训练的影响。由此,本发明可有效增强SAR目标识别的准确度。
Description
技术领域
本发明涉及合成孔径雷达(Synthetic Aperture Radar,SAR)目标识别技术领域,尤其涉及一种基于域适应半监督学习的SAR目标识别方法。
背景技术
合成孔径雷达是一种主动式对地观测系统,能够实现全天时、全天候的观测。因此,SAR在海洋监测、地物勘探等方面具有独一无二的优势。SAR图像目标识别旨在识别SAR图像中的目标类别,是SAR应用中的一项重要任务。传统的SAR图像目标识别方法主要通过手动设计并提取目标的几何、微波散射等特征,再结合基于机器学习的分类器进行识别。随着深度学习的发展,通过构建神经网络自动学习层级特征成为了SAR图像目标识别领域的主流方法,但是这些方法依赖于大量的标记数据。在实际应用中,SAR图像的获取和标注需要花费大量的人力和物力,高昂的成本限制了深度学习方法在该领域中的应用。
针对SAR图像标记样本少这一问题的主要解决方案包括基于半监督学习的方法、基于迁移学习的方法和基于元学习的方法,这些方法采用不同的方式减少对标记样本量的需求。基于半监督学习的方法通过生成无标记样本的伪标签,使无标记样本也能够用于模型训练;基于迁移学习的方法将由其他任务中学习到的知识迁移到目标任务中,从而减少模型训练对样本量的需求;基于元学习的方法构建多个元任务,通过优化所有元任务的方式获得全局模型,使其在只需要少量目标任务数据的情况下便可以快速适应目标任务。
本发明属于基于半监督学习的方法,这类方法的关键在于如何有效地使用无标记样本。现有文献首先利用标记样本训练初始模型,再利用初始模型计算无标记样本的伪标签,最后使用无标记样本和其对应的伪标签对模型进行优化。但是,在实现本发明构思的过程中,发明人发现现有文献中至少存在以下三方面问题:
1)现有方法利用初始模型预测置信度高的无标记样本和其对应伪标签来优化模型,初始模型对筛选出的无标记样本具有较高的预测置信度,说明初始模型对正确计算出这些样本的类别已经有较大的把握,因此这些样本对初始模型的继续优化的作用较为有限;
2)现有方法没有考虑标记样本和无标记样本间的域差异,由于无标记样本的伪标签是由标记样本训练的初始模型得到的,因而当标记样本和无标记样本间存在域差异时,容易为无标记样本生成错误的伪标签,因为后续要利用无标记样本及其伪标签对初始模型进行优化,使用带有错误伪标签的无标记样本进行训练会严重影响模型的优化;
3)现有方法没有考虑错误的伪标签对模型训练的影响,当采用带有错误伪标签的无标记样本训练模型时,会使模型向着错误的方向更新,从而降低模型的准确度。
发明内容
有鉴于此,本发明提供了一种基于域适应半监督学习的SAR目标识别方法,至少部分解决无标记样本对初始模型的继续优化的作用较为有限、模型的准确度不高的技术问题。
本发明提供的基于域适应半监督学习的SAR目标识别方法,包括:获取标记样本集和无标记样本集,对标记样本集中的每个标记样本进行弱增强,生成第一弱增强样本,对无标记样本集中的每个无标记样本分别进行弱增强和强增强,生成第二弱增强样本和强增强样本;将第一弱增强样本、第二弱增强样本和强增强样本分别输入卷积神经网络模型中的特征提取器,获得对应的特征图,计算标记样本集和无标记样本集间的域适应损失;将特征图展开为向量后输入卷积神经网络模型中的特征分类器,分别计算第一弱增强样本、第二弱增强样本和强增强样本的预测概率;根据第一弱增强样本的预测概率,计算每个标记样本的分类损失;根据第二弱增强样本的预测概率筛选部分无标记样本,计算部分无标记样本的分类损失;根据域适应损失、每个标记样本的分类损失和部分无标记样本的分类损失,计算标记样本集和无标记样本集的总损失;使用梯度下降算法优化总损失,更新卷积神经网络模型的参数;加载训练好的卷积神经网络模型,输入待测样本,输出待测样本的预测类别。
进一步地,弱增强依次包括翻转和裁剪变换,强增强包括以下中的任意两种:对比度变换、亮度变换、颜色变换、图像旋转、图像锐化、横向剪切、纵向剪切、横向平移、纵向平移、随机剪切。
进一步地,特征提取器使用ResNet18网络,输入样本的尺寸为128×128,特征图的尺寸为512×1×1;特征分类器包含全连接层和softmax层,其中,全连接层的输入为第一弱增强样本、第二弱增强样本和强增强样本展开后的512维向量,输出为10维向量,全连接层的输出经过softmax层后得到10维的预测概率向量。
进一步地,标记样本集和无标记样本集间的域适应损失根据以下公式计算得出:
式中,L da 为域适应损失;N x 为标记样本集中的标记样本总个数;N u 为无标记样本集
中的无标记样本总个数;分别为第i个和第j个第一弱增强样本的特征图;分
别为第i个和第j个第二弱增强样本的特征图;k(·)表示高斯核函数。
进一步地,每个标记样本的分类损失根据以下公式计算得出:
式中,L ce (p,y)为交叉熵函数;c为预设的目标类别总数;p=[p 1 ,…,p c ]T为预测概
率;y=[y 1 ,…,y c ]T为类别标签;[·]T表示对向量的转置操作;为第i个标记样本x i 的分类
损失;为第i个第一弱增强样本的预测概率;为第i个第一弱增强样本的类别标签。
进一步地,根据第二弱增强样本的预测概率筛选部分无标记样本,包括:判断每个第二弱增强样本的预测概率中的最大元素是否不小于预设概率阈值,如果是,则保留该第二弱增强样本,否则,移除该第二弱增强样本。
进一步地,计算部分无标记样本的分类损失,包括:将部分无标记样本划分为部分弱增强样本和部分强增强样本;使用部分弱增强样本的预测概率,计算部分强增强样本的伪标签;根据部分强增强样本的预测概率和部分强增强样本的伪标签,计算得出部分无标记样本的分类损失。
进一步地,部分无标记样本的分类损失由交叉熵损失和Top-k损失的加权求和得到,其中:交叉熵损失是使用交叉熵函数计算出的部分强增强样本的预测概率与部分强增强样本的伪标签之间的差异;Top-k损失是使用Top-k损失函数计算出的部分强增强样本的预测概率与部分强增强样本的伪标签之间的差异。
进一步地,Top-k损失根据以下公式计算得出:
式中,表示Top-k损失;表示类别空间;表示由中元素构
成的k元组集合;表示k元组集合中包含元素y的子集合;为预测概率p中最大k个元
素对应位置所构成的k元组;为中间系数,其计算方法为:当时,=0,否
则,,α为间隔参数;τ表示温控参数。
进一步地,标记样本集和无标记样本集的总损失由域适应损失、每个标记样本的分类损失和部分无标记样本的分类损失求和得出。
与现有技术相比,本发明提供的基于域适应半监督学习的SAR目标识别方法,至少具有以下有益效果:
(1)针对现有方法筛选出的无标记样本对提升模型性能作用有限的问题,本方法将无标记样本做强弱两种方式的增强,基于同一样本的弱增强和强增强样本应该具有相同类别标签的一致性准则,在初始模型更容易正确预测其类别的弱增强样本上计算伪标签并进行样本的筛选,将更具有多样性、识别难度更大的强增强样本用于模型的训练,使得模型能够得到更加有效的训练;
(2)针对现有方法没有考虑标记和无标记样本间的域差异导致错误伪标签多的问题,本方法通过构建域适应损失减少标记和无标记样本间的域差异,有效地减少了错误伪标签的数量;
(3)针对现有方法没有考虑错误伪标签对模型训练产生影响的问题,本方法通过在无标记样本的分类损失中加入Top-k损失降低错误伪标签对模型训练的影响。
附图说明
通过以下参照附图对本发明实施例的描述,本发明的上述以及其他目的、特征和优点将更为清楚,在附图中:
图1示意性示出了根据本发明实施例的基于域适应半监督学习的SAR目标识别方法的操作流程图;
图2示意性示出了根据本发明实施例的基于域适应半监督学习的SAR目标识别方法的流程图;
图3示意性示出了根据本发明实施例的无标记样本筛选过程的流程图;
图4示意性示出了根据本发明实施例的无标记样本的分类损失计算过程的流程图;
图5示意性示出了根据本发明实施例的卷积神经网络模型训练过程中的迭代次数与总损失曲线。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚明白,以下结合具体实施例,并参照附图,对本发明进一步详细说明。显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
在此使用的术语仅仅是为了描述具体实施例,而并非意在限制本发明。在此使用的术语“包括”、“包含”等表明了所述特征、步骤、操作和/或部件的存在,但是并不排除存在或添加一个或多个其他特征、步骤、操作或部件。
在此使用的所有术语(包括技术和科学术语)具有本领域技术人员通常所理解的含义,除非另外定义。应注意,这里使用的术语应解释为具有与本说明书的上下文相一致的含义,而不应以理想化或过于刻板的方式来解释。
图1示意性示出了根据本发明实施例的基于域适应半监督学习的SAR目标识别方法的操作流程图。
如图1所示,本发明实施例提供的基于域适应半监督学习的SAR目标识别方法,主要包括以下关键步骤:输入标记和无标记样本集,生成弱增强与强增强样本,计算域适应损失,计算样本预测概率,计算标记样本分类损失,筛选无标记样本,计算无标记样本分类损失,计算总损失,优化总损失更新模型参数以及模型加载和测试。
图2示意性示出了根据本发明实施例的基于域适应半监督学习的SAR目标识别方法的流程图。
结合图2,对图1所示的方法进行详细描述。如图2所示,该实施例的基于域适应半监督学习的SAR目标识别方法,可以包括操作S110~操作S180。
在操作S110,获取标记样本集和无标记样本集,对标记样本集中的每个标记样本进行弱增强,生成第一弱增强样本,对无标记样本集中的每个无标记样本分别进行弱增强和强增强,生成第二弱增强样本和强增强样本。
具体地,获取标记样本集X和无标记样本集U。对于每一次迭代过程,获取的标记样本集X包含N x 个标记样本x i ,i=1,…, N x 。获取的无标记样本集U包含N u 个无标记样本u j ,j=1,…, N u 。
本实施例中,弱增强依次包括翻转和裁剪变换,强增强包括以下中的任意两种:对比度变换、亮度变换、颜色变换、图像旋转、图像锐化、横向剪切、纵向剪切、横向平移、纵向平移、随机剪切。
具体来说,在弱增强过程中依次对每个样本(包含标记样本和无标记样本)进行翻转和裁剪2种变换得到弱增强样本。强增强包含10种更为复杂的图像变换方式,具体为:对比度变换、亮度变换、颜色变换、图像旋转、图像锐化、横向剪切、纵向剪切、横向平移、纵向平移、随机剪切,在强增强过程中先在10种变换方式中随机选取2种,再对每个无标记样本依次进行这2种变换得到强增强样本。
在操作S120,将第一弱增强样本、第二弱增强样本和强增强样本分别输入卷积神经网络模型中的特征提取器,获得对应的特征图,计算标记样本集和无标记样本集间的域适应损失。
本实施例中,特征提取器例如可以使用ResNet18网络,输入样本的尺寸为128×128,特征图的尺寸为512×1×1。
接着,根据这三种特征图,计算标记样本集和无标记样本集间的域适应损失。本实施例中,标记样本集和无标记样本集间的域适应损失根据以下公式计算得出:
式中,L da 为域适应损失;N x 为标记样本集中的标记样本总个数;N u 为无标记样本集
中的无标记样本总个数;分别为第i个和第j个第一弱增强样本的特征图;分
别为第i个和第j个第二弱增强样本的特征图;k(·)表示高斯核函数。
可见,该域适应损失基于最大平均差异来计算。通过优化域适应损失L da ,可以减小标记样本集和无标记样本集间的域差异。
在操作S130,将特征图展开为向量后输入卷积神经网络模型中的特征分类器,分别计算第一弱增强样本、第二弱增强样本和强增强样本的预测概率。
本实施例中,特征分类器可以包含全连接层和softmax层。其中,全连接层的输入为第一弱增强样本、第二弱增强样本和强增强样本展开后的512维向量,输出为10维向量,全连接层的输出经过softmax层后得到10维的预测概率向量。在第1次迭代时,全连接层的参数随机生成,在第t次(t>1)迭代时,采用第t-1次迭代输出的模型参数。
在操作S140,根据第一弱增强样本的预测概率,计算每个标记样本的分类损失。
本实施例中,使用交叉熵函数计算每个标记样本的分类损失,交叉熵函数的计算方法如下所示:
式中,L ce (p,y)为交叉熵函数;c为预设的目标类别总数;p=[p 1 ,…,p c ]T为预测概率;y=[y 1 ,…,y c ]T为类别标签;[·]T表示对向量的转置操作。
在此基础上,每个标记样本的分类损失根据以下公式计算得出:
在操作S150,根据第二弱增强样本的预测概率筛选部分无标记样本,计算部分无标记样本的分类损失。
图3示意性示出了根据本发明实施例的无标记样本筛选过程的流程图。
如图3所示,本实施例中,上述操作S150中的根据第二弱增强样本的预测概率筛选部分无标记样本,可以进一步包括操作S1501。
在操作S1501,判断每个第二弱增强样本的预测概率中的最大元素是否不小于预设概率阈值,如果是,则保留该第二弱增强样本,否则,移除该第二弱增强样本。
接着,进行部分无标记样本的分类损失的计算。
图4示意性示出了根据本发明实施例的部分无标记样本的分类损失计算过程的流程图。
如图4所示,本实施例中,上述操作S150中的计算部分无标记样本的分类损失,可以进一步包括操作S1502~操作S1504。
在操作S1502,将部分无标记样本划分为部分弱增强样本和部分强增强样本。
将筛选后的无标记样本,也即部分无标记样本记为,为筛选后无标
记样本的数目,也即部分无标记样本的数目。在该部分无标记样本中,分别记为划分
出的部分弱增强样本和部分强增强样本,分别记为部分弱增强样本和部分强增强
样本的预测概率。
在操作S1503,使用部分弱增强样本的预测概率,计算部分强增强样本的伪标签。
在操作S1504,根据部分强增强样本的预测概率和部分强增强样本的伪标签,计算得出部分无标记样本的分类损失。
本实施例中,部分无标记样本的分类损失由交叉熵损失和Top-k损失的加权求和得到。
其中,交叉熵损失是使用交叉熵函数计算出的部分强增强样本的预测概率与部分强增强样本的伪标签之间的差异。
Top-k损失是使用Top-k损失函数计算出的部分强增强样本的预测概率与部分强增强样本的伪标签之间的差异。可以理解的是,将类别标签y的类别标号(即y中非零元素的位置)记为y,当y属于预测概率p中最大的k个元素对应位置构成的集合时,Top-k损失L Top-k 均能输出较小的损失值,因此能够减小错误伪标签对模型训练的影响。
具体地,Top-k损失根据以下公式计算得出:
式中,表示Top-k损失;表示类别空间;表示由中元素构
成的k元组集合;表示k元组集合中包含元素y的子集合;为预测概率p中最大k个元
素对应位置所构成的k元组;为中间系数,其计算方法为:当时,=0,否
则,,α为间隔参数;τ表示温控参数。
接着,在操作S160,根据域适应损失、每个标记样本的分类损失和部分无标记样本的分类损失,计算标记样本集和无标记样本集的总损失。
本实施例中,标记样本集和无标记样本集的总损失由域适应损失、每个标记样本的分类损失和部分无标记样本的分类损失求和得出。
具体地,标记样本集和无标记样本集的总损失L total ,包含标记样本集和无标记样
本集间的域适应损失L da 、N x 个标记样本的分类损失的和、个部分无标记样本
的分类损失的和,计算方法如下所示:
接着,在操作S170,使用梯度下降算法优化总损失,更新卷积神经网络模型的参数。
使用梯度下降算法优化总损失L total ,根据总损失,重复上述操作S110到操作S170,对卷积神经网络模型的参数进行迭代优化,直至当前迭代次数达到预先设定的总迭代次数。
在操作S180,加载训练好的卷积神经网络模型,输入待测样本,输出待测样本的预测类别。
该预测类别即为最终的识别结果。
通过本发明的实施例,将无标记样本做强弱两种方式的增强,将更具有多样性、识别难度更大的强增强样本用于模型的训练,并将强增强样本所对应弱增强样本的伪标签作为其伪标签以保证伪标签的正确性,通过对强弱增强样本的使用使得模型能够得到更加有效的训练。并且,通过构建域适应损失减少标记和无标记样本间的域差异,有效地减少了错误伪标签的数量,以及,通过在无标记样本的分类损失中加入Top-k损失降低错误伪标签对模型训练的影响。综合上述三方面因素,有效增强SAR目标识别的准确度。
以上只是示例性说明,本发明的实施例不限于此。例如,在一些实施例中,上述操作S120中的特征提取器可以采用其他神经网络结构,如VGG、Inception等网络。
又例如,在一些实施例中,上述操作S120中的域适应损失的计算还可以采用KL散度(Kullback-Leibler Divergence)、JS散度(Jensen–Shannon Divergence)、Wasserstein距离等。
下面通过实际数据的处理示例来验证本发明上述的实施例的方法的处理效果。实验采用MSTAR数据集的10分类任务,在该任务中包含10个类别的SAR车辆目标。训练集包含俯仰角为17度的共2747个样本,测试集包含俯仰角为15度的共2425个样本。
步骤1,从训练集中的每种目标类别中随机选取15个作为标记样本,构成标记样本集X,剩余训练集样本构成无标记样本集U。对于每一次迭代过程,从X和U中分别读取64个标记样本x i ,i=1,…,64和64个无标记样本u j ,j=1,…,64。即N x =64,N u =64。
特征提取器使用ResNet18网络,输入图像的尺寸为128×128,特征图的尺寸为512×1×1。计算域适应损失L da 。
TSNE是由T分布和随机近邻嵌入(Stochastic Neighbor Embedding,SNE)组成,是一种可视化工具,将高位数据降到2-3维,然后画成图。TSNE是目前效果最好的数据降维和可视化方法。
为了便于辨别,发明人利用TSNE可视化特征图对域差异损失的效果进行验证,通过对比加入域适应损失前的结果与加入域适应损失后的结果,可以得到在加入域适应损失前,标记样本和无标记样本间存在较大的域差异,在加入域适应损失后,域差异明显减小,表明本发明实施例的方法加入域适应损失的有效性。
步骤3,计算第一弱增强样本、第二弱增强样本和强增强样本的预测概率。
步骤5,筛选部分无标记样本。保留预测概率中的最大元素大于或者等于预设
概率阈值T的无标记样本,T取0.8。将筛选后的无标记样本记为,为筛选后无
标记样本的数目。分别记为划分出的部分弱增强样本和部分强增强样本,
为对应样本的预测概率。
步骤6,计算部分无标记样本的分类损失。计算部分无标记样本的交叉熵损失和Top-k损失L Top-k 。在计算Top-k损失时,相关参数设置如下:k=3;间隔参数α=1;
温控参数τ=0.5。计算部分无标记样本的分类损失,其中,权重参数λ设置为0.2。
步骤7,计算标记样本集和无标记样本集的总损失L total 。
步骤8,使用Adam梯度下降算法优化总损失L total ,更新模型参数,学习率设置为0.0001。重复步骤1到8,直至达到预先设定的总迭代次数N t ,N t 设置为6000。
图5示意性示出了根据本发明实施例的卷积神经网络模型训练过程中的迭代次数与总损失曲线。横轴表示迭代次数,纵轴表示损失值,由图5中可知随着迭代次数的增加损失值持续下降,表明模型进行了有效的训练。
步骤9,加载训练好的卷积神经网络模型,输入测试样本,输出预测类别,得到最终的识别结果。
最终计算得知,本发明实施例的方法的识别率为89.32%,现有方案的识别率为60.54%,识别结果说明了本发明实施例的方法的先进性。
从以上的描述中,可以看出,本发明上述的实施例提供的基于域适应半监督学习的SAR目标识别方法,至少实现了以下技术效果:
1)现有方法使用初始模型预测置信度高的无标记样本和其对应伪标签来优化模型,本方法将无标记样本做强弱两种方式的增强,以初始模型在弱增强样本上的预测置信度为筛选无标记样本的依据,根据同一样本的弱增强样本和强增强样本应该具有相同类别标签的一致性准则,将筛选后无标记样本的强增强样本和其对应弱增强样本的伪标签对模型进行优化;
2)现有方法没有考虑标记样本和无标记样本间的域差异,本方法通过构建标记样本和无标记样本之间的域适应损失减小域差异;
3)现有方法没有考虑错误的伪标签对模型训练的影响,本方法在计算无标记样本的分类损失过程中加入了Top-k损失,能够减轻错误的伪标签对模型训练的影响。
附图中示出了一些方框图和/或流程图。应理解,方框图和/或流程图中的一些方框或其组合可以由计算机程序指令来实现。这些计算机程序指令可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器,从而这些指令在由该处理器执行时可以创建用于实现这些方框图和/或流程图中所说明的功能/操作的装置。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。因此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本发明的描述中,“多个”的含义是至少两个,例如两个、三个等,除非另有明确具体的限定。此外,位于元件之前的单词“一”或“一个”不排除存在多个这样的元件。
以上所述的具体实施例,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施例而已,并不用于限制本发明,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于域适应半监督学习的SAR目标识别方法,其特征在于,包括:
获取标记样本集和无标记样本集,对所述标记样本集中的每个标记样本进行弱增强,生成第一弱增强样本,对所述无标记样本集中的每个无标记样本分别进行弱增强和强增强,生成第二弱增强样本和强增强样本;
将所述第一弱增强样本、第二弱增强样本和强增强样本分别输入卷积神经网络模型中的特征提取器,获得对应的特征图,计算所述标记样本集和无标记样本集间的域适应损失;
将所述特征图展开为向量后输入所述卷积神经网络模型中的特征分类器,分别计算所述第一弱增强样本、第二弱增强样本和强增强样本的预测概率;
根据所述第一弱增强样本的预测概率,计算每个所述标记样本的分类损失;
根据所述第二弱增强样本的预测概率筛选部分无标记样本,计算所述部分无标记样本的分类损失;
根据所述域适应损失、每个所述标记样本的分类损失和所述部分无标记样本的分类损失,计算所述标记样本集和无标记样本集的总损失;
使用梯度下降算法优化所述总损失,更新所述卷积神经网络模型的参数;
加载训练好的卷积神经网络模型,输入待测样本,输出所述待测样本的预测类别。
2.根据权利要求1所述的基于域适应半监督学习的SAR目标识别方法,其特征在于,所述弱增强依次包括翻转和裁剪变换,所述强增强包括以下中的任意两种:
对比度变换、亮度变换、颜色变换、图像旋转、图像锐化、横向剪切、纵向剪切、横向平移、纵向平移、随机剪切。
3.根据权利要求1所述的基于域适应半监督学习的SAR目标识别方法,其特征在于,所述特征提取器使用ResNet18网络,输入样本的尺寸为128×128,特征图的尺寸为512×1×1;
所述特征分类器包含全连接层和softmax层,其中,所述全连接层的输入为所述第一弱增强样本、第二弱增强样本和强增强样本展开后的512维向量,输出为10维向量,所述全连接层的输出经过所述softmax层后得到10维的预测概率向量。
6.根据权利要求1所述的基于域适应半监督学习的SAR目标识别方法,其特征在于,根据所述第二弱增强样本的预测概率筛选部分无标记样本,包括:
判断每个所述第二弱增强样本的预测概率中的最大元素是否不小于预设概率阈值,如果是,则保留该第二弱增强样本,否则,移除该第二弱增强样本。
7.根据权利要求1所述的基于域适应半监督学习的SAR目标识别方法,其特征在于,计算所述部分无标记样本的分类损失,包括:
将所述部分无标记样本划分为部分弱增强样本和部分强增强样本;
使用所述部分弱增强样本的预测概率,计算所述部分强增强样本的伪标签;
根据所述部分强增强样本的预测概率和所述部分强增强样本的伪标签,计算得出所述部分无标记样本的分类损失。
8.根据权利要求7所述的基于域适应半监督学习的SAR目标识别方法,其特征在于,所述部分无标记样本的分类损失由交叉熵损失和Top-k损失的加权求和得到,其中:
所述交叉熵损失是使用交叉熵函数计算出的所述部分强增强样本的预测概率与所述部分强增强样本的伪标签之间的差异;
所述Top-k损失是使用Top-k损失函数计算出的所述部分强增强样本的预测概率与所述部分强增强样本的伪标签之间的差异。
10.根据权利要求1所述的基于域适应半监督学习的SAR目标识别方法,其特征在于,所述标记样本集和无标记样本集的总损失由所述域适应损失、每个所述标记样本的分类损失和所述部分无标记样本的分类损失求和得出。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210860624.5A CN114943879B (zh) | 2022-07-22 | 2022-07-22 | 基于域适应半监督学习的sar目标识别方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210860624.5A CN114943879B (zh) | 2022-07-22 | 2022-07-22 | 基于域适应半监督学习的sar目标识别方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114943879A CN114943879A (zh) | 2022-08-26 |
CN114943879B true CN114943879B (zh) | 2022-10-04 |
Family
ID=82910617
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210860624.5A Active CN114943879B (zh) | 2022-07-22 | 2022-07-22 | 基于域适应半监督学习的sar目标识别方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114943879B (zh) |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115482418B (zh) * | 2022-10-09 | 2024-06-07 | 北京呈创科技股份有限公司 | 基于伪负标签的半监督模型训练方法、系统及应用 |
CN117253097B (zh) * | 2023-11-20 | 2024-02-23 | 中国科学技术大学 | 半监督域适应图像分类方法、系统、设备及存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111881983A (zh) * | 2020-07-30 | 2020-11-03 | 平安科技(深圳)有限公司 | 基于分类模型的数据处理方法、装置、电子设备及介质 |
CN112395987A (zh) * | 2020-11-18 | 2021-02-23 | 西安电子科技大学 | 基于无监督域适应cnn的sar图像目标检测方法 |
CN114332568A (zh) * | 2022-03-16 | 2022-04-12 | 中国科学技术大学 | 域适应图像分类网络的训练方法、系统、设备及存储介质 |
CN114492574A (zh) * | 2021-12-22 | 2022-05-13 | 中国矿业大学 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20220230065A1 (en) * | 2019-05-06 | 2022-07-21 | Google Llc | Semi-supervised training of machine learning models using label guessing |
-
2022
- 2022-07-22 CN CN202210860624.5A patent/CN114943879B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111881983A (zh) * | 2020-07-30 | 2020-11-03 | 平安科技(深圳)有限公司 | 基于分类模型的数据处理方法、装置、电子设备及介质 |
CN112395987A (zh) * | 2020-11-18 | 2021-02-23 | 西安电子科技大学 | 基于无监督域适应cnn的sar图像目标检测方法 |
CN114492574A (zh) * | 2021-12-22 | 2022-05-13 | 中国矿业大学 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
CN114332568A (zh) * | 2022-03-16 | 2022-04-12 | 中国科学技术大学 | 域适应图像分类网络的训练方法、系统、设备及存储介质 |
Non-Patent Citations (1)
Title |
---|
基于主动学习的半监督领域自适应方法研究;姚明海等;《高技术通讯》;20200815(第08期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN114943879A (zh) | 2022-08-26 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114943879B (zh) | 基于域适应半监督学习的sar目标识别方法 | |
CN111160311B (zh) | 基于多注意力机制双流融合网络的黄河冰凌语义分割方法 | |
CN111369572B (zh) | 一种基于图像修复技术的弱监督语义分割方法和装置 | |
CN113486981B (zh) | 基于多尺度特征注意力融合网络的rgb图像分类方法 | |
Jeon et al. | Partially supervised classification using weighted unsupervised clustering | |
CN108399420B (zh) | 一种基于深度卷积网络的可见光舰船虚警剔除方法 | |
CN108805157B (zh) | 基于部分随机监督离散式哈希的遥感图像分类方法 | |
CN105069796B (zh) | 基于小波散射网络的sar图像分割方法 | |
CN113095417A (zh) | 基于融合图卷积和卷积神经网络的sar目标识别方法 | |
CN104820841A (zh) | 基于低阶互信息和光谱上下文波段选择的高光谱分类方法 | |
CN104517120A (zh) | 基于多路分层正交匹配的遥感图像场景分类方法 | |
CN108596204B (zh) | 一种基于改进型scdae的半监督调制方式分类模型的方法 | |
CN113468939A (zh) | 一种基于监督最小化深度学习模型的sar目标识别方法 | |
CN113066528B (zh) | 基于主动半监督图神经网络的蛋白质分类方法 | |
CN117475236B (zh) | 用于矿产资源勘探的数据处理系统及其方法 | |
CN115705393A (zh) | 一种基于持续学习的雷达辐射源分级识别方法 | |
CN111209813B (zh) | 基于迁移学习的遥感图像语义分割方法 | |
CN111832463A (zh) | 一种基于深度学习的交通标志检测方法 | |
CN115730656A (zh) | 一种利用混合未标记数据的分布外样本检测方法 | |
CN111210433A (zh) | 一种基于各向异性势函数的马氏场遥感图像分割方法 | |
CN115661539A (zh) | 一种嵌入不确定性信息的少样本图像识别方法 | |
CN114169462A (zh) | 一种特征引导的深度子领域自适应隐写检测方法 | |
CN111860547B (zh) | 基于稀疏表示的图像分割方法、装置、设备及存储介质 | |
CN113671493B (zh) | 一种基于特征融合的海面小目标检测方法及系统 | |
CN113688950B (zh) | 用于图像分类的多目标特征选择方法、装置和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |