CN116342978A - 目标检测网络训练及目标检测方法、装置、电子设备 - Google Patents
目标检测网络训练及目标检测方法、装置、电子设备 Download PDFInfo
- Publication number
- CN116342978A CN116342978A CN202310332431.7A CN202310332431A CN116342978A CN 116342978 A CN116342978 A CN 116342978A CN 202310332431 A CN202310332431 A CN 202310332431A CN 116342978 A CN116342978 A CN 116342978A
- Authority
- CN
- China
- Prior art keywords
- target
- information
- feature map
- training
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 314
- 238000001514 detection method Methods 0.000 title claims abstract description 252
- 238000013528 artificial neural network Methods 0.000 claims abstract description 215
- 238000000605 extraction Methods 0.000 claims abstract description 162
- 238000000034 method Methods 0.000 claims abstract description 88
- 239000013598 vector Substances 0.000 claims description 132
- 230000006870 function Effects 0.000 claims description 66
- 238000012545 processing Methods 0.000 claims description 26
- 238000004590 computer program Methods 0.000 claims description 11
- 238000013140 knowledge distillation Methods 0.000 abstract description 28
- 230000008569 process Effects 0.000 description 22
- 238000010606 normalization Methods 0.000 description 12
- 238000010586 diagram Methods 0.000 description 11
- 238000004422 calculation algorithm Methods 0.000 description 8
- 230000010365 information processing Effects 0.000 description 4
- 230000006978 adaptation Effects 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000013461 design Methods 0.000 description 2
- 230000004927 fusion Effects 0.000 description 2
- 239000004973 liquid crystal related substance Substances 0.000 description 2
- 238000012544 monitoring process Methods 0.000 description 2
- 239000004065 semiconductor Substances 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 241000287196 Asthenes Species 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012805 post-processing Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Medical Informatics (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了目标检测网络训练及目标检测方法、装置、电子设备,该训练方法为:将训练图像集中各训练图像输入第一神经网络,获得各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息;将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件的情况下,获得目标检测网络,实现一种通用的知识蒸馏。
Description
技术领域
本公开涉及计算机技术领域,具体而言,涉及目标检测网络训练及目标检测方法、装置、电子设备。
背景技术
目标检测是找出图像中感兴趣的目标或物体,是计算机视觉中的重要问题,在实际部署时,由于不同硬件设备的算力限制,可以使用知识蒸馏算法,对神经网络进行压缩,例如,使用大神经网络训练小神经网络,但是相关技术中,知识蒸馏算法通常仅针对同构网络进行,要求大神经网络和小神经网络的检测器相同并且骨架网络结构相同,降低了知识蒸馏算法应用的灵活性和适用场景。
发明内容
本公开实施例至少提供目标检测网络训练及目标检测方法、装置、电子设备。
第一方面,本公开实施例提供了一种目标检测网络训练方法,包括:
获取训练图像集,其中,所述训练图像集中包括各训练图像;
将所述各训练图像输入第一神经网络,获得所述各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,其中,所述提取模块是基于所述第一神经网络训练而获得的;
将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;
根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件的情况下,获得目标检测网络。
本公开实施例中,将训练图像集中各训练图像输入第一神经网络,通过提取模块获得各训练图像的第一目标检测信息,并将各训练图像输入第二神经网络,通过提取模块获得各训练图像的第二目标检测信息,进而根据各训练图像的第二目标检测信息和第一目标检测信息,训练第二神经网络,获得目标检测网络,这样,通过提取模块,从第一神经网络中提取相关检测知识,再利用提取到的相关检测知识,指导第二神经网络的训练,完成了知识蒸馏过程,并且该方案通过提取模块来作为中介实现,不需要限制第一神经网络和第二神经网络的网络结构,因此可以适用于任意类型检测器和跨骨架网络的知识蒸馏,提高了应用灵活性,并且还可以提高第二神经网络的目标检测的准确性。
一种可选的实施方式中,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,包括:
根据所述提取模块中目标初始向量和所述第一特征图信息,通过所述提取模块的第一目标交叉注意力层,获得查询向量;
根据所述查询向量和所述第一特征图信息,通过所述提取模块的第二目标交叉注意力层,获得所述第一目标检测信息。
本公开实施例中,将目标初始向量和第一特征图信息,进行交叉注意力处理,获得查询向量,再与第一特征图信息进行交叉注意力处理,可以获得第一神经网络中的相关检测知识,通过交叉注意力处理可以更关注训练图像中目标部分,提高准确性。
一种可选的实施方式中,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息,包括:
根据所述提取模块中目标初始向量和所述第一特征图信息,通过所述提取模块的第一目标交叉注意力层,获得查询向量;
根据所述查询向量和所述第二特征图信息,通过所述提取模块的第二目标交叉注意力层,获得所述第二目标检测信息。
本公开实施例中,利用第一神经网络的输出来获得查询向量,进而基于该查询向量和第二特征图信息进行交叉注意力,获得第二神经网络的第二目标检测信息,可以提高第二神经网络学习第一神经网络中检测知识的性能。
一种可选的实施方式中,所述第一目标检测信息包括第一目标内容信息和第一目标位置信息,所述第二目标检测信息包括第二目标内容信息和第二目标位置信息;
所述提取模块包括第一网络分支和第二网络分支,所述第一网络分支用于提取所述第一目标内容信息或所述第二目标内容信息,所述第一网络分支的网络结构至少包括第一交叉注意力层和第二交叉注意力层;
并且所述第二网络分支用于提取所述第一目标位置信息或第二目标位置信息,所述第二网络分支的网络结构至少包括第三交叉注意力层和第四交叉注意力层。
本公开实施例中,相关检测知识可以包括内容知识和位置知识,并通过提取模块的两个独立的第一网络分支和第二网络分支进行提取,可以从第一神经网络中提取到更丰富的相关检测知识,进而提高对第二神经网络训练的精度。
一种可选的实施方式中,所述满足预设条件,包括:
迭代训练次数达到阈值,或者目标损失函数满足收敛条件;
其中,所述目标损失函数至少包括第一损失函数和第二损失函数的加权和,所述第一损失函数表示所述第一目标内容信息和所述第二目标内容信息之间的损失函数,所述第二损失函数表示所述第一目标位置信息和所述第二目标位置信息之间的损失函数。
本公开实施例中,提取模块包括两个网络分支,分别获得内容知识和位置知识,进而可以基于两个网络分支的损失函数的加权和,确定最终的目标损失函数,提高训练准确性。
一种可选的实施方式中,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,包括:
将所述第一特征图信息输入所述第一网络分支,根据所述第一网络分支中第一初始向量和所述第一特征图信息,通过所述第一交叉注意力层,获得第一查询向量,并根据所述第一查询向量和所述第一特征图信息,通过所述第二交叉注意力层,获得所述第一目标内容信息;
将所述第一特征图信息输入所述第二网络分支,根据所述第二网络分支中第二初始向量和所述第一特征图信息,通过所述第三交叉注意力层,获得第二查询向量,并根据所述第二查询向量和所述第一特征图信息,通过所述第四交叉注意力层,获得所述第一目标位置信息,其中,所述第二初始向量与所述各训练图像中目标候选框的位置相关。
本公开实施例中,通过提取模块的第一网络分支和第二网络分支,分别提取第一神经网络的第一目标内容信息和第二目标位置信息,提高相关检测知识提取准确性和丰富性。
一种可选的实施方式中,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息,包括:
将所述第一特征图信息和所述第二特征图信息输入所述第一网络分支,根据所述第一网络分支中第一初始向量和所述第一特征图信息,通过所述第一交叉注意力层,获得第一查询向量,并根据所述第一查询向量和所述第二特征图信息,通过所述第二交叉注意力层,获得所述第二目标内容信息;
将所述第一特征图信息和所述第二特征图信息输入所述第二网络分支,根据所述第二网络分支中第二初始向量和所述第一特征图信息,通过所述第三交叉注意力层,获得第二查询向量,并根据所述第二查询向量和所述第二特征图信息,通过所述第四交叉注意力层,获得所述第二目标位置信息,其中,所述第二初始向量与所述各训练图像中目标候选框的位置相关。
本公开实施例中,基于第一神经网络的特征输出而获得第一查询向量和第二查询向量,进而分别基于提取模块的第一网络分支和第二网络分支,获得第二神经网络的第二目标内容信息和第二目标位置信息,使得第二神经网络模仿第一神经网络,而达到训练第二神经网络的目的,并且提高第二神经网络训练的精度。
一种可选的实施方式中,所述提取模块的训练方式,包括以下步骤:
获取第二训练图像集,其中,所述第二训练图像集中包括各第二训练图像;
将所述各第二训练图像输入所述第一神经网络,获得所述各第二训练图像的第三特征图信息,并根据所述第三特征图信息,通过所述提取模块,获得所述各第二训练图像的第三目标检测信息;
根据所述第三目标检测信息,训练所述提取模块,在迭代训练次数达到阈值或者提取学习损失函数满足收敛条件的情况下,获得训练完成后的提取模块,其中,所述提取学习损失函数包括预测标签和真实标签之间的损失,所述预测标签是基于所述第三目标检测信息进行目标检测后得到。
本公开实施例中,可以将提取模块附加到第一神经网络的特征输出上,进而基于第一神经网络训练获得提取模块,可以使得提取模块可以提取到与检测任务相关的检测知识。
一种可选的实施方式中,所述第三目标检测信息包括第三目标内容信息和第三目标位置信息,则根据所述第三特征图信息,通过所述提取模块,获得所述各第二训练图像的第三目标检测信息,包括:
将所述第三特征图信息输入所述提取模块的第一网络分支,获得所述各第二训练图像的第三目标内容信息;
将所述第三特征图信息输入所述提取模块的第二网络分支,获得所述各第二训练图像的第三目标位置信息。
本公开实施例中,提取模块可以分为第一网络分支和第二网络分支,并且这两个网络分支分别进行训练,其网络参数独立,可以提高对内容知识和位置知识提取的准确性。
一种可选的实施方式中,所述预测标签包括预测目标位置标签和预测类别标签,并所述预测目标位置标签和所述预测类别标签是基于所述第三目标内容信息或所述第三目标位置信息而获得的,所述真实标签包括真实目标位置标签和真实类别标签;
则所述提取学习损失函数包括所述预测目标位置标签与所述真实目标位置标签之间的损失,以及所述预测类别标签与所述真实目标类别标签之间的损失的加和。
本公开实施例中,基于两个网络分支的提取模块的训练,损失函数可以包括位置损失函数和类别损失函数,提高训练准确性。
第二方面,本公开实施例还提供一种目标检测方法,其特征在于,包括:
获取待检测图像;
利用目标检测网络对所述待检测图像进行目标检测,获得目标类别,所述目标检测网络基于上述第一方面或第一方面中任一种可能的实施方式中所述的目标检测网络训练方法得到。
本公开实施例中,基于提取模块提取第一神经网络的第一目标检测信息,并提取第二神经网络的第二目标检测信息,进而训练获得目标检测网络,基于训练后的目标检测信息,可以对待检测图像进行目标检测,获得从待检测图像中检测到的目标类别,这样,第二神经网络相较与第一神经网络更轻量,并且具有和第一神经网络类似的目标检测功能,因此可以直接将第二神经网络部署到相应的应用场景中,第二神经网络对于硬件设备的要求也较低,提高了应用灵活性,并且还可以保证目标检测的性能。
第三方面,本公开实施例还提供一种目标检测网络训练装置,包括:
第一获取模块,用于获取训练图像集,其中,所述训练图像集中包括各训练图像;
第一处理模块,用于将所述各训练图像输入第一神经网络,获得所述各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,其中,所述提取模块是基于所述第一神经网络训练而获得的;
第二处理模块,用于将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;
第一训练模块,用于根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件的情况下,获得目标检测网络。
第四方面,本公开实施例还提供一种目标检测装置,包括:
第二获取模块,用于获取待检测图像;
检测模块,用于利用目标检测网络对所述待检测图像进行目标检测,获得目标类别,所述目标检测网络基于上述第一方面或第一方面中任一种可能的实施方式中所述的目标检测网络训练方法得到。
关于上述目标检测网络或目标检测装置、电子设备、及计算机可读存储介质的效果描述参见上述目标检测网络或目标检测方法的说明,这里不再赘述。
第五方面,本公开可选实现方式还提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现上述第一方面,或第二方面中任一种可能的实施方式中的步骤。
第六方面,本公开可选实现方式还提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述第一方面,或第二方面中任一种可能的实施方式中的步骤。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本公开的技术方案。
为使本公开的上述目的、特征和优点能更明显易懂,下文特举较佳实施例,并配合所附附图,作详细说明如下。
附图说明
为了更清楚地说明本公开实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,此处的附图被并入说明书中并构成本说明书中的一部分,这些附图示出了符合本公开的实施例,并与说明书一起用于说明本公开的技术方案。应当理解,以下附图仅示出了本公开的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1示出了本公开实施例所提供的一种目标检测网络训练方法的流程图;
图2示出了本公开实施例所提供的提取模块训练过程的逻辑原理图;
图3示出了本公开实施例所提供的第二神经网络训练过程的逻辑原理图;
图4示出了本公开实施例所提供的一种目标检测方法流程图;
图5示出了本公开实施例所提供的一种目标检测网络训练装置的示意图;
图6示出了本公开实施例所提供的一种目标检测装置的示意图;
图7示出了本公开实施例所提供的一种电子设备的示意图。
具体实施方式
为使本公开实施例的目的、技术方案和优点更加清楚,下面将结合本公开实施例中附图,对本公开实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本公开一部分实施例,而不是全部的实施例。通常在此处描述和示出的本公开实施例的组件可以以各种不同的配置来布置和设计。因此,以下对本公开的实施例的详细描述并非旨在限制要求保护的本公开的范围,而是仅仅表示本公开的选定实施例。基于本公开的实施例,本领域技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本公开保护的范围。
为便于对本公开技术方案的理解,首先对本公开实施例中的技术用语加以说明:
知识蒸馏:即对知识进行蒸馏,对于神经网络来说,知识被包含在训练好的模型参数里,知识蒸馏是模型压缩的一种方法,在知识蒸馏中,通过提取性能更好的大神经网络的监督信息,构建一个小神经网络,同时使得小神经网络也具有较好的性能和精度,这里的大神经网络也可以称为教师(Teacher)网络,小神经网络称为学生(Student)网络,即将Teacher网络中学习到的知识迁移到Student网络中。
骨架网络(backbone):骨架网络为神经网络的主干网络,主要用于进行特征提取,生成特征图(feature map)供后面的网络模块使用,例如骨架网络为卷积神经网络,例如在计算机视觉领域,通常要先通过骨架网络对图像进行特征提取,这部分是后续下游任务的基础,后续的下游任务需基于提取出的特征来进行,例如分类任务等。
特征金字塔网络(Feature Pyramid Network,FPN):主要解决了目标检测中多尺度问题,同时利用低层特征高分辨率和高层特征的高语义信息,通过融合这些不同层的特征达到预测的效果。
经研究发现,目标检测是计算机视觉中的重要问题,例如在智能监控、智慧城市、智能家居、自动驾驶和机器人等领域都有着广泛的应用,在实际部署时,由于不同硬件设备的算力限制,可以使用知识蒸馏算法,对神经网络进行压缩,满足不同硬件设备对神经网络的参数、推理速度等的要求,但是相关技术中,知识蒸馏算法通常仅针对同构网络进行,要求大神经网络和小神经网络属于同一检测器并且骨架网络结构相同,降低了知识蒸馏算法应用的灵活性和适用场景。
基于上述研究,本公开提供了一种目标检测网络训练方法,设计一种自适应的提取模块,基于第一神经网络训练获得该提取模块,进而在基于训练图像集训练第二神经网络时,将各训练图像输入第一神经网络,通过提取模块,获得各训练图像的第一目标检测信息,并将各训练图像输入第二神经网络,通过提取模块获得各训练图像的第二目标检测信息,进而根据各训练图像的第二目标检测信息和第一目标检测信息,训练第二神经网络,这样,本公开实施例中,通过提取模块提取第一神经网络的相关目标检测知识,并传递给第二神经网络,而不需要限制第一神经网络和第二神经网络的骨架网络或检测器的网络架构,因此应用更加灵活,可以适用于跨检测器网络架构和跨骨架网络架构的知识蒸馏,并且基于第一神经网络提取到的相关目标检测知识,指导第二神经网络的训练,也可以保证第二神经网络的性能。
针对以上方案所存在的缺陷,均是发明人在经过实践并仔细研究后得出的结果,因此,上述问题的发现过程以及下文中本公开针对上述问题所提出的解决方案,都应该是发明人在本公开过程中对本公开做出的贡献。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。
为便于对本实施例进行理解,首先对本公开实施例所公开的一种目标检测网络训练方法进行详细介绍,本公开实施例所提供的目标检测网络训练方法的执行主体一般为具有一定计算能力的电子设备,该电子设备例如包括:终端设备或服务器或其它处理设备,终端设备可以为用户设备(User Equipment,UE)、移动设备、蜂窝电话、无绳电话、个人数字助理(Personal Digital Assistant,PDA)、手持设备、计算设备、车载设备、可穿戴设备等,其中,个人数字助理是一种手持式电子设备,具有电子计算机的某些功能,可以用来管理个人信息,也可以上网浏览,收发电子邮件等,一般不配备键盘,也可以称为掌上电脑。在一些可能的实现方式中,该目标检测网络训练方法可以通过处理器调用存储器中存储的计算机可读指令的方式来实现。
下面以执行主体为服务器为例对本公开实施例提供的目标检测网络训练方法加以说明。
参见图1所示,为本公开实施例提供的目标检测网络训练方法的流程图,所述方法包括:
S101:获取训练图像集,其中,训练图像集中包括各训练图像。
本公开实施例中,对于目标检测方法和目标检测网络的训练方法的应用场景并不进行限制,训练图像集即可以是不同场景下的训练图像,例如,智能交通、机器人、自动驾驶等应用场景。
其中,训练图像集中各训练图像,例如可以为已标注图像,即各训练图像包括已标注目标的位置标签和类别标签,可以用于第二神经网络训练过程中的位置和类别损失函数的计算。又例如,由于本公开实施例中基于第一神经网络来指导训练第二神经网络,也可以使第二神经网络来直接模仿第一神经网络的输出结果,则可以将第二神经网络输出的位置和类别,作为第二神经网络训练过程中的位置标签和类别标签,对此,本公开实施例中并不进行限制。
S102:将各训练图像输入第一神经网络,获得各训练图像的第一特征图信息,根据第一特征图信息,通过提取模块,获得各训练图像的第一目标检测信息,其中,提取模块是基于第一神经网络训练而获得的。
本公开实施例中,提供了一种通用知识蒸馏方法,设计附加的提取模块,该提取模块先基于第一神经网络进行训练,以提取第一神经网络的相关目标检测知识,然后该提取模块可以连接到第二神经网络中,以利用从第一神经网络提取出的相关目标检测知识指导第二神经网络训练。
例如,第一神经网络为Teacher网络,第二神经网络为Student网络,在训练提取模块时,可以固定第一神经网络的其它网络结构的网络参数,而仅训练该提取模块,相比于训练完整的第一神经网络,可以极大提高效率。
为便于理解,这里对提取模块进行简单介绍,提取模块可以理解为相对独立的模块,提取模块的网络结构主要包括两部分,其中第一部分用于获得查询向量,第二部分用于目标检测信息提取,第一部分的输出为第二部分的输入,并且这两部分的网络结构可以是相同或类似的,本公开实施例中,对于提取模块的网络结构并不进行限制。
例如,提取模块的一种可能的网络结构:第一部分的网络结构依次为输入层、第一目标自注意力层、残差连接和归一化层、第一目标交叉注意力层、残差连接和归一化层、前向反馈网络层、残差连接和归一化层,第二部分的网络依次为第二目标自注意力层、残差连接和归一化层、第二目标交叉注意力层、残差连接和归一化层、前向反馈网络层、残差连接和归一化层、输出层。
其中,第一部分的输入层可以为通过第一神经网络获得的第一特征图信息或第二神经网络获得的第二特征图信息,以第一特征图信息为例进行之后处理逻辑说明,然后,通过第一目标自注意力层对目标初始向量进行特征融合和提取,通过残差连接和归一化层对第一目标自注意力层的输出和目标初始向量进行残差连接和归一化处理,进而通过第一目标交叉注意力层,将残差连接和归一化层的输出和第一特征图信息进行交叉特征融合,再通过一个残差连接和归一化层,对第一目标交叉注意力层的输入和输出进行残差连接和归一化处理,然后再通过前向反馈网络层进行特征非线性处理,并再通过一个残差连接和归一化层,对前向反馈网络层的输入和输出进行残差连接和归一化处理,获得查询向量;进而查询向量输入到第二部分,通过第二部分的各网络层进行处理,具体与第一部分中各网络层的处理类似,最终可以输出第一目标检测信息。
具体执行该步骤S102时,本公开提供了一种可能的实施方式:
1)将各训练图像输入第一神经网络,获得各训练图像的第一特征图信息。
具体地,将各训练图像输入第一神经网络,通过第一神经网络的骨架网络,对各训练图像进行特征提取,获得各训练图像的第一特征图信息。
并且,本公开实施例中,该骨架网络可以结合FPN方法,获得多个不同尺度的第一特征图信息,以用于后续提取模块的训练,可以提升特征提取和检测算法对于不同尺度检测目标的鲁棒性。
2)根据提取模块中目标初始向量和第一特征图信息,通过提取模块的第一目标交叉注意力层,获得查询向量。
其中,该目标初始向量可以理解为提取模块中预训练后的参数信息,在提取模块预训练后可以获得。
3)根据查询向量和第一特征图信息,通过提取模块的第二目标交叉注意力层,获得第一目标检测信息。
本公开实施例中,基于第一神经网络训练第二神经网络时,固定已训练的第一神经网络和已训练的提取模块的各种参数,从而基于已训练的提取模块,提取第一神经网络中的第一目标检测信息。
S103:将各训练图像输入第二神经网络,获得各训练图像的第二特征图信息,根据第二特征图信息和第一特征图信息,通过提取模块,获得各训练图像的第二目标检测信息。
执行该步骤S103时,本公开实施例中还提供了一种可能的实施方式:
1)将各训练图像输入第二神经网络,获得各训练图像的第二特征图信息。
具体地包括:将各训练图像输入第二神经网络,通过第二神经网络的骨架网络,对各训练图像进行特征提取,获得各训练图像的第二特征图信息。
同样地,本公开实施例中,第二神经网络中也可以将FPN应用于骨架网络,可以获得不同尺度的第二特征图信息,提高第二神经网络对于不同尺度图像的目标检测效果。
2)根据提取模块中目标初始向量和第一特征图信息,通过提取模块的第一目标交叉注意力层,获得查询向量。
即本公开实施例中,利用第一神经网络输出的第一特征图信息,获得查询向量,以用于第二神经网络的目标检测信息提取,可以使得第二神经网络可以学习到第一神经网络的中间特征信息,并且不是直接传递第一特征图信息,而是结合第一特征图信息和交叉注意力,交叉注意力可以更好地反映骨架网络的激活,在知识蒸馏中可以提高第二神经网络的性能。
3)根据查询向量和第二特征图信息,通过提取模块的第二目标交叉注意力层,获得第二目标检测信息。
S104:根据各训练图像的第二目标检测信息和第一目标检测信息,训练第二神经网络,在满足预设条件的情况下,获得目标检测网络。
其中,一种可能实施例中,目标检测检测网络即为训练后的第二神经网络,对第二神经网络的训练即是训练获得第二神经网络的各网络参数的取值,例如包括第二神经网络的骨架网络的网络参数等,第二神经网络相较于第一神经网络,为更轻量更小的网络模型,并且由于基于第一神经网络来训练第二神经网络,因此第二神经网络可以具有第一神经网络的功能,可以将第二神经网络部署到要求更低的硬件设备或边缘计算设备,满足不同硬件设备的需求。
另外,需要说明的是,本公开实施例中,根据第一目标检测信息和第二目标检测信息,来训练第二神经网络,使得第二神经网络可以从第一神经网络中吸收检测相关知识,模仿第一神经网络输出的第一目标检测信息,第一目标检测信息和第二目标检测信息之间的损失函数也可以称知识蒸馏损失函数,当然,本公开实施例中,在训练第二神经网络过程中,目标损失函数还包括检测算法本身的损失,例如分类损失和目标定位损失。
本公开实施例中,获取训练图像集,将训练图像集中各训练图像输入第一神经网络,获得各训练图像的第一特征图信息,根据第一特征图信息,通过提取模块获得各训练图像的第一目标检测信息,并将各训练图像输入第二神经网络,获得各训练图像的第二特征图信息,根据第二特征图信息和第一特征图信息,通过提取模块,获得各训练图像的第二目标检测信息,进而根据各训练图像的第二目标检测信息和第一目标检测信息,训练第二神经网络,获得目标检测网络,这样,通过设计提取模块,基于第一神经网络训练获得提取模块,进而通过提取模块提取第一神经网络的第一目标检测信息,再使用提取的第一目标检测信息来指导第二神经网络的训练学习,使用该方案,可以完成任意检测器架构和任意骨架网络结构之间的知识蒸馏,降低了为满足检测器知识蒸馏要求对第一神经网络和第二神经网络的适配成本,并且提高了第二神经网络的精度和性能。
基于上述实施例,本公开实施例中,先从第一神经网络中进行相关检测知识的提取,然后再利用提取的知识来指导第二神经网络的训练,一种可能实施例中,提取的第一目标检测信息包括第一目标内容信息和第一目标位置信息,第二目标检测信息包括第二目标内容信息和第二目标位置信息,即从第一神经网络中提取了内容知识和位置知识来指导第二神经网络的训练,提高了知识蒸馏性能,相应地,本公开实施例中,提供了相应的提取和训练方式。
本公开实施例中,提取模块包括第一网络分支和第二网络分支,第一网络分支用于提取第一目标内容信息或第二目标内容信息,第一网络分支的网络结构至少包括第一交叉注意力层和第二交叉注意力层;并且第二网络分支用于提取第一目标位置信息或第二目标位置信息,第二网络分支的网络结构至少包括第三交叉注意力层和第四交叉注意力层。
则针对上述步骤S102中根据第一特征图信息,通过第一神经网络中提取模块,获得各训练图像的第一目标检测信息,提供了一种可能的实施方式:
1)将第一特征图信息输入第一网络分支,根据第一网络分支中第一初始向量和第一特征图信息,通过第一交叉注意力层,获得第一查询向量,并根据第一查询向量和第一特征图信息,通过第二交叉注意力层,获得第一目标内容信息。
其中,Wcontent和θcontent分别为第一交叉注意力层和第二交叉注意力层中的可学习参数,可以通过预训练而确定,第一初始向量E也可以通过预先训练而获得。
2)将第一特征图信息输入第二网络分支,根据第二网络分支中第二初始向量和第一特征图信息,通过第三交叉注意力层,获得第二查询向量,并根据第二查询向量和第一特征图信息,通过第四交叉注意力层,获得第一目标位置信息,其中,第二初始向量与各训练图像中目标候选框的位置相关。
其中,Wpos和θpos分别为第三交叉注意力层和第四交叉注意力层中的可学习参数,可以通过预训练而确定,并且第二初始向量与第一初始向量不同,第二初始向量与位置信息相关,可以通过随机抖动的真实目标候选框的位置,通过全连接层(fully connectedlayers,FC)而生成,例如, 表示包含真实目标候选框和基于真实目标候选框随机抖动而获得,FC层的网络参数也可以通过预先训练而获得。
本公开实施例中,通过两个独立的第一网络分支和第二网络分支,第二网络分支和第二网络分支的网络参数不共享,其主要区别在于第一网络分支中第一初始向量可以由一组可学习参数获得,不包括目标位置信息,第二网络分支中第二初始向量是与目标位置信息相关,但是最终第一网络分支和第二网络分支的输出都可以用于检测任务,获得检测到的目标的候选框位置和类别。
则针对上述步骤S103中根据第二特征图信息和第一特征图信息,通过提取模块,获得各训练图像的第二目标检测信息,具体包括:
1)将第一特征图信息和第二特征图信息输入第一网络分支,根据第一网络分支中第一初始向量和第一特征图信息,通过第一交叉注意力层,获得第一查询向量,并根据第一查询向量和第二特征图信息,通过第二交叉注意力层,获得第二目标内容信息。
本公开实施例中,针对通过第二神经网络提取的第二目标内容信息和第二目标位置信息时,第一查询向量和第二查询向量都是使用第一神经网络输出的第一特征图信息而获得,并且在训练第二神经网络过程中,会固定提取模块的所有网络参数,目的是为了能够更加准确训练第二神经网络,使得第二神经网络更准确地模仿第一神经网络。
2)将第一特征图信息和第二特征图信息输入第二网络分支,根据第二网络分支中第二初始向量和第一特征图信息,通过第三交叉注意力层,获得第二查询向量,并根据第二查询向量和第二特征图信息,通过第四交叉注意力层,获得第二目标位置信息,其中,第二初始向量与各训练图像中目标候选框的位置相关。
进而基于提取模块包括的第一网络分支和第二网络分支,训练第二神经网络时,训练的预设条件包括迭代训练次数达到阈值,或者目标损失函数满足收敛条件,其中,目标损失函数至少包括第一损失函数和第二损失函数的加权和,第一损失函数表示第一目标内容信息和第二目标内容信息之间的损失函数,第二损失函数表示第一目标位置信息和第二目标位置信息之间的损失函数。例如,训练第二神经网络的目标损失函数中第一目标检测信息和第二目标检测信息之间的损失函数,也可以称为知识蒸馏损失函数,该知识蒸馏损失函数记为则/>
“*”表示可以为content或pos,N′为真实目标候选框的数目,N-N′为基于真实目标候选框进行随机抖动而生成的候选框的数目,即N可以表示E或E′查询向量的最终数目,表示正集,/>表示负集,均方误差(Mean Squared Error,MSE)即表示计算均方误差。
这样,本公开实施例中,通过第一查询向量和第二查询向量来提取检测相关知识,即提取内容知识和位置知识,并且在训练第二神经网络过程中,固定提取模块的网络参数,将其直接应用到第二神经网络的特征输出上,去模仿第一神经网络的输出,提高训练效率,实现了一种通用的知识蒸馏方法,不仅可以应用于不同类型检测器,也可以应用于跨骨架网络的知识蒸馏场景,降低了适配成本,提高了灵活性。
下面对本公开实施例中提取模块的训练过程进行说明,本公开实施例中,可以将提取模块附加到第一神经网络的特征输出上,基于第一神经网络进行训练,为便于提取模块的训练效率和准确性,在训练提取模块过程中,可以固定第一神经网络的其它网络参数,具体地,针对提取模块的训练方式,本公开提供了一种可能的实施方式:
1)获取第二训练图像集,其中,第二训练图像集中包括各第二训练图像。
2)将各第二训练图像输入第一神经网络,获得各第二训练图像的第三特征图信息,并根据第三特征图信息,通过提取模块,获得各第二训练图像的第三目标检测信息。
其中,第三目标检测信息包括第三目标内容信息和第三目标位置信息,即在提取模块的训练过程中,也是对提取模块的第一网络分支和第二网络分支分别进行训练,则针对该步骤,本公开提供了一种可能的实施方式:
将第三特征图信息输入提取模块的第一网络分支,获得各第二训练图像的第三目标内容信息;将第三特征图信息输入提取模块的第二网络分支,获得各第二训练图像的第三目标位置信息。
其中,在训练提取模块时,第一网络分支和第二网络分支提取第三目标内容信息和第三目标位置信息的过程,和在上述实施例训练第二神经网络过程中提取第一目标内容信息和第一目标位置信息信息的过程类似,这不过这时提取模块的各网络参数即是需要训练学习的,而不是固定的,例如,提取模块中进行训练学习的网络参数包括E、Wcontent、θcontent、FC、Wpos和θpos等。
3)根据第三目标检测信息,训练提取模块,在迭代训练次数达到阈值或者提取学习损失函数满足收敛条件的情况下,获得训练完成后的提取模块,其中,提取学习损失函数包括预测标签和真实标签之间的损失,预测标签是基于第三目标检测信息进行目标检测后得到。
具体地,预测标签包括预测目标位置标签和预测类别标签,并预测目标位置标签和预测类别标签是基于第三目标内容信息或第三目标位置信息而获得的,真实标签包括真实目标位置标签和真实类别标签;则提取学习损失函数包括预测目标位置标签与真实目标位置标签之间的损失,以及预测类别标签与真实目标类别标签之间的损失的加和。
本公开实施例中,提取第三目标检测信息后,为保证提取的第三目标检测信息与检测任务的相关性,还可以使用二分图匹配算法给第三目标检测信息分配真实目标框,进行检测任务的训练,具体地,根据第三目标内容信息和第三目标位置信息,通过提取模块的前馈神经网络(Feed Forward Networks,FFN)层,可以获得一组固定大小的N个预测,每个预测可以为是否查询到目标,或查询到目标的候选框位置和类别,针对第一网络分支,由于第一网络分支中未预定义位置和查询向量的对应关系,可以进行二分图匹配,将查询向量分为正集和负集/>σcontent()定义为第i个查询向量的索引,则若第i个查询向量属于正集,则/>表示真实目标框的位置,即表示第一网络分支中第i个查询向量所对应的真实目标位置标签,/>即表示第一网络分支中第i个查询向量相应的真实目标类别标签,否则若第i个查询向量属于负集,则没有目标框,并且/>即为背景标签(即无目标);而针对第二网络分支,第二网络分支中查询向量是与位置信息相关的,即查询向量与位置的对应关系是已知的,第二网络分支的/> pos()和第一网络分支中定义类似,相应地/>表示第二网络分支中第i个查询向量所对应的真实目标位置标签,/>表示第二网络分支中第i个查询向量相应的真实目标类别标签,因此,提取模块的提取学习损失函数记为/>可以表示为:
其中,“*”表示可以为content或pos,N为正集和负集的总数目,N′为正集的数目,表示正集,/>表示负集,/>表示预测目标框位置标签和真实目标框位置标签之间损失函数,/>表示预测类别标签和真实类别标签之间的损失函数,/>相应地表示预测目标位置标签或预测类别标签。/>
本公开实施例中,基于第一神经网络对提取模块进行训练,使得提取模块可以从第一神经网络中提取到与检测任务相关的检测知识,并且在训练提取模块过程中,可以固定第一神经网络的其它网络参数,提高训练效率。
需要说明的是,本公开实施例中,对于提取模块的网络结构并不进行限制,例如可以使用可变形转换器(transformer)等,进而基于该提取模块从第一神经网络中提取检测知识,并使用提取到的检测知识指导第二神经网络训练,因此对于第二神经网络和第一神经网络的检测器和骨架网络也不进行限制,例如,可以为可变形端到端目标检测(Detection Transformer,DETR)、残差网络(Residual Neural Network,ResNet)等,本公开实施例中并不进行限制。
基于上述实施例,下面对目标检测网络的训练过程的逻辑原理进行简单说明,本公开实施例目标检测网络的训练过程可以分为两方面:第一方面为提取模块的训练,第二方面为第二神经网络的训练,下面分别进行介绍。
第一方面:提取模块的训练。
参阅图2所示,为本公开实施例中提取模块训练过程的逻辑原理图。如图2所示,获取第二训练图像集,将第二训练图像集中各第二训练图像输入第一神经网络,获得各训练图像的第三特征图信息,在训练提取模块过程中,会固定第二神经网络的其它网络参数,例如,如图2所示,会固定第一神经网络的骨架网络和FPN网络等,可以理解的是此时第二神经网络的其它网络参数也是经过预先训练好的。
如图2所示,将第三特征图信息输入第一网络分支,通过第一初始向量E、第三特征图信息和第一交叉注意力层fcontent(·),获得第一查询向量其中,E可以随机生成,由一组可学习参数得到,然后,再通过第二交叉注意力层fextract1(·),获得第三目标内容信息。
同样地,将第三特征图信息输入第二网络分支,通过第二初始向量E′、第三特征图信息和第三交叉注意力层fpos(·),获得第二查询向量其中,在第二网络分支中,E′根据随机抖动真实目标框获得的/>结合FC而生成,然后,再通过第四交叉注意力层fextract2(·),获得第三目标位置信息。
进而根据第三目标内容信息和第二目标位置信息进行检测任务的训练,获得检测结果,例如是否检测到目标,或者检测到的目标的类别和位置,根据预测输出的预测框位置和真实目标框位置,以及预设类别标签和真实类别标签之间的损失函数,而对提取模块进行训练,即可以获得训练后的提取模块,可以确定提取模块中各网络参数的取值。
第二方面:第二神经网络的训练。
参阅图3所示,为本公开实施例中第二神经网络训练过程的逻辑原理图,如图3所示,将训练图像集中各训练图像分别输入第一神经网络和第二神经网络,获得各训练图像的第一特征图信息,以及各训练图像的第二特征图信息,其中,在第二神经网络的训练过程中,会固定第一神经网络以及提取模块的网络参数,而对第二神经网络的网络参数进行训练。
进而,如图3所示,将第一特征图信息分别输入第一神经网络中提取模块和第二神经网络中提取模块,与图2中提取模块的处理过程相同,根据第一特征图信息分别获得第一查询向量和第二查询向量,然后该第一查询向量和第二查询向量会分别再与第一特征图信息和第二特征图信息进行交叉注意力处理,最终获得第一神经网络的第一目标内容信息和第一目标位置信息,以及第二神经网络的第二目标内容信息和第二目标位置信息,进而基于第一目标内容信息和第二目标内容信息之间损失函数,以及第一目标位置信息和第二目标位置信息之间损失函数,训练第二神经网络,即可获得训练后的第二神经网络,即目标检测网络,其中,在训练第二神经网络过程中,还可以包括检测任务本身的位置和类别损失函数(该部分在图3中未示出)。
这样,本公开实施例中,通过设计提取模块,实现了通用知识蒸馏方法,对于第一神经网络和第二神经网络的检测器和骨架网络的网络结构并不进行限制,可以应用于任意检测器架构和任意骨架网络架构,提高了灵活性,降低了为满足知识蒸馏要求对第一神经网络和第二神经网络的适配成本,并且还可以提高第二神经网络的精度。
相比于相关技术中知识蒸馏方法,相关技术中知识蒸馏方法以第一神经网络输出的第一特征表示和第二神经网络输出的第二特征表示,计算知识蒸馏损失函数以进行学习,这种学习方式就需要第一神经网络和第二神经网络之间特征像素是对齐匹配的,要求知识蒸馏的第一神经网络和第二神经网络是同构网络,不同结构的第一神经网络和第二神经网络时,由于不同的语义差距,第二神经网络从第一神经网络学习就会失败或出错,而本公开实施例中,通过提取模块从第一神经网络进行知识提取,进而再通过提取模块,基于从第一神经网络提取到的知识信息,来指导第二神经网络的训练学习,因此,本公开实施例中不需要限制第一神经网络和第二神经网络的网络结构,可以实现一种通用知识蒸馏方法,适用场景更广。
本公开实施例中,训练完成第二神经网络后,可以将第二神经网络部署在电子设备中,由于第二神经网络较第一神经网络更小更轻量,因此第二神经网络对于电子设备的要求更小,例如可以将第二神经网络部署在一些边缘计算设备,提高任务处理效率。
进而可以基于已训练的第二神经网络,即目标检测网络,进行目标检测,相应地,本公开还提供一种目标检测方法,参阅图4所示,为本公开实施例中目标检测方法流程图,包括:
S401:获取待检测图像。
S402:利用已训练的目标检测网络对待检测图像进行目标检测,获得目标类别。
其中,本公开实施例中目标检测网络,即可以通过上述实施例中目标检测网络训练方法而获得,具体实施方式和上述实施例中相同,这里就不再进行赘述了。
本公开实施例中对于目标检测方法和目标检测网络的训练方法的应用场景并不进行限制,基于不同应用场景所对应的训练图像样本进行训练而得到的目标检测网络,即可以应用到对应的应用场景中以进行目标检测,例如,一种可能实施例,在智能监控领域,可以利用目标检测网络,对待检测监控图像中人体进行目标检测,确定待检测监控图像中是否存在人体,并在确定存在人体情况下,确定人体是否为目标用户。
本公开实施例中,通过提取模块,将其应用于第二神经网络的特征输出上,去模仿第一神经网络的输出,完成了知识蒸馏,使得第二神经网络也可以实现第一神经网络的目标检测功能,进而可以直接应用第二神经网络,进行目标检测,降低了对部署硬件的要求,也提高了目标检测的性能。
本领域技术人员可以理解,在具体实施方式的上述方法中,各步骤的撰写顺序并不意味着严格的执行顺序而对实施过程构成任何限定,各步骤的具体执行顺序应当以其功能和可能的内在逻辑确定。
基于同一发明构思,本公开实施例中还提供了与目标检测网络训练方法对应的目标检测网络训练装置,由于本公开实施例中的装置解决问题的原理与本公开实施例上述目标检测网络训练方法相似,因此装置的实施可以参见方法的实施,重复之处不再赘述。
参照图5所示,为本公开实施例提供的一种目标检测网络训练装置的示意图,该装置包括:
第一获取模块51,用于获取训练图像集,其中,所述训练图像集中包括各训练图像;
第一处理模块52,用于将所述各训练图像输入第一神经网络,获得所述各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,其中,所述提取模块是基于所述第一神经网络训练而获得的;
第二处理模块53,用于将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;
第一训练模块54,用于根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件下,获得目标检测网络。
一种可选的实施方式中,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息时,第一处理模块52用于:
根据所述提取模块中目标初始向量和所述第一特征图信息,通过所述提取模块的第一目标交叉注意力层,获得查询向量;
根据所述查询向量和所述第一特征图信息,通过所述提取模块的第二目标交叉注意力层,获得所述第一目标检测信息。
一种可选的实施方式中,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息时,第二处理模块53用于:
根据所述提取模块中目标初始向量和所述第一特征图信息,通过所述提取模块的第一目标交叉注意力层,获得查询向量;
根据所述查询向量和所述第二特征图信息,通过所述提取模块的第二目标交叉注意力层,获得所述第二目标检测信息。
一种可选的实施方式中,所述第一目标检测信息包括第一目标内容信息和第一目标位置信息,所述第二目标检测信息包括第二目标内容信息和第二目标位置信息;
所述提取模块包括第一网络分支和第二网络分支,所述第一网络分支用于提取所述第一目标内容信息或所述第二目标内容信息,所述第一网络分支的网络结构至少包括第一交叉注意力层和第二交叉注意力层;
并且所述第二网络分支用于提取所述第一目标位置信息或第二目标位置信息,所述第二网络分支的网络结构至少包括第三交叉注意力层和第四交叉注意力层。
一种可选的实施方式中,所述满足预设条件,第一训练模块54用于:迭代训练次数达到阈值,或者目标损失函数满足收敛条件;
其中,所述目标损失函数至少包括第一损失函数和第二损失函数的加权和,所述第一损失函数表示所述第一目标内容信息和所述第二目标内容信息之间的损失函数,所述第二损失函数表示所述第一目标位置信息和所述第二目标位置信息之间的损失函数。
一种可选的实施方式中,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息时,第一处理模块52用于:
将所述第一特征图信息输入所述第一网络分支,根据所述第一网络分支中第一初始向量和所述第一特征图信息,通过所述第一交叉注意力层,获得第一查询向量,并根据所述第一查询向量和所述第一特征图信息,通过所述第二交叉注意力层,获得所述第一目标内容信息;
将所述第一特征图信息输入所述第二网络分支,根据所述第二网络分支中第二初始向量和所述第一特征图信息,通过所述第三交叉注意力层,获得第二查询向量,并根据所述第二查询向量和所述第一特征图信息,通过所述第四交叉注意力层,获得所述第一目标位置信息,其中,所述第二初始向量与所述各训练图像中目标候选框的位置相关。
一种可选的实施方式中,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息时,第二处理模块53用于:
将所述第一特征图信息和所述第二特征图信息输入所述第一网络分支,根据所述第一网络分支中第一初始向量和所述第一特征图信息,通过所述第一交叉注意力层,获得第一查询向量,并根据所述第一查询向量和所述第二特征图信息,通过所述第二交叉注意力层,获得所述第二目标内容信息;
将所述第一特征图信息和所述第二特征图信息输入所述第二网络分支,根据所述第二网络分支中第二初始向量和所述第一特征图信息,通过所述第三交叉注意力层,获得第二查询向量,并根据所述第二查询向量和所述第二特征图信息,通过所述第四交叉注意力层,获得所述第二目标位置信息,其中,所述第二初始向量与所述各训练图像中目标候选框的位置相关。
一种可选的实施方式中,还包括第二训练模块55,针对所述提取模块的训练方式,该第二训练模块55执行以下步骤:
获取第二训练图像集,其中,所述第二训练图像集中包括各第二训练图像;
将所述各第二训练图像输入所述第一神经网络,获得所述各第二训练图像的第三特征图信息,并根据所述第三特征图信息,通过所述提取模块,获得所述各第二训练图像的第三目标检测信息;
根据所述第三目标检测信息,训练所述提取模块,在迭代训练次数达到阈值或者提取学习损失函数满足收敛条件的情况下,获得训练完成后的提取模块,其中,所述提取学习损失函数包括预测标签和真实标签之间的损失,所述预测标签是基于所述第三目标检测信息进行目标检测后得到。
一种可选的实施方式中,所述第三目标检测信息包括第三目标内容信息和第三目标位置信息,则根据所述第三特征图信息,通过所述提取模块,获得所述各第二训练图像的第三目标检测信息时,第二训练模块55用于:
将所述第三特征图信息输入所述提取模块的第一网络分支,获得所述各第二训练图像的第三目标内容信息;
将所述第三特征图信息输入所述提取模块的第二网络分支,获得所述各第二训练图像的第三目标位置信息。
一种可选的实施方式中,所述预测标签包括预测目标位置标签和预测类别标签,并所述预测目标位置标签和所述预测类别标签是基于所述第三目标内容信息或所述第三目标位置信息而获得的,所述真实标签包括真实目标位置标签和真实类别标签;
则所述提取学习损失函数包括所述预测目标位置标签与所述真实目标位置标签之间的损失,以及所述预测类别标签与所述真实目标类别标签之间的损失的加和。
基于同一发明构思,本公开实施例中还提供了与目标检测方法对应的目标检测装置,由于本公开实施例中的装置解决问题的原理与本公开实施例上述目标检测方法相似,因此装置的实施可以参见方法的实施,重复之处不再赘述。参照图6所示,为本公开实施例提供的一种目标检测装置的示意图,所述装置包括:
第二获取模块61,用于获取待检测图像;
检测模块62,用于利用已训练的目标检测网络,对所述待检测图像进行目标检测,获得目标类别。
其中,该已训练的目标检测网络,即是基于本公开实施例中的目标检测网络训练方法而训练生成的。
关于装置中的各模块的处理流程、以及各模块之间的交互流程的描述可以参照上述方法实施例中的相关说明,这里不再详述。
本公开实施例还提供了一种电子设备,如图7所示,为本公开实施例提供的电子设备结构示意图,包括:
处理器71和存储器72;所述存储器72存储有处理器71可执行的机器可读指令,处理器71用于执行存储器72中存储的机器可读指令,所述机器可读指令被处理器71执行时,处理器71执行下述步骤:
获取训练图像集,其中,所述训练图像集中包括各训练图像;
将所述各训练图像输入第一神经网络,获得所述各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,其中,所述提取模块是基于所述第一神经网络训练而获得的;
将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;
根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件的情况下,获得目标检测网络。
或者,该处理器71还可以执行下述步骤:
获取待检测图像;
利用已训练的目标检测网络,对所述待检测图像进行目标检测,获得目标类别。
上述存储器72包括内存721和外部存储器722;这里的内存721也称内存储器,用于暂时存放处理器71中的运算数据,以及与硬盘等外部存储器722交换的数据,处理器71通过内存721与外部存储器722进行数据交换。
上述指令的具体执行过程可以参考本公开实施例中所述的目标检测网络训练方法或目标检测方法的步骤,此处不再赘述。
本公开实施例还提供一种计算机可读存储介质,该计算机可读存储介质上存储有计算机程序,该计算机程序被处理器运行时执行上述方法实施例中所述的目标检测网络训练方法或目标检测方法的步骤。其中,该存储介质可以是易失性或非易失的计算机可读取存储介质。
本公开实施例还提供一种计算机程序产品,该计算机程序产品承载有程序代码,所述程序代码包括的指令可用于执行上述方法实施例中所述的目标检测网络训练方法或目标检测方法的步骤,具体可参见上述方法实施例,在此不再赘述。
其中,上述计算机程序产品可以具体通过硬件、软件或其结合的方式实现。在一个可选实施例中,所述计算机程序产品具体体现为计算机存储介质,在另一个可选实施例中,计算机程序产品具体体现为软件产品,例如软件开发包(Software Development Kit,SDK)等等。
若本公开技术方案涉及个人信息,应用本公开技术方案的产品在处理个人信息前,已明确告知个人信息处理规则,并取得个人自主同意。若本公开技术方案涉及敏感个人信息,应用本公开技术方案的产品在处理敏感个人信息前,已取得个人单独同意,并且同时满足“明示同意”的要求。例如,在摄像头等个人信息采集装置处,设置明确显著的标识告知已进入个人信息采集范围,将会对个人信息进行采集,若个人自愿进入采集范围即视为同意对其个人信息进行采集;或者在个人信息处理的装置上,利用明显的标识/信息告知个人信息处理规则的情况下,通过弹窗信息或请个人自行上传其个人信息等方式获得个人授权;其中,个人信息处理规则可包括个人信息处理者、个人信息处理目的、处理方式以及处理的个人信息种类等信息。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统和装置的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。在本公开所提供的几个实施例中,应该理解到,所揭露的系统、装置和方法,可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,又例如,多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些通信接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本公开各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个处理器可执行的非易失的计算机可读取存储介质中。基于这样的理解,本公开的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台电子设备(可以是个人计算机,服务器,或者网络设备等)执行本公开各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-OnlyMemory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
最后应说明的是:以上所述实施例,仅为本公开的具体实施方式,用以说明本公开的技术方案,而非对其限制,本公开的保护范围并不局限于此,尽管参照前述实施例对本公开进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本公开揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特征进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本公开实施例技术方案的精神和范围,都应涵盖在本公开的保护范围之内。因此,本公开的保护范围应所述以权利要求的保护范围为准。
Claims (15)
1.一种目标检测网络训练方法,其特征在于,包括:
获取训练图像集,其中,所述训练图像集中包括各训练图像;
将所述各训练图像输入第一神经网络,获得所述各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,其中,所述提取模块是基于所述第一神经网络训练而获得的;
将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;
根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件的情况下,获得目标检测网络。
2.根据权利要求1所述的方法,其特征在于,所述根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,包括:
根据所述提取模块中目标初始向量和所述第一特征图信息,通过所述提取模块的第一目标交叉注意力层,获得查询向量;
根据所述查询向量和所述第一特征图信息,通过所述提取模块的第二目标交叉注意力层,获得所述第一目标检测信息。
3.根据权利要求1所述的方法,其特征在于,所述根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息,包括:
根据所述提取模块中目标初始向量和所述第一特征图信息,通过所述提取模块的第一目标交叉注意力层,获得查询向量;
根据所述查询向量和所述第二特征图信息,通过所述提取模块的第二目标交叉注意力层,获得所述第二目标检测信息。
4.根据权利要求1-3任一项所述的方法,其特征在于,所述第一目标检测信息包括第一目标内容信息和第一目标位置信息,所述第二目标检测信息包括第二目标内容信息和第二目标位置信息;
所述提取模块包括第一网络分支和第二网络分支,所述第一网络分支用于提取所述第一目标内容信息或所述第二目标内容信息,所述第一网络分支的网络结构至少包括第一交叉注意力层和第二交叉注意力层;
并且所述第二网络分支用于提取所述第一目标位置信息或第二目标位置信息,所述第二网络分支的网络结构至少包括第三交叉注意力层和第四交叉注意力层。
5.根据权利要求4所述的方法,其特征在于,所述满足预设条件,包括:
迭代训练次数达到阈值,或者目标损失函数满足收敛条件;
其中,所述目标损失函数至少包括第一损失函数和第二损失函数的加权和,所述第一损失函数表示所述第一目标内容信息和所述第二目标内容信息之间的损失函数,所述第二损失函数表示所述第一目标位置信息和所述第二目标位置信息之间的损失函数。
6.根据权利要求4所述的方法,其特征在于,所述根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,包括:
将所述第一特征图信息输入所述第一网络分支,根据所述第一网络分支中第一初始向量和所述第一特征图信息,通过所述第一交叉注意力层,获得第一查询向量,并根据所述第一查询向量和所述第一特征图信息,通过所述第二交叉注意力层,获得所述第一目标内容信息;
将所述第一特征图信息输入所述第二网络分支,根据所述第二网络分支中第二初始向量和所述第一特征图信息,通过所述第三交叉注意力层,获得第二查询向量,并根据所述第二查询向量和所述第一特征图信息,通过所述第四交叉注意力层,获得所述第一目标位置信息,其中,所述第二初始向量与所述各训练图像中目标候选框的位置相关。
7.根据权利要求4所述的方法,其特征在于,所述根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息,包括:
将所述第一特征图信息和所述第二特征图信息输入所述第一网络分支,根据所述第一网络分支中第一初始向量和所述第一特征图信息,通过所述第一交叉注意力层,获得第一查询向量,并根据所述第一查询向量和所述第二特征图信息,通过所述第二交叉注意力层,获得所述第二目标内容信息;
将所述第一特征图信息和所述第二特征图信息输入所述第二网络分支,根据所述第二网络分支中第二初始向量和所述第一特征图信息,通过所述第三交叉注意力层,获得第二查询向量,并根据所述第二查询向量和所述第二特征图信息,通过所述第四交叉注意力层,获得所述第二目标位置信息,其中,所述第二初始向量与所述各训练图像中目标候选框的位置相关。
8.根据权利要求1-7任一项所述的方法,其特征在于,所述提取模块的训练方式,包括以下步骤:
获取第二训练图像集,其中,所述第二训练图像集中包括各第二训练图像;
将所述各第二训练图像输入所述第一神经网络,获得所述各第二训练图像的第三特征图信息,并根据所述第三特征图信息,通过所述提取模块,获得所述各第二训练图像的第三目标检测信息;
根据所述第三目标检测信息,训练所述提取模块,在迭代训练次数达到阈值或者提取学习损失函数满足收敛条件的情况下,获得训练后的提取模块,其中,所述提取学习损失函数包括预测标签和真实标签之间的损失,所述预测标签是基于所述第三目标检测信息进行目标检测后得到。
9.根据权利要求8所述的方法,其特征在于,所述第三目标检测信息包括第三目标内容信息和第三目标位置信息,则根据所述第三特征图信息,通过所述提取模块,获得所述各第二训练图像的第三目标检测信息,包括:
将所述第三特征图信息输入所述提取模块的第一网络分支,获得所述各第二训练图像的第三目标内容信息;
将所述第三特征图信息输入所述提取模块的第二网络分支,获得所述各第二训练图像的第三目标位置信息。
10.根据权利要求9所述的方法,其特征在于,所述预测标签包括预测目标位置标签和预测类别标签,并所述预测目标位置标签和所述预测类别标签是基于所述第三目标内容信息或所述第三目标位置信息而获得的,所述真实标签包括真实目标位置标签和真实类别标签;
则所述提取学习损失函数包括所述预测目标位置标签与所述真实目标位置标签之间的损失,以及所述预测类别标签与所述真实目标类别标签之间的损失的加和。
11.一种目标检测方法,其特征在于,包括:
获取待检测图像;
利用目标检测网络对所述待检测图像进行目标检测,获得目标类别,所述目标检测网络基于权利要求1至10任一项所述的目标检测网络训练方法得到。
12.一种目标检测网络训练装置,其特征在于,包括:
第一获取模块,用于获取训练图像集,其中,所述训练图像集中包括各训练图像;
第一处理模块,用于将所述各训练图像输入第一神经网络,获得所述各训练图像的第一特征图信息,根据所述第一特征图信息,通过提取模块,获得所述各训练图像的第一目标检测信息,其中,所述提取模块是基于所述第一神经网络训练而获得的;
第二处理模块,用于将所述各训练图像输入第二神经网络,获得所述各训练图像的第二特征图信息,根据所述第二特征图信息和所述第一特征图信息,通过所述提取模块,获得所述各训练图像的第二目标检测信息;
第一训练模块,用于根据所述各训练图像的所述第二目标检测信息和所述第一目标检测信息,训练所述第二神经网络,在满足预设条件的情况下,获得目标检测网络。
13.一种目标检测装置,其特征在于,包括:
第二获取模块,用于获取待检测图像;
检测模块,用于利用目标检测网络对所述待检测图像进行目标检测,获得目标类别,所述目标检测网络基于权利要求1至10任一项所述的目标检测网络训练方法得到。
14.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现权利要求1-10或11任一项所述方法的步骤。
15.一种计算机可读存储介质,其上存储有计算机程序,其特征在于:所述计算机程序被处理器执行时实现权利要求1-10或11任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310332431.7A CN116342978A (zh) | 2023-03-30 | 2023-03-30 | 目标检测网络训练及目标检测方法、装置、电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310332431.7A CN116342978A (zh) | 2023-03-30 | 2023-03-30 | 目标检测网络训练及目标检测方法、装置、电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116342978A true CN116342978A (zh) | 2023-06-27 |
Family
ID=86894661
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310332431.7A Pending CN116342978A (zh) | 2023-03-30 | 2023-03-30 | 目标检测网络训练及目标检测方法、装置、电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116342978A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117911877A (zh) * | 2024-03-20 | 2024-04-19 | 岳正检测认证技术有限公司 | 一种基于机器视觉的建筑通信光缆故障识别方法 |
-
2023
- 2023-03-30 CN CN202310332431.7A patent/CN116342978A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117911877A (zh) * | 2024-03-20 | 2024-04-19 | 岳正检测认证技术有限公司 | 一种基于机器视觉的建筑通信光缆故障识别方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Chen et al. | An edge traffic flow detection scheme based on deep learning in an intelligent transportation system | |
CN109766840B (zh) | 人脸表情识别方法、装置、终端及存储介质 | |
CN110909630B (zh) | 一种异常游戏视频检测方法和装置 | |
CN111382868A (zh) | 神经网络结构搜索方法和神经网络结构搜索装置 | |
CN109918684A (zh) | 模型训练方法、翻译方法、相关装置、设备及存储介质 | |
CN113516227B (zh) | 一种基于联邦学习的神经网络训练方法及设备 | |
CN111782840A (zh) | 图像问答方法、装置、计算机设备和介质 | |
CN115512005A (zh) | 一种数据处理方法及其装置 | |
CN113628059A (zh) | 一种基于多层图注意力网络的关联用户识别方法及装置 | |
CN116342978A (zh) | 目标检测网络训练及目标检测方法、装置、电子设备 | |
CN114580794B (zh) | 数据处理方法、装置、程序产品、计算机设备和介质 | |
CN113536970A (zh) | 一种视频分类模型的训练方法及相关装置 | |
Sun et al. | Two-stage deep regression enhanced depth estimation from a single RGB image | |
Zhong | A convolutional neural network based online teaching method using edge-cloud computing platform | |
CN113869366A (zh) | 模型训练方法、亲属关系分类方法、检索方法及相关装置 | |
CN112801138A (zh) | 基于人体拓扑结构对齐的多人姿态估计方法 | |
Modaghegh et al. | Learning of relevance feedback using a novel kernel based neural network | |
CN117540024B (zh) | 一种分类模型的训练方法、装置、电子设备和存储介质 | |
CN117556150B (zh) | 多目标预测方法、装置、设备及存储介质 | |
CN116245809A (zh) | 设备检测模型构建方法、装置、计算机设备和存储介质 | |
CN116628253A (zh) | 一种搜索请求推荐方法、装置、电子设备和存储介质 | |
CN115497059A (zh) | 一种基于注意力网络的车辆行为识别方法 | |
CN117541971A (zh) | 一种目标检测方法、装置、存储介质和电子设备 | |
CN116758365A (zh) | 视频处理方法、机器学习模型训练方法及相关装置、设备 | |
CN116977691A (zh) | 角色识别模型的训练方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |