CN109344897A - 一种基于图片蒸馏的通用物体检测框架及其实现方法 - Google Patents

一种基于图片蒸馏的通用物体检测框架及其实现方法 Download PDF

Info

Publication number
CN109344897A
CN109344897A CN201811150901.3A CN201811150901A CN109344897A CN 109344897 A CN109344897 A CN 109344897A CN 201811150901 A CN201811150901 A CN 201811150901A CN 109344897 A CN109344897 A CN 109344897A
Authority
CN
China
Prior art keywords
rcnn
wae
faster rcnn
faster
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN201811150901.3A
Other languages
English (en)
Other versions
CN109344897B (zh
Inventor
王青
赵惠
陈添水
林倞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sun Yat Sen University
National Sun Yat Sen University
Original Assignee
National Sun Yat Sen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by National Sun Yat Sen University filed Critical National Sun Yat Sen University
Priority to CN201811150901.3A priority Critical patent/CN109344897B/zh
Publication of CN109344897A publication Critical patent/CN109344897A/zh
Application granted granted Critical
Publication of CN109344897B publication Critical patent/CN109344897B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

本发明公开了一种基于图片蒸馏的通用物体检测框架及其实现方法,该框架包括:Faster RCNN模型,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;Wae Faster RCNN检测模型,将输入图像分解成两个分辨率只有原图一半的子图,构建并利用Wae Faster RCNN网络结构分别对低频和高频子图进行物体检测,将两个子图的检测结果进行融合得到最终检测结果;训练指导单元,对Wae Faster RCNN检测模型进行训练,并在训练时引入知识蒸馏机制,利用已训练好的Faster RCNN模型的输出作为软目标来指导Wae Faster RCNN模型的训练。

Description

一种基于图片蒸馏的通用物体检测框架及其实现方法
技术领域
本发明涉及计算机视觉技术领域,特别是涉及一种基于图片蒸馏的通用物体检测框架及其实现方法。
背景技术
通用物体检测是计算机视觉领域最基础的研究方向,它的具体任务是对给定图像,输出该图像包含的物体的边界框和类别。近年来,随着卷积神经网络的发展,通用物体检测已取得重大进展。目前基于CNN的通用物体检测方法主要分为两种:以RCNN,FastRCNN,Faster RCNN,Mask RCNN为代表的基于分类的通用物体检测方法和以YOLO系列、SSD为代表的基于回归的物体检测方法。基于分类的通用物体检测方法一般检测精度较高于基于回归的通用物体检测方法,应用较为广泛,但其检测速度相对较慢。
具体地说,RCNN提出应用候选框策略来解决检测问题,即先用传统方法对图片预测一系列可能含有物体的候选框,再对候选框进行分类和位置微调。RCNN需要提前保存图像的候选框且每个候选框要单独经过网络提取特征,占用内存大且检测时间长;Fast RCNN采用ROI Pooling对此进行改进,使得每张图片仅需经过网络一次,速度有所提高,但仍然偏慢,Faster RCNN在Fast RCNN的基础上,提出了RPN(Region Proposal Network)来提取候选框,速度较传统方法有明显提高,但仍远远不够,Mask RCNN进一步改进Faster RCNN,添加了一个分支使用现有的检测对目标进行并行预测,提高了对小物体的检测精度,而且Mask RCNN的检测速度在5fps,已经是速度比较快的基于分类的通用物体检测框架了,但这个速度离实时检测还有些遥远。
发明内容
为克服上述现有技术存在的不足,本发明之目的在于提供一种基于图片蒸馏的通用物体检测框架及其实现方法,以提高基于分类的通用物体检测技术的检测速度。
为达上述及其它目的,本发明提出一种基于图片蒸馏的通用物体检测框架,包括:
Faster RCNN模型,用于构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
Wae Faster RCNN检测模型,用于将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
训练指导单元,用于对所述Wae Faster RCNN检测模型进行训练,并在所述WaeFaster RCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练。
优选地,所述Wae Faster RCNN检测模型包括:
图像分解单元,用于利用训练好的Anto-Encoder模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
检测单元,用于构建所述Wae Faster RCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测;
融合处理单元,用于对低频子图与高频子图的检测结果进行融合,得到融合后的检测结果。
优选地,所述图像分解单元采用类小波自动编码器WAE进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。
优选地,对于低频子图与高频子图,所述检测单元分别构建所述Wae Faster RCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版FasterRCNN的RPN和Fast RCNN。
优选地,所述轻量版Faster RCNN的部分卷积层通道数为所述完整版Faster RCNN的四分之一。
优选地,所述融合处理单元将低频子图的检测结果和高频子图的检测结果进行融合,作为最终的检测结果。
优选地,所述训练指导单元利用训练好的Faster RCNN模型的输出作为软目标对所述Wae Faster RCNN检测模型的Fast RCNN部分的训练进行指导。
为达到上述目的,本发明还提供一种基于图片蒸馏的通用物体检测框架的实现方法,包括如下步骤:
步骤S1,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
步骤S2,将输入图像分解成两个分辨率只有原图一半的子图,构建Wae FasterRCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
步骤S3,对所述Wae Faster RCNN检测模型进行训练,并在Wae Faster RCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练。
优选地,步骤S2进一步包括;
步骤S201,利用训练好的分类模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
步骤S202,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,对于低频子图与高频子图,分别构建所述WaeFaster RCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版Faster RCNN的RPN和Fast RCNN;
步骤S203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。
优选地,于步骤S3中,利用所述Faster RCNN模型的Fast RCNN得到的候选框得分指导所述Wae Faster RCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到所述Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据所述Faster RCNN模型得到的软目标Soft target与所述Wae Faster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据所述Wae Faster RCNN检测模型得到的硬输出Hard output和真实标签Hard target 计算硬损失Hard loss,得到总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
与现有技术相比,本发明一种基于图片蒸馏的通用物体检测框架及其实现方法通过采用类小波自动编码器将输入图像分解成两个分辨率只有原图一半的子图,然后对两个子图进行后续检测步骤,最后将两个子图的检测结果进行平均得到最终检测结果,本发明由于仅采用分辨率只有原图一半的子图进行检测使得检测速度提高了两倍,但不可避免地会导致精度的下降,因此在训练时引入知识蒸馏的机制,用复杂的但是检测精度高的Faster RCNN模型的输出作为软目标来指导检测模型的训练,从而保证检测精度。
附图说明
图1为本发明一种基于图片蒸馏的通用物体检测框架的结构示意图;
图2为本发明具体实施例中基于图片蒸馏的通用物体检测框架的架构示意图;
图3为本发明具体实施例中Faster RCNN模型得到软目标的过程示意图;
图4为本发明具体实施例中Wae Faster RCNN检测模型的训练过程示意图;
图5为本发明一种基于图片蒸馏的通用物体检测框架的实现方法的步骤流程图。
具体实施方式
以下通过特定的具体实例并结合附图说明本发明的实施方式,本领域技术人员可由本说明书所揭示的内容轻易地了解本发明的其它优点与功效。本发明亦可通过其它不同的具体实例加以施行或应用,本说明书中的各项细节亦可基于不同观点与应用,在不背离本发明的精神下进行各种修饰与变更。
图1为本发明一种基于图片蒸馏的通用物体检测框架的结构示意图。如图1所示,本发明一种基于图片蒸馏的通用物体检测框架,包括:
Faster RCNN模型10,用于构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型。由于这里Faster RCNN模型的构建与训练采用的是现有技术,在此不予赘述。
Wae Faster RCNN检测模型20,用于将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果。
训练指导单元30,用于对Wae Faster RCNN检测模型进行训练,并在Wae FasterRCNN检测模型训练时引入知识蒸馏机制,用复杂的但是检测精度高的训练好的FasterRCNN模型的输出作为软目标(soft target)来指导Wae Faster RCNN检测模型的训练。
具体地,Wae Faster RCNN检测模型20进一步包括:
图像分解单元201,用于利用训练好的Auto-Encoder(自编码器)模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图。在本发明具体实施例中,图像分解单元201应用了类小波自动编码器(Wavelet-like Auto-Encoder,简称WAE)进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。在本发明具体实施例中,图像分解的网络结构如表1所示:
表1
其中,含有“conv”的表示卷积层,括号内为卷积层参数,分别为卷积核个数、填充0个数、卷积核大小,步长,“relu”表示激活层,含“CA”的表示该层输出为低频子图,含“CH”的表示该层输出为高频子图,粗体表示该层的输出即为网络输出。
检测单元202,用于构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测。在本发明具体实施例中,对于低频子图与高频子图,分别构建低频子网络和高频子网络。Wae Faster RCNN网络的RPN(RegionProposal Network)部分,低频子网络对低频子图应用完整版Faster RCNN的RPN,高频子网络对高频子图应用轻量版Faster RCNN的RPN,其中轻量版Faster RCNN的RPN部分卷积层通道数是完整版的四分之一。在本发明具体实施例中,Wae Faster RCNN网络的低频子网络和高频子网络的RPN部分结构如下表2所示:
表2
其中,包含“conv”的表示卷积层,括号内为卷积层参数,分别为卷积核个数、填充0个数、卷积核大小,步长。“relu”表示激活层,“batchnorm”表示批量归一化层,“maxpool”表示最大池化层,括号内为最大池化层参数,分别为卷积核大小和下采样步长,“eltwise”开头的表示eltwise层,括号内为eltwise层参数,表示对每对元素的操作,非斜体部分表示RPN与Fast RCNN共享的网络结构,即主干网络,斜体部分表示RPN特有的网络结构,含“CA”的为低频子网络的部分,含“CH”的为高频子网络的部分,粗体表示该层的输出即为网络输出,表中断开的部分无特殊操作,只是为了方便表示,对断开部分上一行重新进行了排列。
Wae Faster RCNN网络的Fast RCNN部分,对低频子图应用完整版Faster RCNN的Fast RCNN,对高频子图应用轻量版Faster RCNN的Fast RCNN,其中轻量版Faster Rcnn网络的RPN部分卷积层通道数是完整版的四分之一,这里使用的Fast RCNN不完全和FasterRCNN中的一致,主要是对全卷积层的神经元个数做了修改。Wae Faster RCNN网络结构的Fast RCNN部分的具体网络结构如表3所示:
表3
其中,包含“conv”的表示卷积层,括号内为卷积层参数,分别为卷积核个数、填充0个数、卷积核大小,步长,表中“relu”表示激活层。“maxpool”表示最大池化层,括号内为最大池化层参数,分别为卷积核大小和下采样步长。“fc”开头的表示全连接层,括号内为全连接参数,为神经元个数。“ROIPooling”表示感兴趣区域池化层,括号内为感兴趣区域池化层的参数,分别为卷积核宽度、卷积核长度,空间缩放尺度(该层与输入图像相比缩小的倍数),“dropout”表示dropout层,括号内为dropout层参数,表示丢失率。“batchnorm”开头的表示批量归一化层。“concat”开头的表示连接层,括号内为连接层参数,表示按某一维度连接,“eltwise”开头的表示eltwise层,括号内为eltwise层参数,表示对每对元素的操作。非斜体部分表示RPN与Fast Rcnn共享的网络结构,即主干网络,斜体部分表示Fast RCNN特有的网络结构。含“CA”的为低频子网络的部分,含“CH”或“fusion”的为高频子网络的部分,粗体表示该层的输出即为网络输出。
融合处理单元203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。在本发明具体实施例中,融合处理单元203将低频子图与高频子图的检测结果进行平均,得到最终检测结果。
在本发明中,训练指导单元30采用Faster RCNN来指导Wae Faster RCNN检测模型的训练。经过实验发现,Wae Faster RCNN的RPN阶段生成的候选框与Faster RCNN的质量相当,差别只在于Fast RCNN部分。因此,训练指导单元30只对Fast RCNN部分的训练进行指导。具体的,训练指导单元30用Faster RCNN的Fast RCNN得到的候选框得分指导WaeFaster RCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据Faster RCNN模型得到的软目标Soft target与Wae Faster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据Wae Faster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,这样总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
图2为本发明具体实施例中基于图片蒸馏的通用物体检测框架的架构示意图。如图2所示,左边的Teacher model为复杂模型,即Faster RCNN模型,右边的Student model为Wae Faster RCNN检测模型,其参数需要训练,它以Image I作为输入,经过Wae encodinglayer(即图像分解单元)将Image I分解成两个子图(左边是低频子图,右边是高频子图)。对于低频子图,应用复杂的模型(本发明采用如Teacher model的Faster RCNN模型,由于输入图片的分辨率减半,速度会比对原图应用teacher model快),得到检测结果(Studentmodel的左分支)。对于高频子图,应用简化的复杂模型(本发明将如teacher model的Faster RCNN模型的通道数变为原来的四分之一),得到检测结果(Student model的右分支)。将两个分支的结果进行融合得到最终结果。
虽然Student model将输入图片变为原来的一半会加快检测速度,但无疑会带来精度的下降,所以在训练的时候要引入知识蒸馏来保证精度,知识蒸馏就是用训练好的复杂模型(即左边的Teacher model)的输出来指导简单模型(右边的Student model)的训练。
训练时,将相同的图片输入Teacher model和Student model,将Teacher model得到的软目标Soft target与Student model得到的软输出Soft output计算软损失Softloss(这个过程就是知识蒸馏),同时将Student model得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,总的分类部分的损失函数classify loss=Hardloss+λSoft loss,λ是权重。
图3为本发明具体实施例中Faster RCNN模型得到软目标的过程示意图。具体地,输入图像,经过CNN,RoI Pooling,NN得到分类结果teacher_cls和边界框位置teacher_bbox(到目前为止是Faster Rcnn模型的Fast Rcnn检测物体的过程),对于分类结果teacher_cls,先除以一个温度系数T,再进过Softmax变换,即得到软化的概率分布Softtarget(软目标)St。
以下将配合图4来具体说明本发明具体实施例中Wae Faster RCNN检测模型的训练过程,在本发明具体实施例中,Wae Faster RCNN检测模型的训练过程包括如下四个阶段
第一阶段:训练Wae Faster RCNN检测模型的RPN部分。用训练好的WAE分类网络进行Wae Faster RCNN模型的初始化。固定两个conv3_1之前的权值,只微调conv3_1之后的权值。RPN的低频子网络,高频子网络,两者输出的平均都有各自的损失函数,其损失函数类比原Faster RCNN的RPN损失函数得到。
第二阶段:训练Wae Faster RCNN检测模型的Fast RCNN部分。用训练好的WAE分类网络进行初始化,固定两个conv3_1之前的权值,只微调conv3_1之后的权值。在每次迭代时,先将当前处理的图片及对应的候选框输入到Faster RCNN,进行前向传播,得到原Faster RCNN的候选框类别得分teacher_cls,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标,图3中的St。将同样的图片及候选框输入到Wae FasterRCNN的Fast RCNN部分,进行前向传播,该过程如图4所示。低频子网络输出候选框分数CA_cls和候选框位置CA_bbox,高频子网络输出候选框分数CH_cls和候选框位置CH_bbox。将CA_cls与CH_cls进行平均得到Avg_cls,对CA_bbox和CH_bbox进行平均得到Avg_bbox,对CA_cls进行两种操作:除以温度参数T并做softmax变换得到CA_cls_soft和直接做softmax变换得到CA_cls_hard。对CH_cls和Avg_cls类似。对于低频子网络,分类损失有两部分组成:CA_cls_hard与真实值cls的交叉熵损失和CA_cls_soft与St的交叉熵损失,赋予第一个损失较小权重,定位损失为CA_bbox与真实值bbox的Smooth L1损失。高频子网络和两个子网络平均之后计算的损失类似。
第三阶段:用第二阶段得到的权值初始化Wae Faster RCNN的RPN网络,固定conv5_1以及之前的层,只微调RPN特有的层。
第四阶段:用第三阶段得到的权值初始化Wae Faster RCNN的Fast RCNN网络,固定conv5_1以及之前的层,只微调Fast RCNN特有的层。
图5为本发明一种基于图片蒸馏的通用物体检测框架的实现方法的步骤流程图。如图5所示,本发明一种基于图片蒸馏的通用物体检测框架的实现方法,包括如下步骤:
步骤S1,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型。由于这里Faster RCNN模型的构建与训练采用的是现有技术,在此不予赘述。
步骤S2,将输入图像分解成两个分辨率只有原图一半的子图,构建Wae FasterRCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果。
步骤S3,对Wae Faster RCNN检测模型进行训练,并在Wae Faster RCNN检测模型训练时引入知识蒸馏机制,用复杂的但是检测精度高的训练好的Faster RCNN模型的输出作为软目标(soft target)来指导Wae Faster RCNN检测模型的训练。
具体地,步骤S2进一步包括:
步骤S201,利用训练好的Auto-Encoder模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图。在本发明具体实施例中,应用了类小波自动编码器(Wavelet-like Auto-Encoder,简称WAE)进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。
步骤S202,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测。在本发明具体实施例中,对于低频子图与高频子图,分别构建低频子网络和高频子网络。Wae Faster RCNN网络的RPN(Region ProposalNetwork)部分,低频子网络对低频子图应用完整版Faster RCNN的RPN,高频子网络对高频子图应用轻量版Faster RCNN的RPN,其中轻量版Faster RCNN的RPN部分卷积层通道数是完整版的四分之一。Wae Faster RCNN网络的Fast RCNN部分,对低频子图应用完整版FasterRCNN的Fast RCNN,对高频子图应用轻量版Faster RCNN的Fast RCNN,其中轻量版FasterRcnn网络的RPN部分卷积层通道数是完整版的四分之一,这里使用的Fast RCNN不完全和Faster RCNN模型中的一致,主要是对全卷积层的神经元个数做了修改。
步骤S203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。在本发明具体实施例中,将低频子图与高频子图的检测结果进行平均,得到最终检测结果。
于步骤S3中,采用Faster RCNN模型的输出来指导Wae Faster RCNN检测模型的训练。经过实验发现,Wae Faster RCNN的RPN阶段生成的候选框与Faster RCNN的质量相当,差别只在于Fast RCNN部分。因此,Faster RCNN模型的输出只对Fast RCNN部分的训练进行指导。具体的,于步骤S3中,用Faster RCNN的Fast RCNN得到的候选框得分指导Wae FasterRCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据Faster RCNN模型得到的软目标Soft target与Wae Faster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据Wae Faster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,这样总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
综上所述,本发明一种基于图片蒸馏的通用物体检测框架及其实现方法通过采用类小波自动编码器将输入图像分解成两个分辨率只有原图一半的子图,然后对两个子图进行后续检测步骤,最后将两个子图的检测结果进行平均得到最终检测结果,本发明由于仅采用分辨率只有原图一半的子图进行检测使得检测速度提高了两倍,但不可避免地会导致精度的下降,因此在训练时引入知识蒸馏的机制,用复杂的但是检测精度高的Faster RCNN模型的输出作为软目标来指导检测模型的训练,从而保证检测精度。
上述实施例仅例示性说明本发明的原理及其功效,而非用于限制本发明。任何本领域技术人员均可在不违背本发明的精神及范畴下,对上述实施例进行修饰与改变。因此,本发明的权利保护范围,应如权利要求书所列。

Claims (10)

1.一种基于图片蒸馏的通用物体检测框架,包括:
Faster RCNN模型,用于构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
Wae Faster RCNN检测模型,用于将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
训练指导单元,用于对所述Wae Faster RCNN检测模型进行训练,并在所述Wae FasterRCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练。
2.如权利要求1所述的一种基于图片蒸馏的通用物体检测框架,其特征在于,所述WaeFaster RCNN检测模型包括:
图像分解单元,用于利用训练好的Anto-Encoder模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
检测单元,用于构建所述Wae Faster RCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测;
融合处理单元,用于对低频子图与高频子图的检测结果进行融合,得到融合后的检测结果。
3.如权利要求2所述的一种基于图片蒸馏的通用物体检测框架,其特征在于:所述图像分解单元采用类小波自动编码器WAE进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。
4.如权利要求2所述的一种基于图片蒸馏的通用物体检测框架,其特征在于:对于低频子图与高频子图,所述检测单元分别构建所述Wae Faster RCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和FastRCNN,该高频子网络的RPN和Fast RCNN,采用轻量版Faster RCNN的RPN和Fast RCNN。
5.如权利要求4所述的一种基于图片蒸馏的通用物体检测框架,其特征在于:所述轻量版Faster RCNN的部分卷积层通道数为所述完整版Faster RCNN的四分之一。
6.如权利要求2所述的一种基于图片蒸馏的通用物体检测框架,其特征在于:所述融合处理单元将低频子图的检测结果和高频子图的检测结果进行融合,作为最终的检测结果。
7.如权利要求1所述的一种基于图片蒸馏的通用物体检测框架,其特征在于:所述训练指导单元利用训练好的Faster RCNN模型的输出作为软目标对所述Wae Faster RCNN检测模型的Fast RCNN部分的训练进行指导。
8.一种基于图片蒸馏的通用物体检测框架的实现方法,包括如下步骤:
步骤S1,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
步骤S2,将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
步骤S3,对所述Wae Faster RCNN检测模型进行训练,并在Wae Faster RCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练。
9.如权利要求8所述的一种基于图片蒸馏的通用物体检测框架的实现方法,其特征在于,步骤S2进一步包括;
步骤S201,利用训练好的分类模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
步骤S202,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,对于低频子图与高频子图,分别构建所述Wae FasterRCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版FasterRCNN的RPN和Fast RCNN;
步骤S203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。
10.如权利要求8所述的一种基于图片蒸馏的通用物体检测框架的实现方法,其特征在于,于步骤S3中,利用所述Faster RCNN模型的Fast RCNN得到的候选框得分指导所述WaeFaster RCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到所述Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据所述Faster RCNN模型得到的软目标Soft target与所述WaeFaster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据所述WaeFaster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hardloss,得到总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
CN201811150901.3A 2018-09-29 2018-09-29 一种基于图片蒸馏的通用物体检测系统及其实现方法 Active CN109344897B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201811150901.3A CN109344897B (zh) 2018-09-29 2018-09-29 一种基于图片蒸馏的通用物体检测系统及其实现方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201811150901.3A CN109344897B (zh) 2018-09-29 2018-09-29 一种基于图片蒸馏的通用物体检测系统及其实现方法

Publications (2)

Publication Number Publication Date
CN109344897A true CN109344897A (zh) 2019-02-15
CN109344897B CN109344897B (zh) 2022-03-25

Family

ID=65307678

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201811150901.3A Active CN109344897B (zh) 2018-09-29 2018-09-29 一种基于图片蒸馏的通用物体检测系统及其实现方法

Country Status (1)

Country Link
CN (1) CN109344897B (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110335242A (zh) * 2019-05-17 2019-10-15 杭州数据点金科技有限公司 一种基于多模型融合的轮胎x光病疵检测方法
CN112101573A (zh) * 2020-11-16 2020-12-18 智者四海(北京)技术有限公司 一种模型蒸馏学习方法、文本查询方法及装置
CN112307976A (zh) * 2020-10-30 2021-02-02 北京百度网讯科技有限公司 目标检测方法、装置、电子设备以及存储介质

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103390164A (zh) * 2012-05-10 2013-11-13 南京理工大学 基于深度图像的对象检测方法及其实现装置
CN103679677A (zh) * 2013-12-12 2014-03-26 杭州电子科技大学 一种基于模型互更新的双模图像决策级融合跟踪方法
CN107358258A (zh) * 2017-07-07 2017-11-17 西安电子科技大学 基于nsct双cnn通道和选择性注意机制的sar图像目标分类
CN107563381A (zh) * 2017-09-12 2018-01-09 国家新闻出版广电总局广播科学研究院 基于全卷积网络的多特征融合的目标检测方法
CN107886117A (zh) * 2017-10-30 2018-04-06 国家新闻出版广电总局广播科学研究院 基于多特征提取和多任务融合的目标检测算法
CN108470183A (zh) * 2018-02-05 2018-08-31 西安电子科技大学 基于聚类细化残差模型的极化sar分类方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103390164A (zh) * 2012-05-10 2013-11-13 南京理工大学 基于深度图像的对象检测方法及其实现装置
CN103679677A (zh) * 2013-12-12 2014-03-26 杭州电子科技大学 一种基于模型互更新的双模图像决策级融合跟踪方法
CN107358258A (zh) * 2017-07-07 2017-11-17 西安电子科技大学 基于nsct双cnn通道和选择性注意机制的sar图像目标分类
CN107563381A (zh) * 2017-09-12 2018-01-09 国家新闻出版广电总局广播科学研究院 基于全卷积网络的多特征融合的目标检测方法
CN107886117A (zh) * 2017-10-30 2018-04-06 国家新闻出版广电总局广播科学研究院 基于多特征提取和多任务融合的目标检测算法
CN108470183A (zh) * 2018-02-05 2018-08-31 西安电子科技大学 基于聚类细化残差模型的极化sar分类方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
GUOBIN CHEN ET AL: "Learning Efficient Object Detection Models with Knowledge Distillation", 《31ST CONFERENCE ON NEURAL INFORMATION PROCESSING SYSTEMS (NIPS 2017)》 *
TIANSHUI CHEN ET AL: "Learning a Wavelet-like Auto-Encoder to Accelerate Deep Neural Networks", 《ARXIV》 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110335242A (zh) * 2019-05-17 2019-10-15 杭州数据点金科技有限公司 一种基于多模型融合的轮胎x光病疵检测方法
CN112307976A (zh) * 2020-10-30 2021-02-02 北京百度网讯科技有限公司 目标检测方法、装置、电子设备以及存储介质
CN112101573A (zh) * 2020-11-16 2020-12-18 智者四海(北京)技术有限公司 一种模型蒸馏学习方法、文本查询方法及装置
CN112101573B (zh) * 2020-11-16 2021-04-30 智者四海(北京)技术有限公司 一种模型蒸馏学习方法、文本查询方法及装置

Also Published As

Publication number Publication date
CN109344897B (zh) 2022-03-25

Similar Documents

Publication Publication Date Title
Fu et al. Fast and accurate detection of kiwifruit in orchard using improved YOLOv3-tiny model
AU2019101133A4 (en) Fast vehicle detection using augmented dataset based on RetinaNet
CN107609525B (zh) 基于剪枝策略构建卷积神经网络的遥感图像目标检测方法
CN109919122A (zh) 一种基于3d人体关键点的时序行为检测方法
CN109191455A (zh) 一种基于ssd卷积网络的大田作物病虫害检测方法
CN110378281A (zh) 基于伪3d卷积神经网络的组群行为识别方法
CN108021889A (zh) 一种基于姿态外形和运动信息的双通道红外行为识别方法
CN109344897A (zh) 一种基于图片蒸馏的通用物体检测框架及其实现方法
CN111680655A (zh) 一种面向无人机航拍影像的视频目标检测方法
CN110210431B (zh) 一种基于点云语义标注和优化的点云分类方法
CN114049381A (zh) 一种融合多层语义信息的孪生交叉目标跟踪方法
CN109508675A (zh) 一种针对复杂场景的行人检测方法
CN107870992A (zh) 基于多通道主题模型的可编辑服装图像搜索方法
CN108170823B (zh) 一种基于高层语义属性理解的手绘交互式三维模型检索方法
CN111161244B (zh) 基于FCN+FC-WXGBoost的工业产品表面缺陷检测方法
Du et al. Expanding receptive field yolo for small object detection
CN107918772A (zh) 基于压缩感知理论和gcForest的目标跟踪方法
CN109657634A (zh) 一种基于深度卷积神经网络的3d手势识别方法及系统
CN101276370B (zh) 基于关键帧的三维人体运动数据检索方法
CN109344898A (zh) 基于稀疏编码预训练的卷积神经网络图像分类方法
Lu et al. A CNN-transformer hybrid model based on CSWin transformer for UAV image object detection
CN112256904A (zh) 一种基于视觉描述语句的图像检索方法
CN112613428A (zh) 基于平衡损失的Resnet-3D卷积牛视频目标检测方法
CN113657414B (zh) 一种物体识别方法
CN105956604B (zh) 一种基于两层时空邻域特征的动作识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant