CN112528862A - 基于改进的交叉熵损失函数的遥感图像目标检测方法 - Google Patents
基于改进的交叉熵损失函数的遥感图像目标检测方法 Download PDFInfo
- Publication number
- CN112528862A CN112528862A CN202011462894.8A CN202011462894A CN112528862A CN 112528862 A CN112528862 A CN 112528862A CN 202011462894 A CN202011462894 A CN 202011462894A CN 112528862 A CN112528862 A CN 112528862A
- Authority
- CN
- China
- Prior art keywords
- remote sensing
- loss function
- sensing image
- network
- candidate
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/10—Terrestrial scenes
- G06V20/13—Satellite images
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/20—Scenes; Scene-specific elements in augmented reality scenes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Multimedia (AREA)
- Astronomy & Astrophysics (AREA)
- Remote Sensing (AREA)
- Image Analysis (AREA)
Abstract
本发明提出了一种基于改进的交叉熵损失函数的遥感图像目标检测方法,用于解决现有技术中存在的目标检测精度较低的技术问题,实现步骤为:1)获取训练样本集和测试样本集;2)构建基于改进的交叉熵损失函数的遥感图像目标检测模型;3)对基于改进的交叉熵损失函数的遥感图像目标检测模型进行迭代训练;4)获取遥感图像目标的检测结果。本发明通过调制因子控制分类准确率低的类别中的样本对损失函数的贡献程度,使训练更加关注这些样本,有效提升了部分分类准确率较低类别的检测精度,从而提升了整体的检测精度。
Description
技术领域
本发明属于图像处理技术领域,涉及一种遥感图像目标检测方法,具体涉及一种基于改进交叉熵损失函数的遥感图像目标检测方法,可以应用于地形勘探和视频监控等领域。
背景技术
近年来,计算机视觉飞速发展,在地形勘探中,需要目标检测识别出相应的地形及存在的飞机和建筑等,在视频监控中,目标检测可以跟踪关注的舰船,飞机的轨迹等。
目标检测方法是通过提取图像中的特征在图像中寻找目标物体进行分类并标明目标位置的过程。图像存储为像素点的矩阵,检测时通过提取出和目标物体相关的信息进行检测。目标检测任务的重点就是提升检测精度,检测精度分为每一个类别目标的平均检测精度AP和所有类别目标的平均检测精度均值mAP,平均精度AP和平均精度均值mAP越大,表示目标检测效果越好。其中,召回率=检测正确目标总数/目标总数,准确率=检测正确目标总数/检测目标总数。绘制准确率-召回率曲线,曲线与召回率所在的坐标轴包围面积表示每一类目标的平均检测精度AP,对所有类目标的平均检测精度求均值则得到平均检测精度均值mAP。检测精度受多方面因素的影响,如图像的像素高低,特征提取的效果优劣等。目标检测方法主要分为选定候选区域,提取特征和分类回归。
现有的目标识别检测算法主要分为传统的目标检测与识别算法与基于深度学习的目标检测与识别算法。传统的算法主要基于手动设计的特征,如边缘、纹理等特征,灵活性较差,且算法复杂度很高,重复性工作量大,不能有效利用图像的深层特征。而近期兴起的基于深度学习的方法能够有效提取图像的深层特征,充分利用图像的信息,大大提升了识别的准确率。深度学习方法又分为one-stage方法和two-stage方法。One-stage方法计算简单,但准确率相对较低。而two-stage方法虽然计算相对复杂,但是准确率有所提升。
遥感图像分为光学遥感图像和SAR图像,区别于自然图像,它的特征难以提取,且与自然图像的特征不通用,目标较小,某些类别间相似性很大,尤其是SAR图像为灰度图像,即单通道图像,在训练过程中与自然图像模型参数不同,无法通用。某些类别的目标的特征不够明显,导致特征提取效果不佳,因此,这些类别的目标检测精度相较于其他类别会低,从而影响整体的平均检测精度。因此,如何提升遥感图像中这些检测精度较低样本的检测精度是一个非常有意义的课题。
近年来,一些学者对two-stage方法中检测精度较低样本的精度提升做了一些改进。例如Rui Liu等2020年在第39届会议Chinese Control Conference(CCC)中发表的论文“An Improved Faster-RCNN Algorithm for Object Detection in Remote SensingImages”,公开了一种针对检测精度较低样本的遥感图像的目标检测方法,该方法设置阈值,手动将分类准确率较低的样本挑出来,重新输入网络进行训练,以达到提升这些样本分类准确率的目的。虽然有效,但增加了训练时间,并且对分类准确率低的样本的划分过于绝对,不能自动调节不同难分程度的样本对损失函数的贡献。
发明内容
本发明的目的在于针对上述现有技术的不足,提出了一种基于改进交叉熵损失函数的遥感图像目标检测方法,用于解决现有技术中存在的检测精度较低的技术问题。
本发明的技术思路是:获取训练样本和测试样本,搭建基于改进的交叉熵损失函数的Faster R-CNN的遥感图像目标检测模型,包括特征提取子网络、区域生成子网络、ROIAlign池化层、分类定位子网络,损失函数为改进的交叉熵损失函数FL,利用训练样本训练遥感图像目标检测网络模型,得到训练好的遥感图像目标检测网络模型之后,再将测试样本输入到训练好的遥感图像目标检测网络模型之后,获取遥感图像目标预测类别与目标的边界框,得到最终的检测结果。
根据上述技术思路,实现本发明目的采取技术方案包括如下步骤:
(1)获取训练样本集和测试样本集:
从遥感图像数据集中获取包含C种目标类别共N幅带标签的遥感图像,每幅遥感图像至少包含一个目标,并对每个目标类别进行独热编码,将真实的目标类别编码为1,其他的目标类别编码为0;并将N幅带标签的遥感图像中的m幅图像作为训练样本集,将其余带标签的遥感图像作为测试样本集,C≥2,N≥200,
(2)构建基于改进的交叉熵损失函数的遥感图像目标检测模型:
(2a)构建基于Faster R-CNN的遥感图像目标检测模型:
构建包括依次级联的特征提取子网络、区域生成子网络、ROI Align池化层和分类定位子网络,且特征提取子网络的末端与ROI Align池化层相连的Faster R-CNN的遥感图像目标检测模型;其中,特征提取子网络包括多个级联的特征提取模块,每个特征提取模块包含依次连接的多个卷积层-ReLU层和一个最大池化层;区域生成子网络包括依次连接的卷积层-ReLU层、并行连接的第一分类子网络和包含一个卷积层的第一定位模块、proposal层;第一分类子网络包括依次连接的卷积层、reshape层、softmax层、reshape层;分类定位子网络包括依次连接的全连接层、并行连接的第二分类子网络和包含一个全连接层的第二定位模块,第二分类子网络包括级联的全连接层和softmax层;
(2b)定义改进的交叉熵损失函数FL:
FL=FLcls1(pi)+Lreg1(d′1i,d1i)+FLcls2(pc)+Lreg2(d'2k,d2k)
FLcls1(pi)=-(1-pi)γlog(pi),i∈I
FLcls2(pc)=-(1-pc)γlog(pc),c∈C
其中,FLcls1(pi)表示第一分类子网络的改进的交叉熵损失函数,Lreg1(d′1i,d1i)表示第一定位模块的损失函数,FLcls2(pc)表示第二分类子网络的改进的交叉熵损失函数,Lreg2(d'2k,d2k)表示第二定位模块的损失函数;pi表示在FLcls1(pi)中由第一定位子网络生成并筛选的第i个候选框A1i中的内容为目标的概率,i∈I,I为A1i的个数;pc表示在FLcls2(pc)中由第二定位子网络中第k个候选框A2k中的目标为第c类的概率,c∈C,k∈K,K为非极大值抑制中设置的候选框个数,K≤I;(1-pi)γ和(1-pc)γ表示调制因子,γ表示指数参数;d′1i表示A1i的预测偏移量,d′1i=[d′1ix,d′1iy,d′1iw,d′1ih],d′1ix和d′1iy分别表示A1i的中心位置坐标在x轴和y轴的预测偏移量,d′1iw和d′1ih分别表示A1i在宽和高上的预测偏移量;A1i=[A1ix,A1iy,A1iw,A1ih],A1ix和A1iy分别表示A1i的中心位置坐标在x轴和y轴的值,A1iw和A1ih分别表示A1i的宽和高;d1i表示A1i的真实偏移量,d1i=[d1ix,d1iy,d1iw,d1ih],d1ix和d1iy分别表示A1i的中心位置坐标在x轴和y轴的真实偏移量,d1iw和d1ih分别表示A1i在宽和高上的真实偏移量;d'2k表示A2k的预测偏移量,d'2k=[d'2kx,d'2ky,d'2kw,d'2kh],d'2kx和d'2ky分别表示A2k的中心位置坐标在x轴和y轴的预测偏移量,d'2kw和d'2kh分别表示A2k在宽和高上的预测偏移量;A2k=[A2kx,A2ky,A2kw,A2kh]A2kx和A2ky分别表示A2k的中心位置坐标在x轴和y轴的值,A2kw和A2kh分别表示A2k的宽和高;d2k表示A2k的真实偏移量;d2k=[d2kx,d2ky,d2kw,d2kh],d2kx和d2ky分别表示A2k中心位置坐标在x轴和y轴的真实偏移量,d2kw和d2kh分别表示A2k在宽和高上的真实偏移量;smoothL1(x)为平滑损失函数;
(3)对基于改进的交叉熵损失函数的遥感图像目标检测模型进行迭代训练:
(3a)初始化基于Faster R-CNN的遥感图像目标检测模型的网络参数θ,迭代次数为t,最大迭代次数为T,T=20,并令t=0;
(3b)将训练样本集作为遥感图像目标检测模型的输入进行前向传播,特征提取子网络对每个训练样本进行特征提取,得到特征图集合f={f1,f2,...,fq,...,fm},其中fq表示第q个训练样本对应的大小为a×b的特征图,q∈m,每个特征图中的每个像素点对应训练样本中的一个区域;
(3c)区域生成子网络生成预测候选框坐标:
(3c1)以fq中每个像素点为中心,生成该像素点的9种初始候选框,得到fq的9×a×b个初始候选框,并计算每个初始候选框A0i与每个真实框Gr之间的交集与并集的比值IoU,Gr表示第r个真实框,r∈R,R表示真实框的个数,R≤I,再根据IoU对所有初始候选框进行筛选,将所筛选的多个初始候选框作为候选框A1i,筛选规则为:当IoU>0.7时,初始候选框包含有目标,标记为1;当IoU<0.3时,初始候选框不包含有目标,标记为0;舍弃0.3<IoU<0.7时的初始候选框;当真实框Gr没有与之对应的初始候选框的IoU>0.7时,将与真实框Gr的IoU最大的框也标记为1;
(3c2)区域生成子网络中的卷积层-ReLU层对每个筛选后的候选框A1i进行特征提取,得到特征图集合f'={f1',f′2,...,f′q,...,f′m};
(3c3)第一分类子网络根据特征图集合f'={f1',f′2,...,f′q,...,f′m}计算每个候选框A1i含有目标的概率pi;第一定位模块根据特征图集合f'={f1',f′2,...,f′q,...,f′m}计算候选框A1i的预测偏移量d′1j和预测候选框坐标[A′1ix,A′1iy,A′1iw,A′1ih];
(3c4)proposal层采用非极大值抑制方法,选取所有候选框中含有目标的概率得分pi排名前K个候选框,其中每个候选框坐标为[A2kx,A2ky,A2kw,A2kh],k∈K;
(3d)ROI Align池化层将前K个候选框的每个候选框的坐标值A2kx,A2ky,A2kw,A2kh分别缩小16倍,映射到特征图fq的对应位置得到K个感兴趣区域;将每个感兴趣区域划分为相同大小的7×7的子图;对每个子图进行最大池化操作,得到统一尺度的K个感兴趣区域;
(3e)分类定位子网络的全连接层将统一尺度的K个感兴趣区域提取特征,得到特征图集合f”={f1”,f″2,...,f″q,...,f″K};第二分类子网络通过f”={f1”,f″2,...,f″q,...,f″K}计算每个候选框内目标的预测类别pc,同时第二定位模块通过f”={f1”,f″2,...,f″q,...,f″K}计算候选框的精确坐标[Akx,Aky,Akw,Akh];
(3f)采用第一分类子网络的改进的交叉熵损失函数FLcls1(pi),通过pi计算自己的损失值FLcls1,采用第一定位模块的损失函数Lreg1(d′1i,d1i),通过d′1i和d1i计算自己的损失值Lreg1,采用第二分类子网络的改进的交叉熵损失函数FLcls2(pc),通过pc计算自己的损失值FLcls2,采用第二定位模块的损失函数Lreg2(d'2k,d2k),通过d'2k和d2k计算自己的损失值Lreg2,再采用反向传播方法,通过FLcls1、Lreg1、FLcls2和Lreg2计算遥感图像目标检测模型参数梯度,然后采用梯度下降算法,通过遥感图像目标检测模型参数梯度对网络参数θ进行更新;
(3g)判断t=T是否成立,若是,得到训练好的基于改进的交叉熵损失函数的遥感图像目标检测模型,否则,令t=t+1,并执行步骤(3b);
(4)获取遥感图像目标的检测结果:
将测试样本集作为训练好的基于改进的交叉熵损失函数的遥感图像目标检测模型的输入进行检测,得到每个目标的类别和边界框四个顶点坐标。
本发明与现有技术相比,具有以下优点:
本发明通过对Faster R-CNN中损失函数中的分类部分的交叉熵损失函数进行改进,改进后的交叉熵损失函数更加关注于遥感图像中分类准确率较低的类别样本,通过调制因子自动获取样本的难分程度,降低不同程度易分样本对于改进后的交叉熵损失函数的贡献,使得分类准确率较低的类别样本对改进后的交叉熵损失函数贡献更大,且分类准确率越低的样本对改进后的交叉熵损失函数的贡献越大,从而提升了分类准确率较低的类别样本的目标检测精度,进而提升了遥感图像目标检测模型的平均精度。
附图说明
图1是本发明的实现流程图;
图2是本发明遥感图像目标检测模型的结构示意图。
具体实施方式
以下结合附图和具体实施例,对本发明作进一步详细描述:
参照图1,本发明包括如下步骤:
(1)获取训练样本集和测试样本集:
从遥感图像数据集中获取包含C种目标类别共N幅带标签的遥感图像,每幅遥感图像至少包含一个目标,并对每个目标类别进行独热编码,将真实的目标类别编码为1,其他的目标类别编码为0;并将N幅带标签的遥感图像中的m幅图像作为训练样本集,将其余带标签的遥感图像作为测试样本集,C≥2,N≥200,
由于数据集过少,本发明合成制作了遥感图像目标检测数据集SAR_OD。首先将公开的遥感图像数据集MSTAR数据集中每张含有目标的大小128×128像素图片的目标提取出来,因为目标检测过程中,每个目标的阴影部分也会包含一定的信息,所以要将阴影的信息一并提取出来单独保存。根据MSTAR数据集中车辆目标和场景图像,合成制作了遥感图像目标检测数据集SAR_OD。
在本实施例中,C取9,代表MSTAR数据集中的8类目标以及背景,8类目标分别为装甲运输车(BTR70_SNC71、BTR60),步兵战车(BMP2_SN9563),坦克(T62、T72_SN132),装甲侦察车(BRDM2),自行榴弹炮(2S1),推土机(D7)。N为420,m为220。
由于数据集是用于目标检测与识别任务,需要对目标的位置和类别信息进行标注,同时为了使数据集能够具有更广泛的适用性,本发明将数据集制作成VOC2007格式。此数据集格式来源于世界级别的计算机视觉领域的挑战赛PASCAL VOC(The PASCAL VisualObject Classes)挑战赛,其数据格式也成为目标检测领域较为通用的数据格式。VOC2007共包含5个文件夹,其中SegmentationClass与SegmentationObject与图像分割相关,本发明中不做讨论。JPEGImages文件夹存放的是数据图片,即合成后的带有目标的遥感图像,本数据集图片为PNG格式。Annotations文件夹中存放的是.xml格式的文件,每个文件与合成后的JPEGImages文件夹中的图片对应,解释图片相关信息。其中filename元素为图片名称;size元素为图像尺寸,包括图像的长、宽以及通道数;segmented元素代表是否用于分割,本发明不用于分割,所以该元素值为0;每个object元素代表一个目标物体,包括目标类别以及位置信息,位置信息为每个目标左上角与右下角位置。ImageSets文件夹包含4个子文件夹存放的是.txt格式数据,代表每一种类型的任务对应的图像数据。Action文件夹为人的动作,Layout文件夹是具有人体部位的数据,Segmentation存放的是可用于分割的数据,本发明数据集用到的是Main文件夹,存放的是目标检测与识别相关的数据,文件夹中Trainval.txt与Train.txt存放的是训练集的索引,Val.txt存放验证集索引,Test.txt文件存放测试集索引,每个文件中的每行是一个图片名称(不含后缀)。
(2)构建基于改进的交叉熵损失函数的遥感图像目标检测模型:
(2a)构建基于Faster R-CNN的遥感图像目标检测模型,其结构如图2所示:
构建包括依次级联的特征提取子网络、区域生成子网络、ROI Align池化层和分类定位子网络,且特征提取子网络的末端与ROI Align池化层相连的Faster R-CNN的遥感图像目标检测模型;其中,特征提取子网络包括多个级联的特征提取模块,每个特征提取模块包含依次连接的多个卷积层-ReLU层和一个最大池化层;本发明中特征提取子网络包括5个特征提取模块;其中,模块1和模块2由2个卷积层-ReLU层和一个最大池化层组成;模块3,模块4,模块5由3个卷积层-ReLU层和一个最大池化层组成;区域生成子网络包括依次连接的卷积层-ReLU层、并行连接的第一分类子网络和包含一个卷积层的第一定位模块、proposal层;第一分类子网络包括依次连接的卷积层、reshape层、softmax层、reshape层;分类定位子网络包括依次连接的全连接层、并行连接的第二分类子网络和包含一个全连接层的第二定位模块,第二分类子网络包括级联的全连接层和softmax层;
(2b)定义改进的交叉熵损失函数FL:
FL=FLcls1(pi)+Lreg1(d′1i,d1i)+FLcls2(pc)+Lreg2(d'2k,d2k)
FLcls1(pi)=-(1-pi)γlog(pi),i∈I
FLcls2(pc)=-(1-pc)γlog(pc),c∈C
其中,FLcls1(pi)表示第一分类子网络的改进的交叉熵损失函数,Lreg1(d′1i,d1i)表示第一定位模块的损失函数,FLcls2(pc)表示第二分类子网络的改进的交叉熵损失函数,Lreg2(d'2k,d2k)表示第二定位模块的损失函数;pi表示在FLcls1(pi)中由第一定位子网络生成并筛选的第i个候选框A1i中的内容为目标的概率,i∈I,I为A1i的个数;pc表示在FLcls2(pc)中由第二定位子网络中第k个候选框A2k中的目标为第c类的概率,c∈C,k∈K,K为非极大值抑制中设置的候选框个数,K≤I;(1-pi)γ和(1-pc)γ表示调制因子,γ表示指数参数;d′1i表示A1i的预测偏移量,d′1i=[d′1ix,d′1iy,d′1iw,d′1ih],d′1ix和d′1iy分别表示A1i的中心位置坐标在x轴和y轴的预测偏移量,d′1iw和d′1ih分别表示A1i在宽和高上的预测偏移量;A1i=[A1ix,A1iy,A1iw,A1ih],A1ix和A1iy分别表示A1i的中心位置坐标在x轴和y轴的值,A1iw和A1ih分别表示A1i的宽和高;d1i表示A1i的真实偏移量,d1i=[d1ix,d1iy,d1iw,d1ih],d1ix和d1iy分别表示A1i的中心位置坐标在x轴和y轴的真实偏移量,d1iw和d1ih分别表示A1i在宽和高上的真实偏移量;d'2k表示A2k的预测偏移量,d'2k=[d'2kx,d'2ky,d'2kw,d'2kh],d'2kx和d'2ky分别表示A2k的中心位置坐标在x轴和y轴的预测偏移量,d'2kw和d'2kh分别表示A2k在宽和高上的预测偏移量;A2k=[A2kx,A2ky,A2kw,A2kh]A2kx和A2ky分别表示A2k的中心位置坐标在x轴和y轴的值,A2kw和A2kh分别表示A2k的宽和高;d2k表示A2k的真实偏移量;d2k=[d2kx,d2ky,d2kw,d2kh],d2kx和d2ky分别表示A2k中心位置坐标在x轴和y轴的真实偏移量,d2kw和d2kh分别表示A2k在宽和高上的真实偏移量;smoothL1(x)为平滑损失函数;
其中,原始的Faster R-CNN中采用的交叉熵损失函数如下所示:
Lcls(pi)=-log(pi)
其中i为目标的真实类别的编号。
在应用经典的Faster-RCNN进行目标检测的实验中,发现虽然训练集中各类别的目标数量分布较为平均,但是每个类别目标的平均精度(AP)有明显差别,每类目标的分类难度有差异,为了平衡上述问题,本发明采用Focal Loss损失函数改进原来的交叉熵损失。
Focal Loss函数的提出最初是为了解决one-stage算法中正负样本不均衡的问题,其中以二分类为例,损失函数如下所示,
FL(pi)=-αi(1-pi)γlog(pi),i∈I
p为候选框内为目标的概率,候选框内为目标即为正样本,为背景即为负样本。其中αi取不同的值可以调节正负样本对于损失函数的贡献,本发明中采用经典的two-stage算法Faster R-CNN。因为区域生成网络固定了正负样本的比例,而且SAR_OD数据集各类目标的样本数量均衡,所以本发明的损失函数不设置αi。(1-pi)γ部分称为调制因子可以聚焦于难分类的样本,其中当一个样本的分类错误,这时pi的值很小,调制因子接近于1和原来的损失相近,当pi接近1时,代表该目标分类良好,这时调制因子接近于0,降低该目标对损失函数的贡献。在参数γ的选择上,当γ=0时,调制因子不起作用,γ越大,调制因子的影响作用越大。在本发明中,参数γ取0.5,该损失函数通过调制因子降低易分样本对于损失函数的贡献,使得分类准确率低的样本对损失函数贡献更大,从而达到提升遥感图像目标检测平均精度的目的。因此,最终的改进的交叉熵损失函数公式为:
FLcls1(pi)=-(1-pi)0.5log(pi),i∈I
FLcls2(pc)=-(1-pc)0.5log(pc),c∈C
(3)对基于改进的交叉熵损失函数的遥感图像目标检测模型进行迭代训练:
(3a)初始化基于Faster R-CNN的遥感图像目标检测模型的网络参数θ,迭代次数为t,最大迭代次数为T,T=20,并令t=0;
(3b)将训练样本集作为遥感图像目标检测模型的输入进行前向传播,特征提取子网络对每个训练样本进行特征提取,得到特征图集合f={f1,f2,...,fq,...,fm},其中fq表示第q个训练样本对应的大小为a×b的特征图,q∈m,每个特征图中的每个像素点对应训练样本中的一个区域;
(3c)区域生成子网络生成预测候选框坐标:
(3c1)以fq中每个像素点为中心,生成该像素点的9种初始候选框,得到fq的9×a×b个初始候选框,9代表对每个像素点生成的9种不同纵横比,不同在原图上的对应尺度的候选框数量;9种候选框由三组长宽比:1:1、1:2、2:1,三组在原图上的对应尺度:128×128、256×256、512×512像素的边框排列组合而成。得到初始候选框后,计算每个初始候选框A0i与每个真实框Gr之间的交集与并集的比值IoU,其中,IoU的计算公式为:
IoU表示A0i与Gr之间的交集与并集之间的比值,area(A)代表A0j的面积,area(G)代表真实边框的面积。
Gr表示第r个真实框,r∈R,R表示真实框的个数,R≤I,再根据IoU对所有初始候选框进行筛选,将所筛选的多个初始候选框作为候选框A1i,筛选规则为:当IoU>0.7时,初始候选框包含有目标,标记为1;当IoU<0.3时,初始候选框不包含有目标,标记为0;舍弃0.3<IoU<0.7时的初始候选框;当真实框Gr没有与之对应的初始候选框的IoU>0.7时,将与真实框Gr的IoU最大的框也标记为1;
(3c2)区域生成子网络中的卷积层-ReLU层对每个筛选后的候选框A1i进行特征提取,得到特征图集合f'={f1',f′2,...,f′q,...,f′m};
(3c3)第一分类子网络根据特征图集合f'={f1',f′2,...,f′q,...,f′m}计算每个候选框A1i含有目标的概率pi;第一定位模块根据特征图集合f'={f1',f′2,...,f′q,...,f′m}计算候选框A1i的预测偏移量d′1j和预测候选框坐标[A′1ix,A′1iy,A′1iw,A′1ih];
(3c4)proposal层采用非极大值抑制方法,选取所有候选框中含有目标的概率得分pi排名前K个候选框,其中每个候选框坐标为[A2kx,A2ky,A2kw,A2kh],k∈K;具体做法为:将每个候选框按pi进行排序,设置一个阈值threshold,0≤threshold≤1,仅保留候选框中与真实框Gr的IoU≤threshold的候选框,然后在剩余的候选框中选择pi最大K个候选框。
(3d)ROI Align池化层将前K个候选框的每个候选框的坐标值A2kx,A2ky,A2kw,A2kh分别缩小16倍,映射到特征图fq的对应位置得到K个感兴趣区域;将每个感兴趣区域划分为相同大小的7×7的子图;对每个子图进行最大池化操作,得到统一尺度的K个感兴趣区域;
(3e)分类定位子网络的全连接层将统一尺度的K个感兴趣区域提取特征,得到特征图集合f”={f1”,f″2,...,f″q,...,f″K};第二分类子网络通过f”={f1”,f″2,...,f″q,...,f″K}计算每个候选框内目标的预测类别pc,同时第二定位模块通过f”={f1”,f″2,...,f″q,...,f″K}计算候选框的精确坐标[Akx,Aky,Akw,Akh];
(3f)采用第一分类子网络的改进的交叉熵损失函数FLcls1(pi),通过pi计算自己的损失值FLcls1,采用第一定位模块的损失函数Lreg1(d′1i,d1i),通过d′1i和d1i计算自己的损失值Lreg1,采用第二分类子网络的改进的交叉熵损失函数FLcls2(pc),通过pc计算自己的损失值FLcls2,采用第二定位模块的损失函数Lreg2(d'2k,d2k),通过d'2k和d2k计算自己的损失值Lreg2,再采用反向传播方法,通过FLcls1、Lreg1、FLcls2和Lreg2计算遥感图像目标检测模型参数梯度,然后采用梯度下降算法,通过遥感图像目标检测模型参数梯度对网络参数θ进行更新;
(3g)判断t=T是否成立,若是,得到训练好的基于改进的交叉熵损失函数的遥感图像目标检测模型,否则,令t=t+1,并执行步骤(3b);
(4)获取遥感图像目标的检测结果:
将测试样本集作为训练好的基于改进的交叉熵损失函数的遥感图像目标检测模型的输入进行检测,得到每个目标的类别和边界框四个顶点坐标。
其中,每个目标边界框四个顶点坐标(x0,y0)、(x1,y1)、(x2,y2)、(x3,y3)的计算公式为:
Claims (5)
1.一种基于改进的交叉熵损失函数的遥感图像目标检测方法,其特征在于,包括如下步骤:
(1)获取训练样本集和测试样本集:
从遥感图像数据集中获取包含C种目标类别共N幅带标签的遥感图像,每幅遥感图像至少包含一个目标,并对每个目标类别进行独热编码,将真实的目标类别编码为1,其他的目标类别编码为0;并将N幅带标签的遥感图像中的m幅图像作为训练样本集,将其余带标签的遥感图像作为测试样本集,C≥2,N≥200,
(2)构建基于改进的交叉熵损失函数的遥感图像目标检测模型:
(2a)构建基于Faster R-CNN的遥感图像目标检测模型:
构建包括依次级联的特征提取子网络、区域生成子网络、ROI Align池化层和分类定位子网络,且特征提取子网络的末端与ROI Align池化层相连的Faster R-CNN的遥感图像目标检测模型;其中,特征提取子网络包括多个级联的特征提取模块,每个特征提取模块包含依次连接的多个卷积层-ReLU层和一个最大池化层;区域生成子网络包括依次连接的卷积层-ReLU层、并行连接的第一分类子网络和包含一个卷积层的第一定位模块、proposal层;第一分类子网络包括依次连接的卷积层、reshape层、softmax层、reshape层;分类定位子网络包括依次连接的全连接层、并行连接的第二分类子网络和包含一个全连接层的第二定位模块,第二分类子网络包括级联的全连接层和softmax层;
(2b)定义改进的交叉熵损失函数FL:
FL=FLcls1(pi)+Lreg1(d′1i,d1i)+FLcls2(pc)+Lreg2(d′2k,d2k)
FLcls1(pi)=-(1-pi)γlog(pi),i∈I
FLcls2(pc)=-(1-pc)γlog(pc),c∈C
其中,FLcls1(pi)表示第一分类子网络的改进的交叉熵损失函数,Lreg1(d′1i,d1i)表示第一定位模块的损失函数,FLcls2(pc)表示第二分类子网络的改进的交叉熵损失函数,Lreg2(d′2k,d2k)表示第二定位模块的损失函数;pi表示在FLcls1(pi)中由第一定位子网络生成并筛选的第i个候选框A1i中的内容为目标的概率,i∈I,I为A1i的个数;pc表示在FLcls2(pc)中由第二定位子网络中第k个候选框A2k中的目标为第c类的概率,c∈C,k∈K,K为非极大值抑制中设置的候选框个数,K≤I;(1-pi)γ和(1-pc)γ表示调制因子,γ表示指数参数;d′1i表示A1i的预测偏移量,d′1i=[d′1ix,d′1iy,d′1iw,d′1ih],d′1ix和d′1iy分别表示A1i的中心位置坐标在x轴和y轴的预测偏移量,d′1iw和d′1ih分别表示A1i在宽和高上的预测偏移量;A1i=[A1ix,A1iy,A1iw,A1ih],A1ix和A1iy分别表示A1i的中心位置坐标在x轴和y轴的值,A1iw和A1ih分别表示A1i的宽和高;d1i表示A1i的真实偏移量,d1i=[d1ix,d1iy,d1iw,d1ih],d1ix和d1iy分别表示A1i的中心位置坐标在x轴和y轴的真实偏移量,d1iw和d1ih分别表示A1i在宽和高上的真实偏移量;d′2k表示A2k的预测偏移量,d′2k=[d′2kx,d′2ky,d′2kw,d′2kh],d′2kx和d′2ky分别表示A2k的中心位置坐标在x轴和y轴的预测偏移量,d′2kw和d′2kh分别表示A2k在宽和高上的预测偏移量;A2k=[A2kx,A2ky,A2kw,A2kh]A2kx和A2ky分别表示A2k的中心位置坐标在x轴和y轴的值,A2kw和A2kh分别表示A2k的宽和高;d2k表示A2k的真实偏移量;d2k=[d2kx,d2ky,d2kw,d2kh],d2kx和d2ky分别表示A2k中心位置坐标在x轴和y轴的真实偏移量,d2kw和d2kh分别表示A2k在宽和高上的真实偏移量;smoothL1(x)为平滑损失函数;
(3)对基于改进的交叉熵损失函数的遥感图像目标检测模型进行迭代训练:
(3a)初始化基于Faster R-CNN的遥感图像目标检测模型的网络参数θ,迭代次数为t,最大迭代次数为T,T=20,并令t=0;
(3b)将训练样本集作为遥感图像目标检测模型的输入进行前向传播,特征提取子网络对每个训练样本进行特征提取,得到特征图集合f={f1,f2,...,fq,...,fm},其中fq表示第q个训练样本对应的大小为a×b的特征图,q∈m,每个特征图中的每个像素点对应训练样本中的一个区域;
(3c)区域生成子网络生成预测候选框坐标:
(3c1)以fq中每个像素点为中心,生成该像素点的9种初始候选框,得到fq的9×a×b个初始候选框,并计算每个初始候选框A0i与每个真实框Gr之间的交集与并集的比值IoU,Gr表示第r个真实框,r∈R,R表示真实框的个数,R≤I,再根据IoU对所有初始候选框进行筛选,将所筛选的多个初始候选框作为候选框A1i,筛选规则为:当IoU>0.7时,初始候选框包含有目标,标记为1;当IoU<0.3时,初始候选框不包含有目标,标记为0;舍弃0.3<IoU<0.7时的初始候选框;当真实框Gr没有与之对应的初始候选框的IoU>0.7时,将与真实框Gr的IoU最大的框也标记为1;
(3c2)区域生成子网络中的卷积层-ReLU层对每个筛选后的候选框A1i进行特征提取,得到特征图集合f'={f′1,f′2,...,f′q,...,f′m};
(3c3)第一分类子网络根据特征图集合f'={f′1,f′2,...,f′q,...,f′m}计算每个候选框A1i含有目标的概率pi;第一定位模块根据特征图集合f'={f′1,f′2,...,f′q,...,f′m}计算候选框A1i的预测偏移量d′1j和预测候选框坐标[A′1ix,A′1iy,A′1iw,A′1ih];
(3c4)proposal层采用非极大值抑制方法,选取所有候选框中含有目标的概率得分pi排名前K个候选框,其中每个候选框坐标为[A2kx,A2ky,A2kw,A2kh],k∈K;
(3d)ROI Align池化层将前K个候选框的每个候选框的坐标值A2kx,A2ky,A2kw,A2kh分别缩小16倍,映射到特征图fq的对应位置得到K个感兴趣区域;将每个感兴趣区域划分为相同大小的7×7的子图;对每个子图进行最大池化操作,得到统一尺度的K个感兴趣区域;
(3e)分类定位子网络的全连接层将统一尺度的K个感兴趣区域提取特征,得到特征图集合f″={f″1,f″2,...,f″q,...,f″K};第二分类子网络通过f″={f″1,f″2,...,f″q,...,f″K}计算每个候选框内目标的预测类别pc,同时第二定位模块通过f″={f″1,f″2,...,f″q,...,f″K}计算候选框的精确坐标[Akx,Aky,Akw,Akh];
(3f)采用第一分类子网络的改进的交叉熵损失函数FLcls1(pi),通过pi计算自己的损失值FLcls1,采用第一定位模块的损失函数Lreg1(d′1i,d1i),通过d′1i和d1i计算自己的损失值Lreg1,采用第二分类子网络的改进的交叉熵损失函数FLcls2(pc),通过pc计算自己的损失值FLcls2,采用第二定位模块的损失函数Lreg2(d′2k,d2k),通过d′2k和d2k计算自己的损失值Lreg2,再采用反向传播方法,通过FLcls1、Lreg1、FLcls2和Lreg2计算遥感图像目标检测模型参数梯度,然后采用梯度下降算法,通过遥感图像目标检测模型参数梯度对网络参数θ进行更新;
(3g)判断t=T是否成立,若是,得到训练好的基于改进的交叉熵损失函数的遥感图像目标检测模型,否则,令t=t+1,并执行步骤(3b);
(4)获取遥感图像目标的检测结果:
将测试样本集作为训练好的基于改进的交叉熵损失函数的遥感图像目标检测模型的输入进行检测,得到每个目标的类别和边界框四个顶点坐标。
2.根据权利要求1所述的基于改进的损失函数的遥感图像目标检测方法,其特征在于,步骤(2a)中所述特征提取网络包括依次连接的5个特征提取模块;其中,模块1和模块2由2个卷积层-ReLU层和一个最大池化层组成;模块3,模块4,模块5由3个卷积层-ReLU层和一个最大池化层组成。
4.根据权利要求1所述的基于改进的交叉熵损失函数的遥感图像目标检测方法,其特征在于,步骤(3c4)中所述的非极大值抑制方法,具体做法为:
将每个候选框按pi进行排序,设置一个阈值threshold,0≤threshold≤1,仅保留候选框中与真实框Gr的IoU≤threshold的候选框,然后在剩余的候选框中选择pi最大K个候选框。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011462894.8A CN112528862B (zh) | 2020-12-10 | 2020-12-10 | 基于改进的交叉熵损失函数的遥感图像目标检测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011462894.8A CN112528862B (zh) | 2020-12-10 | 2020-12-10 | 基于改进的交叉熵损失函数的遥感图像目标检测方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112528862A true CN112528862A (zh) | 2021-03-19 |
CN112528862B CN112528862B (zh) | 2023-02-10 |
Family
ID=74999334
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011462894.8A Active CN112528862B (zh) | 2020-12-10 | 2020-12-10 | 基于改进的交叉熵损失函数的遥感图像目标检测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112528862B (zh) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112926510A (zh) * | 2021-03-25 | 2021-06-08 | 深圳市商汤科技有限公司 | 异常驾驶行为识别方法及装置、电子设备和存储介质 |
CN113223017A (zh) * | 2021-05-18 | 2021-08-06 | 北京达佳互联信息技术有限公司 | 目标分割模型的训练方法、目标分割方法及设备 |
CN113516639A (zh) * | 2021-06-30 | 2021-10-19 | 哈尔滨工业大学(深圳) | 基于全景x光片的口腔异常检测模型的训练方法及装置 |
CN114627373A (zh) * | 2022-02-25 | 2022-06-14 | 北京理工大学 | 一种面向遥感图像目标检测模型的对抗样本生成方法 |
CN114821201A (zh) * | 2022-06-28 | 2022-07-29 | 江苏广坤铝业有限公司 | 铝材加工的液压式撞角机及其使用方法 |
CN115082740A (zh) * | 2022-07-18 | 2022-09-20 | 北京百度网讯科技有限公司 | 目标检测模型训练方法、目标检测方法、装置、电子设备 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109711288A (zh) * | 2018-12-13 | 2019-05-03 | 西安电子科技大学 | 基于特征金字塔和距离约束fcn的遥感船舶检测方法 |
CN109919108A (zh) * | 2019-03-11 | 2019-06-21 | 西安电子科技大学 | 基于深度哈希辅助网络的遥感图像快速目标检测方法 |
AU2019101133A4 (en) * | 2019-09-30 | 2019-10-31 | Bo, Yaxin MISS | Fast vehicle detection using augmented dataset based on RetinaNet |
CN110874593A (zh) * | 2019-11-06 | 2020-03-10 | 西安电子科技大学 | 基于掩膜的遥感图像旋转目标检测方法 |
CN110991535A (zh) * | 2019-12-04 | 2020-04-10 | 中山大学 | 一种基于多类型医学数据的pCR预测方法 |
CN111091105A (zh) * | 2019-12-23 | 2020-05-01 | 郑州轻工业大学 | 基于新的边框回归损失函数的遥感图像目标检测方法 |
WO2020181685A1 (zh) * | 2019-03-12 | 2020-09-17 | 南京邮电大学 | 一种基于深度学习的车载视频目标检测方法 |
WO2020187153A1 (zh) * | 2019-03-21 | 2020-09-24 | 腾讯科技(深圳)有限公司 | 目标检测方法、模型训练方法、装置、设备及存储介质 |
CN111985376A (zh) * | 2020-08-13 | 2020-11-24 | 湖北富瑞尔科技有限公司 | 一种基于深度学习的遥感影像舰船轮廓提取方法 |
-
2020
- 2020-12-10 CN CN202011462894.8A patent/CN112528862B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109711288A (zh) * | 2018-12-13 | 2019-05-03 | 西安电子科技大学 | 基于特征金字塔和距离约束fcn的遥感船舶检测方法 |
CN109919108A (zh) * | 2019-03-11 | 2019-06-21 | 西安电子科技大学 | 基于深度哈希辅助网络的遥感图像快速目标检测方法 |
WO2020181685A1 (zh) * | 2019-03-12 | 2020-09-17 | 南京邮电大学 | 一种基于深度学习的车载视频目标检测方法 |
WO2020187153A1 (zh) * | 2019-03-21 | 2020-09-24 | 腾讯科技(深圳)有限公司 | 目标检测方法、模型训练方法、装置、设备及存储介质 |
AU2019101133A4 (en) * | 2019-09-30 | 2019-10-31 | Bo, Yaxin MISS | Fast vehicle detection using augmented dataset based on RetinaNet |
CN110874593A (zh) * | 2019-11-06 | 2020-03-10 | 西安电子科技大学 | 基于掩膜的遥感图像旋转目标检测方法 |
CN110991535A (zh) * | 2019-12-04 | 2020-04-10 | 中山大学 | 一种基于多类型医学数据的pCR预测方法 |
CN111091105A (zh) * | 2019-12-23 | 2020-05-01 | 郑州轻工业大学 | 基于新的边框回归损失函数的遥感图像目标检测方法 |
CN111985376A (zh) * | 2020-08-13 | 2020-11-24 | 湖北富瑞尔科技有限公司 | 一种基于深度学习的遥感影像舰船轮廓提取方法 |
Non-Patent Citations (4)
Title |
---|
SHAOQING REN 等: "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks", 《IEEE TRANSACTIONS ON PATTERN ANALYSIS AND MACHINE INTELLIGENCE》 * |
WANJUN WEI 等: "Remote Sensing Image Aircraft Detection Based on Feature Fusion across Deep Learning Framework", 《2019 IEEE 10TH INTERNATIONAL CONFERENCE ON SOFTWARE ENGINEERING AND SERVICE SCIENCE (ICSESS)》 * |
李放: "基于深度卷积神经网络的高分辨率图像目标检测研究", 《中国优秀硕士学位论文全文数据库 工程科技Ⅱ辑》 * |
杨康: "基于多尺度特征与模型压缩加速的光学遥感图像目标检测", 《中国优秀硕士学位论文全文数据库 工程科技Ⅱ辑》 * |
Cited By (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112926510A (zh) * | 2021-03-25 | 2021-06-08 | 深圳市商汤科技有限公司 | 异常驾驶行为识别方法及装置、电子设备和存储介质 |
CN113223017A (zh) * | 2021-05-18 | 2021-08-06 | 北京达佳互联信息技术有限公司 | 目标分割模型的训练方法、目标分割方法及设备 |
CN113516639A (zh) * | 2021-06-30 | 2021-10-19 | 哈尔滨工业大学(深圳) | 基于全景x光片的口腔异常检测模型的训练方法及装置 |
CN113516639B (zh) * | 2021-06-30 | 2023-05-12 | 哈尔滨工业大学(深圳) | 基于全景x光片的口腔异常检测模型的训练方法及装置 |
CN114627373A (zh) * | 2022-02-25 | 2022-06-14 | 北京理工大学 | 一种面向遥感图像目标检测模型的对抗样本生成方法 |
CN114821201A (zh) * | 2022-06-28 | 2022-07-29 | 江苏广坤铝业有限公司 | 铝材加工的液压式撞角机及其使用方法 |
CN114821201B (zh) * | 2022-06-28 | 2022-09-20 | 江苏广坤铝业有限公司 | 铝材加工的液压式撞角机及其使用方法 |
CN115082740A (zh) * | 2022-07-18 | 2022-09-20 | 北京百度网讯科技有限公司 | 目标检测模型训练方法、目标检测方法、装置、电子设备 |
CN115082740B (zh) * | 2022-07-18 | 2023-09-01 | 北京百度网讯科技有限公司 | 目标检测模型训练方法、目标检测方法、装置、电子设备 |
Also Published As
Publication number | Publication date |
---|---|
CN112528862B (zh) | 2023-02-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112528862B (zh) | 基于改进的交叉熵损失函数的遥感图像目标检测方法 | |
CN108229397B (zh) | 基于Faster R-CNN的图像中文本检测方法 | |
CN113065558A (zh) | 一种结合注意力机制的轻量级小目标检测方法 | |
CN109684922B (zh) | 一种基于卷积神经网络的多模型对成品菜的识别方法 | |
CN111062885B (zh) | 基于多阶段迁移学习的标志检测模型训练及标志检测方法 | |
CN108647665A (zh) | 基于深度学习的航拍车辆实时检测方法 | |
CN111833322B (zh) | 一种基于改进YOLOv3的垃圾多目标检测方法 | |
CN111967313B (zh) | 一种深度学习目标检测算法辅助的无人机图像标注方法 | |
CN110796143A (zh) | 一种基于人机协同的场景文本识别方法 | |
CN109766835A (zh) | 基于多参数优化生成对抗网络的sar目标识别方法 | |
CN111126278B (zh) | 针对少类别场景的目标检测模型优化与加速的方法 | |
CN112801182B (zh) | 一种基于困难样本感知的rgbt目标跟踪方法 | |
CN111062441A (zh) | 基于自监督机制和区域建议网络的场景分类方法及装置 | |
CN112348758B (zh) | 一种光学遥感图像数据增强方法及目标识别方法 | |
CN111310609B (zh) | 基于时序信息和局部特征相似性的视频目标检测方法 | |
CN110929746A (zh) | 一种基于深度神经网络的电子卷宗标题定位提取与分类方法 | |
CN110245587B (zh) | 一种基于贝叶斯迁移学习的光学遥感图像目标检测方法 | |
CN110443862A (zh) | 基于无人机的岩性填图方法及系统、电子设备 | |
CN111178438A (zh) | 一种基于ResNet101的天气类型识别方法 | |
CN115620393A (zh) | 一种面向自动驾驶的细粒度行人行为识别方法及系统 | |
CN114139616A (zh) | 一种基于不确定性感知的无监督域适应目标检测方法 | |
CN113392930A (zh) | 基于多层次分治网络的交通标志目标检测方法 | |
CN115205727A (zh) | 一种基于无监督学习的实验智能评分方法和系统 | |
CN111507416A (zh) | 一种基于深度学习的吸烟行为实时检测方法 | |
CN114882204A (zh) | 船名自动识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |