CN115240052A - 一种目标检测模型的构建方法及装置 - Google Patents
一种目标检测模型的构建方法及装置 Download PDFInfo
- Publication number
- CN115240052A CN115240052A CN202210993151.6A CN202210993151A CN115240052A CN 115240052 A CN115240052 A CN 115240052A CN 202210993151 A CN202210993151 A CN 202210993151A CN 115240052 A CN115240052 A CN 115240052A
- Authority
- CN
- China
- Prior art keywords
- target detection
- network
- prediction
- module
- auxiliary head
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/22—Image preprocessing by selection of a specific region containing or referencing a pattern; Locating or processing of specific regions to guide the detection or recognition
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/26—Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/766—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using regression, e.g. by projecting features on hyperplanes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
- G06V10/806—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/70—Labelling scene content, e.g. deriving syntactic or semantic representations
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及计算机视觉技术领域,解决了现有建模方法越来越复杂的技术问题,尤其涉及一种目标检测模型的构建方法,包括以下步骤:根据基础目标检测网络,通过采用多层注意力机制构建网络预测辅助头;根据网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本,并对预设的训练损失函数进行优化,得到优化后的训练损失函数,设计新的加权策略;采用优化后的训练损失函数和新的加权策略对基础目标检测网络进行优化;对优化后的基础目标检测网络进行训练,得到目标检测模型。本发明通过构建目标检测网络预测辅助头并计算网络预测辅助头的预测结果,从而能够在无需复杂建模的情况下提升目标检测模型的性能和检测精度。
Description
技术领域
本发明涉及计算机视觉技术领域,尤其涉及一种目标检测模型的构建方法及装置。
背景技术
目标检测是计算机视觉领域中的一个基础视觉识别任务,视觉目标检测即在给定图像中找出属于特定目标类别的对象及其准确位置,并为每个对象分配对应的类别标签。
目前,基于深度学习的视觉目标检测技术取得了重大突破,但是,现有技术中基于深度学习的视觉目标检测模型构建方法,在享受网络预测结果带来的正向收益的同时,却从未明确说明网络预测的结果对模型训练的重要性,反而在建模方法上越来越复杂,增加了建模难度,使得简单的问题变得复杂化。为此,本文提出了一种基于网络预测辅助训练的目标检测模型构建方法及装置。
发明内容
针对现有技术的不足,本发明提供了一种目标检测模型的构建方法及装置,解决了现有建模方法越来越复杂的技术问题,达到了无需复杂的数学建模,便能显著提高模型的检测精度的目的。
为解决上述技术问题,本发明提供了如下技术方案:一种目标检测模型的构建方法,包括以下步骤:
S1、根据目标检测标签分配方向通用算法,确定基础目标检测网络;
S2、采用多层注意力机制根据所述基础目标检测网络,构建网络预测辅助头;
S3、根据所述网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本;
S4、根据所述网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的训练损失函数,并设计新的加权策略;
S5、采用优化后的训练损失函数和新的加权策略,对所述基础目标检测网络进行优化;
S6、根据随机梯度下降法,对优化后的基础目标检测网络进行优化训练,得到目标检测模型。
进一步地,所述步骤S2的具体步骤如下:
S21、根据基础目标检测网络的特征金字塔模块输出的初步特征图,得到动态任务感知模块增强后的第一特征图;
S22、根据动态任务感知模块得到的第一特征图,进行多次卷积操作得到增强后的第二特征图;
S23、根据动态任务交互增强模块得到的第二特征图,通过不同的卷积操作得到网络预测辅助头。
进一步地,步骤S3中所述网络预测辅助头的预测结果包括分类任务分支预测结果、回归任务分支预测结果以及质量评估分支预测结果。
进一步地,所述步骤S3具体包括以下步骤:
S31、获取基础目标检测网络的特征金字塔模块输出的所有特征图上生成的锚框;
S32、根据获取的锚框和网络预测辅助头的预测结果,得到网络预测的分类、回归和质量评估得分,并构造预测样本评分函数和先验样本评分函数;
S33、根据网络预测的分类、回归和质量评估得分,计算每个锚框的综合得分,并合计所有锚框的回归得分之和k;
S34、根据先验样本评分函数,选取前Top-k锚框作为正样本锚框候选集;
S35、根据正样本锚框候选集,动态更新k值,使其为正样本锚框候选集中所有锚框的回归得分之和;
S36、根据预测样本评分函数和更新后的k值,选取前Top-k锚框作为正样本,剩余锚框设置为负样本。
进一步地,所述预测样本评分函数指基于网络预测辅助头中分类任务分支和质量评估任务分支的预测结果,得到的分类得分和中心度得分的乘积;所述先验样本评分函数指基于网络预测辅助头中回归任务分支的预测结果以及真实标签中心点先验,得到的回归得分和中心先验得分的乘积。
进一步地,步骤S4中所述优化后的训练损失函数的表达式为:
L=αLcls+βLloc+γLcenter-iou
式中,α、β、γ是训练中用到的超参数,Lcls是分类任务分支的损失函数,Lreg是回归任务分支的损失函数,Lcenter-iou是质量评估任务分支的损失函数。
本发明还提供了一种技术方案:一种目标检测模型的构建装置,包括:
基础目标检测网络确定模块,所述基础目标检测网络确定模块用于根据目标检测标签分配方向通用算法,确定基础目标检测网络;
网络预测辅助头构建模块,所述网络预测辅助头构建模块用于采用多层注意力机制根据所述基础目标检测网络,构建网络预测辅助头;
正负样本划分模块,所述正负样本划分模块用于根据网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本;
损失函数优化模块,所述损失函数优化模块用于根据所述网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的训练损失函数,并设计新的加权策略;
模型优化模块,所述模型优化模块用于采用优化后的训练损失函数和新的加权策略,对所述基础目标检测网络进行优化;
模型训练模块,所述模型训练模块用于根据随机梯度下降法,对优化后的基础目标检测网络进行优化训练,得到目标检测模型。
进一步地,所述网络预测辅助头构建模块包括:
第一特征图获取单元,所述第一特征图获取单元用于根据基础目标检测网络的特征金字塔模块输出的初步特征图,得到动态任务感知模块增强后的第一特征图;
第二特征图获取单元,所述第二特征图获取单元用于根据动态任务感知模块得到的第一特征图,进行多次卷积操作得到增强后的第二特征图;
网络预测辅助头生成单元,所述网络预测辅助头生成单元用于根据动态任务交互增强模块得到的第二特征图,通过不同的卷积操作得到网络预测辅助头。
进一步地,所述正负样本划分模块包括:
锚框获取单元,所述锚框获取单元用于获取基础目标检测网络的特征金字塔模块输出的所有特征图上生成的锚框;
第一计算单元,所述第一计算单元用于根据获取的锚框和网络预测辅助头的预测结果,得到网络预测的分类、回归和质量评估得分,并构造预测样本评分函数和先验样本评分函数;
第二计算单元,所述第二计算单元用于根据网络预测的分类、回归和质量评估得分,计算每个锚框的综合得分,并合计所有锚框的回归得分之和k;
正样本锚框候选集生成单元,所述正样本锚框候选集生成单元用于根据先验样本评分函数,选取前Top-k锚框作为正样本锚框候选集;
动态更新单元,所述动态更新单元用于根据正样本锚框候选集,动态更新k值,使其为正样本锚框候选集中所有锚框的回归得分之和;
正负样本生成单元,所述正负样本生成单元用于根据预测样本评分函数和更新后的k值,选取前Top-k锚框作为正样本,剩余锚框设置为负样本。
借由上述技术方案,本发明提供了一种目标检测模型的构建方法及装置,至少具备以下有益效果:
1、本发明通过构建目标检测网络预测辅助头,可以在保证优秀的局部特征提取能力的同时,更大程度地增加全局建模能力,减少目标检测模型对复杂背景的关注,增加对前景目标的关注,提升目标检测头多任务分支的交互能力,显著提高目标检测头的表达能力,达到了无需复杂的数学建模,便能显著提高检测精度的目的。
2、本发明通过对当前一阶段目标检测器ATSS进行改进,通过训练得到最优的目标检测模型,并在COCO数据集上进行前向推理实验,实验结果证明网络预测结果在目标检测模型训练中非常重要,从而能够在无需复杂建模的情况下提升目标检测模型的性能,显著提升了检测精度,且能够在相同训练条件下超越最先进的一阶段目标检测模型,具有较高的社会价值和应用前景。
附图说明
此处所说明的附图用来提供对本申请的进一步理解,构成本申请的一部分,本申请的示意性实施例及其说明用于解释本申请,并不构成对本申请的不当限定。在附图中:
图1为本发明提供的目标检测模型构建方法的流程图;
图2为本发明提供的目标检测模型构建方法的构建网络预测辅助头的流程图;
图3为本发明提供的目标检测模型构建方法的网络预测辅助头的网络架构示意图;
图4为本发明提供的目标检测模型构建方法的网络预测参与训练前后效果对比图;
图5为本发明提供的目标检测模型构建方法的划分正负样本的流程图;
图6为本发明提供的目标检测模型构建方法的训练架构示意图;
图7为本发明提供的目标检测系统的原理框图;
图8为本发明提供的目标检测系统的网络预测辅助头构建模块的框图;
图9为本发明提供的目标检测系统的正负样本划分模块的框图。
图中:10、基础目标检测网络确定模块;20、网络预测辅助头构建模块;201、第一特征图获取单元;202、第二特征图获取单元;203、网络预测辅助头生成单元;30、正负样本划分模块;301、锚框获取单元;302、第一计算单元;303、第二计算单元;304、正样本锚框候选集生成单元;305、动态更新单元;306、正负样本生成单元;40、损失函数优化模块;50、模型优化模块;60、模型训练模块。
具体实施方式
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本发明作进一步详细的说明。借此对本申请如何应用技术手段来解决技术问题并达成技术功效的实现过程能充分理解并据以实施。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
请参照图1-图6,示出了根据本实施例的一种目标检测模型的构建方法,如图1所示,包括以下步骤:
S1、根据目标检测标签分配方向通用算法,确定基础目标检测网络。
具体的,将目标检测标签分配方向通用算法《Bridging the Gap BetweenAnchor-based and Anchor-free Detection via Adaptive Training SampleSelection》(简称为ATSS)的网络结构,确定为基础目标检测网络,该基础目标检测网络包括骨干网络、特征金字塔模块、通用目标检测头、标签分配模块以及损失函数模块。
S2、采用多层注意力机制根据基础目标检测网络,构建网络预测辅助头。
具体的,通过对输入图像进行初步特征提取后,得到包含不同尺度以及语义信息的特征图后,采用多层注意力机制从尺度感知、空间感知和任务感知等角度进行重新设计,可得到网络预测辅助头,提升目标检测头的表达能力,得到可靠的预测结果。如图2和图3所示,构建网络预测辅助头的具体步骤如下:
S21、根据基础目标检测网络的特征金字塔模块输出的初步特征图,得到动态任务感知模块增强后的第一特征图;
具体的,通过动态任务感知模块分别从尺度感知、空间感知和任务感知等角度,使用堆叠多种注意力机制来增强初步特征图上不同尺度不同语义信息的动态提取,同时加强不同分支任务之间的交互,能够提升目标检测模型的预测能力。其中,动态任务感知模块的具体实现步骤如下:
1)输入来自基础目标检测网络的特征金字塔模块输出的初步特征图Rl;
2)基于三个维度S×L×C使用三个序列化的注意力,每个注意力仅仅关注一个维度,具体通过以下公式实现:
W(Rl)=π(Rl)·Rl
π(Rl)=C(S(L(Rl)·Rl)·Rl)·Rl
式中,W(·)是统一注意力的动态任务感知目标检测头方法,L(·)为尺度感知注意力模块,S(·)为空间感知注意力模块,C(·)为任务感知注意力模块。
在特征层级之间做尺度感知,基于语义重要性对不同尺度的特征进行融合,L(·)的具体公式如下:
式中,f(·)为线性函数,近似为1×1卷积操作,σ(·)为hard-sigmoid函数,且σ(·)的表达式如下:
在空间位置之间做空间感知,聚焦于不同空间位置的判别能力,考虑到S的高维度,需要对空间感知注意力模块进行解耦,分为两个步骤:首先使用可变性卷积学习稀疏化,然后在同一空间位置上跨层级聚合特征;S(·)的具体公式如下:
式中,K为稀疏采样的位置数量,pk+Δpk为进行位置偏移以聚焦于有判别力的区域,Δmk为一个关于位置pk的可自学习的重要性度量因子。
在输出通道之间做任务感知,为了能够进行联合学习与目标表示的泛化性,使用任务感知注意力模块C(·),动态地开关特征通道来选择不同的任务,C(·)的具体公式如下:
式中,[α1,α2,β1,β2]T=θ(·)是超函数,用来学习控制激活阈值。
3)输出动态任务感知模块增强后的第一特征图Xfpn。
S22、根据动态任务感知模块得到的第一特征图,进行多次卷积操作得到增强后的第二特征图。
具体的,根据动态任务感知模块得到的第一特征图Xfpn,先进行6次卷积操作得到然后使用动态任务交互增强模块提取任务交互特征,这样可以提升模型在多任务分支下的预测能力。其中,动态任务交互增强模块的具体实现步骤如下:
1)输入动态任务感知模块的增强后的第一特征图Xfpn;
2)提取第一特征图Xfpn中的任务交互特征,具体公式如下:
3)计算跨层级之间的交互特征,具体公式如下:
4)输出动态任务交互增强模块增强后的第二特征图Ztask,具体公式如下:
Ztask=conv2(δ(conv1(Xtask)))
S23、根据动态任务交互增强模块得到的第二特征图,通过不同的卷积操作得到网络预测辅助头。
根据获取的网络预测辅助头的各个分支的预测结果,使用sigmoid函数得到网络预测辅助头分类任务分支和质量评估分支的得分矩阵,使用IOU函数得到网络预测辅助头的回归任务分支得分矩阵。
其中,sigmoid函数表达式如下:
IOU函数采用如下公式:
式中,P表示模型预测的锚框,G表示真实边界框。
本实施例中,如图4所示,图4(a)是模型训练初期的预测能力热力图,图4(b)是有网络预测参与指导的模型,图4(c)是网络预测参与指导后模型预测能力热力图,其中,图4(b)中的“gt”表示G,也称之为真实边界框;“bbox_pred”表示P,即模型预测的锚框;“center”表示真实目标框的中心点。
通过构建网络预测辅助头,可以在保证优秀的局部特征提取能力的同时,更大程度地增加全局建模能力,减少模型对复杂背景的关注,增加对前景目标的关注,提升目标检测头多任务分支的交互能力,以便提高目标检测头的表达能力。
S3、根据网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本。
具体的,基于当前得到的网络预测辅助头的分类分支预测结果回归分支预测结果以及质量评估分支预测结果在标签分配任务中动态自适应寻优划分正负样本,随着模型拟合能力的增强,匹配算法的寻优结果就越好,引导模型朝着更好的方向优化。如图5所示,划分正负样本的具体步骤如下:
S31、获取基础目标检测网络的特征金字塔模块输出的所有特征图上生成的锚框。
需要说明的是,锚框也称之为先验框(也简称为Anchor),是基于特征图生成的一系列预定义的锚框,真实边界框也被称之为ground truth bounding box(也简称为GT框)。
本实施例中,使用基于锚框的目标检测网络模型在特征图上的每一个位置均设置一个宽高比为1∶1的先验框,联合使用分类和回归任务并通过锚框来识别目标物体。
S32、根据获取的锚框和网络预测辅助头的预测结果,得到网络预测的分类、回归和质量评估得分,并构造预测样本评分函数和先验样本评分函数。
具体的,基于锚框的目标检测器,获得网络预测头的各个分支的结果,将分类头的预测结果经过sigmoid函数得到当前模型的分类得分矩阵;将回归头的预测结果记过计算IOU交并比得到当前模型的回归得分矩阵;将质量评估头的预测结果经过sigmoid函数得到当前模型的质量评估得分矩阵。
本实施例中,设计了两种不同的样本评价体系,分别为预测样本评分函数和先验样本评分函数,具体设计步骤如下:
第一种,基于网络预测辅助头中分类任务分支和质量评估分支的预测结果,计算分类得分和中心度得分的乘积来评价当前模型推理能力下每个锚框的质量,得到预测样本评分函数Snet,具体公式如下:
Snet=sigmoid(Scls(fθ(a,i),g))×sigmoid(Scenterness(fθ(a,i),g))
式中,Scls和Scenterness分别是网络预测的每个锚框的分类头的得分和中心度分支头的得分,a和g表示模型预测的锚框和对应的真实标签,i和fθ分别是输入图像和带参数的模型。
第二种,基于网络预测辅助头中回归任务分支的预测结果以及真实标签中心点先验,计算回归得分和中心先验得分的乘积来得到先验样本评分函数Sinp:具体公式如下:
Sin_gt=centerness(fθ(a,i),g),
式中,Sloc是网络预测辅助头的回归得分,因为定位头的输出结果是经过编码后的位置偏移量,所以解码后通过IOU函数计算a和g之间的IOU(Intersection-over-Union)来得到定位得分Sloc;Sin_gt是通过centerness(·)方法计算而得出的每个a中心点落在g内的得分,l*、r*、t*和b*分别表示g的中心点到真实标签框的左右上下边的距离。
S33、根据网络预测的分类、回归和质量评估得分,计算每个锚框的综合得分,并合计所有锚框的回归得分之和k。
具体的,针对每一个真实锚框,根据得到的分类得分矩阵、回归得分矩阵以及质量评估得分矩阵,可计算出每一个锚框的综合得分,本实施例中,设定一个初始的K值,将其设置为所有锚框的综合得分之和。
S34、根据先验样本评分函数,选取前Top-k锚框作为正样本锚框候选集。
具体的,通过根据先验样本评分函数Sinp,基于锚框的综合得分使用动态Top-k进行排序,并选取前Top-k锚框作为正样本锚框候选集。
S35、根据正样本锚框候选集,动态更新k值,使其为正样本锚框候选集中所有锚框的回归得分之和。
具体的,根据正样本锚框候选集,更新动态K值,使其为所有正样本锚框候选集中锚框的回归得分之和,并且将其取值范围限制在0~40之间。
S36、根据预测样本评分函数和更新后的k值,选取前Top-k锚框作为正样本,剩余锚框设置为负样本。
S4、根据所述网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的训练损失函数,并设计新的加权策略。
具体的,通过根据根据网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的用于动态调整每个样本在训练中贡献度的训练损失函数。
本实施例中,主要联合三个分支任务来对目标检测网络模型进行训练,设样本在训练过程中的总体损失函数设为L,则L的具体表达式如下:
L=αLcls+βLloc+γLcenter-iou
式中,α、β、γ是训练中用到的超参数,Lcls是分类任务分支的损失函数,Lreg是回归任务分支的损失函数,Lcenter-iou是质量评估任务分支的损失函数。
1)针对分类任务分支的损失函数Lcls,具体公式如下:
2)针对回归任务分支的损失函数Lreg,具体公式如下:
3)针对质量评估任务分支损失函数Lcenter-iou,具体公式如下:
S5、采用优化后的训练损失函数和新的加权策略,对基础目标检测网络进行优化。
具体的,通过采用优化后的训练损失函数L和新的加权策略,对基础目标检测网络进行优化,实现平衡难易样本训练。
需要说明的是,在标签分配任务中选择出符合条件的正样本后,由于预定义的正样本先验框和目标框的重叠面积并不是完全相同,重叠面积大的为易优化样本,重叠面积小的为难优化样本,并且大部分正样本先验框和目标框的重叠面积非常小,即存在大量难优化样本,在基础目标检测网络优化过程中难优化样本会产生更大的梯度,容易主导模型朝着错误的方向优化,所以训练过程中存在难易样本不平衡问题,因而需要根据得到的预测样本评分函数和先验样本评分函数,设计新的样本加权策略,拉低难优化样本产生的梯度,提升易优化样本的梯度,来进行平衡难易样本训练,可以引导模型朝着正确的方向优化。
S6、根据随机梯度下降法,对优化后的基础目标检测网络进行优化训练,得到目标检测模型。
具体的,根据随机梯度下降法,对正负样本在训练中的贡献度进行动态调整并保存每个样本在训练中最好的权重,以获得最优的目标检测模型。
本发明通过构建目标检测网络预测辅助头,可以在保证优秀的局部特征提取能力的同时,更大程度地增加全局建模能力,减少目标检测模型对复杂背景的关注,增加对前景目标的关注,提升目标检测头多任务分支的交互能力,显著提高目标检测头的表达能力,达到了无需复杂的数学建模,便能显著提高检测精度的目的。
本实施例中,可通过以下实验结果证明网络预测结果在目标检测模型训练中的重要性。
1、实验条件
实验所使用的数据是COCO数据集,COCO数据集是一个可用于图像检测(imagedetection),语义分割(semantic segmentation)和图像标题生成(image captioning)的大规模数据集。它有超过330K张图像(其中220K张是有标注的图像),包含150万个目标,80个目标类别(行人、汽车、大象等),91种材料类别(草、墙、天空等),每张图像包含五句图像的语句描述,且有250,000个带有关键点标注的行人。实验所用系统环境为Linux操作系统,显卡为NVidia RTX 3090GPU。
2、实验步骤
基于上述实验数据进行实验的具体步骤如下:先将COCO数据集图像中的训练集数据输入到目标检测模型进行训练,并设定优化方法为随机梯度下降优化方法,动量设置为0.9,初始学习率为0.005,训练12个批次,在第3个批次,第11个批次学习率降为原来的1/10,经过训练优化后保存训练结果最好的批次的权重,并将上述保存的最好批次的权重加载至目标检测模型中,最后将COCO数据集图像中的验证集数据输入到最优目标检测模型中进行测试,并输出预测结果。
3、实验结果
如表1所示,对比其它不同目标检测模型的检测结果,本发明构建的目标检测模型实现AP(IOU=0.5∶0.95)提高4.7个百分点,大目标检测提高7.6个百分点,中等目标检测提高5.0个百分点,小目标检测提高3.1个百分点,且在相同训练条件下,检测精度明显优于当前先进的一阶段目标检测模型。
表1 本发明与其它检测方法的检测结果对比表(单位:%)
Method | Iteration | Backbone | AP | AP<sub>50</sub> | AP<sub>75</sub> | AP<sub>s</sub> | AP<sub>m</sub> | AP<sub>l</sub> | Reference |
RetinaNet | 90K | ResNet50 | 36.5 | 52.6 | 39.3 | 21.9 | 40.5 | 47.7 | ICCV17 |
FCOS | 90K | ResNet50 | 38.7 | 57.5 | 41.7 | 22.6 | 42.7 | 49.9 | ICCV19 |
FreeAnchor | 90K | ResNet50 | 38.4 | 57.0 | 41.1 | 21.9 | 41.7 | 51.8 | NeurIPS19 |
ATSS | 90K | ResNet50 | 39.4 | 57.5 | 42.7 | 22.9 | 42.9 | 51.2 | CVPR20 |
PAA(w/Voting) | 90K | ResNet50 | 40.4 | 58.4 | 42.9 | 22.9 | 44.3 | 54.0 | ECCV20 |
AutoAssign | 90K | ResNet50 | 40.4 | 59.6 | 43.7 | 22.7 | 44.1 | 52.9 | - |
OTA | 90K | ResNet50 | 40.7 | 58.4 | 44.3 | 23.2 | 45.0 | 53.6 | CVPR21 |
DDOD | 90K | ResNet50 | 41.7 | 60.0 | 45.3 | 23.6 | 44.8 | 55.1 | ACM MM21 |
TOOD | 90K | ResNet50 | 42.4 | 59.7 | 46.2 | 25.4 | 45.5 | 55.5 | ICCV21 |
DyHead | 90K | ResNet50 | 43.0 | 60.7 | 46.8 | 24.7 | 46.4 | 53.9 | CVPR21 |
本发明方法 | 90K | ResNet50 | 44.1 | 61.1 | 47.9 | 26.0 | 47.9 | 58.8 | - |
本实施例中,如图6所示,本发明是基于表1中ATSS作为基线模型,进行了一系列改进,通过训练得到的最优目标检测模型,其中,Input表示COCO数据集中训练集图像,Backbone为基础目标检测网络的骨干网络,FPN为基础目标检测网络的特征金字塔模块,Head为基础目标检测网络的目标检测头,Label Assignment为基础目标检测网络的标签分配模块,Loss为基础目标检测网络的损失函数模块。
参照图7-图9,本发明还提供了一种技术方案:一种目标检测模型的构建装置,如图7所示,包括:
基础目标检测网络确定模块10,基础目标检测网络确定模块10用于根据目标检测标签分配方向通用算法,确定基础目标检测网络;
网络预测辅助头构建模块20,网络预测辅助头构建模块20用于采用多层注意力机制根据基础目标检测网络,构建网络预测辅助头;
正负样本划分模块30,正负样本划分模块30用于根据网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本;
损失函数优化模块40,损失函数优化模块40用于根据网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的训练损失函数,并设计新的加权策略;
模型优化模块50,模型优化模块50用于采用优化后的训练损失函数和新的加权策略对基础目标检测网络进行优化;
模型训练模块60,模型训练模块60用于根据随机梯度下降法,对优化后的基础目标检测网络进行优化训练,得到目标检测模型。
其中,如图8所示,网络预测辅助头构建模块20包括:
第一特征图获取单元201,第一特征图获取单元201用于根据基础目标检测网络的特征金字塔模块输出的初步特征图,得到动态任务感知模块增强后的第一特征图;
第二特征图获取单元202,第二特征图获取单元202用于根据动态任务感知模块得到的第一特征图,进行多次卷积操作得到增强后的第二特征图;
网络预测辅助头生成单元203,网络预测辅助头生成单元203用于根据动态任务交互增强模块得到的第二特征图,通过不同的卷积操作得到网络预测辅助头。
其中,如图9所示,正负样本划分模块30包括:
锚框获取单元301,锚框获取单元301用于获取基础目标检测网络的特征金字塔模块输出的所有特征图上生成的锚框;
第一计算单元302,第一计算单元302用于根据获取的锚框和网络预测辅助头的预测结果,得到网络预测的分类、回归和质量评估得分,并构造预测样本评分函数和先验样本评分函数;
第二计算单元303,第二计算单元303用于根据网络预测的分类、回归和质量评估得分,计算每个锚框的综合得分,并合计所有锚框的回归得分之和k;
正样本锚框候选集生成单元304,正样本锚框候选集生成单元304用于根据先验样本评分函数,选取前Top-k锚框作为正样本锚框候选集;
动态更新单元305,动态更新单元305用于根据正样本锚框候选集,动态更新k值,使其为正样本锚框候选集中所有锚框的回归得分之和;
正负样本生成单元306,正负样本生成单元306用于根据预测样本评分函数和更新后的k值,选取前Top-k锚框作为正样本,剩余锚框设置为负样本。
本发明通过对当前一阶段目标检测器ATSS进行改进,构建目标检测网络预测辅助头并计算网络预测辅助头的预测结果,根据网络预测辅助头的预测结果进行训练得到最优目标检测器模型并在COCO数据集上进行前向推理实验,实验结果证明网络预测结果在目标检测模型训练中的重要性,从而能够在无需复杂建模的情况下提升目标检测模型的性能,显著提升基线模型的检测精度,并且能够在相同训练条件下超越最先进的一阶段目标检测模型,达到了无需复杂的数学建模,便能显著提高检测精度的目的。
以上实施方式对本发明进行了详细介绍,本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的一般技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本发明的限制。
Claims (9)
1.一种目标检测模型的构建方法,其特征在于,包括以下步骤:
S1、根据目标检测标签分配方向通用算法,确定基础目标检测网络;
S2、采用多层注意力机制根据所述基础目标检测网络,构建网络预测辅助头;
S3、根据所述网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本;
S4、根据所述网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的训练损失函数并设计新的加权策略;
S5、采用优化后的训练损失函数和新的加权策略,对所述基础目标检测网络进行优化;
S6、根据随机梯度下降法,对优化后的基础目标检测网络进行优化训练,得到目标检测模型。
2.根据权利要求1所述的目标检测模型的构建方法,其特征在于,所述步骤S2的具体步骤如下:
S21、根据基础目标检测网络的特征金字塔模块输出的初步特征图,得到动态任务感知模块增强后的第一特征图;
S22、根据动态任务感知模块得到的第一特征图,进行多次卷积操作得到增强后的第二特征图;
S23、根据动态任务交互增强模块得到的第二特征图,通过不同的卷积操作得到网络预测辅助头。
3.根据权利要求1所述的目标检测模型的构建方法,其特征在于,步骤S3中所述网络预测辅助头的预测结果包括分类任务分支预测结果、回归任务分支预测结果以及质量评估分支预测结果。
4.根据权利要求3所述的目标检测模型的构建方法,其特征在于,所述步骤S3具体包括以下步骤:
S31、获取基础目标检测网络的特征金字塔模块输出的所有特征图上生成的锚框;
S32、根据获取的锚框和网络预测辅助头的预测结果,得到网络预测的分类、回归和质量评估得分,并构造预测样本评分函数和先验样本评分函数;
S33、根据网络预测的分类、回归和质量评估得分,计算每个锚框的综合得分,并合计所有锚框的回归得分之和k;
S34、根据先验样本评分函数,选取前Top-k锚框作为正样本锚框候选集;
S35、根据正样本锚框候选集,动态更新k值,使其为正样本锚框候选集中所有锚框的回归得分之和;
S36、根据预测样本评分函数和更新后的k值,选取前Top-k锚框作为正样本,剩余锚框设置为负样本。
5.根据权利要求4所述的目标检测模型的构建方法,其特征在于,所述预测样本评分函数指基于网络预测辅助头中分类任务分支和质量评估任务分支的预测结果,得到的分类得分和中心度得分的乘积;所述先验样本评分函数指基于网络预测辅助头中回归任务分支的预测结果以及真实标签中心点先验,得到的回归得分和中心先验得分的乘积。
6.根据权利要求1所述的目标检测模型的构建方法,其特征在于,步骤S4中所述优化后的训练损失函数的表达式为:
L=αLcls+βLloc+γLcenter-iou
式中,α、β、γ是训练中用到的超参数,Lcls是分类任务分支的损失函数,Lreg是回归任务分支的损失函数,Lcenter-iou是质量评估任务分支的损失函数。
7.一种目标检测模型的构建装置,其特征在于,包括:
基础目标检测网络确定模块(10),所述基础目标检测网络确定模块(10)用于根据目标检测标签分配方向通用算法,确定基础目标检测网络;
网络预测辅助头构建模块(20),所述网络预测辅助头构建模块(20)用于采用多层注意力机制根据所述基础目标检测网络,构建网络预测辅助头;
正负样本划分模块(30),所述正负样本划分模块(30)用于根据网络预测辅助头的预测结果,在标签分配任务中动态自适应寻优划分正负样本;
损失函数优化模块(40),所述损失函数优化模块(40)用于根据所述网络预测辅助头的预测结果,对预设的训练损失函数进行优化,得到优化后的训练损失函数,并设计新的加权策略;
模型优化模块(50),所述模型优化模块(50)用于采用优化后的训练损失函数和新的加权策略,对所述基础目标检测网络进行优化;
模型训练模块(60),所述模型训练模块(60)用于根据随机梯度下降法,对优化后的基础目标检测网络进行优化训练,得到目标检测模型。
8.根据权利要求7所述的目标检测模型的构建装置,其特征在于,所述网络预测辅助头构建模块(20)包括:
第一特征图获取单元(201),所述第一特征图获取单元(201)用于根据基础目标检测网络的特征金字塔模块输出的初步特征图,得到动态任务感知模块增强后的第一特征图;
第二特征图获取单元(202),所述第二特征图获取单元(202)用于根据动态任务感知模块得到的第一特征图,进行多次卷积操作得到增强后的第二特征图;
网络预测辅助头生成单元(203),所述网络预测辅助头生成单元(203)用于根据动态任务交互增强模块得到的第二特征图,通过不同的卷积操作得到网络预测辅助头。
9.根据权利要求7所述的目标检测模型的构建装置,其特征在于,所述正负样本划分模块(30)包括:
锚框获取单元(301),所述锚框获取单元(301)用于获取基础目标检测网络的特征金字塔模块输出的所有特征图上生成的锚框;
第一计算单元(302),所述第一计算单元(302)用于根据获取的锚框和网络预测辅助头的预测结果,得到网络预测的分类、回归和质量评估得分,并构造预测样本评分函数和先验样本评分函数;
第二计算单元(303),所述第二计算单元(303)用于根据网络预测的分类、回归和质量评估得分,计算每个锚框的综合得分,并合计所有锚框的回归得分之和k;
正样本锚框候选集生成单元(304),所述正样本锚框候选集生成单元(304)用于根据先验样本评分函数,选取前Top-k锚框作为正样本锚框候选集;
动态更新单元(305),所述动态更新单元(305)用于根据正样本锚框候选集,动态更新k值,使其为正样本锚框候选集中所有锚框的回归得分之和;
正负样本生成单元(306),所述正负样本生成单元(306)用于根据预测样本评分函数和更新后的k值,选取前Top-k锚框作为正样本,剩余锚框设置为负样本。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210993151.6A CN115240052A (zh) | 2022-08-18 | 2022-08-18 | 一种目标检测模型的构建方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210993151.6A CN115240052A (zh) | 2022-08-18 | 2022-08-18 | 一种目标检测模型的构建方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115240052A true CN115240052A (zh) | 2022-10-25 |
Family
ID=83679407
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210993151.6A Pending CN115240052A (zh) | 2022-08-18 | 2022-08-18 | 一种目标检测模型的构建方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115240052A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116229333A (zh) * | 2023-05-08 | 2023-06-06 | 西南交通大学 | 基于难易等级自适应动态调整的难易目标解耦检测方法 |
CN117936001A (zh) * | 2024-03-25 | 2024-04-26 | 四川鸿霖科技有限公司 | 一种实验数据建模评测方法及系统 |
CN117994251A (zh) * | 2024-04-03 | 2024-05-07 | 华中科技大学同济医学院附属同济医院 | 基于人工智能的糖尿病足溃疡严重程度评估方法及系统 |
-
2022
- 2022-08-18 CN CN202210993151.6A patent/CN115240052A/zh active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116229333A (zh) * | 2023-05-08 | 2023-06-06 | 西南交通大学 | 基于难易等级自适应动态调整的难易目标解耦检测方法 |
CN116229333B (zh) * | 2023-05-08 | 2023-07-21 | 西南交通大学 | 基于难易等级自适应动态调整的难易目标解耦检测方法 |
CN117936001A (zh) * | 2024-03-25 | 2024-04-26 | 四川鸿霖科技有限公司 | 一种实验数据建模评测方法及系统 |
CN117994251A (zh) * | 2024-04-03 | 2024-05-07 | 华中科技大学同济医学院附属同济医院 | 基于人工智能的糖尿病足溃疡严重程度评估方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111914622B (zh) | 一种基于深度学习的人物交互检测方法 | |
CN115240052A (zh) | 一种目标检测模型的构建方法及装置 | |
CN110942072B (zh) | 基于质量评估的质量分、检测模型训练、检测方法及装置 | |
Xu et al. | Scale-aware feature pyramid architecture for marine object detection | |
CN111461212B (zh) | 一种用于点云目标检测模型的压缩方法 | |
CN110837836A (zh) | 基于最大化置信度的半监督语义分割方法 | |
Choi et al. | AMC-loss: Angular margin contrastive loss for improved explainability in image classification | |
CN115019039B (zh) | 一种结合自监督和全局信息增强的实例分割方法及系统 | |
CN114330499A (zh) | 分类模型的训练方法、装置、设备、存储介质及程序产品 | |
CN115909443A (zh) | 基于多头注意力机制的表情识别模型及其训练方法 | |
CN111444865A (zh) | 一种基于逐步求精的多尺度目标检测方法 | |
CN115063664A (zh) | 用于工业视觉检测的模型学习方法、训练方法及系统 | |
CN116091979A (zh) | 一种基于特征融合和通道注意力的目标跟踪方法 | |
CN112819024A (zh) | 模型处理方法、用户数据处理方法及装置、计算机设备 | |
CN111179272A (zh) | 一种面向道路场景的快速语义分割方法 | |
CN113436115A (zh) | 一种基于深度无监督学习的图像阴影检测方法 | |
CN117371511A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
Yan et al. | DEST: Deep enhanced swin transformer toward better scoring for NAFLD | |
Gaihua et al. | Instance segmentation convolutional neural network based on multi-scale attention mechanism | |
CN115424012A (zh) | 一种基于上下文信息的轻量图像语义分割方法 | |
CN113822293A (zh) | 用于图数据的模型处理方法、装置、设备及存储介质 | |
Jain et al. | Flynet–neural network model for automatic building detection from satellite images | |
Bi et al. | YOLO-RFB: An improved traffic sign detection model | |
Zhang et al. | Multi-level ensemble network for scene recognition | |
Pang et al. | Adaptive-MAML: Few-shot metal surface defects diagnosis based on model-agnostic meta-learning |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |