CN116612450A - 一种面向点云场景的差异化知识蒸馏3d目标检测方法 - Google Patents
一种面向点云场景的差异化知识蒸馏3d目标检测方法 Download PDFInfo
- Publication number
- CN116612450A CN116612450A CN202310426368.3A CN202310426368A CN116612450A CN 116612450 A CN116612450 A CN 116612450A CN 202310426368 A CN202310426368 A CN 202310426368A CN 116612450 A CN116612450 A CN 116612450A
- Authority
- CN
- China
- Prior art keywords
- model
- distillation
- student
- teacher
- student model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 68
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 31
- 238000004821 distillation Methods 0.000 claims abstract description 109
- 238000012549 training Methods 0.000 claims abstract description 30
- 238000004364 calculation method Methods 0.000 claims abstract description 10
- 238000010586 diagram Methods 0.000 claims description 25
- 235000004522 Pentaglottis sempervirens Nutrition 0.000 claims description 10
- 230000007246 mechanism Effects 0.000 claims description 8
- 240000004050 Pentaglottis sempervirens Species 0.000 claims description 7
- 238000002372 labelling Methods 0.000 claims description 7
- 239000011159 matrix material Substances 0.000 claims description 7
- 238000001914 filtration Methods 0.000 claims description 5
- 239000013598 vector Substances 0.000 claims description 4
- 230000006870 function Effects 0.000 description 37
- 238000000034 method Methods 0.000 description 27
- 238000004422 calculation algorithm Methods 0.000 description 11
- 230000000694 effects Effects 0.000 description 5
- 230000008569 process Effects 0.000 description 5
- 238000012795 verification Methods 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 238000004590 computer program Methods 0.000 description 3
- 238000010200 validation analysis Methods 0.000 description 3
- 238000002474 experimental method Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 238000005259 measurement Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 230000000007 visual effect Effects 0.000 description 2
- GXMBHQRROXQUJS-UHFFFAOYSA-N (2-hept-2-ynylsulfanylphenyl) acetate Chemical compound CCCCC#CCSC1=CC=CC=C1OC(C)=O GXMBHQRROXQUJS-UHFFFAOYSA-N 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000001149 cognitive effect Effects 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011897 real-time detection Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000011524 similarity measure Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000001629 suppression Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/50—Context or environment of the image
- G06V20/56—Context or environment of the image exterior to a vehicle by using sensors mounted on the vehicle
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Multimedia (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种面向点云场景的差异化知识蒸馏3D目标检测方法,包括:构建两阶段教师模型和单阶段学生模型并对两阶段教师模型进行预训练,获得预训练后的教师模型;将所构建的训练数据集中的点云数据分别输入到经训练的教师模型和未经训练的学生模型中,分别对教师模型和学生模型的3D骨干网络和2D鸟瞰图骨干网络的输出结果进行差异化特征蒸馏;利用教师模型的一阶段输出结果对学生模型的输出结果进行差异一致性蒸馏;利用教师模型的二阶段输出结果对学生模型的输出结果进行中心点匹配蒸馏,获得训练后的学生模型。本发明将知识从两阶段的目标检测器迁移到单阶段的目标检测器中,能够在不增加计算开销的前提下提高单阶段检测器的精度。
Description
技术领域
本发明属于目标检测技术领域,具体涉及一种面向点云场景的差异化知识蒸馏3D目标检测方法。
背景技术
3D目标检测在自动驾驶领域、智能交通以及智能机器人等诸多领域发挥着举足轻重的作用,而激光雷达传感器(LiDAR)是高质量3D目标检测的关键,由于点云数据携带稳定的几何和深度信息,且不易受环境周围影响的诸多优点,因此将激光雷达获取的点云数据作为周围环境的三维图像进行精准的3D目标检测算法的研究。
传统的基于单一点云数据的3D目标检测方法根据其算法结构可分为单阶段目标检测方法和两阶段目标检测方法。单阶段目标检测方法具有检测速度快,可以实现实时检测的效果等优点,但由于该类方法需要对特征空间进行下采样,容易丢失空间几何信息,不可避免地会导致检测精度的下降。两阶段目标检测方法可以通过其感兴趣区域(RoI)头解决这一问题,其通常采用虚拟点采样策略,使RoI头能够更好地从原始点云或体素空间中学习几何信息,这种方法极大地提高了目标检测的精度,但激光雷达距离目标太远会产生负样本误判为正样本的问题,且该类检测方法由于计算开销较大,在速度上明显慢于单阶段的目标检测方法。
发明内容
为了解决现有技术中存在的上述问题,本发明提供了一种面向点云场景的差异化知识蒸馏3D目标检测方法,将知识从两阶段的目标检测器(教师模型)迁移到单阶段的目标检测器(学生模型)中,在不增加计算开销的前提下提高单阶段检测器的精度。本发明要解决的技术问题通过以下技术方案实现:
本发明提供了一种面向点云场景的差异化知识蒸馏3D目标检测方法,包括:
S1:构建两阶段教师模型和单阶段学生模型并对所述两阶段教师模型进行预训练,获得预训练后的教师模型,所述教师模型包含3D骨干网络、2D鸟瞰图骨干网络、一阶段检测头和二阶段ROI头,所述学生模型包括3D骨干网络、2D鸟瞰图骨干网络和单阶段检测头;
S2:将所构建的训练数据集中的点云数据分别输入到经训练的教师模型和未经训练的学生模型中,分别对教师模型和学生模型的3D骨干网络和2D鸟瞰图骨干网络的输出结果进行差异化特征蒸馏;
S3:利用教师模型的一阶段输出结果对所述学生模型的输出结果进行差异一致性蒸馏;
S4:利用教师模型的二阶段输出结果对所述学生模型的输出结果进行中心点匹配蒸馏,获得训练后的学生模型。
在本发明的一个实施例中,所述S2包括:
S2.1:将训练数据集中的点云数据分别输入到经预训练的教师模型和未经训练的学生模型的3D骨干网络中,分别获得教师模型的鸟瞰图特征图和学生模型的鸟瞰图特征图;
S2.2:利用教师模型的鸟瞰图特征图对学生模型的鸟瞰图特征图进行差异化特征蒸馏,获得学生模型特征蒸馏后的鸟瞰图特征图;
S2.3:将教师模型的鸟瞰图特征图输入教师模型的2D鸟瞰图骨干网络,获得教师模型的RPN特征图,将学生模型特征蒸馏后的鸟瞰图特征图输入学生模型的2D鸟瞰图骨干网络,获得学生模型的特征图;
S2.4:利用教师模型的RPN特征图对学生模型的特征图进行差异化特征蒸馏,获得学生模型特征蒸馏后的特征图;
S2.5:利用教师模型的RPN特征图输出教师模型的3D框预测、目标朝向预测和分类预测,利用学生模型特征蒸馏后的特征图输出学生模型的3D框预测、目标朝向预测和分类预测;
S2.6:分别利用教师模型的3D框预测和分类预测对学生模型的3D框预测和分类预测进行差异化特征蒸馏,获得学生模型特征蒸馏后的3D框预测和分类预测结果。
在本发明的一个实施例中,所述差异化特征蒸馏过程包括:
计算教师模型中预测特征图的每个位置隶属于前景区域的可能性:
其中,ft表示教师模型θt中间的某个模块输出的特征图,而C表示数据集包含的所有标注目标的种类数,St表示教师模型预测出的可能性得分图;
计算学生模型中预测特征图的每个位置隶属于前景区域的可能性:
其中,fs表示学生模型θts中间的某个模块输出的特征图,Ss表示学生模型预测出的可能性得分图;
定义差异化得分机制:Sd=|St-Ss|;
构造特征蒸馏损失函数:
其中,Ft,Fs分别表示教师模型和学生模型给定的一组特征图,其下标c表示数据集包含的所有标注目标的种类数,下标i,j分别表示特征图的行和列,W,H表示当前特征图的长度和宽度。
获得学生模型和教师模型组成的整个网络的特征蒸馏损失:
Lfeat=γbLfbev+γ2Lf2d+γcLfcls+γrLfreg,
其中,Lfbev表示所述3D骨干网络输出特征图的损失,Lf2d表示所述2D鸟瞰图骨干网络输出特征图的损失,Lfcls和Lfreg分别表示一阶段头部的分类特征图和回归特征图的损失,γb,γ2,γc,γr为手动设置的超参数。
在本发明的一个实施例中,所述S3包括:
S3.1:利用设定的阈值分别从教师模型和学生模型中选定符合要求的边界框,组成学生集合和教师集合;
S3.2:在选出来的学生集合和教师集合中,分别计算回归蒸馏损失函数和分类蒸馏损失函数/>
S3.3:利用分类蒸馏损失函数和回归蒸馏损失函数获得差异一致性总体蒸馏函数:
其中,为两个手动设置的超参数。
在本发明的一个实施例中,在步骤S3.2中,使用Smooth-L1函数来构造回归蒸馏损失函数:
其中,∫为设定的阈值,l为一个指示函数,表示δo的损失函数,x、y、z表示被检测目标的中心点坐标,w、l、h表示被检测目标的长宽高,r表示被检测目标的朝向,N表示边界框集合的元素数量,Sd表示教师模型预测出的可能性得分与学生模型预测出的可能性得分的差。
在本发明的一个实施例中,在步骤S3.2中,所述分类蒸馏损失函数表示为:
其中,σ表示softmax函数,ct,cs分别表示教师模型和学生模型预测的分类结果向量,表示δc的损失函数。
在本发明的一个实施例中,所述S4包括:
S4.1:获得学生模型和教师模型预测的所有边界框的集合,并利用设定的阈值对所述边界框进行过滤,获得过滤后对应的边界框集合
S4.2:计算过滤后教师模型和学生模型任意两边界框中心点的欧式距离,得到尺寸为Nt×Ns的距离矩阵,Nt,Ns分别表示过滤后教师模型和学生模型边界框集合中的元素数量;
S4.3:在所述距离矩阵中统计每行元素的最大值,并从过滤后的学生集合中选择选取每行元素的最大值,形成一个元素数量为Nt的集合
S4.4:在配对的集合上分别计算回归蒸馏损失值和分类蒸馏损失值,并利用所述回归蒸馏损失值和分类蒸馏损失值构建中心匹配蒸馏函数。
在本发明的一个实施例中,在步骤S4.4中,所述回归蒸馏损失值的计算函数为:
其中,分别表示完成配对的学生边界框和教师边界框,N表示边界框集合的元素数量。
在本发明的一个实施例中,在步骤S4.4中,使用Kullback-Leibler散度函数来计算分类蒸馏损失值,损失函数的形式为:
其中,分别表示教师模型和学生模型预测的第i个边界框属于第j个类别的可能性,C表示需要预测的总类别数。
与现有技术相比,本发明的有益效果有:
1、本发明面向点云场景的差异化知识蒸馏3D目标检测方法,提出的知识蒸馏框架可以在不增加额外计算量的前提下明显提高单阶段目标检测器的性能。
2、本发明提出的差异化知识蒸馏3D目标检测方法是一种高效能的训练方法,针对目前所有3D目标检测框架均可用,具有较好的泛化性能。
以下将结合附图及实施例对本发明做进一步详细说明。
附图说明
图1是本发明实施例提供的一种面向点云场景的差异化知识蒸馏3D目标检测方法的流程框图;
图2是本发明实施例提供的一种面向点云场景的差异化知识蒸馏3D目标检测方法的详细流程图;
图3是本发明实施例提供的一种面向点云场景的差异化知识蒸馏3D目标检测方法的框架示意图;
图4是本发明实施例提供的一种差异化特征蒸馏过程的流程图;
图5是本发明实施例提供的一种中心点匹配蒸馏过程的流程图;
图6是利用本发明实施例的方法蒸馏后的学生模型在KITTI验证集合上的可视化效果。
具体实施方式
为了进一步阐述本发明为达成预定发明目的所采取的技术手段及功效,以下结合附图及具体实施方式,对依据本发明提出的一种面向点云场景的差异化知识蒸馏3D目标检测方法进行详细说明。
有关本发明的前述及其他技术内容、特点及功效,在以下配合附图的具体实施方式详细说明中即可清楚地呈现。通过具体实施方式的说明,可对本发明为达成预定目的所采取的技术手段及功效进行更加深入且具体地了解,然而所附附图仅是提供参考与说明之用,并非用来对本发明的技术方案加以限制。
应当说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的物品或者设备中还存在另外的相同要素。
实施例一
请参见图1和图2,图1是本发明实施例提供的一种面向点云场景的差异化知识蒸馏3D目标检测方法的流程框图;图2是本发明实施例提供的一种面向点云场景的差异化知识蒸馏3D目标检测方法的详细流程图。该3D目标检测方法包括:
S1:构建两阶段教师模型和单阶段学生模型并对所述两阶段教师模型进行预训练,获得预训练后的教师模型。
本实施例提出的目标检测方法的框架由一个高实时性的单阶段学生模型和一个高精度的两阶段教师模型组成,由于单阶段学生模型与两阶段教师模型之间存在的结构性差异,使得特征式蒸馏方法无法直接迁移教师模型的RoI头部网络中的知识到学生模型之中。因此,本实施例使用回应式中心匹配蒸馏来完成这种知识的迁移。考虑到两阶段教师模型的结果中通常会存在大量的冗余检测结果,需要一种正则化方式来避免学生模型学习到教师模型的这种冗余预测特性。为此,实施例使用差异一致性蒸馏和差异化特征蒸馏作为中心匹配蒸馏方法来缓解这一问题。
该方法的整体执行顺序为,首先训练一个高精度的两阶段教师模型,将其训练结果作为单阶段学生模型的预训练权重进行学生模型的训练,并在其中加入差异化特征蒸馏将两阶段的教师模型高精度的知识蒸馏到单阶段学生模型中,以实现在较小的计算开销的前提下,提高单阶段学生模型网络的3D目标检测精度。
由图2可以看出,所述教师模型包含依次连接的一个3D骨干网络、一个2D鸟瞰图骨干网络、一个一阶段检测头和一个二阶段ROI头;所述学生模型包含依次连接的一个3D骨干网络、一个2D鸟瞰图骨干网络和一个单阶段检测头。在教师模型和学生模型训练过程中,将差异化特征蒸馏、差异化一致性蒸馏以及中心点匹配蒸馏策略加入网络训练过程。下面将详细介绍每个部分的技术方案。
S2:将所构建的训练数据集中的点云数据分别输入到经训练的教师模型和未经训练的学生模型中,分别对教师模型和学生模型的3D骨干网络和2D鸟瞰图骨干网络的输出结果进行差异化特征蒸馏。
在本实施例中,所述步骤S2具体包括:
S2.1:将训练数据集中的点云数据分别输入到经预训练的教师模型和未经训练的学生模型的3D骨干网络中,分别获得教师模型的鸟瞰图特征图和学生模型的鸟瞰图特征图。
首先,构造具有大量点云数据的训练数据集,以对教师模型和学生模型进行训练,接着,将来自训练数据集的点云数据分别输入到教师模型和学生模型的3D骨干网络中,分别获得教师模型的鸟瞰图特征图和学生模型的鸟瞰图特征图。
S2.2:利用教师模型的鸟瞰图特征图对学生模型的鸟瞰图特征图进行差异化特征蒸馏,获得学生模型特征蒸馏后的鸟瞰图特征图。
S2.3:将教师模型的鸟瞰图特征图输入教师模型的2D鸟瞰图骨干网络,获得教师模型的RPN特征图,将学生模型特征蒸馏后的鸟瞰图特征图输入学生模型的2D鸟瞰图骨干网络,获得学生模型的特征图。
S2.4:利用教师模型的RPN特征图对学生模型的特征图进行差异化特征蒸馏,获得学生模型特征蒸馏后的特征图。
S2.5:利用教师模型的RPN特征图输出教师模型的3D框预测、目标朝向预测和分类预测,利用学生模型特征蒸馏后的特征图输出学生模型的3D框预测、目标朝向预测和分类预测。
S2.6:分别利用教师模型的3D框预测和分类预测对学生模型的3D框预测和分类预测进行差异化特征蒸馏,获得学生模型特征蒸馏后的3D框预测和分类预测结果。
本实施例在提出的蒸馏框架中整合了特征蒸馏过程,同时考虑到2D目标检测任务和3D目标检测任务的差异性,本发明实施例设计了一种差异化得分机制来提高特征蒸馏方法的性能。
具体地,假设yc表示当前特征图f中某个位置隶属于一个前景目标的可能性向量,其可以表示为:
yc=P(c|f,θ)
其中,c表示该目标所属的类别,θ表示预测该结果的深度学习模型(教师模型或学生模型)。
则由教师模型的一阶段检测头预测的特征图中每个位置隶属于前景区域的可能性得分图可以表示为:
其中,ft表示教师模型θt中间的某个模块输出的特征图,而C表示训练数据集包含的所有标注目标的种类数。St表示教师模型预测出的可能性得分图,在实际训练过程中,可以直接通过教师模型的一阶段检测头分类得分预测结果得到。类似地,本实施例可以得到学生模型预测的可能性得分图,定义形式如下:
其中,fs表示学生模型θts中间的某个模块输出的特征图,而C表示训练数据集包含的所有标注目标的种类数。Ss表示学生模型预测出的可能性得分图。
在训练过程中,希望学生模型能够更多地关注与教师模型存在认知差异的部分,因此本实施例定义了如下所示的差异化得分机制:
Sd=|St-Ss|
则在给定的一组特征图Ft(教师模型获得的特征图)和Fs(学生模型获得的特征图)上,本实施例的特征蒸馏损失函数可以表示为:
其中,Ft,Fs表示教师模型和学生模型给定的一组特征图,其下标c表示数据集包含的所有标注目标的种类数,下标i,j分别表示特征图的行和列,W,H分别表示当前特征图的长度和宽度。
在本发明实施例的蒸馏框架中,分别在3D骨干网络的输出特征图,2D骨干网络的输出特征图,以及一阶段分类检测头和回归检测头输出的特征图上执行特征蒸馏。整个特征蒸馏部分损失函数可以表示为:
Lfeat=γbLfbev+γ2Lf2d+γcLfcls+γrLfreg
其中,fbev表示体素3D骨干网络输出的特征图,f2d表示2D骨干网络输出的特征图,fcls和freg分别表示一阶段头部的分类特征图和回归特征图,具体地,Lfbev表示3D骨干网络输出特征图的损失,Lf2d表示2D骨干网络输出特征图的损失,Lfcls和Lfreg分别表示一阶段头部的分类特征图的损失和回归特征图的损失,而γb,γ2,γc,γr是四个手动设置的超参数。
S3:利用教师模型的一阶段输出结果对所述学生模型的输出结果进行差异一致性蒸馏。
考虑到当前的两阶段检测器倾向于将低置信度的区域误认为前景区域,单纯地使用差异化特征蒸馏无法保证学生模型不会学习到冗余预测的特性。为此本实施例设计了差异一致性蒸馏方法,该方法使用两阶段教师模型的单阶段部分预测结果对学生模型进行蒸馏,并对学生模型进行强正则化。
本实施例的差异一致性蒸馏方法使用教师模型的单阶段网络部分结果。而教师模型的单阶段网络部分结果和学生模型的输出在结构上保持一致,因此不再需要额外的边界框匹配机制。同时为了保证过程的高效性,本实施例使用差异化得分机制来对蒸馏过程进行指导。从差异得分机制定义出发,可以得知一个位置的差异化得分高有两种情况:
1)学生模型认为该位置隶属于前景区域,而教师模型的一阶段网络认为其更有可能是背景区域。
2)教师模型的一阶段网络认为该位置隶属于前景区域,而学生模型认为该位置隶属于背景区域。
在第一种情况中,学生模型可能受到教师的二阶段网络的影响进而对该区域产生了错误的判断。而第二种情况中,学生模型学习不充足导致其无法准确的预测该区域。同样地,差异一致性蒸馏方法也设置一个阈值∈来从对应的学生教师集合中选出高置信度的边界框。
在选出来的学生集合和教师集合中,本实施例分别计算分类蒸馏损失函数和回归蒸馏损失函数。对于回归蒸馏损失部分,本实施例使用Smooth-L1函数来进行计算,其形式如下:
其中,∫为设定的阈值,l为一个指示函数,当满足括号内的条件时结果取1否则取0,表示δo的损失函数,x、y、z表示被检测目标的中心点坐标,w、l、h表示被检测目标的长宽高,r表示被检测目标的朝向,N表示边界框集合的元素数量,Sd表示教师模型预测出的可能性得分与学生模型预测出的可能性得分的差。同样地,本实施例的分类蒸馏损失函数可以表示为:
其中,σ表示softmax函数,而ct,cs分别表示教师模型和学生模型预测的分类结果向量,表示δc的损失函数。通过整合两种函数,本实施例的差异一致性总体蒸馏函数可以被表示为:
其中,为两个需要手动设置的超参数。
S4:利用教师模型的二阶段输出结果对所述学生模型的输出结果进行中心点匹配蒸馏,获得训练后的学生模型。
回应式目标检测蒸馏方法通常要求学生模型和教师模型采用相同的架构。然而在本实施例的场景中,教师模型和学生模型采用的是不同的检测器架构。因此,在执行蒸馏之前需要使用一种策略来对数量不一致的学生预测结果和教师预测结果进行配对。本实施例假定学生模型预测的所有边界框的集合为Bs,教师模型预测的所有边界框集合为Bt。在进行匹配之前,需要选择一种策略来衡量两个集合之间的元素相似性。一种常见的相似性衡量机制为IoU。然而单纯的IoU匹配容易出现误匹配现象。在3D目标检测场景中同一类别的目标在尺寸上相对接近,因此中心点靠近的两个同类框之间的相似度会比较高。基于这一观察,在进行集合匹配时本发明实施例选择使用两个边界框的中心点距离作为衡量机制。
本实施例的中心点匹配蒸馏方法使用未经过非极大值抑制处理的预测集合。直接在这两个集合上执行蒸馏,一方面会带来巨大的计算开销,并影响蒸馏的实际效果。因此,本实施例设计了一种过滤方式从两个集合中选择高质量的边界框。如果可以同时得到预先训练好的学生模型和教师模型的权重,则本实施例直接设置一个固定的阈值γ来进行过滤。如果仅有预训练好的教师模型权重,则使用一个固定的阈值γt来过滤教师集合,使用一个自适应的阈值γs来过滤学生集合。调度学生阈值的方法很多,但为了尽可能的保持方法的简洁性本实施例选择一种线性增长策略。
其中,γS表示需要过滤的边界框的阈值,t表示当前的训练轮数,而T表示总共的训练轮数。
假设分别表示过滤后的边界框中心点所组成的集合,而/>则表示对应的边界框集合。假设两个集合的元素数量分别为Nt,Ns。对于每个/>以及/>本实施例的方法首先计算任意两个边界框中心点的欧式距离,从而得到尺寸为Nt×Ns的距离矩阵。在得到的距离矩阵中统计每行元素的最大值和下标,并根据这个下标从过滤后的学生集合中选择出一个元素数量为Nt的集合/>整个过程的pytorch伪代码可以如算法1所示。
在配对的集合上,本实施例分别计算回归蒸馏损失值和分类蒸馏损失值。其中的回归蒸馏损失值的计算函数使用Smooth-L1进行构造,该函数形式如下:
其中,分别表示完成配对的学生边界框和教师边界框,x、y、z表示被检测目标的中心点坐标,w、l、h表示被检测目标的宽、长以及高,r表示目标朝向。
进一步地,本实施例使用Kullback-Leibler散度函数来计算分类蒸馏损失值。该损失函数的形式如下所示:
其中,分别表示教师模型和学生模型预测的第i个边界框属于第j个类别的可能性,而C表示需要预测的总类别数。因此,整个中心匹配蒸馏函数可以表示为:
其中,α,β为两个需要手动设置的超参数。
最终,整个网络的损失函数可以分为监督损失函数和蒸馏损失函数,其中监督损失函数部分在这里我们将其表示为Lsup。整体损失函数可以表示为:
L=Lsup+Lfeat+Lcons+Lctr
其中,+Lfeat、Lcons和Lctr分别表示差异化特征蒸馏损失、差异一致性蒸馏损失以及中心点匹配蒸馏损失。
进一步地,通过KITTI和Waymo Open Dataset主流的开源数据集进行了多组实验,以对本发明实施例的差异化知识蒸馏3D目标检测方法进行算法验证,实验结果表明,在不增加计算开销的前提下,提高了单阶段3D目标检测算法的精度,各组实验结果见下表。
表1KITTI验证集3D平均精度(3D AP)预测对比结果
表2KITTI验证集BEV视图平均精度(BEV AP)预测对比结果
表1和表2分别为各3D目标检测算法在KITTI验证数据集上的3D和BEV(鸟瞰图)检测精度结果对比,其中,student-M-S和student-M-C为知识蒸馏后的SECOND模型和CentPoint模型(本发明模型),从表中可以看出,所提出的框架可以很明显地提高学生模型的表现,对于中等难度的“汽车”类别的物体检测,可以帮助Modify-SECOND和Modify-CenterPoint模型提高1.79和1.97个点的3D AP的提升。在最难检测的类别“行人”中,在中等难度上的表现,可以帮助Modify-SECOND和Modify-CenterPoint模型提高4.96和2.89个点的3D AP的提升。
表3使用20% Waymo训练数据集各算法预测对比结果
由于Waymo数据集较大,因此许多3D目标检测算法仅使用20%的训练数据验证算法的有效性,为了得到更全面的实验结果,本发明实施例同时使用20%和100%训练数据进行算法验证,实验结果如表3和表4所示。由表可以看出,Modify SECOND在Waymo上获得了更大的精度提升。对于“骑自行车人”这一类的目标预测,所提出的检测框架使Modify SECOND在level 1级别上分别提升了6.8个AP和5.5个APH,相应的CenterPoint提升了1.0和0.21。
表4使用100% Waymo训练数据集各算法预测对比结果
蒸馏后的学生模型在KITTI验证集合上的可视化效果如图6所示,其中真实值由红色长方体框标注,预测值由绿色长方体标注,由图可以看出,检测效果较好,其中误检结果由红色圈进行了标注,漏检结果由绿色圈进行了标注。在前三个场景中,所有的汽车目标都被正确的检测到,在中间的三个场景中,有几个目标被错误的检测为汽车,这些目标通常是不在数据集中标注的模糊背景对象,最后的三个场景是一些被严重遮挡的目标,学生模型没有检测到这些目标。这些误检和漏检目标的一个共同特点是它们距离LiDAR传感器相对较远,包含的点云数量非常稀疏。
本发明实施例面向点云场景的差异化知识蒸馏3D目标检测方法,提出的知识蒸馏框架可以在不增加额外计算量的前提下明显提高单阶段目标检测器的性能。本发明实施例提出的差异化知识蒸馏3D目标检测方法是一种高效能的训练方法,针对目前所有3D目标检测框架均可用,具有较好的泛化性能。
在本发明所提供的几个实施例中,应该理解到,本发明所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述模块的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个模块或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
另外,在本发明各个实施例中的各功能模块可以集成在一个处理模块中,也可以是各个模块单独物理存在,也可以两个或两个以上模块集成在一个模块中。上述集成的模块既可以采用硬件的形式实现,也可以采用硬件加软件功能模块的形式实现。
本发明的又一实施例提供了一种存储介质,所述存储介质中存储有计算机程序,所述计算机程序用于执行上述实施例中所述面向点云场景的差异化知识蒸馏3D目标检测方法的步骤。本发明的再一方面提供了一种电子设备,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器调用所述存储器中的计算机程序时实现如上述实施例所述面向点云场景的差异化知识蒸馏3D目标检测方法的步骤。具体地,上述以软件功能模块的形式实现的集成的模块,可以存储在一个计算机可读取存储介质中。上述软件功能模块存储在一个存储介质中,包括若干指令用以使得一台电子设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本发明各个实施例所述方法的部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
以上内容是结合具体的优选实施方式对本发明所作的进一步详细说明,不能认定本发明的具体实施只局限于这些说明。对于本发明所属技术领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干简单推演或替换,都应当视为属于本发明的保护范围。
Claims (9)
1.一种面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,包括:
S1:构建两阶段教师模型和单阶段学生模型并对所述两阶段教师模型进行预训练,获得预训练后的教师模型,所述教师模型包含3D骨干网络、2D鸟瞰图骨干网络、一阶段检测头和二阶段ROI头,所述学生模型包括3D骨干网络、2D鸟瞰图骨干网络和单阶段检测头;
S2:将所构建的训练数据集中的点云数据分别输入到经训练的教师模型和未经训练的学生模型中,分别对教师模型和学生模型的3D骨干网络和2D鸟瞰图骨干网络的输出结果进行差异化特征蒸馏;
S3:利用教师模型的一阶段输出结果对所述学生模型的输出结果进行差异一致性蒸馏;
S4:利用教师模型的二阶段输出结果对所述学生模型的输出结果进行中心点匹配蒸馏,获得训练后的学生模型。
2.根据权利要求1所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,所述S2包括:
S2.1:将训练数据集中的点云数据分别输入到经预训练的教师模型和未经训练的学生模型的3D骨干网络中,分别获得教师模型的鸟瞰图特征图和学生模型的鸟瞰图特征图;
S2.2:利用教师模型的鸟瞰图特征图对学生模型的鸟瞰图特征图进行差异化特征蒸馏,获得学生模型特征蒸馏后的鸟瞰图特征图;
S2.3:将教师模型的鸟瞰图特征图输入教师模型的2D鸟瞰图骨干网络,获得教师模型的RPN特征图,将学生模型特征蒸馏后的鸟瞰图特征图输入学生模型的2D鸟瞰图骨干网络,获得学生模型的特征图;
S2.4:利用教师模型的RPN特征图对学生模型的特征图进行差异化特征蒸馏,获得学生模型特征蒸馏后的特征图;
S2.5:利用教师模型的RPN特征图输出教师模型的3D框预测、目标朝向预测和分类预测,利用学生模型特征蒸馏后的特征图输出学生模型的3D框预测、目标朝向预测和分类预测;
S2.6:分别利用教师模型的3D框预测和分类预测对学生模型的3D框预测和分类预测进行差异化特征蒸馏,获得学生模型特征蒸馏后的3D框预测和分类预测结果。
3.根据权利要求2所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,所述差异化特征蒸馏过程包括:
计算教师模型中预测特征图的每个位置隶属于前景区域的可能性:
其中,ft表示教师模型θt中间的某个模块输出的特征图,而C表示数据集包含的所有标注目标的种类数,St表示教师模型预测出的可能性得分图;
计算学生模型中预测特征图的每个位置隶属于前景区域的可能性:
其中,fs表示学生模型θts中间的某个模块输出的特征图,Ss表示学生模型预测出的可能性得分图;
定义差异化得分机制:Sd=|St-Ss|;
构造特征蒸馏损失函数:
其中,Ft,Fs表示教师模型和学生模型给定的一组特征图,其下标c表示数据集包含的所有标注目标的种类数,下标i,j分别表示特征图的行和列,W,H分别表示当前特征图的长度和宽度;
获得学生模型和教师模型组成的整个网络的特征蒸馏损失:
Lfeat=γbLfbev+γ2Lf2d+γcLfcls+γrLfreg,
其中,Lfbev表示所述3D骨干网络输出特征图的损失,Lf2d表示所述2D鸟瞰图骨干网络输出特征图的损失,Lfcls和Lfreg分别表示一阶段头部的分类特征图和回归特征图的损失,γb,γ2,γc,γr为手动设置的超参数。
4.根据权利要求3所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,所述S3包括:
S3.1:利用设定的阈值分别从教师模型和学生模型中选定符合要求的边界框,组成学生集合和教师集合;
S3.2:在选出来的学生集合和教师集合中,分别计算回归蒸馏损失函数和分类蒸馏损失函数/>
S3.3:利用分类蒸馏损失函数和回归蒸馏损失函数获得差异一致性总体蒸馏函数:
其中,为两个手动设置的超参数。
5.根据权利要求4所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,在步骤S3.2中,使用Smooth-L1函数来构造回归蒸馏损失函数:
其中,∫为设定的阈值,l为一个指示函数,表示δo的损失函数,x、y、z表示被检测目标的中心点坐标,w、l、h表示被检测目标的长宽高,r表示被检测目标的朝向,N表示边界框集合的元素数量,Sd表示教师模型预测出的可能性得分与学生模型预测出的可能性得分的差。
6.根据权利要求5所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,在步骤S3.2中,所述分类蒸馏损失函数表示为:
其中,σ表示softmax函数,ct,cs分别表示教师模型和学生模型预测的分类结果向量,表示δc的损失函数。
7.根据权利要求6所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,所述S4包括:
S4.1:获得学生模型和教师模型预测的所有边界框的集合,并利用设定的阈值对所述边界框进行过滤,获得过滤后对应的边界框集合
S4.2:计算过滤后教师模型和学生模型任意两边界框中心点的欧式距离,得到尺寸为Nt×Ns的距离矩阵,Nt,Ns分别表示过滤后教师模型和学生模型边界框集合中的元素数量;
S4.3:在所述距离矩阵中统计每行元素的最大值,并从过滤后的学生集合中选择选取每行元素的最大值,形成一个元素数量为Nt的集合
S4.4:在配对的集合上分别计算回归蒸馏损失值和分类蒸馏损失值,并利用所述回归蒸馏损失值和分类蒸馏损失值构建中心匹配蒸馏函数。
8.根据权利要求7所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,在步骤S4.4中,所述回归蒸馏损失值的计算函数为:
其中,分别表示完成配对的学生边界框和教师边界框,N表示边界框集合的元素数量。
9.根据权利要求8所述的面向点云场景的差异化知识蒸馏3D目标检测方法,其特征在于,在步骤S4.4中,使用Kullback-Leibler散度函数来计算分类蒸馏损失值,损失函数的形式为:
其中,分别表示教师模型和学生模型预测的第i个边界框属于第j个类别的可能性,C表示需要预测的总类别数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310426368.3A CN116612450A (zh) | 2023-04-19 | 2023-04-19 | 一种面向点云场景的差异化知识蒸馏3d目标检测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310426368.3A CN116612450A (zh) | 2023-04-19 | 2023-04-19 | 一种面向点云场景的差异化知识蒸馏3d目标检测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116612450A true CN116612450A (zh) | 2023-08-18 |
Family
ID=87675507
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310426368.3A Pending CN116612450A (zh) | 2023-04-19 | 2023-04-19 | 一种面向点云场景的差异化知识蒸馏3d目标检测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116612450A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117542085A (zh) * | 2024-01-10 | 2024-02-09 | 湖南工商大学 | 基于知识蒸馏的园区场景行人检测方法、装置及设备 |
-
2023
- 2023-04-19 CN CN202310426368.3A patent/CN116612450A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117542085A (zh) * | 2024-01-10 | 2024-02-09 | 湖南工商大学 | 基于知识蒸馏的园区场景行人检测方法、装置及设备 |
CN117542085B (zh) * | 2024-01-10 | 2024-05-03 | 湖南工商大学 | 基于知识蒸馏的园区场景行人检测方法、装置及设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109902677B (zh) | 一种基于深度学习的车辆检测方法 | |
CN110084292B (zh) | 基于DenseNet和多尺度特征融合的目标检测方法 | |
CN110796168A (zh) | 一种基于改进YOLOv3的车辆检测方法 | |
CN106845430A (zh) | 基于加速区域卷积神经网络的行人检测与跟踪方法 | |
CN112800964B (zh) | 基于多模块融合的遥感影像目标检测方法及系统 | |
CN107633226B (zh) | 一种人体动作跟踪特征处理方法 | |
CN108830188A (zh) | 基于深度学习的车辆检测方法 | |
US20210326638A1 (en) | Video panoptic segmentation | |
CN113486764B (zh) | 一种基于改进的YOLOv3的坑洼检测方法 | |
CN111738055B (zh) | 多类别文本检测系统和基于该系统的票据表单检测方法 | |
CN113129335B (zh) | 一种基于孪生网络的视觉跟踪算法及多模板更新策略 | |
CN110647802A (zh) | 基于深度学习的遥感影像舰船目标检测方法 | |
WO2019007253A1 (zh) | 图像识别方法、装置及设备、可读介质 | |
Zheng et al. | Improvement of grayscale image 2D maximum entropy threshold segmentation method | |
CN113408584B (zh) | Rgb-d多模态特征融合3d目标检测方法 | |
CN110363165B (zh) | 基于tsk模糊系统的多目标跟踪方法、装置及存储介质 | |
CN111950488A (zh) | 一种改进的Faster-RCNN遥感图像目标检测方法 | |
CN111126278A (zh) | 针对少类别场景的目标检测模型优化与加速的方法 | |
CN113052108A (zh) | 基于深度神经网络的多尺度级联航拍目标检测方法和系统 | |
CN110659601A (zh) | 基于中心点的深度全卷积网络遥感图像密集车辆检测方法 | |
CN116612450A (zh) | 一种面向点云场景的差异化知识蒸馏3d目标检测方法 | |
CN115100741A (zh) | 一种点云行人距离风险检测方法、系统、设备和介质 | |
Yulin et al. | Wreckage target recognition in side-scan sonar images based on an improved faster r-cnn model | |
CN110827327B (zh) | 一种基于融合的长期目标跟踪方法 | |
CN117542082A (zh) | 一种基于YOLOv7的行人检测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |