CN116486296A - 目标检测方法、装置及计算机可读存储介质 - Google Patents
目标检测方法、装置及计算机可读存储介质 Download PDFInfo
- Publication number
- CN116486296A CN116486296A CN202310266820.4A CN202310266820A CN116486296A CN 116486296 A CN116486296 A CN 116486296A CN 202310266820 A CN202310266820 A CN 202310266820A CN 116486296 A CN116486296 A CN 116486296A
- Authority
- CN
- China
- Prior art keywords
- data
- image
- loss function
- data set
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 75
- 238000012549 training Methods 0.000 claims abstract description 65
- 238000000034 method Methods 0.000 claims abstract description 37
- 230000006870 function Effects 0.000 claims description 117
- 238000009826 distribution Methods 0.000 claims description 37
- 230000008569 process Effects 0.000 claims description 13
- 238000012545 processing Methods 0.000 claims description 11
- 238000007781 pre-processing Methods 0.000 claims description 10
- 238000013519 translation Methods 0.000 claims description 4
- 230000003321 amplification Effects 0.000 claims description 2
- 238000003199 nucleic acid amplification method Methods 0.000 claims description 2
- 238000012935 Averaging Methods 0.000 claims 1
- 238000002372 labelling Methods 0.000 abstract description 9
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 238000012360 testing method Methods 0.000 description 4
- 238000010200 validation analysis Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000009499 grossing Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 241001465754 Metazoa Species 0.000 description 2
- 230000009471 action Effects 0.000 description 2
- 238000012512 characterization method Methods 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000000295 complement effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000007405 data analysis Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000003709 image segmentation Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 230000001737 promoting effect Effects 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 239000000126 substance Substances 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 230000001131 transforming effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/40—Scenes; Scene-specific elements in video content
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
- G06V10/7784—Active pattern-learning, e.g. online learning of image or video features based on feedback from supervisors
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- General Physics & Mathematics (AREA)
- Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明提供一种目标检测方法、装置及计算机可读存储介质,属于人工智能技术领域。其中,本发明的目标检测方法包括下述步骤:获取待检测的目标数据;将所述目标数据输入到预先训练的检测模型中,其中,所述预先训练的检测模型是经过半监督学习训练生成的,所述半监督学习时基于伪标签数据集和有标签数据集共同训练的,所述伪标签数据集是基于所述无标签数据集生成的;输出所述目标数据对应类别。本发明采用半监督式检测模型来进行目标检测,其能从带标签图片数据和无标签数据集联合训练的方法,模型仅依靠较少的人工标注数据,能达到甚至超越监督学习方法达到的精度,且避免了复杂的多阶段训练方案导致的性能较差问题。
Description
技术领域
本发明属于人工智能技术领域,具体涉及一种目标检测方法、装置及计算机可读存储介质。
背景技术
目标检测技术是计算机视觉技术的基础,目标检测技术可以检测出图像中包含的诸如人像、动物或物品等多种目标对象。在实际应用中目标检测技术可应用于诸多场景,目标检测一般是在图像中定位目标物体并赋予目标物体相应的标签。
当前目标检测数据的物体通常比较大,收据集的收集和人工标注需耗费大量的人力成本,且缺乏有效的数据分析及增强方法。
当前,基于半监督学习的深度学习在很多任务上都获得了成功的进展,一般需要通过大规模的有标签数据进行预训练,并且,目前的半监督学习方法常采用多阶段训练方式,初始检测器的性能影响伪标签质量,并导致影响最终性能,多用于图像分类、图像分割等领域。
因此,针对上述技术问题,本发明提出一种新的目标检测方法、装置以及基于计算机可读存储介质。
发明内容
本发明旨在至少解决现有技术中存在的技术问题之一,提供一种目标检测方法、装置及计算机可读存储介质。
本发明的一方面,提供一种目标检测方法,包括下述步骤:
获取待检测的目标数据;
将所述目标数据输入到预先训练的检测模型中,其中,所述预先训练的检测模型是经过半监督学习训练生成的,所述半监督学习时基于伪标签数据集和有标签数据集共同训练的,所述伪标签数据集是基于无标签数据集生成的;
输出所述目标数据对应类别。
可选的,所述半监督学习时,对有标签数据集训练,包括:
将所述有标签数据集中的有标签数据输入至学生模型,计算得到有标签的损失函数,根据所述损失函数缩小其预测标签与真实标签之间的概率分布差异,以获取真实标签。
可选的,所述半监督学习时,对无标签数据集进行如下处理,包括:
对所述无标签数据集中的无标签数据经强增强处理与弱增强处理;
将弱增强处理后的无标签数据输入至教师模型,强增强处理后的无标签数据输入至学生模型;
将所述教师模型的输出和所述学生模型的输出计算得到无标签损失函数,根据所述损失函数缩小其真实标签与预测标签之间的概率分布差异,以获取伪标签。
可选的,所述根据所述损失函数缩小其预测标签与真实标签之间的概率分布差异,包括:
根据交叉熵损失函数计算目标数据的预测标签与对应的真实标签之间的交叉熵损失函数值;
根据平滑L1损失函数的计算目标数据预测框的预测标签与对应的真实标签之间的平滑损失函数值;
将所述交叉熵损失函数值与所述平滑损失函数值相加得到总损失函数,根据所述总损失函数得到总损失值最小值。
可选的,所述交叉熵损失函数公式如下:
所述平滑L1损失函数的公式如下:
所述总损失函数公式如下:
Ltotal=LcrossEntropy+LsmoothL1
其中,Ltotal代表总损失函数,LcrossEntropy代表交叉熵损失函数,LsmoothL1代表平滑L1损失函数,p(x)代表真实标签的概率分布,q(x)表示预测标签的概率分布,box_true代表预测框分类器框的真实标签,box_pred代表预测框分类器的预测标签。
可选的,所述教师模型是所述学生模型经过指数滑动平均得到,具体公式如下:
shadowVariable=decay*shadowVariable+(1-decay)*Variable;
其中,shadowVariable为经过指数滑动平均处理后得到的参数值,Variable为当前epoch轮次的参数值,decay的范围为0-1。
可选的,所述半监督学习之前,还包括:
改变所述有标签数据集和无标签数据集中输入图像的图像大小形成第一图像,再对所述第一图像裁剪形成第二图像,将所述第二图像经降采样形成第三图像;
将所述第一图像、所述第二图像以及所述第三图像进行数据扩增,作为同一输入图像的三个view输入网络;
通过所述第一图像与所述第二图像不同位置的对比学习使所述检测模型学到平移不变性,对所述第一图像和所述第三图像不同尺度的对比学习使所述检测模型学到尺度不变形。
本发明的另一方面,提出一种目标检测装置,所述装置包括:
数据获取模块,用于获取待检测的目标数据;
数据输入模块,用于将所述目标数据输入到预先训练的检测模型中,其中,所述预先训练的检测模型是经过半监督学习训练生成的,所述半监督学习时基于伪标签数据集和有标签数据集共同训练的,所述伪标签数据集是基于无标签数据集生成的;
数据输出模块,用于输出所述目标数据对应类别。
可选的,所述装置还包括模型训练模块与数据预处理模块;其中,
所述模型训练模块,用于对有标签数据集与无标签数据集进行训练处理;
数据预处理模块,用于对检测模型预先训练之前对所述标签数据集与无标签数据集的数据进行预处理,以使模型学习到平移不变性和尺度不变性。
本发明的另一方面,提出一种计算机可读存储介质,所述计算机可读存储介质存储有多条指令,所述指令适于处理器加载并执行如前文记载所述的方法步骤。
本发明提出一种目标检测方法、装置以及一种计算机可读存储介质,本发明采用半监督式检测模型来进行目标检测,其能从带标签图片数据和无标签数据集联合训练的方法,模型仅依靠较少的人工标注数据,能达到甚至超越监督学习方法达到的精度,解决了目标检测的真实数据集难以获得,数据集的收集和人工标注需耗费大量的人力成本等问题,且避免了复杂的多阶段训练方案导致的性能较差问题。
附图说明
图1为本发明一实施例的目标检测方法的流程框图;
图2为本发明另一实施例的目标检测装置的示意图。
具体实施方式
为使本领域技术人员更好地理解本发明的技术方案,下面结合附图和具体实施方式对本发明作进一步详细描述。显然,所描述的实施例是本发明的一部分实施例,而不是全部的实施例。基于所描述的本发明的实施例,本领域普通技术人员在无需创造性劳动的前提下所获得的所有其他实施例,都属于本发明保护范围。
除非另外具体说明,本发明中使用的技术术语或者科学术语应当为本发明所属领域内具有一般技能的人士所理解的通常意义。本发明中使用的“包括”或者“包含”等既不限定所提及的形状、数字、步骤、动作、操作、构件、原件和/或它们的组,也不排除出现或加入一个或多个其他不同的形状、数字、步骤、动作、操作、构件、原件和/或它们的组。此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示技术特征的数量与顺序。
在发明的一些描述中,除非另有明确的规定和限定,术语“安装”、“连接”、“相连”或者“固定”等类似的词语并非限定于物理的或者机械的连接,而是可以包括电性的连接,不管是直接的还是通过中间媒体间接连接,可以是两个元件内部的连通或者两个元件的互相作用关系。以及,术语“中心”、“纵向”、“横向”、“长度”、“宽度”、“厚度”、“上”、“下”、“前”、“后”、“左”、“右”、“竖直”、“水平”、“顶”、“底”、“内”、“外”、等指示的方位或位置关系为基于附图所示的方位或位置关系,仅用于表示相对位置关系,当被描述对象的绝对位置改变后,则该相对位置关系也可能相应地改变。
如图1所示,本发明的一方面,提出一种目标检测方法S100,包括下述步骤S110~S130:
S110、获取待检测的目标数据。
需要说明的是,本实施例的待检测目标数据可以为任意类型、任意格式、任意尺寸的图像,可获取人像、动物或物品等多种目标数据的图像。
S120、将目标数据输入到预先训练的检测模型中,其中,预先训练的检测模型是经过半监督学习训练生成的,半监督学习时基于伪标签数据集和有标签数据集共同训练的,伪标签数据集是基于无标签数据集生成的。
需要说明的是,针对目前对数据的标注成本较高,本示例采用半监督学习,其半监督相对于全监督含义是训练中使用了部分没有标注的数据,即基于少量已标注数据和大量无标签数据进行训练,以降低标注成本。
简单来说,监督分类常用做法可以归纳为:简单自训练或伪标签学习,其用有标签数据训练一个分类器,然后用这个分类器对无标签数据进行分类,这样就会产生伪标签(pseudo label)或软标签(soft label),挑选认为分类正确的无标签样本,把选出来的无标签样本用来训练分类器;协同训练,其假设每个数据可以从不同的角度进行分类,不同角度可以训练出不同的分类器,然后用这些从不同角度训练出来的分类器对无标签样本进行分类,再选出认为可信的无标签样本加入训练集中。由于这些分类器从不同角度训练出来的,可以形成一种互补,而提高分类精度。
基于上述半监督学习的特性,本实施例提出了一种端对端的半监督式目标检测框架,训练每次迭代中同时对未标注数据打伪标签,使用伪标签数据和少量的标注数据同时训练,这种端到端的方法避免了复杂的多阶段训练方案导致的性能较差问题。
具体地,检测模型的构建过程如下:
S1201、获取目标检测数据集,例如,COCO数据集,将该数据集分为训练集、验证集和测试集以及无标记数据集,本实施例在118K训练集和123K无标记数据集上进行训练,在5K验证集上进行验证。
S1202、基于半监督学习训练检测模型,具体包括:
第一、提出端对端的半监督目标检测框架,训练每次迭代中同时对未标注数据打伪标签,使用伪标签数据和少量的标注数据同时训练。具体来说,在一个数据批次,会按照设定的比率随机从118K训练集和123K无标记数据集采样标注数据和未标注数据。训练过程中会使用两个模型,一个负责检测训练,一个负责给未标注数据打伪标签。前者是学生模型,后者为教师模型,是学生模型经过指数滑动平均(Exponential Moving Average,EMA)得到,即教师模型的输出作为学生模型输出的监督标签,学生模型输出预测标签,教师模型输出预测标签。EMA的意义在于利用滑动平均的参数来提高模型在测试数据上的健壮性,其公式如下:
shadowVariable=decay*shadowVariable+(1-decay)*Variable
其中,shadowVariable为经过指数滑动平均处理后得到的参数值,Variable为当前epoch轮次的参数值,decay控制着模型更新的速度,越大越趋于稳定,教师模型能较多保留之前的值,较少融入学生模型输出,其范围为0-1,在实际运用中,decay通常会设为一个十分接近1的常数,在我们实验中采用0.999进行模型训练。
本实施例中EMA对每一个待更新训练学习的变量(variable)都会维护一个影子变量(shadow variable),影子变量的初始值就是这个变量的初始值。
第二、为了使检测模型可以学习到检测对象级别(object-level)所需要的平移不变性和尺度不变性,能让网络学习到同一类物体的表征尺度大小和位置无关的特性,进一步提高性能,本实施例对输入图像进行了预处理,包括:将输入图像改变计算机图像的大小(resize)到第一预设图像大小,例如,224*224图像大小作为第一图像V1,再使用第一图像V1的随机裁剪作为第二图像V2,裁剪后将第二图像V2 resize成与第一图像V1一样的大小。再将第二图像V2经过降采样为第二预设图像大小,例如,112*112得到第三图像V3,之后将第一图像V1、第二图像V2和第三图像V3经过数据扩增,作为同一张图像的三个view输入网络,其中,view代表图片的三种形态(原图、裁剪图以及降裁剪图),第一图像V1和第二图像V2不同位置的对比学习使检测模型学到了平移不变性,第一图像V1和第三图像V3不同尺度的对比学习使检测模型学到了尺度不变性。
第三、在半监督学习中,对有标签的数据采用常规的pipeline流程(深度学习的流程,主要描述了数据是如何在节点之间流动的),利用学生模型进行预测,计算得到有标签的损失函数loss,包括分类和回归分支损失函数loss。
需要说明的是,在网络训练阶段,真实标签包括目标数据的真实标签(cls_true)以及目标数据预测框的真实标签(box_true),预测标签包括目标数据的预测标签(cls_pred)以及目标数据预测框的预测标签(box_pred)。
具体地,在训练阶段,可通过将每个分类器的预测输出的预测标签cls_pred和相应的真实标签cls_true,根据交叉熵损失函数计算上述目标数据的预测标签和真实标签的损失,以得到交叉熵损失函数值,再根据每个预测框分类器的预测输出的预测标签box_pred和相应的真实标签box_true,通过平滑L1函数Smooth L1Loss计算上述目标数据预测框的预测标签和真实标签的损失,得到平滑损失函数值,然后将交叉损失函数值与平滑损失函数值相加进行反向传播,使总损失函数值得到的损失值达到最小值,从而训练网络。其中,交叉熵损失函数定义如下:
平滑L1损失函数的公式如下:
总损失函数公式如下:
Ltotal=LcrossEntropy+LsmoothL1
其中,Ltotal代表总损失函数,LcrossEntropy代表交叉熵损失函数,LsmoothL1代表平滑L1损失函数,p(x)代表真实标签的概率分布,q(x)表示预测标签的概率分布,box_true代表预测框分类器框的真实标签,box_pred代表预测框分类器的预测标签。
本实施例交叉损失函数通过缩小真实标签和预测标签两个概率分布的差异,其中,真是标签是由教师模型输出得到,预测标签是由学生模型输出得到,通过交叉损失函数可使预测概率分布尽可能达到真实概率分布。
第四、参考FixMatch,无标签数据经过强和弱两种不同的数据增强处理,其中弱增强处理后的数据输入到教师模型,得到伪标签分布;而强增强处理后的数据输入至学生模型,得到预测概率分布,采用一致性损失函数loss对两个分布进行约束,使他们尽可能相同,即将教师模型的输出和学生模型的输出计算得到无标签的损失函数loss,包括分类和回归分支损失函数loss,从而训练网络。也就是说,将弱增强处理的数据输入至教师模型,得到的伪标签输出数据,利用损失函数对该伪标签输出数据进行训练模型,再者,将强增强处理的数据输入至学生模型,得到预测概率分布,同样利用损失函数对该预测输出数据进行训练模型。
具体地,本实施例将上述教师模型输出的数据和学生模型输出的数据均计算得到损失函数,具体过程如下:根据交叉熵损失函数计算目标数据的预测标签cls_pred与对应的真实标签cls_true之间的交叉熵损失函数值;再根据平滑L1损失函数的计算目标数据预测框的预测标签box_pred与对应的真实标签box_true之间的平滑损失函数值;之后,将交叉熵损失函数值与平滑损失函数值相加得到总损失函数,根据总损失函数得到总损失值最小值,然后将交叉损失函数值与平滑损失函数值相加进行反向传播,使总损失函数值得到的损失值达到最小值。
其中,交叉熵损失函数定义如下:
平滑L1损失函数的公式如下:
总损失函数公式如下:
Ltotal=LcrossEntropy+LsmoothL1
其中,Ltotal代表总损失函数,LcrossEntropy代表交叉熵损失函数,LsmoothL1代表平滑L1损失函数,p(x)代表真实标签的概率分布,q(x)表示预测标签的概率分布,box_true代表预测框分类器框的真实标签,box_pred代表预测框分类器的预测标签。
本实施例利用弱增强生成hard伪标签,然后利用增强后的预测值和伪标签进行一致性正则化学习,具体学习过程如下:
1.在有标签图像数据上训练教师模型;
2.使用训练好的教师模型生成无标签图像的伪标签(即边界框及其类别标签);
3.对未标记的图像应用强数据增强,并在应用全局几何变换时变换相应的伪标记(即边界框);
4.计算无标签损失和有标签损失以训练检测器。
应当理解的是,本实施例基于上述有标签数据训练形成教师模型,再利用教师模型对无标签数据进行预测。
本实施例交叉损失函数通过缩小真实标签和预测标签两个概率分布的差异,来使预测概率分布尽可能达到真实概率分布,以获取伪标签。也就是说,本实施例同时对无标签数据给予一个伪标签,用训练中的模型对无标签数据进行预测,以概率最高的类别作为无标签数据的伪标签。
本实施例训练形成基于目标检测的半监督式辅助模型,在给定带标签的训练集以及无标签的无标记数据集的情况下提出了一个半监督解决方案,设计了从带标签图片数据和无标签数据集联合训练的方法,模型仅依靠较少的人工标注数据,即能达到甚至超越监督学习方法达到的精度,从而有效提升网络对目标物体的检测和促进迁移学习的更好的任务对齐和体系结构对齐。
S130、输出目标数据对应类别,即基于上述步骤S120形成的半监督式检测模型,可对待检测的目标数据进行检测,并输出目标数据对应的类别。
本发明提出了一个基于端到端的半监督的检测模型,训练每次迭代中同时对未标注数据打伪标签,使用伪标签数据和少量的标注数据同时训练,这种端到端的方法避免了复杂的多阶段训练方案导致的性能较差问题,模型仅依靠较少的人工标注数据,即能达到甚至超越监督学习方法达到的精度。
如图2所示,本发明的另一方面,提出一种目标检测装置200,该装置包括:数据获取模块210,数据输入模块220,数据输出模块230。其中,数据获取模块210,用于获取待检测的目标数据;数据输入模块220,用于将目标数据输入到预先训练的检测模型中,其中,预先训练的检测模型是经过半监督学习训练生成的,半监督学习时基于伪标签数据集和有标签数据集共同训练的,伪标签数据集是基于无标签数据集生成的。数据输出模块230,用于输出目标数据对应类别。
进一步地,本实施例的装置还包括模型训练模块与数据预处理模块;其中,模型训练模块,用于对有标签数据集与无标签数据集进行训练处理;数据预处理模块,用于对检测模型预先训练之前对标签数据集与无标签数据集的数据进行预处理,以使模型学习到平移不变性和尺度不变性。
具体地,利用模型训练模块对检测模型的训练过程如下:
第一、获取目标检测数据集,例如,COCO数据集,将该数据集分为训练集、验证集和测试集以及无标记数据集,本实施例在118K训练集和123K无标记数据集上进行训练,在5K验证集上进行验证。
第二、基于半监督学习训练检测模型,具体包括:
1)提出端对端的半监督目标检测框架,训练每次迭代中同时对未标注数据打伪标签,使用伪标签数据和少量的标注数据同时训练。具体来说,在一个数据批次,会按照设定的比率随机从118K训练集和123K无标记数据集采样标注数据和未标注数据。训练过程中会使用两个模型,一个负责检测训练,一个负责给未标注数据打伪标签。前者是学生模型,后者为教师模型,是学生模型经过指数滑动平均(Exponential Moving Average,EMA)得到,即教师模型的输出作为学生模型输出的监督标签,学生模型输出预测标签,教师模型输出预测标签。EMA的意义在于利用滑动平均的参数来提高模型在测试数据上的健壮性,其公式如下:
shadowVariable=decay*shadowVariable+(1-decay)*Variable
其中,shadowVariable为经过指数滑动平均处理后得到的参数值,Variable为当前epoch轮次的参数值,decay控制着模型更新的速度,越大越趋于稳定,教师模型能较多保留之前的值,较少融入学生模型输出,其范围为0-1,在实际运用中,decay通常会设为一个十分接近1的常数,在我们实验中采用0.999进行模型训练。
本实施例中EMA对每一个待更新训练学习的变量(variable)都会维护一个影子变量(shadow variable),影子变量的初始值就是这个变量的初始值。
2)、在半监督学习中,对有标签的数据采用常规的pipeline流程(深度学习的流程,主要描述了数据是如何在节点之间流动的),利用学生模型进行预测,计算得到有标签的损失函数loss,包括分类和回归分支损失函数loss。
在训练阶段,最后可通过将每个分类器的预测输出的预测标签cls_pred和相应的真实标签cls_true,根据交叉熵损失函数计算上述目标数据的预测标签和真实标签的损失,以得到交叉损失函数值,再根据每个预测框分类器的预测输出的预测标签box_pred和相应的真实标签box_true,通过平滑L1函数Smooth L1 Loss计算上述目标数据预测框的预测标签和真实标签的损失,得到平滑损失函数值,然后将交叉损失函数值与平滑损失函数值相加进行反向传播,从而训练网络,损失函数定义如下:
所述平滑L1损失函数的公式如下:
所述总损失函数公式如下:
Ltotal=LcrossEntropy+LsmoothL1
其中,Ltotal代表总损失函数,LcrossEntropy代表交叉熵损失函数,LsmoothL1代表平滑L1损失函数,p(x)代表真实标签的概率分布,q(x)表示预测标签的概率分布,box_true代表预测框分类器框的真实标签,box_pred代表预测框分类器的预测标签。
本实施例交叉损失函数通过缩小真实标签和预测标签两个概率分布的差异,来使预测概率分布尽可能达到真实概率分布。
3)、参考FixMatch,无标签数据经过强和弱两种不同的数据增强处理,其中弱增强处理后的数据输入到教师模型,得到伪标签分布;而强增强处理后的数据输入至学生模型,得到预测概率分布,采用一致性损失函数loss对两个分布进行约束,使他们尽可能相同,即将教师模型的输出和学生模型的输出计算得到无标签的损失函数loss,包括分类和回归分支损失函数loss,从而训练网络。也就是说,将弱增强处理的数据输入至教师模型,得到的伪标签输出数据,利用损失函数对该伪标签输出数据进行训练模型,再者,将强增强处理的数据输入至学生模型,得到预测概率分布,同样利用损失函数对该预测输出数据进行训练模型。
其中,将上述教师模型输出的数据和学生模型输出的数据均计算得到损失函数,具体过程如下:根据交叉熵损失函数计算目标数据的预测标签cls_pred与对应的真实标签cls_true之间的交叉熵损失函数值;再根据平滑L1损失函数的计算目标数据预测框的预测标签box_pred与对应的真实标签box_true之间的平滑损失函数值;之后,将交叉熵损失函数值与平滑损失函数值相加得到总损失函数,根据总损失函数得到总损失值最小值,然后将交叉损失函数值与平滑损失函数值相加进行反向传播,使总损失函数值得到的损失值达到最小值。
其中,交叉熵损失函数定义如下:
平滑L1损失函数的公式如下:
总损失函数公式如下:
Ltotal=LcrossEntropy+LsmoothL1
其中,Ltotal代表总损失函数,LcrossEntropy代表交叉熵损失函数,LsmoothL1代表平滑L1损失函数,p(x)代表真实标签的概率分布,q(x)表示预测标签的概率分布,box_true代表预测框分类器框的真实标签,box_pred代表预测框分类器的预测标签。
进一步地,在利用模型训练模块训练形成检测模型的过程之前,还可以利用数据预处理模块对检测模型预先训练之前对标签数据集与无标签数据集的数据进行预处理,具体包括:
将输入图像改变计算机图像的大小(resize)到第一预设图像大小,例如,224*224图像大小作为第一图像V1,再使用第一图像V1的随机裁剪作为第二图像V2,裁剪后将第二图像V2 resize成与第一图像V1一样的大小。再将第二图像V2经过降采样为第二预设图像大小,例如,112*112得到第三图像V3,之后将第一图像V1、第二图像V2和第三图像V3经过数据扩增,作为同一张图像的三个view输入网络,其中,view代表图片的三种形态(原图、裁剪图以及降裁剪图),第一图像V1和第二图像V2不同位置的对比学习使检测模型学到了平移不变性,第一图像V1和第三图像V3不同尺度的对比学习使检测模型学到了尺度不变性。
本发明的另一方面,提出一种计算机可读存储介质,该计算机存储介质存储有多条指令,指令适于处理器加载并执行前文记载的方法步骤。
需要说明的是,计算机可读存储介质可以是本发明的装置、设备、系统中所包含的,也可以是单独存在。
其中,计算机可读存储介质可是任何包含或存储程序的有形介质,其可以是电、磁、光、电磁、红外线、半导体的系统、装置、设备,更具体的例子包括但不限于:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、光纤、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件,或它们任意合适的组合。
另外,计算机可读存储介质也可包括在基带中或作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码,其具体例子包括但不限于电磁信号、光信号,或它们任意合适的组合。
本发明提出一种目标检测方法、目标检测装置以及一种计算机可读存储介质,相对于现有技术具有以下有益效果:
第一、本发明提出一种半监督学习方式,设计了从带标签图片数据和无标签数据集联合训练的方法,模型仅依靠较少的人工标注数据,能达到甚至超越监督学习方法达到的精度,解决了目标检测真实数据难以获得、以及数据集的收集和人工标注耗费大量人力成本的问题。
第二、本发明提出一种端对端的半监督目标检测模型,训练每次迭代中同时对未标注数据打伪标签,使用伪标签数据和少量的标注数据同时训练,这种端到端的方法避免了复杂的多阶段训练方案导致的性能较差问题。
第三、本发明还设计了可以学习到检测对象级别所需要的平移不变性和尺寸不变性的目标检测方法,能让网络学习到同一类物体的表征尺度大小和位置无关的特性,进一步提高性能,解决了目前目标检测方法缺乏有效的平移不变性和尺度不变性的问题。
可以理解的是,以上实施方式是为了说明本发明的原理而采用的示例性实施方式,本发明并不局限于此。对于本领域内的普通技术人员而言,在不脱离本发明的精神和实质的情况下,可以做出各种变型和改进,这些变型和改进也视为本发明的保护范围。
Claims (10)
1.一种目标检测方法,其特征在于,包括下述步骤:
获取待检测的目标数据;
将所述目标数据输入到预先训练的检测模型中,其中,所述预先训练的检测模型是经过半监督学习训练生成的,所述半监督学习时基于伪标签数据集和有标签数据集共同训练的,所述伪标签数据集是基于无标签数据集生成的;
输出所述目标数据对应类别。
2.根据权利要求1所述的方法,其特征在于,所述半监督学习时,对有标签数据集训练,包括:
将所述有标签数据集中的有标签数据输入至学生模型,计算得到有标签的损失函数,根据所述损失函数缩小其预测标签与真实标签之间的概率分布差异,以获取真实标签。
3.根据权利要求2所述的方法,其特征在于,所述半监督学习时,对无标签数据集进行如下处理,包括:
对所述无标签数据集中的无标签数据经强增强处理与弱增强处理;
将弱增强处理后的无标签数据输入至教师模型,强增强处理后的无标签数据输入至学生模型;
将所述教师模型的输出和所述学生模型的输出分别计算得到无标签损失函数,根据所述损失函数缩小其真实标签与预测标签之间的概率分布差异,以获取伪标签。
4.根据权利要求3所述的方法,其特征在于,所述根据所述损失函数缩小其预测标签与真实标签之间的概率分布差异,包括:
根据交叉熵损失函数计算目标数据的预测标签与对应的真实标签之间的交叉熵损失函数值;
根据平滑L1损失函数的计算目标数据预测框的预测标签与对应的真实标签之间的平滑损失函数值;
将所述交叉熵损失函数值与所述平滑损失函数值相加得到总损失函数,根据所述总损失函数得到总损失值最小值。
5.根据权利要求4所述的方法,其特征在于,所述交叉熵损失函数公式如下:
所述平滑L1损失函数的公式如下:
所述总损失函数公式如下:
Ltotal=LcrossEntropy+LsmoothL1
其中,Ltotal代表总损失函数,LcrossEntropy代表交叉熵损失函数,LsmoothL1代表平滑L1损失函数,p(x)代表真实标签的概率分布,q(x)表示预测标签的概率分布,box_true代表预测框分类器框的真实标签,box_pred代表预测框分类器的预测标签。
6.根据权利要求4所述的方法,其特征在于,所述教师模型是所述学生模型经过指数滑动平均得到,具体公式如下:
shadowVariable=decay*shadowVariable+(1-decay)*Variable;
其中,shadowVariable为经过指数滑动平均处理后得到的参数值,Variable为当前epoch轮次的参数值,decay的范围为0-1。
7.根据权利要求1所述的方法,其特征在于,所述半监督学习之前,还包括:
改变所述有标签数据集和无标签数据集中输入图像的图像大小形成第一图像,再对所述第一图像裁剪形成第二图像,将所述第二图像经降采样形成第三图像;
将所述第一图像、所述第二图像以及所述第三图像进行数据扩增,作为同一输入图像的三个view输入网络;
通过所述第一图像与所述第二图像不同位置的对比学习使所述检测模型学到平移不变性,对所述第一图像和所述第三图像不同尺度的对比学习使所述检测模型学到尺度不变形。
8.一种目标检测装置,其特征在于,所述装置包括:
数据获取模块,用于获取待检测的目标数据;
数据输入模块,用于将所述目标数据输入到预先训练的检测模型中,其中,所述预先训练的检测模型是经过半监督学习训练生成的,所述半监督学习时基于伪标签数据集和有标签数据集共同训练的,所述伪标签数据集是基于无标签数据集生成的;
数据输出模块,用于输出所述目标数据对应类别。
9.根据权利要求8所述的装置,其特征在于,所述装置还包括模型训练模块与数据预处理模块;其中,
所述模型训练模块,用于对有标签数据集与无标签数据集进行训练处理;
数据预处理模块,用于对检测模型预先训练之前对所述标签数据集与无标签数据集的数据进行预处理,以使模型学习到平移不变性和尺度不变性。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有多条指令,所述指令适于处理器加载并执行如权利要求1至7任一项所述的方法步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310266820.4A CN116486296A (zh) | 2023-03-20 | 2023-03-20 | 目标检测方法、装置及计算机可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310266820.4A CN116486296A (zh) | 2023-03-20 | 2023-03-20 | 目标检测方法、装置及计算机可读存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116486296A true CN116486296A (zh) | 2023-07-25 |
Family
ID=87222188
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310266820.4A Pending CN116486296A (zh) | 2023-03-20 | 2023-03-20 | 目标检测方法、装置及计算机可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116486296A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116863277A (zh) * | 2023-07-27 | 2023-10-10 | 北京中关村科金技术有限公司 | 结合rpa的多媒体数据检测方法及系统 |
CN116935168A (zh) * | 2023-09-13 | 2023-10-24 | 苏州魔视智能科技有限公司 | 训练目标检测模型的方法、装置、计算机设备及存储介质 |
-
2023
- 2023-03-20 CN CN202310266820.4A patent/CN116486296A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116863277A (zh) * | 2023-07-27 | 2023-10-10 | 北京中关村科金技术有限公司 | 结合rpa的多媒体数据检测方法及系统 |
CN116935168A (zh) * | 2023-09-13 | 2023-10-24 | 苏州魔视智能科技有限公司 | 训练目标检测模型的方法、装置、计算机设备及存储介质 |
CN116935168B (zh) * | 2023-09-13 | 2024-01-30 | 苏州魔视智能科技有限公司 | 目标检测的方法、装置、计算机设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Patil et al. | MSFgNet: A novel compact end-to-end deep network for moving object detection | |
CN108960245B (zh) | 轮胎模具字符的检测与识别方法、装置、设备及存储介质 | |
CN109165623B (zh) | 基于深度学习的水稻病斑检测方法及系统 | |
Wang et al. | Dairy goat detection based on Faster R-CNN from surveillance video | |
CN111460927B (zh) | 对房产证图像进行结构化信息提取的方法 | |
CN103049763B (zh) | 一种基于上下文约束的目标识别方法 | |
CN116486296A (zh) | 目标检测方法、装置及计算机可读存储介质 | |
CN113688665B (zh) | 一种基于半监督迭代学习的遥感影像目标检测方法及系统 | |
Xing et al. | Traffic sign recognition using guided image filtering | |
Naufal et al. | Preprocessed mask RCNN for parking space detection in smart parking systems | |
CN113591671A (zh) | 一种基于Mask-Rcnn识别鱼类生长检测方法 | |
CN116758421A (zh) | 一种基于弱监督学习的遥感图像有向目标检测方法 | |
Ren et al. | MPSA: A multi-level pixel spatial attention network for thermal image segmentation based on Deeplabv3+ architecture | |
Jia et al. | Polar-Net: Green fruit instance segmentation in complex orchard environment | |
Yu et al. | Automatic segmentation of golden pomfret based on fusion of multi-head self-attention and channel-attention mechanism | |
Zhang et al. | Damaged apple detection with a hybrid YOLOv3 algorithm | |
CN113192108A (zh) | 一种针对视觉跟踪模型的人在回路训练方法及相关装置 | |
CN112949634A (zh) | 一种铁路接触网鸟窝检测方法 | |
CN117496138A (zh) | 面向点云分割的伪实例对比学习实现方法、装置及介质 | |
Li et al. | Automatic Counting Method of Fry Based on Computer Vision | |
Das et al. | Object Detection on Scene Images: A Novel Approach | |
CN110826394A (zh) | 基于卷积神经网络算法的水库识别方法和装置 | |
Liu et al. | A study on the design and implementation of an improved AdaBoost optimization mathematical algorithm based on recognition of packaging bottles | |
CN114022509B (zh) | 基于多个动物的监控视频的目标跟踪方法及相关设备 | |
CN115512331A (zh) | 一种交通标志检测方法、装置、计算机设备及计算机可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |