CN109961442B - 神经网络模型的训练方法、装置和电子设备 - Google Patents
神经网络模型的训练方法、装置和电子设备 Download PDFInfo
- Publication number
- CN109961442B CN109961442B CN201910228494.1A CN201910228494A CN109961442B CN 109961442 B CN109961442 B CN 109961442B CN 201910228494 A CN201910228494 A CN 201910228494A CN 109961442 B CN109961442 B CN 109961442B
- Authority
- CN
- China
- Prior art keywords
- pixel
- neural network
- network model
- loss function
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2413—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
- G06F18/24147—Distances to closest patterns, e.g. nearest neighbour classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/10—Segmentation; Edge detection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/10—Image acquisition modality
- G06T2207/10004—Still image; Photographic image
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20081—Training; Learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20084—Artificial neural networks [ANN]
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种神经网络模型的训练方法、训练装置、电子设备和计算机可读存储介质。所述神经网络模型的训练方法,包括:通过第一神经网络模型提取训练图像的第一网络特征图;通过待训练的第二神经网络模型提取所述训练图像的第二网络特征图;基于所述第一网络特征图和所述第二网络特征图,确定像素级的分类损失函数;基于所述像素级的分类损失函数,训练所述第二神经网络模型。通过将大型神经网络模型学习到的知识传递给小型神经网络模型,指导小型神经网络模型训练,使得训练后的小型神经网络模型在参数量、运动速度不变的情况下,实现预测精度的显著提升。
Description
技术领域
本公开涉及图像处理领域,更具体地,本公开涉及一种用于图像语义分割的神经网络模型的训练方法、训练装置、电子设备和计算机可读存储介质。
背景技术
神经网络是一种大规模、多参数优化的工具。依靠大量的训练数据,神经网络能够学习出数据中难以总结的隐藏特征,从而完成多项复杂的任务,如图像语义分割、物体检测、动作追踪、自然语言翻译等。神经网络已被人工智能界广泛应用。
目前,在诸如图像语义分割等图像处理应用中,使用的神经网络模型(例如,ResNet101)通常具有数百层和数千通道,因而伴随着巨大的计算复杂度(例如,每秒数十亿次浮点运算(FLOPS)甚至更多),使得此类神经网络模型往往依赖于高性能的服务器集群以满足处理精度和运行速度的要求。随着诸如智能手机、无人车等移动终端对于基于神经网络模型的图像处理应用的需求的不断增加,需要在移动终端配置匹配移动终端处理能力的相对小型的神经网络模型(例如,ResNet18)以实现与服务器端接近的处理精度。
发明内容
鉴于上述问题而提出了本公开。本公开提供了一种神经网络模型的训练方法、训练装置、电子设备和计算机可读存储介质。
根据本公开的一个方面,提供了一种用于图像语义分割的神经网络模型的训练方法,包括:通过第一神经网络模型提取训练图像的第一网络特征图;通过待训练的第二神经网络模型提取所述训练图像的第二网络特征图;基于所述第一网络特征图和所述第二网络特征图,确定像素级的分类损失函数;基于所述像素级的分类损失函数,训练第二神经网络模型。
此外,根据本公开一个方面的训练方法,其中,所述像素级的分类损失函数包括:所述第一神经网络模型和所述第二神经网络模型之间的相反像素级分类损失函数,所述第一神经网络模型和所述第二神经网络模型之间的像素级知识逼近损失函数,以及所述第二神经网络模型自身的像素级分类损失函数。
此外,根据本公开一个方面的训练方法,其中,所述确定像素级的分类损失函数包括:基于所述第一网络特征图生成第一网络注意力图,并且基于所述第二网络特征图生成第二网络注意力图;所述第一网络注意力图与所述第二网络注意力图相减,生成所述第一神经网络模型和所述第二神经网络模型之间的掩膜特征图;所述第二网络特征图与所述掩膜特征图相乘,生成掩膜后的第二网络特征图;所述第一网络特征图与所述掩膜后的第二网络特征图相加作为训练用特征图,以所述第一神经网络模型的逐像素的分类损失的相反数作为所述相反像素级分类损失函数。
此外,根据本公开一个方面的训练方法,其中,所述确定像素级的分类损失函数包括:利用所述第一网络特征图确定所述第一神经网络模型的逐像素的第一分类结果,利用所述第二网络特征图确定所述第二神经网络模型的逐像素的第二分类结果;以所述逐像素的第一分类结果和所述逐像素的第二分类结果的交叉熵作为所述像素级知识逼近损失函数,以所述逐像素的第二分类结果和所述训练图像的交叉熵作为所述自身的像素级分类损失函数;以所述自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和作为所述像素级的分类损失函数。
此外,根据本公开一个方面的训练方法,其中,所述基于所述像素级的分类损失函数,训练第二神经网络模型包括以下的任一:基于所述相反像素级分类损失函数,训练第二神经网络模型;基于所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练第二神经网络模型;以及基于所述相反像素级分类损失函数、以及所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练第二神经网络模型。
根据本公开的另一个方面,提供了一种用于执行分类任务的神经网络模型的训练方法,包括:通过第一神经网络模型提取训练数据的第一网络特征图;通过待训练的第二神经网络模型提取所述训练数据的第二网络特征图;基于所述第一网络特征图和所述第二网络特征图,确定分类损失函数;基于所述分类损失函数,训练所述第二神经网络模型。
根据本公开的另一个方面,提供了一种用于图像语义分割的神经网络模型的训练装置,包括:特征图提取单元,用于通过第一神经网络模型提取训练图像的第一网络特征图,以及通过待训练的第二神经网络模型提取所述训练图像的第二网络特征图;损失函数确定单元,用于基于所述第一网络特征图和所述第二网络特征图,确定像素级的分类损失函数;训练单元,用于基于所述像素级的分类损失函数,训练第二神经网络模型。
此外,根据本公开另一个方面的训练装置,其中,所述像素级的分类损失函数包括:所述第一神经网络模型和所述第二神经网络模型之间的相反像素级分类损失函数,所述第一神经网络模型和所述第二神经网络模型之间的像素级知识逼近损失函数,以及所述第二神经网络模型自身的像素级分类损失函数。
此外,根据本公开另一个方面的训练装置,其中,所述损失函数确定单元用于:基于所述第一网络特征图生成第一网络注意力图,并且基于所述第二网络特征图生成第二网络注意力图;所述第一网络注意力图与所述第二网络注意力图相减,生成所述第一神经网络模型和所述第二神经网络模型之间的掩膜特征图;所述第二网络特征图与所述掩膜特征图相乘,生成掩膜后的第二网络特征图;所述第一网络特征图与所述掩膜后的第二网络特征图相加作为训练用特征图,以所述第一神经网络模型的逐像素的分类损失的相反数作为所述相反像素级分类损失函数。
此外,根据本公开另一个方面的训练装置,其中,所述损失函数确定单元用于:利用所述第一网络特征图确定所述第一神经网络模型的逐像素的第一分类结果,利用所述第二网络特征图确定所述第二神经网络模型的逐像素的第二分类结果;以所述逐像素的第一分类结果和所述逐像素的第二分类结果的交叉熵作为所述像素级知识逼近损失函数,以所述逐像素的第二分类结果和所述训练图像的交叉熵作为所述自身的像素级分类损失函数;以所述自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和作为所述像素级的分类损失函数。
此外,根据本公开另一个方面的训练装置,其中,所述训练单元用于执行以下的任一训练:基于所述相反像素级分类损失函数,训练第二神经网络模型;基于所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练第二神经网络模型;以及基于所述相反像素级分类损失函数、以及所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练第二神经网络模型。
根据本公开的又一个方面,提供了一种电子设备,包括:处理器;以及存储器,用于存储计算机程序指令;其中,当所述计算机程序指令由所述处理器加载并运行时,所述处理器执行如上所述的训练方法。
根据本公开的再一个方面,提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序指令,其中,所述计算机程序指令被处理器加载并运行时,所述处理器执行如上所述的训练方法。
如以下将详细描述的,根据本公开实施例的用于图像语义分割的神经网络模型的训练方法、训练装置、电子设备和计算机可读存储介质,通过在云端服务器训练好参数量大、速度慢、精度高的大型神经网络模型,将大型神经网络模型学习到的知识传递给参数量小、速度快、精度低的小型神经网络模型,指导小型神经网络模型训练,依靠与大型神经网络模型相同的训练数据,使得训练后的小型神经网络模型在参数量、运动速度不变的情况下,实现预测精度的显著提升。此外,通过指导小型神经网络模型着重学习大型神经网络模型对于训练图像中难像素区域的知识,使得小型神经网络模型有针对性地学习像素级的部分难像素区域,从而提升小型神经网络模型的预测精度。
要理解的是,前面的一般描述和下面的详细描述两者都是示例性的,并且意图在于提供要求保护的技术的进一步说明。
附图说明
通过结合附图对本公开实施例进行更详细的描述,本公开的上述以及其它目的、特征和优势将变得更加明显。附图用来提供对本公开实施例的进一步理解,并且构成说明书的一部分,与本公开实施例一起用于解释本公开,并不构成对本公开的限制。在附图中,相同的参考标号通常代表相同部件或步骤。
图1是概述根据本公开实施例的神经网络模型的训练方法的应用场景的示意图;
图2是图示根据本公开实施例的神经网络模型的训练方法的流程图;
图3是进一步图示根据本公开实施例的神经网络模型的训练方法的流程图;
图4是进一步图示根据本公开实施例的神经网络模型的训练方法的示意图;
图5是进一步图示根据本公开实施例的神经网络模型的训练方法的流程图;
图6是进一步图示根据本公开实施例的神经网络模型的训练方法的示意图;
图7是进一步图示根据本公开实施例的神经网络模型的训练方法的示意图;
图8是图示根据本公开实施例的神经网络模型的训练装置的功能框图;
图9是图示根据本公开实施例的电子设备的硬件框图;以及
图10是图示根据本公开的实施例的计算机可读存储介质的示意图。
具体实施方式
为了使得本公开的目的、技术方案和优点更为明显,下面将参照附图详细描述根据本公开的示例实施例。显然,所描述的实施例仅仅是本公开的一部分实施例,而不是本公开的全部实施例,应理解,本公开不受这里描述的示例实施例的限制。
首先,参照图1示意性地描述根据本公开实施例的神经网络模型的训练方法的应用场景。
如图1的(A)所示,对于第一设备10,利用训练数据30训练以获得训练好的第一神经网络20。在本公开的实施例中,第一设备10是云端服务器设备,训练好的第一神经网络模型20是参数量大、速度慢、精度高的大型神经网络模型。训练好的第一神经网络模型20用于图像语义分割、物体检测、物体跟踪等任务。例如,第一神经网络模型20是ResNet101模型,其参数量大(模型大小为170 M),运行速度慢(NVIDIA GTX1080显卡上,256×256大小图片需要156ms进行预测),精度高(Cityscapes数据集分割精度为76%)。
如图1的(B)所示,使用相同的训练数据30,将训练好的第一神经网络模型20学习到的知识传递给第二神经网络模型40,指导第二神经网络模型40训练。在本公开的实施例中,第二神经网络模型40是参数量小、速度快、精度低的小型神经网络模型。例如,第二神经网络模型40是ResNet18模型,其参数量大(模型大小为45 M),运行速度快(NVIDIA GTX1080显卡上,256×256大小图片需要31ms进行预测),精度较低(普通训练后,Cityscapes数据集分割精度为68%)。如下将详细描述的,在根据本公开实施例的神经网络模型的训练方法中,通过基于第一神经网络模型20得到的第一网络特征图和第二神经网络模型40得到的第二网络特征图,确定像素级的分类损失函数50。一方面,可以利用增大第一神经网络模型20的像素级分类损失,使第二神经网络模型40特征与第一神经网络模型20更加同质化;另一方面,利用知识差距感知作为分割置信度模仿学习的权重,侧重让第二神经网络模型40对于知识差距大的像素从第一神经网络模型20学习更多知识,得到训练后的第二神经网络模型40。例如,在第二神经网络模型40是ResNet18模型,参数量和运行速度完全一样的情况下,Cityscapes数据集分割精度可提升到73%,远高于未经过本公开实施例的神经网络模型的训练方法所得到的ResNet18模型,接近作为第一神经网络模型20的ResNet101模型的分割精度。
如图1的(C)所示,将训练好的第二神经网络模型40部署到第二设备60。在本公开的实施例中,第二设备60是诸如智能手机、无人车等的移动终端设备。将预测精度提高后的第二神经网络模型40部署到第二设备60进行实时预测识别,既能满足移动端设备对速度和内存空间的苛刻要求,又能提供较高的预测精度。
以上,参照图1概述了根据本公开实施例的神经网络模型的训练方法的应用场景。以下,将参照图2到图7,详细描述根据本公开实施例的神经网络模型的训练方法。
图2是图示根据本公开实施例的神经网络模型的训练方法的流程图。容易理解的是,图2所示的根据本公开实施例的神经网络模型的训练方法对应于上述图1的(B)中示出的过程。
如图2所示,根据本公开实施例的神经网络模型的训练方法包括以下步骤。
在步骤S201中,通过第一神经网络模型提取训练图像的第一网络特征图。
在步骤S202中,通过待训练的第二神经网络模型提取训练图像的第二网络特征图。需要理解的是,在本公开各附图中,并不将本公开限制到各步骤编号所排列的执行顺序,而是各步骤可以以不同的顺序执行、并行执行、分解和/或重新组合来执行。这些以不同的顺序执行、并行执行、分解和/或重新组合来执行各步骤应视为本公开的等效方案。
在步骤S203中,基于第一网络特征图和第二网络特征图,确定像素级的分类损失函数。如下将进一步详细描述的,所述像素级的分类损失函数包括:所述第一神经网络模型和所述第二神经网络模型之间的相反像素级分类损失函数,所述第一神经网络模型和所述第二神经网络模型之间的像素级知识逼近损失函数,以及所述第二神经网络模型自身的像素级分类损失函数。
在步骤S204中,基于像素级的分类损失函数,训练第二神经网络模型。在本公开的实施例中,当像素级的分类损失函数在训练过程中收敛时,则可停止训练过程。
更一般地,根据本公开实施例的神经网络模型的训练方法是用于执行分类任务的神经网络模型的训练方法,其包括:通过第一神经网络模型提取训练数据的第一网络特征图;通过待训练的第二神经网络模型提取所述训练数据的第二网络特征图;基于所述第一网络特征图和所述第二网络特征图,确定分类损失函数;基于所述分类损失函数,训练所述第二神经网络模型。
如参照图1所示,在本公开的实施例中,第一神经网络模型是参数量大、速度慢、精度高的大型神经网络模型,第二神经网络模型是参数量小、速度快、精度低的小型神经网络模型。
以下,进一步参照图3和图4、图5和图6以及图7描述根据本公开实施例的神经网络模型的训练方法的进一步具体实施例。
图3和图4用于描述利用增大第一神经网络模型的像素级分类损失,使第二神经网络模型特征与第一神经网络模型更加同质化的具体实施例。
图3的步骤S301和S302分别与图2所示的相同。即,在步骤S301中,通过第一神经网络模型402提取训练图像401的第一网络特征图404。在步骤S302中,通过待训练的第二神经网络模型403提取训练图像401的第二网络特征图405。
在步骤S303中,基于第一网络特征图404生成第一网络注意力图406,并且基于第二网络特征图405生成第二网络注意力图407。具体地,注意力图的计算表达式为:
其中,|Fi|为第一网络注意力图406和第二网络注意力图407中的第i个特征图。
在步骤S304中,第一网络注意力图406与第二网络注意力图407相减,生成第一神经网络模型402和第二神经网络模型403之间的掩膜特征图408。具体地,掩膜特征图408的计算表达式为:
M=|As-At|, 表达式(2)
其中,As为第二网络注意力图407,At为第一网络注意力图406。
在步骤S305中,第二网络特征图405与掩膜特征图408相乘,生成掩膜后的第二网络特征图409。
在步骤S306中,第一网络特征图404与掩膜后的第二网络特征图409相加作为训练用特征图,训练输出分割结果410。在该训练步骤中,以第一神经网络模型402的逐像素的分类损失的相反数作为相反像素级分类损失函数411。具体地,相反像素级分类损失函数411的计算表达式为:
Lteacher=-H(softmax(Zteacher,y)), 表达式(3)
其中,H函数为交叉熵函数,Zteacher为第一神经网络模型402在利用softmax函数归一化前的预测分类置信度,y为训练数据中标注的分类真值。在本公开的实施例中,预测分类置信度和标注的分类真值为像素级,以获得该逐像素的分类损失。
在步骤S307中,基于像素级的分类损失函数411,训练第二神经网络模型403。在图3和图4所示的实施例中,像素级的分类损失函数411由表达式(3)表示。
具体地,当增大第一神经网络模型402像素级分类损失时,由于训练好的第一神经网络模型402全部参数均固定,损失梯度来自于第一神经网络模型402的全部像素分类交叉熵的相反数。第二神经网络模型403在此训练过程中,损失梯度从第一神经网络模型402的全部像素分类交叉熵相反数传播至掩膜后的第二网络特征图409,再通过掩膜传播至第二网络注意力图407,而后到第二网络特征图405(如图4中虚线所示)。也就是说,利用增大第二神经网络模型403像素级分类损失,能使掩膜后的第二网络特征图409与第一网络特征图404更加接近,该过程的损失梯度由于通过掩膜传导至第二网络特征图405,可使第二网络特征图405与第一网络特征图404更加接近。
图5和图6用于描述利用知识差距感知作为分割置信度模仿学习的权重,侧重让第二神经网络模型对于知识差距大的像素从第一神经网络模型学习更多知识的进一步具体实施例。
图5的步骤S501和S502分别与图2所示的相同。即,在步骤S501中,通过第一神经网络模型402提取训练图像401的第一网络特征图404。在步骤S502中,通过待训练的第二神经网络模型403提取训练图像401的第二网络特征图405。
在步骤S503中,利用第一网络特征图404确定第一神经网络模型402的逐像素的第一分类结果601,利用第二网络特征图405确定第二神经网络模型403的逐像素的第二分类结果602。
在步骤S504中,以逐像素的第一分类结果601和逐像素的第二分类结果602的交叉熵作为像素级知识逼近损失函数603,以逐像素的第二分类结果602和训练图像401的交叉熵作为自身的像素级分类损失函数604。
具体地,假定Zteacher和Zstudent为第一神经网络模型402和第二神经网络模型403在softmax函数归一化前的预测分类置信度,T为一个缩放第一神经网络模型402分类置信度的参数,则Pstudent=softmax(Zstudent)为第二神经网络模型403的第二分类结果602,Pteacher=softmax(Zteacher/T)为第一神经网络模型402用于指导第二神经网络模型403模仿学习的第二分类结果602。在训练第二神经网络模型403时,损失函数的计算表达式为:
其中,H为交叉熵函数,n为像素编号,N为总像素个数,μ为调节第二神经网络模型403从第一神经网络模型402处学习知识和从训练数据的标注真值中学习知识权重的系数,wn为大小网络知识差距感知的权重,其表达式为
如此,容易理解的是,μwnHsoft(Pstudent,n,Pteacher,n)表示以逐像素的第一分类结果601和逐像素的第二分类结果602的交叉熵作为像素级知识逼近损失函数603。Hhard(Pstudent,n,yn)表示以逐像素的第二分类结果602和训练图像401的交叉熵作为自身的像素级分类损失函数604。
在步骤S506中,基于像素级的分类损失函数411,训练第二神经网络模型403。在图5和图6所示的实施例中,像素级的分类损失函数411由表达式(4)表示。
以上图3和图4以及图5和图6分别示出了以第一神经网络模型402两种不同的方式指导训练第二神经网络模型403的示例。本公开的实施例不限于此,如图7所示,在根据本公开实施例的神经网络模型的训练方法中,可以同时可以利用增大第一神经网络模型402的像素级分类损失,使第二神经网络模型403特征与第一神经网络模型402更加同质化;以及利用知识差距感知作为分割置信度模仿学习的权重,侧重让第二神经网络模型403对于知识差距大的像素从第一神经网络模型402学习更多知识这两种优化的训练方式。
图8是图示根据本公开实施例的神经网络模型的训练装置的功能框图。如图8所示,根据本公开实施例的训练装置80包括特征图提取单元801、损失函数确定单元802、以及训练单元803。上述各模块可以分别执行如上参照图2到图7描述的根据本公开的实施例的神经网络模型的训练方法的各个步骤。本领域的技术人员理解:这些单元模块可以单独由硬件、单独由软件或者由其组合以各种方式实现,并且本公开不限于它们的任何一个。
特征图提取单元801用于通过第一神经网络模型提取训练图像的第一网络特征图,以及通过待训练的第二神经网络模型提取所述训练图像的第二网络特征图。
损失函数确定单元802用于基于所述第一网络特征图和所述第二网络特征图,确定像素级的分类损失函数。
训练单元803用于基于所述像素级的分类损失函数,训练第二神经网络模型。
更具体地,所述损失函数确定单元802用于:基于所述第一网络特征图生成第一网络注意力图,并且基于所述第二网络特征图生成第二网络注意力图;所述第一网络注意力图与所述第二网络注意力图相减,生成所述第一神经网络模型和所述第二神经网络模型之间的掩膜特征图;所述第二网络特征图与所述掩膜特征图相乘,生成掩膜后的第二网络特征图;所述第一网络特征图与所述掩膜后的第二网络特征图相加作为训练用特征图,以所述第一神经网络模型的逐像素的分类损失的相反数作为所述相反像素级分类损失函数。
此外,所述损失函数确定单元802用于:利用所述第一网络特征图确定所述第一神经网络模型的逐像素的第一分类结果,利用所述第二网络特征图确定所述第二神经网络模型的逐像素的第二分类结果;以所述逐像素的第一分类结果和所述逐像素的第二分类结果的交叉熵作为所述像素级知识逼近损失函数,以所述逐像素的第二分类结果和所述训练图像的交叉熵作为所述自身的像素级分类损失函数;以所述自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和作为所述像素级的分类损失函数。
所述训练单元803用于执行以下的任一训练:基于所述相反像素级分类损失函数,训练第二神经网络模型;基于所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练第二神经网络模型;以及基于所述相反像素级分类损失函数、以及所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练第二神经网络模型。
图9是图示根据本公开实施例的电子设备900的硬件框图。根据本公开实施例的电子设备至少包括处理器;以及存储器,用于存储计算机程序指令。当计算机程序指令由处理器加载并运行时,所述处理器执行如上所述的神经网络模型的训练方法。
图9所示的电子设备900具体地包括:中央处理单元(CPU)901、图形处理单元(GPU)902和主存储器903。这些单元通过总线904互相连接。中央处理单元(CPU)901和/或图形处理单元(GPU)902可以用作上述处理器,主存储器903可以用作上述存储计算机程序指令的存储器。此外,电子设备900还可以包括通信单元905、存储单元906、输出单元907、输入单元908和外部设备909,这些单元也连接到总线904。
图10是图示根据本公开的实施例的计算机可读存储介质的示意图。如图10所示,根据本公开实施例的计算机可读存储介质1000其上存储有计算机程序指令1001。当所述计算机程序指令1001由处理器运行时,执行参照以上附图描述的根据本公开实施例的神经网络模型的训练方法。所述计算机可读存储介质包括但不限于例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存、光盘、磁盘等。
以上,参照附图描述了根据本公开实施例的用于图像语义分割的神经网络模型的训练方法、训练装置、电子设备和计算机可读存储介质,通过在云端服务器训练好参数量大、速度慢、精度高的大型神经网络模型,将大型神经网络模型学习到的知识传递给参数量小、速度快、精度低的小型神经网络模型,指导小型神经网络模型训练,依靠与大型神经网络模型相同的训练数据,使得训练后的小型神经网络模型在参数量、运动速度不变的情况下,实现预测精度的显著提升。此外,通过指导小型神经网络模型着重学习大型神经网络模型对于训练图像中难像素区域的知识,使得小型神经网络模型有针对性地学习像素级的部分难像素区域,从而提升小型神经网络模型的预测精度。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
以上结合具体实施例描述了本公开的基本原理,但是,需要指出的是,在本公开中提及的优点、优势、效果等仅是示例而非限制,不能认为这些优点、优势、效果等是本公开的各个实施例必须具备的。另外,上述公开的具体细节仅是为了示例的作用和便于理解的作用,而非限制,上述细节并不限制本公开为必须采用上述具体的细节来实现。
本公开中涉及的器件、装置、设备、系统的方框图仅作为例示性的例子并且不意图要求或暗示必须按照方框图示出的方式进行连接、布置、配置。如本领域技术人员将认识到的,可以按任意方式连接、布置、配置这些器件、装置、设备、系统。诸如“包括”、“包含”、“具有”等等的词语是开放性词汇,指“包括但不限于”,且可与其互换使用。这里所使用的词汇“或”和“和”指词汇“和/或”,且可与其互换使用,除非上下文明确指示不是如此。这里所使用的词汇“诸如”指词组“诸如但不限于”,且可与其互换使用。
另外,如在此使用的,在以“至少一个”开始的项的列举中使用的“或”指示分离的列举,以便例如“A、B或C的至少一个”的列举意味着A或B或C,或AB或AC或BC,或ABC(即A和B和C)。此外,措辞“示例的”不意味着描述的例子是优选的或者比其他例子更好。
还需要指出的是,在本公开的系统和方法中,各部件或各步骤是可以分解和/或重新组合的。这些分解和/或重新组合应视为本公开的等效方案。
可以不脱离由所附权利要求定义的教导的技术而进行对在此所述的技术的各种改变、替换和更改。此外,本公开的权利要求的范围不限于以上所述的处理、机器、制造、事件的组成、手段、方法和动作的具体方面。可以利用与在此所述的相应方面进行基本相同的功能或者实现基本相同的结果的当前存在的或者稍后要开发的处理、机器、制造、事件的组成、手段、方法或动作。因而,所附权利要求包括在其范围内的这样的处理、机器、制造、事件的组成、手段、方法或动作。
提供所公开的方面的以上描述以使本领域的任何技术人员能够做出或者使用本公开。对这些方面的各种修改对于本领域技术人员而言是非常显而易见的,并且在此定义的一般原理可以应用于其他方面而不脱离本公开的范围。因此,本公开不意图被限制到在此示出的方面,而是按照与在此公开的原理和新颖的特征一致的最宽范围。
为了例示和描述的目的已经给出了以上描述。此外,此描述不意图将本公开的实施例限制到在此公开的形式。尽管以上已经讨论了多个示例方面和实施例,但是本领域技术人员将认识到其某些变型、修改、改变、添加和子组合。
Claims (10)
1.一种用于图像语义分割的神经网络模型的训练方法,包括:
通过第一神经网络模型提取训练图像的第一网络特征图;
通过待训练的第二神经网络模型提取所述训练图像的第二网络特征图;
基于所述第一网络特征图和所述第二网络特征图,确定像素级的分类损失函数;
基于所述像素级的分类损失函数,训练所述第二神经网络模型,
其中,所述像素级的分类损失函数包括相反像素级分类损失函数,
其中,所述确定像素级的分类损失函数包括:
基于所述第一网络特征图生成第一网络注意力图,并且基于所述第二网络特征图生成第二网络注意力图;
所述第一网络注意力图与所述第二网络注意力图相减,生成所述第一神经网络模型和所述第二神经网络模型之间的掩膜特征图;
所述第二网络特征图与所述掩膜特征图相乘,生成掩膜后的第二网络特征图;
所述第一网络特征图与所述掩膜后的第二网络特征图相加作为训练用特征图,以所述第一神经网络模型的逐像素的分类损失的相反数作为所述相反像素级分类损失函数。
2.如权利要求1所述的训练方法,其中,所述像素级的分类损失函数还包括:
所述第一神经网络模型和所述第二神经网络模型之间的像素级知识逼近损失函数,以及所述第二神经网络模型自身的像素级分类损失函数。
3.如权利要求2所述的训练方法,其中,所述确定像素级的分类损失函数还包括:
利用所述第一网络特征图确定所述第一神经网络模型的逐像素的第一分类结果,利用所述第二网络特征图确定所述第二神经网络模型的逐像素的第二分类结果;
以所述逐像素的第一分类结果和所述逐像素的第二分类结果的交叉熵作为所述像素级知识逼近损失函数,以所述逐像素的第二分类结果和所述训练图像的交叉熵作为所述自身的像素级分类损失函数;
以所述自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和作为所述像素级的分类损失函数。
4.如权利要求2或3所述的训练方法,其中,所述基于所述像素级的分类损失函数,训练所述第二神经网络模型包括以下的任一:
基于所述相反像素级分类损失函数,训练所述第二神经网络模型;
基于所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练所述第二神经网络模型;以及
基于所述相反像素级分类损失函数、以及所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练所述第二神经网络模型。
5.一种用于图像语义分割的神经网络模型的训练装置,包括:
特征图提取单元,用于通过第一神经网络模型提取训练图像的第一网络特征图,以及通过待训练的第二神经网络模型提取所述训练图像的第二网络特征图;
损失函数确定单元,用于基于所述第一网络特征图和所述第二网络特征图,确定像素级的分类损失函数;
训练单元,用于基于所述像素级的分类损失函数,训练所述第二神经网络模型,
其中,所述像素级的分类损失函数包括相反像素级分类损失函数,
其中,所述损失函数确定单元用于:
基于所述第一网络特征图生成第一网络注意力图,并且基于所述第二网络特征图生成第二网络注意力图;
所述第一网络注意力图与所述第二网络注意力图相减,生成所述第一神经网络模型和所述第二神经网络模型之间的掩膜特征图;
所述第二网络特征图与所述掩膜特征图相乘,生成掩膜后的第二网络特征图;
所述第一网络特征图与所述掩膜后的第二网络特征图相加作为训练用特征图,以所述第一神经网络模型的逐像素的分类损失的相反数作为所述相反像素级分类损失函数。
6.如权利要求5所述的训练装置,其中,所述像素级的分类损失函数还包括:
所述第一神经网络模型和所述第二神经网络模型之间的像素级知识逼近损失函数,以及所述第二神经网络模型自身的像素级分类损失函数。
7.如权利要求6所述的训练装置,其中,所述损失函数确定单元还用于:
利用所述第一网络特征图确定所述第一神经网络模型的逐像素的第一分类结果,利用所述第二网络特征图确定所述第二神经网络模型的逐像素的第二分类结果;
以所述逐像素的第一分类结果和所述逐像素的第二分类结果的交叉熵作为所述像素级知识逼近损失函数,以所述逐像素的第二分类结果和所述训练图像的交叉熵作为所述自身的像素级分类损失函数;
以所述自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和作为所述像素级的分类损失函数。
8.如权利要求5到7的任一项所述的训练装置,其中,所述训练单元用于执行以下的任一训练:
基于所述相反像素级分类损失函数,训练所述第二神经网络模型;
基于所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练所述第二神经网络模型;以及
基于所述相反像素级分类损失函数、以及所述第二神经网络模型自身的像素级分类损失函数与加权的所述像素级知识逼近损失函数的和,训练所述第二神经网络模型。
9.一种电子设备,包括:
处理器;以及
存储器,用于存储计算机程序指令;
其中,当所述计算机程序指令由所述处理器加载并运行时,所述处理器执行如权利要求1到4的任一项所述的训练方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序指令,其中,所述计算机程序指令被处理器加载并运行时,所述处理器执行如权利要求1到4的任一项所述的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910228494.1A CN109961442B (zh) | 2019-03-25 | 2019-03-25 | 神经网络模型的训练方法、装置和电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910228494.1A CN109961442B (zh) | 2019-03-25 | 2019-03-25 | 神经网络模型的训练方法、装置和电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN109961442A CN109961442A (zh) | 2019-07-02 |
CN109961442B true CN109961442B (zh) | 2022-11-18 |
Family
ID=67024999
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910228494.1A Active CN109961442B (zh) | 2019-03-25 | 2019-03-25 | 神经网络模型的训练方法、装置和电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109961442B (zh) |
Families Citing this family (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110378278B (zh) * | 2019-07-16 | 2021-11-02 | 北京地平线机器人技术研发有限公司 | 神经网络的训练方法、对象搜索方法、装置以及电子设备 |
CN110599492B (zh) * | 2019-09-19 | 2024-02-06 | 腾讯科技(深圳)有限公司 | 图像分割模型的训练方法、装置、电子设备及存储介质 |
CN111507210B (zh) * | 2020-03-31 | 2023-11-21 | 华为技术有限公司 | 交通信号灯的识别方法、系统、计算设备和智能车 |
CN113673533A (zh) * | 2020-05-15 | 2021-11-19 | 华为技术有限公司 | 一种模型训练方法及相关设备 |
CN111737429B (zh) * | 2020-06-16 | 2023-11-03 | 平安科技(深圳)有限公司 | 训练方法、ai面试方法及相关设备 |
CN113139956B (zh) * | 2021-05-12 | 2023-04-14 | 深圳大学 | 基于语言知识导向的切面识别模型的生成方法及识别方法 |
CN113139520B (zh) * | 2021-05-14 | 2022-07-29 | 江苏中天互联科技有限公司 | 用于工业互联网的设备膜片性能监测方法 |
CN113361602B (zh) * | 2021-06-04 | 2023-07-14 | 展讯通信(上海)有限公司 | 神经网络模型的训练方法、装置和电子设备 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
CN107977707A (zh) * | 2017-11-23 | 2018-05-01 | 厦门美图之家科技有限公司 | 一种对抗蒸馏神经网络模型的方法及计算设备 |
CN108647684A (zh) * | 2018-05-02 | 2018-10-12 | 深圳市唯特视科技有限公司 | 一种基于引导注意力推理网络的弱监督语义分割方法 |
CN108805803A (zh) * | 2018-06-13 | 2018-11-13 | 衡阳师范学院 | 一种基于语义分割与深度卷积神经网络的肖像风格迁移方法 |
CN109034198A (zh) * | 2018-06-25 | 2018-12-18 | 中国科学院计算技术研究所 | 基于特征图恢复的场景分割方法和系统 |
CN109087303A (zh) * | 2018-08-15 | 2018-12-25 | 中山大学 | 基于迁移学习提升语义分割模型效果的框架 |
CN109377496A (zh) * | 2017-10-30 | 2019-02-22 | 北京昆仑医云科技有限公司 | 用于分割医学图像的系统和方法及介质 |
-
2019
- 2019-03-25 CN CN201910228494.1A patent/CN109961442B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
CN109377496A (zh) * | 2017-10-30 | 2019-02-22 | 北京昆仑医云科技有限公司 | 用于分割医学图像的系统和方法及介质 |
CN107977707A (zh) * | 2017-11-23 | 2018-05-01 | 厦门美图之家科技有限公司 | 一种对抗蒸馏神经网络模型的方法及计算设备 |
CN108647684A (zh) * | 2018-05-02 | 2018-10-12 | 深圳市唯特视科技有限公司 | 一种基于引导注意力推理网络的弱监督语义分割方法 |
CN108805803A (zh) * | 2018-06-13 | 2018-11-13 | 衡阳师范学院 | 一种基于语义分割与深度卷积神经网络的肖像风格迁移方法 |
CN109034198A (zh) * | 2018-06-25 | 2018-12-18 | 中国科学院计算技术研究所 | 基于特征图恢复的场景分割方法和系统 |
CN109087303A (zh) * | 2018-08-15 | 2018-12-25 | 中山大学 | 基于迁移学习提升语义分割模型效果的框架 |
Non-Patent Citations (5)
Title |
---|
Knowledge Adaptation for Efficient Semantic Segmentation;Tong He 等;《arXiv》;20190312;第1-12页 * |
PAYING MORE ATTENTION TO ATTENTION:IMPROVING THE PERFORMANCE OF CONVOLUTIONAL NEURAL NETWORKS VIA ATTENTION TRANSFER;Sergey Zagoruyko 等;《arXiv》;20170212;第1-13页 * |
Structured Knowledge Distillation for Semantic Segmentation;Yifan Liu 等;《arXiv》;20190311;第1-10页 * |
基于深度学习的无人车夜视图像语义分割;高凯珺 等;《应用光学》;20170531;第38卷(第3期);第421-428页 * |
基于深度特征蒸馏的人脸识别;葛仕明 等;《北京交通大学学报》;20171231;第41卷(第6期);第27-33、41页 * |
Also Published As
Publication number | Publication date |
---|---|
CN109961442A (zh) | 2019-07-02 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109961442B (zh) | 神经网络模型的训练方法、装置和电子设备 | |
US10853726B2 (en) | Neural architecture search for dense image prediction tasks | |
US10991074B2 (en) | Transforming source domain images into target domain images | |
Matsubara et al. | Distilled split deep neural networks for edge-assisted real-time systems | |
US11062453B2 (en) | Method and system for scene parsing and storage medium | |
CN113033537B (zh) | 用于训练模型的方法、装置、设备、介质和程序产品 | |
US20180018555A1 (en) | System and method for building artificial neural network architectures | |
CN110622176A (zh) | 视频分区 | |
WO2021218517A1 (zh) | 获取神经网络模型的方法、图像处理方法及装置 | |
CN110990631A (zh) | 视频筛选方法、装置、电子设备和存储介质 | |
KR20190031318A (ko) | 도메인 분리 뉴럴 네트워크들 | |
US20180173997A1 (en) | Training device and training method for training image processing device | |
US11144782B2 (en) | Generating video frames using neural networks | |
US10936938B2 (en) | Method for visualizing neural network models | |
US20230162477A1 (en) | Method for training model based on knowledge distillation, and electronic device | |
CN110163052B (zh) | 视频动作识别方法、装置和机器设备 | |
CN115331275A (zh) | 图像处理的方法、计算机系统、电子设备和程序产品 | |
EP4018411A1 (en) | Multi-scale-factor image super resolution with micro-structured masks | |
CN116994021A (zh) | 图像检测方法、装置、计算机可读介质及电子设备 | |
KR20230132350A (ko) | 연합 감지 모델 트레이닝, 연합 감지 방법, 장치, 설비 및 매체 | |
US20220207861A1 (en) | Methods, devices, and computer readable storage media for image processing | |
KR20210064817A (ko) | 상이한 딥러닝 모델 간의 전이 학습방법 | |
US20220004849A1 (en) | Image processing neural networks with dynamic filter activation | |
CN115457365B (zh) | 一种模型的解释方法、装置、电子设备及存储介质 | |
CN115809688B (zh) | 一种模型调试方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |