CN115880486B - 一种目标检测网络蒸馏方法、装置、电子设备及存储介质 - Google Patents
一种目标检测网络蒸馏方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN115880486B CN115880486B CN202310169069.6A CN202310169069A CN115880486B CN 115880486 B CN115880486 B CN 115880486B CN 202310169069 A CN202310169069 A CN 202310169069A CN 115880486 B CN115880486 B CN 115880486B
- Authority
- CN
- China
- Prior art keywords
- intermediate feature
- target detection
- network
- distillation
- detection network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 157
- 238000001514 detection method Methods 0.000 title claims abstract description 144
- 238000000034 method Methods 0.000 title claims abstract description 49
- 238000012549 training Methods 0.000 claims abstract description 66
- 238000013138 pruning Methods 0.000 claims abstract description 18
- 239000011159 matrix material Substances 0.000 claims description 22
- 238000004364 calculation method Methods 0.000 claims description 18
- 238000010606 normalization Methods 0.000 claims description 9
- 230000008569 process Effects 0.000 claims description 9
- 238000005457 optimization Methods 0.000 claims 1
- 230000000694 effects Effects 0.000 abstract description 7
- 238000010586 diagram Methods 0.000 description 11
- 238000004590 computer program Methods 0.000 description 7
- 238000013140 knowledge distillation Methods 0.000 description 7
- 238000013139 quantization Methods 0.000 description 6
- 230000004044 response Effects 0.000 description 5
- 102100030148 Integrator complex subunit 8 Human genes 0.000 description 4
- 101710092891 Integrator complex subunit 8 Proteins 0.000 description 4
- 230000006870 function Effects 0.000 description 4
- 238000012545 processing Methods 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 230000002950 deficient Effects 0.000 description 2
- 238000001914 filtration Methods 0.000 description 2
- 241000282472 Canis lupus familiaris Species 0.000 description 1
- 230000001133 acceleration Effects 0.000 description 1
- 230000004075 alteration Effects 0.000 description 1
- 230000000295 complement effect Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000011002 quantification Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Image Analysis (AREA)
Abstract
本发明公开了一种目标检测网络蒸馏方法、装置、电子设备及存储介质,用于解决现有的目标检测网络蒸馏方式蒸馏效果较差的技术问题。本发明包括:获取预训练目标检测网络;对所述预训练目标检测网络进行剪枝,得到学生网络;将预设检测图像输入所述预训练目标检测网络,得到第一中间特征;将所述预设检测图像输入所述学生网络,得到第二中间特征;根据所述第一中间特征和所述第二中间特征计算蒸馏损失;根据所述蒸馏损失优化所述学生网络,得到目标检测网络。
Description
技术领域
本发明涉及知识蒸馏技术领域,尤其涉及一种目标检测网络蒸馏方法、装置、电子设备及存储介质。
背景技术
目标检测是当前计算机视觉领域的一个重要分支,具有极广泛的应用场景,相比于分类网络,它的参数量更多,模型结构更加复杂。自知识蒸馏被提出以来,以被广泛应用于模型压缩领域,其操作简单,相比于直接训练一个小的模型,知识蒸馏只需要先训练一个大的教师网络,再使用这个教师网络蒸馏小的学生网络,便可带来性能的提升。知识蒸馏问题的核心在于教师网络对学生网络的监督损失,即教师网络如何将最关键的信息传递给学生网络。现有的大多数知识蒸馏技术都是使用一个强大的教师网络训练一个较弱的学生网络。但如果两个模型的容量差异过大,可能导致蒸馏损失主导学生网络的训练方向,从而导致网络欠拟合;一些较新的技术采用中间特征作为教师网络和学生网络之间知识传递的桥梁。但是,中间特征位置的选取、特征损失的权重等还缺乏令人信服的解释,导致蒸馏效果较差。
发明内容
本发明提供了一种目标检测网络蒸馏方法、装置、电子设备及存储介质,用于解决现有的目标检测网络蒸馏方式蒸馏效果较差的技术问题。
本发明提供了一种目标检测网络蒸馏方法,包括:
获取预训练目标检测网络;
对所述预训练目标检测网络进行剪枝,得到学生网络;
将预设检测图像输入所述预训练目标检测网络,得到第一中间特征;
将所述预设检测图像输入所述学生网络,得到第二中间特征;
根据所述第一中间特征和所述第二中间特征计算蒸馏损失;
根据所述蒸馏损失优化所述学生网络,得到目标检测网络。
可选地,所述对所述预训练目标检测网络进行剪枝,得到学生网络的步骤,包括:
获取所述预训练目标检测网络的拟归一化层的缩放因子;
对所述缩放因子进行稀疏化,确定所述预训练目标检测网络各通道的绝对值;
移除绝对值小于预设阈值的通道,得到学生网络。
可选地,所述蒸馏损失包括第一蒸馏损失;所述根据所述第一中间特征和所述第二中间特征计算蒸馏损失的步骤,包括:
获取所述预训练目标检测网络中各特征的位置信息;
采用所述位置信息生成二维关键性矩阵;
以所述二维关键性矩阵作为权重,结合所述第一中间特征和所述第二中间特征计算第一蒸馏损失。
可选地,所述蒸馏损失还包括第二蒸馏损失,所述根据所述第一中间特征和所述第二中间特征计算蒸馏损失的步骤,还包括:
根据所述第一中间特征获取第一关键性向量;
以所述第一关键性向量作为损失系数,结合所述第一中间特征和所述第二中间特征计算第二蒸馏损失。
可选地,所述第一关键性向量包括第一通道关键性向量、第一空间关键性向量和第一逐点关键性向量。
可选地,所述蒸馏损失还包括第三蒸馏损失,所述根据所述第一中间特征和所述第二中间特征计算蒸馏损失的步骤,还包括:
获取所述第一中间特征的第一梯度,以及获取所述第二中间特征的第二梯度;
采用所述第一梯度计算第二关键性向量;
以所述第二关键性向量作为损失系数,结合所述第一梯度和所述第二梯度计算第三蒸馏损失。
本发明还提供了一种目标检测网络蒸馏装置,包括:
预训练目标检测网络获取模块,用于获取预训练目标检测网络;
剪枝模块,用于对所述预训练目标检测网络进行剪枝,得到学生网络;
第一中间特征获取模块,用于将预设检测图像输入所述预训练目标检测网络,得到第一中间特征;
第二中间特征获取模块,用于将所述预设检测图像输入所述学生网络,得到第二中间特征;
蒸馏损失计算模块,用于根据所述第一中间特征和所述第二中间特征计算蒸馏损失;
优化模块,用于根据所述蒸馏损失优化所述学生网络,得到目标检测网络。
可选地,所述剪枝模块,包括:
缩放因子获取子模块,用于获取所述预训练目标检测网络的拟归一化层的缩放因子;
绝对值确定子模块,用于对所述缩放因子进行稀疏化,确定所述预训练目标检测网络各通道的绝对值;
学生网络获取子模块,用于移除绝对值小于预设阈值的通道,得到学生网络。
本发明还提供了一种电子设备,所述设备包括处理器以及存储器:
所述存储器用于存储程序代码,并将所述程序代码传输给所述处理器;
所述处理器用于根据所述程序代码中的指令执行如上任一项所述的目标检测网络蒸馏方法。
本发明还提供了一种计算机可读存储介质,所述计算机可读存储介质用于存储程序代码,所述程序代码用于执行如上任一项所述的目标检测网络蒸馏方法。
从以上技术方案可以看出,本发明具有以下优点:本发明提供了一种目标检测网络蒸馏方法,包括:获取预训练目标检测网络;对预训练目标检测网络进行剪枝,得到学生网络;将预设检测图像输入预训练目标检测网络,得到第一中间特征;将预设检测图像输入学生网络,得到第二中间特征;根据第一中间特征和第二中间特征计算蒸馏损失;根据蒸馏损失优化学生网络,得到目标检测网络。本发明通过中间特征来计算蒸馏损失,再根据蒸馏损失对学生网络进行优化,得到优化后的目标检测网络,从而提高了目标检测网络的蒸馏效果。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其它的附图。
图1为本发明实施例提供的一种目标检测网络蒸馏方法的步骤流程图;
图2为本发明另一实施例提供的一种目标检测网络蒸馏方法的步骤流程图;
图3为图像及生成的二维关键性矩阵示意图;
图4为普通卷积流程的流程示意图;
图5为量化卷积流程的流程示意图;
图6为本发明实施例提供的一种目标检测网络蒸馏装置的结构框图。
具体实施方式
本发明实施例提供了一种目标检测网络蒸馏方法、装置、电子设备及存储介质,用于解决现有的目标检测网络蒸馏方式蒸馏效果较差的技术问题。
为使得本发明的发明目的、特征、优点能够更加的明显和易懂,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,下面所描述的实施例仅仅是本发明一部分实施例,而非全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
请参阅图1,图1为本发明实施例提供的一种目标检测网络蒸馏方法的步骤流程图。
本发明提供的一种目标检测网络蒸馏方法,具体可以包括以下步骤:
步骤101,获取预训练目标检测网络;
在本发明实施例中,可以通过采集图像数据,以人工分拣的方式,挑选出所需的含有欲进行目标检测的图像,生成目标检测数据集。然后利用标注工具,对图像中的正常设备和缺陷设备进行标注,使用ImageNet上预训练的ResNet50和MobileNetV2作为主干网络,进行RetinaNet检测网络的预训练,得到预训练目标检测网络。
步骤102,对预训练目标检测网络进行剪枝,得到学生网络;
在本发明实施例中,可以通过对预训练目标检测网络进行剪枝来得到学生网络。
在实际应用中,由于BN广泛应用于各种网络结构当中,因此在进行剪枝时,就不需要添加额外的算子来达到剪枝的目的。对于一个现有的网络,只需要直接对BN的缩放因子进行稀疏化,然后将绝对值趋向于0的通道剪去即可。预训练目标检测网络有大量的卷积层和BN层进行直连,因此可以利用BN层的权重系数来判断对应通道的重要性,移除掉不重要的通道,从而减少前面一层卷积层的参数。
BN层(Batch Normalization layer):拟归一化层,神经网络中一种用于改善人工神经网络的性能和稳定性的结构,是一种为神经网络中的任何层提供零均值/单位方差输入的技术结构。
步骤103,将预设检测图像输入预训练目标检测网络,得到第一中间特征;
步骤104,将预设检测图像输入学生网络,得到第二中间特征;
在目标检测中,对于一张输入的预设检测图像,目标检测网络应更加关注特定位置上的目标,而大部分背景对于网络来说都属于噪声,因此在蒸馏时,将背景过滤掉可以减少噪声或是无效信息对学生网络的干扰。
在具体实现中,可以通过中间特征来筛除噪声。
步骤105,根据第一中间特征和第二中间特征计算蒸馏损失;
蒸馏损失,是知识蒸馏过程中教师网络与学生网络之间的差异性。蒸馏损失越小,表征学生网络与教师网络(预训练目标检测网络)的性能差异越小。
知识蒸馏是模型压缩的一种常用方法,是通过构建一个轻量化的小模型(学生网络),利用性能更好的大模型的监督信息(预训练目标检测网络)来训练这个小模型。来自大模型输出的监督信息称之为“知识”,而小模型学习迁移来自教师网络的监督信息的过程称之为“蒸馏”。
在具体实现中,在获取到预训练目标检测网络的第一中间特征和学生网络的第二中间特征后,可以基于第一中间特征和第二中间特征来计算蒸馏损失,以根据蒸馏损失来判断预训练目标检测网络与学生网络之间的差异。
步骤106,根据蒸馏损失优化学生网络,得到目标检测网络。
在本发明实施例中,在获取到蒸馏损失后,可以往减小蒸馏损失的方向去调整学生网络的参数,通过多次迭代调整,得到蒸馏损失尽可能小的目标检测网络。
本发明通过中间特征来计算蒸馏损失,再根据蒸馏损失对学生网络进行优化,得到优化后的目标检测网络,从而提高了目标检测网络的蒸馏效果。
请参阅图2,图2为本发明另一实施例提供的一种目标检测网络蒸馏方法的步骤流程图。
步骤201,获取预训练目标检测网络;
在本发明实施例中,可以通过采集图像数据,以人工分拣的方式,挑选出所需的含有欲进行目标检测的图像,生成目标检测数据集。然后利用标注工具,对图像中的正常设备和缺陷设备进行标注,使用ImageNet上预训练的ResNet50和MobileNetV2作为主干网络,进行RetinaNet检测网络的预训练,得到预训练目标检测网络。
步骤202,获取预训练目标检测网络的拟归一化层的缩放因子;
步骤203,对缩放因子进行稀疏化,确定预训练目标检测网络各通道的绝对值;
步骤204,移除绝对值小于预设阈值的通道,得到学生网络;
在本发明实施例中,可以获取预训练目标检测网络的拟归一化层的缩放因子,并对缩放因子进行稀疏化,根据稀疏结果确定预训练目标检测网络各通道的绝对值,并将绝对值小于预设阈值(该预设阈值可选趋向于0的某个数值,具体选择哪个数值,本发明实施例不作具体限制)的通道移除,得到学生网络。
步骤205,将预设检测图像输入预训练目标检测网络,得到第一中间特征;
步骤206,将预设检测图像输入学生网络,得到第二中间特征;
在目标检测中,对于一张输入的预设检测图像,目标检测网络应更加关注特定位置上的目标,而大部分背景对于网络来说都属于噪声,因此在蒸馏时,将背景过滤掉可以减少噪声或是无效信息对学生网络的干扰。
在具体实现中,可以通过中间特征来筛除噪声。
步骤207,根据第一中间特征和第二中间特征计算蒸馏损失;
在具体实现中,在获取到预训练目标检测网络的第一中间特征和学生网络的第二中间特征后,可以基于第一中间特征和第二中间特征来计算蒸馏损失,以根据蒸馏损失来判断预训练目标检测网络与学生网络之间的差异。
其中,C、H、W分别表示特征的通道数、高、宽。
在一个示例中,蒸馏损失可以包括第一蒸馏损失,根据第一中间特征和第二中间特征计算蒸馏损失的步骤,可以包括以下子步骤:
S71,获取预训练目标检测网络中特征的位置信息;
S72,采用位置信息生成二维关键性矩阵;
S73,以二维关键性矩阵作为权重,结合第一中间特征和第二中间特征计算第一蒸馏损失。
如图3所示,(a)表示输入到预训练目标检测网络中的图像,(b)表示根据目标所在的位置生成的二维关键性矩阵,可以观察到,通过这种方法可以让学生网络在蒸馏过程中不去关注树木、墙壁等背景,更加聚焦于对狗、自行车和汽车三个目标特征的学习。
在另一个示例中,还可以基于特征显著性进行关键信息蒸馏,在本发明实施例中,蒸馏损失还可以包括第二蒸馏损失;根据第一中间特征和第二中间特征计算蒸馏损失的步骤,还可以包括以下子步骤:
S74,根据第一中间特征获取第一关键性向量;
S75,以第一关键性向量作为损失系数,结合第一中间特征和第二中间特征计算第二蒸馏损失。
在实际应用中,网络的中间输出在目标区域有更高的响应值,在背景区域的响应值较小,因此可以直接根据响应值的大小来评估特征的关键性程度。本发明实施例根据中间特征生成一个范围在0到1之间的矩阵作为关键性的衡量指标。特征显著性不直接使用硬性的界定方法,而是使用一个权重来对信息关键性程度进行评估。因为往往一个目标并不是完全脱离背景而单独存在的,背景中也可能包含一些可以辅助目标分类定位的信息,此外,对于某一个单独的目标,网络对于目标上不同位置的关注度也是不同的,显然通过人工标注找到这些关键的位置并不现实,因此本发明实施例借助于网络自己的中间特征进行一种自监督。一般而言,重要的关键位置在中间的特征图上往往具有较高的响应值,从而对网络的预测提供更加准确的信息。考虑到教师网络(预训练目标检测网络)具有更加准确的预测效果,本步骤使用教师网络的中间特征来衡量关键性。
在本发明实施例中,可以从三个角度来获取第一关键性指标;其中,第一关键性指标可以包括通道关键性、空间关键性和逐点关键性。
则第c个通道的关键性大小为:
以空间关键性为例,空间关键性即根据每个位置的特征显著性,生成二维空间的关键性矩阵。在生成时对同一个空间位置的所有通道的特征大小求和并归一化到0到1之间,视为这一位置的第一关键性向量(第一空间关键性向量),即:
此外,以逐点关键性为例,卷积神经网络中间层的输出特征的每一个通道都包含不同的关键信息,多个通道的信息是互补的,不能对所有通道使用一个相同的二维关键性评价指标,而应该对于不同通道的不同位置分别逐特征点进行讨论,因此,其第一关键性向量(第一逐点关键性向量)可以为:
相应的第二蒸馏损失则为:
在另一个实施例中,还可以基于梯度显著性进行关键信息蒸馏以计算蒸馏损失;蒸馏损失还包括第三蒸馏损失,根据第一中间特征和第二中间特征计算蒸馏损失的步骤,还包括:
S76,获取第一中间特征的第一梯度,以及获取第二中间特征的第二梯度;
S77,采用第一梯度计算第二关键性向量;
S78,以第二关键性向量作为损失系数,结合第一梯度和第二梯度计算第三蒸馏损失。
在本发明实施例中,考虑到教师网络(预训练目标检测网络)具有更高的准确性,本发明实施例使用教师网络来衡量关键性。由于中间特征的梯度本身反映的是中间特征相比于理想情况存在的偏差,无论是正向偏差还是负向偏差,均表明这块区域是需要重点关注的,因此在计算关键性指标时,首先对梯度取绝对值。
在本发明实施例中,第一中间特征的第一梯度为:
第二中间特征的第二梯度为:
在获取到第一梯度后,可以根据第一梯度计算第二关键性向量。
其中,第二关键性向量可以包括第二通道关键性向量、第二空间关键性向量和第二逐点关键性向量。
在计算得到第一蒸馏损失、第二蒸馏损失和第三蒸馏损失后,可以将第一蒸馏损失、第二蒸馏损失和第三蒸馏损失进行加权求和,得到完整的蒸馏损失。
步骤208,根据蒸馏损失优化学生网络,得到目标检测网络。
在本发明实施例中,在获取到蒸馏损失后,可以往减小蒸馏损失的方向去调整学生网络的参数,通过多次迭代调整,得到蒸馏损失尽可能小的目标检测网络。
在完成对目标检测网络的蒸馏后,可以基于NCNN加速推理框架对模型进行量化。如图4和图5所示,图4为普通卷积流程,图5为量化卷积流程。相比于图4的普通卷积流程,图5的量化卷积流程主要添加了量化和反量化操作,对于一个训练好的模型,先直接离线将其量化为INT8类型,然后使用这个模型进行推理,在推理过程中,每一层的输出在输入到下一层进行卷积操作之前,都需要在线转化为INT8类型,然后两个INT8类型的矩阵进行运算,得到的结果再反量化为 FP32,与偏置进行求和,即可完成对目标检测模型的量化。需要注意的是,由于偏置项占有的参数量和计算量均很小,所以未对偏置项进行量化。
进一步地,在完成对目标检测网络的量化后,可以将量化后的目标检测网络部署到ARM平台。ARM是一个精简指令集处理器构架家族,其具有低成本、高效能、低能耗的特性,目前广泛使用于各种嵌入式硬件中。Arm Neon 技术是Arm Cortex-A处理器的单指令多数据架构的扩展,它提供了16 个 128 位向量寄存器,每个寄存器可以存放4个32位或8个16位的操作数,在进行计算时可以将一个寄存器中的多个操作数使用一个指令完成计算,实现并行计算,加快计算速度,在步骤4中,完成NCNN的INT8量化流程后,便可将量化后的目标检测模型部署到到手机端的ARM处理器上。
本发明通过中间特征来计算蒸馏损失,再根据蒸馏损失对学生网络进行优化,得到优化后的目标检测网络,从而提高了目标检测网络的蒸馏效果。
请参阅图6,图6为本发明实施例提供的一种目标检测网络蒸馏装置的结构框图。
本发明实施例提供了一种目标检测网络蒸馏装置,包括:
预训练目标检测网络获取模块601,用于获取预训练目标检测网络;
剪枝模块602,用于对预训练目标检测网络进行剪枝,得到学生网络;
第一中间特征获取模块603,用于将预设检测图像输入预训练目标检测网络,得到第一中间特征;
第二中间特征获取模块604,用于将预设检测图像输入学生网络,得到第二中间特征;
蒸馏损失计算模块605,用于根据第一中间特征和第二中间特征计算蒸馏损失;
优化模块606,用于根据蒸馏损失优化学生网络,得到目标检测网络。
在本发明实施例中,剪枝模块602,包括:
缩放因子获取子模块,用于获取预训练目标检测网络的拟归一化层的缩放因子;
绝对值确定子模块,用于对缩放因子进行稀疏化,确定预训练目标检测网络各通道的绝对值;
学生网络获取子模块,用于移除绝对值小于预设阈值的通道,得到学生网络。
在本发明实施例中,蒸馏损失包括第一蒸馏损失;蒸馏损失计算模块605,包括:
位置信息获取子模块,用于获取预训练目标检测网络中各特征的位置信息;
二维关键性矩阵生成子模块,用于采用位置信息生成二维关键性矩阵;
第一蒸馏损失计算子模块,用于以二维关键性矩阵作为权重,结合第一中间特征和第二中间特征计算第一蒸馏损失。
在本发明实施例中,蒸馏损失还包括第二蒸馏损失,蒸馏损失计算模块605,还包括:
第一关键性向量获取子模块,用于根据第一中间特征获取第一关键性向量;
第二蒸馏损失计算子模块,用于以第一关键性向量作为损失系数,结合第一中间特征和第二中间特征计算第二蒸馏损失。
在本发明实施例中,第一关键性向量包括第一通道关键性向量、第一空间关键性向量和第一逐点关键性向量。
在本发明实施例中,蒸馏损失还包括第三蒸馏损失,蒸馏损失计算模块605,还包括:
梯度获取子模块,用于获取第一中间特征的第一梯度,以及获取第二中间特征的第二梯度;
第二关键性向量计算子模块,用于采用第一梯度计算第二关键性向量;
第三蒸馏损失计算子模块,用于以第二关键性向量作为损失系数,结合第一梯度和第二梯度计算第三蒸馏损失。
本发明实施例还提供了一种电子设备,设备包括处理器以及存储器:
存储器用于存储程序代码,并将程序代码传输给处理器;
处理器用于根据程序代码中的指令执行本发明实施例的目标检测网络蒸馏方法。
本发明实施例还提供了一种计算机可读存储介质,计算机可读存储介质用于存储程序代码,程序代码用于执行本发明实施例的目标检测网络蒸馏方法。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本说明书中的各个实施例均采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似的部分互相参见即可。
本领域内的技术人员应明白,本发明实施例的实施例可提供为方法、装置、或计算机程序产品。因此,本发明实施例可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明实施例可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明实施例是参照根据本发明实施例的方法、终端设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理终端设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理终端设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理终端设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理终端设备上,使得在计算机或其他可编程终端设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程终端设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
尽管已描述了本发明实施例的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例做出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本发明实施例范围的所有变更和修改。
最后,还需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者终端设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者终端设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者终端设备中还存在另外的相同要素。
以上所述,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (7)
1.一种目标检测网络蒸馏方法,其特征在于,包括:
获取预训练目标检测网络;
对所述预训练目标检测网络进行剪枝,得到学生网络;
将预设检测图像输入所述预训练目标检测网络,得到第一中间特征;
将所述预设检测图像输入所述学生网络,得到第二中间特征;
根据所述第一中间特征和所述第二中间特征计算蒸馏损失;
根据所述蒸馏损失优化所述学生网络,得到目标检测网络;
其中,所述对所述预训练目标检测网络进行剪枝,得到学生网络的步骤,包括:
获取所述预训练目标检测网络的拟归一化层的缩放因子;
对所述缩放因子进行稀疏化,确定所述预训练目标检测网络各通道的绝对值;
移除绝对值小于预设阈值的通道,得到学生网络;
其中,所述蒸馏损失包括第一蒸馏损失;所述根据所述第一中间特征和所述第二中间特征计算蒸馏损失的步骤,包括:
获取所述预训练目标检测网络中各特征的位置信息;
采用所述位置信息生成二维关键性矩阵;
以所述二维关键性矩阵作为权重,结合所述第一中间特征和所述第二中间特征计算第一蒸馏损失;
所述第一蒸馏损失计算过程为:
2.根据权利要求1所述的方法,其特征在于,所述蒸馏损失还包括第二蒸馏损失,所述根据所述第一中间特征和所述第二中间特征计算蒸馏损失的步骤,还包括:
根据所述第一中间特征获取第一关键性向量;
以所述第一关键性向量作为损失系数,结合所述第一中间特征和所述第二中间特征计算第二蒸馏损失。
3.根据权利要求2所述的方法,其特征在于,所述第一关键性向量包括第一通道关键性向量、第一空间关键性向量和第一逐点关键性向量。
4.根据权利要求1-3任一项所述的方法,其特征在于,所述蒸馏损失还包括第三蒸馏损失,所述根据所述第一中间特征和所述第二中间特征计算蒸馏损失的步骤,还包括:
获取所述第一中间特征的第一梯度,以及获取所述第二中间特征的第二梯度;
采用所述第一梯度计算第二关键性向量;
以所述第二关键性向量作为损失系数,结合所述第一梯度和所述第二梯度计算第三蒸馏损失。
5.一种目标检测网络蒸馏装置,其特征在于,包括:
预训练目标检测网络获取模块,用于获取预训练目标检测网络;
剪枝模块,用于对所述预训练目标检测网络进行剪枝,得到学生网络;
第一中间特征获取模块,用于将预设检测图像输入所述预训练目标检测网络,得到第一中间特征;
第二中间特征获取模块,用于将所述预设检测图像输入所述学生网络,得到第二中间特征;
蒸馏损失计算模块,用于根据所述第一中间特征和所述第二中间特征计算蒸馏损失;
优化模块,用于根据所述蒸馏损失优化所述学生网络,得到目标检测网络;
其中,所述剪枝模块,包括:
缩放因子获取子模块,用于获取所述预训练目标检测网络的拟归一化层的缩放因子;
绝对值确定子模块,用于对所述缩放因子进行稀疏化,确定所述预训练目标检测网络各通道的绝对值;
学生网络获取子模块,用于移除绝对值小于预设阈值的通道,得到学生网络;
其中,蒸馏损失包括第一蒸馏损失;蒸馏损失计算模块,包括:
位置信息获取子模块,用于获取预训练目标检测网络中各特征的位置信息;
二维关键性矩阵生成子模块,用于采用位置信息生成二维关键性矩阵;
第一蒸馏损失计算子模块,用于以二维关键性矩阵作为权重,结合第一中间特征和第二中间特征计算第一蒸馏损失;
所述第一蒸馏损失计算过程为:
6.一种电子设备,其特征在于,所述设备包括处理器以及存储器:
所述存储器用于存储程序代码,并将所述程序代码传输给所述处理器;
所述处理器用于根据所述程序代码中的指令执行权利要求1-4任一项所述的目标检测网络蒸馏方法。
7.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质用于存储程序代码,所述程序代码用于执行权利要求1-4任一项所述的目标检测网络蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310169069.6A CN115880486B (zh) | 2023-02-27 | 2023-02-27 | 一种目标检测网络蒸馏方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310169069.6A CN115880486B (zh) | 2023-02-27 | 2023-02-27 | 一种目标检测网络蒸馏方法、装置、电子设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115880486A CN115880486A (zh) | 2023-03-31 |
CN115880486B true CN115880486B (zh) | 2023-06-02 |
Family
ID=85761666
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310169069.6A Active CN115880486B (zh) | 2023-02-27 | 2023-02-27 | 一种目标检测网络蒸馏方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115880486B (zh) |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114519717A (zh) * | 2021-12-31 | 2022-05-20 | 深圳云天励飞技术股份有限公司 | 一种图像处理方法及装置、计算机设备、存储介质 |
Family Cites Families (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CA3033014A1 (en) * | 2018-02-07 | 2019-08-07 | Royal Bank Of Canada | Robust pruned neural networks via adversarial training |
CN113159173B (zh) * | 2021-04-20 | 2024-04-26 | 北京邮电大学 | 一种结合剪枝与知识蒸馏的卷积神经网络模型压缩方法 |
CN113574566A (zh) * | 2021-05-14 | 2021-10-29 | 北京大学深圳研究生院 | 目标检测网络构建优化方法、装置、设备、介质及产品 |
KR20220160814A (ko) * | 2021-05-28 | 2022-12-06 | 삼성에스디에스 주식회사 | 회귀 태스크 기반의 지식 증류 방법 및 이를 수행하기 위한 컴퓨팅 장치 |
CN113343817A (zh) * | 2021-05-31 | 2021-09-03 | 扬州大学 | 一种面向目标区域的无人车路径检测方法、装置及介质 |
CN115511071A (zh) * | 2021-06-23 | 2022-12-23 | 北京字跳网络技术有限公司 | 模型训练方法、装置及可读存储介质 |
CN113743514A (zh) * | 2021-09-08 | 2021-12-03 | 庆阳瑞华能源有限公司 | 一种基于知识蒸馏的目标检测方法及目标检测终端 |
CN114049512A (zh) * | 2021-09-22 | 2022-02-15 | 北京旷视科技有限公司 | 模型蒸馏方法、目标检测方法、装置及电子设备 |
CN114139703A (zh) * | 2021-11-26 | 2022-03-04 | 上海瑾盛通信科技有限公司 | 知识蒸馏方法及装置、存储介质及电子设备 |
CN114187435A (zh) * | 2021-12-10 | 2022-03-15 | 北京百度网讯科技有限公司 | 文本识别方法、装置、设备以及存储介质 |
CN114819135A (zh) * | 2022-03-18 | 2022-07-29 | 上海高仙自动化科技发展有限公司 | 检测模型的训练方法、目标检测方法、装置和存储介质 |
CN114663848A (zh) * | 2022-03-23 | 2022-06-24 | 京东鲲鹏(江苏)科技有限公司 | 一种基于知识蒸馏的目标检测方法和装置 |
CN114842449A (zh) * | 2022-05-10 | 2022-08-02 | 安徽蔚来智驾科技有限公司 | 目标检测方法、电子设备、介质及车辆 |
-
2023
- 2023-02-27 CN CN202310169069.6A patent/CN115880486B/zh active Active
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114519717A (zh) * | 2021-12-31 | 2022-05-20 | 深圳云天励飞技术股份有限公司 | 一种图像处理方法及装置、计算机设备、存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN115880486A (zh) | 2023-03-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111105017B (zh) | 神经网络量化方法、装置及电子设备 | |
CN111368656A (zh) | 一种视频内容描述方法和视频内容描述装置 | |
EP4318313A1 (en) | Data processing method, training method for neural network model, and apparatus | |
CN110781686B (zh) | 一种语句相似度计算方法、装置及计算机设备 | |
CN110647974A (zh) | 深度神经网络中的网络层运算方法及装置 | |
CN112990438A (zh) | 基于移位量化操作的全定点卷积计算方法、系统及设备 | |
CN111950633A (zh) | 神经网络的训练、目标检测方法及装置和存储介质 | |
CN115393633A (zh) | 数据处理方法、电子设备、存储介质及程序产品 | |
CN115147598A (zh) | 目标检测分割方法、装置、智能终端及存储介质 | |
CN113191318A (zh) | 目标检测方法、装置、电子设备及存储介质 | |
CN114565196B (zh) | 基于政务热线的多事件趋势预判方法、装置、设备及介质 | |
CN115439694A (zh) | 一种基于深度学习的高精度点云补全方法及装置 | |
Lee et al. | Channel pruning via gradient of mutual information for light-weight convolutional neural networks | |
WO2022100607A1 (zh) | 一种神经网络结构确定方法及其装置 | |
CN114491289A (zh) | 一种双向门控卷积网络的社交内容抑郁检测方法 | |
CN116805387B (zh) | 基于知识蒸馏的模型训练方法、质检方法和相关设备 | |
CN116523888B (zh) | 路面裂缝的检测方法、装置、设备及介质 | |
CN115880486B (zh) | 一种目标检测网络蒸馏方法、装置、电子设备及存储介质 | |
CN114820755B (zh) | 一种深度图估计方法及系统 | |
CN114155388B (zh) | 一种图像识别方法、装置、计算机设备和存储介质 | |
CN114267422B (zh) | 地表水质参数预测方法、系统、计算机设备及存储介质 | |
CN112561050B (zh) | 一种神经网络模型训练方法及装置 | |
CN116959489B (zh) | 语音模型的量化方法、装置、服务器及存储介质 | |
CN116030347B (zh) | 一种基于注意力网络的高分辨率遥感影像建筑物提取方法 | |
CN116206212A (zh) | 一种基于点特征的sar图像目标检测方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |