CN113408709B - 基于单元重要度的条件计算方法 - Google Patents

基于单元重要度的条件计算方法 Download PDF

Info

Publication number
CN113408709B
CN113408709B CN202110785452.5A CN202110785452A CN113408709B CN 113408709 B CN113408709 B CN 113408709B CN 202110785452 A CN202110785452 A CN 202110785452A CN 113408709 B CN113408709 B CN 113408709B
Authority
CN
China
Prior art keywords
network
residual
importance
unit
input image
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110785452.5A
Other languages
English (en)
Other versions
CN113408709A (zh
Inventor
周泓
杨涛
楼震宇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202110785452.5A priority Critical patent/CN113408709B/zh
Publication of CN113408709A publication Critical patent/CN113408709A/zh
Application granted granted Critical
Publication of CN113408709B publication Critical patent/CN113408709B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

本发明公开了一种基于单元重要度的条件计算方法,包含:S1:预先训练主干残差网络M;S2:构建门控网络G;S3:计算所述主干残差网络M中每个所述残差单元对每一张输入图像的重要度;S4:将所述输入图像及其对应的各所述残差单元的重要度组成为输入‑标签对,构建数据集,通过所述数据集训练所述门控网络G;S5:对所述主干残差网络M进行微调以适应动态裁剪;S6:重复步骤S3‑S5直到模型的裁剪率和精度满足预设条件。本发明的基于单元重要度的条件计算方法,计算主干残差网络M中每个残差单元对每一张输入图像的重要度,并以此构建数据集用于训练门控网络G,使门控网络G能够根据输入图像与中间特征图,预测出不同残差单元的重要度。

Description

基于单元重要度的条件计算方法
技术领域
本发明涉及一种基于单元重要度的条件计算方法。
背景技术
目前,深度学习模型压缩主要包括裁剪、量化、知识蒸馏等。其中,裁剪按照粒度区分可分为神经元级裁剪、滤波器级裁剪、甚至残差单元级裁剪等,考虑实际应用场景中通用处理器的实际推理加速效果,通常采用的是滤波器级或残差单元级裁剪。常见的裁剪方案中通常设计一个滤波器或一个残差单元重要性评估指标,然后衡量每一个裁剪候选单元的重要性,并裁去重要度较低的直到模型的计算复杂度符合要求。
条件计算是一种较新的深度学习模型压缩手段,它利用了不同滤波器或不同残差单元提取的特征互不相同,以及不同输入图像拥有不同特征的特点,个性化地根据输入图像的不同决策出恰当的计算路径。现有的条件计算方法主要为残差单元级粒度的条件计算,通常通过强化学习训练一个小型门控网络用于根据输入或中间特征图预测各个残差单元的开闭。
但现有的条件计算方法大多采用强化学习,根据分类的交叉熵损失与裁剪率构建强化学习的reward,并将该reward返回给所有门控的输出进行训练。这使得门控网络的搜索空间非常大,在数据集容量有限的情况下很难实现良好的动态裁剪。
发明内容
本发明提供了一种基于单元重要度的条件计算方法,采用如下的技术方案:
一种基于单元重要度的条件计算方法,包含以下步骤:
S1:预先训练主干残差网络M,主干残差网络M包含n个残差单元;
S2:为预训练好的主干残差网络M构建门控网络G;
S3:计算主干残差网络M中每个残差单元对每一张输入图像的重要度;
S4:将输入图像及其对应的各残差单元的重要度组成为输入-标签对,构建数据集,固定主干残差网络M,通过数据集训练门控网络G;
S5:在训练好门控网络G后固定门控网络G,对主干残差网络M进行微调以适应动态裁剪;
S6:重复步骤S3-S5直到模型的裁剪率和精度满足预设条件。
进一步地,在步骤S3中计算主干残差网络M中每个残差单元对每一张输入图像的重要度的具体方法为通过下述公式进行计算:
imp(x,i)=loss(M-Block[i],x)-loss(M,x)
其中,x为输入图像,M-Block[i]为M中第i个残差单元被裁去时的剩余n-1个残差单元构成的子网络,function为给定的当前任务的目标函数,imp(x,i)为M中第i个残差单元对输入x的重要度。
进一步地,在步骤S4中,将重要度标注作为reward,门控网络G的输出G(x)作为各个门控的预测值,经过Sigmoid函数将门控预测值转化为开启概率后,使用类强化学习的算法对门控网络G进行训练。
进一步地,步骤S4中的目标函数通过下述公式进行计算,
Figure BDA0003158564610000021
其中,G(x)为各个门控的预测值,训练采用梯度上升以最大化目标函数。
进一步地,在步骤S5中对主干残差网络M进行微调时,使得每个输入图像都只经过所有n个残差单元的特定子集,对于某个输入图像,主干残差网络M的微调只对特定子集中的残差单元进行。
进一步地,步骤S2中构建的门控网络为ResNet8卷积神经网络,或以LSTM循环神经网络为主体的神经网络,或者是n个独立的MLP,每个MLP对应于一个残差单元。
本发明的有益之处在于所提供的基于单元重要度的条件计算方法,首先预先训练主干残差网络M,然后为预训练的主干残差网络M构建门控网络G用于预测所有主干残差网络M中残差单元的重要度与开闭。为了训练门控网络M,计算训练集上主干残差网络M中每个残差单元对每一张输入图像的重要度,并以此构建数据集用于训练门控网络G,使门控网络G能够根据输入图像与中间特征图,预测出不同残差单元的重要度。从而能够在推理阶段动态地对不同输入裁去重要度低,或者对当前输入无效甚至有害的残差单元以实现模型裁剪与精度提升。
附图说明
图1是本发明的基于单元重要度的条件计算方法的示意图。
具体实施方式
以下结合附图和具体实施例对本发明作具体的介绍。
如图1所示为本发明一种基于单元重要度的条件计算方法,主要包含以下步骤:步骤S1:预先训练主干残差网络M,主干残差网络M包含n个残差单元。步骤S2:为预训练好的主干残差网络M构建门控网络G。门控网络G用于控制主干残差网络M中n个残差单元的开闭。若残差单元开启,则前向推理时该残差单元被正常计算;若残差单元关闭,则前向推理时该残差单元中只有短连接被经过,残差单元被裁去而不需要做任何计算。步骤S3:选取若干输入图像,计算主干残差网络M中每个残差单元对每一张输入图像的重要度。步骤S4:将输入图像及其对应的各残差单元的重要度组成为输入-标签对,构建数据集,固定主干残差网络M,通过数据集训练门控网络G。步骤S5:在训练好门控网络G后固定门控网络G,对主干残差网络M进行微调以适应动态裁剪。步骤S6:重复步骤S3-S5直到模型的裁剪率和精度满足预设条件。通过上述步骤,首先预先训练主干残差网络M,然后为预训练的主干残差网络M构建门控网络G用于预测所有主干残差网络M中残差单元的重要度与开闭。为了训练门控网络M,计算训练集上主干残差网络M中每个残差单元对每一张输入图像的重要度,并以此构建数据集用于训练门控网络G,使门控网络G能够根据输入图像与中间特征图,预测出不同残差单元的重要度。从而能够在推理阶段动态地对不同输入裁去重要度低,或者对当前输入无效甚至有害的残差单元以实现模型裁剪与精度提升。
作为一种优选的实施方式,在步骤S3中,计算主干残差网络M中每个残差单元对每一张输入图像的重要度的具体方法为通过下述公式进行计算:
imp(x,i)=loss(M-Block[i],x)-loss(M,x)
其中,x为输入图像,M-Block[i]为M中第i个残差单元被裁去时的剩余n-1个残差单元构成的子网络,loss为给定的当前任务的损失函数,imp(x,i)为M中第i个残差单元对输入x的重要度。
作为一种优选的实施方式,在步骤S4中,将重要度标注作为reward,门控网络G的输出G(x)作为各个门控的预测值,经过Sigmoid函数将门控预测值转化为开启概率后,使用类强化学习的算法对门控网络G进行训练。
作为一种优选的实施方式,步骤S4中的损失函数通过下述公式进行计算,
Figure BDA0003158564610000031
其中,G(x)为各个门控的预测值,训练采用梯度上升以最大化目标函数。
作为一种优选的实施方式,在步骤S5中对主干残差网络M进行微调时,使得每个输入图像都只经过所有n个残差单元的特定子集,对于某个输入图像,主干残差网络M的微调只对特定子集中的残差单元进行。
具体而言,由于残差单元粒度的裁剪破坏了主干残差网络M预训练过程中BN层统计的数据分布信息,包括running_mean、running_var等。在正式适用门控网络G进行动态裁剪前我们还需要固定门控网络G,并在门控网络G的指导下进行动态裁剪,使得每个输入图像x都只经过所有n个残差单元的特定子集。譬如对于输入x0,在门控网络G的指导下我们裁掉了第3、第6个残差单元,此时x0对应的需要通过的残差单元子集为U={Block[1]、Block[2]、Block[4]、Block[5]、Block[7]、…、Block[n]},且在整个步骤S5这一步的微调环节中,图像x0都只会利用U中的残差单元进行推理,对于图像x0,主干残差网络的微调也只会对U中的残差单元进行。
作为一种优选的实施方式,步骤S2中构建的门控网络为卷积神经网络。卷积神经网络为ResNet8。门控网络G独立于主干残差网络M,直接接收输入图像为网络输入,并在全连接层输出所有门控的预测结果。
使用卷积神经网络型的门控网络能够使我们在主干残差网络运作之前就能一次性获得所有门控的预测结果,便于我们事先进行单元裁剪的决策,同时门控网络的开销不会随主干网络容量的增大而变大。
当采用卷积神经网络型的门控网络时,由于能够事先获得所有门控的预测结果,因此可以直接使用贪心法:找出重要度最低的一个或若干个残差单元裁去。也可以采用阈值法:设置阈值α,裁去所有-G(x)>α的单元并保留-G(x)<α;也可首先计算Softmax(-G(x))后设定阈值α,裁去Softmax(-G(x))>α的单元并保留Softmax(-G(x))<α的单元。
作为一种可选的实施方式,作为一种优选的实施方式,步骤S2中构建的门控网络为以L循环神经网络为主体的神经网络。作为一种优选的实施方式,循环神经网络为LSTM。使用LSTM等循环神经网络作为门控网络,将主干残差网络中每个残差单元的输入特征图组成序列,降维后输入门控网络,门控网络逐个对序列中每个残差单元对应的门控进行预测。
使用循环神经网络型的门控网络能够利用浅层所有残差单元的序列信息协同进行下一个残差单元门控的预测。
作为另一种可选的实施方式,步骤S2中构建的门控网络为n个独立的MLP(Multilayer Perceptron,多层感知器),每个MLP对应于一个残差单元。使用MLP型的门控网络,每个单元分配独立的门控单元使得门控网络的训练更加容易且稳定。
当采用循环神经网络型或MLP型门控网络时,由于无法事先获得所有门控的预测结果,需要在主干残差网络M前向推理的过程中同时进行动态裁剪的决策,只能采用阈值法。在精度敏感且计算开销限制较低的场合下,也可以首先对主干残差网络M进行一次前向推理用于收集所有门控的预测结果后,使用贪心法指导动态裁剪并重新进行一次主干残差网络M的前向推理。
以上显示和描述了本发明的基本原理、主要特征和优点。本行业的技术人员应该了解,上述实施例不以任何形式限制本发明,凡采用等同替换或等效变换的方式所获得的技术方案,均落在本发明的保护范围内。

Claims (3)

1.一种基于单元重要度的条件计算方法,其特征在于,包含以下步骤:
S1:预先训练主干残差网络M,所述主干残差网络M包含n个残差单元;
S2:为预训练好的所述主干残差网络M构建门控网络G;
S3:计算所述主干残差网络M中每个所述残差单元对每一张输入图像的重要度;
S4:将所述输入图像及其对应的各所述残差单元的重要度组成为输入-标签对,构建数据集,固定所述主干残差网络M,通过所述数据集训练所述门控网络G;
S5:在训练好所述门控网络G后固定所述门控网络G,对所述主干残差网络M进行微调以适应动态裁剪;
S6:重复步骤S3-S5直到模型的裁剪率和精度满足预设条件;
在所述步骤S3中计算所述主干残差网络M中每个残差单元对每一张输入图像的重要度的具体方法为通过下述公式进行计算:
imp(x,i)=loss(M-Block[i],x)-loss(M,x)
其中,x为输入图像,M-Block[i]为M中第i个残差单元被裁去时的剩余n-1个残差单元构成的子网络,loss为给定的当前任务的损失函数,imp(x,i)为M中第i个残差单元对输入x的重要度;
在所述步骤S4中,将重要度标注作为reward,所述门控网络G的输出G(x)作为各个门控的预测值,经过Sigmoid函数将门控预测值转化为开启概率后,使用类强化学习的算法对所述门控网络G进行训练;
步骤S4中的目标函数通过下述公式进行计算,
Figure FDA0003952353190000011
其中,G(x)为各个门控的预测值,训练采用梯度上升以最大化目标函数。
2.根据权利要求1所述的基于单元重要度的条件计算方法,其特征在于,
在所述步骤S5中对所述主干残差网络M进行微调时,使得每个所述输入图像都只经过所有n个所述残差单元的特定子集,对于某个所述输入图像,所述主干残差网络M的微调只对所述特定子集中的所述残差单元进行。
3.根据权利要求1所述的基于单元重要度的条件计算方法,其特征在于,
所述步骤S2中构建的所述门控网络为ResNet8卷积神经网络,或以LSTM循环神经网络为主体的神经网络,或者是n个独立的MLP,每个MLP对应于一个所述残差单元。
CN202110785452.5A 2021-07-12 2021-07-12 基于单元重要度的条件计算方法 Active CN113408709B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110785452.5A CN113408709B (zh) 2021-07-12 2021-07-12 基于单元重要度的条件计算方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110785452.5A CN113408709B (zh) 2021-07-12 2021-07-12 基于单元重要度的条件计算方法

Publications (2)

Publication Number Publication Date
CN113408709A CN113408709A (zh) 2021-09-17
CN113408709B true CN113408709B (zh) 2023-04-07

Family

ID=77686131

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110785452.5A Active CN113408709B (zh) 2021-07-12 2021-07-12 基于单元重要度的条件计算方法

Country Status (1)

Country Link
CN (1) CN113408709B (zh)

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110276316A (zh) * 2019-06-26 2019-09-24 电子科技大学 一种基于深度学习的人体关键点检测方法

Family Cites Families (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111314694B (zh) * 2013-12-27 2023-05-05 索尼公司 图像处理装置及方法
CN108764471B (zh) * 2018-05-17 2020-04-14 西安电子科技大学 基于特征冗余分析的神经网络跨层剪枝方法
US20200160565A1 (en) * 2018-11-19 2020-05-21 Zhan Ma Methods And Apparatuses For Learned Image Compression
CN109785847B (zh) * 2019-01-25 2021-04-30 东华大学 基于动态残差网络的音频压缩算法
CN111598233A (zh) * 2020-05-11 2020-08-28 浙江大学 深度学习模型的压缩方法、装置及设备
CN111898591B (zh) * 2020-08-28 2022-06-24 电子科技大学 一种基于剪枝残差网络的调制信号识别方法
CN112052951A (zh) * 2020-08-31 2020-12-08 北京中科慧眼科技有限公司 一种剪枝神经网络方法、系统、设备及可读存储介质

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110276316A (zh) * 2019-06-26 2019-09-24 电子科技大学 一种基于深度学习的人体关键点检测方法

Also Published As

Publication number Publication date
CN113408709A (zh) 2021-09-17

Similar Documents

Publication Publication Date Title
CN110223517B (zh) 基于时空相关性的短时交通流量预测方法
CN110007652B (zh) 一种水电机组劣化趋势区间预测方法与系统
CN113723007B (zh) 基于drsn和麻雀搜索优化的设备剩余寿命预测方法
CN110889085A (zh) 基于复杂网络多元在线回归的废水智能监控方法及系统
CN109886343B (zh) 图像分类方法及装置、设备、存储介质
CN111553535B (zh) 一种基于ae-lstm-bo车流量预测的导航参考方法
CN112766496B (zh) 基于强化学习的深度学习模型安全性保障压缩方法与装置
CN113993172B (zh) 一种基于用户移动行为预测的超密集网络切换方法
CN108399434A (zh) 基于特征提取的高维时间序列数据的分析预测方法
CN112766603A (zh) 一种交通流量预测方法、系统、计算机设备及存储介质
CN116244647A (zh) 一种无人机集群的运行状态估计方法
CN114202065B (zh) 一种基于增量式演化lstm的流数据预测方法及装置
CN113408709B (zh) 基于单元重要度的条件计算方法
CN113035348A (zh) 一种基于gru特征融合的糖尿病诊断方法
Sun et al. Ada-STNet: A Dynamic AdaBoost Spatio-Temporal Network for Traffic Flow Prediction
Zineb et al. Cognitive radio networks management using an ANFIS approach with QoS/QoE mapping scheme
CN115796017A (zh) 一种基于模糊理论的可解释交通认知方法
CN115293249A (zh) 一种基于动态时序预测的电力系统典型场景概率预测方法
CN113255963A (zh) 基于路元拆分和深度学习模型lstm的路面使用性能预测方法
Bi et al. Multi-indicator Water Time Series Imputation with Autoregressive Generative Adversarial Networks
CN114386602B (zh) 一种面向多路服务器负载数据的htm预测分析方法
CN116957166B (zh) 一种基于鸿蒙系统的隧道交通情况预测方法及系统
CN117273225B (zh) 一种基于时空特征的行人路径预测方法
Mei et al. Research on short-term urban traffic congestion based on fuzzy comprehensive evaluation and machine learning
CN115063975B (zh) 短时交通流数据预测方法、系统、计算机设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant