CN114881136A - 基于剪枝卷积神经网络的分类方法及相关设备 - Google Patents
基于剪枝卷积神经网络的分类方法及相关设备 Download PDFInfo
- Publication number
- CN114881136A CN114881136A CN202210458105.6A CN202210458105A CN114881136A CN 114881136 A CN114881136 A CN 114881136A CN 202210458105 A CN202210458105 A CN 202210458105A CN 114881136 A CN114881136 A CN 114881136A
- Authority
- CN
- China
- Prior art keywords
- pruning
- weight
- classification
- classification model
- initial
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Abstract
本发明提供一种基于剪枝卷积神经网络的分类方法及相关设备,包括:获取待分类的图片;将待分类的图片输入剪枝后的分类模型,得到对应的分类结果;其中,剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、训练图片集对应的标签集以及权重掩码训练集训练得到。本发明能够大大节约剪枝耗时。
Description
技术领域
本发明涉及模型剪枝技术领域,尤其涉及一种基于剪枝卷积神经网络的分类方法及相关设备。
背景技术
目前,深度学习模型需要大量算力、内存和电量。当需要执行实时推断、在设备端运行模型、在计算资源有限的情况下运行深度学习模型时,需要体积小且准确率高的深度学习模型,因而模型压缩则可以实现这一目标,模型剪枝正是模型压缩中的一种。
模型剪枝其主要用来减少卷积神经网络中的计算量,通常情况下是通过裁剪掉神经网络权重中不重要的张量来达到降低整个神经网络的计算量的目的。在对不重要的张量进行剪枝前,需要确定模型各层的稀疏率从而确定出不重要的张量。
现有的稀疏率确定方法包括:各层稀疏一致的方法与对各层分析其敏感度的方法。其中,敏感度分析方法的主要思路为:依次分析各层剪枝后的模型效果变化,以此判断各层敏感度。而要想更好地分析各层剪枝后的敏感度,需要在剪枝后对模型进行微调训练(fine-tuning训练),如果想获得给定目标下(如特定计算量要求下)的剪枝配置,需分析所有层的剪枝敏感度,从而确定不重要的张量,剪枝过程耗时久。
发明内容
本发明提供一种基于剪枝卷积神经网络的分类方法及相关设备,用以解决上述问题。
本发明提供一种基于剪枝卷积神经网络的分类方法,包括:
获取待分类的图片;
将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
根据本发明提供的一种基于剪枝卷积神经网络的分类方法,所述方法还包括:
为初始分类模型中每个待分析网络层构建初始剪枝敏感度分析模型;
将权重掩码训练集输入所述初始剪枝敏感度分析模型,得到待分析网络层的初始权重;所述权重掩码训练集通过随机数生成方法生成,且所述权重掩码训练集中的每项权重掩码训练数据与剪枝率对应;
将训练图片集输入所述初始分类模型,基于每个待分析网络层的初始权重以及前向计算得到初始预测结果;
根据所述初始预测结果以及所述标签集,利用反向传播算法对所述每个待分析网络层的初始权重进行更新,得到待分析网络层的更新后权重;
根据所述待分析网络层的更新后权重,利用反向传播算法的链式法则,对所述初始剪枝敏感度分析模型中的权重进行更新,直到所述初始分类模型收敛,从而得到预先训练好的剪枝敏感度分析模型以及训练好的分类模型。
根据本发明提供的一种基于剪枝卷积神经网络的分类方法,所述将权重掩码训练集输入所述初始剪枝敏感度分析模型,得到待分析网络层的初始权重,包括:
将权重掩码训练集中的权重掩码训练数据输入多个卷积网络,以生成特征图;以及
融合特征图与权重掩码训练数据,以生成待分析网络层的初始权重。
根据本发明提供的一种基于剪枝卷积神经网络的分类方法,所述方法还包括:
步骤201,在预设的剪枝率搜索空间内穷举所述训练好的分类模型中每个网络层的剪枝率,得到剪枝率集合;在所述剪枝率集合中筛选出满足预设的剪枝目标的剪枝率,作为待分析剪枝率;其中,所述预设的剪枝目标为所述训练好的分类模型的剪枝目标;
步骤202,根据剪枝率与权重掩码的对应关系,确定所述待分析剪枝率所对应的权重掩码,作为待分析权重掩码;
步骤203,将所述待分析权重掩码输入所述预先训练好的剪枝敏感度分析模型中,得到所述训练好的分类模型对应的剪枝后权重;
步骤204,基于所述剪枝后权重对所述训练好的分类模型进行性能评价,得到性能评价指标值;
步骤205,重复所述步骤203至所述步骤204,直到穷尽所有待分析剪枝率,从而得到多个性能评价指标值;
步骤206,从所述多个性能评价指标值中确定出最大的性能评价指标值作为最优灵敏度,将所述最优灵敏度对应的剪枝后权重作为剪枝后的分类模型中的分类模型权重。
根据本发明提供的一种基于剪枝卷积神经网络的分类方法,所述剪枝率与权重掩码的对应关系为权重掩码mask基于剪枝率p与权重维度信息得到,其中,所述权重维度信息包括通道数C:
mask[0:C*p]=0,mask[C*p:C]=1
其中,mask[0:C*p]=0为C个通道中前C*p个通道对应的权重掩码mask为0,mask[C*p:C]=1为C个通道中后C-C*p个通道对应的权重掩码mask为1。
根据本发明提供的一种基于剪枝卷积神经网络的分类方法,所述在预设的剪枝率搜索空间内穷举所述训练好的分类模型中每个网络层的剪枝率,得到剪枝率集合,包括:
根据预先设定的剪枝率的取值范围以及剪枝率取值步长,在所述预先设定的剪枝率的取值范围内,穷举得到所有符合所述剪枝率取值步长的剪枝率,从而得到剪枝率集合。
根据本发明提供的一种基于剪枝卷积神经网络的分类方法,所述预设的剪枝目标包括以下至少一项:所述训练好的分类模型的目标计算力;所述训练好的分类模型的目标参数量。
本发明还提供一种基于剪枝卷积神经网络的分类装置,包括:
图片获取模块,用于获取待分类的图片;
分类模块,用于将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
本发明还提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行程序时实现如上述任一种基于剪枝卷积神经网络的分类方法。
本发明还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现如上述任一种基于剪枝卷积神经网络的分类方法。
本发明还提供一种计算机程序产品,包括计算机程序,计算机程序被处理器执行时实现如上述任一种基于剪枝卷积神经网络的分类方法。
本发明提供的基于剪枝卷积神经网络的分类方法及相关设备,通过剪枝后的分类模型对待分类的图片进行分类识别,其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;而所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到,因此,分类模型在剪枝后并不需要对模型进行微调训练,其对应的分类模型权重直接根据预先训练好的剪枝敏感度分析模型得到,从而大大节省了剪枝耗时。
附图说明
为了更清楚地说明本发明或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的基于剪枝卷积神经网络的分类方法的流程示意图;
图2是本发明实施例提供的初始分类模型的网络结构示意图;
图3是本发明实施例提供的初始剪枝敏感度分析模型的网络结构示意图;
图4为本发明实施例提供的基于剪枝卷积神经网络的分类装置结构示意图;
图5为本发明实施例提供的一种电子设备的实体结构示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面将结合本发明中的附图,对本发明中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
图1是本发明实施例提供的基于剪枝卷积神经网络的分类方法的流程示意图;如图1所示,一种基于剪枝卷积神经网络的分类方法,包括如下步骤:
S101,获取待分类的图片。
本步骤中,待分类的图片为待分类车型图片,该待分类车型图片可以直接拍摄获取,也可以是网上获取的图片,或者是车型图片数据库中的影像。
S102,将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果。
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到。
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
本实施例中,与上述待分类车型图片对应,剪枝后的分类模型为剪枝后的车型分类模型,对应地,经过剪枝后的车型分类模型前向推理后得到的分类结果包括轿车、卡车、货车、救护车、公交车、自行车、三轮车等。
同样地,预先训练好的剪枝敏感度分析模型是基于初始车型分类模型、车型训练图片集、所述车型训练图片集对应的车型标签集以及权重掩码训练集训练得到。
具体地,预先训练好的剪枝敏感度分析模型对应的输入为权重掩码,输出结果则为剪枝后的分类模型的分类模型权重,其中,输入权重掩码又与剪枝率对应,也即通过预先训练好的剪枝敏感度分析模型能够得到预定剪枝率下的分类模型权重,基于该分类模型权重直接得到预定剪枝率下的分类模型。
最优的分类模型权重则通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到,即通过神经网络结构搜索方法能够获得剪枝敏感度分析模型中敏感度最优的分类模型权重,将最优的分类模型权重作为分类模型中的权重,从而得到剪枝后的分类模型。
本发明实施例提供的基于剪枝卷积神经网络的分类方法,通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到敏感度最优的权重,基于该敏感度最优的权重直接获得剪枝后的分类模型,无需在分类模型剪枝后进行微调训练,剪枝过程中的敏感度分析时间变为仅需训练一次的剪枝敏感度分析模型的,从而大大减少了剪枝耗时。
进一步地,所述方法还包括:
为初始分类模型中每个待分析网络层构建初始剪枝敏感度分析模型;
将权重掩码训练集输入所述初始剪枝敏感度分析模型,得到待分析网络层的初始权重;所述权重掩码训练集通过随机数生成方法生成,且所述权重掩码训练集中的每项权重掩码训练数据与剪枝率对应;
将训练图片集输入所述初始分类模型,基于每个待分析网络层的初始权重以及前向计算得到初始预测结果;
根据所述初始预测结果以及所述标签集,利用反向传播算法对所述每个待分析网络层的初始权重进行更新,得到待分析网络层的更新后权重;
根据所述待分析网络层的更新后权重,利用反向传播算法的链式法则,对所述初始剪枝敏感度分析模型中的权重进行更新,直到所述初始分类模型收敛,从而得到预先训练好的剪枝敏感度分析模型以及训练好的分类模型。
图2是本发明实施例提供的初始分类模型的网络结构示意图,如图2所示,初始分类模型包括卷积层A、卷积层B、卷积层C等等,而初始分类模型的输入为训练图片集中预先标定好的图片以及对应的类别标签,输出则为预先标定好的图片对应的预测类别。
初始分类模型中每个网络层均需要进行敏感度分析,因此每个网络层均为待分析网络层。由于每个待分析网络层均需要进行敏感度分析,因此每个待分析网络层均对应有一个初始剪枝敏感度分析模型。
图3是本发明实施例提供的初始剪枝敏感度分析模型的网络结构示意图,如图3所示,初始剪枝敏感度分析模型对应的网络层包括卷积层1与卷积层2,初始剪枝敏感度分析模型的输入为权重掩码训练集中随机生成的权重掩码,权重掩码经过卷积层1以及卷积层2的卷积操作后得到特征图与开始输入的权重掩码进行融合后得到输出结果,该输出结果直接作为待分析层的权重。
其中,初始剪枝敏感度分析模型的输入input为权重掩码mask,shape为(待分析网络层的输出大小layer-out,待分析网络层的输入大小layer-in,k,k),k为卷积核的大小;Conv1的shape为(layer-in*2,layer-in,3,3);Conv2的shape为(layer-out,layer-in*2,3,3);输出output的shape为(layer-out,layer-in,k,k),该输出output即为待分析网络层的权重。
上述权重掩码与剪枝率对应,即剪枝率的大小影响权重掩码取0的多少,若是剪枝率小,则通道中权重掩码取0的数量少,当权重掩码为0时,对应的通道中权重值不参与计算;当权重掩码取1,对应的通道中权重值参与计算。另外,上述权重掩码训练集利用现有的随机算法随机生成。
根据上述初始分类模型与初始剪枝敏感度分析模型的特性以及关联性进行模型的联合训练,具体如下:
将训练图片集中预先标定好的图片输入初始分类模型,初始分类模型中的各个网络层的计算之后得到输出结果(即前向计算),即预测的类别。然后,通过损失函数对预测的类别与预先标定好的标签进行损失计算(即loss计算),基于计算得到的损失值,并利用反向传播算法来更新整个初始分类模型中网络层的权重值大小。当初始分类模型收敛时,其预测的类别与对应标签匹配度最高,整个初始分类模型为训练好的分类模型。
在上述初始分类模型中每个待分析网络层的分类模型权重更新过程,更新的分类模型权重作为初始剪枝敏感度分析模型输出结果的标签,同样利用损失函数计算初始剪枝敏感度分析模型输出结果的标签与在将权重掩码输入初始剪枝敏感度分析模型后得到的输出结果进行损失计算,并根据损失值以及反向传播算法训练初始剪枝敏感度分析模型,直到每个初始剪枝敏感度分析模型收敛,从而得到训练好的剪枝敏感度分析模型。
本发明实施例提供的基于剪枝卷积神经网络的分类方法,利用反向传播的链式法则训练得到剪枝敏感度分析模型与分类模型,从而节约了剪枝耗时。
进一步地,所述将权重掩码训练集输入所述初始剪枝敏感度分析模型,得到待分析网络层的初始权重,包括:
将权重掩码训练集中的权重掩码训练数据输入多个卷积网络,以生成特征图;以及
融合特征图与权重掩码训练数据,以生成待分析网络层的初始权重。
具体地,将随机生成的权重掩码输入初始剪枝敏感度分析模型,经过初始剪枝敏感度分析模型中权重参数的前向计算,预测得到与权重掩码对应的剪枝敏感度分析结果,将输出结果作为待分析网络层的初始权重,也即在权重掩码确定的情况下,待分析网络层中的哪些权重保持,哪些权重被置0可通过初始剪枝敏感度分析模型预测得到。
通过初始剪枝敏感度分析模型与初始分类模型的联合训练能够准确获得权重掩码与分类模型中权重之间的对应关系。
进一步地,所述方法还包括:
步骤201,在预设的剪枝率搜索空间内穷举所述训练好的分类模型中每个网络层的剪枝率,得到剪枝率集合。在所述剪枝率集合中筛选出满足预设的剪枝目标的剪枝率,作为待分析剪枝率。
其中,所述预设的剪枝目标为所述训练好的分类模型的剪枝目标。
在本步骤中,剪枝率搜索空间为设定的剪枝率取值范围,以特定的步长来在上述剪枝率取值范围内可以穷尽所有可能的剪枝率,从而得到剪枝率集合。
在本步骤中,预设的剪枝目标是指用户希望剪枝后的分类模型具有何种性能,例如,希望剪枝后的分类模型的参数量达到什么目标,计算力(即FLOPs)达到什么目标,从而来衡量压缩效果。
综合所有待分析网络层对应的剪枝率集合,筛选出能够满足剪枝目标的剪枝率组合,即每个待分析网络层筛选出的剪枝率组合在一起能够满足上述剪枝目标。
步骤202,根据剪枝率与权重掩码的对应关系,确定所述待分析剪枝率所对应的权重掩码,作为待分析权重掩码。
在本步骤中,根据上述筛选出的待分析剪枝率以及剪枝率与权重掩码之间的对应关系,确定待分析剪枝率对应的权重掩码,并作为待分析权重掩码。
步骤203,将所述待分析权重掩码输入所述预先训练好的剪枝敏感度分析模型中,得到所述训练好的分类模型对应的剪枝后权重。
在本步骤中,将每个待分析网络层对应的待分析权重掩码输入对应的训练好的剪枝敏感度分析模型,得到的输出结果即为待分析网络层的剪枝后权重,综合所有待分析网络层的剪枝后权重,得到整个分类模型的剪枝后权重。
步骤204,基于所述剪枝后权重对所述训练好的分类模型进行性能评价,得到性能评价指标值。
在本步骤中,根据上述整个分类模型的剪枝后权重来评价分类模型性能,从而得到性能评价指标值。
本实施例中,性能评价指标包括准确率ACC;在本发明的其他实施例中,性能评价指标也可以是误检率false positve、精确率precision、召回率recall等等其他模型性能评价指标。
步骤205,重复所述步骤203至所述步骤204,直到穷尽所有待分析剪枝率,从而得到多个性能评价指标值.
在本步骤中,由于满足剪枝目标的待分析剪枝率组合会有很多钟,因而需要穷尽所有待分析剪枝率,从而得到多个待分析剪枝率组合下对应的分类模型的性能评价指标值。
步骤206,从所述多个性能评价指标值中确定出最大的性能评价指标值作为最优灵敏度,将所述最优灵敏度对应的剪枝后权重作为剪枝后的分类模型中的分类模型权重。
在本步骤中,将性能评价指标值最好的分类模型作为剪枝敏感度最优的模型,其对应的剪枝后权重最终作为剪枝后的分类模型中的分类模型权重。
本发明实施例提供的基于剪枝卷积神经网络的分类方法,根据预设的剪枝目标获取所有剪枝率可能,进一步获取所有分类模型的剪枝可能,再通过模型的性能评价来衡量分类模型的剪枝敏感度,最后将性能评价指标值最高的分类模型作为敏感度最优的模型,其对应的剪枝后权重以及剪枝率即为最优剪枝后权重与最优剪枝率。从而得到剪枝后的分类模型,因此,在剪枝后不需要微调训练,也不需要逐层剪枝敏感度分析,能在短时间能完成预定目标的剪枝,提高了剪枝效率。
进一步地,所述剪枝率与权重掩码的对应关系为权重掩码mask基于剪枝率p与权重维度信息得到,其中,所述权重维度信息包括通道数C:
mask[0:C*p]=0,mask[C*p:C]=1
其中,mask[0:C*p]=0为C个通道中前C*p个通道对应的权重掩码mask为0,mask[C*p:C]=1为C个通道中后C-C*p个通道对应的权重掩码mask为1,剪枝率p∈[0,1]。
具体地,权重维度信息包括通道数C、卷积核大小k。假设权重维度信息为(c1,c2,k,k),其中,c1与c2分别为第一维度通道数与第二维度通道数。则在第一维度c1中,权重掩码mask[0:c1*p]=0,mask[c1*p:c1]=1,即c1个通道中前c1*p个通道对应的权重掩码数值为0,后c1-c1*p个通道中对应的权重掩码数值为1;当权重掩码数值为0时,表示前c1*p个通道中的权重不参与计算,即完成剪枝;当权重掩码数值为1时,表示后c1-c1*p个通道中对应的权重数值保持不变,继续参与计算,即被保留。
同样地,在第二维度c2中,权重掩码mask[0:c2*p]=0,mask[c2*p:c2]=1,即c2个通道中前c2*p个通道对应的权重掩码数值为0,后c2-c2*p个通道中对应的权重掩码数值为1.
进一步地,所述在预设的剪枝率搜索空间内穷举所述训练好的分类模型中每个网络层的剪枝率,得到剪枝率集合,包括:
根据预先设定的剪枝率的取值范围以及剪枝率取值步长,在所述预先设定的剪枝率的取值范围内,穷举得到所有符合所述剪枝率取值步长的剪枝率,从而得到剪枝率集合。
具体地,假设设定的剪枝率的取值范围为(0.1,0.9),剪枝率取值步长为0.1,则剪枝率集合为(0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9)。
进一步地,所述预设的剪枝目标包括以下至少一项:所述训练好的分类模型的目标计算力;所述训练好的分类模型的目标参数量。
具体地,若未经过剪枝的分类模型的计算量为12GFLOPs,参数量为2M,则剪枝目标为:剪枝后的计算量为2GFLOPs,参数量为0.5M,即在上述剪枝率筛选过程中,筛选得到的待分析剪枝率需要使得剪枝后分类模型的计算量不大于2GFLOPs,参数量不大于0.5M,从而实现模型的加速与压缩,便于直接应用在自动驾驶场景中。
下面对本发明提供的基于剪枝卷积神经网络的分类装置进行描述,下文描述的基于剪枝卷积神经网络的分类装置与上文描述的基于剪枝卷积神经网络的分类方法可相互对应参照。
图4为本发明实施例提供的基于剪枝卷积神经网络的分类装置结构示意图,如图4所示,一种基于剪枝卷积神经网络的分类装置,包括:
图片获取模块401,用于获取待分类的图片。
在本模块中,待分类的图片为待分类车型图片,该待分类车型图片可以直接拍摄获取,也可以是网上获取的图片,或者是车型图片数据库中的影像。
分类模块402,用于将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果。
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到。
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
本实施例中,与上述待分类车型图片对应,剪枝后的分类模型为剪枝后的车型分类模型,对应地,经过剪枝后的车型分类模型前向推理后得到的分类结果包括轿车、卡车、货车、救护车、公交车、自行车、三轮车等。
同样地,预先训练好的剪枝敏感度分析模型是基于初始车型分类模型、车型训练图片集、所述车型训练图片集对应的车型标签集以及权重掩码训练集训练得到。
具体地,预先训练好的剪枝敏感度分析模型对应的输入为权重掩码,输出结果则为剪枝后的分类模型的分类模型权重,其中,输入权重掩码又与剪枝率对应,也即通过预先训练好的剪枝敏感度分析模型能够得到预定剪枝率下的分类模型权重,基于该分类模型权重直接得到预定剪枝率下的分类模型。
最优的分类模型权重则通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到,即通过神经网络结构搜索方法能够获得剪枝敏感度分析模型中敏感度最优的分类模型权重,将最优的分类模型权重作为分类模型中的权重,从而得到剪枝后的分类模型。
上述分类结果包括轿车、卡车、货车、救护车、公交车、自行车、三轮车等。
本发明实施例提供的基于剪枝卷积神经网络的分类装置,通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到敏感度最优的权重,基于该敏感度最优的权重直接获得剪枝后的分类模型,无需在分类模型剪枝后进行微调训练,剪枝过程中的敏感度分析时间变为仅需训练一次的剪枝敏感度分析模型的,从而大大减少了剪枝耗时。
图5为本发明实施例提供的一种电子设备的实体结构示意图,如图5所示,该电子设备可以包括:处理器(processor)510、通信接口(Communications Interface)520、存储器(memory)530和通信总线540,其中,处理器510,通信接口520,存储器530通过通信总线540完成相互间的通信。处理器510可以调用存储器530中的逻辑指令,以执行基于剪枝卷积神经网络的分类方法,所述基于剪枝卷积神经网络的分类方法,包括:
获取待分类的图片;
将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
此外,上述的存储器530中的逻辑指令可以通过软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
另一方面,本发明还提供一种计算机程序产品,计算机程序产品包括计算机程序,计算机程序可存储在非暂态计算机可读存储介质上,计算机程序被处理器执行时,计算机能够执行上述各方法所提供的一种基于剪枝卷积神经网络的分类方法,所述基于剪枝卷积神经网络的分类方法,包括:
获取待分类的图片;
将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
又一方面,本发明还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现以执行上述个方法所提供的基于剪枝卷积神经网络的分类方法,所述基于剪枝卷积神经网络的分类方法,包括:
获取待分类的图片;
将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性的劳动的情况下,即可以理解并实施。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到各实施方式可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件。基于这样的理解,上述技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在计算机可读存储介质中,如ROM/RAM、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行各个实施例或者实施例的某些部分的方法。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (11)
1.一种基于剪枝卷积神经网络的分类方法,其特征在于,包括:
获取待分类的图片;
将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
2.根据权利要求1所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述方法还包括:
为初始分类模型中每个待分析网络层构建初始剪枝敏感度分析模型;
将权重掩码训练集输入所述初始剪枝敏感度分析模型,得到待分析网络层的初始权重;所述权重掩码训练集通过随机数生成方法生成,且所述权重掩码训练集中的每项权重掩码训练数据与剪枝率对应;
将训练图片集输入所述初始分类模型,基于每个待分析网络层的初始权重以及前向计算得到初始预测结果;
根据所述初始预测结果以及所述标签集,利用反向传播算法对所述每个待分析网络层的初始权重进行更新,得到待分析网络层的更新后权重;
根据所述待分析网络层的更新后权重,利用反向传播算法的链式法则,对所述初始剪枝敏感度分析模型中的权重进行更新,直到所述初始分类模型收敛,从而得到预先训练好的剪枝敏感度分析模型以及训练好的分类模型。
3.根据权利要求2所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述将权重掩码训练集输入所述初始剪枝敏感度分析模型,得到待分析网络层的初始权重,包括:
将所述权重掩码训练集中的权重掩码训练数据输入多个卷积网络,以生成特征图;以及
融合所述特征图与所述权重掩码训练数据,以生成所述待分析网络层的所述初始权重。
4.根据权利要求2所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述方法还包括:
步骤201,在预设的剪枝率搜索空间内穷举所述训练好的分类模型中每个网络层的剪枝率,得到剪枝率集合;在所述剪枝率集合中筛选出满足预设的剪枝目标的剪枝率,作为待分析剪枝率;其中,所述预设的剪枝目标为所述训练好的分类模型的剪枝目标;
步骤202,根据剪枝率与权重掩码的对应关系,确定所述待分析剪枝率所对应的权重掩码,作为待分析权重掩码;
步骤203,将所述待分析权重掩码输入所述预先训练好的剪枝敏感度分析模型中,得到所述训练好的分类模型对应的剪枝后权重;
步骤204,基于所述剪枝后权重对所述训练好的分类模型进行性能评价,得到性能评价指标值;
步骤205,重复所述步骤203至所述步骤204,直到穷尽所有待分析剪枝率,从而得到多个性能评价指标值;
步骤206,从所述多个性能评价指标值中确定出最大的性能评价指标值作为最优灵敏度,将所述最优灵敏度对应的剪枝后权重作为剪枝后的分类模型中的分类模型权重。
5.根据权利要求4所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述剪枝率与权重掩码的对应关系为权重掩码mask基于剪枝率p与权重维度信息得到,其中,所述权重维度信息包括通道数C:
mask[0:C*p]=0,mask[C*p:C]=1
其中,mask[0:C*p]=0为C个通道中前C*p个通道对应的权重掩码mask为0,mask[C*p:C]=1为C个通道中后C-C*p个通道对应的权重掩码mask为1。
6.根据权利要求4所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述在预设的剪枝率搜索空间内穷举所述训练好的分类模型中每个网络层的剪枝率,得到剪枝率集合,包括:
根据预先设定的剪枝率的取值范围以及剪枝率取值步长,在所述预先设定的剪枝率的取值范围内,穷举得到所有符合所述剪枝率取值步长的剪枝率,从而得到剪枝率集合。
7.根据权利要求4-6任一所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述预设的剪枝目标包括以下至少一项:所述训练好的分类模型的目标计算力;所述训练好的分类模型的目标参数量。
8.一种基于剪枝卷积神经网络的分类装置,其特征在于,包括:
图片获取模块,用于获取待分类的图片;
分类模块,用于将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果;
其中,所述剪枝后的分类模型中的分类模型权重是通过预定义的神经网络结构搜索方法对预先训练好的剪枝敏感度分析模型搜索得到;
所述预先训练好的剪枝敏感度分析模型是基于初始分类模型、训练图片集、所述训练图片集对应的标签集以及权重掩码训练集训练得到。
9.一种电子设备,包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1至7任一项所述基于剪枝卷积神经网络的分类方法。
10.一种非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述基于剪枝卷积神经网络的分类方法。
11.一种计算机程序产品,包括计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述基于剪枝卷积神经网络的分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210458105.6A CN114881136A (zh) | 2022-04-27 | 2022-04-27 | 基于剪枝卷积神经网络的分类方法及相关设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210458105.6A CN114881136A (zh) | 2022-04-27 | 2022-04-27 | 基于剪枝卷积神经网络的分类方法及相关设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114881136A true CN114881136A (zh) | 2022-08-09 |
Family
ID=82671193
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210458105.6A Pending CN114881136A (zh) | 2022-04-27 | 2022-04-27 | 基于剪枝卷积神经网络的分类方法及相关设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114881136A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117131920A (zh) * | 2023-10-26 | 2023-11-28 | 北京市智慧水务发展研究院 | 一种基于网络结构搜索的模型剪枝方法 |
-
2022
- 2022-04-27 CN CN202210458105.6A patent/CN114881136A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117131920A (zh) * | 2023-10-26 | 2023-11-28 | 北京市智慧水务发展研究院 | 一种基于网络结构搜索的模型剪枝方法 |
CN117131920B (zh) * | 2023-10-26 | 2024-01-30 | 北京市智慧水务发展研究院 | 一种基于网络结构搜索的模型剪枝方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109120462B (zh) | 机会网络链路的预测方法、装置及可读存储介质 | |
CN107562795A (zh) | 基于异构信息网络的推荐方法及装置 | |
CN110287332B (zh) | 云环境下仿真模型选择方法与装置 | |
CN112434188B (zh) | 一种异构数据库的数据集成方法、装置及存储介质 | |
CN114418035A (zh) | 决策树模型生成方法、基于决策树模型的数据推荐方法 | |
CN111243682A (zh) | 药物的毒性预测方法及装置、介质和设备 | |
CN114220458B (zh) | 基于阵列水听器的声音识别方法和装置 | |
CN116596095B (zh) | 基于机器学习的碳排放量预测模型的训练方法及装置 | |
WO2023124386A1 (zh) | 神经网络架构搜索的方法、装置、设备和存储介质 | |
CN111695824A (zh) | 风险尾端客户分析方法、装置、设备及计算机存储介质 | |
CN114881136A (zh) | 基于剪枝卷积神经网络的分类方法及相关设备 | |
CN116489038A (zh) | 网络流量的预测方法、装置、设备和介质 | |
CN113421264B (zh) | 轮毂质量检测方法、设备、介质及计算机程序产品 | |
US20240095529A1 (en) | Neural Network Optimization Method and Apparatus | |
CN108470251B (zh) | 基于平均互信息的社区划分质量评价方法及系统 | |
CN116223962B (zh) | 线束电磁兼容性预测方法、装置、设备及介质 | |
CN111105127B (zh) | 一种基于数据驱动的模块化产品设计评价方法 | |
CN111815209A (zh) | 应用于风控模型的数据降维方法及装置 | |
CN116451081A (zh) | 数据漂移的检测方法、装置、终端及存储介质 | |
CN116796821A (zh) | 面向3d目标检测算法的高效神经网络架构搜索方法及装置 | |
CN113393023B (zh) | 模具质量评估方法、装置、设备及存储介质 | |
CN114972950A (zh) | 多目标检测方法、装置、设备、介质及产品 | |
CN114860617A (zh) | 一种智能压力测试方法及系统 | |
CN110222842B (zh) | 一种网络模型训练方法、装置及存储介质 | |
CN113850523A (zh) | 基于数据补全的esg指数确定方法及相关产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |