CN117875406A - 基于特征丰富度的知识蒸馏方法、系统、电子设备和介质 - Google Patents
基于特征丰富度的知识蒸馏方法、系统、电子设备和介质 Download PDFInfo
- Publication number
- CN117875406A CN117875406A CN202311408418.1A CN202311408418A CN117875406A CN 117875406 A CN117875406 A CN 117875406A CN 202311408418 A CN202311408418 A CN 202311408418A CN 117875406 A CN117875406 A CN 117875406A
- Authority
- CN
- China
- Prior art keywords
- network model
- feature map
- loss
- distillation
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 54
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 31
- 238000001514 detection method Methods 0.000 claims abstract description 111
- 238000004821 distillation Methods 0.000 claims abstract description 67
- 239000011159 matrix material Substances 0.000 claims abstract description 30
- 238000004364 calculation method Methods 0.000 claims abstract description 24
- 238000012549 training Methods 0.000 claims abstract description 24
- 238000010586 diagram Methods 0.000 claims description 26
- 230000006870 function Effects 0.000 claims description 13
- 238000005457 optimization Methods 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 12
- 230000002708 enhancing effect Effects 0.000 claims description 8
- 238000013507 mapping Methods 0.000 claims description 6
- 238000010276 construction Methods 0.000 claims description 4
- 238000012545 processing Methods 0.000 claims description 4
- 238000010606 normalization Methods 0.000 claims description 3
- 230000008569 process Effects 0.000 abstract description 12
- 238000002372 labelling Methods 0.000 abstract description 6
- 230000007246 mechanism Effects 0.000 abstract description 4
- 238000005070 sampling Methods 0.000 description 4
- 238000012795 verification Methods 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 230000001419 dependent effect Effects 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000008014 freezing Effects 0.000 description 2
- 238000007710 freezing Methods 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 230000008094 contradictory effect Effects 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 238000011895 specific detection Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0495—Quantised networks; Sparse networks; Compressed networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
- G06V10/443—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components by matching or filtering
- G06V10/449—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters
- G06V10/451—Biologically inspired filters, e.g. difference of Gaussians [DoG] or Gabor filters with interaction between the filter responses, e.g. cortical complex cells
- G06V10/454—Integrating the filters into a hierarchical structure, e.g. convolutional neural networks [CNN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/52—Scale-space analysis, e.g. wavelet analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/771—Feature selection, e.g. selecting representative features from a multi-dimensional feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Biomedical Technology (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Biodiversity & Conservation Biology (AREA)
- Image Analysis (AREA)
Abstract
本申请提供一种基于特征丰富度的知识蒸馏方法、系统、电子设备和介质,所述方法采用特征丰富度评分机制,能够有效地选择并利用目标框内外的重要特征,帮助学生网络模型关注在多尺度特征图中具有显著差异的位置,从而可以提高知识蒸馏的效率,使学生网络模型能够更全面地学习教师网络模型的知识,提高检测性能;特征丰富度评分机制能够有效消除目标框内的有害特征,解决了由于标注信息错误或者漏标而导致蒸馏区域不准确的问题,从而更准确地指导学生网络模型的学习,提高了学生网络模型的泛化能力;特征丰富度矩阵存储代价很低,能够在训练过程中节省训练计算资源和内存空间的消耗,在实际部署过程中,小模型在资源受限的设备上也能够高效运行。
Description
技术领域
本申请属于深度学习技术领域,涉及一种基于特征丰富度的知识蒸馏方法、系统、电子设备和介质。
背景技术
近年来,深度学习技术在目标检测任务上取得了显著进展,尤其是引入了深度卷积神经网络。许多目标检测模型在具有巨大的参数和复杂结构的情况下,表现出优异的检测性能。然而在实际应用中,将这些模型部署到计算资源有限的设备(如移动设备、边缘设备或嵌入式系统)面临着巨大的挑战。尽管这些目标检测模型在准确性方面表现优异,但需要一种方法在保持高检测性能的同时,减少模型的复杂性和资源需求。
知识蒸馏是一种模型压缩和加速技术,可以将大模型的知识转迁移到小模型中,从而在保持一定准确性的同时,降低模型的计算和存储开销,同时能够在资源受限设备上获得更好的泛化性能。然而当前目标检测蒸馏方法主要关注近似目标框周围的特征,往往忽略了目标框外的有益特征,这可能导致模型对目标的理解不够全面和准确、在复杂场景中的检测性能下降,并限制其在未见过的数据上的泛化能力。
发明内容
本申请的目的在于提供一种基于特征丰富度的知识蒸馏方法、系统、电子设备和介质,用于解决现有技术无法有效利用目标框内外的特征,导致蒸馏后小模型的泛化能力和性能下降的技术问题。
第一方面,本申请提供一种基于特征丰富度的知识蒸馏方法,其特征在于,包括:
获取样本图像,将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果;
利用所述第一检测结果构建特征丰富度矩阵;
将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果;
基于所述特征丰富度矩阵、所述第一特征图和所述第二特征图,计算加权后的特征蒸馏损失;
基于所述特征丰富度矩阵、所述第一检测结果和所述第二检测结果,计算加权后的分类蒸馏损失;
获取所述学生网络模型的检测任务相关的损失;
联合所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失,以计算蒸馏总损失;
基于所述蒸馏总损失对所述学生网络模型进行迭代优化,直至所述学生网络模型收敛。
在第一方面的一种实现方式中,将所述样本图像输入所述教师网络模型之前,还包括:
训练所述教师网络模型,并固定训练好的所述教师网络模型的训练权重。
在第一方面的一种实现方式中,所述将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果包括:
从所述样本图像中提取多尺度特征;
对所述多尺度特征进行提取和增强处理,以生成所述第一特征图;
将所述第一特征图映射转化为所述第一检测结果;所述第一检测结果包括目标位置信息和目标类别概率得分。
在第一方面的一种实现方式中,所述将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果包括:
从所述样本图像中提取多尺度特征;
对所述多尺度特征进行提取和增强处理,以生成所述第二特征图;
将所述第二特征图映射转化为所述第二检测结果;所述第二检测结果包括目标位置信息和目标类别概率得分。
在第一方面的一种实现方式中,采用如下公式计算所述加权后的特征蒸馏损失:
其中L表示所述第一特征图或所述第二特征图的层级数;
W表示所述第一特征图或所述第二特征图的宽;
H表示所述第一特征图或所述第二特征图的高;
C'表示所述教师网络模型分类的类别数目;
C表示所述第一特征图或所述第二特征图的通道数;
表示所述教师网络模型输出的第l层特征图上位于第k个通道第i行第j列元素的特征值;
表示所述学生网络模型输出的第l层特征图上位于第k个通道第i行第j列元素的特征值;l∈[1,L],i∈[1,W],j∈[1,H],k∈[1,C];
表示对所述第二特征增强网络输出的第l层特征图进行卷积和批量归一化处理;M表示所述特征丰富度矩阵;
yt表示所述第一检测结果中的目标类别概率得分。
在第一方面的一种实现方式中,采用如下公式计算所述加权后的分类蒸馏损失:
其中L表示所述第一特征图或所述第二特征图的层级数
W表示所述第一特征图或所述第二特征图的宽;
H表示所述第一特征图或所述第二特征图的高;
C'表示所述教师网络模型分类的类别数目;
C表示所述第一特征图或所述第二特征图的通道数;
表示所述教师网络模型输出的第l层特征图上第k个通道第i行第j列元素的目标类别概率得分;
表示所述学生网络模型输出的第l层特征图上第k个通道第i行第j列元素的目标类别概率得分;l∈[1,L],i∈[1,W],j∈[1,H],c∈[1,C];
BCEloss表示二值交叉熵损失函数;
M表示所述特征丰富度矩阵;
yt表示所述第一检测结果中的目标类别概率得分。
在第一方面的一种实现方式中,采用如下公式计算所述蒸馏总损失:
L=Ltask+L1+L2
其中Ltask表示所述学生网络模型的检测任务相关的损失函数;
L1表示所述加权后的特征蒸馏损失;
L2表示所述加权后的分类蒸馏损失。
第二方面,本申请提供一种基于特征丰富度的知识蒸馏系统,包括:
第一检测模块,用于获取样本图像,将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果;
构建模块,用于利用所述第一检测结果构建特征丰富度矩阵;
第二检测模块,用于将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果;
第一损失计算模块,用于基于所述特征丰富度矩阵、所述第一特征图和所述第二特征图,计算加权后的特征蒸馏损失;
第二损失计算模块,用于基于所述特征丰富度矩阵、所述第一检测结果和所述第二检测结果,计算加权后的分类蒸馏损失;
第三损失计算模块,用于获取所述学生网络模型的检测任务相关的损失;
总损失计算模块,用于联合所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失,以计算蒸馏总损失;
优化模块,用于基于所述蒸馏总损失对所述学生网络模型进行迭代优化,直至所述学生网络模型收敛。
第三方面,本申请提供一种电子设备,包括:处理器及存储器;
所述存储器用于存储计算机程序;
所述处理器用于执行所述存储器存储的计算机程序,以使所述电子设备执行上述任一项所述的方法。
第四方面,本申请提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述任一项所述的方法。
如上所述,本申请所述的基于特征丰富度的知识蒸馏方法、系统、电子设备和介质,具有以下有益效果:
(1)采用特征丰富度评分机制,能够有效地选择并利用目标框内外的重要特征,帮助学生网络模型关注于那些在多尺度特征图中具有显著差异的位置,从而可以提高知识蒸馏的效率,使小模型能够更全面地学习教师网络模型的知识,提高检测性能;
(2)特征丰富度评分机制能够有效消除目标框内的有害特征,解决了由于标注信息错误或者漏标而导致蒸馏区域不准确的问题,从而能够更准确地指导学生网络模型的学习,避免学生网络模型模仿不合理的特征,提高了学生网络模型的泛化能力;
(3)特征丰富度矩阵存储代价很低,能够在训练过程中节省训练计算资源和内存空间的消耗,在实际部署过程中,小模型在资源受限的设备上也能够高效运行。
附图说明
图1显示为本申请所述的基于特征丰富度的知识蒸馏方法于一实施例中的流程图。
图2显示为本申请所述的基于特征丰富度的知识蒸馏方法于另一实施例中的流程图。
图3显示为本申请所述的目标检测方法于一实施例的流程图。
图4显示为本申请所述的基于特征丰富度的知识蒸馏系统于一实施例中的结构示意图。
图5显示为本申请所述的电子设备于一实施例的结构示意图。
元件标号说明
41 第一检测模块
42 构建模块
43 第二检测模块
44 第一损失计算模块
45 第二损失计算模块
46 第三损失计算模块
47 总损失计算模块
48 优化模块
51 处理器
52 存储器
具体实施方式
以下通过特定的具体实例说明本申请的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本申请的其他优点与功效。本申请还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本申请的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。
需要说明的是,以下实施例中所提供的图示仅以示意方式说明本申请的基本构想,遂图式中仅显示与本申请中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。
另外,在本申请中如涉及“第一”、“第二”等的描述仅用于描述目的,而不能理解为指示或暗示其相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括至少一个该特征。另外,各个实施例之间的技术方案可以相互结合,但是必须是以本领域普通技术人员能够实现为基础,当技术方案的结合出现相互矛盾或无法实现时应当认为这种技术方案的结合不存在,也不在本申请要求的保护范围之内。
本申请以下实施例提供了基于特征丰富度的知识蒸馏方法、系统、电子设备和介质。所述知识蒸馏的基本框架包括教师网络模型和学生网络模型,本申请中所述教师网络模型和所述学生网络模型均为满足端侧部署的目标检测器,其中教师网络模型为高精度的大模型,具有较多参数和较大容量,如ResNet、VGG、Inception、YOLOv8等;学生网络模型为轻量级的小模型,具有较少参数和较低计算复杂度,如MobileNet、ShuffleNet、EfficientNet等。所述教师网络模型的学习能力等性能强于所述学生网络模型,因此可以将教师网络模型学习到的知识迁移到学习能力相对较弱的学生网络模型,以此增强学生网络模型的泛化能力。下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行详细描述。
如图1所示,本实施例提供一种基于特征丰富度的知识蒸馏方法,包括如下步骤S11至步骤S18。
步骤S11、获取样本图像,将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果。
于一实施例中,获取样本图像包括:收集真实场景的图像,以创建图像数据集;对所述图像数据集进行检测框和类别标注,得到标注好的图像和对应的标注文件;确定训练集、验证集和测试集的划分比例;将标注好的图像和对应的标注文件随机分配到所述训练集、验证集和测试集中。
具体地,可以从不同来源获取图像数据集,以确保数据集的多样性和代表性。常见的比例是70%的数据用于训练,10-15%的数据用于验证,剩余的15-20%的数据用于测试。这个比例可以根据实际情况进行调整。通常每个样本会有对应的图像文件和标注文件,标注文件包含类别标签和对应的检测框信息等。
所述将所述样本图像输入所述教师网络模型之前,还包括:训练所述教师网络模型,并固定训练好的所述教师网络模型的训练权重。
于一实施例中,可以使用所述验证集评估所述教师网络模型的性能;当验证集上的性能达到一个满意的水平时,固定所述教师网络模型的训练权重;
于另一实施例中,可以观察所述教师网络模型的训练误差曲线,当误差趋于稳定或收敛到一个较小的值时,固定所述教师网络模型的训练权重;
于再一实施例中,还可以设定一个固定的训练轮数或迭代次数,当教师网络模型达到指定的轮数时,固定所述教师网络模型的训练权重。
需要说明的是,可以根据具体任务和数据集情况选择最佳的冻结时机。一般来说,可以在教师网络模型达到一定的准确性和稳定性之后进行冻结。
本实现方式中,通过冻结教师网络模型的训练权重,可以使其在学生网络模型的训练过程中不再更新,从而可以有效地引导学生网络模型的训练,提高学生模型的性能和泛化能力。学生网络模型也可以利用教师网络模型的预测结果作为辅助信息,从而可以更好地学习教师网络模型的知识。
于一实施例中,所述将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果包括:从所述样本图像中提取多尺度特征;对所述多尺度特征进行提取和增强处理,以生成所述第一特征图;将所述第一特征图映射转化为所述第一检测结果;所述第一检测结果包括目标位置信息和目标类别概率得分。
本实施例中,所述教师网络模型为预先训练好的教师网络模型。
图2显示为本申请所述的基于特征丰富度的知识蒸馏方法于另一实施例的流程图。
如图2所示,所述教师网络模型包括第一主干网络、第一特征增强网络和第一检测头。
具体地,所述第一主干网络可以理解为图像的特征提取器,它将输入的样本图像逐层处理,提取出不同层次(不同尺度或分辨率)的特征表示。所述第一主干网络的设计可以是基于经典的卷积神经网络(如VGG和ResNet等),也可以是自定义的网络结构。例如,所述第一主干网络可以通过一系列的卷积层从所述待训练的图片数据集中提取不同尺度的特征。
所述第一特征增强网络包括一些附加的卷积层、池化层或上采样层,用于对所述不同尺度的特征进行进一步提取和增强处理,以生成第一特征图。所述第一特征增强网络可以通过上采样、下采样、跳跃连接等方式将这些不同尺度的特征图进行融合,以获取更全面和丰富的特征表示,以便更好地进行目标检测。
所述第一检测头包括第一分类分支和第一回归分支。所述第一分类分支用于根据所述第一特征图预测每个位置(每个锚框/检测框)的目标类别得分。所述第一回归分支用于根据所述第一特征图预测每个位置(每个锚框/检测框)的目标位置信息。
步骤S12、利用所述第一检测结果构建特征丰富度矩阵。
所述特征丰富度矩阵的大小与所述样本图像的大小相对应,矩阵中每个位置上的值表示该位置的所述目标类别概率得分。所述特征丰富度矩阵可以体现类别的分布差异,用来指导后续的蒸馏损失的计算。
本实现方式中,利用教师网络模型中第一检测头的分类分支输出的目标类别概率得分对多尺度特征图进行蒸馏处理,能有效地选择并利用目标框内外的重要特征,可以帮助学生网络模型关注于那些在多尺度特征图中具有显著差异的位置,从而可以提高知识蒸馏的效率,使小模型能够更全面地学习教师模型的知识,提高检测性能。
步骤S13、将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果。
于一实施例中,所述将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果包括:从所述样本图像中提取多尺度特征;对所述多尺度特征进行提取和增强处理,以生成所述第二特征图;将所述第二特征图映射转化为所述第二检测结果;所述第二检测结果包括目标位置信息和目标类别概率得分。
本实施例中,所述学生网络模型包括第二主干网络、第二特征增强网络和第二检测头。
具体地,所述第二主干网络可以理解为图像的特征提取器,它将输入的样本图像逐层处理,提取出不同层次(不同尺度或分辨率)的特征表示。
所述第二特征增强网络用于对所述不同尺度的特征进行进一步提取和增强处理,以生成第二特征图。所述第二特征增强网络可以通过上采样、下采样、跳跃连接等方式将这些不同尺度的特征图进行融合,以获取更全面和丰富的特征表示,以便更好地进行目标检测。
所述第二检测头包括第二分类分支和第二回归分支。所述第二分类分支用于根据所述第二特征图预测每个位置(每个锚框/检测框)的目标类别得分。所述第二回归分支用于根据所述第二特征图预测每个位置(每个锚框/检测框)的目标位置信息。
步骤S14、基于所述特征丰富度矩阵、所述第一特征图和所述第二特征图,计算加权后的特征蒸馏损失。
本申请的加权后的特征蒸馏损失为L2损失,于一实施例中,采用如下公式计算所述加权后的特征蒸馏损失:
其中L表示所述第一特征图或所述第二特征图的层级数;W表示所述第一特征图或所述第二特征图的宽;H表示所述第一特征图或所述第二特征图的高;C'表示所述教师网络模型分类的类别数目;C表示所述第一特征图或所述第二特征图的通道数;表示所述教师网络模型输出的第l层特征图上位于第k个通道第i行第j列元素的特征值;/>表示所述学生网络模型输出的第l层特征图上位于第k个通道第i行第j列元素的特征值;l∈[1,L],i∈[1,W],j∈[1,H],k∈[1,C];/>表示对所述第二特征增强网络输出的第l层特征图进行卷积和批量归一化处理;M表示所述特征丰富度矩阵;yt表示所述第一检测结果中的目标类别概率得分。
需要说明的是,所述第一特征图和所述第二特征图的层级数、宽、高和通道数均相等。
本申请中,计算丰富度矩阵等同于计算所述教师网络模型的分类得分最大值。
步骤S15、基于所述特征丰富度矩阵、所述第一检测结果和所述第二检测结果,计算加权后的分类蒸馏(KD)损失。
于一实施例中,采用如下公式计算所述加权后的分类蒸馏损失:
其中L表示所述第一特征图或所述第二特征图的层级数;W表示所述第一特征图或所述第二特征图的宽;H表示所述第一特征图或所述第二特征图的高;C'表示所述教师网络模型分类的类别数目;C表示所述第一特征图或所述第二特征图的通道数;表示所述教师网络模型输出的第l层特征图上第k个通道第i行第j列元素的目标类别概率得分;表示所述学生网络模型输出的第l层特征图上第k个通道第i行第j列元素的目标类别概率得分;l∈[1,L],i∈[1,W],j∈[1,H],c∈[1,C];BCEloss表示二值交叉熵损失函数;M表示所述特征丰富度矩阵;yt表示所述第一检测结果中的目标类别概率得分。
需要说明的是,所述第一特征图和所述第二特征图的层级数、宽、高和通道数均相等。
步骤S16、获取所述学生网络模型的检测任务相关的损失。
在训练过程中,除了使用教师网络模型的输出结果来计算特征蒸馏损失和分类蒸馏损失之外,还需要使用学生网络模型的输出结果来计算学生网络模型的检测任务相关的损失。
具体地,学生网络模型的检测任务相关的损失是通过比较学生网络模型的检测结果和真实标签来计算的。通常使用一些常见的损失函数,如交叉熵损失函数或均方误差损失函数。
需要说明的是,在计算学生网络模型的检测任务相关的损失时,可以根据具体的检测任务选择合适的损失函数。例如,在目标检测任务中,可以使用目标的边界框坐标和类别标签之间的差异来计算损失。
本实现方式中,通过对学生网络模型的检测任务相关的损失进行优化,可以使得学生网络模型在检测任务上的性能逐渐接近于教师网络模型,从而实现知识蒸馏的目标。
步骤S17、联合所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失,以计算蒸馏总损失。
于一实施例中,采用如下公式计算所述蒸馏总损失:
L=Ltask+L1+L2
其中Ltask表示所述学生网络模型的检测任务相关的损失函数;L1表示所述加权后的特征蒸馏损失;L2表示所述加权后的分类蒸馏损失。
需要说明的是,也可以调整所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失参与蒸馏总损失计算的比例,例如,采用如下公式计算所述蒸馏总损失:
L=αLtask+βL1+γL2
其中α+β+γ=1。
如果希望所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失中任一项在蒸馏总损失中占比较大,可以将其权重系数设置为较大的值。通过调整不同损失的权重,可以根据具体问题和需求来平衡不同损失的重要性,从而得到更好的蒸馏效果。
本申请的蒸馏总损失可以通过反向传播算法来更新学生网络模型的参数,以使得学生网络模型能够更好地逼近教师网络模型的检测结果。
步骤S18、基于所述蒸馏总损失对所述学生网络模型进行迭代优化,直至所述学生网络模型收敛。
当模型的损失迭代优化收敛之后,只需要用学生网络模型作为实际部署的模型。这是因为,学生网络模型则通常具有较少的参数量,通过将教师网络模型的知识转移到学生网络模型中,可以减小模型的体积,从而降低模型在部署过程中的存储和传输成本。此外,较小的学生模型通常具有更快的推理速度。经过了长时间的训练和复杂的优化过程后,学生网络模型的性能和鲁棒性提高,可以更好地满足实际应用的需求。
需要说明的是,本申请实施例所述的基于特征丰富度的知识蒸馏方法的保护范围不限于本实施例列举的步骤执行顺序,凡是根据本申请的原理所做的现有技术的步骤增减、步骤替换所实现的方案都包括在本申请的保护范围内。
如图3所示,本申请还提供一种目标检测方法,包括:步骤S21和S22。
步骤S21、获取待检测图像。
步骤S22、将所述待检测图像输入迭代优化后的学生网络模型,得到目标检测结果。
需要说明的是,所述迭代优化后的学生网络模型为通过上述任一项所述的基于特征丰富度的知识蒸馏方法得到。
如图4所示,本申请还提供一种基于特征丰富度的知识蒸馏系统,包括:
第一检测模块41,用于获取样本图像,将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果。
构建模块42,用于利用所述第一检测结果构建特征丰富度矩阵。
第二检测模块43,用于将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果。
第一损失计算模块44,用于基于所述特征丰富度矩阵、所述第一特征图和所述第二特征图,计算加权后的特征蒸馏损失。
第二损失计算模块45,用于基于所述特征丰富度矩阵、所述第一检测结果和所述第二检测结果,计算加权后的分类蒸馏损失。
第三损失计算模块46,用于获取所述学生网络模型的检测任务相关的损失。
总损失计算模块47,用于联合所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失,以计算蒸馏总损失。
优化模块48,用于基于所述蒸馏总损失对所述学生网络模型进行迭代优化,直至所述学生网络模型收敛。
需要说明的是,本实施例中的所述第一检测模块41、构建模块42、第二检测模块43、第一损失计算模块44、第二损失计算模块45、第三损失计算模块46、总损失计算模块47和优化模块48的结构和原理与上述基于特征丰富度的知识蒸馏方法中的步骤一一对应,故在此不再赘述。
本申请实施例提供的基于特征丰富度的知识蒸馏系统可以实现本申请所述基于特征丰富度的知识蒸馏方法,但本申请所述的基于特征丰富度的知识蒸馏方法的实现装置包括但不限于本实施例列举的基于特征丰富度的知识蒸馏系统的结构,凡是根据本申请的原理所做的现有技术的结构变形和替换,都包括在本申请的保护范围内。
如图5所示,本实施例提供一种电子设备,包括:处理器51及存储器52。
所述存储器52用于存储计算机程序。
所述处理器51用于执行所述存储器52存储的计算机程序,以使所述电子设备执行上述任一项所述的方法。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统、装置或方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅是示意性的,例如,模块/单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个模块或单元可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或模块或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
作为分离部件说明的模块/单元可以是或者也可以不是物理上分开的,作为模块/单元显示的部件可以是或者也可以不是物理模块,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块/单元来实现本申请实施例的目的。例如,在本申请各个实施例中的各功能模块/单元可以集成在一个处理模块中,也可以是各个模块/单元单独物理存在,也可以两个或两个以上模块/单元集成在一个模块/单元中。
本领域普通技术人员应该还可以进一步意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
本申请实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述任一项所述的方法。本领域普通技术人员可以理解实现上述实施例的方法中的全部或部分步骤是可以通过程序来指令处理器完成,所述的程序可以存储于计算机可读存储介质中,所述存储介质是非短暂性(non-transitory)介质,例如随机存取存储器,只读存储器,快闪存储器,硬盘,固态硬盘,磁带(magnetic tape),软盘(floppydisk),光盘(optical disc)及其任意组合。上述存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。该可用介质可以是磁性介质(例如,软盘、硬盘、磁带)、光介质(例如数字视频光盘(digital videodisc,DVD))、或者半导体介质(例如固态硬盘(solid state disk,SSD))等。
本申请实施例还可以提供一种计算机程序产品,所述计算机程序产品包括一个或多个计算机指令。在计算设备上加载和执行所述计算机指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机或数据中心进行传输。
所述计算机程序产品被计算机执行时,所述计算机执行前述方法实施例所述的方法。该计算机程序产品可以为一个软件安装包,在需要使用前述方法的情况下,可以下载该计算机程序产品并在计算机上执行该计算机程序产品。
上述各个附图对应的流程或结构的描述各有侧重,某个流程或结构中没有详述的部分,可以参见其他流程或结构的相关描述。
上述实施例仅例示性说明本申请的原理及其功效,而非用于限制本申请。任何熟悉此技术的人士皆可在不违背本申请的精神及范畴下,对上述实施例进行修饰或改变。因此,举凡所属技术领域中具有通常知识者在未脱离本申请所揭示的精神与技术思想下所完成的一切等效修饰或改变,仍应由本申请的权利要求所涵盖。
Claims (10)
1.一种基于特征丰富度的知识蒸馏方法,其特征在于,包括:
获取样本图像,将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果;
利用所述第一检测结果构建特征丰富度矩阵;
将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果;
基于所述特征丰富度矩阵、所述第一特征图和所述第二特征图,计算加权后的特征蒸馏损失;
基于所述特征丰富度矩阵、所述第一检测结果和所述第二检测结果,计算加权后的分类蒸馏损失;
获取所述学生网络模型的检测任务相关的损失;
联合所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失,以计算蒸馏总损失;
基于所述蒸馏总损失对所述学生网络模型进行迭代优化,直至所述学生网络模型收敛。
2.根据权利要求1所述的方法,其特征在于,所述将所述样本图像输入所述教师网络模型之前,还包括:
训练所述教师网络模型,并固定训练好的所述教师网络模型的训练权重。
3.根据权利要求1所述的方法,其特征在于,所述将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果包括:
从所述样本图像中提取多尺度特征;
对所述多尺度特征进行提取和增强处理,以生成所述第一特征图;
将所述第一特征图映射转化为所述第一检测结果;所述第一检测结果包括目标位置信息和目标类别概率得分。
4.根据权利要求1所述的方法,其特征在于,所述将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果包括:
从所述样本图像中提取多尺度特征;
对所述多尺度特征进行提取和增强处理,以生成所述第二特征图;
将所述第二特征图映射转化为所述第二检测结果;所述第二检测结果包括目标位置信息和目标类别概率得分。
5.根据权利要求4所述的方法,其特征在于,采用如下公式计算所述加权后的特征蒸馏损失:
其中L表示所述第一特征图或所述第二特征图的层级数;
W表示所述第一特征图或所述第二特征图的宽;
H表示所述第一特征图或所述第二特征图的高;
C'表示所述教师网络模型分类的类别总数;
C表示所述第一特征图或所述第二特征图的通道数;
表示所述教师网络模型输出的第l层特征图上位于第k个通道第i行第j列元素的特征值;
表示所述学生网络模型输出的第l层特征图上位于第k个通道第i行第j列元素的特征值;l∈[1,L],i∈[1,W],j∈[1,H],k∈[1,C];
表示对所述第二特征增强网络输出的第l层特征图进行卷积和批量归一化处理;
M表示所述特征丰富度矩阵;
yt表示所述第一检测结果中的目标类别概率得分。
6.根据权利要求4所述的方法,其特征在于,采用如下公式计算所述加权后的分类蒸馏损失:
其中L表示所述第一特征图或所述第二特征图的层级数;
W表示所述第一特征图或所述第二特征图的宽;
H表示所述第一特征图或所述第二特征图的高;
C'表示所述教师网络模型分类的类别总数;
C表示所述第一特征图或所述第二特征图的通道数;
表示所述教师网络模型输出的第l层特征图上第k个通道第i行第j列元素的目标类别概率得分;
表示所述学生网络模型输出的第l层特征图上第k个通道第i行第j列元素的目标类别概率得分;l∈[1,L],i∈[1,W],j∈[1,H],k∈[1,C];
BCEloss表示二值交叉熵损失函数;
M表示所述特征丰富度矩阵;
yt表示所述第一检测结果中的目标类别概率得分。
7.根据权利要求1所述的方法,其特征在于,采用如下公式计算所述蒸馏总损失:
L=Ltask+L1+L2
其中Ltask表示所述学生网络模型的检测任务相关的损失函数;
L1表示所述加权后的特征蒸馏损失;
L2表示所述加权后的分类蒸馏损失。
8.一种基于特征丰富度的知识蒸馏系统,其特征在于,包括:
第一检测模块,用于获取样本图像,将所述样本图像输入教师网络模型,得到第一特征图和第一检测结果;
构建模块,用于利用所述第一检测结果构建特征丰富度矩阵;
第二检测模块,用于将所述样本图像输入学生网络模型,得到第二特征图和第二检测结果;
第一损失计算模块,用于基于所述特征丰富度矩阵、所述第一特征图和所述第二特征图,计算加权后的特征蒸馏损失;
第二损失计算模块,用于基于所述特征丰富度矩阵、所述第一检测结果和所述第二检测结果,计算加权后的分类蒸馏损失;
第三损失计算模块,用于获取所述学生网络模型的检测任务相关的损失;
总损失计算模块,用于联合所述加权后的特征蒸馏损失、所述加权后的分类蒸馏损失和所述学生网络模型的检测任务相关的损失,以计算蒸馏总损失;
优化模块,用于基于所述蒸馏总损失对所述学生网络模型进行迭代优化,直至所述学生网络模型收敛。
9.一种电子设备,其特征在于,包括:处理器及存储器;
所述存储器用于存储计算机程序;
所述处理器用于执行所述存储器存储的计算机程序,以使所述电子设备执行权利要求1至7中任一项所述的方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现权利要求1至7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311408418.1A CN117875406A (zh) | 2023-10-27 | 2023-10-27 | 基于特征丰富度的知识蒸馏方法、系统、电子设备和介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311408418.1A CN117875406A (zh) | 2023-10-27 | 2023-10-27 | 基于特征丰富度的知识蒸馏方法、系统、电子设备和介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117875406A true CN117875406A (zh) | 2024-04-12 |
Family
ID=90580020
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311408418.1A Pending CN117875406A (zh) | 2023-10-27 | 2023-10-27 | 基于特征丰富度的知识蒸馏方法、系统、电子设备和介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117875406A (zh) |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112200062A (zh) * | 2020-09-30 | 2021-01-08 | 广州云从人工智能技术有限公司 | 一种基于神经网络的目标检测方法、装置、机器可读介质及设备 |
CN112766087A (zh) * | 2021-01-04 | 2021-05-07 | 武汉大学 | 一种基于知识蒸馏的光学遥感图像舰船检测方法 |
CN113936295A (zh) * | 2021-09-18 | 2022-01-14 | 中国科学院计算技术研究所 | 基于迁移学习的人物检测方法和系统 |
WO2023279693A1 (zh) * | 2021-07-09 | 2023-01-12 | 平安科技(深圳)有限公司 | 知识蒸馏方法、装置、终端设备及介质 |
CN115861997A (zh) * | 2023-02-27 | 2023-03-28 | 松立控股集团股份有限公司 | 一种关键前景特征引导知识蒸馏的车牌检测识别方法 |
CN116824336A (zh) * | 2023-06-27 | 2023-09-29 | 上海艺冉医疗科技股份有限公司 | 基于误差修正与区域权重引导知识蒸馏的影像检测方法 |
-
2023
- 2023-10-27 CN CN202311408418.1A patent/CN117875406A/zh active Pending
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112200062A (zh) * | 2020-09-30 | 2021-01-08 | 广州云从人工智能技术有限公司 | 一种基于神经网络的目标检测方法、装置、机器可读介质及设备 |
CN112766087A (zh) * | 2021-01-04 | 2021-05-07 | 武汉大学 | 一种基于知识蒸馏的光学遥感图像舰船检测方法 |
WO2023279693A1 (zh) * | 2021-07-09 | 2023-01-12 | 平安科技(深圳)有限公司 | 知识蒸馏方法、装置、终端设备及介质 |
CN113936295A (zh) * | 2021-09-18 | 2022-01-14 | 中国科学院计算技术研究所 | 基于迁移学习的人物检测方法和系统 |
CN114419667A (zh) * | 2021-09-18 | 2022-04-29 | 中国科学院计算技术研究所 | 基于迁移学习的人物检测方法和系统 |
CN115861997A (zh) * | 2023-02-27 | 2023-03-28 | 松立控股集团股份有限公司 | 一种关键前景特征引导知识蒸馏的车牌检测识别方法 |
CN116824336A (zh) * | 2023-06-27 | 2023-09-29 | 上海艺冉医疗科技股份有限公司 | 基于误差修正与区域权重引导知识蒸馏的影像检测方法 |
Non-Patent Citations (1)
Title |
---|
张云峰: "机器学习算法理论与应用", 30 June 2022, 中国海洋大学出版社, pages: 41 - 42 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11100266B2 (en) | Generating integrated circuit floorplans using neural networks | |
WO2022083536A1 (zh) | 一种神经网络构建方法以及装置 | |
CN110852447B (zh) | 元学习方法和装置、初始化方法、计算设备和存储介质 | |
CN116261731A (zh) | 基于多跳注意力图神经网络的关系学习方法与系统 | |
CN106960219A (zh) | 图片识别方法及装置、计算机设备及计算机可读介质 | |
CN113065013B (zh) | 图像标注模型训练和图像标注方法、系统、设备及介质 | |
CN111612010A (zh) | 图像处理方法、装置、设备以及计算机可读存储介质 | |
US20220383036A1 (en) | Clustering data using neural networks based on normalized cuts | |
CN114556364B (zh) | 用于执行神经网络架构搜索的计算机实现方法 | |
WO2020019102A1 (en) | Methods, systems, articles of manufacture and apparatus to train a neural network | |
CN114972877B (zh) | 一种图像分类模型训练方法、装置及电子设备 | |
CN114417058A (zh) | 一种视频素材的筛选方法、装置、计算机设备和存储介质 | |
CN116362325A (zh) | 一种基于模型压缩的电力图像识别模型轻量化应用方法 | |
CN114781611A (zh) | 自然语言处理方法、语言模型训练方法及其相关设备 | |
CN113592008B (zh) | 小样本图像分类的系统、方法、设备及存储介质 | |
CN114078203A (zh) | 一种基于改进pate的图像识别方法和系统 | |
CN112966743A (zh) | 基于多维度注意力的图片分类方法、系统、设备及介质 | |
CN112633246A (zh) | 开放场景中多场景识别方法、系统、设备及存储介质 | |
Zerrouk et al. | Evolutionary algorithm for optimized CNN architecture search applied to real-time boat detection in aerial images | |
CN117875406A (zh) | 基于特征丰富度的知识蒸馏方法、系统、电子设备和介质 | |
CN117010480A (zh) | 模型训练方法、装置、设备、存储介质及程序产品 | |
CN112348161B (zh) | 神经网络的训练方法、神经网络的训练装置和电子设备 | |
CN114444654A (zh) | 一种面向nas的免训练神经网络性能评估方法、装置和设备 | |
CN111062477B (zh) | 一种数据处理方法、装置及存储介质 | |
CN112348045B (zh) | 神经网络的训练方法、训练装置和电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |