CN117315722B - 一种基于知识迁移剪枝模型的行人检测方法 - Google Patents
一种基于知识迁移剪枝模型的行人检测方法 Download PDFInfo
- Publication number
- CN117315722B CN117315722B CN202311579036.5A CN202311579036A CN117315722B CN 117315722 B CN117315722 B CN 117315722B CN 202311579036 A CN202311579036 A CN 202311579036A CN 117315722 B CN117315722 B CN 117315722B
- Authority
- CN
- China
- Prior art keywords
- pruning
- model
- channel
- convolution
- output
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000013138 pruning Methods 0.000 title claims abstract description 131
- 238000001514 detection method Methods 0.000 title claims abstract description 49
- 238000013508 migration Methods 0.000 title claims abstract description 23
- 230000005012 migration Effects 0.000 title claims abstract description 23
- 238000012549 training Methods 0.000 claims abstract description 39
- 238000000034 method Methods 0.000 claims abstract description 25
- 230000008569 process Effects 0.000 claims abstract description 14
- 238000004364 calculation method Methods 0.000 claims abstract description 9
- 238000005259 measurement Methods 0.000 claims abstract description 9
- 230000006870 function Effects 0.000 claims description 8
- 125000004122 cyclic group Chemical group 0.000 claims description 7
- 238000012360 testing method Methods 0.000 claims description 7
- 230000009466 transformation Effects 0.000 claims description 6
- 230000000694 effects Effects 0.000 claims description 5
- 238000011156 evaluation Methods 0.000 claims description 5
- 238000010606 normalization Methods 0.000 claims description 5
- 238000010276 construction Methods 0.000 claims description 3
- 230000000750 progressive effect Effects 0.000 claims description 3
- 230000009467 reduction Effects 0.000 claims description 3
- 238000010008 shearing Methods 0.000 claims description 3
- 238000010200 validation analysis Methods 0.000 claims description 3
- 238000012795 verification Methods 0.000 claims description 3
- 244000141353 Prunus domestica Species 0.000 abstract description 2
- 238000010586 diagram Methods 0.000 description 4
- 238000000605 extraction Methods 0.000 description 4
- 230000003044 adaptive effect Effects 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 230000004927 fusion Effects 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 230000003542 behavioural effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000007500 overflow downdraw method Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V40/00—Recognition of biometric, human-related or animal-related patterns in image or video data
- G06V40/10—Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- Multimedia (AREA)
- Biomedical Technology (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Human Computer Interaction (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及行人检测技术领域,特别涉及一种基于知识迁移剪枝模型的行人检测方法,本发明采用通道剪枝方法,以YOLOv8网络作为基础检测模型,修剪YOLOv8网络卷积层中重要性低的卷积核,在模型剪枝重新训练的过程中计算剪枝模型和原模型输出之间的KL散度,作为剪枝和训练过程中的损失函数中的一部分,使剪枝模型的输出更接近与原模型,准确度也更接近于原模型,在降低剪枝模型参数量和计算量的同时,保持和原模型相当的性能表现;另外结合L1范数和批量标准化权重作为卷积核重要性的衡量标准,使模型的通道剪枝的选择更有效,解决了目前通道剪枝方法中衡量标准单一的问题,在保持行人检测准确性的同时大量降低算法的参数量和计算量,满足实时性的要求。
Description
技术领域
本发明属于行人检测技术领域,特别涉及一种基于知识迁移剪枝模型的行人检测方法。
背景技术
在深度学习中,行人检测是一项重要的任务,它的主要目标是识别图像或视频中的行人并将其与其他物体进行区分。行人检测技术在自动驾驶、增强现实及行人计数和行为分析等方面有着重要应用。
对于行人检测方法,目前大多都直接采用开源目标检测算法,如YOLO、DETR、Mask-RCNN、EfficientDet等性能较好的算法。但这些检测算法都是基于VOC或COCO等大型、多类别数据集上进行调整测试的,并不是特别针对行人检测进行开发,而对于行人检测任务,这些算法有着大量冗余的参数和计算量。在自动驾驶和实时监控系统等应用中,实时性是一个重要的考虑因素。然而,这些行人检测方法在速度上无法满足实时性的要求,因此我们需要提出一种基于知识迁移剪枝模型的行人检测方法来解决上述存在的问题。
发明内容
针对上述问题,本发明提供了一种基于知识迁移剪枝模型的行人检测方法,包括如下步骤:
S1、采集行人数据集,进行数据标注,构建行人数据集;
S2、基于YOLOv8网络确定通道需要剪枝的每个卷积层;
S3、根据L1范数和批量标准权重对卷积层通道进行重要性排序;
S4、将排序后的卷积层通道按照剪枝比例剪去相应通道;
S5、采用KL散度来进行模型训练,以衡量知识迁移损失;
S6、剪枝YOLOv8模型重建与匹配预训练权重,重建后的网络能够正常用于行人检测。
进一步的,步骤S1中,所述行人数据集包括训练集、验证集和测试集,所述行人数据集在进行标注时,将行人数据集以6:2:2的比例分成训练集、验证集和测试集。
进一步的,步骤S2中,所述YOLOv8网络含有64个卷积层,卷积层的通道剪枝考虑因素包括参数量减小程度、剪枝后模型正常推理进程、剪枝模型便于重建度。
进一步的,所述卷积层的参数维度为输出通道数、输入通道数、卷积核的高和卷积核的宽,对于减小卷积层的参数量,通道剪枝选择如下三种剪枝方案之一进行剪枝:
1)对输出通道进行剪枝;
2)对输入通道进行剪枝;
3)同时对输出通道和输入通道进行剪枝。
进一步的,所述输出通道进行剪枝的卷积层含有19层,其中卷积层按顺序有2、4、5、9、10、12、16、17、19、23、25、29、30、33、34 38、39、43、44;
所述输入通道进行剪枝的卷积层含有24层,其中卷积层按顺序为3、6、8、11、13、15、18、20、22、25、28、31、32、35、37、40、42、45、46、49、52、55、58、61;
所述输出通道和输入通道一起剪枝的卷积层含有7层,其中卷积层按顺序为7、14、21、26、27、36、41。
进一步的,步骤S3中,对所述卷积层进行重要性排序时,先确定对卷积层的哪些通道进行剪枝,再结合L1范数和批量标准权重作为卷积核重要性的衡量标准,一个卷积核的lp范数由如下公式计算得出:
其中i∈Nl+1表示第l个卷积层的第i个卷积核,Nl为该卷积层的输入通道数,Kl为卷积核大小,Fi l为范数,p为范数的次序,当p取1时,上述lp范数计算公式记为一个卷积核的L1范数;
卷积层的批量标准化定义如下:
其中μ和σ表示该卷积层输出的均值和标准差,γ和β为可学习参数,Zout为卷积层的输出提供可学习的线性变换,Zin为卷积层的输入提供可学习的线性变换,其中∈为批量标准化权重;z为卷积层的批量标准化定义值;
通道剪枝评价标准定义为:
由上述通道剪枝评价标准定义公式计算得到卷积层中每个通道的重要性大小,作为后续步骤每个卷积层剪枝的衡量标准。
进一步的,步骤S4中,所述卷积层通道进行剪枝时,将剪枝率设定为50%,通过剪掉整个模型一半的通道的方式大量减少网络参数,首先剪枝前,先计算网络中每个卷积层的通道的重要性值并进行排序,记录排名在后50%的通道的索引,然后,构建剪枝掩码,对排名在后50%部分的通道权重乘以0,使这部分权重在后续的输出不起作用,达到模型训练时剪枝的效果,其中,剪枝的过程采用循环递进剪枝策略,具体为每5个训练epoch进行一次剪枝操作,每次剪枝操作剪去重要性最低的5%的通道,循环直到50%的卷积层通道被剪掉。
进一步的,步骤S5中,在进行模型训练时,在损失函数中加入了知识迁移损失,使剪枝网络不但学习真实标签的分布,也学习原网络输出的分布。
进一步的,所述KL散度来衡量知识迁移损失的计算公式如下:
其中C表示网络输出的总通道数,Yo表示原网络的输出,YP表示剪枝网络的输出,YC表示网络输出的一个通道,i表示输出中一个通道中的每个位置,H和W分别表示网络输出的卷积核的高和卷积核的宽,T为一常数,用于调节softmax函数输出的分布,在剪枝网络的训练中,总损失等于原目标检测损失加上用KL散度衡量的剪枝损失。
进一步的,步骤S6中,所述剪枝YOLOv8模型重建时,需要用到训练时用的剪枝掩码,根据掩码判断每个卷积核所剩下的输出通道和输入通道,进行每个卷积层的构建,在模型训练完成后,进行模型重建,同时,使用for循环遍历训练时保存的权重,删减掉权重值为0的通道,保留权重值非0的通道,这时得到的新的权重和重建的网络便能实现匹配。
本发明的有益效果是:
1、本发明采用通道剪枝方法,以YOLOv8网络作为基础检测模型,修剪YOLOv8网络卷积层中重要性低的卷积核,在模型剪枝重新训练的过程中计算剪枝模型和原模型输出之间的KL散度,作为剪枝和训练过程中的损失函数中的一部分,使剪枝模型的输出更接近与原模型,准确度也更接近于原模型,在降低剪枝模型参数量和计算量的同时,保持和原模型相当的性能表现。
2、本发明结合L1范数和批量标准化权重作为卷积核重要性的衡量标准,使模型的通道剪枝的选择更有效,解决了目前通道剪枝方法中衡量标准单一的问题,在保持行人检测准确性的同时大量降低算法的参数量和计算量,满足实时性的要求。
本发明的其它特征和优点将在随后的说明书中阐述,并且,部分地从说明书中变得显而易见,或者通过实施本发明而了解。本发明的目的和其他优点可通过在说明书、权利要求书以及附图中所指出的结构来实现和获得。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了根据本发明实施例的总体流程框图;
图2示出了根据本发明实施例的YOLOv8网络的细节示意图;
图3示出了根据本发明实施例的循环剪枝的流程示意图;
图4示出了根据本发明实施例的模型剪枝训练的解析示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地说明,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例提供了一种基于知识迁移剪枝模型的行人检测方法,如图1所示,包括如下步骤:
S1、采集行人数据集,进行数据标注,构建行人数据集;
所述行人数据集包括训练集、验证集和测试集,所述行人数据集在进行标注时,将行人数据集以6:2:2的比例分成训练集、验证集和测试集。
S2、基于YOLOv8网络确定通道需要剪枝的每个卷积层;
所述YOLOv8网络含有64个卷积层,卷积层的通道剪枝考虑因素包括参数量减小程度、剪枝后模型正常推理进程、剪枝模型便于重建度。
其中YOLOv8网络是目标检测算法YOLO(You Only Look Once)的第八个版本。YOLO是一种实时目标检测算法,其特点是能够在一次前向传播中同时完成目标定位和分类,并且具有较快的速度。YOLOv8在YOLOv3的基础上进行了改进和优化,以提高检测精度和速度。主要的改进包括:
使用Darknet作为基础网络架构:YOLOv8采用了Darknet作为卷积神经网络的基础。Darknet是一个轻量且高效的深度学习框架,具有较好的性能和可移植性。
基于FPN(Feature Pyramid Network)的特征融合:YOLOv8引入了FPN来融合不同尺度的特征图,以提高对不同大小目标的检测能力。
使用PANet(Path Aggregation Network)进行上下文特征融合:PANet是一种用于实现上下文感知的特征融合方法,通过在多个尺度上聚合特征图,改进了对小尺寸目标的检测性能。
采用自适应卷积进行特征提取:YOLOv8使用自适应卷积替代了常规的卷积操作,自适应卷积能够根据输入特征图的内容和大小来自动调整感受野,提升了特征提取的效果。
多尺度训练和推理:为了更好地处理不同大小的目标,YOLOv8采用了一种多尺度训练和推理策略,通过分别处理不同尺度的特征图,提高了对小目标和远距离目标的检测能力。
如图2所示,YOLOv8网络在确定通道需要剪枝的每个卷积层时,先输入图片,再进行特征提取网络,特征提取网络包含8倍下采样特征图、16倍采样特征图和32倍采样特征图,然后再进行特征整合网络,再多尺度预测模型,多尺度预测模型包括8倍下采样网络输出、16倍下采样网络输出和32倍下采样网络输出,最后进行行人检测。
所述卷积层的参数维度为输出通道数、输入通道数、卷积核的高和卷积核的宽,对于减小卷积层的参数量,通道剪枝选择如下三种剪枝方案之一进行剪枝:
1)对输出通道进行剪枝;所述输出通道进行剪枝的卷积层含有19层,其中卷积层按顺序有2、4、5、9、10、12、16、17、19、23、25、29、30、33、34 38、39、43、44。
2)对输入通道进行剪枝;所述输入通道进行剪枝的卷积层含有24层,其中卷积层按顺序为3、6、8、11、13、15、18、20、22、25、28、31、32、35、37、40、42、45、46、49、52、55、58、61。
3)同时对输出通道和输入通道进行剪枝,所述输出通道和输入通道一起剪枝的卷积层含有7层,其中卷积层按顺序为7、14、21、26、27、36、41。
S3、根据L1范数和批量标准权重对卷积层通道进行重要性排序;
对所述卷积层进行重要性排序时,先确定对卷积层的哪些通道进行剪枝,再结合L1范数和批量标准权重作为卷积核重要性的衡量标准,一个卷积核的lp范数由如下公式计算得出:
其中i∈Nl+1表示第l个卷积层的第i个卷积核,Nl为该卷积层的输入通道数,Kl为卷积核大小,Fi l为范数,p为范数的次序,当p取1时,上述公式记为一个卷积核的L1范数;
卷积层的批量标准化定义如下:
其中μ和σ表示该卷积层输出的均值和标准差,γ和β为可学习参数,Zout为卷积层的输出提供可学习的线性变换,Zin为卷积层的输入提供可学习的线性变换,其中∈为批量标准化权重;z为卷积层的批量标准化定义值;
通道剪枝评价标准定义为:
由上述公式计算得到卷积层中每个通道的重要性大小,作为后续步骤每个卷积层剪枝的衡量标准。
S4、将排序后的卷积层通道按照剪枝比例剪去相应通道;
所述卷积层通道进行剪枝时,将剪枝率设定为50%,通过剪掉整个模型一半的通道的方式大量减少网络参数,首先剪枝前,先计算网络中每个卷积层的通道的重要性值并进行排序,记录排名在后50%的通道的索引,然后,构建剪枝掩码,对排名在后50%部分的通道权重乘以0,使这部分权重在后续的输出不起作用,达到模型训练时剪枝的效果,其中,剪枝的过程采用循环递进剪枝策略,具体为每5个训练epoch进行一次剪枝操作,每次剪枝操作剪去重要性最低的5%的通道,循环直到50%的卷积层通道被剪掉,如图3所示,使用初始模型进行模型剪枝,再训练并调整权重,得到剪枝模型,若权重调整不在设定范围内时则需再次进行模型剪枝。
S5、采用KL散度来进行模型训练,以衡量知识迁移损失;
在进行模型训练时,在损失函数中加入了知识迁移损失,使剪枝网络不但学习真实标签的分布,也学习原网络输出的分布。
所述KL散度来衡量知识迁移损失的计算公式如下:
其中C表示网络输出的总通道数,Yo表示原网络的输出,YP表示剪枝网络的输出,YC表示网络输出的一个通道,i表示输出中一个通道中的每个位置,H和W分别表示网络输出的卷积核的高和卷积核的宽,T为一常数,用于调节softmax函数输出的分布,在剪枝网络的训练中,总损失等于原目标检测损失加上用KL散度衡量的剪枝损失。
S6、剪枝YOLOv8模型重建与匹配预训练权重,重建后的网络能够正常用于行人检测。
所述剪枝YOLOv8模型重建时,需要用到训练时用的剪枝掩码,根据掩码判断每个卷积核所剩下的输出通道和输入通道,进行每个卷积层的构建,在模型训练时只是对剪枝权重乘以0,并没有真正减少这部分参数,这时需要在模型训练完成后,进行模型重建,同时,使用for循环遍历训练时保存的权重,删减掉权重值为0的通道,保留权重值非0的通道,这时得到的新的权重和重建的网络便能实现匹配,达到迁移权重并实现模型剪枝的效果。
综上,如图4所示,先输入图片,将图片特征分成原网络和剪枝网络,原网络输出Softmax(T=t),然后计算KLDivLoss(剪枝损失),剪枝网络输出Softmax(T=t)和Softmax,然后通过Softmax(T=t)计算KLDivLoss(剪枝损失),通过Softmax计算行人检测损失,以获得真实标签,通过采用通道剪枝方法,以YOLOv8网络作为基础检测模型,修剪YOLOv8网络卷积层中重要性低的卷积核,在模型剪枝重新训练的过程中计算剪枝模型和原模型输出之间的KL散度,作为剪枝和训练过程中的损失函数中的一部分,使剪枝模型的输出更接近与原模型,准确度也更接近于原模型,在降低剪枝模型参数量和计算量的同时,保持和原模型相当的性能表现;另外结合L1范数和批量标准化权重作为卷积核重要性的衡量标准,使模型的通道剪枝的选择更有效,解决了目前通道剪枝方法中衡量标准单一的问题,在保持行人检测准确性的同时大量降低算法的参数量和计算量,满足实时性的要求。
尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (7)
1.一种基于知识迁移剪枝模型的行人检测方法,其特征在于:包括如下步骤:
S1、采集行人数据集,进行数据标注,构建行人数据集;
S2、基于YOLOv8网络确定通道需要剪枝的每个卷积层;
S3、根据L1范数和批量标准权重对卷积层通道进行重要性排序;
对所述卷积层进行重要性排序时,先确定对卷积层的哪些通道进行剪枝,再结合L1范数和批量标准权重作为卷积核重要性的衡量标准,一个卷积核的范数由如下公式计算得出:
,
其中表示第/>个卷积层的第/>个卷积核,/>为该卷积层的输入通道数,/>为卷积核大小,/>为范数,/>为范数的次序,当/>取1时,上述/>范数计算公式记为一个卷积核的L1范数;
卷积层的批量标准化定义如下:
,
其中和/>表示该卷积层输出的均值和标准差,/>和/>为可学习参数,/>为卷积层的输出提供可学习的线性变换,/>为卷积层的输入提供可学习的线性变换,其中/>为批量标准化权重;z为卷积层的批量标准化定义值;
通道剪枝评价标准定义为:
,由上述通道剪枝评价标准定义公式计算得到卷积层中每个通道的重要性大小,作为后续步骤每个卷积层剪枝的衡量标准;
S4、将排序后的卷积层通道按照剪枝比例剪去相应通道;
S5、采用KL散度来进行模型训练,以衡量知识迁移损失;
所述KL散度来衡量知识迁移损失的计算公式如下:
,/> ,
其中表示网络输出的总通道数,/>表示原网络的输出,/>表示剪枝网络的输出,/>表示网络输出的一个通道,i表示输出中一个通道中的每个位置,H和W分别表示网络输出的卷积核的高和卷积核的宽,/>为一常数,用于调节softmax函数输出的分布,在剪枝网络的训练中,总损失等于原目标检测损失加上用KL散度衡量的剪枝损失;
S6、剪枝YOLOv8模型重建与匹配预训练权重,重建后的网络能够正常用于行人检测;
所述剪枝YOLOv8模型重建时,需要用到训练时用的剪枝掩码,根据掩码判断每个卷积核所剩下的输出通道和输入通道,进行每个卷积层的构建,在模型训练完成后,进行模型重建,同时,使用for循环遍历训练时保存的权重,删减掉权重值为0的通道,保留权重值非0的通道,得到的新的权重和重建的网络便能实现匹配。
2.根据权利要求1所述的一种基于知识迁移剪枝模型的行人检测方法,其特征在于:步骤S1中,所述行人数据集包括训练集、验证集和测试集,所述行人数据集在进行标注时,将行人数据集以6:2:2的比例分成训练集、验证集和测试集。
3.根据权利要求2所述的一种基于知识迁移剪枝模型的行人检测方法,其特征在于:步骤S2中,所述YOLOv8网络含有64个卷积层,卷积层的通道剪枝考虑因素包括参数量减小程度、剪枝后模型正常推理进程、剪枝模型便于重建度。
4.根据权利要求3所述的一种基于知识迁移剪枝模型的行人检测方法,其特征在于:所述卷积层的参数维度为输出通道数、输入通道数、卷积核的高和卷积核的宽,对于减小卷积层的参数量,通道剪枝选择如下三种剪枝方案之一进行剪枝:
1)对输出通道进行剪枝;
2)对输入通道进行剪枝;
3)同时对输出通道和输入通道进行剪枝。
5.根据权利要求4所述的一种基于知识迁移剪枝模型的行人检测方法,其特征在于:所述输出通道进行剪枝的卷积层含有19层,其中卷积层按顺序有2、 4、 5、 9、 10、 12、 16、17、 19、 23、 25、 29、 30、 33、 34 38、 39、 43、 44;
所述输入通道进行剪枝的卷积层含有24层,其中卷积层按顺序为3、 6、 8、 11、 13、15、 18、 20、 22、 25、 28、 31、 32、 35、 37、 40、 42、 45、 46、 49、 52、 55、 58、 61;
所述输出通道和输入通道一起剪枝的卷积层含有7层,其中卷积层按顺序为7、 14、21、 26、 27、 36、 41。
6.根据权利要求5所述的一种基于知识迁移剪枝模型的行人检测方法,其特征在于:步骤S4中,所述卷积层通道进行剪枝时,将剪枝率设定为50%,通过剪掉整个模型一半的通道的方式大量减少网络参数,首先剪枝前,先计算网络中每个卷积层的通道的重要性值并进行排序,记录排名在后50%的通道的索引,然后,构建剪枝掩码,对排名在后50%部分的通道权重乘以0,使这部分权重在后续的输出不起作用,达到模型训练时剪枝的效果,其中,剪枝的过程采用循环递进剪枝策略,具体为每5个训练epoch进行一次剪枝操作,每次剪枝操作剪去重要性最低的5%的通道,循环直到50%的卷积层通道被剪掉。
7.根据权利要求6所述的一种基于知识迁移剪枝模型的行人检测方法,其特征在于:步骤S5中,在进行模型训练时,在损失函数中加入了知识迁移损失,使剪枝网络不但学习真实标签的分布,也学习原网络输出的分布。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311579036.5A CN117315722B (zh) | 2023-11-24 | 2023-11-24 | 一种基于知识迁移剪枝模型的行人检测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311579036.5A CN117315722B (zh) | 2023-11-24 | 2023-11-24 | 一种基于知识迁移剪枝模型的行人检测方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117315722A CN117315722A (zh) | 2023-12-29 |
CN117315722B true CN117315722B (zh) | 2024-03-15 |
Family
ID=89288644
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311579036.5A Active CN117315722B (zh) | 2023-11-24 | 2023-11-24 | 一种基于知识迁移剪枝模型的行人检测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117315722B (zh) |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110909667A (zh) * | 2019-11-20 | 2020-03-24 | 北京化工大学 | 面向多角度sar目标识别网络的轻量化设计方法 |
CN113128355A (zh) * | 2021-03-29 | 2021-07-16 | 南京航空航天大学 | 一种基于通道剪枝的无人机图像实时目标检测方法 |
CN114445332A (zh) * | 2021-12-21 | 2022-05-06 | 江西航天鄱湖云科技有限公司 | 基于faster-rcnn模型的多尺度检测方法 |
WO2023024407A1 (zh) * | 2021-08-24 | 2023-03-02 | 平安科技(深圳)有限公司 | 基于相邻卷积的模型剪枝方法、装置及存储介质 |
CN116502698A (zh) * | 2023-06-29 | 2023-07-28 | 中国人民解放军国防科技大学 | 网络通道剪枝率自适应调整方法、装置、设备和存储介质 |
-
2023
- 2023-11-24 CN CN202311579036.5A patent/CN117315722B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110909667A (zh) * | 2019-11-20 | 2020-03-24 | 北京化工大学 | 面向多角度sar目标识别网络的轻量化设计方法 |
CN113128355A (zh) * | 2021-03-29 | 2021-07-16 | 南京航空航天大学 | 一种基于通道剪枝的无人机图像实时目标检测方法 |
WO2023024407A1 (zh) * | 2021-08-24 | 2023-03-02 | 平安科技(深圳)有限公司 | 基于相邻卷积的模型剪枝方法、装置及存储介质 |
CN114445332A (zh) * | 2021-12-21 | 2022-05-06 | 江西航天鄱湖云科技有限公司 | 基于faster-rcnn模型的多尺度检测方法 |
CN116502698A (zh) * | 2023-06-29 | 2023-07-28 | 中国人民解放军国防科技大学 | 网络通道剪枝率自适应调整方法、装置、设备和存储介质 |
Non-Patent Citations (1)
Title |
---|
Lite-YOLOv3: a real-time object detector based on multi-scale slice depthwise convolution and lightweight attention mechanism;Yipeng Zhou et al;《 Journal of Real-Time Image Processing 》;1-10 * |
Also Published As
Publication number | Publication date |
---|---|
CN117315722A (zh) | 2023-12-29 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Zhang et al. | Identification of maize leaf diseases using improved deep convolutional neural networks | |
CN108764063B (zh) | 一种基于特征金字塔的遥感影像时敏目标识别系统及方法 | |
CN107945204B (zh) | 一种基于生成对抗网络的像素级人像抠图方法 | |
CN111739075A (zh) | 一种结合多尺度注意力的深层网络肺部纹理识别方法 | |
CN112541532B (zh) | 基于密集连接结构的目标检测方法 | |
CN113128355A (zh) | 一种基于通道剪枝的无人机图像实时目标检测方法 | |
CN112529146B (zh) | 神经网络模型训练的方法和装置 | |
CN113159048A (zh) | 一种基于深度学习的弱监督语义分割方法 | |
WO2022039675A1 (en) | Method and apparatus for forecasting weather, electronic device and storage medium thereof | |
CN112308825B (zh) | 一种基于SqueezeNet的农作物叶片病害识别方法 | |
CN115223063A (zh) | 基于深度学习的无人机遥感小麦新品种倒伏面积提取方法及系统 | |
CN115936177A (zh) | 一种基于神经网络的光伏输出功率预测方法及系统 | |
CN117315380B (zh) | 一种基于深度学习的肺炎ct图像分类方法及系统 | |
CN113627240B (zh) | 一种基于改进ssd学习模型的无人机树木种类识别方法 | |
CN114821299A (zh) | 一种遥感图像变化检测方法 | |
CN111783688B (zh) | 一种基于卷积神经网络的遥感图像场景分类方法 | |
Sari et al. | Daily rainfall prediction using one dimensional convolutional neural networks | |
CN117315722B (zh) | 一种基于知识迁移剪枝模型的行人检测方法 | |
CN116151479B (zh) | 一种航班延误预测方法及预测系统 | |
CN116403071A (zh) | 基于特征重构的少样本混凝土缺陷检测方法及装置 | |
CN115641498A (zh) | 基于空间多尺度卷积神经网络的中期降水预报后处理订正方法 | |
CN116992944B (zh) | 基于可学习重要性评判标准剪枝的图像处理方法及装置 | |
CN117611580B (zh) | 瑕疵检测方法、装置、计算机设备和存储介质 | |
Kunwar et al. | Prediction of Air Pollution using Deep Learning based LSTM and CNN Model | |
CN113688989B (zh) | 深度学习网络加速方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |