CN110135562B - 基于特征空间变化的蒸馏学习方法、系统、装置 - Google Patents
基于特征空间变化的蒸馏学习方法、系统、装置 Download PDFInfo
- Publication number
- CN110135562B CN110135562B CN201910360632.1A CN201910360632A CN110135562B CN 110135562 B CN110135562 B CN 110135562B CN 201910360632 A CN201910360632 A CN 201910360632A CN 110135562 B CN110135562 B CN 110135562B
- Authority
- CN
- China
- Prior art keywords
- network
- layer
- teacher
- student
- student network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 230000008859 change Effects 0.000 title claims abstract description 70
- 238000004821 distillation Methods 0.000 title claims abstract description 61
- 238000000034 method Methods 0.000 title claims abstract description 60
- 230000006870 function Effects 0.000 claims abstract description 40
- 239000011159 matrix material Substances 0.000 claims abstract description 17
- 238000004364 calculation method Methods 0.000 claims description 17
- 238000010276 construction Methods 0.000 claims description 6
- 238000011176 pooling Methods 0.000 claims description 6
- 239000004576 sand Substances 0.000 claims description 3
- 230000008030 elimination Effects 0.000 claims 1
- 238000003379 elimination reaction Methods 0.000 claims 1
- 238000010801 machine learning Methods 0.000 abstract description 3
- 230000008569 process Effects 0.000 description 8
- 230000007547 defect Effects 0.000 description 7
- 238000012549 training Methods 0.000 description 6
- 230000009286 beneficial effect Effects 0.000 description 4
- 238000003062 neural network model Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 230000005540 biological transmission Effects 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 230000001133 acceleration Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000004880 explosion Methods 0.000 description 1
- 239000012212 insulator Substances 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 101150049349 setA gene Proteins 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 239000013598 vector Substances 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明属于计算机视觉及机器学习领域,具体涉及了一种基于特征空间变化的蒸馏学习方法、系统、装置,旨在解决学生网络无法学习教师网络全局知识的问题。本发明方法包括:按照蒸馏学习教师网络的参数结构构建对应的学生网络;分别选取预设的网络层,计算每一层的特征空间表示以及特定两个层间的跨层特征空间变化矩阵;计算基于特征空间变化的损失函数,根据真实标签计算分类损失函数;通过两个损失函数的加权将教师网络的特征空间变化作为知识迁移到学生网络中。本发明将教师网络层与层之间的特征空间变化刻画为一种新的知识,从而,使得学生网络在学习层与层之间的特征空间变化时,就学习到整个教师网络全局的知识。
Description
技术领域
本发明属于计算机视觉及机器学习领域,具体涉及了一种基于特征空间变化的蒸馏学习方法、系统、装置。
背景技术
蒸馏学习是计算机视觉和机器学习的一个重要研究领域。在蒸馏学习中,包含两个网络:一个为预训练好的,具有较强性能,但是计算复杂度高、要求存储空间大的教师网络;一个为待训练的,但是往往具有远低于教师网络的计算复杂度以及存储空间要求的学生网络。蒸馏学习旨在从教师网络中提取出有用的信息和知识来作为学生网络训练过程中的指导。在教师网络的指导下进行训练学习,学生网络可以获得比单独训练更加优良的性能。如此一来,蒸馏学习可以得到高性能、低计算复杂度、低存储消耗的学生网络。此方法特别适用于算力有限的移动设备和嵌入式设备。
蒸馏学习能够训练得到性能更好的学生网络,其根本原因在于性能强大的教师网络在训练过程中提供了除了数据以外的额外有用信息。因此,如何从教师网络中高效地挖掘对于学生有利的信息便成为了蒸馏学习的关键研究课题。由于蒸馏学习的首次提出是在2012年,其发展时间还十分有限,使用的方法目前也十分单一。目前流行的蒸馏学习方法几乎都使用教师网络某一层(或者某几层)的输出特征作为额外信息。在学习过程中,学生网络通过最小化欧式距离的方法,来学习对应层的特征,从而从教师网络中迁移知识。例如,最为流行的方法将教师网络输出的预测结果作为软标签,作为数据自带的真实标签以外的额外标签,来给学生网络提供知识。另外,也有方法将教师网络的中间层特征抽象化为注意力图谱,利用这些图谱来指导学生网络也可以获得更佳的性能。
但是,目前流行的蒸馏学习方法将教师网络每一层的特征视为互相独立的特征空间,却忽略了层与层之间的相关性。因此,上述的基于层输出特征的学习方法,仅仅学习到了教师网络的一部分知识。如果要学习教师网络所有层的知识,需要学生网络对所有层的知识进行同时学习。然而,直接强行使学生网络拟合教师网络所有层的特征,往往没法得到性能更佳的学生网络,甚至无法使其收敛,说明直接对教师网络的多层空间进行拟合是一个不利于学习的强约束。
发明内容
为了解决现有技术中的上述问题,即学生网络无法学习教师网络全局知识的问题,本发明提供了一种基于特征空间变化的蒸馏学习方法,包括:
步骤S10,根据蒸馏学习的教师网络的通道数、计算复杂度、存储空间要求,构建蒸馏学习的学生网络;
步骤S20,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示;
步骤S30,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵;
步骤S40,基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络。
在一些优选的实施例中,步骤S10中“根据蒸馏学习的教师网络的通道数、计算复杂度、存储空间要求,构建蒸馏学习的学生网络”,其方法为:
步骤S11,提取教师网络的通道数、计算复杂度、存储空间要求:
T={CT,ST,NT}
其中,T代表教师网络,CT代表教师网络计算复杂度,ST代表教师网络存储空间要求,NT代表教师网络通道数;
步骤S12,所述学生网络采用与所述教师网络相同的网络结构,学生网络的计算复杂度要求为CS,学生网络的存储空间消耗要求SS,学生网络的通道数NS,构建学生网络:
S={CS,SS,NS}
其中,CS、SS为预先设定的,NS根据教师网络通道数NT计算:
NS=min{NT*(CS/CT),NT*sqrt(SS/ST)}
其中,sqrt()代表平方根计算,min()代表求最小数。
在一些优选的实施例中,步骤S20中“选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示”,其方法为:
步骤S21,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别提取样本蔟中每个样本的样本特征;
步骤S22,分别对所述样本特征进行全局平均池化,获得每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示。
在一些优选的实施例中,步骤S30中“基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵”,其方法为:
步骤S31,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络样本簇中每个样本的跨层特征空间变化;
步骤S32,基于所述教师网络、学生网络样本簇中每个样本的跨层特征空间变化,分别计算所述教师网络、学生网络的特征空间变化矩阵。
在一些优选的实施例中,所述目标损失函数,其计算方法为:
LossTotal=LossGT+λLossTrans
其中,LossTotal代表目标损失函数;LossGT代表分类损失函数;LossTrans代表基于特征空间变化的损失函数,λ为空间变化损失函数的权重。
在一些优选的实施例中,“基于特征空间变化矩阵的损失函数”,其计算方法为:
本发明的另一方面,提出了一种基于特征空间变化的蒸馏学习系统,包括网络构建模块、网络特征表示模块、跨层网络特征表示模块、蒸馏学习模块、输出模块;
所述网络构建模块,配置为根据训练好的教师网络的通道数、计算复杂度、存储空间要求,构建学生网络;
所述网络特征表示模块,配置为选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示;
所述跨层网络特征表示模块,配置为基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵;
所述蒸馏学习模块,配置为基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络;
所述输出模块,配置为将获取的学习后的学生网络输出。
本发明的第三方面,提出了一种存储装置,其中存储有多条程序,所述程序适于由处理器加载并执行以实现上述的基于特征空间变化的蒸馏学习方法。
本发明的第四方面,提出了一种处理装置,包括处理器、存储装置;所述处理器,适于执行各条程序;所述存储装置,适于存储多条程序;所述程序适于由处理器加载并执行以实现上述的基于特征空间变化的蒸馏学习方法。
本发明的有益效果:
(1)本发明基于特征空间变化的蒸馏学习方法,通过刻画教师网络层与层之间特征空间的变化来指导学生网络进行蒸馏学习和知识迁移,可以在不拟合所有层特征的同时,学习到教师网络的全局知识,从而得到更优性能的学生网络。
(2)本发明基于特征空间变化的蒸馏学习方法,不用直接强行使学生网络拟合教师网络所有层的特征,而是将层与层之间的特征空间变化刻画为一种新的知识,使得学生网络在学习层与层之间的特征空间变化时,就学习到整个教师网络全局的知识。
附图说明
通过阅读参照以下附图所作的对非限制性实施例所作的详细描述,本申请的其它特征、目的和优点将会变得更明显:
图1是本发明基于特征空间变化的蒸馏学习方法的流程示意图;
图2是本发明基于特征空间变化的蒸馏学习方法的算法框架图。
具体实施方式
下面结合附图和实施例对本申请作进一步的详细说明。可以理解的是,此处所描述的具体实施例仅用于解释相关发明,而非对该发明的限定。另外还需要说明的是,为了便于描述,附图中仅示出了与有关发明相关的部分。
需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本申请。
本发明的一种基于特征空间变化的蒸馏学习方法,包括:
步骤S10,根据蒸馏学习的教师网络的通道数、计算复杂度、存储空间要求,构建蒸馏学习的学生网络;
步骤S20,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示;
步骤S30,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵;
步骤S40,基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络。
为了更清晰地对本发明基于特征空间变化的蒸馏学习方法进行说明,下面结合图1对本发明方法实施例中各步骤展开详述。
本发明一种实施例的基于特征空间变化的蒸馏学习方法,包括步骤S10-步骤S40,各步骤详细描述如下:
步骤S10,根据蒸馏学习的教师网络的通道数、计算复杂度、存储空间要求,构建蒸馏学习的学生网络。
在利用深度网络解决问题的时候人们常常倾向于设计更为复杂的网络收集更多的数据以期获得更好的效果,但随之而来的是模型的复杂度急剧提升,直观的表现是模参数越来越多、规模越来越大,需要的硬件资源(内存、GPU)越来越高,不利于模型的部署和应用向移动端的推广。
蒸馏学习采用的是迁移学习,通过采用预先训练好的复杂网络-教师网络模型(Teacher model)的输出作为监督信号去训练另外一个简单的网络-学生网络模型(Student model),获得的学生网络精简且复杂度低,同时具有教师网络的知识,利于模型的部署和应用向移动端的推广。
步骤S11,提取教师网络的通道数、计算复杂度、存储空间要求,如式(1)所示:
T={CT,ST,NT} 式(1)
其中,T代表教师网络,CT代表教师网络计算复杂度,ST代表教师网络存储空间要求,NT代表教师网络通道数。
步骤S12,所述学生网络采用与所述教师网络相同的网络结构,学生网络的计算复杂度要求为CS,学生网络的存储空间消耗要求SS,学生网络的通道数NS,构建学生网络,如式(2)所示:
S={CS,SS,NS} 式(2)
其中,CS、SS为预先设定的,NS根据教师网络通道数NT计算,如式(3)所示:
NS=min{NT*(CS/CT),NT*sqrt(SS/ST)} 式(3)
其中,sqrt()代表平方根计算,min()代表求最小数。
步骤S20,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示。
步骤S21,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别提取样本蔟中每个样本的样本特征,其方法为:
其中,N表示一个样本蔟中样本的数量,Fi l表示样本蔟中第i个样本在教师网络第l层的特征,fi l表示样本蔟中第i个样本在学生网络第l层的特征。本发明一个具体的实例中,分别选取教师网络和学生网络6个层作为样本蔟,选取层的策略可以根据具体需求做相应调整。
步骤S22,分别对所述样本特征进行全局平均池化,获得每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示。
分别对所述样本特征进行全局平均池化,如式(6)所示:
其中,X代表一个单通道宽为W、高为H的特征图谱,每个特征图谱分辨率对应两个层。
经过全局池化平均后,X的对应输出为一个标量。以此类推,所述教师网络、学生网络样本蔟在网络中某个层的输出特征图谱(宽为W、高为H、通道数为C),分别经过全局池化操作后,得到对应的长度为C的特征向量fT、fS,如式(7)和是(8)所示:
步骤S30,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵。
步骤S31,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络样本蔟中每个样本的跨层特征空间变化,以l1层到l2层间的变化为例,如式(9)和式(10)所示:
以此类推,本发明一个实施例中,教师网络和学生网络l3层到l4层间的变化、l5层到l6层间的变化,分别如式(11)、式(12)、式(13)、式(14)所示:
步骤S32,基于所述教师网络、学生网络样本簇中每个样本的跨层特征空间变化,分别计算所述教师网络、学生网络的特征空间变化矩阵,以l1层到l2层为例,如式(15)和式(16)所示:
以此类推,本发明一个实施例中,教师网络和学生网络l3层到l4层间的变化、l5层到l6层的特征空间变化矩阵,分别如式(17)、式(18)、式(19)、式(20)所示:
步骤S40,基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络。
所述目标损失函数,其计算方法如式(21)所示:
LossTotal=LossGT+λLossTrans 式(21)
其中,LossTotal代表目标损失函数;LossGT代表分类损失函数;LossTrans代表基于特征空间变化的损失函数,λ为空间变化损失函数的权重。
本发明一个实施例中,分别选取教师网络和学生网络中l1层到l6层,基于特征空间变化矩阵的损失函数,其计算方法如式(22)所示:
其中,l1、l2、l3、l4、l5、l6代表步骤S30选择的层; 分别为教师网络、学生网络l1与l2之间的特征空间变化矩阵; 分别为教师网络、学生网络l3与l4之间的特征空间变化矩阵; 分别为教师网络、学生网络l5与l6之间的特征空间变化矩阵。
如图2所示,为本发明基于特征空间变化的蒸馏学习方法的算法框架图,教师网络为预先训练好的复杂网络,按照教师网络的通道数、计算复杂度、存储空间要求构建对应的学生网络;分别选取预设的网络层,计算每一层的特征空间表示以及特定两个层间的跨层特征空间变化矩阵;根据计算获得的跨层特征空间变化矩阵计算基于特征空间变化的损失函数,根据真实标签计算分类损失函数;通过两个损失函数的加权将教师网络的特征空间变化作为知识迁移到学生网络中,从而使学生网络学习到教师网络全局的知识。其中,N1、N2…Nn为教师网络结构中的模块编号,N1′、N2′…Nn′为学生网络结构中的模块编号,箭头代表监督信号。
本发明一个应用为直升机航巡数据图像识别技术中的模型压缩与加速。直升机航巡数据的图像识别,即通过基于深度学习的图像智能识别技术,代替人工对输电线路的缺陷进行查找,提高图像缺陷发现率的同时辅助人工提高工作效率。具体而言,在输电线路上可能存在各种缺陷和隐患(如绝缘子自爆、螺栓缺销子等),需要及时排查。采用基于深度学习的智能识别技术能够自动识别出可能存在缺陷和隐患:首先,对已有缺陷图像数据进行数据标注,作为模型训练的数据支撑;然后,将数据输入到缺陷识别的卷积神经网络中进行训练,最终得到能够预测缺陷类别的深度神经网络模型。现有深度神经网络模型为了获取更高的精确度,常常拥有大量参数,使得算法响应速度难以满足实际应用。为了进一步提升识别模型的响应速度,精简模型,采用本发明基于特征空间变化的蒸馏学习方法,获得具有深度神经网络模型全局知识的精简模型,大大提高了模型的相应速度同时模型的精准度能够满足要求。本发明方法可以适用于所有卷积神经网络的蒸馏,不仅限于此应用,仅以本实例作为一个示例说明本发明方法的应用。
本发明第二实施例的基于特征空间变化的蒸馏学习系统,包括网络构建模块、网络特征表示模块、跨层网络特征表示模块、蒸馏学习模块、输出模块;
所述网络构建模块,配置为根据训练好的教师网络的通道数、计算复杂度、存储空间要求,构建学生网络;
所述网络特征表示模块,配置为选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示;
所述跨层网络特征表示模块,配置为基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵;
所述蒸馏学习模块,配置为基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络;
所述输出模块,配置为将获取的学习后的学生网络输出。
所属技术领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统的具体工作过程及有关说明,可以参考前述方法实施例中的对应过程,在此不再赘述。
需要说明的是,上述实施例提供的基于特征空间变化的蒸馏学习系统,仅以上述各功能模块的划分进行举例说明,在实际应用中,可以根据需要而将上述功能分配由不同的功能模块来完成,即将本发明实施例中的模块或者步骤再分解或者组合,例如,上述实施例的模块可以合并为一个模块,也可以进一步拆分成多个子模块,以完成以上描述的全部或者部分功能。对于本发明实施例中涉及的模块、步骤的名称,仅仅是为了区分各个模块或者步骤,不视为对本发明的不当限定。
本发明第三实施例的一种存储装置,其中存储有多条程序,所述程序适于由处理器加载并执行以实现上述的基于特征空间变化的蒸馏学习方法。
本发明第四实施例的一种处理装置,包括处理器、存储装置;处理器,适于执行各条程序;存储装置,适于存储多条程序;所述程序适于由处理器加载并执行以实现上述的基于特征空间变化的蒸馏学习方法。
所属技术领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的存储装置、处理装置的具体工作过程及有关说明,可以参考前述方法实施例中的对应过程,在此不再赘述。
本领域技术人员应该能够意识到,结合本文中所公开的实施例描述的各示例的模块、方法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,软件模块、方法步骤对应的程序可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。为了清楚地说明电子硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以电子硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。本领域技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
术语“包括”或者任何其它类似用语旨在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备/装置不仅包括那些要素,而且还包括没有明确列出的其它要素,或者还包括这些过程、方法、物品或者设备/装置所固有的要素。
至此,已经结合附图所示的优选实施方式描述了本发明的技术方案,但是,本领域技术人员容易理解的是,本发明的保护范围显然不局限于这些具体实施方式。在不偏离本发明的原理的前提下,本领域技术人员可以对相关技术特征作出等同的更改或替换,这些更改或替换之后的技术方案都将落入本发明的保护范围之内。
Claims (9)
1.一种适用于嵌入式设备的基于特征空间变化的蒸馏学习方法,其特征在于,包括:
步骤S10,根据蒸馏学习的教师网络的通道数、计算复杂度、存储空间要求,构建蒸馏学习的学生网络;
步骤S20,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示;
步骤S30,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵;
步骤S40,基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络。
2.根据权利要求1所述的适用于嵌入式设备的基于特征空间变化的蒸馏学习方法,其特征在于,步骤S10中“根据蒸馏学习的教师网络的通道数、计算复杂度、存储空间要求,构建蒸馏学习的学生网络”,其方法为:
步骤S11,提取教师网络的通道数、计算复杂度、存储空间要求:
T={CT,ST,NT}
其中,T代表教师网络,CT代表教师网络计算复杂度,ST代表教师网络存储空间要求,NT代表教师网络通道数;
步骤S12,所述学生网络采用与所述教师网络相同的网络结构,学生网络的计算复杂度要求为CS,学生网络的存储空间消耗要求SS,学生网络的通道数NS,构建学生网络:
S={CS,SS,NS}
其中,CS、SS为预先设定的,NS根据教师网络通道数NT计算:
NS=min{NT*(CS/CT),NT*sqrt(SS/ST)}
其中,sqrt()代表平方根计算,min()代表求最小数。
3.根据权利要求1所述的适用于嵌入式设备的基于特征空间变化的蒸馏学习方法,其特征在于,步骤S20中“选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示”,其方法为:
步骤S21,选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别提取样本蔟中每个样本的样本特征;
步骤S22,分别对所述样本特征进行全局平均池化,获得每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示。
4.根据权利要求1所述的适用于嵌入式设备的基于特征空间变化的蒸馏学习方法,其特征在于,步骤S30中“基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵”,其方法为:
步骤S31,基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络样本簇中每个样本的跨层特征空间变化;
步骤S32,基于所述教师网络、学生网络样本簇中每个样本的跨层特征空间变化,分别计算所述教师网络、学生网络的特征空间变化矩阵。
5.根据权利要求1所述的适用于嵌入式设备的基于特征空间变化的蒸馏学习方法,其特征在于,所述目标损失函数,其计算方法为:
LossTotal=LossGT+λLossTrans
其中,LossTotal代表目标损失函数;LossGT代表分类损失函数;LossTrans代表基于特征空间变化的损失函数,λ为空间变化损失函数的权重。
7.一种适用于嵌入式设备的基于特征空间变化的蒸馏学习系统,其特征在于,包括网络构建模块、网络特征表示模块、跨层网络特征表示模块、蒸馏学习模块、输出模块;
所述网络构建模块,配置为根据训练好的教师网络的通道数、计算复杂度、存储空间要求,构建学生网络;
所述网络特征表示模块,配置为选取所述教师网络预设层、所述学生网络的相应层作为样本蔟,分别计算每个样本蔟的样本在所述教师网络、学生网络中每一层的特征空间表示;
所述跨层网络特征表示模块,配置为基于所述每个样本蔟的样本在教师网络、学生网络中每一层的特征空间表示,分别计算所述教师网络、学生网络的跨层特征空间变化矩阵;
所述蒸馏学习模块,配置为基于所述教师网络、学生网络的跨层特征空间变化矩阵,计算目标损失函数并通过所述目标损失函数将所述教师网络的跨层特征空间变化作为知识迁移到所述学生网络中,获得学习后的学生网络;
所述输出模块,配置为将获取的学习后的学生网络输出。
8.一种存储装置,其中存储有多条程序,其特征在于,所述程序适于由处理器加载并执行以实现权利要求1-6任一项所述的适用于嵌入式设备的基于特征空间变化的蒸馏学习方法。
9.一种处理装置,包括
处理器,适于执行各条程序;以及
存储装置,适于存储多条程序;
其特征在于,所述程序适于由处理器加载并执行以实现:
权利要求1-6任一项所述的适用于嵌入式设备的基于特征空间变化的蒸馏学习方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910360632.1A CN110135562B (zh) | 2019-04-30 | 2019-04-30 | 基于特征空间变化的蒸馏学习方法、系统、装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910360632.1A CN110135562B (zh) | 2019-04-30 | 2019-04-30 | 基于特征空间变化的蒸馏学习方法、系统、装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110135562A CN110135562A (zh) | 2019-08-16 |
CN110135562B true CN110135562B (zh) | 2020-12-01 |
Family
ID=67575888
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910360632.1A Active CN110135562B (zh) | 2019-04-30 | 2019-04-30 | 基于特征空间变化的蒸馏学习方法、系统、装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110135562B (zh) |
Families Citing this family (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110490136B (zh) * | 2019-08-20 | 2023-03-24 | 电子科技大学 | 一种基于知识蒸馏的人体行为预测方法 |
CN112487182B (zh) * | 2019-09-12 | 2024-04-12 | 华为技术有限公司 | 文本处理模型的训练方法、文本处理方法及装置 |
CN111275183B (zh) * | 2020-01-14 | 2023-06-16 | 北京迈格威科技有限公司 | 视觉任务的处理方法、装置和电子系统 |
CN111260056B (zh) * | 2020-01-17 | 2024-03-12 | 北京爱笔科技有限公司 | 一种网络模型蒸馏方法及装置 |
CN111544855B (zh) * | 2020-04-30 | 2021-08-31 | 天津大学 | 基于蒸馏学习和深度学习纯意念控制智能康复方法及应用 |
CN111753878A (zh) * | 2020-05-20 | 2020-10-09 | 济南浪潮高新科技投资发展有限公司 | 一种网络模型部署方法、设备及介质 |
CN113536970A (zh) * | 2021-06-25 | 2021-10-22 | 华为技术有限公司 | 一种视频分类模型的训练方法及相关装置 |
CN113947590B (zh) * | 2021-10-26 | 2023-05-23 | 四川大学 | 一种基于多尺度注意力引导和知识蒸馏的表面缺陷检测方法 |
CN115631178B (zh) * | 2022-11-03 | 2023-11-10 | 昆山润石智能科技有限公司 | 自动晶圆缺陷检测方法、系统、设备及存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
CN108921294A (zh) * | 2018-07-11 | 2018-11-30 | 浙江大学 | 一种用于神经网络加速的渐进式块知识蒸馏方法 |
CN109299657A (zh) * | 2018-08-14 | 2019-02-01 | 清华大学 | 基于语义注意力保留机制的群体行为识别方法及装置 |
CN109409500A (zh) * | 2018-09-21 | 2019-03-01 | 清华大学 | 基于知识蒸馏与非参数卷积的模型加速方法及装置 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11195093B2 (en) * | 2017-05-18 | 2021-12-07 | Samsung Electronics Co., Ltd | Apparatus and method for student-teacher transfer learning network using knowledge bridge |
-
2019
- 2019-04-30 CN CN201910360632.1A patent/CN110135562B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
CN108921294A (zh) * | 2018-07-11 | 2018-11-30 | 浙江大学 | 一种用于神经网络加速的渐进式块知识蒸馏方法 |
CN109299657A (zh) * | 2018-08-14 | 2019-02-01 | 清华大学 | 基于语义注意力保留机制的群体行为识别方法及装置 |
CN109409500A (zh) * | 2018-09-21 | 2019-03-01 | 清华大学 | 基于知识蒸馏与非参数卷积的模型加速方法及装置 |
Non-Patent Citations (2)
Title |
---|
Distilling the Knowledge in a Neural Network;Geoffrey Hinton et al.;《arXiv:1503.02531v1[stat.ML]》;20150309;第1-9页 * |
基于移动端的高效人脸识别算法;魏彪等;《现代计算机》;20190305(第7期);第61-66页 * |
Also Published As
Publication number | Publication date |
---|---|
CN110135562A (zh) | 2019-08-16 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110135562B (zh) | 基于特征空间变化的蒸馏学习方法、系统、装置 | |
CN114241282B (zh) | 一种基于知识蒸馏的边缘设备场景识别方法及装置 | |
CN111967294A (zh) | 一种无监督域自适应的行人重识别方法 | |
CN111126258A (zh) | 图像识别方法及相关装置 | |
CN108345875A (zh) | 可行驶区域检测模型训练方法、检测方法和装置 | |
CN108288014A (zh) | 道路智能提取方法和装置、提取模型构建方法及混合导航系统 | |
CN113963165B (zh) | 一种基于自监督学习的小样本图像分类方法及系统 | |
CN107403426A (zh) | 一种目标物体检测方法及设备 | |
CN113807399A (zh) | 一种神经网络训练方法、检测方法以及装置 | |
CN115187772A (zh) | 目标检测网络的训练及目标检测方法、装置及设备 | |
CN110866564A (zh) | 多重半监督图像的季节分类方法、系统、电子设备和介质 | |
CN117056452B (zh) | 知识点学习路径构建方法、装置、设备以及存储介质 | |
CN115546196A (zh) | 一种基于知识蒸馏的轻量级遥感影像变化检测方法 | |
CN111104831A (zh) | 一种视觉追踪方法、装置、计算机设备以及介质 | |
Demertzis et al. | A machine hearing framework for real-time streaming analytics using Lambda architecture | |
CN117036843A (zh) | 目标检测模型训练方法、目标检测方法和装置 | |
CN113408621A (zh) | 面向机器人技能学习的快速模仿学习方法、系统、设备 | |
CN116912624A (zh) | 一种伪标签无监督数据训练方法、装置、设备及介质 | |
CN114972904B (zh) | 一种基于对抗三元组损失的零样本知识蒸馏方法及系统 | |
CN115690568A (zh) | 一种基于增量学习的无人艇目标检测方法 | |
CN110321818A (zh) | 一种复杂场景中的行人检测方法 | |
CN114330554A (zh) | 一种面向智能安防的视觉深度模型知识重组方法 | |
CN114445684A (zh) | 车道线分割模型的训练方法、装置、设备及存储介质 | |
CN111178370B (zh) | 车辆检索方法及相关装置 | |
CN114708307B (zh) | 基于相关滤波器的目标跟踪方法、系统、存储介质及设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |