CN114897160A - 模型训练方法、系统及计算机存储介质 - Google Patents
模型训练方法、系统及计算机存储介质 Download PDFInfo
- Publication number
- CN114897160A CN114897160A CN202210540240.5A CN202210540240A CN114897160A CN 114897160 A CN114897160 A CN 114897160A CN 202210540240 A CN202210540240 A CN 202210540240A CN 114897160 A CN114897160 A CN 114897160A
- Authority
- CN
- China
- Prior art keywords
- student
- teacher
- network
- feature map
- characteristic diagram
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/25—Determination of region of interest [ROI] or a volume of interest [VOI]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
- G06V10/806—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Multimedia (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本申请提供一种模型训练方法、装置及计算机存储介质,包括利用教师网络和学生网络分别针对样本图像执行特征提取,获取样本图像的教师特征图和学生特征图;根据样本图像的真值框、学生特征图,确定学生特征图的蒸馏区域;基于蒸馏区域比对教师特征图和学生特征图,确定损失函数;基于损失函数训练学生网络,以获得训练好的学生网络。借此,本申请可在大幅降低系统计算量的同时,提高模型的训练效率。
Description
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种模型训练方法、装置及计算机存储介质。
背景技术
目标检测任务是计算机视觉的最基本任务之一,同时也是应用最广泛的任务。相比较于普通的分类任务,目标检测任务更加复杂,需要同时针对目标对象的位置的类别进行预测,因此目标检测网络通常来说也更为复杂,不仅包括分类网络的特征提取网络,还包括特有的特征融合网络以及检测框回归层,此将导致目标检测网络的参数量和计算量较大,并在实际应用场景对硬件算力要求较高。
有鉴于此,目标检测网络的轻量化即成为了业界研究的重要方向。
发明内容
鉴于上述问题,本申请提供一种模型训练方法、装置及计算机存储介质,可在降低系统计算量的同时,提高模型训练效果。
本申请第一方面提供一种模型训练方法,包括:利用教师网络和学生网络分别针对样本图像执行特征提取,获取所述样本图像的教师特征图和学生特征图;根据所述样本图像的真值框、所述学生特征图,确定所述学生特征图的蒸馏区域;基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定损失函数,并基于所述损失函数训练所述学生网络,以获得训练好的学生网络。
本申请第二方面提供一种模型训练装置,包括:特征图生成模块,用于利用教师网络和学生网络分别针对样本图像执行特征提取,获取所述样本图像的教师特征图和学生特征图;蒸馏区域分析模块,用于根据所述样本图像的真值框、所述学生特征图,确定所述学生特征图的蒸馏区域;训练模块,用于基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定损失函数,并基于所述损失函数训练所述学生网络,以获得训练好的学生网络。
本申请第三方面提供一种计算机存储介质,所述计算机存储介质中存储有用于执行第一方面所述的方法中各步骤的各指令。
综上所述,本申请提供的模型训练方案,利用教师网络和学生网络分别获取样本图像的教师特征图和学生特征图,并根据样本图像的真值框,确定学生特征图中的蒸馏区域,以基于蒸馏区域比对教师特征图和学生特征图,并根据比对结果训练学生网络,据此,本申请提供了一种结合知识蒸馏技术的模型训练方案,可在大幅降低系统计算量的同时,快速缩小学生网络和教师网络之间的精度差距,以提高模型训练效率。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请实施例中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。
图1为本申请示例性实施例的模型训练方法的处理流程图。
图2为本申请另一示例性实施例的模型训练方法的处理流程图。
图3为本申请另一示例性实施例的模型训练方法的处理流程图。
图4为本申请示例性实施例的模型训练装置的结构框图。
具体实施方式
为了使本领域的人员更好地理解本申请实施例中的技术方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请实施例一部分实施例,而不是全部的实施例。基于本申请实施例中的实施例,本领域普通技术人员所获得的所有其他实施例,都应当属于本申请实施例保护的范围。
目标的检测网络不仅包括了分类网络的特征提取网络,还包括了特有的特征融合网络以及检测框回归层,此复杂的网络结构设计导致了目标检测网络的参数量和计算量都比较大,且在实际应用场景中,对于硬件的算力要求也较高。因此,目标检测网络的轻量化设计已成为了当前相关研究业者研究的热点之一。
目前的模型轻量化的设计方案可包括模型剪枝以及INT8量化。其中,模型剪枝是通过裁剪掉模型中冗余的层,在保证精度的同时,来减少模型的计算量,然而,此实现过程需要分为三个阶段,先训练好模型,再剪枝,再微调,存在着实施过程复杂,精度与原始网络有天然的差距,且剪枝算法主要是去除冗余参数,而在轻量化模型中,冗余参数较少,因此作用并不明显。另外,INT8量化是指将高精度浮点数的模型参数按照一定的映射规则映射到低精度INT8类型,以此加速网络的计算,同时保证网络的精度,缺点在于量化表的建立比较复杂,且对最终的精度影响较大。
为实现目标检测网络的轻量化设计,目前业界推出了一些更为精简高效的目标检测网络,如SSD、YOLO等,这些网络能以较小的参数量实现较高的精度。同时,知识蒸馏技术目前也越来越多地被应用于目标检测任务中进行探索。然而,不同于分类任务输出的是一维特征向量,检测任务输出的通常是二维特征向量,因此,不能简单地将分类任务中的相关经验直接进行迁移,而如何将知识蒸馏技术应用于目标检测网络的训练任务中,以提高模型训练效率,成为了其中的难点和重点。
有鉴于此,本申请提供一种结合知识蒸馏技术的模型训练方法,可至少部分地解决现有技术存在的种种问题。
以下将结合各附图详细描述本申请的各实施例。
图1示出了本申请示例性实施例的模型训练方法的处理流程图。如图所示,本实施例的模型训练方法主要包括:
步骤S102,利用教师网络和学生网络分别针对样本图像执行特征提取,获取样本图像的教师特征图和学生特征图。
可选地,教师网络和学生网络可包括目标检测网络,例如YOLO网络等。
可选地,可将教师网络的网络尺度因子缩小预设比值,以获取学生网络,例如,可将教师网络YOLOV5s的网络尺度因子缩小1/4,获取学生网络YOLOV5xs,其中,学生网络YOLOV5xs的参数量为教师网络YOLOV5s的1/4。
可选地,可利用教师网络,基于预设维度针对样本图像执行特征提取,获取对应于预设维度的教师特征图,并利用学生网络,基于预设维度针对样本图像执行特征提取,获取对应于预设维度的学生特征图。
可选地,预设维度可包括8、16、32中的至少一个。
例如,可根据样本图像,分别提取教师网络和学生网络的stride=8,stride=16,stride=32的三个教师特征图和三个学生特征图。
可选地,教师特征图的通道数量与学生特征图的通道数量不相同。
可选地,可针对学生特征图的通道数执行适配处理,以使所生成的学生特征图的通道数量与教师特征图的通道数量相适配。
可选地,可利用预设卷积核针对学生特征图的通道数执行卷积处理。
于本实施例中,在学生特征图的大小与教师特征图的大小为相同的情况下,预设卷积核尺寸可设置为1×1×C2×C1,其中,1×1为学生特征图和教师特征图的高度与宽度,C2为学生特征图的通道数量;C1为教师特征图的通道数量。
步骤S104,根据样本图像的真值框、学生特征图,确定学生特征图的蒸馏区域。
于本实施例中,可计算样本图像中的目标对象的真值框(亦可称为真值框)相较于学生特征中的每一个参考锚框之间的交并比,以确定各参考锚框中的至少一个候选锚框,并基于各候选锚框构成的并集,确定学生特征图的蒸馏区域。
可选地,可根据单位锚框设置参数,在学生特征图的每一个网格单元中生成满足单位锚框设置参数的各参考锚框。
例如,若单位锚框设置参数为3,则在学生特征图的每一个网格单元中生成3个参考锚框。
步骤S106,基于蒸馏区域比对教师特征图和学生特征图,确定损失函数,并基于损失函数训练学生网络,以获得训练好的学生网络。
可选地,损失函数可包括模仿损失子函数。
可选地,可基于蒸馏区域比对教师特征图和学生特征图,确定模仿损失子函数。
具体地,可将具有相同特征维度的学生特征图和教师特征图设置为一个特征图组,并基于蒸馏区域计算每一个特征图组中学生特征图与教师特征图之间的特征向量的均方差,再累加每一个特征图组的均方差,以确定模仿损失子函数。
可选地,损失函数还可包括训练损失子函数。
于本实施例中,可根据样本图像的样本标签和学生网络针对样本图像输出的预测结果,确定训练损失子函数。
可选地,可基于损失函数优化更新学生网络的网络参数,并返回执行步骤S102,直至损失函数满足收敛条件,以完成学生网络的训练。
于本实施例中,可在损失函数收敛至预设值时,获得损失函数满足预设收敛条件的判断结果,或者可在损失函数收敛至稳定时,获得损失函数满足预设收敛条件的判断结果。
综上所述,本申请实施例提供的模型训练方法,利用教师网络和学生网络分别预测样本图像的教师特征图和学生特征图,根据样本图像的真值框确定学生特征图的蒸馏区域,并基于所确定的蒸馏区域比对教师特征图和学生特征图,并根据比对结果训练学生网络,借由此种结合知识蒸馏技术所执行的模型训练方案,可在大幅减少模型参数量的同时,快速缩小学生网络和教师网络之间的差距,以提高模型训练效率。
图2示出了本申请另一示例性实施例的模型训练方法的处理流程图。本实施例为上述步骤S104的具体实施方案,如图所示,本实施例主要包括以下步骤:
步骤S202,根据样本图像中的目标对象的真值框、学生特征图中的各参考锚框,计算真值框相对于每一个参考锚框之间的交并比,获得每一个参考锚框的交并比值。
于本实施例中,真值框可由人工标注生成,用于标注样本图像的中的目标对象。
可选地,可依次将学生特征图中的一个参考锚框作为当前锚框,计算真值框与当前锚框的交并比(IOU),也就是将真值框与当前锚框的交集除以真值框与当前锚框之间的并集,以获得当前锚框的交并比值,从重复执行获得当前锚框的交并比值的步骤,以获得学生特征图中每一个参考锚框的交并比值。
步骤S204,根据每一个参考锚框的交并比值、学生特征图的交并比阈值,确定各参考锚框中的至少一个候选锚框。
可选地,可根据每一个参考锚框的交并比值中的最高者、预设全局阈值,确定学生特征图的交并比阈值。
可选地,预设全局阈值可介于0.2至0.5之间,较佳为0.3。
具体地,可根据各交并比值中的最高值与预设全局阈值的乘积,确定交并比阈值。
需说明的是,上述交并比阈值除可根据各参考锚框对应的各交并比值的实际计算结果进行动态调整之外,亦可设置为一个定值,根据实际训练需求而定,本申请对此不作限制。
步骤S206,根据每一个候选锚框的并集,确定学生特征图的蒸馏区域。
可选地,根据每一个候选锚框的并集计算结果,可将学生特征图划分为蒸馏区域和非蒸馏区域,其中,蒸馏区域可对应于样本图像中包含目标对象的目标区域。
综上所述,本实施例通过确定学生特征图中的蒸馏区域,可供学生网络针对样本图像的目标对象的附近区域的特征进行模仿,而非模仿整个特征图,借以避免样本图像中的目标信息在蒸馏过程中淹没在背景信息中,可以大大提高对于正例目标的召回。此外,由于模型的损失函数中还同时考量了学生网络的普通损失函数(即训练损失子函数),因此,对于背景部分可有常规的抑制能力,并最终实现检测精度的提升。
图3示出了本申请另一实施例的模型训练方法的处理流程图。本实施例为上述步骤S106的具体实施方案,如图所示,本实施例主要包括以下步骤:
步骤S302,基于蒸馏区域比对教师特征图和学生特征图,确定模仿损失子函数。
可选地,可根据蒸馏区域、教师特征图、学生特征图、预设模仿损失换算公式,确定模仿损失子函数。
于本实施例中,预设模仿损失换算公式表示为:
其中,Limitation表示模仿损失子函数,teacheri表示教师网络针对样本图像输出的第i个教师特征图;studenti表示学生网络针对样本图像输出的第i个学生特征图,Maski为第i个学生特征图对应的蒸馏区域的特征掩膜;h为第i个教师特征图或第i个学生特征图的高度值,w为第i个教师特征图或第i个学生特征图的宽度值;fadap为针对第i个学生特征图的通道数执行适配处理。
步骤S304,根据样本图像的真实标签、学生网络针对样本图像输出的预测标签,确定训练损失子函数。
于本实施例中,可获取学生网络针对样本图像执行分类预测所获取的预测标签(例如目标对象的预测位置和预测类别),并将预测标签与学生网络的真实标签进行比对,以确定训练损失子函数。
步骤S306,基于模仿损失子函数和训练损失子函数,确定损失函数。
于本实施例中,可基于模仿损失子函数、训练损失子函数、预设损失函数换算公式,确定损失函数。
于本实施例中,预设损失函数换算公式表示为:
L=Lgt+εLimitation
其中,L表示损失函数,Lgt表示训练损失子函数;ε为平衡参数;Limitation表示模仿损失子函数。
综上所述,本申请实施例基于模仿损失子函数和训练损失子函数执行学生网络的训练处理,不仅可着重于样本图像中的目标特征信息,使学生网络能够较好地模仿教师网络在真值框附近的响应,同时学生网络还可受到普通训练过程的监督,能够有效学习到样本图像中的背景特征信息,借以提高模型的整体训练效果。
图4示出了本申请示例性实施例的模型训练装置的结构框图。如图所示,本实施例的模型训练装置400主要包括:特征图生成模块402、蒸馏区域分析模块404、训练模块406。
特征图生成模块402,用于利用教师网络和学生网络分别针对样本图像执行特征提取,获取所述样本图像的教师特征图和学生特征图;
蒸馏区域分析模块404,用于根据所述样本图像的真值框、所述学生特征图,确定所述学生特征图的蒸馏区域;
训练模块406,用于基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定损失函数,并基于所述损失函数训练所述学生网络,以获得训练好的学生网络。
可选地,特征图生成模块402还用于:利用所述教师网络,基于预设维度针对所述样本图像执行特征提取,获取对应于所述预设维度的所述教师特征图;利用所述学生网络,基于所述预设维度针对所述样本图像执行特征提取,获取对应于所述预设维度的所述学生特征图;其中,所述预设维度包括8、16、32中的至少一个。
可选地,所述教师特征图的通道数量与所述学生特征图的通道数量不相同,特征图生成模块402还用于:针对所述学生特征图的通道数执行适配处理,以使所述学生特征图的通道数量与所述教师特征图的通道数量相适配。
可选地,特征图生成模块402还用于:利用预设卷积核针对所述学生特征图的通道数执行卷积处理;所述预设卷积核尺寸为1×1×C2×C1;其中,所述C2为所述学生特征图的通道数量;所述C1为所述教师特征图的通道数量。
可选地,蒸馏区域分析模块404还用于:根据所述样本图像中的目标对象的真值框、所述学生特征图中的各参考锚框,计算所述真值框相对于每一个参考锚框之间的交并比,获得每一个参考锚框的交并比值;根据每一个参考锚框的交并比值、所述学生特征图的交并比阈值,确定各参考锚框中的至少一个候选锚框;根据每一个候选锚框的并集,确定所述学生特征图的蒸馏区域。
可选地,所述学生特征图中的各参考锚框可通过以下方式生成:根据单位锚框设置参数,在所述学生特征图的每一个网格单元中生成满足所述单位锚框设置参数的各参考锚框。
可选地,蒸馏区域分析模块404还用于:根据各参考锚框的交并比值中的最高值、预设全局阈值,确定所述学生特征图的交并比阈值;其中,所述预设全局阈值可介于0.2至0.5之间,较佳为0.3。
可选地,训练模块406还用于:基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定模仿损失子函数;根据所述样本图像的真实标签、所述学生网络针对所述样本图像输出的预测标签,确定训练损失子函数;基于所述模仿损失子函数和所述训练损失子函数,确定所述损失函数。
可选地,训练模块406还用于:根据所述蒸馏区域、所述教师特征图、所述学生特征图、预设模仿损失换算公式,确定所述模仿损失子函数;所述预设模仿损失换算公式表示为:
其中,所述Limitation表示所述模仿损失子函数,所述teacheri表示所述教师网络针对所述样本图像输出的第i个教师特征图;所述studenti表示所述学生网络针对所述样本图像输出的第i个学生特征图,所述Maski为所述第i个学生特征图对应的蒸馏区域的特征掩膜;所述h为所述第i个教师特征图或所述第i个学生特征图的高度值,所述w为所述第i个教师特征图或所述第i个学生特征图的宽度值;所述fadap为针对所述第i个学生特征图的通道数执行适配处理。
可选地,训练模块406还用于:基于所述模仿损失子函数、所述训练损失子函数、预设损失函数换算公式,确定所述损失函数;所述预设损失函数换算公式表示为:L=Lgt+εLimitation;
其中,所述L表示所述损失函数;所述Lgt表示所述训练损失子函数;所述ε为平衡参数;所述Limitation表示所述模仿损失子函数。
可选地,训练模块406还用于:基于所述损失函数优化更新所述学生网络的网络参数,并返回执行所述利用教师网络和学生网络分别针对样本图像执行特征提取的步骤,直至所述损失函数满足收敛条件,以完成所述学生网络的训练。
本申请另一实施例还提供一种计算机存储介质,所述计算机存储介质中存储有用于执行上述各方法实施例中各步骤的各指令。
综上所述,本申请各实施例针对模型剪枝和INT8量化中存在的过程复杂,模型精度损失严重的问题,设计了一种基于细粒度特征模仿的目标检测知识蒸馏方案,可着重蒸馏特征图上真值框附近的区域,使得学生网络能够较好地模仿教师网络在特征图真值框附近的响应,同时学生网络还受到普通训练过程的监督,能够学习到背景信息,以最终在大幅度减少计算量的同时,能够保持较高的精度。
最后应说明的是:以上实施例仅用以说明本申请实施例的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围。
Claims (13)
1.一种模型训练方法,包括:
利用教师网络和学生网络分别针对样本图像执行特征提取,获取所述样本图像的教师特征图和学生特征图;
根据所述样本图像的真值框、所述学生特征图,确定所述学生特征图的蒸馏区域;
基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定损失函数,并基于所述损失函数训练所述学生网络,以获得训练好的学生网络。
2.根据权利要求1所述的模型训练方法,其中,所述利用教师网络和学生网络分别针对样本图像执行特征提取,获取所述样本图像的教师特征图和学生特征图,包括:
利用所述教师网络,基于预设维度针对所述样本图像执行特征提取,获取对应于所述预设维度的所述教师特征图;
利用所述学生网络,基于所述预设维度针对所述样本图像执行特征提取,获取对应于所述预设维度的所述学生特征图;
其中,所述预设维度包括8、16、32中的至少一个。
3.根据权利要求2所述的模型训练方法,其中,所述教师特征图的通道数量与所述学生特征图的通道数量不相同,且所述方法还包括:
针对所述学生特征图的通道数执行适配处理,以使所述学生特征图的通道数量与所述教师特征图的通道数量相适配。
4.根据权利要求3所述的模型训练方法,其中,所述方法还包括:
利用预设卷积核针对所述学生特征图的通道数执行卷积处理;
所述预设卷积核尺寸为1×1×C2×C1;
其中,所述C2为所述学生特征图的通道数量;所述C1为所述教师特征图的通道数量。
5.根据权利要求1所述的模型训练方法,其中,所述根据所述样本图像的真值框、所述学生特征图,确定所述学生特征图的蒸馏区域,包括:
根据所述样本图像中的目标对象的真值框、所述学生特征图中的各参考锚框,计算所述真值框相对于每一个参考锚框之间的交并比,获得每一个参考锚框的交并比值;
根据每一个参考锚框的交并比值、所述学生特征图的交并比阈值,确定各参考锚框中的至少一个候选锚框;
根据每一个候选锚框的并集,确定所述学生特征图的蒸馏区域。
6.根据权利要求5所述的模型训练方法,其中,所述学生特征图中的各参考锚框可通过以下方式生成:
根据单位锚框设置参数,在所述学生特征图的每一个网格单元中生成满足所述单位锚框设置参数的各参考锚框。
7.根据权利要求5所述的模型训练方法,其中,所述方法还包括:
根据各参考锚框的交并比值中的最高值、预设全局阈值,确定所述学生特征图的交并比阈值;
其中,所述预设全局阈值可介于0.2至0.5之间,较佳为0.3。
8.根据权利要求1所述的模型训练方法,其中,所述基于所述蒸馏区域比对所述教师特征图和学生特征图,确定损失函数,包括:
基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定模仿损失子函数;
根据所述样本图像的真实标签、所述学生网络针对所述样本图像输出的预测标签,确定训练损失子函数;
基于所述模仿损失子函数和所述训练损失子函数,确定所述损失函数。
9.根据权利要求8所述的模型训练方法,其中,所述基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定模仿损失子函数,包括:
根据所述蒸馏区域、所述教师特征图、所述学生特征图、预设模仿损失换算公式,确定所述模仿损失子函数;
所述预设模仿损失换算公式表示为:
其中,所述Limitation表示所述模仿损失子函数,所述teacheri表示所述教师网络针对所述样本图像输出的第i个教师特征图;所述studenti表示所述学生网络针对所述样本图像输出的第i个学生特征图,所述Maski为所述第i个学生特征图对应的蒸馏区域的特征掩膜;所述h为所述第i个教师特征图或所述第i个学生特征图的高度值,所述w为所述第i个教师特征图或所述第i个学生特征图的宽度值;所述fadap为针对所述第i个学生特征图的通道数执行适配处理。
10.根据权利要求9所述的模型训练方法,其中,所述基于所述模仿损失子函数和所述训练损失子函数,确定所述损失函数,包括:
基于所述模仿损失子函数、所述训练损失子函数、预设损失函数换算公式,确定所述损失函数;
所述预设损失函数换算公式表示为:
L=Lgt+εLimitation
其中,所述L表示所述损失函数;所述Lgt表示所述训练损失子函数;所述ε为平衡参数;所述Limitation表示所述模仿损失子函数。
11.根据权利要求1所述的模型训练方法,其中,所述基于所述损失函数训练所述学生网络,以获得训练好的学生网络,包括:
基于所述损失函数优化更新所述学生网络的网络参数,并返回执行所述利用教师网络和学生网络分别针对样本图像执行特征提取的步骤,直至所述损失函数满足收敛条件,以完成所述学生网络的训练。
12.一种模型训练装置,包括:
特征图生成模块,用于利用教师网络和学生网络分别针对样本图像执行特征提取,获取所述样本图像的教师特征图和学生特征图;
蒸馏区域分析模块,用于根据所述样本图像的真值框、所述学生特征图,确定所述学生特征图的蒸馏区域;
训练模块,用于基于所述蒸馏区域比对所述教师特征图和所述学生特征图,确定损失函数,并基于所述损失函数训练所述学生网络,以获得训练好的学生网络。
13.一种计算机存储介质,其中,所述计算机存储介质中存储有用于执行上述权利要求1至11中任一项所述的方法中各步骤的各指令。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210540240.5A CN114897160A (zh) | 2022-05-18 | 2022-05-18 | 模型训练方法、系统及计算机存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210540240.5A CN114897160A (zh) | 2022-05-18 | 2022-05-18 | 模型训练方法、系统及计算机存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114897160A true CN114897160A (zh) | 2022-08-12 |
Family
ID=82723618
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210540240.5A Pending CN114897160A (zh) | 2022-05-18 | 2022-05-18 | 模型训练方法、系统及计算机存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114897160A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115456167A (zh) * | 2022-08-30 | 2022-12-09 | 北京百度网讯科技有限公司 | 轻量级模型训练方法、图像处理方法、装置及电子设备 |
CN115829983A (zh) * | 2022-12-13 | 2023-03-21 | 广东工业大学 | 一种基于知识蒸馏的高速工业场景视觉质量检测方法 |
CN116385274A (zh) * | 2023-06-06 | 2023-07-04 | 中国科学院自动化研究所 | 多模态影像引导的脑血管造影质量增强方法和装置 |
-
2022
- 2022-05-18 CN CN202210540240.5A patent/CN114897160A/zh active Pending
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115456167A (zh) * | 2022-08-30 | 2022-12-09 | 北京百度网讯科技有限公司 | 轻量级模型训练方法、图像处理方法、装置及电子设备 |
CN115456167B (zh) * | 2022-08-30 | 2024-03-12 | 北京百度网讯科技有限公司 | 轻量级模型训练方法、图像处理方法、装置及电子设备 |
CN115829983A (zh) * | 2022-12-13 | 2023-03-21 | 广东工业大学 | 一种基于知识蒸馏的高速工业场景视觉质量检测方法 |
CN115829983B (zh) * | 2022-12-13 | 2024-05-03 | 广东工业大学 | 一种基于知识蒸馏的高速工业场景视觉质量检测方法 |
CN116385274A (zh) * | 2023-06-06 | 2023-07-04 | 中国科学院自动化研究所 | 多模态影像引导的脑血管造影质量增强方法和装置 |
CN116385274B (zh) * | 2023-06-06 | 2023-09-12 | 中国科学院自动化研究所 | 多模态影像引导的脑血管造影质量增强方法和装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111062951B (zh) | 一种基于语义分割类内特征差异性的知识蒸馏方法 | |
CN107945204B (zh) | 一种基于生成对抗网络的像素级人像抠图方法 | |
CN114897160A (zh) | 模型训练方法、系统及计算机存储介质 | |
CN111444760A (zh) | 一种基于剪枝与知识蒸馏的交通标志检测与识别方法 | |
CN112150821B (zh) | 轻量化车辆检测模型构建方法、系统及装置 | |
CN112163628A (zh) | 一种适用于嵌入式设备的改进目标实时识别网络结构的方法 | |
CN113486726A (zh) | 一种基于改进卷积神经网络的轨道交通障碍物检测方法 | |
US20230260255A1 (en) | Three-dimensional object detection framework based on multi-source data knowledge transfer | |
CN111104831B (zh) | 一种视觉追踪方法、装置、计算机设备以及介质 | |
CN115223049A (zh) | 面向电力场景边缘计算大模型压缩的知识蒸馏与量化技术 | |
CN113516133A (zh) | 一种多模态图像分类方法及系统 | |
CN111488786A (zh) | 基于cnn的监视用客体检测器的方法及装置 | |
CN115240052A (zh) | 一种目标检测模型的构建方法及装置 | |
CN117217280A (zh) | 神经网络模型优化方法、装置及计算设备 | |
Blier-Wong et al. | Geographic ratemaking with spatial embeddings | |
CN113962388A (zh) | 一种硬件加速感知的神经网络通道剪枝方法 | |
CN109492697A (zh) | 图片检测网络训练方法及图片检测网络训练装置 | |
CN113536944A (zh) | 基于图像识别的配电线路巡检数据识别及分析方法 | |
CN112288084A (zh) | 一种基于特征图通道重要性程度的深度学习目标检测网络压缩方法 | |
CN117576149A (zh) | 一种基于注意力机制的单目标跟踪方法 | |
CN116310328A (zh) | 基于跨图像相似度关系的语义分割知识蒸馏方法及系统 | |
CN117132890A (zh) | 一种基于Kubernetes边缘计算集群的遥感图像目标检测方法和系统 | |
CN115953668A (zh) | 一种基于YOLOv5算法的迷彩伪装目标检测方法及系统 | |
CN114648762A (zh) | 语义分割方法、装置、电子设备和计算机可读存储介质 | |
CN114972429A (zh) | 云边协同自适应推理路径规划的目标追踪方法和系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |