CN113673533A - 一种模型训练方法及相关设备 - Google Patents

一种模型训练方法及相关设备 Download PDF

Info

Publication number
CN113673533A
CN113673533A CN202010412910.6A CN202010412910A CN113673533A CN 113673533 A CN113673533 A CN 113673533A CN 202010412910 A CN202010412910 A CN 202010412910A CN 113673533 A CN113673533 A CN 113673533A
Authority
CN
China
Prior art keywords
network
classification
loss
target
feature
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010412910.6A
Other languages
English (en)
Inventor
唐福辉
张晓鹏
钮敏哲
王子辰
韩建华
田奇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huawei Technologies Co Ltd
Original Assignee
Huawei Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huawei Technologies Co Ltd filed Critical Huawei Technologies Co Ltd
Priority to CN202010412910.6A priority Critical patent/CN113673533A/zh
Priority to PCT/CN2021/088787 priority patent/WO2021227804A1/zh
Publication of CN113673533A publication Critical patent/CN113673533A/zh
Priority to US17/986,081 priority patent/US20230075836A1/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/20Image preprocessing
    • G06V10/30Noise filtering
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Data Mining & Analysis (AREA)
  • Databases & Information Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • Medical Informatics (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例提供一种模型训练方法及相关装置,可用于人工智能、计算机视觉等领域,用来进行图像检测,该方法包括:分别通过第一网络的特征提取层和第二网络的特征提取层提取目标图像中的特征信息;并分别通过高斯掩膜进一步提取特征信息中关于目标物体的特征,得到第一局部特征和第二局部特征;再通过第一局部特征和第二局部特征确定特征损失;并且通过第一网络和第二网络基于同样的区域提议集合进行预测得到第一分类预测值和第二分类预测值,再根据第一分类预测值和第二分类预测值得到分类损失;之后根据分类损失和特征损失对第二网络训练,得到目标网络。采用本申请实施例,能够使得用于检测图像的目标网络的预测速度更快、预测准确度更高。

Description

一种模型训练方法及相关设备
技术领域
本申请涉及图像处理领域,尤其涉及一种模型训练方法及相关设备。
背景技术
目标检测指的是在图像中对目标物体进行分类和定位,如图1所示,通过目标检测可以对图像中的伞101进行分类和定位,也可以对图像中的人102进行分类和定位。对图像进行目标检测的应用非常广泛,如自动驾驶、平安城市、以及手机终端等,因此对检测准确度和速度都有很高的要求。对图像进行目标检测通常是通过神经网络来实现的,然而,大神经网络检测的准确度虽高,速度却很慢;小神经网络检测的速度快,准确度很低。
如何训练出一个检测速度更快、检测结果更准确的神经网络是本领域的技术人员正在研究的技术问题。
发明内容
本申请实施例公开了一种模型训练方法及相关设备,可以应用于人工智能、计算机视觉等领域中,用来进行图像检测,该方法及相关设备能够提高网络的预测效率和准确度。
第一方面,本申请实施例提供一种模型训练方法,该方法包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息,其中,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到目标网络。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第一方面,在第一方面的第一种可能的实现方式中,所述方法还包括:
通过所述第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
通过所述第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第一方面,或者第一方面的第一种可能的实现方式,在第一方面的第二种可能的实现方式中,所述根据所述特征损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失训练所述第二网络;
通过第三网络对经过训练后的所述第二网络进行训练,得到目标网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
第二方面,本申请实施例提供一种模型训练方法,该方法包括:
基于第一网络训练第二网络得到中间网络;
基于第三网络训练所述中间网络,得到目标网络,其中,所述第一网络、所述第二网络和所述第三网络均为分类网络,且所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
结合第二方面,在第二方面的第一种可能的实现方式中,所述基于第一网络训练第二网络包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第二方面的第一种可能的实现方式,在第二方面的第二种可能的实现方式中,所述方法还包括:
通过第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
通过第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到所述中间网络,包括:
根据所述特征损失和所述分类损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式,在又一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
结合第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式,在又一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
结合第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式,在又一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
结合第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式,在又一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000031
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000032
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000033
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000034
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000039
表示基于
Figure BDA0002493947430000035
和ym得到的交叉熵损失,
Figure BDA0002493947430000036
表示基于
Figure BDA0002493947430000037
Figure BDA0002493947430000038
得到的二值交叉熵损失,β为预设的权重平衡因子。
结合第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式,在又一种可能的实现方式中,所述方法还包括:
根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
结合第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式,在又一种可能的实现方式中,所述根据所述特征损失训练所述第二网络,得到目标网络之后,还包括:
向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
第三方面,本申请实施例提供一种图像检测方法,该方法包括:
获取目标网络,其中,所述目标网络为通过第一网络对第二网络进行训练后得到的网络,通过所述第一网络训练所述第二网络用到的参数包括特征损失,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
通过所述目标网络识别图像中的内容。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第三方面,在第三方面的第一种可能的实现方式中,训练所述第二网络用到的参数还包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第三方面,或者第三方面的第一种可能的实现方式,在第二种可能的实现方式中,所述目标网络具体为通过所述第一网络对第二网络进行训练,并通过第三网络对训练得到的网络进一步进行训练之后的网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
第四方面,本申请实施例提供一种图像检测方法,该方法包括:
获取目标网络,其中,所述目标网络为通过多个网络迭代对第二网络进行训练得到的网络,所述多个网络均为分类网络,所述多个网络至少包括第一网络和第三网络,所述第三网络用于在所述第一网络对第二网络进行训练得到中间网络后对所述中间网络进行训练,其中,所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度;
通过所述目标网络识别图像中的内容。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
结合第四方面,在第四方面的第一种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括特征损失,其中,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第四方面的第一种可能的实现方式,在第四方面的第二种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式,在又一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
结合第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式,在又一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
结合第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式,在又一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
结合第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式,在又一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000051
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000052
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000053
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000061
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000062
表示基于
Figure BDA0002493947430000063
和ym得到的交叉熵损失,
Figure BDA0002493947430000064
表示基于
Figure BDA0002493947430000065
Figure BDA0002493947430000066
得到的二值交叉熵损失,β为预设的权重平衡因子。
结合第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式,在又一种可能的实现方式中,训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
结合第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式,在又一种可能的实现方式中,所述获取目标网络,包括:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
第五方面,本申请实施例提供一种模型训练装置,该装置包括:
特征提取单元,用于通过第一网络的特征提取层提取目标图像中的第一特征信息;
所述特征提取单元,还用于通过第二网络的特征提取层提取目标图像中的第二特征信息,其中,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
第一优化单元,用于通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
第二优化单元,用于通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
第一确定单元,用于通过所述第一局部特征和所述第二局部特征确定特征损失;
权重调整单元,用于根据所述特征损失训练所述第二网络,得到目标网络。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第五方面,在第五方面的第一种可能的实现方式中,所述装置还包括:
第一生成单元,用于通过所述第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
第二生成单元,用于通过所述第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
第二确定单元,用于根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述权重调整单元具体用于:根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第五方面,或者第五方面的任一种可能的实现方式,在第五方面的第二种可能的实现方式中,所述在根据所述特征损失训练所述第二网络,得到目标网络,所述权重调整单元具体用于:
根据所述特征损失训练所述第二网络;
通过第三网络对经过训练后的所述第二网络进行训练,得到目标网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
第六方面,本申请实施例提供一种模型训练装置,该装置包括:
第一训练单元,用于基于第一网络训练第二网络得到中间网络;
第二训练单元,用于基于第三网络训练所述中间网络,得到目标网络,其中,所述第一网络、所述第二网络和所述第三网络均为分类网络,且所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
结合第六方面,在第六方面的第一种可能的实现方式中,所述基于第一网络训练第二网络得到中间网络包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第六方面,或者第六方面的第一种可能的实现方式,在第六方面的第二种可能的实现方式中,所述装置还包括:
第一生成单元,用于通过第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
第二生成单元,用于通过第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
第二确定单元,用于根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到所述中间网络,具体为:
根据所述特征损失和所述分类损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第五方面,或者第六方面,或者第五方面的任一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
结合第五方面,或者第六方面,或者第五方面的任一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
结合第五方面,或者第六方面,或者第五方面的任一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
结合第五方面,或者第六方面,或者第五方面的任一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000081
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000082
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000083
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000084
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000085
表示基于
Figure BDA0002493947430000086
和ym得到的交叉熵损失,
Figure BDA0002493947430000087
表示基于
Figure BDA0002493947430000088
Figure BDA0002493947430000089
得到的二值交叉熵损失,β为预设的权重平衡因子。
结合第五方面,或者第六方面,或者第五方面的任一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,所述装置还包括:
第三确定单元,用于根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
所述权重调整单元具体用于:根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
结合第五方面,或者第六方面,或者第五方面的任一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,还包括:
发送单元,用于在所述权重调整单元根据所述特征损失训练所述第二网络,得到目标网络之后,向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
第七方面,本申请实施例提供一种图像检测装置,该装置包括:
获取单元,用于获取目标网络,其中,所述目标网络为通过第一网络对第二网络进行训练后得到的网络,通过所述第一网络训练所述第二网络用到的参数包括特征损失,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
识别单元,用于通过所述目标网络识别图像中的内容。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第七方面,在第七方面的第一种可能的实现方式中,训练所述第二网络用到的参数还包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第七方面,或者第七方面的第一种可能的实现方式,在第七方面的第二种可能的实现方式中,所述目标网络具体为通过所述第一网络对第二网络进行训练,并通过第三网络对训练得到的网络进一步进行训练之后的网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
第八方面,本申请实施例提供一种图像检测装置,该装置包括:
获取单元,用于获取目标网络,其中,所述目标网络为通过多个网络迭代对第二网络进行训练得到的网络,所述多个网络均为分类网络,所述多个网络至少包括第一网络和第三网络,所述第三网络用于在所述第一网络对第二网络进行训练得到中间网络后对所述中间网络进行训练,其中,所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度;
识别单元,用于通过所述目标网络识别图像中的内容。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
结合第八方面,在第八方面的第一种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括特征损失,其中,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
结合第八方面的第一种可能的实现方式,在第八方面的第二种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
结合第七方面,或者第八方面,或者第七方面的任一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
结合第七方面,或者第八方面,或者第七方面的任一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
结合第七方面,或者第八方面,或者第七方面的任一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
结合第七方面,或者第八方面,或者第七方面的任一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000101
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000102
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000103
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000111
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000112
表示基于
Figure BDA0002493947430000113
和ym得到的交叉熵损失,
Figure BDA0002493947430000114
表示基于
Figure BDA0002493947430000115
Figure BDA0002493947430000116
得到的二值交叉熵损失,β为预设的权重平衡因子。
结合第七方面,或者第八方面,或者第七方面的任一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中:训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
结合第七方面,或者第八方面,或者第七方面的任一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中,所述获取单元具体用于:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
第九方面,本申请实施例提供一种模型训练设备,该模型训练设备包括存储器和处理器,该存储器用于存储计算机程序,该处理器用于调用该计算机程序来实现上述第一方面,或者第二方面,或者第一方面的任一种可能的实现方式,或者第二方面的任一种可能的实现方式所描述的方法。
第十方面,本申请实施例提供一种模型使用设备,该模型使用设备包括存储器和处理器,该存储器用于存储计算机程序,该处理器用于调用该计算机程序来实现上述第三方面,或者第四方面,或者第三方面的任一种可能的实现方式,或者第四方面的任一种可能的实现方式所描述的方法。
第十一方面,本申请实施例提供一种计算机可读存储介质,所述计算机可读存储介质用于存储计算机程序,当所述计算机程序在处理器上运行时,实现上述第一方面,或者第二方面,或者第三方面,或者第四方面,或者其中任一个方面的任一种可能的实现方式所描述的方法。
附图说明
以下对本申请实施例用到的附图进行介绍。
图1是本申请实施例提供的一种图像检测的场景示意图;
图2是本申请实施例提供的一种模型蒸馏的场景示意图;
图3是本申请实施例提供的又一种模型蒸馏的场景示意图;
图4是本申请实施例提供的又一种模型蒸馏的场景示意图;
图5是本申请实施例提供的又一种模型蒸馏的场景示意图;
图6是本申请实施例提供的一种模型训练的架构示意图;
图7是本申请实施例提供的又一种模型蒸馏的场景示意图;
图8是本申请实施例提供的又一种图像检测的场景示意图;
图9是本申请实施例提供的又一种图像检测的场景示意图;
图10是本申请实施例提供的一种模型训练方法的流程示意图;
图11是本申请实施例提供的一种高斯掩膜的原理示意图;
图12是本申请实施例提供的又一种模型蒸馏的场景示意图;
图13是本申请实施例提供的一种模型训练装置的结构示意图;
图14是本申请实施例提供的又一种模型训练装置的结构示意图;
图15是本申请实施例提供的一种图像检测装置的结构示意图;
图16是本申请实施例提供的又一种图像检测装置的结构示意图;
图17是本申请实施例提供的一种模型训练设备的结构示意图;
图18是本申请实施例提供的又一种模型训练设备的结构示意图;
图19是本申请实施例提供的一种模型使用设备的结构示意图;
图20是本申请实施例提供的又一种模型使用设备的结构示意图。
具体实施方式
下面结合本申请实施例中的附图对本申请实施例进行描述。
在人工智能、计算机视觉等领域中,通常涉及图像检测,即对图像中的物体(或者说目标)进行识别,图像检测通常是通过神经网络来实现的,在面临如何训练出一个检测速度更快、检测结果更准确的神经网络来进行目标检测的技术问题时,可以采用模型蒸馏的方式来解决,如图2所示,具体是让大神经网络指导小神经网络的训练,或者说小神经网络模拟大神经网络的表现,从而将大神经网络检测结果更准确的性能赋予小神经网络,使得小神经网络既具备检测速度快,又具备检测结果更准确的优点。
下面例举几种通过大神经网络指导小神经网络训练的方案。
例如,如图3所示,通过大神经网络的特征提取层提取图像的特征信息,然后送到大神经网络的分类层产生分类软标签;以及通过小神经网络的特征提取层提取图像的特征信息,然后送到小神经网络的分类层产生分类软标签;接着通过两个网络产生的分类软标签确定分类损失;然后通过分类损失指导小神经网络的分类层的训练。然而,这种方案中大神经网络对小神经网络的指导不够全面(如忽略了对特征提取层的指导),也不够精细化,因此指导之后的效果并不理想。
再如,如图4所示,将大神经网络的特征提取层的信息传递给小神经网络,然后对具有大神经网络的特征提取层的信息的小神经网络进行压缩,得到一个新的更瘦、更深的小神经网络,图4中的WT为大神经网络(老师网络)的模型权重,不同层可以有不同的权重,例如,
Figure BDA0002493947430000121
等;图4中的WS为小神经网络(学生网络)的模型权重,不同层可以有不同的权重,例如,
Figure BDA0002493947430000122
等。然而,这种方案中需要构建新的小神经网络,且构建的新的小神经网络的层数变得更多,具有一定的复杂性。并且,这种方案中通过特征提取层提取的是整幅图中的特征,存在较多背景噪声,从而导致检测结果不理性。
再如,如图5所示,选择图像中目标区域附近的区域作为蒸馏区域,让小神经网络学习大神经网络在该蒸馏区域的特征提取层表达,然而蒸馏区域中依然存在很多无用的背景噪声,因此学习的特征提取层表达依旧不理想;并且,小神经网络仅对特征提取层的表达进行了学习,因此性能提升有限。
鉴于大神经网络指导小神经网络训练依旧存在较多局限性,本申请实施例下面进一步提供相关的架构、设备、及方法来进一步提升大神经网络指导小神经网络训练的效果。
请参数图6,图6是本申请实施例提供的一种模型训练的架构示意图,该架构包括模型训练设备601和一个或多个模型使用设备602,该模型训练设备601与该模型使用设备602之间通过有线或者无线的方式进行通信,因此该模型训练设备601可以将训练出的用于预测图像中的目标物体的模型(或者网络)发送给模型使用设备602;相应地,模型使用设备602通过接收到的模型来预测待预测图像中的目标物体。
可选的,该模型使用设备602可以将基于该模型预测的结果反馈给上述模型训练设备601,使得模型训练设备601可以进一步基于该模型使用设备602的预测结果对模型做进一步的训练;重新训练好的模型可以发送给模型使用设备602对原来的模型进行更新。
该模型训练设备601可以为具有较强计算能力的设备,例如,一个服务器,或者由多个服务器组成的一个服务器集群。该模型训练设备601中可以包括很多神经网络,其中层比较多的神经网络相对于层比较少的神经网络而言可以称为大神经网络,层比较少的神经网络相对于层比较多的神经网络而言可以称为小神经网络,即第一网络的深度大于第二网络的深度。
如图7所示,该模型训练设备601中包括第一网络701和第二网络702,该第一网络701可以为比第二网络702大的神经网络,第一网络701和第二网络702均包括特征提取层(也可以称为特征层)和分类层(也可以称分类器,或者分类头),其中,特征提取层用于提取图像中的特征信息,分类层用于基于提取的特征信息对图像中的目标物体进行分类。
本申请实施例中,可以将第一网络701作为老师网络,将第二网络702作为学生网络,由第一网络701来指导第二网络702的训练,该过程可以称为蒸馏。本申请实施例中,该第一网络701指导第二网络702的思路包括如下三个技术点:
1、通过第一网络701的特征提取层和第二网络702的特征提取层分别提取特征信息,然后通过高斯掩膜突出将这两个网络的特征提取层提取的特征信息中关于目标物体的特征信息;再通过第一网络提取的关于目标物体的特征信息和第二网络提取的关于目标物体的特征信息确定特征损失,之后通过该特征损失指导第二网络702的特征提取层的训练。
2、第一网络701与第二网络702选取同样的区域提议集合,例如,通过共享区域提议网络(Region proposal network,RPN)的方式来使得该第一网络和第二网络均具有区域提议集合,因此第一网络701和第二网络702可以基于同样的区域提议生成软标签,然后基于第一网络701生成的软标签与第二网络702生成的软标签得到二值交叉熵损失(BCEloss),之后通过二值交叉熵损失(BCEloss)来指导第二网络702的分类层的训练。
3、采用渐进式蒸馏的方式指导第二网络702的训练。例如,假若上述第一网络701为一个101层(res101)的神经网络,第二网络702为一个50层(res50)的神经网络,那么基于上述第1和/或第2项的技术点通过第一网络701对第二网络702进行训练得到目标神经网络(可以标记为res101-50)之后,进一步通过第三神经网络对目标神经网络(res101-50)进行训练,通过第三神经网络对目标神经网络进行训练的原理,与通过第一网络701对第二网络702进行训练的原理相同,此处不再赘述。这里的第三神经网络是一个比第一网络701更大的神经网络,例如,该第三神经网络为一个152层(res152)的神经网络。
以上第1项、第2项、第3项的实现将在后续的方法实施例中做更详细的阐述。
该模型使用设备602为需要对图像进行识别(或者说检测)的设备,例如,手持设备(例如,手机、平板电脑、掌上电脑等)、车载设备(例如,汽车、自行车、电动车、飞机、船舶等)、可穿戴设备(例如智能手表(如iWatch等)、智能手环、计步器等)、智能家居设备(例如,冰箱、电视、空调、电表等)、智能机器人、车间设备,等等。
下面分别以模型使用设备602为汽车和手机为例进行举例说明。
汽车实现无人驾驶或者电脑驾驶,是目前非常流行的课题。随着经济的发展,全球汽车数量的不断增加,道路拥挤以及驾驶事故给人民财产和社会财产造成了很大的损失。人为因素是造成交通事故的主要因素,如何降低人为的失误,智能的避障以及合理的规划是提高驾驶安全系数的重要课题。自动驾驶的出现,给这一切带来了可能,它不需要人类操作即能感知周围的环境并导航。目前,全球各大公司都开始关注和开发自动驾驶系统,如谷歌、特斯拉、百度等。自动驾驶技术已成为各国争抢的战略制高点。由于相机设备具有价格便宜、使用方便等优势,构建以视觉感知为主的感知系统是很多公司的研发方向。如图8所示,是自动驾驶过程中车载相机采集的路面场景,汽车上的视觉感知系统充当人类的眼睛(即计算机视觉),通过检测网络(例如,后续方法实施例中提及的目标网络、第一网络、第二网络、第三网络、第四网络等,该检测网络也可以称作分类网络、或者检测模型,或者检测模块等)自动对相机采集的图像进行检测,从而确定汽车周围的物体和位置(即检测图像中的目标物体)。
例如,汽车通过检测网络识别图像中的物体,如果发现图像中有人,且距离汽车较近,那么汽车可以控制减速,或者控制停车,以避免造成人员伤亡;如果发现图像中有其他汽车,那么可以适当控制汽车的行驶速度,避免追尾;如果发现图像中有物体正在快速撞向汽车,那么可以控制器汽车通过移位或者变道等方式进行避让。
再如,汽车通过检测网络识别图像中的物体,如果发现路面有交通路线(如双黄线、单黄线、车道分界线等),那么可以对汽车的行驶状态进行预判,如果预判发现汽车可能会压线,那么可对汽车进行相应控制以避免压线;当然也可以在识别车道分界线及其位置的情况下,据此信息来决策如何进行变道,其余交通线路的控制依次类推。
再如,汽车通过检测网络识别图像中的物体,然后将识别出的某个物体作为标的来测算汽车的行驶速度、加速度、转弯角度等信息。
手机原本只是被作为一种通讯工具,方便人们的沟通。随着全球经济的发展,人们生活质量的提高,大家对手机的体验感以及性能的追求也越来越高。除了娱乐、导航、购物、拍照之外,检测和识别功能也受到了很大的关注。目前检测图像中的目标物体的识别技术已经在很多手机APP中应用,包括美图秀秀、魔漫相机、神拍手、Camera360、支付宝扫脸支付等等。开发者只需要调用授权的人脸检测、人脸关键点检测和人脸分析的移动端SDK包,就可以自动识别出照片、视频中的人脸身份(即检测图像中的目标物体)。如图9所示,是手机相机采集到的人脸图片。手机通过检测网络(例如,后续方法实施例中提及的目标网络,该检测网络也可以称作检测模型,或者检测模块,或者检测识别系统等)对图片中的物体进行定位和识别,找到人脸(还会获取人脸位置),同时判断图像中的人脸是否是某个特定人的人脸(即计算机视角),如果是就可以进行下一步的操作,如移动支付、人脸登录等多种功能。
设备(如汽车、手机等)通过计算机视角感知周围的环境信息,从而基于该环境信息进行相应的智能控制,基于计算机视角完成的智能控制是人工智能的一种实现。
请参见图10,图10是本申请实施例提供的一种模型训练方法的流程示意图,该方法可以基于图6所示的模型训练系统来实现,该方法包括但不限于如下步骤:
步骤S1001:模型训练设备分别通过第一网络的特征提取层和第二网络的特征提取层提取目标图像中的特征信息。
具体地,第一网络用于作为老师网络,第二网络用于作为学生网络,在第一网络指导第二网络训练的过程中,即模型蒸馏过程中,第一网络的特征提取层和第二网络的特征提取层针对同样的图像提取特征信息,为了方便描述,可以称该同样的图像可以称之为目标图像。
可选的,第一网络的层大于第二网络的层,例如,该第一网络可以为一个101层(res101)的神经网络,第二网络可以为一个50层(res50)的神经网络。
本申请实施例中的特征信息可以通过向量来表示,或者其他机器可识别的方式来表示。
步骤S1002:模型训练设备通过高斯掩膜突出第一特征信息中关于目标物体的特征。
本申请发明人发现,在基于神经网络的目标物体检测(即识别)过程中,检测性能的增益很大程度上来自于特征提取层(backbone层)提取的特征信息,因此对特征提取层的模拟是模型训练的重要环节。本申请实施例中,引入高斯掩膜实际就是突出目标图像中目标物体的特征,镇压目标物体以外的背景的特征的过程;实际也是突出对目标物体的响应,弱化边缘信息的过程。需要说明的是,采用高斯掩膜的方式不仅可以抑制目标物体所在的矩形框(一般为能够框柱目标物体的最小矩形框)外的背景的特征,还可以抑制矩形框内除目标物体以外的背景的特征,因此最大限度地突出了目标物体的特征。
如图11所示,通过高斯掩膜可以突出目标图像中的人物(即目标物体),弱化人物以外的背景的特征,表现在高斯掩膜的模型数据上就是目标物体对应的特征信息在坐标中形成一个较大的突起,背景的信息在坐标中的高度接近于零,近似于处在一个高度为零的平面上。
为了便于理解,下面例举针对目标图像的高斯掩膜定义,具体如公式1-1所示:
Figure BDA0002493947430000151
公式1-1中,(x,y)为目标图像中的像素点的坐标,B为该目标图像中的目标物体的一个正例区域提议,该正例区域提议B的几何规格为w×h,正例区域提议B的中心点坐标为(x0,y0),σx和σy分别为x轴和y轴上的衰减因子,可选的,为了方便起见,可以设置σx=σy,这个高斯掩膜只对目标真值框有效,框外的背景全部滤掉。当存在多个关于目标物体的正例区域提议时,目标物体中的像素点(x,y)可能存在多个高斯掩膜值(分别对应不同的正例区域提议),这时可选择该多个高斯掩膜值中的最大值作为该像素点(x,y)的高斯掩膜值M(x,y),高斯掩膜值M(x,y)可以通过公式1-2表示:
Figure BDA0002493947430000152
公式1-2中,Np为该目标图像中的目标物体的正例区域提议的数量,M1(x,y)是其中第一个正例区域提议中点(x,y)的高斯掩膜的值,M2(x,y)是其中第二个正例区域提议中点(x,y)的高斯掩膜的值,
Figure BDA0002493947430000153
是其中第Np个正例区域提议中点(x,y)的高斯掩膜的值,其余依次类推。从公式1-2以看出,目标物体中某个像素点(x,y)的高斯掩膜的值取多个值中的最大值。
为了便于描述,称通过第一网络的特征提取层提取到的所述目标图像中的特征信息为第一特征信息,称通过高斯掩膜突出的第一特征信息中关于目标物体的特征为第一局部特征。
步骤S1003:模型训练设备通过高斯掩膜突出第二特征信息中关于所述目标物体的特征。
为了便于描述,称通过第二网络的特征提取层提取到的所述目标图像中的特征信息为第二特征信息,称通过高斯掩膜突出的第二特征信息中关于目标物体的特征为第二局部特征。
步骤S1004:模型训练设备通过所述第一局部特征和所述第二局部特征确定特征损失。
本申请实施例中,上述第一局部特征为第一网络得到的针对目标图像中目标物体的特征,第二局部特征为第二网络得到的针对目标图像中目标物体的特征,第一局部特征与第二局部特征之间的差异,能够反映第一网络的特征提取层与第二网络的特征提取层之间的差异。本申请实施例中的特征损失(也称蒸馏损失)能够体现第二局部特征相较第一局部特征的差异。
可选的,第一局部特征可以表示为
Figure BDA0002493947430000154
第二局部特征可以表示为
Figure BDA0002493947430000155
可以通过公式1-3来计算该特征损失Lb
Figure BDA0002493947430000156
本申请实施例中,
Figure BDA0002493947430000157
引入A是为了实现归一化操作;其中目标图像的规格为WxH,Mij表示目标图像中一个像素点的高斯掩膜的值,i可以依此从1到W取值,j可以依此从1到H取值,目标图像中的任意一个像素点(i,j)的高斯掩膜的值Mij可以通过上述公式1-1和公式1-2计算得到,此处不再赘述。另外,
Figure BDA0002493947430000161
表示第二网络提取的像素点(i,j)的特征,
Figure BDA0002493947430000162
表示第一网络提取的像素点(i,j)的特征;C代表第一网络、第二网络提取目标图像中的特征信息时特征图的通道数。
步骤S1005:模型训练设备通过第一网络的分类层生成区域提议集合中的目标区域提议的分类预测值。
具体地,可以采用相应的算法或者策略来保证第一网络和第二网络均具有该区域提议集合,例如,第一网络和第二网络共享RPN来使得第一网络和第二网络均具有该区域提议集合。例如,共享的RPN中可以包括2000个区域提议,而该区域提议集合包括该2000个区域提议中的512个区域提议,可以在第一网络和第二网络中配置相同的检测器,使得第一网络和第二网络能够从2000个共享的区域提议中提取出同样的512个区域提议,即区域提议集合。图12示意了一种通过共享RPN以及配置同样的检测器来使得第一网络和第二网络选出同样的区域提议(proposal)的流程示意图,选出的全部区域提议统称为区域提议集合。
所述RPN为所述第一网络和所述第二网络共享的RPN,既可以是第二网络共享给第一网络的,也可以是第一网络共享给第二网络的,或者其他方式共享的。
在一种可选的方案中,目标区域提议为所述区域提议集合中的全部区域提议,即包括目标物体的正例区域提议和目标物体的负例区域提议。可选的,区域提议集合中哪些是正例区域提议,哪些是负例区域提议,可以是人为预先标记好的,也可以是机器自动标记好的;通常的划分标准是,如果哪个区域提议与目标物体所在的矩形框(通常为框柱该目标物体的最小矩形框)的重合度超过设定的参考阈值(例如,可设置为50%,或其他值),则可以将该区域提议归类为目标物体的正例区域提议,否则将该区域提议归类为目标物体的负例区域提议。
在又一种可选的方案中,目标区域提议为区域提议集合中属于所述目标物体的正例区域提议。
本申请实施例中,可以称通过第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值为第一分类预测值以方便后续描述。
步骤S1006:模型训练设备通过第二网络的分类层生成区域提议集合中的目标区域提议的分类预测值。
本申请实施例中,可以称通过第二网络的分类层生成的区域提议集合中的目标区域提议的分类预测值为第二分类预测值以方便后续描述。
无论是上述第一分类预测值,还是上述第二分类预测值,其均用于表征对相应区域提议的分类倾向,或者说概率,例如,第一网络的分类标签生成的区域提议1的分类预测值表征了该区域提议1内的物体被分类为人的概率为0.8,被分类为树的概率为0.3,被分类为汽车的概率为0.1。需要说明的是,不同的网络中的分类层针对同一个区域提议分类得到的分类预测值可能不同,因为不同网络的模型参数一般不同,它们的预测能力一般存在差异。
步骤S1007:模型训练设备根据第一分类预测值和第二分类预测值确定分类损失。
具体地,本申请实施例正是通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失,实际上是将第一预测值作为软标签来确定第二网络的分类层的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此模型训练效果更好。
在一种可选的方案中,该分类损失满足公式1-4所示的关系:
Figure BDA0002493947430000171
在公式1-4中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000172
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000173
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000174
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000175
表示基于
Figure BDA0002493947430000176
和ym得到的交叉熵损失,
Figure BDA0002493947430000177
表示基于
Figure BDA0002493947430000178
Figure BDA0002493947430000179
得到的二值交叉熵损失,β为预设的权重平衡因子。
步骤S1008:模型训练设备根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
本申请实施例中提到的根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络的含义是:目标网络是通过对第二网络进行训练得到的,训练过程中用到的参数包括但不限于该特征损失和该分类损失,即还可能用到这两项以外的其他参数;并且训练过程中可能仅用到了基于第一网络得到的该特征损失和该分类损失(即基于第一网蒸馏第二网络得到目标网络),也可能不仅用到了基于第一网络得到的该特征损失和该分类损失,还用到了基于其他网络(一个或多个)得到信息(即基于第一网络和其他网络蒸馏第二网络得到目标网络)。
可选的,在根据所述特征损失和所述分类损失训练所述第二网络的过程中,可以基于所述特征损失Lb和所述分类损失Lcls确定总损失L,然后通过总损失来训练该第二网络。可选的,也可以通过特征损失训练第二网络中的一部分模型参数(例如,特征提取层的模型参数),通过分类损失训练第二网络中的又一部分模型参数(例如,分类层的模型参数)。当然,还可以通过其他方式来使用特征损失和分类损失训练所述第二网络。
针对通过总损失L来训练第二网络的情况,下面例举两种可选的计算总损失的案例。
案例1,通过公式1-5计算总损失L。
L=δLb+Lcls 公式1-5
在公式1-5中,δ是预设或者预先训练出的权重平衡因子。
案例2,根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;即第二网络在不依赖第一网络的情况下,训练得到回归损失Lreg和RPN损失Lrpn,然后结合回归损失Lreq、RPN损失Lrpn、上述特征损失Lb和分类损失Lcls得到总损失L,可选的,总损失L的计算方式如公式1-6所示。
L=δLb+Lcls+Lreg+Lrpn 公式1-6
本申请实施例中,确定分类损失和特征损失的先后顺序不做限定,可以同时执行,也可以确定分类损失的流程先执行,还可以确定特征损失的流程先执行。
在一种可选的方案中,模型训练设备根据所述特征损失和所述分类损失训练所述第二网络,得到还不是最终的目标网络,而是一个中间网络,后续还要通过其他网络(如第三网络)再对该中间网络进行训练,这个过程可以看做是针对第二网络采用渐进式蒸馏。原理如下:在基于上述第一网络对第二网络进行蒸馏(即根据第一网络和第二网络得到特征损失和分类损失,然后根据特征损失和分类损失对第二网络进行训练)得到中间网络后,还可以进一步通过一个比第一网络的层更多的第三网络对该中间网络再次进行蒸馏,通过第三网络对中间网络进行蒸馏的原理,与通过第一网络对第二网络进行蒸馏的原理相同;后续还可以逐步使用层数更多的网络对新训练出网络做进一步蒸馏,直至对第二网络的蒸馏达到预期目标,从而得到目标网络。
例如,假若上述第一网络701为一个101层(res101)的神经网络,第二网络702为一个50层(res50)的神经网络,那么基于上述第1和/或第2项的技术点通过第一网络701对第二网络702进行蒸馏得到中间神经网络(可以标记为res101-50)之后,进一步通过第三神经网络对中间神经网络(res101-50)进行蒸馏(即依次通过第一网络701和第三网络对第二网络702进行蒸馏),通过第三神经网络对中间神经网络进行蒸馏的原理,与通过第一网络701对第二网络702进行蒸馏的原理相同,此处不再赘述。第三神经网络是一个比第一网络701更大的神经网络,例如,一个152层(res152)的神经网络。
需要说明的是,在图10所示的方法实施例中,可以采用上述例举的三个技术点中的一个,或者两个,或者三个。例如,可以采用渐进式蒸馏的技术点(上述三个技术点中的第1个),但是针对具体如何提取特征,以及具体如何使用区域提议不做特殊限定(即对是否使用上述三个技术点中的第2个,第3个不做限定)。
步骤S1009:模型训练设备向模型使用设备发送目标网络。
步骤S1010:模型使用设备接收模型训练设备发送的目标网络。
具体地,该模型使用设备接收到该目标网络后,通过该目标网络来预测(或者说检测,或者说估计)图像中的内容(即识别图像中的目标),例如,识别图像中的是否存在人脸,当存在时人脸在图像中的位置具体是什么;或者识别图像中是否存在道路障碍物,当存在时障碍物在图像中的位置是什么,等等。具体使用场景可参照图6所示的架构中关于模型使用设备602的介绍。
为了验证上述实施例的效果,本申请发明人在两个标准的检测数据集上进行了验证,这两个检测数据集为:COCO2017数据集、BDD100k数据集。其中COCO数据集包含了80个物体类别,11万张训练图片和5000张验证图片,BDD100k数据集包含了10个类别,一共具有10万张图片。对这两个数据集,都采用coco的评估标准进行评估,即类别平均准确度(mAP)。
表1展示了不同的蒸馏策略方案,其中的层数分别为res18,res50,res101,res152的网络(或说模型)都已经在COCO数据集上进行了预训练,只需采用上述实施例进行蒸馏即可。
表1不同的蒸馏策略
策略编号 老师网络 学生网络 蒸馏后的学生网络
1 res50 res18 res50-18
2 res101 res18 res101-18
3 res101 res50-18 res101-50-18
4 res152 res101-50-18 res152-101-50-18
在表1中,res50-18表示通过50层的网络对18层的网络进行蒸馏后得到的网络;res101-18表示通过101层的网络对18层的网络进行蒸馏后得到的网络;res101-50-18表示通过101层的网络对蒸馏后得到的网络res101-18进一步蒸馏所得到的网络;res152-101-50-18表示通过152层的网络对蒸馏后得到的网络res101-50-18进一步蒸馏所得到的网络。可选的,网络res50可以认为是上述第一网络,网络res18可以认为是上述第二网络,网络101可以认为是上述第三网络,网络res152可以认为是第四网络,该第四网络是一个比第三网络的层数更多的网络。在第二网络依次经过第一网络、第三网络、第四网络蒸馏之后得到的网络res152-101-50-18就可以发送给上述模型使用设备来对图像中的目标进行检测。
表2展示了不同网络在COCO数据集上的评估结果,其中,网络Res50-18检测的准确度相比于原始网络res18有了明显提高,提升了2.8个百分点,网络res101-18检测的准确度比网络res18的准确度提升了3.2个点,采用渐进式的蒸馏方法得到的网络res101-50-18比单次蒸馏得到的网络re50-18有了进一步的提升,值得一提的是,网络res152-101-50-18检测的准确度相比于网络res18提升得特别多,有4.4个百分点,并且蒸馏后的mAP达到了0.366,已经超越了网络res50检测的准确度0.364。换言之,虽然网络res18比网络res50具有更少的网络层数,且原始mAP相差的很多,有4.2个百分点,但是本申请实施例的方法通过对网络res18进行渐进式蒸馏,使得蒸馏后的网络res18超越了网络res50的性能。
表2在COCO数据集上不同网络的性能评估结果
网络 MAP AP50 AP75 Aps(小) Apm(中) Apl(大)
res18 0.322 0.534 0.34 0.183 0.354 0.411
res50 0.364 0.581 0.393 0.216 0.398 0.468
res101 0.385 0.602 0.417 0.225 0.429 0.492
res152 0.409 0.622 0.448 0.241 0.453 0.532
res50-18 0.35 0.56 0.373 0.187 0.384 0.459
res101-18 0.354 0.563 0.38 0.186 0.387 0.473
res101-50-18 0.358 0.567 0.387 0.184 0.392 0.479
res152-101-50-18 0.366 0.574 0.396 0.184 0.399 0.5
在表2中,MAP为平均精度均值,AP50为检测评价函数(Intersection over Union,IOU)大于0.5时的精度均值,AP75为IOU大于0.75时的精度均值,Aps为小物体的精度均值,Apm为中物体的精度均值,Apl为大物体的精度均值。
如表3所示。分别采用网络res50和网络res101作为老师网络,网络res18作为学生网络,+1表示只采用上述第1项技术点(即采用高斯掩膜突出目标物体),+2表示采用上述第2项技术点(即选取同样的区域提议集合),+3表示采用上述第3项技术点(即渐进式蒸馏)。其中网络res18(+1)和网络res18(+1+2)中的学生网络res18是未在COCO数据集上进行预训练的,res18(+1+2+3)的学生网络res18是在COCO上进行了预训练,有利于拉近它与老师网络之间的差异距离,相当于渐进式蒸馏的一种方案。可以看到,随着不断地改进,蒸馏的效果在逐步提升,也证明了上述三项技术点的有效性。
表3每个技术点的实验效果
res18 res18(+1) res18(+1+2) res18(+1+2+3)
老师网络res50 0.322 0.342 0.344 0.35
老师网络res101 0.322 0.347 0.349 0.354
为了验证本申请在不同数据上的适用性,在BDD100k数据集上也进行了对比实验,结果如表4所示。原始网络res18(作为学生网络)与网络res50(作为老师网络)之间有2.1个百分点的mAP(准确度)差距,采用本申请实施例的方法进行蒸馏后,蒸馏后得到的网络res50-18比原始的网络res18的检测mAP(准确度)提升了1.5个百分点,与老师网络res50只差0.6个百分点,弥补了近75%的mAP(准确度)差距,效果非常明显。
表4BDD100k数据集上的性能评估结果
网络 mAP AP50 AP75 Aps Apm Apl
res18 0.321 0.619 0.289 0.159 0.374 0.513
res50 0.342 0.644 0.314 0.171 0.399 0.546
res50-18 0.336 0.636 0.309 0.166 0.393 0.539
在图10所描述的方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
进一步地,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
进一步地,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
上述详细阐述了本申请实施例的方法,下面提供了本申请实施例的装置。
请参见图13,图13是本申请实施例提供的一种模型训练装置130的结构示意图,该模型训练装置130可以为上述方法实施例中的模型训练设备或者该模型训练设备中的器件,该模型训练装置130可以包括特征提取单元1301、第一优化单元1302、第二优化单元1303、第一确定单元1304和权重调整单元1305,其中,各个单元的详细描述如下。
特征提取单元1301,用于通过第一网络的特征提取层提取目标图像中的第一特征信息;
所述特征提取单元1301,还用于通过第二网络的特征提取层提取目标图像中的第二特征信息,其中,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
第一优化单元1302,用于通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
第二优化单元1303,用于通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
第一确定单元1304,用于通过所述第一局部特征和所述第二局部特征确定特征损失;
权重调整单元1305,用于根据所述特征损失训练所述第二网络,得到目标网络。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述装置还包括:
第一生成单元,用于通过所述第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
第二生成单元,用于通过所述第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
第二确定单元,用于根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述权重调整单元具体用于:根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述在根据所述特征损失训练所述第二网络,得到目标网络,所述权重调整单元具体用于:
根据所述特征损失训练所述第二网络;
通过第三网络对经过训练后的所述第二网络进行训练,得到目标网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式,或者第六方面的任一种可能的实现方式,在又一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000211
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000212
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000213
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000214
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000215
表示基于
Figure BDA0002493947430000216
Figure BDA0002493947430000217
得到的交叉熵损失,
Figure BDA0002493947430000218
表示基于
Figure BDA0002493947430000219
Figure BDA00024939474300002110
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中,所述装置还包括:
第三确定单元,用于根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
所述权重调整单元具体用于:根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
在一种可能的实现方式中,所述装置还包括:
发送单元,用于在所述权重调整单元根据所述特征损失训练所述第二网络,得到目标网络之后,向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
需要说明的是,各个单元的实现及有益效果还可以对应参照图10所示的方法实施例的相应描述。
请参见图14,图14是本申请实施例提供的一种模型训练装置140的结构示意图,该模型训练装置140可以为上述方法实施例中的模型训练设备或者该模型训练设备中的器件,该模型训练装置140可以包括第一训练单元1401和第二训练单元1402,其中,各个单元的详细描述如下。
第一训练单元1401,用于基于第一网络训练第二网络得到中间网络;
第二训练单元1402,用于基于第三网络训练所述中间网络,得到目标网络,其中,所述第一网络、所述第二网络和所述第三网络均为分类网络,且所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述基于第一网络训练第二网络得到中间网络包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述装置还包括:
第一生成单元,用于通过第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
第二生成单元,用于通过第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
第二确定单元,用于根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到所述中间网络,具体为:
根据所述特征损失和所述分类损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000231
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000232
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000233
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000234
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000235
表示基于
Figure BDA0002493947430000236
和ym得到的交叉熵损失,
Figure BDA0002493947430000237
表示基于
Figure BDA0002493947430000238
Figure BDA0002493947430000239
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中,所述装置还包括:
第三确定单元,用于根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
所述权重调整单元具体用于:根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
在一种可能的实现方式中,该装置还包括:
发送单元,用于在所述权重调整单元根据所述特征损失训练所述第二网络,得到目标网络之后,向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
需要说明的是,各个单元的实现及有益效果还可以对应参照图10所示的方法实施例的相应描述。
请参见图15,图15是本申请实施例提供的一种图像检测装置150的结构示意图,该图像检测装置150可以为上述方法实施例中的模型使用设备或者该模型使用设备中的器件,该图像检测装置150可以包括获取单元1501和识别单元1502,其中,各个单元的详细描述如下。
获取单元1501,用于获取目标网络,其中,所述目标网络为通过第一网络对第二网络进行训练后得到的网络,通过所述第一网络训练所述第二网络用到的参数包括特征损失,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
识别单元1502,用于通过所述目标网络识别图像中的内容。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,训练所述第二网络用到的参数还包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述目标网络具体为通过所述第一网络对第二网络进行训练,并通过第三网络对训练得到的网络进一步进行训练之后的网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000241
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000242
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000243
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000244
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000251
表示基于
Figure BDA0002493947430000252
和ym得到的交叉熵损失,
Figure BDA0002493947430000253
表示基于
Figure BDA0002493947430000254
Figure BDA0002493947430000255
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中:训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
在一种可能的实现方式中,所述获取单元具体用于:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
需要说明的是,各个单元的实现及有益效果还可以对应参照图10所示的方法实施例的相应描述。
请参见图16,图16是本申请实施例提供的一种图像检测装置160的结构示意图,该图像检测装置160可以为上述方法实施例中的模型使用设备或者该模型使用设备中的器件,该图像检测装置160可以包括获取单元1601和识别单元1602,其中,各个单元的详细描述如下。
获取单元1601,用于获取目标网络,其中,所述目标网络为通过多个网络迭代对第二网络进行训练得到的网络,所述多个网络均为分类网络,所述多个网络至少包括第一网络和第三网络,所述第三网络用于在所述第一网络对第二网络进行训练得到中间网络后对所述中间网络进行训练,其中,所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度;
识别单元1602,用于通过所述目标网络识别图像中的内容。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括特征损失,其中,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式,或者第八方面的任一种可能的实现方式,在又一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000261
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000262
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000263
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000264
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000265
表示基于
Figure BDA0002493947430000266
和ym得到的交叉熵损失,
Figure BDA0002493947430000267
表示基于
Figure BDA0002493947430000268
Figure BDA0002493947430000269
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中:训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
在一种可能的实现方式中,所述获取单元具体用于:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
需要说明的是,各个单元的实现及有益效果还可以对应参照图10所示的方法实施例的相应描述。
请参见图17,图17是本申请实施例提供的一种模型训练设备170的结构示意图,该模型训练设备170包括处理器1701、存储器1702和通信接口1703,所述处理器1701、存储器1702和通信接口1703通过总线相互连接。
存储器1702包括但不限于是随机存储记忆体(random access memory,RAM)、只读存储器(read-only memory,ROM)、可擦除可编程只读存储器(erasable programmableread only memory,EPROM)、或便携式只读存储器(compact disc read-only memory,CD-ROM),该存储器1702用于相关计算机程序及数据。通信接口1703用于接收和发送数据。
处理器1701可以是一个或多个中央处理器(central processing unit,CPU),在处理器1701是一个CPU的情况下,该CPU可以是单核CPU,也可以是多核CPU。
该模型训练设备170中的处理器1701用于读取所述存储器1702中存储的计算机程序代码,执行以下操作:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到目标网络。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述处理器还用于:
通过所述第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
通过所述第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,在根据所述特征损失训练所述第二网络,得到目标网络方面,所述处理器具体用于:
根据所述特征损失训练所述第二网络;
通过第三网络对经过训练后的所述第二网络进行训练,得到目标网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000271
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000281
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000282
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000283
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000284
表示基于
Figure BDA0002493947430000285
和ym得到的交叉熵损失,
Figure BDA0002493947430000286
表示基于
Figure BDA0002493947430000287
Figure BDA0002493947430000288
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中,所述处理器还用于:
根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
在一种可能的实现方式中,在根据所述特征损失训练所述第二网络,得到目标网络之后,所述处理器还用于:
通过通信接口1703向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
需要说明的是,各个操作的实现还可以对应参照图10所示的方法实施例的相应描述。
请参见图18,图18是本申请实施例提供的一种模型训练设备180的结构示意图,该模型训练设备180包括处理器1801、存储器1802和通信接口1803,所述处理器1801、存储器1802和通信接口1803通过总线相互连接。
存储器1802包括但不限于是随机存储记忆体(random access memory,RAM)、只读存储器(read-only memory,ROM)、可擦除可编程只读存储器(erasable programmableread only memory,EPROM)、或便携式只读存储器(compact disc read-only memory,CD-ROM),该存储器1802用于相关计算机程序及数据。通信接口1803用于接收和发送数据。
处理器1801可以是一个或多个中央处理器(central processing unit,CPU),在处理器1801是一个CPU的情况下,该CPU可以是单核CPU,也可以是多核CPU。
该模型训练设备180中的处理器1801用于读取所述存储器1802中存储的计算机程序代码,执行以下操作:
基于第一网络训练第二网络得到中间网络;本申请各个实施例中,基于第一网络训练第二网络实质就是通过第一网络蒸馏第二网络,以及基于第三网络训练已被第一网络训练的第二网络实质就是通过第三网络蒸馏已被第一网络蒸馏的第二网络,此处统一说明。
基于第三网络训练所述中间网络,得到目标网络,其中,所述第一网络、所述第二网络和所述第三网络均为分类网络,且所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,在基于第一网络训练第二网络得到中间网络方面,所述处理器具体用于:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述处理器1801还用于:
通过第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
通过第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
根据所述第一分类预测值和所述第二分类预测值确定分类损失;
在根据所述特征损失训练所述第二网络,得到所述中间网络方面,所述处理器具体用于:
根据所述特征损失和所述分类损失训练所述第二网络,得到所述中间网络。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000291
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000292
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000293
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000294
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000295
表示基于
Figure BDA0002493947430000296
和ym得到的交叉熵损失,
Figure BDA0002493947430000297
表示基于
Figure BDA0002493947430000298
Figure BDA0002493947430000299
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中,所述方法还包括:
根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
在一种可能的实现方式中,所述根据所述特征损失训练所述第二网络,得到目标网络之后,还包括:
通过通信接口1803向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
需要说明的是,各个操作的实现还可以对应参照图10所示的方法实施例的相应描述。
请参见图19,图19是本申请实施例提供的一种模型使用设备190的结构示意图,该模型使用设备190也可以称为图像检测设备或者其他名称,该模型使用设备190包括处理器1901、存储器1902和通信接口1903,所述处理器1901、存储器1902和通信接口1903通过总线相互连接。
存储器1902包括但不限于是随机存储记忆体(random access memory,RAM)、只读存储器(read-only memory,ROM)、可擦除可编程只读存储器(erasable programmableread only memory,EPROM)、或便携式只读存储器(compact disc read-only memory,CD-ROM),该存储器1902用于相关计算机程序及数据。通信接口1903用于接收和发送数据。
处理器1901可以是一个或多个中央处理器(central processing unit,CPU),在处理器1901是一个CPU的情况下,该CPU可以是单核CPU,也可以是多核CPU。
该模型使用设备190中的处理器1901用于读取所述存储器1902中存储的计算机程序代码,执行以下操作:
获取目标网络,其中,所述目标网络为通过第一网络对第二网络进行训练后得到的网络,通过所述第一网络训练所述第二网络用到的参数包括特征损失,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
通过所述目标网络识别图像中的内容。
在上述方法中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,训练所述第二网络用到的参数还包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述目标网络具体为通过所述第一网络对第二网络进行训练,并通过第三网络对训练得到的网络进一步进行训练之后的网络,其中,所述第三网络的深度大于所述第一网络的深度。
在该可能的实现方式中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000311
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000312
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000313
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000314
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000315
表示基于
Figure BDA0002493947430000316
和ym得到的交叉熵损失,
Figure BDA0002493947430000317
表示基于
Figure BDA0002493947430000318
Figure BDA0002493947430000319
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中,训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
在一种可能的实现方式中,在获取目标网络方面,所述处理器具体用于:
通过通信接口1903接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
需要说明的是,各个操作的实现还可以对应参照图10所示的方法实施例的相应描述。
请参见图20,图20是本申请实施例提供的一种模型使用设备200的结构示意图,该模型使用设备200也可以称为图像检测设备或者其他名称,该模型使用设备200包括处理器2001、存储器2002和通信接口2003,所述处理器2001、存储器2002和通信接口2003通过总线相互连接。
存储器2002包括但不限于是随机存储记忆体(random access memory,RAM)、只读存储器(read-only memory,ROM)、可擦除可编程只读存储器(erasable programmableread only memory,EPROM)、或便携式只读存储器(compact disc read-only memory,CD-ROM),该存储器2002用于相关计算机程序及数据。通信接口2003用于接收和发送数据。
处理器2001可以是一个或多个中央处理器(central processing unit,CPU),在处理器2001是一个CPU的情况下,该CPU可以是单核CPU,也可以是多核CPU。
该模型使用设备200中的处理器2001用于读取所述存储器2002中存储的计算机程序代码,执行以下操作:
获取目标网络,其中,所述目标网络为通过多个网络迭代对第二网络进行训练得到的网络,所述多个网络均为分类网络,所述多个网络至少包括第一网络和第三网络,所述第三网络用于在所述第一网络对第二网络进行训练得到中间网络后对所述中间网络进行训练,其中,所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度;
通过所述目标网络识别图像中的内容。
在上述方法中,在通过第一网络对第二网络进行训练之后,进一步使用层更多的第三网络对已训练的第二网络做进一步训练,能够稳定提升第二网络的性能。
在一种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括特征损失,其中,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息。
在该可能的实现方式中,通过高斯掩膜突出第一网络提取的特征信息中关于目标物体的局部特征,以及突出第二网络提取的特征信息中关于目标物体的局部特征,然后根据两网络中关于目标物体的局部特征确定特征损失,后续基于该特征损失对第二网络进行训练。通过高斯掩膜滤掉了图像的背景噪声(包括目标物体的方框外的背景噪声和方框内的背景噪声),在此基础上得到的特征损失更能够反映出第二网络与第一网络的差异,因此基于该特征损失对第二网络进行训练能够使得第二网络对特征的表达更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述第一网络对第二网络进行训练时用到的参数包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
在该可能的实现方式中,通过选取同样的区域提议集合的方式,使得第一网络的分类层和第二网络的分类层基于同样的区域提议来生成分类预测值,在基于的区域提议相同的情况下,两个网络生成的预测值的差异一般就是由于这两个网络的模型参数的差异导致的,因此本申请实施例基于第一预测值与第二预测值的差异确定用于训练第二网络的分类损失;通过这种方式可以最大程度得第二网络相对于第一网络的损失,因此基于分类损失对第二模型进行训练能够使得第二网络的分类结果更趋近于第一网络,模型蒸馏效果很好。
在一种可能的实现方式中,所述第一网络和所述第二网络通过共享区域提议网络RPN的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
在一种可能的实现方式中,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
在一种可能的实现方式中,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
在一种可能的实现方式中,所述分类损失Lcls满足如下关系:
Figure BDA0002493947430000331
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure BDA0002493947430000332
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure BDA0002493947430000333
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure BDA0002493947430000334
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure BDA0002493947430000335
表示基于
Figure BDA0002493947430000336
和ym得到的交叉熵损失,
Figure BDA0002493947430000337
表示基于
Figure BDA0002493947430000338
Figure BDA0002493947430000339
得到的二值交叉熵损失,β为预设的权重平衡因子。
在一种可能的实现方式中,训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
在一种可能的实现方式中,所述获取目标网络,包括:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
需要说明的是,各个操作的实现还可以对应参照图10所示的方法实施例的相应描述。
本申请实施例还提供一种芯片系统,所述芯片系统包括至少一个处理器,存储器和接口电路,所述存储器、所述收发器和所述至少一个处理器通过线路互联,所述至少一个存储器中存储有计算机程序;所述计算机程序被所述处理器执行时,图10所示的方法流程得以实现。
本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,当其在处理器上运行时,图10所示的方法流程得以实现。
本申请实施例还提供一种计算机程序产品,当所述计算机程序产品在处理器上运行时,图10所示的方法流程得以实现。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,该流程可以由计算机程序来计算机程序相关的硬件完成,该计算机程序可存储于计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法实施例的流程。而前述的存储介质包括:ROM或随机存储记忆体RAM、磁碟或者光盘等各种可存储计算机程序代码的介质。

Claims (51)

1.一种模型训练方法,其特征在于,包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息,其中,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到目标网络。
2.根据权利要求1所述的方法,其特征在于,所述方法还包括:
通过所述第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
通过所述第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
3.根据权利要求2所述的方法,其特征在于,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
4.根据权利要求3所述的方法,其特征在于,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
5.根据权利要求2-4任一项所述的方法,其特征在于,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
6.根据权利要求2-5任一项所述的方法,其特征在于,所述分类损失Lcls满足如下关系:
Figure FDA0002493947420000011
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure FDA0002493947420000012
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure FDA0002493947420000013
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure FDA0002493947420000014
所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure FDA0002493947420000015
表示基于
Figure FDA0002493947420000016
和ym得到的交叉熵损失,
Figure FDA0002493947420000017
表示基于
Figure FDA0002493947420000018
Figure FDA0002493947420000019
得到的二值交叉熵损失,β为预设的权重平衡因子。
7.根据权利要求2-6任一项所述的方法,其特征在于,所述方法还包括:
根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
8.根据权利要求1所述的方法,其特征在于,所述根据所述特征损失训练所述第二网络,得到目标网络,包括:
根据所述特征损失训练所述第二网络;
通过第三网络对经过训练后的所述第二网络进行训练,得到目标网络,其中,所述第三网络的深度大于所述第一网络的深度。
9.根据权利要求1-8任一项所述的方法,其特征在于,所述根据所述特征损失训练所述第二网络,得到目标网络之后,还包括:
向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
10.一种模型训练方法,其特征在于,包括:
基于第一网络训练第二网络得到中间网络;
基于第三网络训练所述中间网络,得到目标网络,其中,所述第一网络、所述第二网络和所述第三网络均为分类网络,且所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度。
11.根据权利要求10所述的方法,其特征在于,所述基于第一网络训练第二网络得到中间网络包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到所述中间网络。
12.根据权利要求11所述的方法,其特征在于,所述方法还包括:
通过第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
通过第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到所述中间网络,包括:
根据所述特征损失和所述分类损失训练所述第二网络,得到所述中间网络。
13.一种图像检测方法,其特征在于,包括:
获取目标网络,其中,所述目标网络为通过第一网络对第二网络进行训练后得到的网络,通过所述第一网络训练所述第二网络用到的参数包括特征损失,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
通过所述目标网络识别图像中的内容。
14.根据权利要求13所述的方法,其特征在于,训练所述第二网络用到的参数还包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
15.根据权利要求14所述的方法,其特征在于,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
16.根据权利要求15所述的方法,其特征在于,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
17.根据权利要求14-16任一项所述的方法,其特征在于,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
18.根据权利要求14-17任一项所述的方法,其特征在于,所述分类损失Lcls满足如下关系:
Figure FDA0002493947420000031
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure FDA0002493947420000032
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure FDA0002493947420000033
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure FDA0002493947420000034
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure FDA0002493947420000035
表示基于
Figure FDA0002493947420000036
和ym得到的交叉熵损失,
Figure FDA0002493947420000037
表示基于
Figure FDA0002493947420000038
Figure FDA0002493947420000039
得到的二值交叉熵损失,β为预设的权重平衡因子。
19.根据权利要求14-18任一项所述的方法,其特征在于:
训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
20.根据权利要求13-19任一项所述的方法,其特征在于,所述目标网络具体为通过所述第一网络对第二网络进行训练,并通过第三网络对训练得到的网络进一步进行训练之后的网络,其中,所述第三网络的深度大于所述第一网络的深度。
21.根据权利要求13-20任一项所述的方法,其特征在于,所述获取目标网络,包括:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
22.一种图像检测方法,其特征在于,包括:
获取目标网络,其中,所述目标网络为通过多个网络迭代对第二网络进行训练得到的网络,所述多个网络均为分类网络,所述多个网络至少包括第一网络和第三网络,所述第三网络用于在所述第一网络对第二网络进行训练得到中间网络后对所述中间网络进行训练,其中,所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度;
通过所述目标网络识别图像中的内容。
23.根据权利要求22所述的方法,其特征在于,所述第一网络对第二网络进行训练时用到的参数包括特征损失,其中,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息。
24.根据权利要求22或23所述的方法,其特征在于,所述第一网络对第二网络进行训练时用到的参数包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
25.一种模型训练装置,其特征在于,包括:
特征提取单元,用于通过第一网络的特征提取层提取目标图像中的第一特征信息;
所述特征提取单元,还用于通过第二网络的特征提取层提取目标图像中的第二特征信息,其中,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
第一优化单元,用于通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
第二优化单元,用于通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征;
第一确定单元,用于通过所述第一局部特征和所述第二局部特征确定特征损失;
权重调整单元,用于根据所述特征损失训练所述第二网络,得到目标网络。
26.根据权利要求25所述的装置,其特征在于,所述装置还包括:
第一生成单元,用于通过所述第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
第二生成单元,用于通过所述第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
第二确定单元,用于根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述权重调整单元具体用于:根据所述特征损失和所述分类损失训练所述第二网络,得到目标网络。
27.根据权利要求26所述的装置,其特征在于,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
28.根据权利要求27所述的装置,其特征在于,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
29.根据权利要求26-28任一项所述的装置,其特征在于,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
30.根据权利要求26-29任一项所述的装置,其特征在于,所述分类损失Lcls满足如下关系:
Figure FDA0002493947420000051
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure FDA0002493947420000052
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure FDA0002493947420000053
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure FDA0002493947420000054
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure FDA0002493947420000055
表示基于
Figure FDA0002493947420000056
和ym得到的交叉熵损失,
Figure FDA0002493947420000057
表示基于
Figure FDA0002493947420000058
Figure FDA0002493947420000059
得到的二值交叉熵损失,β为预设的权重平衡因子。
31.根据权利要求26-30任一项所述的装置,其特征在于,所述装置还包括:
第三确定单元,用于根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值,确定所述第二网络的回归损失和RPN损失;
所述权重调整单元具体用于:根据所述特征损失、所述分类损失、所述回归损失和所述RPN损失训练所述第二网络,得到目标网络。
32.根据权利要求25所述的装置,其特征在于,所述在根据所述特征损失训练所述第二网络,得到目标网络,所述权重调整单元具体用于:
根据所述特征损失训练所述第二网络;
通过第三网络对经过训练后的所述第二网络进行训练,得到目标网络,其中,所述第三网络的深度大于所述第一网络的深度。
33.根据权利要求25-32任一项所述的装置,其特征在于,还包括:
发送单元,用于在所述权重调整单元根据所述特征损失训练所述第二网络,得到目标网络之后,向模型使用设备发送所述目标网络,其中所述目标网络用于预测图像中的内容。
34.一种模型训练装置,其特征在于,包括:
第一训练单元,用于基于第一网络训练第二网络得到中间网络;
第二训练单元,用于基于第三网络训练所述中间网络,得到目标网络,其中,所述第一网络、所述第二网络和所述第三网络均为分类网络,且所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度。
35.根据权利要求34所述的装置,其特征在于,所述基于第一网络训练第二网络得到中间网络包括:
通过第一网络的特征提取层提取目标图像中的第一特征信息;
通过第二网络的特征提取层提取目标图像中的第二特征信息;
通过高斯掩膜提取所述第一特征信息中关于目标物体的特征,得到第一局部特征;
通过高斯掩膜提取所述第二特征信息中关于所述目标物体的特征,得到第二局部特征;
通过所述第一局部特征和所述第二局部特征确定特征损失;
根据所述特征损失训练所述第二网络,得到所述中间网络。
36.根据权利要求35所述的装置,其特征在于,所述装置还包括:
第一生成单元,用于通过第一网络的分类层生成区域提议集合中的目标区域提议的第一分类预测值;
第二生成单元,用于通过第二网络的分类层生成所述区域提议集合中的所述目标区域提议的第二分类预测值;
第二确定单元,用于根据所述第一分类预测值和所述第二分类预测值确定分类损失;
所述根据所述特征损失训练所述第二网络,得到所述中间网络,具体为:
根据所述特征损失和所述分类损失训练所述第二网络,得到所述中间网络。
37.一种图像检测装置,其特征在于,包括:
获取单元,用于获取目标网络,其中,所述目标网络为通过第一网络对第二网络进行训练后得到的网络,通过所述第一网络训练所述第二网络用到的参数包括特征损失,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息,所述第一网络和所述第二网络均为分类网络,且所述第一网络的深度大于所述第二网络的深度;
识别单元,用于通过所述目标网络识别图像中的内容。
38.根据权利要求37所述的装置,其特征在于,训练所述第二网络用到的参数还包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
39.根据权利要求38所述的装置,其特征在于,所述第一网络和所述第二网络通过共享区域提议网络(RPN)的方式使得所述第一网络和所述第二网络均具有所述区域提议集合。
40.根据权利要求39所述的装置,其特征在于,所述RPN为所述第二网络共享给所述第一网络的,或者为所述第一网络共享给所述第二网络的。
41.根据权利要求38-40任一项所述的装置,其特征在于,所述目标区域提议为所述区域提议集合中的全部区域提议,或者为所述区域提议集合中属于所述目标物体的正例区域提议。
42.根据权利要求28-41任一项所述的装置,其特征在于,所述分类损失Lcls满足如下关系:
Figure FDA0002493947420000071
其中,K为所述区域提议集合中区域提议的总数,Np为所述区域提议集合中属于所述目标物体的正例区域提议的总数,
Figure FDA0002493947420000072
为所述第二网络的分类层对所述区域提议集合中第m个区域提议预测的分类预测值,ym为所述区域提议集合中第m个区域提议对应的真值标签,
Figure FDA0002493947420000073
为所述第二网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第二分类预测值,
Figure FDA0002493947420000074
为所述第一网络的分类层对所述区域提议集合中第n个属于所述目标物体的正例区域提议预测的第一分类预测值,
Figure FDA0002493947420000075
表示基于
Figure FDA0002493947420000076
和ym得到的交叉熵损失,
Figure FDA0002493947420000077
表示基于
Figure FDA0002493947420000078
Figure FDA0002493947420000079
得到的二值交叉熵损失,β为预设的权重平衡因子。
43.根据权利要求38-42任一项所述的装置,其特征在于:
训练所述第二网络用到的参数还包括所述第二网络的回归损失和RPN损失,其中,所述第二网络的回归损失和RPN损失为根据所述目标图像中的区域提议的真值标签和所述第二网络对所述目标图像中的区域提议预测的预测值确定的。
44.根据权利要求37-43任一项所述的装置,其特征在于,所述目标网络具体为通过所述第一网络对第二网络进行训练,并通过第三网络对训练得到的网络进一步进行训练之后的网络,其中,所述第三网络的深度大于所述第一网络的深度。
45.根据权利要求37-44任一项所述的装置,其特征在于,所述获取单元具体用于:
接收模型训练设备发送的目标网络,其中所述模型训练设备用于训练得到所述目标网络。
46.一种图像检测装置,其特征在于,包括:
获取单元,用于获取目标网络,其中,所述目标网络为通过多个网络迭代对第二网络进行训练得到的网络,所述多个网络均为分类网络,所述多个网络至少包括第一网络和第三网络,所述第三网络用于在所述第一网络对第二网络进行训练得到中间网络后对所述中间网络进行训练,其中,所述第三网络的深度大于所述第一网络的深度,所述第一网络的深度大于所述第二网络的深度;
识别单元,用于通过所述目标网络识别图像中的内容。
47.根据权利要求46所述的装置,其特征在于,所述第一网络对第二网络进行训练时用到的参数包括特征损失,其中,所述特征损失为根据第一局部特征和第二局部特征确定的,所述第一局部特征为通过高斯掩膜从第一特征信息中提取的关于目标物体的特征,所述第二局部特征为通过高斯掩膜从第二特征信息中提取的关于所述目标物体的特征,所述第一特征信息为通过所述第一网络的特征提取层提取到的目标图像中的特征信息,所述第二特征信息为通过所述第二网络的特征提取层提取到的所述目标图像中的特征信息。
48.根据权利要求46或47所述的装置,其特征在于,所述第一网络对第二网络进行训练时用到的参数包括分类损失,其中,所述分类损失为根据第一分类预测值和第二分类预测值确定的,所述第一分类预测值为通过所述第一网络的分类层生成的区域提议集合中的目标区域提议的分类预测值,所述第二分类预测值为通过所述第二网络的分类层生成的所述区域提议集合中的所述目标区域提议的分类预测值。
49.一种模型训练设备,其特征在于,包括处理器和存储器,所述存储器用于存储计算机程序,所述处理器用于调用所述计算机程序来执行权利要求1-12任一项所述的方法。
50.一种模型使用设备,其特征在于,包括处理器和存储器,所述存储器用于存储计算机程序,所述处理器用于调用所述计算机程序来执行权利要求13-24任一项所述的方法。
51.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质用于存储计算机程序,当所述计算机程序在处理器上运行时,实现权利要求1-24任一项所述的方法。
CN202010412910.6A 2020-05-15 2020-05-15 一种模型训练方法及相关设备 Pending CN113673533A (zh)

Priority Applications (3)

Application Number Priority Date Filing Date Title
CN202010412910.6A CN113673533A (zh) 2020-05-15 2020-05-15 一种模型训练方法及相关设备
PCT/CN2021/088787 WO2021227804A1 (zh) 2020-05-15 2021-04-21 一种模型训练方法及相关设备
US17/986,081 US20230075836A1 (en) 2020-05-15 2022-11-14 Model training method and related device

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010412910.6A CN113673533A (zh) 2020-05-15 2020-05-15 一种模型训练方法及相关设备

Publications (1)

Publication Number Publication Date
CN113673533A true CN113673533A (zh) 2021-11-19

Family

ID=78526362

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010412910.6A Pending CN113673533A (zh) 2020-05-15 2020-05-15 一种模型训练方法及相关设备

Country Status (3)

Country Link
US (1) US20230075836A1 (zh)
CN (1) CN113673533A (zh)
WO (1) WO2021227804A1 (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114581751A (zh) * 2022-03-08 2022-06-03 北京百度网讯科技有限公司 图像识别模型的训练方法和图像识别方法、装置
CN117542085A (zh) * 2024-01-10 2024-02-09 湖南工商大学 基于知识蒸馏的园区场景行人检测方法、装置及设备

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114444558A (zh) * 2020-11-05 2022-05-06 佳能株式会社 用于对象识别的神经网络的训练方法及训练装置
JP2022090491A (ja) * 2020-12-07 2022-06-17 キヤノン株式会社 画像処理装置、画像処理方法、及びプログラム
CN115131357B (zh) * 2022-09-01 2022-11-08 合肥中科类脑智能技术有限公司 一种输电通道挂空悬浮物检测方法

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107247989B (zh) * 2017-06-15 2020-11-24 北京图森智途科技有限公司 一种实时的计算机视觉处理方法及装置
CN108921294A (zh) * 2018-07-11 2018-11-30 浙江大学 一种用于神经网络加速的渐进式块知识蒸馏方法
CN109961442B (zh) * 2019-03-25 2022-11-18 腾讯科技(深圳)有限公司 神经网络模型的训练方法、装置和电子设备
CN110472730A (zh) * 2019-08-07 2019-11-19 交叉信息核心技术研究院(西安)有限公司 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114581751A (zh) * 2022-03-08 2022-06-03 北京百度网讯科技有限公司 图像识别模型的训练方法和图像识别方法、装置
CN114581751B (zh) * 2022-03-08 2024-05-10 北京百度网讯科技有限公司 图像识别模型的训练方法和图像识别方法、装置
CN117542085A (zh) * 2024-01-10 2024-02-09 湖南工商大学 基于知识蒸馏的园区场景行人检测方法、装置及设备
CN117542085B (zh) * 2024-01-10 2024-05-03 湖南工商大学 基于知识蒸馏的园区场景行人检测方法、装置及设备

Also Published As

Publication number Publication date
US20230075836A1 (en) 2023-03-09
WO2021227804A1 (zh) 2021-11-18

Similar Documents

Publication Publication Date Title
CN113673533A (zh) 一种模型训练方法及相关设备
US10176405B1 (en) Vehicle re-identification techniques using neural networks for image analysis, viewpoint-aware pattern recognition, and generation of multi- view vehicle representations
EP4152204A1 (en) Lane line detection method, and related apparatus
CN111652114B (zh) 一种对象检测方法、装置、电子设备及存储介质
CN108805016B (zh) 一种头肩区域检测方法及装置
US9098744B2 (en) Position estimation device, position estimation method, and program
CN112926461B (zh) 神经网络训练、行驶控制方法及装置
Raja et al. SPAS: Smart pothole-avoidance strategy for autonomous vehicles
CN113159198A (zh) 一种目标检测方法、装置、设备及存储介质
CN115690714A (zh) 一种基于区域聚焦的多尺度道路目标检测方法
US20200160059A1 (en) Methods and apparatuses for future trajectory forecast
CN113189989B (zh) 车辆意图预测方法、装置、设备及存储介质
CN106462736A (zh) 用于人脸检测的处理设备和方法
Nejad et al. Vehicle trajectory prediction in top-view image sequences based on deep learning method
CN111178181B (zh) 交通场景分割方法及相关装置
WO2023179593A1 (zh) 数据处理方法及装置
JP7269694B2 (ja) 事象発生推定のための学習データ生成方法・プログラム、学習モデル及び事象発生推定装置
CN116501820A (zh) 车辆轨迹预测方法、装置、设备及存储介质
CN107451719B (zh) 灾区车辆调配方法和灾区车辆调配装置
Pan et al. A Hybrid Deep Learning Algorithm for the License Plate Detection and Recognition in Vehicle-to-Vehicle Communications
CN115408710A (zh) 一种图像脱敏方法和相关装置
CN116152675A (zh) 一种基于深度学习的无人机救援方法及系统
CN111062311B (zh) 一种基于深度级可分离卷积网络的行人手势识别与交互方法
CN113627332A (zh) 一种基于梯度控制联邦学习的分心驾驶行为识别方法
CN114863685B (zh) 一种基于风险接受程度的交通参与者轨迹预测方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination