CN116384471A - 模型剪枝方法、装置、计算机设备、存储介质和程序产品 - Google Patents
模型剪枝方法、装置、计算机设备、存储介质和程序产品 Download PDFInfo
- Publication number
- CN116384471A CN116384471A CN202310227819.0A CN202310227819A CN116384471A CN 116384471 A CN116384471 A CN 116384471A CN 202310227819 A CN202310227819 A CN 202310227819A CN 116384471 A CN116384471 A CN 116384471A
- Authority
- CN
- China
- Prior art keywords
- model
- pruning
- target
- initial
- source
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000013138 pruning Methods 0.000 title claims abstract description 295
- 238000000034 method Methods 0.000 title claims abstract description 112
- 238000004422 calculation algorithm Methods 0.000 claims abstract description 113
- 238000012549 training Methods 0.000 claims abstract description 54
- 238000012216 screening Methods 0.000 claims abstract description 14
- 238000012804 iterative process Methods 0.000 claims abstract description 6
- 230000008569 process Effects 0.000 claims description 54
- 238000012545 processing Methods 0.000 claims description 39
- 230000006835 compression Effects 0.000 claims description 22
- 238000007906 compression Methods 0.000 claims description 22
- 230000006870 function Effects 0.000 claims description 22
- 238000004821 distillation Methods 0.000 claims description 19
- 238000004590 computer program Methods 0.000 claims description 15
- 238000012821 model calculation Methods 0.000 claims description 14
- 238000010801 machine learning Methods 0.000 claims description 12
- 238000010586 diagram Methods 0.000 description 12
- 238000005516 engineering process Methods 0.000 description 7
- 238000003062 neural network model Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 6
- 238000004364 calculation method Methods 0.000 description 5
- 230000008859 change Effects 0.000 description 4
- 238000013145 classification model Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 229910052799 carbon Inorganic materials 0.000 description 2
- 230000000670 limiting effect Effects 0.000 description 2
- 230000002829 reductive effect Effects 0.000 description 2
- 230000003068 static effect Effects 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- OKTJSMMVPCPJKN-UHFFFAOYSA-N Carbon Chemical compound [C] OKTJSMMVPCPJKN-UHFFFAOYSA-N 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000003066 decision tree Methods 0.000 description 1
- 229910021389 graphene Inorganic materials 0.000 description 1
- 230000005484 gravity Effects 0.000 description 1
- 238000013140 knowledge distillation Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000036961 partial effect Effects 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Feedback Control In General (AREA)
Abstract
本申请涉及一种模型剪枝方法、装置、计算机设备、存储介质和程序产品。所述方法包括:获取待剪枝处理的源模型对应的初始模型集合,所述初始模型集合中的各初始模型是利用各初始剪枝算法分别对所述源模型进行剪枝处理得到的;将所述初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用所述参考模型训练所述目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合;基于目标模型集合和筛选条件,确定源模型对应的剪枝处理后的模型。采用本方法能够在提升剪枝效率的同时保证剪枝后的模型的性能。
Description
技术领域
本申请涉及模型剪枝技术领域,特别是涉及一种模型剪枝方法、装置、计算机设备、存储介质和程序产品。
背景技术
模型剪枝(Model Pruning)技术可以减少神经网络模型(以下简称为源模型)的参数量,在模型的轻量化部署中有着广泛应用。模型剪枝的主要流程是:按照压缩比要求为源模型中需要剪枝的权重层分配压缩比,再衡量每个权重层中不同维度的权重的重要性,删除重要性较低的权重,保留重要性较高的权重。
目前,已经存在多种剪枝算法,若需要为某个源模型进行模型剪枝,相关技术中,通常分别训练各个剪枝算法,训练完成后再从中选择最优解作为剪枝模型对源模型进行模型剪枝。
然而,在具体的业务场景中,不仅要考虑剪枝处理后的模型的性能,还需要兼顾剪枝效率,上述模型剪枝方式存在剪枝效率低的问题。
发明内容
本申请实施例提供了一种模型剪枝方法、装置、计算机设备、存储介质和程序产品,可以在提升剪枝效率的同时保证剪枝后的模型的性能。
第一方面,提供了一种模型剪枝方法,该方法包括:
获取待剪枝处理的源模型对应的初始模型集合,所述初始模型集合中的各初始模型是利用各初始剪枝算法分别对所述源模型进行剪枝处理得到的;
将所述初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用所述参考模型训练所述目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合;
基于所述目标模型集合和筛选条件,确定所述源模型对应的剪枝处理后的模型。
第二方面,提供了一种模型剪枝装置,该装置包括:
获取模块,用于获取待剪枝处理的源模型对应的初始模型集合,所述初始模型集合中的各初始模型是利用各初始剪枝算法分别对所述源模型进行剪枝处理得到的;
迭代模块,用于将所述初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用所述参考模型训练所述目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合;
确定模块,用于基于所述目标模型集合和筛选条件,确定所述源模型对应的剪枝处理后的模型。
第三方面,提供了一种计算机设备,包括存储器及处理器,所述存储器中储存有计算机程序,所述计算机程序被所述处理器执行时,使得所述处理器执行如上述第一方面所述的方法的步骤。
第四方面,提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现如上述第一方面所述的方法的步骤。
第五方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上述第一方面所述的方法的步骤。
本申请实施例提供的技术方案带来的有益效果至少包括:
通过获取待剪枝处理的源模型对应的初始模型集合,该初始模型集合中的各初始模型是利用各初始剪枝算法分别对源模型进行剪枝处理得到的,而后,将初始模型集合作为当前模型集合进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用参考模型训练目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合,再基于目标模型集合和筛选条件,确定源模型对应的剪枝处理后的模型,该筛选条件例如可以是从目标模型集合中选择出性能最好的剪枝模型作为源模型对应的剪枝处理后的模型,该筛选条件例如还可以是从目标模型集合中选择出性能最好的剪枝模型后再对其进行重训练处理,从而得到源模型对应的剪枝处理后的模型,该筛选条件例如还可以是从目标模型集合中选择出性能最好的多个剪枝模型再分别对其进行重训练处理,再从重训练结果中选择性能最好的模型作为源模型对应的剪枝处理后的模型,等等,这样,本申请实施例通过迭代学习不断更新优化当前模型集合,迭代完成后再基于上述筛选条件从中确定性能最好的模型作为源模型对应的剪枝处理后的模型,从而确保该剪枝处理后的模型的性能;另外,相较于传统技术中利用训练样本分别训练各个参数初始化的剪枝算法的方式而言,传统技术中每个参数初始化的剪枝算法在训练过程中需要循环迭代多次,其数据处理量庞大,训练效率低下,导致剪枝效率低,而本申请实施例中,直接利用参数初始化的各初始剪枝算法分别对源模型进行剪枝处理,不必分别训练各个参数初始化的剪枝算法,且本申请实施例利用参考模型训练目标模型来对当前模型集合进行迭代更新的方式,由于迭代次数少等原因(迭代次数例如可以和初始模型集合中各初始模型的个数相同),数据处理量也远远小于利用训练样本分别训练各个参数初始化的剪枝算法的数据处理量,从而本申请实施例也提升了剪枝效率。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为一个实施例中剪枝算法的分类示意图;
图2为一个实施例中模型剪枝方法的流程图;
图3为一个实施例中利用参考模型训练目标模型的流程图;
图4为一个实施例中从当前模型集合中确定目标模型和参考模型的流程图;
图5为一个实施例中步骤201的流程图;
图6为一个实施例中一种示例性地源模型的目标卷积层和目标源模型的目标卷积层的卷积核数量对比示意图;
图7为一个实施例中从目标模型集合中确定目标剪枝模型的流程图;
图8为一个实施例中一种示例性地模型剪枝方法的整体框架示意图;
图9为一个实施例中一种示例性地利用各初始剪枝算法分别对源模型进行剪枝处理的示意图;
图10为一个实施例中模型剪枝装置的结构框图;
图11为一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
模型剪枝(Model Pruning)作为神经网络模型的压缩和优化的一项技术,由于其可以减少神经网络模型的参数量,因此在神经网络模型的轻量化部署和实际落地中有着广泛应用
模型剪枝的主要依据是训练完成的神经网络模型(为了描述方便,以下简称为源模型)存在一定的冗余度,可以在保持性能的前提下去掉部分权重。模型剪枝的主要流程是:按照压缩任务的压缩比要求,根据一定的分配算法为源模型中需要剪枝的所有权重层(如卷积层Conv、全连接层Dense等)分配压缩比,再根据一定的度量标准衡量每个权重层中不同维度的权重的重要性,删除重要性较低的权重,保留重要性较高的权重,从而使源模型成为一个参数量更少的新模型,随后通过一定的方法对新模型的权重进行重新训练微调,使其性能表现逼近甚至优于源模型。
目前,模型剪枝领域已经进化出了大量的剪枝算法,以衡量每个权重层中不同维度的权重的重要性为例,有的剪枝算法基于权重层的不同范数(L1范数,L2范数等)选取重要通道,有的剪枝算法基于几何中心数选取重要通道,有的剪枝算法基于BN层的训练参数选取对应权重层的通道,有的剪枝算法基于权重层输出特征张量中的均值或统计零值等选取通道,还有的剪枝算法根据权重层反向传导梯度的泰勒展开估算值选取重要的通道,等等。
以2D卷积操作(其对应的卷积核为4维张量)为例,参见图1,剪枝算法从维度上可以分为4种:
1)细粒度剪枝(Fine-grained),其是从单个元素上剪枝,利用0-D表示,即对连接或者神经元进行剪枝,它是粒度最小的剪枝。
2)向量剪枝(Vector-level),是减去整列的剪枝方式,利用1-D表示,它相对于细粒度剪枝粒度更大,属于对卷积核内部的剪枝。
3)核剪枝(Kernel-level),是减去整个卷积核的剪枝方式,利用2-D表示,即去除某个卷积核,它将丢弃对输入通道中对应计算通道的响应。
4)滤波器剪枝(Filter-level),是对整个卷积核组进行剪枝,利用3-D表示,会造成推理过程中输出特征通道数的改变。
上述4种剪枝算法的结构化程度依次提高,越结构化的剪枝方案对神经网络模型的部署越友好。
在实际任务中,各个剪枝算法效果各异,针对某个任务训练的剪枝算法很难在其他任务上表现优异,况且多数剪枝算法本身有超参数,参数大小的变化调整也会导致不同的效果也影响其性能。若需要为某个源模型进行模型剪枝,只能分别训练各个参数初始化的剪枝算法,训练完成后再从中选择最优解作为剪枝模型对源模型进行模型剪枝。但是,这种方式中,每个参数初始化的剪枝算法在训练过程中需要循环迭代多次,其数据处理量庞大,训练效率低下,导致剪枝效率低。
而且,在具体的业务场景中,不仅要考虑兼顾剪枝效率,还要兼顾剪枝处理后的模型的性能。
为解决上述问题,本申请实施例提供一种模型剪枝方法,能够在提升剪枝效率的同时保证剪枝后的模型的性能。以下,结合本申请实施例模型剪枝方法所应用的应用环境,对本申请实施例的实施过程进行介绍。
本申请实施例提供的模型剪枝方法,可以应用于计算机设备中,该计算机设备可以是服务器,服务器可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
在一个实施例中,如图2所示,提供了一种模型剪枝方法,包括以下步骤:
步骤201,获取待剪枝处理的源模型对应的初始模型集合。
源模型可以是任意需要进行模型剪枝的神经网络模型,例如,源模型可以是分类任务中的分类模型、目标检测任务中的目标检测模型,等等。计算机设备获取该源模型对应的初始模型集合,该初始模型集合中包括多个初始模型。以下,对各初始模型的概念以及计算机设备获取各初始模型的方式进行示例性地介绍。
本申请实施例中,该初始模型集合中的各初始模型是利用各初始剪枝算法分别对源模型进行剪枝处理得到的,各初始剪枝算法是指各参数初始化的剪枝算法,例如,以决策树中常见的剪枝算法为例,各初始剪枝算法为参数初始化的错误率降低剪枝算法(Reduced-Error Pruning,REP)、参数初始化的悲观错误剪枝算法(Pesimistic-ErrorPruning,PEP)、参数初始化的代价复杂度剪枝算法(Cost-Complexity Pruning,CCP),等等。参数初始化可以理解为对剪枝算法的超参数赋以随机值。
在步骤201一种可能的实施方式中,计算机设备获取该源模型,并选定多个剪枝算法,计算机设备将选定的各剪枝算法的超参数设置为随机值,从而得到多个初始剪枝算法,然后,对于每个初始剪枝算法,计算机设备利用该初始剪枝算法对源模型进行剪枝处理,得到该初始剪枝算法对应的初始模型,各个初始剪枝算法对应的初始模型则组成初始模型集合。
在步骤201另一种可能的实施方式中,上述进行剪枝处理得到各初始模型的过程也可以在其他设备中实现,本申请实施例计算机设备从该其他设备中获取最终的初始模型集合。在此对初始模型集合的具体获取方式不做限定。
本申请实施例中,利用各初始剪枝算法分别对源模型进行剪枝处理的方式,可以是减去整个卷积核的方式(即核剪枝的方式),还可以是对整个卷积核组进行剪枝的方式(即滤波器剪枝的方式),相较于细粒度剪枝和向量剪枝而言,核剪枝和滤波器剪枝的结构化程度更高,有利于最终得到的源模型对应的剪枝处理后的模型的部署。
步骤202,将初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用参考模型训练目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合。
计算机设备将初始模型集合作为当前模型集合,是指计算机设备将初始模型集合作为第一次迭代过程中的当前模型集合。
对于第一次迭代过程,计算机设备从初始模型集合中确定目标模型和参考模型,然后,利用该参考模型训练目标模型得到训练后的目标模型,计算机设备再利用该训练后的目标模型更新初始模型集合,得到第一次迭代过程对应的更新后的模型集合。
对于第二次迭代过程,计算机设备将上一次迭代过程(即第一次迭代过程)对应的更新后的模型集合作为第二次迭代过程的当前模型集合,然后,计算机设备从该当前模型集合中确定目标模型和参考模型,并继续与上一次迭代过程同样的训练和更新过程,得到第二次迭代过程对应的更新后的模型集合。
同理,后续迭代过程以此类推,计算机设备最终得到最后一次迭代过程对应的更新后的模型集合,即得到目标模型集合。
上述计算机设备利用参考模型训练目标模型的过程,可以是基于蒸馏学习算法进行训练,也可以是基于迁移学习算法进行训练,当然还可以是其他具有类似效果(即以参考模型指导目标模型训练的效果)的机器学习算法。
以蒸馏学习算法为例,计算机设备可以将参考模型作为教师网络,将目标模型作为学生网路,利用教师网络指导学生网络的输出,在训练过程中,可以将参考模型的输出和真实标签一并作为目标模型输出的标签进行训练,这样,通过参考模型的输出和真实标签的共同监督,使得训练后的目标模型的模型性能优于训练之前的目标模型的性能。
以下,对计算机设备利用训练后的目标模型更新当前模型集合的过程进行示例性地说明。
在一种可能的实施方式中,计算机设备可以在当前模型集合中确定候选模型,该候选模型的性能指标值小于训练后的目标模型的性能指标值,然后在当前模型集合中利用训练后的目标模型替换候选模型,也即,计算机设备从当前模型集合选出模型性能比该训练后的目标模型的模型性能差的候选模型,该候选模型的数量可以是一个或多个,然后利用该训练后的目标模型替换该一个或多个候选模型。
在另一种可能的实施方式中,计算机设备利用训练后的目标模型更新当前模型集合,还可以是从当前模型集合选出模型性能最差的候选模型,然后利用该训练后的目标模型替换该模型性能最差的候选模型。
可以看出,本申请实施例通过不断淘汰模型性能差的候选模型,不断迭代更新当前模型集合,使得当前模型集合中模型的模型性能越来越优异。而且,随着当前模型集合中模型的模型性能越来越优异,计算机设备所选择的参考模型的模型性能也越来越优异,进一步提升对目标模型的监督效果,提升训练后的目标模型的模型性能。
另外,本申请实施例迭代学习的迭代次数可以与初始模型集合中的各初始模型的个数相同,当然,迭代学习的迭代次数在实际实施时也可以自行设置。相较于传统技术中利用训练样本分别训练各个参数初始化的剪枝算法,每个剪枝算法都需要重复迭代循环多次,导致数据处理量巨大,本申请实施例由于迭代次数少,且迭代学习过程简单,从而本申请实施例迭代学习的数据处理量也远远小于传统技术中上述训练方式的数据处理量,有利于提升本申请实施例的剪枝效率。
步骤203,基于目标模型集合和筛选条件,确定源模型对应的剪枝处理后的模型。
如上文所述,目标模型集合是最后一次迭代过程对应的更新后的模型集合,相较于初始模型集合而言,目标模型集合中的部分或者全部剪枝模型都是利用相应的参考模型训练后的性能优异的模型。
在目标模型集合的基础上,计算机设备基于筛选条件,确定源模型对应的剪枝处理后的模型。
在一种可能的实施方式中,计算机设备可以从目标模型集合中选择出性能最好的剪枝模型,并将该剪枝模型作为源模型对应的剪枝处理后的模型。
在另一种可能的实施方式中,计算机设备还可以从目标模型集合中确定性能最好的剪枝模型后,再对其进行重训练处理,并将重训练处理后的模型作为源模型对应的剪枝处理后的模型。
在其他可能的实施方式中,计算机设备还可以从目标模型集合中选择出性能最好的多个剪枝模型,再分别对其进行重训练处理,最后再从重训练结果中选择性能最好的模型作为源模型对应的剪枝处理后的模型。
上述对剪枝模型进行重训练处理,可以是基于蒸馏学习算法,将源模型作为教师网络,将剪枝模型作为学生网络,利用源模型指导剪枝模型的输出,从而通过重训练处理可以进一步调整剪枝模型的模型参数,使其性能逼近源模型,甚至超越源模型。
整体而言,上述实施例通过获取待剪枝处理的源模型对应的初始模型集合,该初始模型集合中的各初始模型是利用各初始剪枝算法分别对源模型进行剪枝处理得到的,而后,将初始模型集合作为当前模型集合进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用参考模型训练目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合,再基于目标模型集合和筛选条件,确定源模型对应的剪枝处理后的模型,该筛选条件例如可以是从目标模型集合中选择出性能最好的剪枝模型作为源模型对应的剪枝处理后的模型,该筛选条件例如还可以是从目标模型集合中选择出性能最好的剪枝模型后再对其进行重训练处理,从而得到源模型对应的剪枝处理后的模型,该筛选条件例如还可以是从目标模型集合中选择出性能最好的多个剪枝模型再分别对其进行重训练处理,再从重训练结果中选择性能最好的模型作为源模型对应的剪枝处理后的模型,等等,这样,本申请实施例通过迭代学习不断更新优化当前模型集合,迭代完成后再基于上述筛选条件从中确定性能最好的模型作为源模型对应的剪枝处理后的模型,从而确保该剪枝处理后的模型的性能;另外,相较于传统技术中利用训练样本分别训练各个参数初始化的剪枝算法的方式而言,传统技术中每个参数初始化的剪枝算法在训练过程中需要循环迭代多次,其数据处理量庞大,训练效率低下,导致剪枝效率低,而本申请实施例中,直接利用参数初始化的各初始剪枝算法分别对源模型进行剪枝处理,不必分别训练各个参数初始化的剪枝算法,且本申请实施例利用参考模型训练目标模型来对当前模型集合进行迭代更新的方式,由于迭代次数少等原因(迭代次数例如可以和初始模型集合中各初始模型的个数相同),数据处理量也远远小于利用训练样本分别训练各个参数初始化的剪枝算法的数据处理量,从而本申请实施例也提升了剪枝效率。
在一个实施例中,基于图2所示的实施例,参见图3,本实施例涉及的是计算机设备如何利用参考模型训练目标模型的过程。如图3所示,该过程包括步骤301和步骤302:
步骤301,获取参考模型针对目标输入样本的参考输出结果,并获取目标模型针对目标输入样本的目标输出结果。
目标输入样本即为符合参考模型和目标模型的输入要求的样本,例如,参考模型和目标模型均为图像分类模型,则目标输入样本为任意需要分类的图像样本。
计算机设备将目标输入样本分别输入至参考模型和目标模型中,得到参考模型输出的参考输出结果,以及目标模型输出的目标输出结果。
步骤302,根据参考输出结果、目标输出结果和预设机器学习算法训练目标模型。
以预设机器学习算法为蒸馏学习算法为例,对步骤302的实施过程进行示例性地介绍。
在一种可能的实施方式中,步骤302可以包括如下步骤A1和步骤A2:
步骤A1,根据参考输出结果、目标输出结果和预设机器学习算法对应的损失函数计算损失值。
本申请实施例中,利用选出的参考模型对目标模型做知识蒸馏,优化目标模型的快速训练。蒸馏学习算法的损失函数使用交叉熵(CE,Cross Entropy)和KL(Kullback-Leibler)散度的组合在BP(Back Propagation,反向传播)过程中求梯度,蒸馏学习算法的损失函数的表达式如下:
其中,x,y分别表示训练集中的输入样本(即上述目标输入样本)和真实标签,p*表示参考模型,p表示目标函数,wp*表示p*的模型参数(或者称之为权重),wp表示p的模型参数,p*(x,wp*)表示p*针对输入x的输出(即上述参考输出结果),p(x,wp)表示p针对输入x的输出(即上述目标输出结果),LCE(y,p(x,wp))为损失函数的基础项,表征p的输出与真实标签之间的交叉熵,LKD(p*(x,wp*),p(x,wp))是蒸馏学习的附加项,表征p的输出与p*的输出之间的KL散度,η表示学习率,表示基于wp的梯度。
步骤A2,根据损失值调整目标模型的模型参数。
这样,将参考输出结果、目标输出结果输入至上述损失函数则可计算得到损失值,再根据该损失值反向传播调整目标模型p的模型参数。
上述实施例在蒸馏学习的过程中,通过真实标签y和参考输出结果共同监督目标模型的训练,使得训练后的目标模型的模型性能优于训练之前的目标模型的性能。
在另一种可能的实施方式中,步骤302可以包括如下步骤B1、步骤B2和步骤A2,即上述步骤A1可以通过如下步骤B1和步骤B2实现:
步骤B1,获取参考模型和目标模型之间的匹配度。
步骤B2,将参考输出结果、目标输出结果和匹配度输入至损失函数中,得到损失值。
步骤A2,根据损失值调整目标模型的模型参数。
参考模型和目标模型之间的匹配度,计算机设备可以通过匹配网络进行计算,匹配网络作为单独训练的网络,可以由简单的双层全连接网络构成。匹配网络计算的表达式如下所示:
ρ(p,p*)=Meta((p(x,wp)-p*(x,wp*)),θ)
其中,ρ(p,p*)为参考模型和目标模型之间的匹配度,θ为匹配网络中取值固定的参数,其余字母含义参见上文。
通过该匹配网络计算的表达式可以看出,计算机设备将上述参考输出结果和目标输出结果之间的差值输入至匹配网络,得到匹配网络输出的参考模型和目标模型之间的匹配度,该匹配度是归一化到(0,1)的标量。
计算机设备根据参考输出结果、目标输出结果和匹配度计算损失值,并根据损失值调整目标模型的模型参数。目标模型的模型参数wp更新的表达式如下
即基于上述蒸馏学习算法的损失函数的表达式,计算机设备将该匹配度作为该损失函数中蒸馏学习的附加项LKD(p*(x,wp*),p(x,wp))的调节系数,计算损失值,并根据该损失值反向传播调整目标模型p的模型参数wp。
本申请实施例中,匹配度的大小与损失值的大小负相关,即参考模型和目标模型之间的匹配程度越高,LKD(p*(x,wp*),p(x,wp))函数值在损失值中的比重越大,损失值越小,目标模型越快收敛,参考模型的指导意义越大,以得到模型性能优异的训练后的目标模型,从而得到模型性能优异的目标模型集合。
然后,计算机设备基于该目标模型集合和筛选条件,确定源模型对应的剪枝处理后的模型,这样,本申请实施例模型剪枝方法综合了现有的剪枝算法,结合蒸馏学习算法来检索最优的剪枝模型,蒸馏学习算法作为模型剪枝中的一环,既有利于为源模型选择合适的剪枝算法及超参数,同时也在参数更新过程中使用优异的参考模型辅助现有模型(即目标模型),通过蒸馏学习的方式提高训练后的目标模型。
在一个实施例中,基于图2所示的实施例,参见图4,本实施例涉及的是计算机设备如何从当前模型集合中确定目标模型和参考模型的过程。如图4所示,该过程包括步骤401、步骤402和步骤403:
步骤401,从当前模型集合中随机选取出目标模型。
如上文所述,对于第N次迭代过程,计算机设备将第N-1次迭代过程对应的更新后的模型集合,作为第N次迭代过程的当前模型集合。然后,计算机设备从当前模型集合中随机选取一个模型作为当前参与迭代学习的目标模型。
步骤402,对于当前模型集合中的每个其他模型,利用匹配网络计算目标模型和其他模型之间的匹配度。
计算机设备利用匹配网络计算目标模型和其他模型之间的匹配度的过程,可以参见上述步骤B1中获取匹配度的过程,在此不再赘述。
步骤403,若匹配度大于预设匹配度阈值,则确定其他模型为参考模型。
该预设匹配度阈值的设定,可以使得参考模型对应的匹配度是各其他模型对应的匹配度中的最大匹配度,也即,从当前模型集合中的各个其他模型中选择与目标模型匹配度最高的模型作为参考模型。
这样,计算机设备确定出目标模型和参考模型后,将参考模型作为后续训练目标模型的教师网络,指导目标模型的训练,具体训练过程参见上文所述的实施例,在此不再赘述。
基于图4所示的实施例,本申请实施例模型剪枝方法还包括:根据目标模型的输出和参考模型的输出,对匹配网络的网络参数进行更新。
以下,对匹配网络的网络参数的更新过程进行介绍。
本申请实施例中,利用上述蒸馏学习算法的损失函数来指导匹配网络的参数更新,匹配网络的损失函数Lmeta的表达式如下:
其中,ρ(θ)为基于θ相关的表达式,其余字母含义参见上文描述。
匹配网络的损失函数隐含的意义在于,如果匹配网络选出的参考模型对于目标模型的指导意义更大,那么目标模型更新后在验证集上所取得的损失值就更小,因此,匹配网络的损失函数的目标在于引导匹配网络为目标网络匹配出最为匹配的参考模型。这样,通过目标模型的输出和参考模型的输出即可计算得到R值,进一步得到匹配网络对应的损失值,计算机设备基于该损失值反向传播调整匹配网络的网络参数。
本申请实施例中,对匹配网络的网络参数进行更新的更新频率大于或者等于迭代学习的迭代频率,这是由于计算机设备基于该损失值反向传播调整匹配网络的网络参数时,反向推导的过程涉及求解梯度的梯度,矩阵展开比较复杂且耗时,因此,可以设定匹配网络每隔一定周期做一次参数更新,每隔一定周期是指每隔若干个迭代学习过程,而非每个迭代学习过程都更新匹配网络的网络参数,从而提升了匹配网络的网络参数的更新速率,有利于提升本申请实施例剪枝效率。
在一个实施例中,基于图2所示的实施例,参见图5,本实施例涉及的是计算机设备如何获取待剪枝处理的源模型对应的初始模型集合的过程。如图5所示,步骤201包括:
步骤2011,将各初始剪枝算法的超参数设置为随机值。
计算机设备选定多个基础的剪枝算法,剪枝算法如上文所述的REP、PEP、CCP等,然后,计算机设备将各剪枝算法的超参数设置为随机值。
步骤2012,对于每个初始剪枝算法,利用初始剪枝算法和目标压缩比对源模型进行剪枝处理得到初始模型,以得到初始模型集合。
目标压缩比可以是源模型对应当前压缩任务所需要的压缩比,例如为50%,即表征需要对源模型压缩50%的参数量。
然后,对于每个初始剪枝算法,计算机设备利用该初始剪枝算法和目标压缩比对源模型进行剪枝处理,得到该初始剪枝算法对应的初始模型。
请继续参见图5,步骤201还包括:
步骤2013,对于每个初始模型,检测初始模型的模型参数量和模型算力是否满足预设条件。
模型参数量是指模型参数的数量,模型算力是指模型使用时点乘的次数,本申请实施例中,为了避免利用初始剪枝算法和目标压缩比对源模型进行剪枝处理得到的初始模型不符合压缩任务的要求,则设置预设条件对初始模型的模型参数量和模型算力进行检测,该预设条件可以是初始模型的模型参数量小于参数量限制阈值,该初始模型的模型算力小于算力限制阈值。
步骤2014,若初始模型的模型参数量和模型算力满足预设条件,则将初始模型添加至初始模型集合中。
只有在初始模型的模型参数量和模型算力满足预设条件时,才将其添加在初始模型集合中,换言之,初始模型集合中的初始模型的模型参数量和模型算力都满足上述预设条件,从而提升了各初始模型与源模型对应的压缩业务的匹配度,提升了模型剪枝的效果。
在一种可能的实施方式中,步骤2012之前,本申请实施例模型剪枝方法还包括:对源模型的目标卷积层进行卷积核增加处理,得到目标源模型。
示例性地,参见图6,图6为一种示例性地源模型的目标卷积层和目标源模型的目标卷积层的卷积核数量对比示意图。
目标卷积层的层序号小于预设层序号阈值,即目标卷积层为浅层卷积层,以目标卷积层是二维卷积层为例,源模型的目标卷积层的卷积核维度假设为(H,W,Cin,Cout),针对输出通道数Cout增加n个卷积核,则目标源模型的目标卷积层的卷积核维度为(H,W,Cin,Cout+n)。
相应地,计算机设备利用初始剪枝算法和目标压缩比对目标源模型进行剪枝处理,实现利用初始剪枝算法和目标压缩比对源模型进行剪枝处理得到初始模型的过程。
这样,通过对浅层卷积层做通道扩充,一方面能够提高其原始性能,从而间接提高剪枝处理后的模型的性能,另一方面利用初始剪枝算法对其进行剪枝处理时,使得剪枝路径的检索空间扩大,增加可选择的余地。
在一个实施例中,基于图2所示的实施例,参见图7,本实施例涉及的是计算机设备如何基于目标模型集合和筛选条件,确定源模型对应的剪枝处理后的模型的过程。如图7所示,该过程包括步骤701和步骤702:
步骤701,获取目标模型集合中各剪枝模型的性能指标值。
性能指标值是剪枝模型的模型性能的量化值,例如,剪枝模型为分类模型,则剪枝模型的性能指标值可以是剪枝模型的分类化准确性。
本申请实施例中,计算机设备可以基于预设的验证集,逐一测试各个剪枝模型的性能,得到各剪枝模型的性能指标值。
步骤702,根据各剪枝模型的性能指标值,确定剪枝处理后的模型。
在步骤702一种可能的实施方式中,计算机设备可以根据各剪枝模型的性能指标值,从目标模型集合中选择出性能最好的剪枝模型作为源模型对应的剪枝处理后的模型。
在步骤702另一种可能的实施方式中,计算机设备可以从各剪枝模型中确定性能指标值最大的剪枝模型,即从目标模型集合中确定性能最好的剪枝模型,然后对该性能指标值最大的剪枝模型进行重训练处理,即进一步训练直至收敛,得到源模型对应的剪枝处理后的模型。
在步骤702另一种可能的实施方式中,计算机设备还可以按照性能指标值由大到小的顺序,从各剪枝模型中确定多个候选剪枝模型,如选择top n(n为大于1的整数)个性能最优的剪枝模型(即多个候选剪枝模型),再对各候选剪枝模型分别进行重训练处理,收敛后得到多个重训练剪枝模型,最后再将各重训练剪枝模型中性能指标值最大(即性能最好)的重训练剪枝模型作为源模型对应的剪枝处理后的模型。
这样,本申请实施例在模型的重训练阶段,直接选择性能最好一个或者多个剪枝模型进行重训练处理,减少了重训练的剪枝模型的数量,从而减少了训练过程中的数据处理量,提升了模型剪枝的效率。
参见图8,图8为一种示例性地模型剪枝方法的整体框架示意图,图9为一种示例性地利用各初始剪枝算法分别对源模型进行剪枝处理的示意图。以下,结合图8和图9,对本申请实施例模型剪枝方法的整体过程进行简要介绍。该模型剪枝方法包括:
1)计算机设备对源模型的目标卷积层进行卷积核增加处理,得到目标源模型,目标卷积层的层序号小于预设层序号阈值。
2)计算机设备将各初始剪枝算法的超参数设置为随机值。
3)计算机设备对于每个初始剪枝算法,利用初始剪枝算法和目标压缩比对目标源模型进行剪枝处理得到初始模型。
请参见图8和图9,利用Ai表示初始剪枝算法,i=1,2,...,N,N为各个初始剪枝法的个数,ai表示对应初始剪枝算法的超参数,ai具有一定的检索空间(即ai的值可以调整)。
计算机设备每次随机选择初始剪枝算法Ai,并且将Ai的超参数设置为随机值,然后利用该初始剪枝算法Ai对目标源模型M(V)进行剪枝处理得到一个初始模型。
4)对于每个初始模型,计算机设备检测初始模型的模型参数量和模型算力是否满足预设条件。
用C表示选取初始模型的限制条件,Count_params()表示计算初始模型的模型参数量的函数,Cparam表示参数量限制阈值,Count_flops()表示计算初始模型的模型算力的函数,Cflops表示算力限制阈值。
C={Cparam,Cflops},预设条件即为:
Count_params(p)<Cparam,且Count_flops(p)<Cflops
5)若初始模型的模型参数量和模型算力满足预设条件,计算机设备则将初始模型添加至初始模型集合中。
6)计算机设备将初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用参考模型训练目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合。
在第一次迭代过程时,图8所示的子模型池即为初始模型集合,随着迭代学习的不断进行,子模型池也在不断更新,迭代完成后,图8所示的子模型池即为目标模型集合。
其中,从当前模型集合中确定目标模型和参考模型,包括:计算机设备从当前模型集合中随机选取出目标模型;对于当前模型集合中的每个其他模型,利用匹配网络计算目标模型和其他模型之间的匹配度;若匹配度大于预设匹配度阈值,则确定其他模型为参考模型。
可选地,计算机设备还可以根据目标模型的输出和参考模型的输出,对匹配网络的网络参数进行更新,且对匹配网络的网络参数进行更新的更新频率大于或者等于迭代学习的迭代频率。
其中,利用参考模型训练目标模型,包括:计算机设备获取参考模型针对目标输入样本的参考输出结果,并获取目标模型针对目标输入样本的目标输出结果;获取参考模型和目标模型之间的匹配度;将参考输出结果、目标输出结果和匹配度输入至蒸馏学习算法对应的损失函数中,得到损失值,匹配度的大小与损失值的大小负相关;根据损失值调整目标模型的模型参数。
其中,计算机设备利用训练后的目标模型更新当前模型集合,包括:在当前模型集合中确定候选模型,候选模型的性能指标值小于训练后的目标模型的性能指标值;在当前模型集合中利用训练后的目标模型替换候选模型。
7)计算机设备获取剪枝模型集合中各剪枝模型的性能指标值;
8)计算机设备从各剪枝模型中确定性能指标值最大的剪枝模型,对性能指标值最大的剪枝模型进行重训练处理,得到源模型对应的剪枝处理后的模型;或者,
计算机设备按照性能指标值由大到小的顺序,从各剪枝模型中确定多个候选剪枝模型;对各候选剪枝模型分别进行重训练处理,得到多个重训练剪枝模型;将各重训练剪枝模型中性能指标值最大的重训练剪枝模型作为源模型对应的剪枝处理后的模型。
请继续参见图8,整体而言,本申请实施例的子模型池用于存储表现优秀的剪枝处理后的模型,M(V)表示源模型(或者表示目标源模型),通过初始剪枝算法根据M(V)得到模型p,检验p是否符合算力和参数量限制,若符合则添加在子模型池,Meta(θ)作为匹配网络,从子模型池中选择与p匹配的p*,对p做快速训练,训练过程中p*用以蒸馏辅助,随后利用训练后的p对子模型池进行替换更新,对匹配网络的网络参数也保持定期更新,经过一定的迭代周期后,从子模型池中选择表现最好的剪枝模型作为最终输出。
本申请实施例整合了多种基础剪枝算法(即各初始剪枝算法),将剪枝的调优过程分为检索和重训练两部分,在检索过程中将剪枝算法以及参数选择和模型训练相结合,动态选择最优模型。
本申请实施例建立的子模型池,在模型检索和调优阶段,可以提供匹配的p*用于对p的蒸馏学习,而且保存了检索过程中的优异模型,在模型的重训练阶段,可以直接选择性能最优异的模型作为最终的输出,避免了重新检索和定位。
应该理解的是,虽然如上所述的各实施例所涉及的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,如上所述的各实施例所涉及的流程图中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
基于同样的发明构思,本申请实施例还提供了一种用于实现上述所涉及的模型剪枝方法的模型剪枝装置。该装置所提供的解决问题的实现方案与上述方法中所记载的实现方案相似,故下面所提供的一个或多个模型剪枝装置实施例中的具体限定可以参见上文中对于模型剪枝方法的限定,在此不再赘述。
在一个实施例中,如图10所示,提供了一种模型剪枝装置,包括:
获取模块1001,用于获取待剪枝处理的源模型对应的初始模型集合,所述初始模型集合中的各初始模型是利用各初始剪枝算法分别对所述源模型进行剪枝处理得到的;
迭代模块1002,用于将所述初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用所述参考模型训练所述目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合;
确定模块1003,用于基于所述目标模型集合和筛选条件,确定所述源模型对应的剪枝处理后的模型。
在一个实施例中,所述迭代模块1002,包括:
第一获取单元,用于获取所述参考模型针对目标输入样本的参考输出结果,并获取所述目标模型针对所述目标输入样本的目标输出结果;
训练单元,用于根据所述参考输出结果、所述目标输出结果和预设机器学习算法训练所述目标模型。
在一个实施例中,所述训练单元具体用于根据所述参考输出结果、所述目标输出结果和所述预设机器学习算法对应的损失函数计算损失值;根据所述损失值调整所述目标模型的模型参数。
在一个实施例中,所述训练单元具体用于获取所述参考模型和所述目标模型之间的匹配度;将所述参考输出结果、所述目标输出结果和所述匹配度输入至所述损失函数中,得到所述损失值,所述匹配度的大小与所述损失值的大小负相关。
在一个实施例中,所述预设机器学习算法为蒸馏学习算法。
在一个实施例中,所述迭代模块1002,还包括:
选取单元,用于从当前模型集合中随机选取出所述目标模型;
匹配单元,用于对于当前模型集合中的每个其他模型,利用匹配网络计算所述目标模型和所述其他模型之间的匹配度;
第一确定单元,用于若所述匹配度大于预设匹配度阈值,则确定所述其他模型为所述参考模型。
在一个实施例中,所述迭代模块1002,还包括:
更新单元,用于根据所述目标模型的输出和所述参考模型的输出,对所述匹配网络的网络参数进行更新。
在一个实施例中,对所述匹配网络的网络参数进行更新的更新频率大于或者等于所述迭代学习的迭代频率。
在一个实施例中,所述迭代模块1002,还包括:
第二确定单元,用于在当前模型集合中确定候选模型,所述候选模型的性能指标值小于所述训练后的目标模型的性能指标值;
替换单元,用于在当前模型集合中利用所述训练后的目标模型替换所述候选模型。
在一个实施例中,所述获取模块1001,包括:
设置单元,用于将各所述初始剪枝算法的超参数设置为随机值;
剪枝单元,用于对于每个所述初始剪枝算法,利用所述初始剪枝算法和目标压缩比对所述源模型进行剪枝处理得到初始模型,以得到所述初始模型集合。
在一个实施例中,所述获取模块1001还包括:
增加单元,用于对所述源模型的目标卷积层进行卷积核增加处理,得到目标源模型,所述目标卷积层的层序号小于预设层序号阈值;
所述剪枝单元具体用于利用所述初始剪枝算法和所述目标压缩比对所述目标源模型进行剪枝处理得到初始模型。
在一个实施例中,所述装置还包括:
检测模块,用于对于每个所述初始模型,检测所述初始模型的模型参数量和模型算力是否满足预设条件;
添加模块,用于若所述初始模型的模型参数量和模型算力满足所述预设条件,则将所述初始模型添加至所述初始模型集合中。
在一个实施例中,所述确定模块1003,包括:
第二获取单元,用于获取所述剪枝模型集合中各剪枝模型的性能指标值;
第三确定单元,用于根据各所述剪枝模型的性能指标值,确定所述剪枝处理后的模型。
在一个实施例中,所述第三确定单元具体用于从各所述剪枝模型中确定性能指标值最大的剪枝模型;对所述性能指标值最大的剪枝模型进行重训练处理,得到所述剪枝处理后的模型。
在一个实施例中,所述第三确定单元具体用于按照性能指标值由大到小的顺序,从各所述剪枝模型中确定多个候选剪枝模型;对各所述候选剪枝模型分别进行重训练处理,得到多个重训练剪枝模型;将各所述重训练剪枝模型中性能指标值最大的重训练剪枝模型作为所述剪枝处理后的模型。
上述模型剪枝装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图11所示。该计算机设备包括处理器、存储器、输入/输出接口(Input/Output,简称I/O)和通信接口。其中,处理器、存储器和输入/输出接口通过系统总线连接,通信接口通过输入/输出接口连接到系统总线。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质和内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储模型剪枝数据。该计算机设备的输入/输出接口用于处理器与外部设备之间交换信息。该计算机设备的通信接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种模型剪枝方法。
本领域技术人员可以理解,图11中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
本申请实施例还提供了一种计算机可读存储介质。一个或多个包含计算机可执行指令的非易失性计算机可读存储介质,当所述计算机可执行指令被一个或多个处理器执行时,使得所述处理器执行模型剪枝方法的步骤。
本申请实施例还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行模型剪枝方法。
需要说明的是,本申请所涉及的用户信息(包括但不限于用户设备信息、用户个人信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户授权或者经过各方充分授权的信息和数据,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-OnlyMemory,ROM)、磁带、软盘、闪存、光存储器、高密度嵌入式非易失性存储器、阻变存储器(ReRAM)、磁变存储器(Magnetoresistive Random Access Memory,MRAM)、铁电存储器(Ferroelectric Random Access Memory,FRAM)、相变存储器(Phase Change Memory,PCM)、石墨烯存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器等。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic RandomAccess Memory,DRAM)等。本申请所提供的各实施例中所涉及的数据库可包括关系型数据库和非关系型数据库中至少一种。非关系型数据库可包括基于区块链的分布式数据库等,不限于此。本申请所提供的各实施例中所涉及的处理器可为通用处理器、中央处理器、图形处理器、数字信号处理器、可编程逻辑器、基于量子计算的数据处理逻辑器等,不限于此。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对本申请专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请的保护范围应以所附权利要求为准。
Claims (19)
1.一种模型剪枝方法,其特征在于,包括:
获取待剪枝处理的源模型对应的初始模型集合,所述初始模型集合中的各初始模型是利用各初始剪枝算法分别对所述源模型进行剪枝处理得到的;
将所述初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用所述参考模型训练所述目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合;
基于所述目标模型集合和筛选条件,确定所述源模型对应的剪枝处理后的模型。
2.根据权利要求1所述的方法,其特征在于,所述利用所述参考模型训练所述目标模型,包括:
获取所述参考模型针对目标输入样本的参考输出结果,并获取所述目标模型针对所述目标输入样本的目标输出结果;
根据所述参考输出结果、所述目标输出结果和预设机器学习算法训练所述目标模型。
3.根据权利要求2所述的方法,其特征在于,所述根据所述参考输出结果、所述目标输出结果和预设机器学习算法训练所述目标模型,包括:
根据所述参考输出结果、所述目标输出结果和所述预设机器学习算法对应的损失函数计算损失值;
根据所述损失值调整所述目标模型的模型参数。
4.根据权利要求3所述的方法,其特征在于,所述根据所述参考输出结果、所述目标输出结果和所述预设机器学习算法对应的损失函数计算损失值,包括:
获取所述参考模型和所述目标模型之间的匹配度;
将所述参考输出结果、所述目标输出结果和所述匹配度输入至所述损失函数中,得到所述损失值,所述匹配度的大小与所述损失值的大小负相关。
5.根据权利要求2-4任一项所述的方法,其特征在于,所述预设机器学习算法为蒸馏学习算法。
6.根据权利要求1所述的方法,其特征在于,所述从当前模型集合中确定目标模型和参考模型,包括:
从当前模型集合中随机选取出所述目标模型;
对于当前模型集合中的每个其他模型,利用匹配网络计算所述目标模型和所述其他模型之间的匹配度;
若所述匹配度大于预设匹配度阈值,则确定所述其他模型为所述参考模型。
7.根据权利要求6所述的方法,其特征在于,所述方法还包括:
根据所述目标模型的输出和所述参考模型的输出,对所述匹配网络的网络参数进行更新。
8.根据权利要求7所述的方法,其特征在于,对所述匹配网络的网络参数进行更新的更新频率大于或者等于所述迭代学习的迭代频率。
9.根据权利要求1所述的方法,其特征在于,所述利用训练后的目标模型更新当前模型集合,包括:
在当前模型集合中确定候选模型,所述候选模型的性能指标值小于所述训练后的目标模型的性能指标值;
在当前模型集合中利用所述训练后的目标模型替换所述候选模型。
10.根据权利要求1所述的方法,其特征在于,所述获取待剪枝处理的源模型对应的初始模型集合,包括:
将各所述初始剪枝算法的超参数设置为随机值;
对于每个所述初始剪枝算法,利用所述初始剪枝算法和目标压缩比对所述源模型进行剪枝处理得到初始模型,以得到所述初始模型集合。
11.根据权利要求10所述的方法,其特征在于,所述利用所述初始剪枝算法和目标压缩比对所述源模型进行剪枝处理得到初始模型之前,所述方法还包括:
对所述源模型的目标卷积层进行卷积核增加处理,得到目标源模型,所述目标卷积层的层序号小于预设层序号阈值;
所述利用所述初始剪枝算法和目标压缩比对所述源模型进行剪枝处理得到初始模型,包括:
利用所述初始剪枝算法和所述目标压缩比对所述目标源模型进行剪枝处理得到初始模型。
12.根据权利要求10所述的方法,其特征在于,所述方法还包括:
对于每个所述初始模型,检测所述初始模型的模型参数量和模型算力是否满足预设条件;
若所述初始模型的模型参数量和模型算力满足所述预设条件,则将所述初始模型添加至所述初始模型集合中。
13.根据权利要求1所述的方法,其特征在于,所述基于所述目标模型集合和筛选条件,确定所述源模型对应的剪枝处理后的模型,包括:
获取所述目标模型集合中各剪枝模型的性能指标值;
根据各所述剪枝模型的性能指标值,确定所述剪枝处理后的模型。
14.根据权利要求13所述的方法,其特征在于,所述根据各所述剪枝模型的性能指标值,确定所述剪枝处理后的模型,包括:
从各所述剪枝模型中确定性能指标值最大的剪枝模型;
对所述性能指标值最大的剪枝模型进行重训练处理,得到所述剪枝处理后的模型。
15.根据权利要求13所述的方法,其特征在于,所述根据各所述剪枝模型的性能指标值,确定所述剪枝处理后的模型,包括:
按照性能指标值由大到小的顺序,从各所述剪枝模型中确定多个候选剪枝模型;
对各所述候选剪枝模型分别进行重训练处理,得到多个重训练剪枝模型;
将各所述重训练剪枝模型中性能指标值最大的重训练剪枝模型作为所述剪枝处理后的模型。
16.一种模型剪枝装置,其特征在于,包括:
获取模块,用于获取待剪枝处理的源模型对应的初始模型集合,所述初始模型集合中的各初始模型是利用各初始剪枝算法分别对所述源模型进行剪枝处理得到的;
迭代模块,用于将所述初始模型集合作为当前模型集合,进行迭代学习,对于每次迭代过程,从当前模型集合中确定目标模型和参考模型,利用所述参考模型训练所述目标模型,并利用训练后的目标模型更新当前模型集合,直至得到目标模型集合;
确定模块,用于基于所述目标模型集合和筛选条件,确定所述源模型对应的剪枝处理后的模型。
17.一种计算机设备,包括存储器及处理器,所述存储器中储存有计算机程序,其特征在于,所述计算机程序被所述处理器执行时,使得所述处理器执行如权利要求1至15中任一项所述的方法的步骤。
18.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至15中任一项所述的方法的步骤。
19.一种计算机程序产品,包括计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至15中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310227819.0A CN116384471A (zh) | 2023-03-09 | 2023-03-09 | 模型剪枝方法、装置、计算机设备、存储介质和程序产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310227819.0A CN116384471A (zh) | 2023-03-09 | 2023-03-09 | 模型剪枝方法、装置、计算机设备、存储介质和程序产品 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116384471A true CN116384471A (zh) | 2023-07-04 |
Family
ID=86970350
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310227819.0A Pending CN116384471A (zh) | 2023-03-09 | 2023-03-09 | 模型剪枝方法、装置、计算机设备、存储介质和程序产品 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116384471A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117910536A (zh) * | 2024-03-19 | 2024-04-19 | 浪潮电子信息产业股份有限公司 | 文本生成方法及其模型梯度剪枝方法、装置、设备、介质 |
-
2023
- 2023-03-09 CN CN202310227819.0A patent/CN116384471A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117910536A (zh) * | 2024-03-19 | 2024-04-19 | 浪潮电子信息产业股份有限公司 | 文本生成方法及其模型梯度剪枝方法、装置、设备、介质 |
CN117910536B (zh) * | 2024-03-19 | 2024-06-07 | 浪潮电子信息产业股份有限公司 | 文本生成方法及其模型梯度剪枝方法、装置、设备、介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
He et al. | Asymptotic soft filter pruning for deep convolutional neural networks | |
Li et al. | Towards compact cnns via collaborative compression | |
Li et al. | Improved techniques for training adaptive deep networks | |
US20190340533A1 (en) | Systems and methods for preparing data for use by machine learning algorithms | |
CN110909926A (zh) | 基于tcn-lstm的太阳能光伏发电预测方法 | |
Tang et al. | Automatic sparse connectivity learning for neural networks | |
WO2022042123A1 (zh) | 图像识别模型生成方法、装置、计算机设备和存储介质 | |
CN109784474A (zh) | 一种深度学习模型压缩方法、装置、存储介质及终端设备 | |
CN111079899A (zh) | 神经网络模型压缩方法、系统、设备及介质 | |
CN114118402A (zh) | 基于分组注意力机制的自适应剪枝模型压缩算法 | |
Yang et al. | Channel pruning based on convolutional neural network sensitivity | |
CN116384471A (zh) | 模型剪枝方法、装置、计算机设备、存储介质和程序产品 | |
CN116188878A (zh) | 基于神经网络结构微调的图像分类方法、装置和存储介质 | |
CN115311506A (zh) | 基于阻变存储器的量化因子优化的图像分类方法及装置 | |
JP7235836B2 (ja) | クラスタ接続ニューラルネットワーク | |
Li et al. | Using feature entropy to guide filter pruning for efficient convolutional networks | |
Parada-Mayorga et al. | Graphon pooling for reducing dimensionality of signals and convolutional operators on graphs | |
CN112734025B (zh) | 基于固定基正则化的神经网络参数稀疏化方法 | |
CN110288002B (zh) | 一种基于稀疏正交神经网络的图像分类方法 | |
WO2022095984A1 (en) | Method and system for convolution with workload-balanced activation sparsity | |
Tang et al. | Training Compact DNNs with ℓ1/2 Regularization | |
CN114677535A (zh) | 域适应图像分类网络的训练方法、图像分类方法及装置 | |
Zhang et al. | Compressing knowledge graph embedding with relational graph auto-encoder | |
Kumar et al. | Structure level pruning of efficient convolutional neural networks with sparse group LASSO | |
Dai et al. | Deep Learning Model Compression With Rank Reduction in Tensor Decomposition |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |