CN114936605A - 基于知识蒸馏的神经网络训练方法、设备及存储介质 - Google Patents
基于知识蒸馏的神经网络训练方法、设备及存储介质 Download PDFInfo
- Publication number
- CN114936605A CN114936605A CN202210646401.9A CN202210646401A CN114936605A CN 114936605 A CN114936605 A CN 114936605A CN 202210646401 A CN202210646401 A CN 202210646401A CN 114936605 A CN114936605 A CN 114936605A
- Authority
- CN
- China
- Prior art keywords
- student
- teacher
- network model
- loss function
- decoding
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Feedback Control In General (AREA)
Abstract
本发明公开了一种基于知识蒸馏的神经网络训练方法、设备及存储介质,方法包括如下步骤:构建未训练的学生网络模型和训练好的教师网络模型;根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,其中,损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型。本发明通过设置训练好的教师网络模型对学生网络模型进行知识蒸馏,可有效提高学生网络模型的场景信息获取能力以及数据泛化性能力,而且由多种损失函数组成的蒸馏损失函数组,再通过蒸馏损失函数组对学生网络模型进行训练,能够有效提高学生网络模型预测结果的准确性。
Description
技术领域
本发明涉及人工智能领域,特别涉及一种基于知识蒸馏的神经网络训练方法、设备及存储介质。
背景技术
随着深度学习的发展,三维平面恢复与重建技术是目前计算机视觉领域的研究任务之一,单张图片的三维平面与恢复需要从图像维度分割出场景的平面实例区域,同时估计出每个实例区域的平面参数,非平面区域会用网络模型估计的深度进行表示,三维平面恢复与重建技术在虚拟现实、增强现实、机器人等领域具有广阔的应用前景。
平面恢复重建是三维平面恢复与重建的一个重要研究方向,相关技术中,三维平面恢复与重建方法着重于重建精度,通过分析平面结构的边缘以及与场景的嵌入性来加强神经网络模型的准确性,但用于平面恢复与重建的神经网络模型存在丢失场景结构信息、缺乏数据泛化性的问题。
发明内容
本发明旨在至少解决现有技术中存在的技术问题之一。为此,本发明提供了一种基于知识蒸馏的神经网络训练方法、设备及存储介质,可提升神经网络模型的场景信息获取能力以及数据泛化性能。
本发明第一方面实施例提供一种基于知识蒸馏的神经网络训练方法,包括如下步骤:
构建未训练的学生网络模型和训练好的教师网络模型;
根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,其中,蒸馏损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;
根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型。
根据本发明的上述实施例,至少具有如下有益效果:通过设置训练好的教师网络模型对学生网络模型进行知识蒸馏,可有效提高学生网络模型的场景信息获取能力以及数据泛化性能力,而且根据教师网络模型和学生网络模型获取由多种损失函数组成的蒸馏损失函数组,再通过蒸馏损失函数组对学生网络模型进行训练,能够确保学生网络模型中每个处理环节的可靠性,从而有效提高学生网络模型预测结果的准确性。
根据本发明第一方面的一些实施例,根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,包括:
将训练样本输入到学生网络模型,得到学生特征组;
将训练样本输入到教师网络模型,得到教师特征组;
根据学生特征组和教师特征组,得到蒸馏损失函数组。
根据本发明第一方面的一些实施例,将训练样本输入到学生网络模型,得到学生特征组,包括:
将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组。
根据本发明第一方面的一些实施例,将训练样本输入到教师网络模型,得到教师特征组,包括:
将训练样本输入到教师网络模型,得到包括教师编码特征、教师解码特征和教师预测结果特征的教师特征组,其中,教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,教师骨干网络模型输出教师编码特征,教师编码器输出教师解码特征,教师解码器输出教师预测结果特征。
根据本发明第一方面的一些实施例,根据学生特征组和教师特征组,得到蒸馏损失函数组,包括:
根据学生编码特征和教师编码特征,得到编码损失函数;
根据学生解码特征和教师解码特征,得到解码损失函数;
根据学生预测结果特征和教师预测结果特征,得到预测结果损失函数。
根据本发明第一方面的一些实施例,学生网络模型包括学生编码器和学生解码器;
将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组,包括:
将训练样本输入到学生编码器进行下采样编码,得到学生编码特征;
根据学生编码器的下采样编码过程,得到融合特征层;
对学生编码特征进行卷积,得到学生解码特征;
将学生解码特征和融合特征层输入到学生解码器,先进行融合再进行上采样解码,得到学生预测结果特征。
根据本发明第一方面的一些实施例,根据学生编码器的下采样编码过程,得到融合特征层,包括:
根据学生编码器的下采样编码过程形成的每一下采样中间特征图,得到每一尺度下的融合特征层。
根据本发明第一方面的一些实施例,将学生解码特征和融合特征层输入到学生解码器,先进行融合再进行上采样解码,得到学生预测结果特征,包括:
将学生解码特征输入到学生解码器进行上采样解码,并将每一融合特征层分别与学生解码器上采样解码过程中对应尺度下的上采样中间特征图进行融合,得到学生解码特征。
本发明第二方面实施例提供一种电子设备,包括:
存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现第一方面任意一项的基于知识蒸馏的神经网络训练方法。
由于第二方面实施例的电子设备应用第一方面任意一项的基于知识蒸馏的神经网络训练方法,因此具有本发明第一方面的所有有益效果。
根据本发明第三方面实施例提供的一种计算机存储介质,存储有计算机可执行指令,计算机可执行指令用于执行第一方面任意一项的基于知识蒸馏的神经网络训练方法。
由于第三方面实施例的计算机存储介质可执行第一方面任意一项的基于知识蒸馏的神经网络训练方法,因此具有本发明第一方面的所有有益效果。
本发明的附加方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
本发明的上述和/或附加的方面和优点从结合下面附图对实施例的描述中将变得明显和容易理解,其中:
图1是本发明实施例的基于知识蒸馏的神经网络训练方法的主要步骤图;
图2是本发明实施例的基于知识蒸馏的神经网络训练方法中步骤S2000的具体步骤图;
图3是本发明实施例的基于知识蒸馏的神经网络训练方法中步骤S2100的具体步骤图;
图4是本发明实施例的基于知识蒸馏的神经网络训练方法中步骤S2300的具体步骤图;
图5是本发明实施例的基于知识蒸馏的神经网络训练方法中学生网络模型的工作原理图;
图6是本发明实施例的基于知识蒸馏的神经网络训练方法中教师网络模型的工作原理图;
图7是本发明实施例的基于知识蒸馏的神经网络训练方法的工作原理图。
具体实施方式
本发明的描述中,除非另有明确的限定,设置、安装、连接等词语应做广义理解,所属技术领域技术人员可以结合技术方案的具体内容合理确定上述词语在本发明中的具体含义。在本发明的描述中,若干的含义是一个或者多个,多个的含义是两个以上,大于、小于、超过等理解为不包括本数,以上、以下、以内等理解为包括本数。此外,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本发明的描述中,除非另有说明,“多个”的含义是两个或两个以上。
随着深度学习的发展,计算机视觉领域受到了越来有多的研究者关注,三维平面恢复与重建技术是目前计算机视觉领域的研究任务之一,单张图片的三维平面与恢复需要从图像维度分割出场景的平面实例区域,同时估计出每个实例区域的平面参数,非平面区域会用网络模型估计的深度进行表示,三维平面恢复与重建技术在虚拟现实、增强现实、机器人等领域具有广阔的应用前景。单张图片的平面检测与恢复方法需要同时对图像深度、平面法线、平面分割等展开研究,传统的基于人工提取特征的三维平面恢复重建方法仅提取了图像的浅层纹理信息,同时依赖于平面几何的先验条件,存在泛化能力较弱的缺点。而现实中室内场景十分复杂,复杂光线所产生的多重阴影以及各种折叠遮挡物都会影响平面恢复和重建的质量,导致传统方法难以应对复杂室内场景的平面恢复和重建的任务。
目前,三维重建的方法主要通过三维视觉方法生成点云数据,然后再通过拟合相关点生成非线性的场景表面,再通过全局推理优化整体的重建模型,而平面恢复重建是结合视觉实例分割方法识别场景的平面区域,并用笛卡尔坐标下的三个参数以及一个分割掩码表示平面,具有更好的重建精度和效果,通过分段式的平面恢复重建来实现三维重建。分段式平面恢复重建是多阶段的重建方法,平面识别和参数估计的精确度都会影响最终模型的结果。
相关技术中,平面预测重建可通过多种方法实现:卷积神经网络架构Planenet能够从单张RGB图片中推断固定数量的平面实例掩码以及平面参数;还能够通过预测固定数量的平面直接从平面结构诱导的深度模态中学习;两阶段Mask R-CNN框架用平面几何预测代替对象类别分类,然后用卷积神经网络对平面分割掩码进行细化;还可以预测逐像素平面参数,采用关联嵌入方法,训练网络参数将每个像素映射到嵌入空间,再将嵌入的像素聚类成平面实例;受曼哈顿世界假设约束的平面细化方法,通过限制平面实例间的集合关系来加强平面化参数的细化;从垂直和水平方向对全景图平面分割进行了分而治之的处理方法,针对于全景图与普通图像的像素分布差异,该方法能够恢复畸变的平面实例;基于transformer模块的方法PlaneTR,通过加入平面实例中心及边缘特征,能够有效提高平面检测的效率。Planenet是一种用于从单个RGB图片进行分段重建平面深度图的深度神经网络,Mask R-CNN是一种神经网络框架,transformer模块是一个基于多头注意力机制的模型,PlaneTR是一种用于提取场景中3D平面特征的模型。
平面恢复重建是三维平面恢复与重建的一个重要研究方向,目前三维重建方法大多是先通过三维视觉方法生成点云数据,再通过拟合相关点生成非线性的场景表面,再通过全局推理优化整体的重建模型,相关技术中,三维平面恢复与重建方法着重于重建精度,通过分析平面结构的边缘以及与场景的嵌入性来加强模型的准确性,用于平面恢复与重建的神经网络存在丢失场景结构信息、缺乏数据泛化性的问题。
为解决用于平面恢复与重建的神经网络存在的丢失场景信息、缺乏数据泛化性的问题,本发明通过知识蒸馏训练得到的学生网络模型用作平面恢复与重建,能够避免出现场景结构信息丢失、数据泛化性低的问题。
知识蒸馏是一种训练框架,学生网络模型利用强大的教师网络模型的softmax函数输出向量作为软标签进行学习,学生网络模型一般是轻量化的小型网络模型,通过蒸馏过程能够有效提高轻量化网络模型的预测精度,由于教师网络模型与学生网络模型之间的容量存在差异,对模型预测结果进行硬性关联会使学生网络模型在蒸馏过程中受到负正则化,即网络的过拟合,从而限制了蒸馏过程的有效性,而且由于神经网络提取特征随着深度迭加愈加抽象化,通过知识蒸馏得到的学生网络模型存在预测准确度低的问题。
从原始特征图推导出一个注意图来表达知识,通过匹配特征空间中的概率分布传递知识,引入因素作为一个更容易理解的中间表示形式,利用隐藏神经元的激活边界进行知识转移,知识转移的可塑性在最初几个训练阶段后迅速下降,知识蒸馏的有效性会被降低。
下面参照图1至图7描述本发明的基于知识蒸馏的神经网络训练方法、设备及存储介质,不仅能够提升所获学生网络模型的场景结构信息获取能力以及数据泛化性能力,还能够提高学生网络模型的预测精度。
参考图1所示,根据本发明第一方面实施例的基于知识蒸馏的神经网络训练方法,包括如下步骤:
步骤S1000、构建未训练的学生网络模型和训练好的教师网络模型;
步骤S2000、根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,其中,蒸馏损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;
步骤S3000、根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型,其中,训练好的学生网络模型可用于实现平面恢复重建。
通过设置训练好的教师网络模型对学生网络模型进行知识蒸馏,可有效提高学生网络模型的场景信息获取能力以及数据泛化性能力,而且根据教师网络模型和学生网络模型获取由多种损失函数组成的蒸馏损失函数组,再通过蒸馏损失函数组对学生网络模型进行训练,能够确保学生网络模型中每个处理环节的可靠性,从而有效提高学生网络模型预测结果的准确性。
可以理解的是,参考图2所示,步骤S2000,根据训练样本、学生网络模型和教师网络模型,得到蒸馏损失函数组,包括但不限于以下步骤:
步骤S2100、将训练样本输入到学生网络模型,得到学生特征组;
步骤S2200、将训练样本输入到教师网络模型,得到教师特征组;
步骤S2300、根据学生特征组和教师特征组,得到蒸馏损失函数组。
利用同组训练样本,分别输入到学生网络模型和教师网络模型,通过知识蒸馏的方法提取两个模型的特征并构建蒸馏损失函数组,并利用具有多层感知的蒸馏损失函数组对学生网络模型进行训练,能够有效提高训练好的学生网络模型的性能,学生网络模型的预测精度高。
可以理解的是,学生网络模型包括学生编码器和学生解码器,学生特征组包括学生编码特征、学生解码特征和学生预测结果特征;步骤S2100,将训练样本输入到学生网络模型,得到学生特征组,包括但不限于以下步骤:
步骤S2110、将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组。
通过设置学生网络模型省略学生网络模型预训练的学生骨干模型,使用学生编码器和学生解码器组成学生网络模型,能够显著简化学生网络模型的结构,学生网络模型的轻量化程度高,在使用训练好的学生网络模型进行预测时,能够有效提高其预测速度,用于平面检测恢复时的速度快;利用网络中间特征也具有一定的学习潜质,通过包含了三组蒸馏损失函数的蒸馏损失函数组对学生网络模型进行迭代训练,即通过循序渐进的蒸馏过程有助于减轻硬性关联的负面影响,最终获得的训练好的学生网络模型能够同时满足实时以及高精度的预测性能。
可以理解的是,参考图3所示,步骤S2110,将训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的学生特征组,包括但不限于以下步骤:
步骤S2111、将训练样本输入到学生编码器进行下采样编码,得到学生编码特征;
步骤S2112、根据学生编码器的下采样编码过程,得到融合特征层;
步骤S2113、对学生编码特征进行卷积,得到学生解码特征;
步骤S2114、将学生解码特征和融合特征层输入到学生解码器,学生解码特征与融合特征层先进行融合再进行上采样解码,得到学生预测结果特征。
学生编码器对训练样本的输入数据进行下采样编码处理,下采样过程中,可以采用快速的下采样策略,通过具有足够大的感知域进行特征提取识别,能够有效提高识别速度,下采样操作会导致空间信息丢失,而这些丢失的信息在后续处理过程中是无法恢复的,通过在学生编码器下采样过程中提取相应的特征作为特征融合层,用于与学生解码器上采样解码过程中对应的特征进行融合,能够对下采样过程中丢失的空间信息作出相应的补偿,能够有效确保学生解码器上采样解码后得到的学生预测结果特征的可靠性。
可以理解的是,学生网络模型的工作原理参考图5所示,步骤S2112,根据学生编码器的下采样编码过程,得到融合特征层,包括但不限于以下步骤:
根据学生编码器的下采样编码过程形成的每一下采样中间特征图,得到每一尺度下的融合特征层。
步骤S2114,将学生解码特征和融合特征层输入到学生解码器,学生解码特征与融合特征层先进行融合再进行上采样解码,得到学生预测结果特征,包括但不限于以下步骤:
将学生解码特征输入到学生解码器进行上采样解码,并将每一融合特征层分别与学生解码器上采样解码过程中对应尺度下的上采样中间特征图进行融合,学生解码器最终输出得到学生预测结果特征。
在学生编码器对训练样本的输入数据进行下采样编码处理时,通常需要对训练样本的输入数据进行多层下采样操作,一旦下采样的层数过多,会导致下采样操作丢失大部分的空间信息,由于这些丢失的信息在后续处理过程中是无法恢复的,向上采样过程提供的数据失真严重,会严重影响最终的预测结果,通过将下采样的浅层特征作为融合特征层,将每一融合特征层与同一尺度上采样的深层特征进行融合,能够逐步恢复空间细节,进而能够有效保证学生解码器输出的学生解码特征的可靠性。
可以理解的是,教师网络模型的工作原理参考图6所示,步骤S2200,将训练样本输入到教师网络模型,得到教师特征组,包括但不限于以下步骤:
步骤S2210、将训练样本输入到教师网络模型,得到包括教师编码特征、教师解码特征和教师预测结果特征的教师特征组,其中,教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,教师骨干网络模型输出教师编码特征,教师编码特征输入到教师编码器后输出教师解码特征,教师解码特征输入到教师解码器后输出教师预测结果特征,教师特征组包括教师编码特征、教师解码特征和教师预测结果特征。
具体的,教师网络模型中,训练样本输入到教师骨干网络模型中,教师骨干网络模型输出教师编码特征,教师编码特征输入到教师编码器中,教师编码器输出教师解码特征,教师解码特征输入到教师解码器中,教师解码器输出教师预测结果特征。
可以理解的是,参考图4所示,步骤S2300,根据学生特征组和教师特征组,得到蒸馏损失函数组,包括但不限于以下步骤:
步骤S2310、根据学生编码特征和教师编码特征,得到编码损失函数,其中,编码损失函数用于对学生编码器的下采样编码进行校正,以使学生编码器输出更准确的学生编码特征;
步骤S2320、根据学生解码特征和教师解码特征,得到解码损失函数,其中,解码损失函数用于对学生解码器前的卷积进行校正,以确保输入到学生解码器的学生解码特征的准确性;
步骤S2330、根据学生预测结果特征和教师预测结果特征,得到预测结果损失函数,其中,预测结果损失函数用于对学生解码器的上采样解码进行校正,以使学生网络模型输出更准确的学生预测结果特征。
基于知识蒸馏的神经网络训练方法的工作原理图参考图7所示,通过在教师网络模型中三个网络层与学生网络模型中对应网络层中,分别提取相应输出的特征,并生成相应的蒸馏损失函数,通过三种对应不同网络层的蒸馏损失函数对学生网络模型进行迭代训练,即通过学生网络模型与教师网络模型的相应层次之间实现直接且有效的一对一匹配,能够有效确保学生网络模型中对应网络层进行数据处理的准确性,从架构上有效提高学生网络模型的性能,能够有效确保学生网络模型的泛化性以及预测结果的准确性。
具有多个学生网络的知识蒸馏架构,采用批判性学习意识KD(KnowledgeDistillation)方案,确保关键连接的形成,允许有效地模仿教师的信息流,而不是仅仅学习一个学生,允许学生网络模型和教师网络模型的对应层次之间进行直接和有效的一对一匹配,将教师网络模型和学生网络模型自适应分为三份,赋予每一份所包含对应网络层的自适应参数,并进行知识蒸馏学习,通过浅层网络特征关联的语义校正来显著提高特征知识迁移的有效性,利用注意机制实现跨层整流,能够缓解语义不匹配的问题。
可以理解的是,步骤S3000、根据蒸馏损失函数组对学生网络模型进行训练,得到训练好的学生网络模型,包括但不限于以下步骤:
根据编码损失函数对学生编码器的下采样编码进行校正,根据解码损失函数对学生解码器之前的卷积进行校正,根据预测结果损失函数对学生解码器的上采样解码进行校正,通过上述方式对学生网络模型进行训练,得到训练好的学生网络模型。
学生网络在根据任务训练网络的同时,根据含有多个中间层损失函数的蒸馏损失函数组进行辅助训练,通过中间特征层的迁移学习能够加强学生网络的估计性能。具体的,在编码维度、解码维度以及预测结果的维度,保证了教师网络模型在容量下溢时为学生网络模型提供更可靠的参数学习。
可以理解的是,基于transformer模块的教师网络能够实现全局区域的检测,设置教师网络模型以transformer模块为基础进行搭建,以HR-Net模型作为特征提取的教师骨干网络模型,生成高维的低尺度特征作为块嵌入,HR-Net模型是一种高分辨率网络,块嵌入的尺寸为p,H×W的像素图片被分为特征块嵌入的集合S0∈RD等等,其中RD是教师骨干网络模型输出的特征空间,特征块数量为最后输入到共有12层的transformer模块中,教师网络模型包括深度估计分支,深度估计分支以教师骨干网络模型的多尺度特征以及教师编码特征作为输入源,通过自上而下的解码结构估计图像深度,该结构采用双线性插值的上采样模块,每次采样后的特征模块与教师骨干网络模型特征尺度相对应,即实行2倍上采样机制估计图像深度,教师骨干网络模型输出相应的特征维度。
可以理解的是,教师网络模型与学生网络模型的最终输出进行L2损失函数校正,由于网络中没有最大值函数,L2损失函数用于对相应网络模型中最后一层激活层之前的特征进行校正,训练学生网络模型时,蒸馏损失函数组和L2损失函数能够实现更可靠的校正效果,应用于平面恢复与重建时,训练好的学生网络模型的预测精度更高。
下面以一个具体的实施例来详细描述本发明第一方面的基于知识蒸馏的神经网络训练方法。值得理解的是,下述描述仅是示例性说明,而不是对发明的具体限制。
构建未训练的学生网络模型和训练好的教师网络模型,其中,教师网络模型基于transformer模块设计的教师网络模型,教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,其中教师骨干网络模型采用HR-Net模型,学生网络模型包括学生编码器和学生解码器,学生网络模型省略了其骨干网络模型的设置;
将训练样本输入到学生编码器进行下采样编码,得到学生编码特征,mobilenet-v3是一种轻量级网络,使用mobilenet-v3作为特征提取器,根据学生编码器的下采样编码过程中的下采样中间特征图,得到融合特征层,根据融合特征层,将学生编码特征输入到学生解码器进行上采样解码,得到学生解码特征;
将训练样本输入到教师网络模型中,获取教师骨干网络模型的教师编码特征、教师编码器输出的教师解码特征、教师解码器输出的教师预测结果特征;获取学生编码器输出的学生编码特征、学生解码器前经过卷积处理后的学生解码特征、学生解码器输出的学生预测结果特征;
根据学生编码特征和教师编码特征,得到编码损失函数,根据学生解码特征和教师解码特征,得到解码损失函数,根据学生预测结果特征和教师预测结果特征,得到预测结果损失函数;
根据编码损失函数对学生编码器的下采样编码进行校正,根据解码损失函数对学生解码器之前的卷积进行校正,根据预测结果损失函数对学生解码器的上采样解码进行校正,通过上述方式对学生网络模型进行训练,得到训练好的学生网络模型,训练好的学生网络模型可用于实现平面恢复重建。
其中,学生网络模型在工作的过程中,训练样本输入到学生编码器进行下采样编码,得到学生编码特征;根据学生编码器下采样编码过程中生成的每一下采样中间特征图,得到每一尺度下的融合特征层;对学生编码特征进行卷积,得到学生解码特征;将学生解码特征输入到学生解码器进行上采样解码,并且将每一融合特征层分别与学生解码器中对应尺度下的上采样中间特征图进行融合,学生解码器输出得到学生预测结果特征。在学生解码器中的每个解码阶段,通过特征融合模块将通尺度浅层特征即融合特征层,与上采样解码过程中的上采样中间特征图进行串联融合,其分辨率分别为1/32、1/16、1/8、1/4和1/2,能够保证每一次特征融合后相同尺度的特征具有相同的特征通道,在最后学生编码特征、学生解码特征和学生预测结果特征分别进行迁移学习。
另外,本发明第二方面实施例还提供了一种电子设备,该电子设备包括:存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序。
处理器和存储器可以通过总线或者其他方式连接。
存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序以及非暂态性计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
实现上述第一方面实施例的基于知识蒸馏的神经网络训练方法所需的非暂态软件程序以及指令存储在存储器中,当被处理器执行时,执行上述实施例中的基于知识蒸馏的神经网络训练方法,例如,执行以上描述的方法步骤S1000至S3000、方法步骤S2100至S2300、方法步骤S2110、方法步骤S2111至S2114、方法步骤S2210、方法步骤S2310至S2330。
以上所描述的设备实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。
此外,本发明的一个实施例还提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机可执行指令,该计算机可执行指令被一个处理器或控制器执行,例如,被上述设备实施例中的一个处理器执行,可使得上述处理器执行上述实施例中的基于知识蒸馏的神经网络训练方法,例如,执行以上描述的方法步骤S1000至S3000、方法步骤S2100至S2300、方法步骤S2110、方法步骤S2111至S2114、方法步骤S2210、方法步骤S2310至S2330。
本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统可以被实施为软件、固件、硬件及其适当的组合。某些物理组件或所有物理组件可以被实施为由处理器,如中央处理器、数字信号处理器或微处理器执行的软件,或者被实施为硬件,或者被实施为集成电路,如专用集成电路。这样的软件可以分布在计算机可读介质上,计算机可读介质可以包括计算机存储介质(或非暂时性介质)和通信介质(或暂时性介质)。如本领域普通技术人员公知的,术语计算机存储介质包括在用于存储信息(诸如计算机可读指令、数据结构、程序模块或其他数据)的任何方法或技术中实施的易失性和非易失性、可移除和不可移除介质。计算机存储介质包括但不限于RAM、ROM、EEPROM、闪存或其他存储器技术、CD-ROM、数字多功能盘(DVD)或其他光盘存储、磁盒、磁带、磁盘存储或其他磁存储装置、或者可以用于存储期望的信息并且可以被计算机访问的任何其他的介质。此外,本领域普通技术人员公知的是,通信介质通常包含计算机可读指令、数据结构、程序模块或者诸如载波或其他传输机制之类的调制数据信号中的其他数据,并且可包括任何信息递送介质。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示意性实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不一定指的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任何的一个或多个实施例或示例中以合适的方式结合。
尽管已经示出和描述了本发明的实施例,本领域的普通技术人员可以理解:在不脱离本发明的原理和宗旨的情况下可以对这些实施例进行多种变化、修改、替换和变型,本发明的范围由权利要求及其等同物限定。
Claims (10)
1.一种基于知识蒸馏的神经网络训练方法,其特征在于,包括如下步骤:
构建未训练的学生网络模型和训练好的教师网络模型;
根据训练样本、所述学生网络模型和所述教师网络模型,得到蒸馏损失函数组,其中,所述蒸馏损失函数组包括编码损失函数、解码损失函数和预测结果损失函数;
根据所述蒸馏损失函数组对所述学生网络模型进行训练,得到训练好的学生网络模型。
2.根据权利要求1所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述根据训练样本、所述学生网络模型和所述教师网络模型,得到蒸馏损失函数组,包括:
将所述训练样本输入到学生网络模型,得到学生特征组;
将所述训练样本输入到教师网络模型,得到教师特征组;
根据所述学生特征组和所述教师特征组,得到所述蒸馏损失函数组。
3.根据权利要求2所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述将所述训练样本输入到学生网络模型,得到学生特征组,包括:
将所述训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的所述学生特征组。
4.根据权利要求3所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述将所述训练样本输入到教师网络模型,得到教师特征组,包括:
将所述训练样本输入到教师网络模型,得到包括教师编码特征、教师解码特征和教师预测结果特征的所述教师特征组,其中,所述教师网络模型包括教师骨干网络模型、教师编码器和教师解码器,所述教师骨干网络模型输出所述教师编码特征,所述教师编码器输出所述教师解码特征,所述教师解码器输出所述教师预测结果特征。
5.根据权利要求4所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述根据所述学生特征组和所述教师特征组,得到所述蒸馏损失函数组,包括:
根据所述学生编码特征和所述教师编码特征,得到编码损失函数;
根据所述学生解码特征和所述教师解码特征,得到解码损失函数;
根据所述学生预测结果特征和所述教师预测结果特征,得到预测结果损失函数。
6.根据权利要求3所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述学生网络模型包括学生编码器和学生解码器;
将所述训练样本输入到学生网络模型,得到包括学生编码特征、学生解码特征和学生预测结果特征的所述学生特征组,包括:
将所述训练样本输入到所述学生编码器进行下采样编码,得到所述学生编码特征;
根据所述学生编码器的下采样编码过程,得到融合特征层;
对所述学生编码特征进行卷积,得到学生解码特征;
将所述学生解码特征和所述融合特征层输入到所述学生解码器,先进行融合再进行上采样解码,得到所述学生预测结果特征。
7.根据权利要求6所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述根据所述学生编码器的下采样编码过程,得到融合特征层,包括:
根据所述学生编码器的下采样编码过程形成的每一下采样中间特征图,得到每一尺度下的所述融合特征层。
8.根据权利要求7所述的基于知识蒸馏的神经网络训练方法,其特征在于,所述将所述学生解码特征与所述融合特征层输入到所述学生解码器,先进行融合再进行上采样解码,得到所述学生预测结果特征,包括:
将所述学生解码特征输入到所述学生解码器进行上采样解码,并将每一所述融合特征层分别与所述学生解码器上采样解码过程中对应尺度下的上采样中间特征图进行融合,所述学生解码器输出得到所述学生预测结果特征。
9.一种电子设备,其特征在于,包括:
存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至8中任意一项所述的基于知识蒸馏的神经网络训练方法。
10.一种计算机存储介质,其特征在于,存储有计算机可执行指令,所述计算机可执行指令用于执行权利要求1至8中任意一项所述的基于知识蒸馏的神经网络训练方法。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210646401.9A CN114936605A (zh) | 2022-06-09 | 2022-06-09 | 基于知识蒸馏的神经网络训练方法、设备及存储介质 |
PCT/CN2022/098769 WO2023212997A1 (zh) | 2022-05-05 | 2022-06-14 | 基于知识蒸馏的神经网络训练方法、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210646401.9A CN114936605A (zh) | 2022-06-09 | 2022-06-09 | 基于知识蒸馏的神经网络训练方法、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114936605A true CN114936605A (zh) | 2022-08-23 |
Family
ID=82865998
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210646401.9A Pending CN114936605A (zh) | 2022-05-05 | 2022-06-09 | 基于知识蒸馏的神经网络训练方法、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114936605A (zh) |
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115471645A (zh) * | 2022-11-15 | 2022-12-13 | 南京信息工程大学 | 一种基于u型学生网络的知识蒸馏异常检测方法 |
CN115774851A (zh) * | 2023-02-10 | 2023-03-10 | 四川大学 | 基于分级知识蒸馏的曲轴内部缺陷检测方法及其检测系统 |
CN115908253A (zh) * | 2022-10-18 | 2023-04-04 | 中科(黑龙江)数字经济研究院有限公司 | 一种基于知识蒸馏的跨域医学影像分割方法及装置 |
CN116028891A (zh) * | 2023-02-16 | 2023-04-28 | 之江实验室 | 一种基于多模型融合的工业异常检测模型训练方法和装置 |
CN116310667A (zh) * | 2023-05-15 | 2023-06-23 | 鹏城实验室 | 联合对比损失和重建损失的自监督视觉表征学习方法 |
CN116304029A (zh) * | 2023-02-22 | 2023-06-23 | 北京麦克斯泰科技有限公司 | 一种使用知识异构的深度学习模型蒸馏方法和系统 |
CN117425013A (zh) * | 2023-12-19 | 2024-01-19 | 杭州靖安防务科技有限公司 | 一种基于可逆架构的视频传输方法和系统 |
CN117521848A (zh) * | 2023-11-10 | 2024-02-06 | 中国科学院空天信息创新研究院 | 面向资源受限场景的遥感基础模型轻量化方法、装置 |
-
2022
- 2022-06-09 CN CN202210646401.9A patent/CN114936605A/zh active Pending
Cited By (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115908253A (zh) * | 2022-10-18 | 2023-04-04 | 中科(黑龙江)数字经济研究院有限公司 | 一种基于知识蒸馏的跨域医学影像分割方法及装置 |
CN115471645A (zh) * | 2022-11-15 | 2022-12-13 | 南京信息工程大学 | 一种基于u型学生网络的知识蒸馏异常检测方法 |
CN115774851A (zh) * | 2023-02-10 | 2023-03-10 | 四川大学 | 基于分级知识蒸馏的曲轴内部缺陷检测方法及其检测系统 |
CN115774851B (zh) * | 2023-02-10 | 2023-04-25 | 四川大学 | 基于分级知识蒸馏的曲轴内部缺陷检测方法及其检测系统 |
CN116028891A (zh) * | 2023-02-16 | 2023-04-28 | 之江实验室 | 一种基于多模型融合的工业异常检测模型训练方法和装置 |
CN116304029A (zh) * | 2023-02-22 | 2023-06-23 | 北京麦克斯泰科技有限公司 | 一种使用知识异构的深度学习模型蒸馏方法和系统 |
CN116304029B (zh) * | 2023-02-22 | 2023-10-13 | 北京麦克斯泰科技有限公司 | 一种使用知识异构的深度学习模型蒸馏方法和系统 |
CN116310667A (zh) * | 2023-05-15 | 2023-06-23 | 鹏城实验室 | 联合对比损失和重建损失的自监督视觉表征学习方法 |
CN116310667B (zh) * | 2023-05-15 | 2023-08-22 | 鹏城实验室 | 联合对比损失和重建损失的自监督视觉表征学习方法 |
CN117521848A (zh) * | 2023-11-10 | 2024-02-06 | 中国科学院空天信息创新研究院 | 面向资源受限场景的遥感基础模型轻量化方法、装置 |
CN117521848B (zh) * | 2023-11-10 | 2024-05-28 | 中国科学院空天信息创新研究院 | 面向资源受限场景的遥感基础模型轻量化方法、装置 |
CN117425013A (zh) * | 2023-12-19 | 2024-01-19 | 杭州靖安防务科技有限公司 | 一种基于可逆架构的视频传输方法和系统 |
CN117425013B (zh) * | 2023-12-19 | 2024-04-02 | 杭州靖安防务科技有限公司 | 一种基于可逆架构的视频传输方法和系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114936605A (zh) | 基于知识蒸馏的神经网络训练方法、设备及存储介质 | |
Wang et al. | SFNet-N: An improved SFNet algorithm for semantic segmentation of low-light autonomous driving road scenes | |
CN111539887B (zh) | 一种基于混合卷积的通道注意力机制和分层学习的神经网络图像去雾方法 | |
WO2023212997A1 (zh) | 基于知识蒸馏的神经网络训练方法、设备及存储介质 | |
CN109087258B (zh) | 一种基于深度学习的图像去雨方法及装置 | |
CN108734210B (zh) | 一种基于跨模态多尺度特征融合的对象检测方法 | |
CN113592736B (zh) | 一种基于融合注意力机制的半监督图像去模糊方法 | |
CN111079532A (zh) | 一种基于文本自编码器的视频内容描述方法 | |
CN113591968A (zh) | 一种基于非对称注意力特征融合的红外弱小目标检测方法 | |
CN112396645A (zh) | 一种基于卷积残差学习的单目图像深度估计方法和系统 | |
Zeng et al. | LEARD-Net: Semantic segmentation for large-scale point cloud scene | |
CN112581409B (zh) | 一种基于端到端的多重信息蒸馏网络的图像去雾方法 | |
CN113066025B (zh) | 一种基于增量学习与特征、注意力传递的图像去雾方法 | |
CN113066089B (zh) | 一种基于注意力引导机制的实时图像语义分割方法 | |
CN116311254B (zh) | 一种恶劣天气情况下的图像目标检测方法、系统及设备 | |
CN110852199A (zh) | 一种基于双帧编码解码模型的前景提取方法 | |
CN115908789A (zh) | 跨模态特征融合及渐近解码的显著性目标检测方法及装置 | |
CN114283352A (zh) | 一种视频语义分割装置、训练方法以及视频语义分割方法 | |
CN115035172A (zh) | 基于置信度分级及级间融合增强的深度估计方法及系统 | |
CN116258756B (zh) | 一种自监督单目深度估计方法及系统 | |
CN116597144A (zh) | 一种基于事件相机的图像语义分割方法 | |
CN113807354B (zh) | 图像语义分割方法、装置、设备和存储介质 | |
CN114219738A (zh) | 单幅图像多尺度超分辨重建网络结构及方法 | |
Zou et al. | Group‐Based Atrous Convolution Stereo Matching Network | |
CN114926734B (zh) | 基于特征聚合和注意融合的固体废弃物检测装置及方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |