CN115169556B - 模型剪枝方法及装置 - Google Patents
模型剪枝方法及装置 Download PDFInfo
- Publication number
- CN115169556B CN115169556B CN202210880632.6A CN202210880632A CN115169556B CN 115169556 B CN115169556 B CN 115169556B CN 202210880632 A CN202210880632 A CN 202210880632A CN 115169556 B CN115169556 B CN 115169556B
- Authority
- CN
- China
- Prior art keywords
- cloud data
- detection model
- point cloud
- sampling
- sample point
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/36—Creation of semantic tools, e.g. ontology or thesauri
- G06F16/367—Ontology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
- G06V10/806—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Databases & Information Systems (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Biophysics (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Animal Behavior & Ethology (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及人工智能技术领域,提供一种模型剪枝方法及装置。方法包括:获取初始检测模型和样本点云数据;将所述样本点云数据输入所述初始检测模型进行稀疏训练,确定所述初始检测模型的通道重要性分数;其中,所述通道重要性分数基于所述样本点云数据的采样点特征、所述样本点云数据的未采样点特征和所述样本点云数据的采样点坐标特征确定;基于所述通道重要性分数对所述初始检测模型进行剪枝,确定目标检测模型。本申请提出一种基于空间信息和特征融合的剪枝重要性指示器,设计了一种知识再利用的剪枝方式,通过未采样点来增强剪枝方法的鲁棒性。
Description
技术领域
本申请涉及人工智能技术领域,尤其涉及一种模型剪枝方法及装置。
背景技术
随着人工智能技术的发展,神经网络模型的应用也越来越广泛。考虑到网络模型的参数多、运算量大,为了提高模型的运算速度,需要对模型进行剪枝。剪枝即压缩网络模型,以达到减小模型尺寸,降低资源消耗并提升响应时间的目的。
发明内容
本申请旨在至少解决相关技术中存在的技术问题之一。为此,本申请提出一种模型剪枝方法,针对初始检测模型,通过设计通道重要性指标有效地指导了剪枝过程,引入采样点坐标特征后,可以利用3D点云包含坐标信息的特点辅助对通道的选择,样本点云的未采样点特征再利用,能将浪费的信息引入了通道选择,提升了剪枝后模型的鲁棒性。
本申请还提出一种模型剪枝装置。
本申请还提出一种电子设备。
本申请还提出一种非暂态计算机可读存储介质。
本申请还提出一种计算机程序产品。
根据本申请第一方面实施例的模型剪枝方法,包括:
获取初始检测模型和样本点云数据;
将所述样本点云数据输入所述初始检测模型进行稀疏训练,确定所述初始检测模型的通道重要性分数;
其中,所述通道重要性分数基于所述样本点云数据的采样点特征、所述样本点云数据的未采样点特征和所述样本点云数据的采样点坐标特征确定;
基于所述通道重要性分数对所述初始检测模型进行剪枝,确定目标检测模型。
根据本申请实施例的模型剪枝方法,针对基于3D点云搭建的初始检测模型,通过设计通道重要性指标有效地指导了剪枝过程,获得目标检测模型。引入采样点坐标特征后,可以利用3D点云包含坐标信息的特点辅助对通道的选择,样本点云的未采样点特征再利用,能将浪费的信息引入了通道选择,提升了剪枝后模型的鲁棒性。
根据本申请的一个实施例,构建所述采样点坐标特征包括:
对所述样本点云数据进行降采样,确定采样点;
对所述采样点进行特征提取,确定所述样本点云数据的采样点特征;
对所述采样点的坐标信息进行特征提取,确定所述采样点的初始坐标特征;
对所述样本点云数据的采样点特征和所述采样点的初始坐标特征进行特征融合,确定所述采样点坐标特征。
根据本申请的一个实施例,所述对所述采样点特征和所述采样点的初始坐标特征进行特征融合,包括:
通过交叉注意力机制对所述采样点特征和所述采样点的初始坐标特征进行特征融合。
根据本申请的一个实施例,构建所述未采样点特征,包括:
确定所述样本点云数据中未被降采样的未采样点;
提取所述未采样点的特征信息,确定所述未采样点特征。
根据本申请的一个实施例,确定所述初始检测模型的通道重要性分数,包括:
对所述样本点云数据的采样点特征、所述样本点云数据的未采样点特征和所述样本点云数据的采样点坐标特征分别进行归一化,确定采样点特征分数、未采样点特征分数和采样点坐标特征分数;
将所述采样点特征分数、所述未采样点特征分数和所述采样点坐标特征分数进行加权计算,确定通道重要性分数。
根据本申请的一个实施例,所述基于所述通道重要性分数对所述初始检测型进行剪枝,确定目标检测模型,包括:
基于所述通道重要性分数对所述初始检测模型进行通道选择,确定目标掩膜;
基于所述目标掩膜和稀疏训练后的所述初始检测模型,确定目标检测模型。
根据本申请的一个实施例,所述基于所述目标掩膜和稀疏训练后的所述初始检测模型,确定目标检测模型,包括:
基于所述目标掩膜对所述稀疏训练后的初始检测模型的压缩层梯度进行掩码,确定掩码后的压缩层;
对所述稀疏训练后初始检测模型的卷积层和掩码后的压缩层进行融合,确定目标检测模型。
根据本申请第二方面实施例的模型剪枝装置,包括:
准备模块,用于获取初始检测模型和样本点云数据;
训练模块,用于将所述样本点云数据输入所述初始检测模型进行稀疏训练,确定所述初始检测模型的通道重要性分数;
其中,所述通道重要性分数基于所述样本点云数据的采样点特征、所述样本点云数据的未采样点特征和所述样本点云数据的采样点坐标特征确定;
剪枝模块,用于基于所述通道重要性分数对所述初始检测模型进行剪枝,确定目标检测模型。
根据本申请实施例的模型剪枝装置,针对基于3D点云搭建的初始检测模型,通过设计通道重要性指标有效地指导了剪枝过程,获得目标检测模型。引入采样点坐标特征后,可以利用3D点云包含坐标信息的特点辅助对通道的选择,样本点云数据的未采样点特征再利用,能将浪费的信息引入了通道选择,提升了剪枝后模型的鲁棒性。
根据本申请第三方面实施例的一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现所述场景推荐方法或家电知识图谱构建方法。
根据本申请第四方面实施例的一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现所述场景推荐方法或家电知识图谱构建方法。
根据本申请第五方面实施例的一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现所述场景推荐方法或家电知识图谱构建方法。
本申请实施例中的上述一个或多个技术方案,至少具有如下技术效果之一:结合3D点云数据的特点,将采样点的坐标信息和采样点的特征信息进行融合,从而增强坐标信息在通道重要性评估时的重要性。
进一步的,设计了一种知识再利用的剪枝方式,将未被采样的遗弃点进行推理获得特征信息,通过未采样点特征信息来增强剪枝方法的鲁棒性。
更进一步的,通过交叉注意力机制融合采样点的坐标信息和采样点的特征信息,能够提取对当前任务更关键的信息。
本申请的附加方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本申请的实践了解到。
附图说明
为了更清楚地说明本申请实施例或相关技术中的技术方案,下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的模型剪枝方法流程示意图;
图2是本申请实施例提供的初始检测模型结构示意图;
图3是本申请实施例提供的目标检测方法流程示意图;
图4是本申请实施例提供的模型剪枝装置结构示意图;
图5是本申请实施例提供的目标检测装置结构示意图;
图6是本申请实施例提供的电子设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合本申请中的附图,对本申请中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本申请实施例的至少一个实施例或示例中。此外,术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
根据发明人研究发现,剪枝算法主要目的是减小模型尺寸,降低资源消耗提升响应时间,针对3D剪枝任务,3D数据主要分为两个主流趋势,一种为RGB-D,另一种方式是点云。RGB-D是广泛使用的3D格式。相比于RGB-D,点云表示保留了三维空间中原始的几何信息,不进行离散化。点云存储了点的坐标信息XYZ和颜色信息。传统的模型剪枝方法适用于基于2D图像训练出的模型,与传统的2D图像相比,3D点云数据信息更丰富,包括了坐标信息。因此,基于2D图像的剪枝方法在基于3D点云训练得到的模型上并不适用。
本申请提出了一种针对3D任务的模型剪枝方法,有效的利用了3D点云数据以及模型设计的特点,可以实现高压缩率的同时不降低模型精度。下面结合附图和实施例对本申请的实施方式作进一步详细描述。以下实施例用于说明本申请,但不能用来限制本申请的范围。
如图1所示,本申请一种模型剪枝方法,包括:
步骤101、获取初始检测模型和样本点云数据;
步骤102、将样本点云数据输入初始检测模型进行稀疏训练,确定初始检测模型的通道重要性分数;
其中,通道重要性分数基于样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征确定;
步骤103、基于通道重要性分数对初始检测模型进行剪枝,确定目标检测模型。
针对步骤101,需要说明的是,初始检测模型可以为3D目标检测的通用框架,例如:PointNet++。检测框架包括骨干网络(Backbone)和检测头(Detection Head)。骨干网络中通常包括了降采样操作层(Set Abstraction,SA)和上采样操作层(Feature propagation,FP)。其中,SA层主要是进行降采样,由:采样模块(sampling),分组模块(grouping),特征提取模块(pointnet)三个模块构成。sampling模块用于从输入点云中选择一系列点,定义这些局部区域的中心点,grouping模块用于通过寻找中心点的临近点,将他们组合成局部区域的点集,pointnet模块用于将局部区域内的点的坐标转换为相对该区域中心点的坐标,并作为卷积的输入,得到局部特征。在降采样过程中,经过sampling模块,部分点被选择进行训练,其余的被遗弃,称为遗弃点。FP层是针对检测和分割任务设计的。
针对步骤102,需要说明的是,样本点云数据中包括样本三维图像中点的三维坐标信息以及颜色信息。稀疏训练是一种针对神经网络模型进行剪枝的有效方法,其目的是根据指令条件对神经网络第i层包含的N个神经元对应的梯度值进行选择性的置零,利用未置零的梯度值进行训练运算,梯度值的变化对应模型中不同通道的权重。因此,稀疏训练可以剪除不重要通道,压缩模型体积。样本点云数据的采样点特征即对样本点云数据进行传统剪枝得到的原始采样点特征。样本点云数据的未采样点特征即降采样过程中,对未被采样的遗弃点提取到的特征。样本点云数据的采样点坐标特征是指对采样点的点云坐标编码后提取到的坐标信息特征。另外,本申请实施例中,可以采用Votenet、GroupFree3D等3D目标检测框架来提取采样点的特征。
针对步骤103,需要说明的是,通道重要性分数可以作为剪枝重要性指示器,从而指导网络模型中各个通道剪枝时的去留。目标检测模型是经过剪枝后的初始检测模型。
另外,需要说明的是,样本点云数据在初始检测模型中进行稀疏训练时,会分成三个支路分别进行训练,三条支路分别为原始压缩路径、坐标加强路径和丢弃点信息再利用路径。原始压缩路径中,样本点云数据经过特征提取生成采样点特征。坐标加强路径中,将样本点云数据坐标信息进行特征提取后生成初始坐标特征,再结合采样点特征融合生成采样点坐标特征,丢弃点信息再利用路径中,会对样本点云数据中采样时被丢弃的未采样点进行特征提取,获得未采样点特征。获得了样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征之后,将三种特征进行归一化,生成对应的采样点特征分数、采样点坐标特征分数以及未采样点特征分数。获得了三个特征分数后,将三个特征分数相加,最终获得通道重要性分数。
本申请实施例的模型剪枝方法,针对基于3D点云搭建的初始检测模型,通过设计通道重要性指标有效地指导了剪枝过程,获得目标检测模型。与2D数据相比,由于3D数据中坐标信息更为重要,引入采样点坐标特征后,可以利用3D点云包含坐标信息的特点辅助对通道的选择,通过提出的坐标加强的评估分数有效地指导了剪枝过程中对通道的选择。样本点云数据的未采样点特征再利用,能将浪费的信息引入了通道选择,提升了剪枝后模型的鲁棒性。
可以理解的是,构建采样点坐标特征,包括:
对样本点云数据进行降采样,确定采样点;
对采样点进行特征提取,确定样本点云数据的采样点特征;
对采样点的坐标信息进行特征提取,确定采样点的初始坐标特征;
对样本点云数据的采样点特征和采样点的初始坐标特征进行特征融合,确定采样点坐标特征。
需要说明的是,对样本点云数据进行降采样,降采样时被选择的点即为采样点,采样点的特征信息是直接对采样点进行特征提取得到的,采样点的坐标信息是对采样点的点云坐标进行坐标编码后得到的,坐标编码是对采样点信息的一种维度扩展方式。
本申请实施例的模型剪枝方法,提取到的采样点坐标特征,是在原始的采样点特征的基础上加强了坐标信息,实现了特征信息和坐标信息的融合,通过坐标加强(coordinate enhancing module)的采样点坐标特征可以影响通道重要性分数的取值,从而指导了剪枝过程中对通道的选择。
可以理解的是,对采样点的特征信息和坐标信息进行特征融合,包括:
通过交叉注意力机制对采样点的特征信息和坐标信息进行特征融合。
需要说明的是,交叉注意力机制可以对学习到的特征信息和坐标信息进行融合,增强坐标信息在采样点特征上的表达,可以使模型能自适应感知与坐标信息更相关的特征,使特征信息得到合理有效的处理,增强了模型的表示能力。另外,通过引入注意机制,可以缓解梯度消失问题,降低网络深度选择的难度,减轻可能出现的过拟合问题。
可以理解的是,构建未采样点特征,包括:
确定样本点云数据中未被降采样的未采样点;
提取未采样点的特征信息,确定未采样点特征。
需要说明的是,在降采样过程中,部分点被选择进行训练,其余的被遗弃,称为遗弃点或未采样点。在这些遗弃点中仍然包含了很多有效信息,将遗弃点信息进行再利用通过卷积、归一化和压缩,可以生成未采样点特征。
本申请实施例中,设计了一种知识再利用的剪枝方式,将未采样点信息进行知识再利用,避免了传统方法只对采样点信息进行特征提取后指示通道选择时,可能会出现的漏选或错选的问题。
可以理解的是,通道重要性分数的确定,包括:
对样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征分别进行归一化,确定采样点特征分数、未采样点特征分数和采样点坐标特征分数;
将采样点特征分数、未采样点特征分数和采样点坐标特征分数进行加权计算,确定通道重要性分数。
需要说明的是,样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征是对采样点或未采样点的信息进行推理得到的特征图。对于这些特征进一步进行归一化后,得到采样点特征分数、未采样点特征分数和采样点坐标特征分数,归一化的方式例如L1范数或L2范数(L2 norm)等。
具体的,本申请实施例通过基于坐标加强的采样点坐标特征分数tce、基于遗弃点特征的未采样点特征分数tkc与原始的采样点特征分数to结合进行加权计算,得到最终的通道重要性分数tf如式1所示,然后根据该分数进行通道选择。
其中,i表示第i层卷积,L表示卷积层的总数,tf i、tce i、tkc i和to i分别为第i层卷积对应的各特征分数。在本申请实施例中tce i、tkc i和to i前的权重系数为1,在实际应用中,权重系数对应的超参数会根据模型呈现的实际效果进行调整。
可以理解的是,基于通道重要性分数对初始检测型进行剪枝,确定目标检测模型,包括:
基于通道重要性分数对初始检测模型进行通道选择,确定目标掩膜;
基于目标掩膜和稀疏训练后的初始检测模型,确定目标检测模型。
需要说明的是,在本申请实施例中初始检测模型包括降采样层、卷积层、归一化层和压缩层,其中,压缩层用于通过掩膜对模型进行剪枝。本申请实施例通过通道重要性分数,可以对初始检测模型中的压缩层进行通道选择,通道重要性分数超过重要性阈值的通道保留,将未超过重要性阈值的通道删除,最终确定目标掩膜。
具体的,重要性阈值依据剪枝百分比确定,即需要剪枝的通道数占总通道数的百分比。剪枝百分比根据总稀疏率确定的,总稀疏率和flops裁剪度成正比。例如归一化后得到的特征分数是{1,2,3,4},剪枝50%,重要性阈值就是2,则特征分数为1,2的通道掩码就变成0。
可以理解的是,基于目标掩膜和稀疏训练后的初始检测模型,确定目标检测模型,包括:
基于目标掩膜对稀疏训练后的初始检测模型的压缩层梯度进行掩码,确定掩码后的压缩层;
对稀疏训练后初始检测模型的卷积层和掩码后的压缩层进行融合,确定目标检测模型。
需要说明的是,本申请实施例根据通道重要性评估结果进行反向传播梯度掩码,在反向传播时,根据目标掩膜对压缩层的梯度进行掩码,对掩码为0的部分的梯度清0。
本申请实施例的模型剪枝方法,在传统的压缩层的基础之上,通过融合了遗弃点特征和坐标特征的通道重要性分数创建了新的目标掩膜,从而基于目标掩膜更新压缩层,将稀疏训练后的卷积层与更新后的压缩层进行融合,确定目标检测模型。
本申请实施例的初始检测模型结构包括:降采样模块和特征提取模块,其中特征提取模块包括卷积、归一化和压缩模块。如图2所示,某次稀疏训练时输入的样本点云数据的维度为N×3,N表示本次训练样本点云数据中点的数量,3表示XYZ坐标对应的通道数为3,将样本点云数据输入降采样模块后基于第l-1层局部特征Nl-1×Cl-1得到第l层局部特征为Nl×(3+Cl-1)。
将第l层局部特征为Nl×(3+Cl-1)分为3路,分别为原始压缩路径、坐标加强路径和丢弃点信息再利用路径。原始压缩路径中,第l层局部特征为Nl×(3+Cl-1)经特征提取模块通过卷积、归一化和压缩模块生成采样点特征和采样点特征分数。坐标加强路径中,第1层点云坐标Nl×3将会通过坐标加强支路进行坐标编码生成维度为Nl×Cl的初始坐标特征,在经过交互注意力机制结合采样点特征生成采样点坐标特征。经过归一化后生成经过坐标信息加强的采样点坐标特征分数。丢弃点信息再利用路径中将第l层提取局部特征之外的其余采样点特征经过特征提取模块通过卷积、归一化和压缩模块后生成未采样点特征和未采样点特征分数。
基于坐标信息增强的特征分数和丢弃知识回收的特征分数,获得了通道重要性分数。并基于通道重要性分数获得目标掩膜,对压缩模块的梯度进行掩码,得到更新的压缩模块,最后将更新的压缩模块和稀疏训练后的卷积进行融合得到了目标检测模型。融合得到的目标检测模型如式2所示:
convM(x)==convP(convA(x)) 式2
其中,convM表示融合得到的目标检测模型,convA表示稀疏训练后的卷积层,convP表示更新的压缩模块。
经过等效计算,融合得到的目标检测模型的权重如式3所示:
M.weight=conv2d(convA.weight,convP.weight) 式3
其中,M.weight表示目标检测模型的权重,convA.weight表示稀疏训练后的卷积层的权重,convP.weight表示更新的压缩模块的权重,conv2d()表示二维卷积。
下面对本申请提供的模型剪枝装置进行描述,下文描述的模型剪枝装置与上文描述的模型剪枝方法可相互对应参照。如图4所示,本申请实施例公开了一种模型剪枝装置,包括:
准备模块401,用于获取初始检测模型和样本点云数据;
训练模块402,用于将样本点云数据输入初始检测模型进行稀疏训练,确定初始检测模型的通道重要性分数;
其中,通道重要性分数基于样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征确定;
剪枝模块403,用于基于通道重要性分数对初始检测模型进行剪枝,确定目标检测模型。
本申请实施例的模型剪枝装置,针对基于3D点云搭建的初始检测模型,通过设计通道重要性指标有效地指导了剪枝过程,获得目标检测模型。与2D数据相比,由于3D数据中坐标信息更为重要,引入采样点坐标特征后,可以利用3D点云包含坐标信息的特点辅助对通道的选择,通过提出的坐标加强的评估分数有效地指导了剪枝过程中对通道的选择。样本点云数据的未采样点特征再利用,能将浪费的信息引入了通道选择,提升了剪枝后模型的鲁棒性。
可以理解的是,训练模块402中采样点坐标特征的构建包括:
对样本点云数据进行降采样,确定采样点;
提取采样点特征和采样点的初始坐标特征;
对采样点的特征信息和坐标信息进行特征融合,确定采样点坐标特征。
可以理解的是,训练模块402中对采样点的特征信息和坐标信息进行特征融合,包括:
通过交叉注意力机制对采样点的特征信息和坐标信息进行特征融合。
可以理解的是,训练模块402中未采样点特征的构建,包括:
确定样本点云数据中未被降采样的未采样点;
提取未采样点的特征信息,确定未采样点特征。
可以理解的是,训练模块402中通道重要性分数的确定,包括:
对样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征分别进行归一化,确定采样点特征分数、未采样点特征分数和采样点坐标特征分数;
将采样点特征分数、未采样点特征分数和采样点坐标特征分数进行加权计算,确定通道重要性分数。
可以理解的是,剪枝模块403包括:
基于通道重要性分数对初始检测模型进行通道选择,确定目标掩膜;
基于目标掩膜和稀疏训练后的初始检测模型,确定目标检测模型。
可以理解的是,剪枝模块403中基于目标掩膜和稀疏训练后的初始检测模型,确定目标检测模型,包括:
基于目标掩膜对稀疏训练后的初始检测模型的压缩层梯度进行掩码,确定掩码后的压缩层;
对稀疏训练后初始检测模型的卷积层和掩码后的压缩层进行融合,确定目标检测模型。
随着人工智能的发展,目标检测在生活及生产中的应用越来越广泛。目标检测模型常基于2D图像构建,为了提高模型检测的准确度,在建模时会引入具有更丰富信息的3D数据。3D数据主要分为两个主流趋势,一种为RGB-D,另一种方式是点云。3D点云存储了点的坐标信息和颜色信息,与RGB-D相比,保留了三维空间中原始的几何信息,不用进行离散化,更适用于目标检测任务。
基于点云数据的检测相比于传统方法,虽然提高了检测精度,但是往往具有检测模型尺寸大、资源消耗高且响应时间慢等问题。
如图3所示,本申请实施例公开了一种目标检测方法,包括:
步骤301、采集待检测点云数据;
步骤302、将待检测点云数据输入目标检测模型,得到目标检测模型输出的目标信息;
其中,目标检测模型是本申请上述实施例中任一种模型剪枝方法确定的。
需要说明的是,本申请可应用的场景广泛,包含家居场景、交通场景等。在家居场景下,输入点云数据给目标检测模型,模型可以快速检测出室内的物体位置及类别。在交通场景下,输入点云数据给目标检测模型,模型可以快速检测出路面车辆位置和类别、障碍物位置和类别等。
根据本申请实施例的目标检测方法,通过经过了剪枝的目标检测模型,实现对待测3D点云的目标识别。相比于传统的3D数据的目标检测,本方法由于使用更轻便的模型,检测速度大大提升,更适应于需要快速检测目标的场景。
可以理解的是,在家用场景下,将待检测点云数据输入目标检测模型,得到目标检测模型输出的目标信息,包括:
采集卧室场景下的点云数据作为待检测点云数据;
使用本申请的目标检测模型识别卧室场景下的待检测数据,输出卧室中的物体的目标框,并标注有类别信息,例如床或床头柜。
需要说明的是,本实施例中的目标检测模型,由于进行了剪枝,剪枝后的模型占用资源更少,运行速度更快,在对识别速度要求较高的场景下可以获得比传统3D目标检测模型更好的效果。例如,对于交通场景,车辆移动速度很快,相关部门进行车辆信息采集时,需要快速给出评定结果,此时本发明剪枝后的目标检测模型将会明显由于传统3D目标检测算法,给出实时位置识别结果。
如图5所示,本申请实施例公开了一种目标检测装置,包括:
采集模块501,用于采集待检测点云数据;
检测模块502,用于将待检测点云数据输入目标检测模型,得到目标检测模型输出的目标信息;
其中,目标检测模型是根据本申请上述实施例中任一种模型剪枝方法确定的。
根据本申请实施例的目标检测装置,通过经过了剪枝的目标检测模型,实现对待测3D点云的目标识别。相比于传统的3D数据的目标检测,本方法由于使用更轻便的模型,检测速度大大提升,更适应于需要快速检测目标的场景。
图6示例了一种电子设备的实体结构示意图,如图6所示,该电子设备可以包括:处理器(processor)610、通信接口(Communications Interface)620、存储器(memory)630和通信总线640,其中,处理器610,通信接口620,存储器630通过通信总线640完成相互间的通信。处理器610可以调用存储器630中的逻辑指令,以执行如下方法:
获取初始检测模型和样本点云数据;
将样本点云数据输入初始检测模型进行稀疏训练,确定初始检测模型的通道重要性分数;
其中,通道重要性分数基于样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征确定;
基于通道重要性分数对初始检测模型进行剪枝,确定目标检测模型。
或执行:
采集待检测点云数据;
将待检测点云数据输入目标检测模型,得到目标检测模型输出的目标信息。
此外,上述的存储器630中的逻辑指令可以通过软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对相关技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
另一方面,本申请实施例公开一种计算机程序产品,计算机程序产品包括存储在非暂态计算机可读存储介质上的计算机程序,计算机程序包括程序指令,当程序指令被计算机执行时,计算机能够执行上述各方法实施例所提供的方法,例如包括:
获取初始检测模型和样本点云数据;
将样本点云数据输入初始检测模型进行稀疏训练,确定初始检测模型的通道重要性分数;
其中,通道重要性分数基于样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征确定;
基于通道重要性分数对初始检测模型进行剪枝,确定目标检测模型。
或执行:
采集待检测点云数据;
将待检测点云数据输入目标检测模型,得到目标检测模型输出的目标信息。
又一方面,本申请实施例还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现以执行上述各实施例提供的传输方法,例如包括:
获取初始检测模型和样本点云数据;
将样本点云数据输入初始检测模型进行稀疏训练,确定初始检测模型的通道重要性分数;
其中,通道重要性分数基于样本点云数据的采样点特征、样本点云数据的未采样点特征和样本点云数据的采样点坐标特征确定;
基于通道重要性分数对初始检测模型进行剪枝,确定目标检测模型。
或执行:
采集待检测点云数据;
将待检测点云数据输入目标检测模型,得到目标检测模型输出的目标信息。
以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性的劳动的情况下,即可以理解并实施。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到各实施方式可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件。基于这样的理解,上述技术方案本质上或者说对相关技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在计算机可读存储介质中,如ROM/RAM、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行各个实施例或者实施例的某些部分的方法。
最后应说明的是,以上实施方式仅用于说明本申请,而非对本申请的限制。尽管参照实施例对本申请进行了详细说明,本领域的普通技术人员应当理解,对本申请的技术方案进行各种组合、修改或者等同替换,都不脱离本申请技术方案的精神和范围,均应涵盖在本申请的范围中。
Claims (9)
1.一种模型剪枝方法,其特征在于,包括:
获取初始检测模型和样本点云数据,其中,所述样本点云数据中包括样本三维图像中点的三维坐标信息;
将所述样本点云数据输入所述初始检测模型进行稀疏训练,确定所述初始检测模型的通道重要性分数;
其中,所述通道重要性分数基于所述样本点云数据的采样点特征、所述样本点云数据的未采样点特征和所述样本点云数据的采样点坐标特征确定;
基于所述通道重要性分数对所述初始检测模型进行剪枝,确定目标检测模型。
2.根据权利要求1所述的模型剪枝方法,其特征在于,构建所述采样点坐标特征,包括:
对所述样本点云数据进行降采样,确定采样点;
对所述采样点进行特征提取,确定所述样本点云数据的采样点特征;
对所述采样点的坐标信息进行特征提取,确定所述采样点的初始坐标特征;
对所述样本点云数据的采样点特征和所述采样点的初始坐标特征进行特征融合,确定所述样本点云数据的采样点坐标特征。
3.根据权利要求2所述的模型剪枝方法,其特征在于,所述对所述样本点云数据的采样点特征和所述采样点的初始坐标特征进行特征融合,包括:
通过交叉注意力机制对所述采样点特征和所述采样点的初始坐标特征进行特征融合。
4.根据权利要求2所述的模型剪枝方法,其特征在于,构建所述未采样点特征,包括:
确定所述样本点云数据中未被降采样的未采样点;
提取所述未采样点的特征信息,确定所述未采样点特征。
5.根据权利要求4所述的模型剪枝方法,其特征在于,确定所述初始检测模型的通道重要性分数,包括:
对所述样本点云数据的采样点特征、所述样本点云数据的未采样点特征和所述样本点云数据的采样点坐标特征分别进行归一化,确定采样点特征分数、未采样点特征分数和采样点坐标特征分数;
将所述采样点特征分数、所述未采样点特征分数和所述采样点坐标特征分数进行加权计算,确定所述通道重要性分数。
6.根据权利要求1至5任一所述的模型剪枝方法,其特征在于,所述基于所述通道重要性分数对所述初始检测型进行剪枝,确定目标检测模型,包括:
基于所述通道重要性分数对所述初始检测模型进行通道选择,确定目标掩膜;
基于所述目标掩膜和稀疏训练后的所述初始检测模型,确定目标检测模型。
7.根据权利要求6所述的模型剪枝方法,其特征在于,所述基于所述目标掩膜和稀疏训练后的所述初始检测模型,确定目标检测模型,包括:
基于所述目标掩膜对稀疏训练后的初始检测模型的压缩层梯度进行掩码,确定掩码后的压缩层;
对所述稀疏训练后初始检测模型的卷积层和掩码后的压缩层进行融合,确定目标检测模型。
8.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1至7任一项所述的模型剪枝方法。
9.一种非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现如权利要求1至7任一项所述的模型剪枝方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210880632.6A CN115169556B (zh) | 2022-07-25 | 2022-07-25 | 模型剪枝方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210880632.6A CN115169556B (zh) | 2022-07-25 | 2022-07-25 | 模型剪枝方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115169556A CN115169556A (zh) | 2022-10-11 |
CN115169556B true CN115169556B (zh) | 2023-08-04 |
Family
ID=83496841
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210880632.6A Active CN115169556B (zh) | 2022-07-25 | 2022-07-25 | 模型剪枝方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115169556B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116468101A (zh) * | 2023-03-21 | 2023-07-21 | 美的集团(上海)有限公司 | 模型剪枝方法、装置、电子设备和可读存储介质 |
Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CA2652710A1 (en) * | 2008-02-05 | 2009-08-05 | Solido Design Automation Inc. | Pruning-based variation-aware design |
CN111461212A (zh) * | 2020-03-31 | 2020-07-28 | 中国科学院计算技术研究所 | 一种用于点云目标检测模型的压缩方法 |
CN111932690A (zh) * | 2020-09-17 | 2020-11-13 | 北京主线科技有限公司 | 基于3d点云神经网络模型的剪枝方法及装置 |
CN112396179A (zh) * | 2020-11-20 | 2021-02-23 | 浙江工业大学 | 一种基于通道梯度剪枝的柔性深度学习网络模型压缩方法 |
CN112446476A (zh) * | 2019-09-04 | 2021-03-05 | 华为技术有限公司 | 神经网络模型压缩的方法、装置、存储介质和芯片 |
CN112465114A (zh) * | 2020-11-25 | 2021-03-09 | 重庆大学 | 基于优化通道剪枝的快速目标检测方法及系统 |
CN112668630A (zh) * | 2020-12-24 | 2021-04-16 | 华中师范大学 | 一种基于模型剪枝的轻量化图像分类方法、系统及设备 |
CN113011430A (zh) * | 2021-03-23 | 2021-06-22 | 中国科学院自动化研究所 | 大规模点云语义分割方法及系统 |
AU2021103976A4 (en) * | 2021-03-22 | 2021-09-09 | Jiangsu University | Asthma diagnosis system based on decision tree and improved SMOTE algorithm |
CN113408561A (zh) * | 2020-03-17 | 2021-09-17 | 北京京东乾石科技有限公司 | 模型生成方法、目标检测方法、装置、设备及存储介质 |
CN113766228A (zh) * | 2020-06-05 | 2021-12-07 | Oppo广东移动通信有限公司 | 点云压缩方法、编码器、解码器及存储介质 |
CN114286103A (zh) * | 2021-12-24 | 2022-04-05 | 复旦大学 | 一种基于深度学习的保留密度的点云压缩方法 |
CN114419732A (zh) * | 2022-01-11 | 2022-04-29 | 江南大学 | 基于注意力机制优化的HRNet人体姿态识别方法 |
Family Cites Families (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10671082B2 (en) * | 2017-07-03 | 2020-06-02 | Baidu Usa Llc | High resolution 3D point clouds generation based on CNN and CRF models |
WO2019099899A1 (en) * | 2017-11-17 | 2019-05-23 | Facebook, Inc. | Analyzing spatially-sparse data based on submanifold sparse convolutional neural networks |
CN110349230A (zh) * | 2019-07-15 | 2019-10-18 | 北京大学深圳研究生院 | 一种基于深度自编码器的点云几何压缩的方法 |
US20210090328A1 (en) * | 2020-12-07 | 2021-03-25 | Intel Corporation | Tile-based sparsity aware dataflow optimization for sparse data |
-
2022
- 2022-07-25 CN CN202210880632.6A patent/CN115169556B/zh active Active
Patent Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CA2652710A1 (en) * | 2008-02-05 | 2009-08-05 | Solido Design Automation Inc. | Pruning-based variation-aware design |
CN112446476A (zh) * | 2019-09-04 | 2021-03-05 | 华为技术有限公司 | 神经网络模型压缩的方法、装置、存储介质和芯片 |
CN113408561A (zh) * | 2020-03-17 | 2021-09-17 | 北京京东乾石科技有限公司 | 模型生成方法、目标检测方法、装置、设备及存储介质 |
CN111461212A (zh) * | 2020-03-31 | 2020-07-28 | 中国科学院计算技术研究所 | 一种用于点云目标检测模型的压缩方法 |
CN113766228A (zh) * | 2020-06-05 | 2021-12-07 | Oppo广东移动通信有限公司 | 点云压缩方法、编码器、解码器及存储介质 |
CN111932690A (zh) * | 2020-09-17 | 2020-11-13 | 北京主线科技有限公司 | 基于3d点云神经网络模型的剪枝方法及装置 |
CN112396179A (zh) * | 2020-11-20 | 2021-02-23 | 浙江工业大学 | 一种基于通道梯度剪枝的柔性深度学习网络模型压缩方法 |
CN112465114A (zh) * | 2020-11-25 | 2021-03-09 | 重庆大学 | 基于优化通道剪枝的快速目标检测方法及系统 |
CN112668630A (zh) * | 2020-12-24 | 2021-04-16 | 华中师范大学 | 一种基于模型剪枝的轻量化图像分类方法、系统及设备 |
AU2021103976A4 (en) * | 2021-03-22 | 2021-09-09 | Jiangsu University | Asthma diagnosis system based on decision tree and improved SMOTE algorithm |
CN113011430A (zh) * | 2021-03-23 | 2021-06-22 | 中国科学院自动化研究所 | 大规模点云语义分割方法及系统 |
CN114286103A (zh) * | 2021-12-24 | 2022-04-05 | 复旦大学 | 一种基于深度学习的保留密度的点云压缩方法 |
CN114419732A (zh) * | 2022-01-11 | 2022-04-29 | 江南大学 | 基于注意力机制优化的HRNet人体姿态识别方法 |
Also Published As
Publication number | Publication date |
---|---|
CN115169556A (zh) | 2022-10-11 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP6980958B1 (ja) | 深層学習に基づく農村地域分けゴミ識別方法 | |
CN109086773B (zh) | 基于全卷积神经网络的断层面识别方法 | |
CN110689599B (zh) | 基于非局部增强的生成对抗网络的3d视觉显著性预测方法 | |
CN108564097A (zh) | 一种基于深度卷积神经网络的多尺度目标检测方法 | |
CN112017189A (zh) | 图像分割方法、装置、计算机设备和存储介质 | |
CN110111366A (zh) | 一种基于多级损失量的端到端光流估计方法 | |
CN113240691A (zh) | 一种基于u型网络的医学图像分割方法 | |
CN107506761A (zh) | 基于显著性学习卷积神经网络的脑部图像分割方法及系统 | |
CN109784283A (zh) | 基于场景识别任务下的遥感图像目标提取方法 | |
CN106462771A (zh) | 一种3d图像的显著性检测方法 | |
CN112233129B (zh) | 基于深度学习的并行多尺度注意力机制语义分割方法及装置 | |
CN110110646A (zh) | 一种基于深度学习的手势图像关键帧提取方法 | |
CN113223005B (zh) | 一种甲状腺结节自动分割及分级的智能系统 | |
CN114266794B (zh) | 基于全卷积神经网络的病理切片图像癌症区域分割系统 | |
CN113034444A (zh) | 一种基于MobileNet-PSPNet神经网络模型的路面裂缝检测方法 | |
CN115169556B (zh) | 模型剪枝方法及装置 | |
CN112288749A (zh) | 一种基于深度迭代融合深度学习模型的颅骨图像分割方法 | |
CN112927237A (zh) | 基于改进SCB-Unet网络的蜂窝肺病灶分割方法 | |
CN113724286A (zh) | 显著性目标的检测方法、检测设备及计算机可读存储介质 | |
CN111860465A (zh) | 基于超像素的遥感图像提取方法、装置、设备及存储介质 | |
CN115797929A (zh) | 基于双注意力机制的小型农田图像分割方法、装置 | |
CN111462090A (zh) | 一种多尺度图像目标检测方法 | |
CN113554656A (zh) | 基于图神经网络的光学遥感图像实例分割方法及装置 | |
CN111223113A (zh) | 基于双重密集上下文感知网络的核磁共振海马体分割算法 | |
CN116630850A (zh) | 基于多注意力任务融合与边界框编码的孪生目标跟踪方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |