CN109977913B - 一种目标检测网络训练方法、装置及电子设备 - Google Patents
一种目标检测网络训练方法、装置及电子设备 Download PDFInfo
- Publication number
- CN109977913B CN109977913B CN201910277616.6A CN201910277616A CN109977913B CN 109977913 B CN109977913 B CN 109977913B CN 201910277616 A CN201910277616 A CN 201910277616A CN 109977913 B CN109977913 B CN 109977913B
- Authority
- CN
- China
- Prior art keywords
- output
- target
- network
- loss
- layer
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/25—Determination of region of interest [ROI] or a volume of interest [VOI]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/40—Scenes; Scene-specific elements in video content
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Biophysics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明实施例提供了一种目标检测网络训练方法、装置及电子设备,其中,该方法包括:获取携带目标的样本和未携带目标的样本,将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。可以降低第一SSD网络从非目标位置检测出目标的可能性,从而提高训练后的第一SSD网络检测的准确度。
Description
技术领域
本发明涉及图像识别技术领域,特别是涉及一种目标检测网络训练方法、装置及电子设备。
背景技术
Logo(LOGO type,徽标)是企业综合信息传递的媒介,通过形象的Logo可以让消费者记住公司主体和品牌文化,起到对Logo拥有公司的识别和推广的作用,可以是文字、图标以及二者的混合,Logo也常常被加入到公司的产品中,以表明生产该产品的公司的身份,例如,在视频媒体领域中,许多公司会把公司的Logo嵌入到其创作或播放的图片或者视频中,例如,视频网站在播放视频时,会将该视频网站的Logo嵌入到播放的视频中,以表明该视频网站对该视频的播放的权利。
视频网站在将拥有的Logo嵌入到在该视频网站播放的视频中时,可以对该视频进行Logo检测,以避免该视频中有其他公司的Logo,还可以避免在该视频中重复添加该视频网站的Logo。
目前,常用的Logo检测方法主要是基于深度学习的目标检测方法,该目标检测方法可以检测Logo在图片或视频帧中的位置,并同时可以检测出Logo的类别。在采用该基于深度学习的目标检测方法进行Logo检测时,首先需要对该基于深度学习的目标检测方法中的深度学习网络进行训练,然后采用训练好的深度学习网络对待检测图片或视频帧进行检测。
在对该深度学习网络时,需要提供标注有Logo的样本集,该样本集的每个样本中必须至少含有一个Logo,并且,每个样本中必须要标注出Logo的位置和类型。这样,在对该深度学习网络进行训练时,该深度学习网络可以将把Logo以外的区域作为背景,从而可以学会区分Logo和背景。
然而,发明人在实现本发明的过程中发现,现有技术至少存在如下问题:
若仅采用标注有Logo的样本集对深度学习网络进行训练,并采用训练好的深度学习网络进行Logo检测时,容易将待检测图片或视频帧中,与该样本集中样本的Logo相似的图形检测为Logo,从而造成误检。
发明内容
本发明实施例的目的在于提供一种目标检测网络训练方法、装置及电子设备,以实现提高训练后的神经网络对目标进行检测的准确度。具体技术方案如下:
在本发明实施例的一个方面,本发明实施例提供了一种目标检测网络训练方法,该方法包括:
获取携带目标的样本和未携带目标的样本,其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本;
将携带目标的样本输入至第一SSD(Single Shot MultiBox Detector,单次多框检测器)网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;
对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
可选的,在获取携带目标的样本和未携带目标的样本之前,该目标检测网络训练方法,还包括:
获取多个携带目标的样本;
采用多个携带目标的样本对预先建立的第一SSD网络和预先建立的第二SSD进行训练,得到第一SSD网络和第二SSD网络;
可选的,获取携带目标的样本和未携带目标的样本,包括:
获取未携带目标的图片集,并采用第一SSD网络对未携带目标的图片集进行检测,得到第一误检测图片,其中,第一误检测图片中未携带目标;
将获取的多个携带目标的样本作为携带目标的样本,将第一误检测图片作为未携带目标的样本。
可选的,第二SSD网络包括:基础特征层、第一卷积层、第二卷积层、第三卷积层、第四卷积层、池化层以及第二输出层,第二输出层包括:第一输出子层、第二输出子层、第三输出子层、第四输出子层、第五输出子层以及第六输出子层;
可选的,将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,包括:
将未携带目标的样本输入至第二SSD网络的基础特征层,得到基础特征层输出的基础特征图;
将基础特征图输入至第二SSD网络的第一卷积层和第一输出子层,得到第一卷积层输出的卷积后的第一特征图和第一输出子层输出的第一类别损失;
将第一特征图输入至第二SSD网络的第二卷积层和第二输出子层,得到第二卷积层输出的卷积后的第二特征图和第二输出子层输出的第二类别损失;
将第二特征图输入至第二SSD网络的第三卷积层和第三输出子层,得到第三卷积层输出的卷积后的第三特征图和第三输出子层输出的第三类别损失;
将第三特征图输入至第二SSD网络的第四卷积层和第四输出子层,得到第四卷积层输出的卷积后的第四特征图和第四输出子层输出的第四类别损失;
将第四特征图输入至第二SSD网络的池化层和第五输出子层,得到池化层输出的池化后的特征图和第五输出子层输出的第五类别损失;
将池化后的特征图输入至第六输出子层,得到第六输出子层输出的第六类别损失;
将第一类别损失、第二类别损失、第三类别损失、第四类别损失、第五类别损失以及第六类别损失,作为第二输出层输出的类别损失。
可选的,在对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数之前,本发明实施例的一种目标检测网络训练方法还包括:
对第二输出层输出的类别损失按照从大到小的顺序进行排序,得到排序后的类别损失;
获取预设的类别损失阈值,并在排序后的类别损失中,选择大于或等于预设的类别损失阈值的类别损失;
对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数,包括:
对选择的大于或等于预设的类别损失阈值的类别损失以及第一输出层输出的类别损失和位置损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
可选的,在基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数之后,本发明实施例的一种目标检测网络训练方法还包括:
获取多个新的携带目标的样本和多个新的未携带目标的样本;
采用多个新的携带目标的样本和多个新的未携带目标的样本,对更新参数后的第一SSD网络和更新参数后的第二SSD网络进行训练,得到训练完成的第一SSD网络和训练完成的第二SSD网络。
可选的,本发明实施例的一种目标检测网络训练方法还包括:
采用训练完成的第一SSD网络,对除第一误检测图片外的未携带目标的图片集进行检测,得到第二误检测图片,其中,第二误检测图片中未携带目标;
将第二误检测图片、第一误检测图片以及多个携带目标的样本作为训练样本,对训练完成的第一SSD网络和训练完成的第二SSD网络进行训练。
在本发明实施例的又一方面,本发明实施例还提供了一种目标检测网络训练装置,该装置包括:
第一样本获取模块,用于获取携带目标的样本和未携带目标的样本,其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本;
样本输入模块,用于将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;
更新模块,用于对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
可选的,目标检测网络训练装置还包括:
携带目标样本获取模块,用于获取多个携带目标的样本;
第一训练模块,用于采用多个携带目标的样本对预先建立的第一SSD网络和预先建立的第二SSD进行训练,得到第一SSD网络和第二SSD网络;
第一样本获取模块,包括:
检测子模块,用于获取未携带目标的图片集,并采用第一SSD网络对未携带目标的图片集进行检测,得到第一误检测图片,其中,第一误检测图片中未携带目标;
样本获取子模块,将获取的多个携带目标的样本作为携带目标的样本,将第一误检测图片作为未携带目标的样本。
可选的,第二SSD网络包括:基础特征层、第一卷积层、第二卷积层、第三卷积层、第四卷积层、池化层以及第二输出层,第二输出层包括:第一输出子层、第二输出子层、第三输出子层、第四输出子层、第五输出子层以及第六输出子层;
可选的,样本输入模块,包括:
第一输入子模块,用于将未携带目标的样本输入至第二SSD网络的基础特征层,得到基础特征层输出的基础特征图;
第二输入子模块,用于将基础特征图输入至第二SSD网络的第一卷积层和第一输出子层,得到第一卷积层输出的卷积后的第一特征图和第一输出子层输出的第一类别损失;
第三输入子模块,用于将第一特征图输入至第二SSD网络的第二卷积层和第二输出子层,得到第二卷积层输出的卷积后的第二特征图和第二输出子层输出的第二类别损失;
第四输入子模块,用于将第二特征图输入至第二SSD网络的第三卷积层和第三输出子层,得到第三卷积层输出的卷积后的第三特征图和第三输出子层输出的第三类别损失;
第五输入子模块,用于将第三特征图输入至第二SSD网络的第四卷积层和第四输出子层,得到第四卷积层输出的卷积后的第四特征图和第四输出子层输出的第四类别损失;
第六输入子模块,用于将第四特征图输入至第二SSD网络的池化层和第五输出子层,得到池化层输出的池化后的特征图和第五输出子层输出的第五类别损失;
第七输入子模块,用于将池化后的特征图输入至第六输出子层,得到第六输出子层输出的第六类别损失;
转换子模块,用于将第一类别损失、第二类别损失、第三类别损失、第四类别损失、第五类别损失以及第六类别损失,作为第二输出层输出的类别损失。
可选的,本发明实施例的一种目标检测网络训练装置还包括:
类别损失排序模块,用于对第二输出层输出的类别损失按照从大到小的顺序进行排序,得到排序后的类别损失;
类别损失选择模块,用于获取预设的类别损失阈值,并在排序后的类别损失中,选择大于或等于预设的类别损失阈值的类别损失;
更新模块,具体用于:
对选择的大于或等于预设的类别损失阈值的类别损失以及第一输出层输出的类别损失和位置损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
可选的,本发明实施例的一种目标检测网络训练装置还包括:
第二样本获取模块,用于获取多个新的携带目标的样本和多个新的未携带目标的样本;
第二训练模块,用于采用多个新的携带目标的样本和多个新的未携带目标的样本,对更新参数后的第一SSD网络和更新参数后的第二SSD网络进行训练,得到训练完成的第一SSD网络和训练完成的第二SSD网络。
可选的,本发明实施例的一种目标检测网络训练装置还包括:
检测模块,用于采用训练完成的第一SSD网络,对除第一误检测图片外的未携带目标的图片集进行检测,得到第二误检测图片,其中,第二误检测图片中未携带目标;
第三训练模块,用于将第二误检测图片、第一误检测图片以及多个携带目标的样本作为训练样本,对训练完成的第一SSD网络和训练完成的第二SSD网络进行训练。
在本发明实施例的又一方面,本发明实施例还提供了一种电子设备,该电子设备包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现上述任一所述的目标检测网络训练方法。
在本发明实施例的又一方面,本发明实施例还提供了一种计算机可读存储介质,计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述任一所述的目标检测网络训练方法。
在本发明实施例的又一方面,本发明实施例还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述任一所述的目标检测网络训练方法。
本发明实施例提供的一种目标检测网络训练方法、装置及电子设备,可以将标注有目标类别和目标位置的样本输入至第一SSD网络,将未携带目标的样本输入至与第一SSD网络具有相同网络参数的第二SSD网络,这样,可以使得该第一SSD网络输出与该标注有目标类别和目标位置的样本对应的类别损失和位置损失,使得第二SSD网络输出与该未携带目标的样本对应的类别损失,由于该未携带目标的样本中未携带目标,当该类别损失较大时,则说明该未携带目标的样本中存在与该携带目标的样本中的目标相似的对象,使得该第二SSD网络将该未携带目标的样本中的对象预测为该携带目标的样本中的目标。通过对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。可以使得在对该第一SSD网络的网络参数和第二SSD网络的网络参数更新后,降低该第二SSD网络将未携带目标的样本中的对象预测为目标的可能性,由于第一SSD网络和第二SSD网络共享参数,因此,可以使得该第一SSD网络在对待检测的图片进行检测时,能够更好的识别出该待检测图片中与目标相似的对象,提高该第一SSD网络检测的准确度,减少误检。当然,实施本发明的任一产品或方法必不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍。
图1为本发明实施例一种目标检测网络训练方法第一种实施方式的流程图;
图2a为图1所示的目标检测网络中第一SSD网络示例性的结构示意图;
图2b为图1所示的目标检测网络中第二SSD网络示例性的结构示意图;
图3为本发明实施例一种目标检测网络训练方法第二种实施方式的流程图;
图4为本发明实施例一种目标检测网络训练方法第三种实施方式的流程图;
图5为本发明实施例一种目标检测网络训练方法第四种实施方式的流程图;
图6为本发明实施例的一种目标检测网络训练装置的结构示意图;
图7为本发明实施例的一种电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述。
为了解决现有技术存在的问题,本发明实施例提供了一种目标检测网络训练方法、装置及电子设备,以实现提高训练后的神经网络对目标进行检测的准确度。
下面,首先对本发明实施例的一种目标检测网络训练方法进行介绍,如图1所示,为本发明实施例一种目标检测网络训练方法第一种实施方式的流程图,该方法可以包括:
S110,获取携带目标的样本和未携带目标的样本。
其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本。
在对神经网络进行训练时,可以首先为该待训练的神经网络设置样本,这样,在对神经网络进行训练时,该神经网络可以获取到训练时使用的样本。
在一些示例中,为了为本发明实施例的目标检测网络训练方法中使用的神经网络进行训练,本发明实施例为该目标检测网络设置了两种样本,该两种样本可以是携带目标的样本和未携带目标的样本。
在一些示例中,该目标可以是该目标检测网络训练完成时,能够检测出来的对象,例如,该目标可以是Logo,也可以是人物、动物等。这样,在采用携带有目标的样本训练完成后,该目标检测网络可以检测出各个待检测图片中是否含有该目标。
在一些示例中,为了提高该目标检测网络的检测准确度,本发明实施例采用了辅助网络对该目标检测网络进行辅助训练。因此,该未携带目标的样本可以是为该辅助网络设置的样本。
当将待训练的目标检测网络和辅助网络设置在一个电子设备上时,可以在该电子设备上预先存储携带目标的样本和未携带目标的样本,这样,该电子设备可以获取到该携带有目标的样本和未携带目标的样本。
在一些示例中,还可以在于该电子设备通信连接的另一电子设备上预先存储携带目标的样本和未携带目标的样本,该电子设备可以从该另一电子设备上获取该预先存储的携带目标的样本和未携带目标的样本。
在一些示例中,该未携带目标的样本可以是预先设置的多个样本中,除标注有目标的类别的样本外的样本,该预先设置的多个样本中的每个样本中,都至少标注有该样本携带的对象的类别。例如,假设一电子设备中设置有携带Logo的样本、携带有建筑的样本、携带有人物的样本以及携带有动物的样本,又假设目标为Logo,则该未携带目标的样本可以包括:携带有建筑的样本、携带有人物的样本以及携带有动物的样本。
S120,将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失。
其中,第二SSD网络与第一SSD网络具有相同的网络参数。
在一些示例中,在获取到携带目标的样本和未携带目标的样本后,为了对第一SSD网络进行训练,可以将该携带目标的样本输入至第一SSD网络,将未携带目标的样本输入至第二SSD网络,这样,第一SSD网络可以对该携带目标的样本进行预测,得到该携带目标的样本中目标的类别信息和位置信息,也即对该携带目标的样本进行预测的预测值,以计算该第一SSD网络对携带目标的样本进行预测的预测值与真实值之间的误差,也即第一输出层输出的类别损失和位置损失。
将未携带目标的样本输入至第二SSD网络,这样,第二SSD网络可以使用与第一SSD网络相同的网络参数对该未携带目标的样本进行预测,得到该未携带目标的样本的对象预测为目标的类别信息,然后与该未携带目标的样本中标注的类别信息进行对比,以计算该第二SSD网络对未携带目标的样本进行预测的预测值与真实值之间的误差,也即,第二输出层输出的类别损失。
在一些示例中,该类别信息可以是对携带目标的样本进行预测时,预测的对象为目标的概率,或者对未携带目标的样本进行预测时,预测的对象为目标的概率,该位置信息可以是预测的对象的位置。
在一些示例中,该第二SSD网络可以是基于第一SSD网络改变得到的,该第一SSD网络或第二SSD网络可以包括多个特征提取层,该第一SSD网络还可以包括第一输出层,该第一输出层可以接收第一SSD网络的各个特征提取层输出的各个特征图,然后可以计算出该携带目标的样本对应的类别损失和位置损失。该第二SSD网络还可以包括第二输出层,该第二输出层可以接收该第二SSD网络的各个特征提取层输出的各个特征图,然后可以计算出该未携带目标的样本对应的类别损失。
例如,该第一SSD网络可以是如图2a所示的神经网络,该第二SSD网络可以是如图2b所示的神经网络。
在图2a中,该多个特征提取层可以包括基础特征层210、第一卷积层220、第二卷积层230、第三卷积层240、第四卷积层250、池化层260以及第一输出层270。该第一输出层可以接收基础特征层210输出的基础特征图、第一卷积层220输出的第一特征图、第二卷积层230输出的第二特征图、第三卷积层240输出的第三特征图、第四卷积层250输出的第四特征图、以及池化层260输出的池化后的特征图。然后可以计算出该携带目标的样本的类别损失和位置损失。
在一些示例中,该第一输出层270中可以包括多个输出子层,例如,如图2a所示的多个输出子层271。该输出子层271的数量与特征提取层的数量相对应。该输出子层271可以包括类别损失计算单元和位置损失计算单元。各个输出子层271在接收到各个特征提取层输入的特征图后,可以计算出各个特征图的类别损失和位置损失,然后对各个特征图的类别损失和位置损失进行合并,从而可以得到携带目标的样本的类别损失和位置损失。
在图2b中,该多个特征提取层可以包括基础特征层210、第一卷积层220、第二卷积层230、第三卷积层240、第四卷积层250、池化层260以及第二输出层280。
在一些示例中,该基础特征层210可以是VGG16网络或VGG19网络,可以提取出该未携带目标的样本的基础特征图,该基础特征层输出的基础特征图的尺寸为38*38,通道数为512。该基础特征层210可以将得到的基础特征图输出至第一卷积层220和第二输出层280。
在一些示例中,该第一卷积层220可以接收基础特征层210输出的基础特征图,然后对该基础特征图进行卷积,从而可以得到该第一卷积层220输出的第一特征图,该第一卷积层220可以将得到的第一特征图传输至第二卷积层230和第二输出层280。
在又一些示例中,该第一卷积层220可以包括两个全连接卷积子层。该两个全连接卷积子层的输出的特征图的尺寸为19*19,通道数为1024。
该第二卷积层230在得到第一特征图后,可以继续对该第一特征图进行卷积,从而可以输出卷积后得到的第二特征图,然后可以将该第二特征图输出至相连的第三卷积层240和第二输出层280。
在一些示例中,该第二卷积层输出的第二特征图的尺寸为10*10,通道数为512。
该第三卷积层240在得到第二特征图后,可以继续对该第二特征图进行卷积,从而可以输出卷积后得到的第三特征图,然后可以将该第三特征图输出至相连的第四卷积层250和第二输出层280。
在一些示例中,该第三卷积层230输出的第三特征图的尺寸为5*5,通道数为512。
该第四卷积层250在得到该第三特征图后,可以继续对该第三特征图进行卷积,从而可以输出卷积后得到的第四特征图,然后可以将该第四特征图输出至相连的池化层260和第二输出层280。
在一些示例中,该第四卷积层250输出的第四特征图的尺寸为3*3,通道数为256。
该池化层260在得到该第四特征图后,可以对该第四特征图进行池化,然后输出池化后的特征图,并将该池化后的特征图传输至第二输出层280。
在一些示例中,该池化层260输出的池化后的特征图的尺寸为1*1,通道数为256。
第二输出层280在得到基础特征层210传输的基础特征图、第一卷积层220传输的第一特征图、第二卷积层230传输的第二特征图、第三卷积层240传输的第三特征图、第四卷积层250传输的第四特征图和池化层260传输的池化后的特征图后,可以基于该基础特征图、第一特征图、第二特征图、第三特征图、第四特征图以及池化后的特征图,计算与该未携带目标的样本对应的类别损失。
在又一些示例中,如图2b所示,该第二输出层280可以包括多个输出子层281。例如,可以包括如图2b所示的第一输出子层281、第二输出子层282、第三输出子层283、第四输出子层284、第五输出子层285、第六输出子层286。
该多个第二输出子层281分别与基础特征层210、第一卷积层220、第二卷积层230、第三卷积层240、第四卷积层250和池化层260一一对应。因此,基础特征层210可以将基础特征图传输至对应的第一输出子层281,第一卷积层220也可以将第一特征图传输至第二输出子层282,第二卷积层230可以将第二特征图传输至第三输出子层283,第四卷积层240可以将第三特征图传输至第四输出子层284,第四卷积层250可以将第四特征图传输至第五输出子层285,池化层260可以将池化后的特征图传输至第六输出子层286。
这样,第一输出子层281可以输出基础特征图对应的第一类别损失,第二输出子层282可以输出第一特征图对应的第二类别损失,第三输出子层283可以输出第二特征图对应的第三类别损失,第四输出子层284可以输出第三特征图对应的第四类别损失,第五输出子层285可以输出第四特征图对应的第五类别损失,第六输出子层286可以输出池化后的特征图对应的第六类别损失。
在一些示例中,该第二输出层281还可以包括损失合并子层。该损失合并子层可以将第一类别损失、第二类别损失、第三类别损失、第四类别损失、第五类别损失以及第六类别损失,作为第二输出层输出的类别损失。这样,可以得到第二输出层输出的类别损失。
S130,对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
在得到第一输出层输出的类别损失和位置损失、第二输出层输出的类别损失后,为了使得训练完成的第一SSD网络能够更准确的进行检测,区分出与目标相似的图像,提高训练完成的第一SSD网络检测的准确度,可以对该第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失。然后可以基于该总损失,更新第一SSD网络中的网络参数和第二SSD网络中的网络参数,以便在下一次训练时,该第一SSD网络中的网络参数与该第二SSD网络中的网络参数相同。
在一些示例中,由于输入至该第二SSD网络的是未携带目标的样本,因此,输出的类别损失可以反映出该第二SSD网络的预测值与该未携带目标的样本的真实值之间的误差。通过反向传播,可以降低该误差,从而可以实现对该未携带目标中的对象更准确的预测。由于该第一SSD网络的网络参数与该第二SSD的网络参数相同,因此,该第一SSD网络可以更清楚的区分目标和与目标相似的对象,从而可以避免误检,提高训练后的神经网络对目标进行检测的准确度。
本发明实施例提供的一种目标检测网络训练方法,可以将标注有目标类别和目标位置的样本输入至第一SSD网络,将未携带目标的样本输入至与第一SSD网络具有相同网络参数的第二SSD网络,这样,可以使得该第一SSD网络输出与该标注有目标类别和目标位置的样本对应的类别损失和位置损失,使得第二SSD网络输出与该未携带目标的样本对应的类别损失,由于该未携带目标的样本中未携带目标,当该类别损失较大时,则说明该未携带目标的样本中存在与该携带目标的样本中的目标相似的对象,使得该第二SSD网络将该未携带目标的样本中的对象预测为该携带目标的样本中的目标。通过对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。可以使得在对该第一SSD网络的网络参数和第二SSD网络的网络参数更新后,降低该第二SSD网络将未携带目标的样本中的对象预测为目标的可能性,由于第一SSD网络和第二SSD网络共享参数,因此,可以使得该第一SSD网络在对待检测的图片进行检测时,能够更好的识别出该待检测图片中与目标相似的对象,提高该第一SSD网络检测的准确度,减少误检。
在一些示例中,上述的步骤S110~S130描述的是进行一次训练的过程,在实际训练时,可以对该第一SSD网络和第二SSD网络进行多次训练。每次训练时,可以使用不同的携带目标的样本和不同的未携带目标的样本。
例如,可以获取多个新的携带目标的样本和多个新的未携带目标的样本;然后
采用该多个新的携带目标的样本和多个新的未携带目标的样本,通过步骤S120~S130,对更新参数后的第一SSD网络和更新参数后的第二SSD网络进行训练,这样,通过多次训练后,可以得到训练完成的第一SSD网络和训练完成的第二SSD网络。
在一些示例中,在对第一SSD网络和第二SSD网络进行多次训练后,当训练后的第一SSD网络和训练后的第二SSD网络满足预设输出条件时,可以将该训练后的第一SSD网络作为训练完成的第一SSD网络,将该训练后的第二SSD网络作为训练完成的第二SSD网络。
在一些示例中,该预设输出条件可以是对第一SSD和第二SSD网络进行训练的次数大于或等于预先设置的训练次数阈值,或者,第一输出层输出的类别损失和位置损失与第二输出层输出的类别损失之和,小于或等于预设损失值阈值。
在图1所示的一种目标检测网络训练方法的基础上,为了进一步提高训练完成训练后的神经网络对目标进行检测的准确度,本发明实施例还提供了一种可能的实现方式,如图3所示,为本发明实施例一种目标检测网络训练方法第二种实施方式的流程图,该方法可以包括:
S111,获取多个携带目标的样本;
S112,采用多个携带目标的样本对预先建立的第一SSD网络和预先建立的第二SSD进行训练,得到第一SSD网络和第二SSD网络。
S113,获取未携带目标的图片集,并采用第一SSD网络对未携带目标的图片集进行检测,得到第一误检测图片。
其中,第一误检测图片中未携带目标;
S114,将获取的多个携带目标的样本作为携带目标的样本,将第一误检测图片作为未携带目标的样本。
在一些示例中,虽然通过随机选择的携带目标的样本和未携带目标的样本对上述的第一SSD网络和第二SSD网络进行训练,可以在训练完成后,提高第一SSD网络进行检测的准确度。但是,为了进一步提高训练完成的第一SSD网络检测的准确度,并且减少对第一SSD网络和第二SSD网络的训练次数。可以有针对性的选择一些未携带目标的样本。例如,可以基于携带目标的样本选择未携带目标的样本。
具体的,可以首先获取多个携带目标的样本,然后采用该多个携带目标的样本,对预先建立的第一SSD网络和预先建立的第二SSD进行训练。
在一些示例中,该预先建立的第一SSD网络和预先建立的第二SSD网络具有相同的参数。在又一些示例中,可以是将该多个携带目标的样本输入至预先建立的第一SSD网络进行训练,在训练完成后,可以将该训练得到的第一SSD网络的网络参数更新至该预先建立的第二SSD网络。
在又一些示例中,也可以是在对预先建立的第一SSD网络进行一次训练后,在更新该第一SSD网络的网络参数时,同时更新预先建立第二SSD网络的网络参数。这样,在对预先建立的第一SSD网络训练完成后,同时也完成了对预先建立的第二SSD网络的训练。从而可以得到第一SSD网络和第二SSD网络。
在得到第一SSD网络和第二SSD网络后,可以采用该第一SSD网络,在预先设置的未携带目标的图片集中进行检测,得到检测出目标的图片,由于该检测出目标的图片实际并未携带目标,因此,该检测出目标的图片为误检图片。第一SSD网络能够将该误检图片检测为携带目标的图片,说明该误检图片与目标有一定的相似性,若采用该误检图片作为未携带目标的图片对第一SSD网络和第二SSD网络训练时的样本,可以进一步提高该第一SSD网络检测时的准确度,减少与目标相似的图片的误检。
S120,将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失。
S130,对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
应当理解的是,步骤S120~S130可以参考本发明实施例中的第一种实施方式,这里不再赘述。
在图3所示的一种目标检测网络训练方法的基础上,本发明实施例还提供了一种可能的实现方式,如图4所示,为本发明实施例一种目标检测网络训练方法第三种实施方式的流程图,在S120,将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失之后,该方法还可以包括:
S140,对第二输出层输出的类别损失按照从大到小的顺序进行排序,得到排序后的类别损失;
S150,获取预设的类别损失阈值,并在排序后的类别损失中,选择大于或等于预设的类别损失阈值的类别损失;
S131,对选择的大于或等于预设的类别损失阈值的类别损失以及第一输出层输出的类别损失和位置损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
在一些示例中,该第二输出层280可以包括多个子输出层,因此,该第二输出层可以输出多个类别损失,该多个类别损失中,可以存在较大的类别损失,也可以存在较小的类别损失,各个类别损失可以不同,因此,为了对第一SSD网络和第二SSD网络进行较好的训练,可以在该多个类别损失中,选择一些较大类别损失,用于更新第一SSD网络的网络参数和第二SSD网络的网络参数。该类别损失越大,则说明该第二SSD网络的预测值与真实值之间的误差越大,这样,采用较大的类别损失,可以较好的更新第一SSD网络的网络参数和第二SSD网络的网络参数。
在一些示例中,可以对第二输出层输出的类别损失按照从大到小的顺序进行排序,然后在得到的排序后的类别损失中,选择大于或等于预设的类别损失阈值的类别损失。
在一些示例中,该预设的类别损失阈值,可以是根据经验预先设置的阈值。
在选择类别损失后,可以对选择的大于或等于预设的类别损失阈值的类别损失以及第一输出层输出的类别损失和位置损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
应当理解的是,步骤S111~S120可以参考本发明实施例中的第二种实施方式,这里不再赘述。
在图3所示的一种目标检测网络训练方法的基础上,本发明实施例还提供了一种可能的实现方式,如图5所示,为本发明实施例一种目标检测网络训练方法第四种实施方式的流程图,在S130,对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数之后,该方法还可以包括:
S160,获取多个新的携带目标的样本和多个新的未携带目标的样本;
S170,采用多个新的携带目标的样本和多个新的未携带目标的样本,对更新参数后的第一SSD网络和更新参数后的第二SSD网络进行训练,得到训练完成的第一SSD网络和训练完成的第二SSD网络。
S180,采用训练完成的第一SSD网络,对除第一误检测图片外的未携带目标的图片集进行检测,得到第二误检测图片。
其中,第二误检测图片中未携带目标;
S190,将第二误检测图片、第一误检测图片以及多个携带目标的样本作为训练样本,对训练完成的第一SSD网络和训练完成的第二SSD网络进行训练。
在一些示例中,为了进一步提高训练完成的第一SSD网络进行检测的准确度,在通过步骤S160和S170对更新参数后的第一SSD网络和更新参数后的第二SSD网络进行训练后,还可以采用该训练完成的第一SSD网络,对除第一误检测图片外的未携带目标的图片集进行检测,当该除第一误检测图片外的未携带目标的图片集中存在与目标相似的图片时,该训练完成的第一SSD网络可能会将与目标相似的图片检测为携带目标的图片,从而可以得到第二误检测图片。
然后,可以将该第二误检测图片、第一误检测图片以及多个携带目标的样本作为训练样本,继续对训练完成的第一SSD网络和训练完成的第二SSD网络进行训练,这样,在训练完成后,该第一SSD网络能够进一步区分出与该目标相似的图片,避免造成误检,提高该第一SSD网络的准确度。
应当理解的是,步骤S111~S120可以参考本发明实施例中的第二种实施方式,这里不再赘述。
为了更清楚的说明本发明实施例,这里以一个完整的训练过程进行说明。
首先,可以为预先建立第一SSD网络和第二SSD网络,并为该预先建立的第一SSD网络和预先建立的第二SSD网络设置初始参数,接下来,可以预先设置一个携带目标的样本集,该携带目标的样本集中可以包括多个携带目标的样本。进而可以采用该携带目标的样本集对预先建立的第一SSD网络和预先建立的第二SSD网络进行训练。假设训练N次得到了采用该携带目标的样本集训练的第一SSD网络和第二SSD网络。
紧接着,可以采用该训练得到的第一SSD网络在预先设置的包括多个不携带目标的图片集中进行检测,然后将误检测得到的图片加入该预先设置的携带目标的样本集中。
应当理解的是,由于该多个不携带目标的图片集中,每个图片均不携带目标,因此,采用该训练得到的第一SSD网络进行检测得到的图片是误检测图片,也即,检测得到的图片是该训练得到的第一SSD网络认为携带目标的图片。
随后,可以采用加入该误检测图片的携带目标的样本集,通过本发明实施例的第一种实施方式,进行N次训练,可以得到再次训练后的第一SSD网络和第二SSD网络。
继续采用该再次训练后的第一SSD网络在除该误检测图片后的包括多个不携带目标的图片集中检测,得到再次检测的误检测图片,进而可以将该再次检测的误检测图片,加入上述已经加入误检测图片的携带目标的样本集中。
重复上述步骤,直至对不携带目标的图片集进行M轮检测。
最后,将第M轮检测得到的误检测图片加入样本集中,然后采用加入第M检测的误检测图片的样本集继续对第一SSD网络和第二SSD网络训练N次,总共对第一SSD网络和第二SSD网络进行了N*(M+1)次训练,也即,对预先建立的第一SSD网络和预先建立的第二SSD网络进行了(M+1)轮训练,每一轮训练进行N次,这样,便可以得到训练完成的第一SSD网络和第二SSD网络。在训练完成后,可以采用该训练完成的第一SSD网络对待检测图片进行检测,检测该待检测图片中是否携带目标,在携带目标时,该携带的目标的类别以及在该待检测图片中的位置。
在一些示例中,由于携带目标的样本中目标具有固定的位置,例如,有Logo的样本图片中Logo具有固定的位置,因此,在采用包括该携带目标的样本进行训练得到第一SSD网络后,该第一SSD网络在对待检测图片进行检测时,可以检测到该待检测图片中,与目标所在位置相同的位置是否有Logo。这样,提高该第一SSD网络检测的准确度,减少误检。
相应于上述的方法实施例,本发明实施例还提供了一种目标检测网络训练装置,如图6所示,为本发明实施例的一种目标检测网络训练装置的结构示意图,该装置可以包括:
第一样本获取模块610,用于获取携带目标的样本和未携带目标的样本,其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本;
样本输入模块620,用于将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;
更新模块630,用于对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
本发明实施例提供的一种目标检测网络训练装置,可以将标注有目标类别和目标位置的样本输入至第一SSD网络,将未携带目标的样本输入至与第一SSD网络具有相同网络参数的第二SSD网络,这样,可以使得该第一SSD网络输出与该标注有目标类别和目标位置的样本对应的类别损失和位置损失,使得第二SSD网络输出与该未携带目标的样本对应的类别损失,由于该未携带目标的样本中未携带目标,当该类别损失较大时,则说明该未携带目标的样本中存在与该携带目标的样本中的目标相似的对象,使得该第二SSD网络将该未携带目标的样本中的对象预测为该携带目标的样本中的目标。通过对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。可以使得在对该第一SSD网络的网络参数和第二SSD网络的网络参数更新后,降低该第二SSD网络将未携带目标的样本中的对象预测为目标的可能性,由于第一SSD网络和第二SSD网络共享参数,因此,可以使得该第一SSD网络在对待检测的图片进行检测时,能够更好的识别出该待检测图片中与目标相似的对象,提高该第一SSD网络检测的准确度,减少误检。
具体的,该目标检测网络训练装置,还可以包括:
携带目标样本获取模块,用于获取多个携带目标的样本;
第一训练模块,用于采用多个携带目标的样本对预先建立的第一SSD网络和预先建立的第二SSD进行训练,得到第一SSD网络和第二SSD网络;
具体的,第一样本获取模块610,可以包括:
检测子模块,用于获取未携带目标的图片集,并采用第一SSD网络对未携带目标的图片集进行检测,得到第一误检测图片,其中,第一误检测图片中未携带目标;
样本获取子模块,将获取的多个携带目标的样本作为携带目标的样本,将第一误检测图片作为未携带目标的样本。
具体的,第二SSD网络可以包括:基础特征层、第一卷积层、第二卷积层、第三卷积层、第四卷积层、池化层以及第二输出层,第二输出层可以包括:第一输出子层、第二输出子层、第三输出子层、第四输出子层、第五输出子层以及第六输出子层;
具体的,样本输入模块620,可以包括:
第一输入子模块,用于将未携带目标的样本输入至第二SSD网络的基础特征层,得到基础特征层输出的基础特征图;
第二输入子模块,用于将基础特征图输入至第二SSD网络的第一卷积层和第一输出子层,得到第一卷积层输出的卷积后的第一特征图和第一输出子层输出的第一类别损失;
第三输入子模块,用于将第一特征图输入至第二SSD网络的第二卷积层和第二输出子层,得到第二卷积层输出的卷积后的第二特征图和第二输出子层输出的第二类别损失;
第四输入子模块,用于将第二特征图输入至第二SSD网络的第三卷积层和第三输出子层,得到第三卷积层输出的卷积后的第三特征图和第三输出子层输出的第三类别损失;
第五输入子模块,用于将第三特征图输入至第二SSD网络的第四卷积层和第四输出子层,得到第四卷积层输出的卷积后的第四特征图和第四输出子层输出的第四类别损失;
第六输入子模块,用于将第四特征图输入至第二SSD网络的池化层和第五输出子层,得到池化层输出的池化后的特征图和第五输出子层输出的第五类别损失;
第七输入子模块,用于将池化后的特征图输入至第六输出子层,得到第六输出子层输出的第六类别损失;
转换子模块,用于将第一类别损失、第二类别损失、第三类别损失、第四类别损失、第五类别损失以及第六类别损失,作为第二输出层输出的类别损失。
具体的,本发明实施例的一种目标检测网络训练装置,还可以包括:
类别损失排序模块,用于对第二输出层输出的类别损失按照从大到小的顺序进行排序,得到排序后的类别损失;
类别损失选择模块,用于获取预设的类别损失阈值,并在排序后的类别损失中,选择大于或等于预设的类别损失阈值的类别损失;
更新模块630,具体用于:
对选择的大于或等于预设的类别损失阈值的类别损失以及第一输出层输出的类别损失和位置损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
具体的,本发明实施例的一种目标检测网络训练装置,还可以包括:
第二样本获取模块,用于获取多个新的携带目标的样本和多个新的未携带目标的样本;
第二训练模块,用于采用多个新的携带目标的样本和多个新的未携带目标的样本,对更新参数后的第一SSD网络和更新参数后的第二SSD网络进行训练,得到训练完成的第一SSD网络和训练完成的第二SSD网络。
具体的,本发明实施例的一种目标检测网络训练装置,还可以包括:
检测模块,用于采用训练完成的第一SSD网络,对除第一误检测图片外的未携带目标的图片集进行检测,得到第二误检测图片,其中,第二误检测图片中未携带目标;
第三训练模块,用于将第二误检测图片、第一误检测图片以及多个携带目标的样本作为训练样本,对训练完成的第一SSD网络和训练完成的第二SSD网络进行训练。
本发明实施例还提供了一种电子设备,如图7所示,包括处理器701、通信接口702、存储器703和通信总线704,其中,处理器701,通信接口702,存储器703通过通信总线704完成相互间的通信,
存储器703,用于存放计算机程序;
处理器701,用于执行存储器703上所存放的程序时,执行上述实施例中任一所述的目标检测网络训练方法的步骤,例如,执行如下步骤:
获取携带目标的样本和未携带目标的样本,其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本;
将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;
对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,简称PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,简称EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,简称RAM),也可以包括非易失性存储器(non-volatile memory),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,简称CPU)、网络处理器(Network Processor,简称NP)等;还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application SpecificIntegrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本发明提供的又一实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述实施例中任一所述的目标检测网络训练方法的步骤,例如,执行如下步骤:
获取携带目标的样本和未携带目标的样本,其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本;
将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;
对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
在本发明提供的又一实施例中,还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述实施例中任一所述的目标检测网络训练方法,例如,执行如下步骤:
获取携带目标的样本和未携带目标的样本,其中,携带目标的样本中标注有目标的类别和目标的位置,未携带目标的样本为除标注有目标的类别的样本外的样本;
将携带目标的样本输入至第一SSD网络,得到第一SSD网络中的第一输出层输出的类别损失和位置损失;并将未携带目标的样本输入至第二SSD网络,得到第二SSD网络中的第二输出层输出的类别损失,其中,第二SSD网络与第一SSD网络具有相同的网络参数;
对第一输出层输出的类别损失和位置损失以及第二输出层输出的类别损失求和,得到总损失,并基于总损失更新第一SSD网络中的网络参数和第二SSD网络中的网络参数。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本发明实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk(SSD))等。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于系统实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本发明的较佳实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本发明的保护范围内。
Claims (13)
1.一种目标检测网络训练方法,其特征在于,所述方法包括:
获取携带目标的样本和未携带所述目标的样本,其中,所述携带目标的样本中标注有所述目标的类别和所述目标的位置;
将所述携带目标的样本输入至第一单次多框检测器SSD网络,得到所述第一SSD网络中的第一输出层输出的类别损失和位置损失;并将所述未携带所述目标的样本输入至第二SSD网络,得到所述第二SSD网络中的第二输出层输出的类别损失,其中,所述第二SSD网络与所述第一SSD网络具有相同的网络参数;
对所述第一输出层输出的类别损失和位置损失以及所述第二输出层输出的类别损失求和,得到总损失,并基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数,以使得在训练完成后,采用训练完成的第一SSD网络对待检测图片进行检测。
2.根据权利要求1所述的方法,其特征在于,在所述获取携带目标的样本和未携带所述目标的样本之前,所述方法还包括:
获取多个携带所述目标的样本;
采用所述多个携带所述目标的样本对预先建立的第一SSD网络和预先建立的第二SSD进行训练,得到所述第一SSD网络和所述第二SSD网络;
所述获取携带目标的样本和未携带所述目标的样本,包括:
获取未携带所述目标的图片集,并采用所述第一SSD网络对所述未携带所述目标的图片集进行检测,得到第一误检测图片,其中,所述第一误检测图片中未携带所述目标;
将获取的多个携带所述目标的样本作为所述携带目标的样本,将所述第一误检测图片作为所述未携带所述目标的样本。
3.根据权利要求1或2所述的方法,其特征在于,所述第二SSD网络包括:基础特征层、第一卷积层、第二卷积层、第三卷积层、第四卷积层、池化层以及第二输出层,所述第二输出层包括:第一输出子层、第二输出子层、第三输出子层、第四输出子层、第五输出子层以及第六输出子层;
所述将所述未携带所述目标的样本输入至第二SSD网络,得到所述第二SSD网络中的第二输出层输出的类别损失,包括:
将所述未携带所述目标的样本输入至所述第二SSD网络的基础特征层,得到所述基础特征层输出的基础特征图;
将所述基础特征图输入至所述第二SSD网络的第一卷积层和所述第一输出子层,得到所述第一卷积层输出的卷积后的第一特征图和所述第一输出子层输出的第一类别损失;
将所述第一特征图输入至所述第二SSD网络的第二卷积层和所述第二输出子层,得到所述第二卷积层输出的卷积后的第二特征图和所述第二输出子层输出的第二类别损失;
将所述第二特征图输入至所述第二SSD网络的第三卷积层和所述第三输出子层,得到所述第三卷积层输出的卷积后的第三特征图和所述第三输出子层输出的第三类别损失;
将所述第三特征图输入至所述第二SSD网络的第四卷积层和所述第四输出子层,得到所述第四卷积层输出的卷积后的第四特征图和所述第四输出子层输出的第四类别损失;
将所述第四特征图输入至所述第二SSD网络的池化层和所述第五输出子层,得到所述池化层输出的池化后的特征图和所述第五输出子层输出的第五类别损失;
将所述池化后的特征图输入至所述第六输出子层,得到所述第六输出子层输出的第六类别损失;
将所述第一类别损失、所述第二类别损失、所述第三类别损失、所述第四类别损失、所述第五类别损失以及所述第六类别损失,作为所述第二输出层输出的类别损失。
4.根据权利要求3所述的方法,其特征在于,在所述对所述第一输出层输出的类别损失和位置损失以及所述第二输出层输出的类别损失求和,得到总损失,并基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数之前,所述方法还包括:
对所述第二输出层输出的类别损失按照从大到小的顺序进行排序,得到排序后的类别损失;
获取预设的类别损失阈值,并在所述排序后的类别损失中,选择大于或等于所述预设的类别损失阈值的类别损失;
所述对所述第一输出层输出的类别损失和位置损失以及所述第二输出层输出的类别损失求和,得到总损失,并基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数,包括:
对所述选择的大于或等于所述预设的类别损失阈值的类别损失以及所述第一输出层输出的类别损失和位置损失求和,得到总损失,并基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数。
5.根据权利要求2所述的方法,其特征在于,在所述基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数之后,所述方法还包括:
获取多个新的携带所述目标的样本和多个新的未携带所述目标的样本;
采用所述多个新的携带所述目标的样本和多个新的未携带所述目标的样本,对所述更新参数后的第一SSD网络和所述更新参数后的第二SSD网络进行训练,得到训练完成的第一SSD网络和训练完成的第二SSD网络。
6.根据权利要求5所述的方法,其特征在于,所述方法还包括:
采用所述训练完成的第一SSD网络,对除所述第一误检测图片外的未携带所述目标的图片集进行检测,得到第二误检测图片,其中,所述第二误检测图片中未携带所述目标;
将所述第二误检测图片、所述第一误检测图片以及所述多个携带所述目标的样本作为训练样本,对所述训练完成的第一SSD网络和所述训练完成的第二SSD网络进行训练。
7.一种目标检测网络训练装置,其特征在于,所述装置包括:
第一样本获取模块,用于获取携带目标的样本和未携带所述目标的样本,其中,所述携带目标的样本中标注有所述目标的类别和所述目标的位置;
样本输入模块,用于将所述携带目标的样本输入至第一单次多框检测器SSD网络,得到所述第一SSD网络中的第一输出层输出的类别损失和位置损失;并将所述未携带所述目标的样本输入至第二SSD网络,得到所述第二SSD网络中的第二输出层输出的类别损失,其中,所述第二SSD网络与所述第一SSD网络具有相同的网络参数;
更新模块,用于对所述第一输出层输出的类别损失和位置损失以及所述第二输出层输出的类别损失求和,得到总损失,并基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数,以使得在训练完成后,采用训练完成的第一SSD网络对待检测图片进行检测。
8.根据权利要求7所述的装置,其特征在于,所述装置还包括:
携带目标样本获取模块,用于获取多个携带所述目标的样本;
第一训练模块,用于采用所述多个携带所述目标的样本对预先建立的第一SSD网络和预先建立的第二SSD进行训练,得到所述第一SSD网络和所述第二SSD网络;
所述第一样本获取模块,包括:
检测子模块,用于获取未携带所述目标的图片集,并采用所述第一SSD网络对所述未携带所述目标的图片集进行检测,得到第一误检测图片,其中,所述第一误检测图片中未携带所述目标;
样本获取子模块,将获取的多个携带所述目标的样本作为所述携带目标的样本,将所述第一误检测图片作为所述未携带所述目标的样本。
9.根据权利要求7或8所述的装置,其特征在于,所述第二SSD网络包括:基础特征层、第一卷积层、第二卷积层、第三卷积层、第四卷积层、池化层以及第二输出层,所述第二输出层包括:第一输出子层、第二输出子层、第三输出子层、第四输出子层、第五输出子层以及第六输出子层;
所述样本输入模块,包括:
第一输入子模块,用于将所述未携带所述目标的样本输入至所述第二SSD网络的基础特征层,得到所述基础特征层输出的基础特征图;
第二输入子模块,用于将所述基础特征图输入至所述第二SSD网络的第一卷积层和所述第一输出子层,得到所述第一卷积层输出的卷积后的第一特征图和所述第一输出子层输出的第一类别损失;
第三输入子模块,用于将所述第一特征图输入至所述第二SSD网络的第二卷积层和所述第二输出子层,得到所述第二卷积层输出的卷积后的第二特征图和所述第二输出子层输出的第二类别损失;
第四输入子模块,用于将所述第二特征图输入至所述第二SSD网络的第三卷积层和所述第三输出子层,得到所述第三卷积层输出的卷积后的第三特征图和所述第三输出子层输出的第三类别损失;
第五输入子模块,用于将所述第三特征图输入至所述第二SSD网络的第四卷积层和所述第四输出子层,得到所述第四卷积层输出的卷积后的第四特征图和所述第四输出子层输出的第四类别损失;
第六输入子模块,用于将所述第四特征图输入至所述第二SSD网络的池化层和所述第五输出子层,得到所述池化层输出的池化后的特征图和所述第五输出子层输出的第五类别损失;
第七输入子模块,用于将所述池化后的特征图输入至所述第六输出子层,得到所述第六输出子层输出的第六类别损失;
转换子模块,用于将所述第一类别损失、所述第二类别损失、所述第三类别损失、所述第四类别损失、所述第五类别损失以及所述第六类别损失,作为所述第二输出层输出的类别损失。
10.根据权利要求9所述的装置,其特征在于,所述装置还包括:
类别损失排序模块,用于对所述第二输出层输出的类别损失按照从大到小的顺序进行排序,得到排序后的类别损失;
类别损失选择模块,用于获取预设的类别损失阈值,并在所述排序后的类别损失中,选择大于或等于所述预设的类别损失阈值的类别损失;
所述更新模块,具体用于:
对所述选择的大于或等于所述预设的类别损失阈值的类别损失以及所述第一输出层输出的类别损失和位置损失求和,得到总损失,并基于所述总损失更新所述第一SSD网络中的网络参数和所述第二SSD网络中的网络参数。
11.根据权利要求8所述的装置,其特征在于,所述装置还包括:
第二样本获取模块,用于获取多个新的携带所述目标的样本和多个新的未携带所述目标的样本;
第二训练模块,用于采用所述多个新的携带所述目标的样本和多个新的未携带所述目标的样本,对所述更新参数后的第一SSD网络和所述更新参数后的第二SSD网络进行训练,得到训练完成的第一SSD网络和训练完成的第二SSD网络。
12.根据权利要求11所述的装置,其特征在于,所述装置还包括:
检测模块,用于采用所述训练完成的第一SSD网络,对除所述第一误检测图片外的未携带所述目标的图片集进行检测,得到第二误检测图片,其中,所述第二误检测图片中未携带所述目标;
第三训练模块,用于将所述第二误检测图片、所述第一误检测图片以及所述多个携带所述目标的样本作为训练样本,对所述训练完成的第一SSD网络和所述训练完成的第二SSD网络进行训练。
13.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1-6任一所述的方法步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910277616.6A CN109977913B (zh) | 2019-04-08 | 2019-04-08 | 一种目标检测网络训练方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910277616.6A CN109977913B (zh) | 2019-04-08 | 2019-04-08 | 一种目标检测网络训练方法、装置及电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN109977913A CN109977913A (zh) | 2019-07-05 |
CN109977913B true CN109977913B (zh) | 2021-11-05 |
Family
ID=67083474
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910277616.6A Active CN109977913B (zh) | 2019-04-08 | 2019-04-08 | 一种目标检测网络训练方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109977913B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111242081B (zh) * | 2020-01-19 | 2023-05-12 | 深圳云天励飞技术有限公司 | 视频检测方法、目标检测网络训练方法、装置及终端设备 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106548155A (zh) * | 2016-10-28 | 2017-03-29 | 安徽四创电子股份有限公司 | 一种基于深度信念网络的车牌检测方法 |
CN106778867A (zh) * | 2016-12-15 | 2017-05-31 | 北京旷视科技有限公司 | 目标检测方法和装置、神经网络训练方法和装置 |
CN108898620A (zh) * | 2018-06-14 | 2018-11-27 | 厦门大学 | 基于多重孪生神经网络与区域神经网络的目标跟踪方法 |
CN109446889A (zh) * | 2018-09-10 | 2019-03-08 | 北京飞搜科技有限公司 | 基于孪生匹配网络的物体追踪方法及装置 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109035204B (zh) * | 2018-06-25 | 2021-06-08 | 华南理工大学 | 一种焊缝目标实时检测方法 |
-
2019
- 2019-04-08 CN CN201910277616.6A patent/CN109977913B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106548155A (zh) * | 2016-10-28 | 2017-03-29 | 安徽四创电子股份有限公司 | 一种基于深度信念网络的车牌检测方法 |
CN106778867A (zh) * | 2016-12-15 | 2017-05-31 | 北京旷视科技有限公司 | 目标检测方法和装置、神经网络训练方法和装置 |
CN108898620A (zh) * | 2018-06-14 | 2018-11-27 | 厦门大学 | 基于多重孪生神经网络与区域神经网络的目标跟踪方法 |
CN109446889A (zh) * | 2018-09-10 | 2019-03-08 | 北京飞搜科技有限公司 | 基于孪生匹配网络的物体追踪方法及装置 |
Non-Patent Citations (2)
Title |
---|
"Fully-Convolutional Siamese Networks for Object Tracking";Luca Bertinetto et al.;《arXiv:1606.09549v2》;20160914;第1-16页 * |
"SSD: Single Shot MultiBox Detector";Wei Liu et al.;《arXiv:1512.02325v5》;20161229;第1-17页 * |
Also Published As
Publication number | Publication date |
---|---|
CN109977913A (zh) | 2019-07-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108921206B (zh) | 一种图像分类方法、装置、电子设备及存储介质 | |
CN111079570B (zh) | 一种人体关键点识别方法、装置及电子设备 | |
CN108171203B (zh) | 用于识别车辆的方法和装置 | |
CN110033018B (zh) | 图形相似度判断方法、装置及计算机可读存储介质 | |
CN109685528A (zh) | 基于深度学习检测仿冒产品的系统和方法 | |
JP6892606B2 (ja) | 位置特定装置、位置特定方法及びコンピュータプログラム | |
CN111526119A (zh) | 异常流量检测方法、装置、电子设备和计算机可读介质 | |
CN112163480B (zh) | 一种行为识别方法及装置 | |
CN111861909A (zh) | 一种网络细粒度图像去噪分类方法 | |
CN111178364A (zh) | 一种图像识别方法和装置 | |
CN115100739B (zh) | 人机行为检测方法、系统、终端设备及存储介质 | |
CN109977913B (zh) | 一种目标检测网络训练方法、装置及电子设备 | |
CN115797735A (zh) | 目标检测方法、装置、设备和存储介质 | |
CN111027412A (zh) | 一种人体关键点识别方法、装置及电子设备 | |
CN111325067B (zh) | 违规视频的识别方法、装置及电子设备 | |
CN113076961B (zh) | 一种图像特征库更新方法、图像检测方法和装置 | |
CN110659954A (zh) | 作弊识别方法、装置、电子设备及可读存储介质 | |
CN108647986B (zh) | 一种目标用户确定方法、装置及电子设备 | |
CN113557546A (zh) | 图像中关联对象的检测方法、装置、设备和存储介质 | |
CN111881007B (zh) | 操作行为判断方法、装置、设备及计算机可读存储介质 | |
CN111222558A (zh) | 图像处理方法及存储介质 | |
CN116257885A (zh) | 基于联邦学习的隐私数据通信方法、系统和计算机设备 | |
CN112149698A (zh) | 一种困难样本数据的筛选方法及装置 | |
CN112884866B (zh) | 一种黑白视频的上色方法、装置、设备及存储介质 | |
CN110399803B (zh) | 一种车辆检测方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |