CN111738436A - 一种模型蒸馏方法、装置、电子设备及存储介质 - Google Patents
一种模型蒸馏方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN111738436A CN111738436A CN202010607520.4A CN202010607520A CN111738436A CN 111738436 A CN111738436 A CN 111738436A CN 202010607520 A CN202010607520 A CN 202010607520A CN 111738436 A CN111738436 A CN 111738436A
- Authority
- CN
- China
- Prior art keywords
- network model
- parameter
- network
- training
- distillation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 117
- 238000000034 method Methods 0.000 title claims abstract description 59
- 238000012549 training Methods 0.000 claims abstract description 153
- 230000006870 function Effects 0.000 claims description 81
- 238000013528 artificial neural network Methods 0.000 claims description 67
- 238000010606 normalization Methods 0.000 claims description 23
- 238000004364 calculation method Methods 0.000 claims description 12
- 230000004913 activation Effects 0.000 claims description 11
- 238000000605 extraction Methods 0.000 claims description 8
- 238000004590 computer program Methods 0.000 claims description 7
- 238000012216 screening Methods 0.000 claims description 7
- 230000009466 transformation Effects 0.000 claims description 4
- 238000003062 neural network model Methods 0.000 abstract description 26
- 230000000875 corresponding effect Effects 0.000 description 88
- 230000008569 process Effects 0.000 description 21
- 238000010801 machine learning Methods 0.000 description 15
- 238000010586 diagram Methods 0.000 description 11
- 238000012545 processing Methods 0.000 description 10
- 238000011176 pooling Methods 0.000 description 7
- 238000010276 construction Methods 0.000 description 6
- 239000010985 leather Substances 0.000 description 6
- 238000001514 detection method Methods 0.000 description 5
- 238000013473 artificial intelligence Methods 0.000 description 4
- 230000003416 augmentation Effects 0.000 description 4
- 238000013527 convolutional neural network Methods 0.000 description 4
- 230000002829 reductive effect Effects 0.000 description 4
- 230000009471 action Effects 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 3
- 238000005070 sampling Methods 0.000 description 3
- 230000019771 cognition Effects 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 239000011159 matrix material Substances 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000036961 partial effect Effects 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 241000287196 Asthenes Species 0.000 description 1
- 241001465754 Metazoa Species 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 238000013529 biological neural network Methods 0.000 description 1
- 210000004556 brain Anatomy 0.000 description 1
- 210000003169 central nervous system Anatomy 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000001149 cognitive effect Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000005094 computer simulation Methods 0.000 description 1
- 230000002596 correlated effect Effects 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 230000008014 freezing Effects 0.000 description 1
- 238000007710 freezing Methods 0.000 description 1
- 239000011521 glass Substances 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 238000013140 knowledge distillation Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 230000000670 limiting effect Effects 0.000 description 1
- 238000013178 mathematical model Methods 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 239000004576 sand Substances 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 230000001131 transforming effect Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请提供一种模型蒸馏方法、装置、电子设备及存储介质,用于快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上。该方法包括:获得预先训练的第一网络模型和未经训练的第二网络模型,第一网络模型的网络参数多于第二网络模型的网络参数;从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,第一参数和第二参数均是可学习的,第一参数影响网络模型的特征分布的方差,第二参数影响网络模型的特征分布的均值;根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型;使用第一网络模型对初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。
Description
技术领域
本申请涉及人工智能、机器学习和蒸馏学习的技术领域,具体而言,涉及一种模型蒸馏方法、装置、电子设备及存储介质。
背景技术
人工智能(Artificial Intelligence,AI),是指研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学;人工智能是计算机科学的一个分支,人工智能企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。
机器学习(Machine Learning,ML),是指人工智能领域中研究人类学习行为的一个分支。借鉴认知科学、生物学、哲学、统计学、信息论、控制论、计算复杂性等学科或理论的观点,通过归纳、一般化、特殊化、类比等基本方法探索人类的认识规律和学习过程,建立各种能通过经验自动改进的算法,使计算机系统能够具有自动学习特定知识和技能的能力。
知识蒸馏(Knowledge Distillation),又被称为模型蒸馏、暗知识提取、蒸馏训练或蒸馏学习,是指将知识从一个复杂的机器学习模型迁移到另一个简化的机器学习模型,从而在保持原复杂的机器学习模型的计算准确率基本不变的情况下,简化机器学习模型的网络结构,以减小机器学习模型在实际应用中的运算量,从而提高机器学习模型的运算速度,让简化后的机器学习模型能够运行在更多计算性能不强的终端设备上。
在具体的实践中发现,目前在保证简化神经网络模型的性能几乎不受影响的情况下,难以快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上。
发明内容
本申请实施例的目的在于提供一种模型蒸馏方法、装置、电子设备及存储介质,用于快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上。
本申请实施例提供了一种模型蒸馏方法,包括:获得预先训练的第一网络模型和未经训练的第二网络模型,第一网络模型的网络参数多于第二网络模型的网络参数;从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,第一参数和第二参数均是可学习的,第一参数影响网络模型的特征分布的方差,第二参数影响网络模型的特征分布的均值;根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型;使用第一网络模型对初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。在上述的实现过程中,通过将复杂神经网络模型中的特征分布规律直接赋值给简化神经网络模型,从而极大地提升了模型蒸馏的有效性,即在保证简化神经网络模型的性能几乎不受影响的情况下,快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上。
可选地,在本申请实施例中,获得预先训练的第一网络模型,包括:获得多个训练图像和多个训练图像对应的标签表,多个训练图像包括目标对象的原始图像和对原始图像进行空间变换获得的变换图像,标签表包括目标对象的至少一个标签;以多个训练图像为训练数据,以多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络进行训练,获得训练后的第一网络模型。在上述的实现过程中,通过获得多个训练图像和多个训练图像对应的标签表;以多个训练图像为训练数据,以多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络进行训练,获得训练后的第一网络模型;从而有效地提高了第一网络模型对训练图像中的多标签识别能力。
可选地,在本申请实施例中,在对预先构建的第一神经网络进行训练之前,还包括:获得分类神经网络,分类神经网络包括:特征识别网络和归一化指数层;从分类神经网络中删除归一化指数层,获得特征识别网络;根据特征识别网络和全连接层构建第一神经网络。
可选地,在本申请实施例中,从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,包括:从第一网络模型中的多个批量正则化层筛选出至少一个第一批量正则化层,第一批量正则化层为跨步卷积计算之前的正则化层;从第一批量正则化层中提取出第一参数和第二参数。
可选地,在本申请实施例中,根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化,包括:判断第一批量正则化层对应模块的通道数量是否大于第二批量正则化层对应模块的通道数量;若是,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值。在上述的实现过程中,若第一批量正则化层对应模块的通道数量大于第二批量正则化层对应模块的通道数量,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值,从而有效地改善第一批量正则化层对应模块的通道数量与第二批量正则化层对应模块的通道数量不一致的问题。
可选地,在本申请实施例中,使用第一网络模型对初始化后的第二网络模型进行蒸馏训练,包括:根据第一批量正则化层对应的特征值和第二批量正则化层对应的特征值构建蒸馏损失函数,蒸馏损失函数表征第一网络模型和第二网络模型的蒸馏损失,第一批量正则化层对应的特征值和第二批量正则化层对应的特征值均在批量正则化层之后且在激活函数计算之前的特征值;根据第一网络模型的分类损失函数和蒸馏损失函数对初始化后的第二网络模型进行蒸馏训练,分类损失函数表征第一网络模型对输入数据进行预测的分类标签与训练标签的分类任务损失。在上述的实现过程中,通过根据第一批量正则化层对应的特征值和第二批量正则化层对应的特征值构建蒸馏损失函数,蒸馏损失函数表征第一网络模型和第二网络模型的蒸馏损失,第一批量正则化层对应的特征值和第二批量正则化层对应的特征值均在批量正则化层之后且在激活函数计算之前的特征值;根据第一网络模型的分类损失函数和蒸馏损失函数对初始化后的第二网络模型进行蒸馏训练,从而极大地提升了模型蒸馏的有效性,即在保证简化神经网络模型的性能几乎不受影响的情况下,快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上。
可选地,在本申请实施例中,在获得蒸馏训练后的第二网络模型之后,还包括:对待预测图像进行归一化处理,获得归一化后的图像;使用蒸馏训练后的第二网络模型对正则化后的图像进行预测,获得待预测图像对应的预测结果。在上述的实现过程中,通过对待预测图像进行归一化处理,获得归一化后的图像;使用蒸馏训练后的第二网络模型对正则化后的图像进行预测,获得待预测图像对应的预测结果;也就是说,通过使用蒸馏训练后的第二网络模型预测待预测图像,从而有效地提高了获得待预测图像对应的预测结果的准确率。
本申请实施例还提供了一种模型蒸馏装置,包括:模型获得模块,用于获得预先训练的第一网络模型和未经训练的第二网络模型,第一网络模型的网络参数多于第二网络模型的网络参数;参数提取模块,用于从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,第一参数和第二参数均是可学习的,第一参数影响网络模型的特征分布的方差,第二参数影响网络模型的特征分布的均值;层初始化模块,用于根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型;蒸馏训练模块,用于使用第一网络模型对初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。
可选地,在本申请实施例中,模型获得模块,包括:图像标签获得模块,用于获得多个训练图像和多个训练图像对应的标签表,多个训练图像包括目标对象的原始图像和对原始图像进行空间变换获得的变换图像,标签表包括目标对象的至少一个标签;神经网络训练模块,用于以多个训练图像为训练数据,以多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络进行训练,获得训练后的第一网络模型。
可选地,在本申请实施例中,模型蒸馏装置,还包括:分类网络获得模块,用于获得分类神经网络,分类神经网络包括:特征识别网络和归一化指数层;特征网络获得模块,用于从分类神经网络中删除归一化指数层,获得特征识别网络;神经网络构建模块,用于根据特征识别网络和全连接层构建第一神经网络。
可选地,在本申请实施例中,参数提取模块,包括:正则化层筛选模块,用于从第一网络模型中的多个批量正则化层筛选出至少一个第一批量正则化层,第一批量正则化层为跨步卷积计算之前的正则化层;正则化层提取模块,用于从第一批量正则化层中提取出第一参数和第二参数。
可选地,在本申请实施例中,层初始化模块,包括:通道数量判断模块,用于判断第一批量正则化层对应模块的通道数量是否大于第二批量正则化层对应模块的通道数量;正则化层赋值模块,用于若第一批量正则化层对应模块的通道数量大于第二批量正则化层对应模块的通道数量,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值。
可选地,在本申请实施例中,蒸馏训练模块,包括:损失函数构建模块,用于根据第一批量正则化层对应的特征值和第二批量正则化层对应的特征值构建蒸馏损失函数,蒸馏损失函数表征第一网络模型和第二网络模型的蒸馏损失,第一批量正则化层对应的特征值和第二批量正则化层对应的特征值均在批量正则化层之后且在激活函数计算之前的特征值;模型蒸馏训练模块,用于根据第一网络模型的分类损失函数和蒸馏损失函数对初始化后的第二网络模型进行蒸馏训练,分类损失函数表征第一网络模型对输入数据进行预测的分类标签与训练标签的分类任务损失。
可选地,在本申请实施例中,模型蒸馏装置,还包括:图像正则处理模块,用于对待预测图像进行归一化处理,获得归一化后的图像;预测结果获得模块,用于使用蒸馏训练后的第二网络模型对正则化后的图像进行预测,获得待预测图像对应的预测结果。
本申请实施例还提供了一种电子设备,包括:处理器和存储器,存储器存储有处理器可执行的机器可读指令,机器可读指令被处理器执行时执行如上面描述的方法。
本申请实施例还提供了一种存储介质,该存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如上面描述的方法。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1示出的本申请实施例提供的模型蒸馏方法的流程示意图;
图2示出的本申请实施例提供的从第一网络模型蒸馏至第二网络模型的蒸馏位置示意图;
图3示出的本申请实施例提供的模型蒸馏装置的结构示意图;
图4示出的本申请实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整的描述。
在介绍本申请实施例提供的模型蒸馏方法之前,先介绍本申请实施例所涉及的一些概念:
深度学习(Deep Learning),是机器学习中一种基于对数据进行表征学习的算法,深度学习是机器学习的分支,也是一种以人工神经网络为架构,对数据进行表征学习的算法。
归一化指数(Softmax)层,又被称为归一化指数函数层、softmax分类器、softmax层或Softmax函数,实际上是有限项离散概率分布的梯度对数归一化;在数学中,尤其是概率论和相关领域中,归一化指数函数,或称Softmax函数,是逻辑函数的一种推广;归一化指数函数能将一个含任意实数的K维向量z“压缩”到另一个K维实向量σ(z)中,使得每一个元素的范围都在(0,1)之间,并且所有元素的和为1。
全连接层(Fully Connected Layer,FC),是指将将经过多个卷积层和池化层的图像特征图中的特征进行整合的线性运算单元层。全连接层将卷积层产生的特征图映射成一个固定长度的特征向量,这里的固定长度一般是指输入图像数据集中的图像类别数。
全局平均池化(global average pooling,GAP)层,是指将特征图所有像素值相加求平局的神经网络层,使用GAP计算可以得到一个数值,即用该数值来表示对应特征图。
图像增广,又称扩增训练数据集或图像扩增,是指对现有的训练图像进行图像增强操作,以获得更多的训练图像,图像增强操作具体例如:改变背景颜色或亮度、旋转图像角度或者裁剪图像大小等等。
归一化处理,具体有两种形式,一种是把数变为(0,1)之间的小数,一种是把有量纲表达式变为无量纲表达式;归一化处理是要把需要处理的数据经过处理后限制在需要的一定范围内;归一化的具体作用是归纳统一样本的统计分布性;归一化在0到1之间是统计的概率分布,归一化在某个区间上是统计的坐标分布。
损失函数(loss function),又被称为成本函数,是指一种将一个事件(即在一个样本空间中的一个元素)映射到一个表达与其事件相关的经济成本或机会成本的实数上的一种函数,借此直观表示的一些“成本”与事件的关联;损失函数可以决定训练过程如何来“惩罚”网络的预测结果和真实结果之间的差异,各种不同的损失函数适用于不同类型的任务。
批量正则化(Batch Normalization,BN),又被称为批量归一化,是指在神经网络的计算过程中,对每一批数据进行归一化处理,对于训练中某一个批量(batch)的数据,注意这个数据是可以输入也可以是网络中间的某一层输出。
服务器是指通过网络提供计算服务的设备,服务器例如:x86服务器以及非x86服务器,非x86服务器包括:大型机、小型机和UNIX服务器。当然在具体的实施过程中,上述的服务器可以具体选择大型机或者小型机,这里的小型机是指采用精简指令集计算(ReducedInstruction Set Computing,RISC)、单字长定点指令平均执行速度(MillionInstructions Per Second,MIPS)等专用处理器,主要支持UNIX操作系统的封闭且专用的提供计算服务的设备;这里的大型机,又名大型主机,是指使用专用的处理器指令集、操作系统和应用软件来提供计算服务的设备。
需要说明的是,本申请实施例提供的模型蒸馏方法可以被电子设备执行,这里的电子设备是指具有执行计算机程序功能的设备终端或者上述的服务器,设备终端例如:智能手机、个人电脑(personal computer,PC)、平板电脑、个人数字助理(personal digitalassistant,PDA)、移动上网设备(mobile Internet device,MID)、网络交换机或网络路由器等。
在介绍本申请实施例提供的模型蒸馏方法之前,先介绍该模型蒸馏方法适用的应用场景,这里的应用场景包括但不限于:使用该模型蒸馏方法将一个复杂的机器学习模型的具体能力迁移到另一个简化的机器学习模型上,将模型进行压缩,使得压缩后的模型能够在终端设备上运行,或者提供模型压缩服务等;这里的具体能力例如:对文本内容或图像视频进行分类、图像中的内容识别和自然语言处理处理任务等能力,常见的自然语言处理处理任务例如:命名实体识别和词性标注等等。
请参见图1示出的本申请实施例提供的模型蒸馏方法的流程示意图;该模型蒸馏方法的主要思路为,从复杂的第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,第一参数和第二参数均是可学习的,再根据第一参数和第二参数初始化第二网络模型中的第二批量正则化层,最后使用第一网络模型对初始化后的第二网络模型进行蒸馏训练;也就是说,通过将复杂神经网络模型中能够体现特征分布规律的参数直接赋值给简化神经网络模型,从而极大地提升了模型蒸馏的有效性,即在保证简化神经网络模型的性能几乎不受影响的情况下,快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上,上述的模型蒸馏方法可以包括:
步骤S110:获得预先训练的第一网络模型和未经训练的第二网络模型,第一网络模型的网络参数多于第二网络模型的网络参数。
第一网络模型,是指网络参数多于第二网络模型的网络模型,也可以理解为蒸馏学习中的教师网络(teachernetwork);第一网络模型具体可以是卷积神经网络(Convolutional neural network,CNN)、循环神经网络(Recurrent Neural Network,RNN)和深度神经网络(Deep Neural Networks,DNN)等等,也可以是目标检测网络,这里的目标检测网络具体可以是区域卷积神经网络(Region Convolutional Neural Network,RCNN)等。在具体的实践过程中,可以采用mobilenetv3-large作为教师网络的骨干网,即第一网络模型可以为mobilenetv3-large。
第二网络模型,是指网络参数少于第一网络模型的网络模型,也可以理解为蒸馏学习中的学生网络(student network);第二网络模型具体可以是CNN、RNN或DNN;也可以是目标检测网络,这里的目标检测网络具体可以是区域卷积神经网络。在具体的实践过程中,可以采用mobilenetv3-small作为学生网络的骨干网,即第二网络模型可以为mobilenetv3-small。
在具体的蒸馏学习或者模型蒸馏的过程中,需要在保持第二网络模型的准确率基本和第一网络模型基本不变的情况下,根据第一网络模型的权重参数来训练第二网络模型的权重参数,即将知识从第一网络模型迁移到第二网络模型中,上述的蒸馏学习任务可以选择人体图像多属性分类任务,也就是说,第一网络模型的和第二网络模型均可以是人体图像多属性分类网络模型,即该模型是针对人体图像中的多个属性进行分类。
上述步骤S110中的获得第一网络模型包括:第一网络模型的构建阶段、第一网络模型的训练阶段和第二网络模型的构建阶段;其中,第一网络模型可以使用已经构建好的网络模型,具体例如:从文件系统中获取或者从数据库中获取已经构建好的网络模型,使用浏览器等软件获取互联网上的已经构建好的网络模型,或者使用其它应用程序访问互联网获得已经构建好的网络模型;当然,也可以从头开始构建第一网络模型,那么第一网络模型的构建阶段的实施方式可以包括:
步骤S111:获得分类神经网络,分类神经网络包括:特征识别网络和归一化指数层。
分类神经网络,这里的分类神经网络又被称为多标签分类神经网络(multi-labelclassification neural network),或者多属性分类神经网络,是指对神经网络进行训练后获得的用于对图像的多个属性或者多个标签进行分类的神经网络,即将图像作为图像分类神经网络模型的输入获得该图像对应的多个属性或者多个标签的概率列表。
上述步骤S111的实施方式包括:第一种方式,获取预先存储的分类神经网络,具体例如:从文件系统中获取分类神经网络,或者从数据库中获取分类神经网络;第二种方式,其他终端设备向电子设备发送分类神经网络,然后电子设备接收其他终端设备发送的分类神经网络;第三种方式,使用浏览器等软件获取互联网上的分类神经网络,或者使用其它应用程序访问互联网获得分类神经网络;当然也可以直接将普通的多分类神经网络的尾部改为多属性分类网络模型。
步骤S112:从分类神经网络中删除归一化指数层,获得特征识别网络。
特征识别网络,是指对数据中的特征进行识别的神经网络,这里的神经网络(Neural Network,NN),又被称为类神经网络,在机器学习和认知科学领域,是一种模仿生物神经网络(例如:动物的中枢神经系统,可以是大脑)的结构和功能的数学模型或计算模型,用于对函数进行估计或近似;这里的神经网络由大量的人工神经元联结进行计算。
上述步骤S112的实施方式例如:大部分的普通分类神经网络包括:特征识别网络和归一化指数层;那么从分类神经网络中删除归一化指数层,即可获得特征识别网络。
步骤S113:根据特征识别网络和全连接层构建第一神经网络。
上述步骤S113的实施方式包括两种情况:第一种情况,特征识别网络中包括全局平均池化层的情况,若特征识别网络中包括全局平均池化层,则可以将特征识别网络和全连接层进行连接,获得第一神经网络;第二种情况,特征识别网络中不包括全局平均池化层,若特征识别网络中不包括全局平均池化层,则可以将特征识别网络、全局平均池化层和全连接层进行连接,获得第一神经网络。在具体的实践过程中,对第一网络模型的构建方式又例如:分别在第一网络模型的最后一层卷积层上添加GAP,获得维度为C的一维特征FC;然后分别在教师网络和学生网络的一维特征FC上添加个全连接层,获得第一网络模型。
上述的第一网络模型的训练阶段的实施方式可以包括:
步骤S114:获得多个训练图像和多个训练图像对应的标签表。
训练图像,是指对第一神经网络进行训练的图像,这里的训练图像可以是包括人体的不同属性或者标签的图像,即该训练图像包括多个标签或者属性,这里的多个标签具体例如:第一标签为图像中的人戴了帽子,图像中的人戴帽子的概率为0.7;第二标签为图像中的人穿了皮鞋,图像中的人穿皮鞋的概率为0.9,更多的标签分类概率依此类推;这里的训练图像包括目标对象的原始图像和对原始图像进行空间变换获得的变换图像。
上述步骤S114中的获得多个训练图像的实施方式包括:第一种方式,使用图像采集装置对目标人体进行采集,获得采集的人体图像,将人体图像作为训练图像;第二种方式,使用浏览器等软件获取互联网上的训练图像,或者使用其它应用程序访问互联网获得训练图像,例如可以使用imagenet数据集,或者使用公开数据集Wider Attribute作为训练图像数据集;第三种方式,使用图像增广的方式对已经获得的训练图像进行扩充,从而获得训练图像数据集。
标签表,是指上述训练图像中目标对象对应的多个标签或者多个属性构成的数据表;标签表包括目标对象的至少一个标签。这里的标签表包括目标对象的至少一个标签,上述的标签(label)有时候也被称为属性(attribute),具体的属性例如:戴帽、戴眼镜和穿皮鞋等等。这里的标签表中的标签对应具体值的设置可以根据具体情况进行设置,例如:若在人体图像中的某属性的位置被遮挡或没有被拍摄到,则该属性对应数值可以设置为-1;若图像中存在该属性,则对应数值可以设置为1,若图像中不存在该属性,则可以设置为0;具体以脚上是否穿着皮鞋为例,也就是说,若根本没有拍摄到脚或鞋,那么可以将标签具体值设置为-1,若训练图像中存在穿着皮鞋的脚,那么可以将该标签具体值设置为1,若训练图像中不存在穿着皮鞋的脚,那么可以将该标签具体值设置为0;另外,人体图像数据集中的图像应该包含一个或多个属性的相关区域。
上述步骤S114的实施方式例如:上述的训练图像和标签表可以分开获取,具体例如:人工的搜集训练图像,并人工地识别训练图像的标签表;当然,也可以将训练图像和标签表打包为数据集一起获取,这里以数据集一起获取为例进行说明;上述数据集的获得方式包括:第一种方式,获取预先存储的数据集,具体例如:从文件系统中获取数据集,或者从数据库中获取数据集;第二种方式,其他终端设备向电子设备发送数据集,然后电子设备接收其他终端设备发送的数据集;第三种方式,使用浏览器等软件获取互联网上的数据集,或者使用其它应用程序访问互联网获得数据集。在具体的实施过程中,上述的数据集可以选择公开的数据集,公开的数据集例如:Wider Attribute数据集。
步骤S115:以多个训练图像为训练数据,以多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络进行训练,获得训练后的第一网络模型。
可以理解的是,在对预先构建的第一神经网络进行训练之前,若第一神经网络被预先训练过,那么可以先加载在原数据集上训练过的权重参数,具体例如:若预先使用ImageNet图像数据集对第一神经网络训练过,那么可以加载之前使用ImageNet图像数据集预训练过的权重参数。
上述步骤S115的实施方式可以包括:
首先,将训练数据中的图像分为训练图像和测试图像,训练图像占训练数据中的预设比例,这里的预设比例具体可以为70%,并使用所有训练图像的均值对输入图像进行归一化处理;使用人工打标签的方式或者程序生成的方式构建标签表。
其次,设置训练模型的超参数,这里的训练模型的超参数例如:将训练批量大小(batch size)设置为128;将网络优化器设置为随机梯度下降(stochastic gradientdescent,SGD);将动量(momentum)设置为0.9;将学习率(learn rate)的初始值设置为1e-3,在训练过程中,每增加5个时期(epoch)时,将学习率下降到原来的10%。与此同时,为了减少网络过拟合的概率或可能性,可以将权重衰减(weight decay)设置为1e-4。
然后,设置训练模型的损失函数,当然,训练模型的超参数和损失函数可以同时设置,也可以不分先后的设置,可以根据具体情况先后设置或者同时设置。这里的损失函数具体可以采用交叉熵损失函数,或者在训练模型的过程中的损失函数设置为自定义损失函数,该函数可以使用公式表示为:
其中,lc表示为上述的第一损失函数,即多标签图像分类损失函数,N为图像的具体数量,i表示N张图像中的第i个图像,L为标签的具体数量,即属性的具体数量,j表示L个标签中的第j个标签,xij∈R是第i张图像的第j个属性对应的逻辑斯特值,xij之后会被归一化,yij∈{0,1}表示第i张图像第j个标签对应的具体值,pj是训练图像集中的第j个属性正样本所占的比例,这里的pj是用来定义权重矩阵wij的,即根据属性正样本所占的比例来设置具体的权重矩阵,可以有效改善正负样本不均衡的问题。这里的逻辑斯特(logits)值是指模型中的未经过激活函数运算的参数值,这里的激活函数例如:sigmoid激活函数或者softmax激活函数,这里的逻辑斯特值可以理解为与标签具体值或者属性具体值正相关的参数值;其中,多个逻辑斯特值与多个训练图像是逐一对应的,即一个逻辑斯特值对应一个训练图像。
最后,以多个训练图像为训练数据,以多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络中的网络参数进行迭代训练,获得训练后的第一网络模型。
在具体的实践过程中,对第二网络模型的构建方式例如:分别在第二网络模型的最后一层卷积层上添加GAP,获得维度为C的一维特征FC;然后分别在教师网络和学生网络的一维特征FC上添加个全连接层,获得第二网络模型。可以理解的是,上述的第二网络模型的构建阶段的实施方式与第一网络模型的构建阶段的实施方式是相似或类似的,因此,这里不再对该步骤的实施方式和实施原理进行说明,如有不清楚的地方,可以参考对步骤S111至步骤S113的描述。
在步骤S110之后,执行步骤S120:从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数。
第一参数和第二参数,是指BN层中直接控制着特征分布的两个参数:第一参数和第二参数均是可学习的,第一参数可以使用α表示,α影响网络模型的特征分布的方差;第二参数可以使用β表示,β影响网络模型的特征分布的均值。需要说明的是,这里的两个参数α和β是BN层的可学习参数,就像卷积层中的权重一样,这里的α和β也可以被学习更新,不过每个通道有一个对应的α和β,其中,α影响特征分布的方差,β影响均值。
第一批量正则化层,是指在第一网络模型中的批量正则化层,即教师网络模型中的BN层,这里的BN层主要就是,首先将特征按通道归一化,特征在每个通道上的均值为0,方差为1,然后再通过α和β对均值和方差进行修正。上述的教师网络模型中的骨干网中的部分中间层或者中间模块可以使用It来表示,学生网络模型骨干网中的部分中间层或者中间模块可以使用Is来表示;因此,上述的从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,可以理解为,将It中的BN层里的α参数和β参数提取出来,这里的α参数和β参数用于初始化Is中的BN层中对应的参数。
在上述的步骤S120中的提取参数之前,还可以对正则化层进行筛选,这里的对正则化层进行筛选的实施方式包括:
步骤S121:从第一网络模型中的多个批量正则化层筛选出至少一个第一批量正则化层,第一批量正则化层为跨步卷积计算之前的正则化层。
请参见图2示出的本申请实施例提供的从第一网络模型蒸馏至第二网络模型的蒸馏位置示意图;图中使用虚线连接的表示要进行蒸馏学习的模块(block),这里的模块(block)是一系列层构成的,一般的网络的模块会由卷积层、BN层和激活函数按顺序堆积而来,可能堆叠一次也可能堆叠多次。上述的模块可以是降采样模块,图2中除降采样模块以外均为普通模块,这里的降采样模块和普通模块均可以采用倒置残差模块,也可以不采用倒置残差模块;其中,这里的倒置残差模块用于先将特征压缩到一个比较小的通道数,然后进行特征提取,最后再将通道数扩张到比较大的维度。从图2中可以看出第一网络模型包括15个模块(Block),第二网络模型包括11个模块,这15个模块和11个模块均可以是倒置残差模块,这里的15个模块和11个模块即可以包括挤压和膨胀(squeeze and expand)结构,也可以不包括挤压和膨胀(squeeze and expand)结构,但是,这里的15个模块中的每个模块和11个模块中的每个模块均包括批量正则化(BN)层。
上述步骤S121的实施方式包括:为了便于说明和理解,这里以第一网络模型为mobilenetv3-large,且第二网络模型为mobilenetv3-small为例进行说明,一共有四对模块需要进行蒸馏学习,这四对模块分别是,第一网络模型中的第一模块(Block1)、第四模块(Block4)、第七模块(Block7)和第十四模块(Block14),以及第二网络模型中的第一模块(Block1)、第二模块(Block2)、第四模块(Block4)和第九模块(Block9);那么可以从第一网络模型中的所有模块对应的批量正则化层筛选出第一模块、第四模块、第七模块和第十四模块对应的批量正则化层,这里的第一模块、第四模块、第七模块和第十四模块对应的批量正则化层均为跨步卷积计算之前的批量正则化层,因此,可以将第一模块、第四模块、第七模块和第十四模块对应的批量正则化层中的任一批量正则化层确定为第一批量正则化层,可以将第一模块、第二模块、第四模块和第九模块对应的批量正则化层中的任一批量正则化层确定为下面的第二批量正则化层,这里的第一批量正则化层对应的模块可以使用It表示,下面的第二批量正则化层对应的模块可以使用Is表示。
步骤S122:从第一批量正则化层中提取出第一参数和第二参数。
上述步骤S122的实施方式包括:从It中的批量正则化层中提取出第一参数α和第二参数β,第一参数α可以影响网络模型的特征分布的方差;第二参数β可以影响网络模型的特征分布的均值。
在步骤S120之后,执行步骤S130:根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型。
上述步骤S130中的根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化的实施方式包括:
步骤S131:判断第一批量正则化层对应模块的通道数量是否大于第二批量正则化层对应模块的通道数量。
步骤S132:若第一批量正则化层对应模块的通道数量大于第二批量正则化层对应模块的通道数量,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值。
上述步骤S131至步骤S132的实施方式例如:若出现第一批量正则化层对应的模块通道数与第二批量正则化层对应的模块通道数不一致的情况,即教师网络与学生网络中对应的要用于蒸馏的模块对儿出现通道数不一致的情况时,可以在教师网络中挑选靠前的对应通道数的参数来完成学生网络BN层的初始化。也就是说,若第一批量正则化层对应模块的通道数量大于第二批量正则化层对应模块的通道数量,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值。
上述步骤S130的实施方式具体例如:若教师网络的其中一个模块(block)中的通道数为m,学生网络对应的模块的通道数为n,这里的m>n,那么教师网络在该模块中包括m个α,这里的m个α分别表示为而学生网络在该模块中包含n个α,这里的m个α分别表示为那么可以选择教师网络中的来初始化学生网络对应参数,这里的学生网络对应的参数是指同理地,对于β参数也是类似的方法来初始化。在上述的实现过程中,若第一批量正则化层对应模块的通道数量大于第二批量正则化层对应模块的通道数量,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值,从而有效地改善第一批量正则化层对应模块的通道数量与第二批量正则化层对应模块的通道数量不一致的问题。
在步骤S130之后,执行步骤S140:使用第一网络模型对初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。
可以理解的是,在对预先构建的第二神经网络进行训练之前,若第二神经网络被预先训练过,那么可以先加载在原数据集上训练过的权重参数,具体例如:若预先使用ImageNet图像数据集对第二神经网络训练过,那么可以加载之前使用ImageNet图像数据集预训练过的权重参数。上述的步骤S140的实施方式包括:
步骤S141:根据第一批量正则化层对应的特征值和第二批量正则化层对应的特征值构建蒸馏损失函数,第一批量正则化层对应的特征值和第二批量正则化层对应的特征值均在批量正则化层之后且在激活函数计算之前的特征值。
蒸馏损失函数,是指表征第一网络模型和第二网络模型的蒸馏损失,这里的蒸馏损失函数具体可以表示为:其中,LD表示第一网络模型和第二网络模型的蒸馏损失,LT表示人体图像多属性分类任务损失,α是一个平衡损失数量级的超参数,具体地可以将α设置为30;LD可以是简单的平方差损失函数,I表示输入图像,T(I)和S(I)表示对应是模块(可以是倒置残差模块)中在第一个BN层后且在激活函数之前的特征值,L表示总损失函数,总损失函数表征第一网络模型和第二网络模型的蒸馏损失,以及人体图像多属性分类任务损失的加权和。
步骤S142:根据第一网络模型的分类损失函数和蒸馏损失函数对初始化后的第二网络模型进行蒸馏训练。
分类损失函数,是指表征第一网络模型对输入数据进行预测的分类标签与训练标签的分类任务损失的函数,这里的分类损失函数具体可以为人体图像多属性分类任务损失,即对人体图像的多个标签或者多个属性进行分类的任务损失。
上述步骤S142的实施方式包括:设置蒸馏训练过程中的超参数,这里的超参数例如:将训练批量大小(batch size)设置为128;将网络优化器设置为随机梯度下降(stochastic gradient descent,SGD);将动量(momentum)设置为0.9;将学习率(learnrate)的初始值设置为1e-3,在训练过程中,每增加5个时期(epoch)时,将学习率下降到原来的10%。与此同时,为了减少网络过拟合的概率或可能性,可以将权重衰减(weightdecay)设置为1e-4。根据第一网络模型的分类损失函数和蒸馏损失函数计算出总损失函数,然后根据总损失函数对初始化后的第二网络模型进行蒸馏训练;这里的总损失函数表征第一网络模型和第二网络模型的蒸馏损失,以及分类任务损失的加权和。
需要强调的是,在使用第一网络模型作为教师网络对第二网络模型作为学生网络进行蒸馏训练的时候,为了让蒸馏训练达到更好地效果,或者说更好地对第二网络模型进行蒸馏训练,即使用更少的时间让第二网络模型更快地蒸馏训练完成,以使得第二网络模型的性能(指对输入数据处理的正确率)达到几乎和第一网络模型的性能一样的水平,可以将第二网络模型的部分参数冻结,这里的部分参数冻结具体例如:将学生网络要进行蒸馏的残差对儿(即需要蒸馏训练一对倒置残差模块)中的BN层中的α和β的学习率设置为其他参数学习率的预设比例,这里的预设比例具体可以是1/10,当然也可以根据具体情况进行设置,预设比例例如可以为0.01、0.03、0.13、0.25和0.33等等,也就是说,为了更好的蒸馏训练,可以把学生网络这部分的参数设置为其他参数的1/10,这是为了尽量去保留这部分信息不受到太大的影响。
在上述的实现过程中,从复杂的第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,第一参数和第二参数均是可学习的,再根据第一参数和第二参数初始化第二网络模型中的第二批量正则化层,最后使用第一网络模型对初始化后的第二网络模型进行蒸馏训练;也就是说,通过将复杂神经网络模型中的特征分布规律直接赋值给简化神经网络模型,从而极大地提升了模型蒸馏的有效性,即在保证简化神经网络模型的性能几乎不受影响的情况下,快速有效地将复杂神经网络模型的参数直接迁移到简化神经网络模型上。
可选地,在本申请实施例中,在获得蒸馏训练后的第二网络模型之后,还可以使用蒸馏训练后的第二网络模型对图像进行预测,这里的对图像进行预测的过程可以包括:
在步骤S140之后,执行步骤S150:对待预测图像进行归一化处理,获得归一化后的图像。
在步骤S150之后,执行步骤S160:使用蒸馏训练后的第二网络模型对正则化后的图像进行预测,获得待预测图像对应的预测结果。
上述步骤S150至步骤S160的实施方式例如:用所有训练图像的均值对输入的训练图像进行归一化处理,获得归一化后的图像;利用训练好的学生网络对归一化后的图像进行预测,得到该归一化后的图像各个属性的概率值,可以将这里的归一化后的图像各个属性的概率值确定为待预测图像对应的预测结果。
在上述的实现过程中,通过对待预测图像进行归一化处理,获得归一化后的图像;使用蒸馏训练后的第二网络模型对正则化后的图像进行预测,获得待预测图像对应的预测结果;也就是说,通过使用蒸馏训练后的第二网络模型预测待预测图像,从而有效地提高了获得待预测图像对应的预测结果的准确率。
请参见图3示出的本申请实施例提供的模型蒸馏装置的结构示意图;本申请实施例提供了一种模型蒸馏装置200,包括:
模型获得模块210,用于获得预先训练的第一网络模型和未经训练的第二网络模型,第一网络模型的网络参数多于第二网络模型的网络参数。
参数提取模块220,用于从第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,第一参数影响网络模型的特征分布的方差,第二参数影响网络模型的特征分布的均值。
层初始化模块230,用于根据第一参数和第二参数对第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型。
蒸馏训练模块240,用于使用第一网络模型对初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。
可选地,在本申请实施例中,模型获得模块,包括:
图像标签获得模块,用于获得多个训练图像和多个训练图像对应的标签表,多个训练图像包括目标对象的原始图像和对原始图像进行空间变换获得的变换图像,标签表包括目标对象的至少一个标签。
神经网络训练模块,用于以多个训练图像为训练数据,以多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络进行训练,获得训练后的第一网络模型。
可选地,在本申请实施例中,模型蒸馏装置,还包括:
分类网络获得模块,用于获得分类神经网络,分类神经网络包括:特征识别网络和归一化指数层。
特征网络获得模块,用于从分类神经网络中删除归一化指数层,获得特征识别网络。
神经网络构建模块,用于根据特征识别网络和全连接层构建第一神经网络。
可选地,在本申请实施例中,参数提取模块,包括:
正则化层筛选模块,用于从第一网络模型中的多个批量正则化层筛选出至少一个第一批量正则化层,第一批量正则化层为跨步卷积计算之前的正则化层。
正则化层提取模块,用于从第一批量正则化层中提取出第一参数和第二参数。
可选地,在本申请实施例中,层初始化模块,包括:
通道数量判断模块,用于判断第一批量正则化层对应模块的通道数量是否大于第二批量正则化层对应模块的通道数量。
正则化层赋值模块,用于若第一批量正则化层对应模块的通道数量大于第二批量正则化层对应模块的通道数量,则使用第一批量正则化层中的第一参数对第二批量正则化层中的第一参数进行赋值,并使用第一批量正则化层中的第二参数对第一批量正则化层中的第二参数进行赋值。
可选地,在本申请实施例中,蒸馏训练模块,包括:
损失函数构建模块,用于根据第一批量正则化层对应的特征值和第二批量正则化层对应的特征值构建蒸馏损失函数,蒸馏损失函数表征第一网络模型和第二网络模型的蒸馏损失,第一批量正则化层对应的特征值和第二批量正则化层对应的特征值均在批量正则化层之后且在激活函数计算之前的特征值。
模型蒸馏训练模块,用于根据第一网络模型的分类损失函数和蒸馏损失函数对初始化后的第二网络模型进行蒸馏训练,分类损失函数表征第一网络模型对输入数据进行预测的分类标签与训练标签的分类任务损失。
可选地,在本申请实施例中,模型蒸馏装置,还可以包括:
图像正则处理模块,用于对待预测图像进行归一化处理,获得归一化后的图像。
预测结果获得模块,用于使用蒸馏训练后的第二网络模型对正则化后的图像进行预测,获得待预测图像对应的预测结果。
应理解的是,该装置与上述的模型蒸馏方法实施例对应,能够执行上述方法实施例涉及的各个步骤,该装置具体的功能可以参见上文中的描述,为避免重复,此处适当省略详细描述。该装置包括至少一个能以软件或固件(firmware)的形式存储于存储器中或固化在装置的操作系统(operating system,OS)中的软件功能模块。
请参见图4示出的本申请实施例提供的电子设备的结构示意图。本申请实施例提供的一种电子设备300,包括:处理器310和存储器320,存储器320存储有处理器310可执行的机器可读指令,机器可读指令被处理器310执行时执行如上的方法。
本申请实施例还提供了一种存储介质330,该存储介质330上存储有计算机程序,该计算机程序被处理器310运行时执行如上的方法。
其中,存储介质330可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,如静态随机存取存储器(Static Random Access Memory,简称SRAM),电可擦除可编程只读存储器(Electrically Erasable Programmable Read-Only Memory,简称EEPROM),可擦除可编程只读存储器(Erasable Programmable Read Only Memory,简称EPROM),可编程只读存储器(Programmable Red-Only Memory,简称PROM),只读存储器(Read-Only Memory,简称ROM),磁存储器,快闪存储器,磁盘或光盘。
本申请实施例所提供的几个实施例中,应该理解到,所揭露的装置和方法,也可以通过其他的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和框图显示了根据本申请实施例的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现方式中,方框中所标注的功能也可以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。
以上的描述,仅为本申请实施例的可选实施方式,但本申请实施例的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请实施例揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请实施例的保护范围之内。
Claims (10)
1.一种模型蒸馏方法,其特征在于,包括:
获得预先训练的第一网络模型和未经训练的第二网络模型,所述第一网络模型的网络参数多于所述第二网络模型的网络参数;
从所述第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,所述第一参数影响网络模型的特征分布的方差,所述第二参数影响网络模型的特征分布的均值;
根据所述第一参数和所述第二参数对所述第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型;
使用所述第一网络模型对所述初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。
2.根据权利要求1所述的方法,其特征在于,所述获得预先训练的第一网络模型,包括:
获得多个训练图像和所述多个训练图像对应的标签表,所述多个训练图像包括目标对象的原始图像和对所述原始图像进行空间变换获得的变换图像,所述标签表包括所述目标对象的至少一个标签;
以所述多个训练图像为训练数据,以所述多个训练图像对应的标签表为训练标签,对预先构建的第一神经网络进行训练,获得训练后的所述第一网络模型。
3.根据权利要求2所述的方法,其特征在于,在所述对预先构建的第一神经网络进行训练之前,还包括:
获得分类神经网络,所述分类神经网络包括:特征识别网络和归一化指数层;
从所述分类神经网络中删除所述归一化指数层,获得所述特征识别网络;
根据所述特征识别网络和全连接层构建所述第一神经网络。
4.根据权利要求1所述的方法,其特征在于,所述从所述第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,包括:
从所述第一网络模型中的多个批量正则化层筛选出至少一个第一批量正则化层,所述第一批量正则化层为跨步卷积计算之前的正则化层;
从所述第一批量正则化层中提取出所述第一参数和所述第二参数。
5.根据权利要求4所述的方法,其特征在于,所述根据所述第一参数和所述第二参数对所述第二网络模型中的第二批量正则化层进行初始化,包括:
判断所述第一批量正则化层对应模块的通道数量是否大于所述第二批量正则化层对应模块的通道数量;
若是,则使用所述第一批量正则化层中的第一参数对所述第二批量正则化层中的第一参数进行赋值,并使用所述第一批量正则化层中的第二参数对所述第一批量正则化层中的第二参数进行赋值。
6.根据权利要求5所述的方法,其特征在于,所述使用所述第一网络模型对所述初始化后的第二网络模型进行蒸馏训练,包括:
根据所述第一批量正则化层对应的特征值和所述第二批量正则化层对应的特征值构建蒸馏损失函数,所述蒸馏损失函数表征所述第一网络模型和所述第二网络模型的蒸馏损失,所述第一批量正则化层对应的特征值和所述第二批量正则化层对应的特征值均在批量正则化层之后且在激活函数计算之前的特征值;
根据所述第一网络模型的分类损失函数和所述蒸馏损失函数对所述初始化后的第二网络模型进行蒸馏训练,所述分类损失函数表征所述第一网络模型对输入数据进行预测的分类标签与训练标签的分类任务损失。
7.根据权利要求1-6任一所述的方法,其特征在于,在所述获得蒸馏训练后的第二网络模型之后,还包括:
对待预测图像进行归一化处理,获得归一化后的图像;
使用所述蒸馏训练后的第二网络模型对所述正则化后的图像进行预测,获得所述待预测图像对应的预测结果。
8.一种模型蒸馏装置,其特征在于,包括:
模型获得模块,用于获得预先训练的第一网络模型和未经训练的第二网络模型,所述第一网络模型的网络参数多于所述第二网络模型的网络参数;
参数提取模块,用于从所述第一网络模型中的第一批量正则化层中提取出第一参数和第二参数,所述第一参数影响网络模型的特征分布的方差,所述第二参数影响网络模型的特征分布的均值;
层初始化模块,用于根据所述第一参数和所述第二参数对所述第二网络模型中的第二批量正则化层进行初始化,获得初始化后的第二网络模型;
蒸馏训练模块,用于使用所述第一网络模型对所述初始化后的第二网络模型进行蒸馏训练,获得蒸馏训练后的第二网络模型。
9.一种电子设备,其特征在于,包括:处理器和存储器,所述存储器存储有所述处理器可执行的机器可读指令,所述机器可读指令被所述处理器执行时执行如权利要求1至7任一所述的方法。
10.一种存储介质,其特征在于,该存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如权利要求1至7任一所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010607520.4A CN111738436B (zh) | 2020-06-28 | 2020-06-28 | 一种模型蒸馏方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010607520.4A CN111738436B (zh) | 2020-06-28 | 2020-06-28 | 一种模型蒸馏方法、装置、电子设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111738436A true CN111738436A (zh) | 2020-10-02 |
CN111738436B CN111738436B (zh) | 2023-07-18 |
Family
ID=72653500
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010607520.4A Active CN111738436B (zh) | 2020-06-28 | 2020-06-28 | 一种模型蒸馏方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111738436B (zh) |
Cited By (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112101484A (zh) * | 2020-11-10 | 2020-12-18 | 中国科学院自动化研究所 | 基于知识巩固的增量事件识别方法、系统、装置 |
CN112184508A (zh) * | 2020-10-13 | 2021-01-05 | 上海依图网络科技有限公司 | 一种用于图像处理的学生模型的训练方法及装置 |
CN112329467A (zh) * | 2020-11-03 | 2021-02-05 | 腾讯科技(深圳)有限公司 | 地址识别方法、装置、电子设备以及存储介质 |
CN112613312A (zh) * | 2020-12-18 | 2021-04-06 | 平安科技(深圳)有限公司 | 实体命名识别模型的训练方法、装置、设备及存储介质 |
CN112766463A (zh) * | 2021-01-25 | 2021-05-07 | 上海有个机器人有限公司 | 基于知识蒸馏技术优化神经网络模型的方法 |
CN112949433A (zh) * | 2021-02-18 | 2021-06-11 | 北京百度网讯科技有限公司 | 视频分类模型的生成方法、装置、设备和存储介质 |
CN113762368A (zh) * | 2021-08-27 | 2021-12-07 | 北京市商汤科技开发有限公司 | 数据蒸馏的方法、装置、电子设备和存储介质 |
CN113919444A (zh) * | 2021-11-10 | 2022-01-11 | 北京市商汤科技开发有限公司 | 目标检测网络的训练方法、目标检测方法及装置 |
CN114359649A (zh) * | 2021-11-22 | 2022-04-15 | 腾讯科技(深圳)有限公司 | 图像处理方法、装置、设备、存储介质及程序产品 |
CN114581946A (zh) * | 2022-02-25 | 2022-06-03 | 江西师范大学 | 人群计数方法、装置、存储介质及电子设备 |
CN114973156A (zh) * | 2022-08-02 | 2022-08-30 | 松立控股集团股份有限公司 | 一种基于知识蒸馏的夜间渣土车检测方法 |
US20220366226A1 (en) * | 2021-05-17 | 2022-11-17 | Marziehsadat TAHAEI | Methods and systems for compressing a trained neural network and for improving efficiently performing computations of a compressed neural network |
Citations (15)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108875787A (zh) * | 2018-05-23 | 2018-11-23 | 北京市商汤科技开发有限公司 | 一种图像识别方法及装置、计算机设备和存储介质 |
CN109447146A (zh) * | 2018-10-24 | 2019-03-08 | 厦门美图之家科技有限公司 | 分类优化方法及装置 |
CN110059672A (zh) * | 2019-04-30 | 2019-07-26 | 福州大学 | 一种利用增量学习对显微镜细胞图像检测模型进行增类学习的方法 |
CN110059747A (zh) * | 2019-04-18 | 2019-07-26 | 清华大学深圳研究生院 | 一种网络流量分类方法 |
CN110147836A (zh) * | 2019-05-13 | 2019-08-20 | 腾讯科技(深圳)有限公司 | 模型训练方法、装置、终端及存储介质 |
CN110163234A (zh) * | 2018-10-10 | 2019-08-23 | 腾讯科技(深圳)有限公司 | 一种模型训练方法、装置和存储介质 |
US20190325269A1 (en) * | 2018-04-20 | 2019-10-24 | XNOR.ai, Inc. | Image Classification through Label Progression |
WO2019222401A2 (en) * | 2018-05-17 | 2019-11-21 | Magic Leap, Inc. | Gradient adversarial training of neural networks |
CN111047054A (zh) * | 2019-12-13 | 2020-04-21 | 浙江科技学院 | 一种基于两阶段对抗知识迁移的对抗样例防御方法 |
CN111126573A (zh) * | 2019-12-27 | 2020-05-08 | 深圳力维智联技术有限公司 | 基于个体学习的模型蒸馏改进方法、设备及存储介质 |
CN111144496A (zh) * | 2019-12-27 | 2020-05-12 | 齐齐哈尔大学 | 一种基于混合卷积神经网络的垃圾分类方法 |
CN111242900A (zh) * | 2019-12-31 | 2020-06-05 | 电子科技大学中山学院 | 一种产品合格确定方法、装置、电子设备及存储介质 |
CN111275190A (zh) * | 2020-02-25 | 2020-06-12 | 北京百度网讯科技有限公司 | 神经网络模型的压缩方法及装置、图像处理方法及处理器 |
CN111310684A (zh) * | 2020-02-24 | 2020-06-19 | 东声(苏州)智能科技有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN111337768A (zh) * | 2020-03-02 | 2020-06-26 | 武汉大学 | 变压器油中溶解气体的深度并行故障诊断方法及系统 |
-
2020
- 2020-06-28 CN CN202010607520.4A patent/CN111738436B/zh active Active
Patent Citations (15)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190325269A1 (en) * | 2018-04-20 | 2019-10-24 | XNOR.ai, Inc. | Image Classification through Label Progression |
WO2019222401A2 (en) * | 2018-05-17 | 2019-11-21 | Magic Leap, Inc. | Gradient adversarial training of neural networks |
CN108875787A (zh) * | 2018-05-23 | 2018-11-23 | 北京市商汤科技开发有限公司 | 一种图像识别方法及装置、计算机设备和存储介质 |
CN110163234A (zh) * | 2018-10-10 | 2019-08-23 | 腾讯科技(深圳)有限公司 | 一种模型训练方法、装置和存储介质 |
CN109447146A (zh) * | 2018-10-24 | 2019-03-08 | 厦门美图之家科技有限公司 | 分类优化方法及装置 |
CN110059747A (zh) * | 2019-04-18 | 2019-07-26 | 清华大学深圳研究生院 | 一种网络流量分类方法 |
CN110059672A (zh) * | 2019-04-30 | 2019-07-26 | 福州大学 | 一种利用增量学习对显微镜细胞图像检测模型进行增类学习的方法 |
CN110147836A (zh) * | 2019-05-13 | 2019-08-20 | 腾讯科技(深圳)有限公司 | 模型训练方法、装置、终端及存储介质 |
CN111047054A (zh) * | 2019-12-13 | 2020-04-21 | 浙江科技学院 | 一种基于两阶段对抗知识迁移的对抗样例防御方法 |
CN111126573A (zh) * | 2019-12-27 | 2020-05-08 | 深圳力维智联技术有限公司 | 基于个体学习的模型蒸馏改进方法、设备及存储介质 |
CN111144496A (zh) * | 2019-12-27 | 2020-05-12 | 齐齐哈尔大学 | 一种基于混合卷积神经网络的垃圾分类方法 |
CN111242900A (zh) * | 2019-12-31 | 2020-06-05 | 电子科技大学中山学院 | 一种产品合格确定方法、装置、电子设备及存储介质 |
CN111310684A (zh) * | 2020-02-24 | 2020-06-19 | 东声(苏州)智能科技有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN111275190A (zh) * | 2020-02-25 | 2020-06-12 | 北京百度网讯科技有限公司 | 神经网络模型的压缩方法及装置、图像处理方法及处理器 |
CN111337768A (zh) * | 2020-03-02 | 2020-06-26 | 武汉大学 | 变压器油中溶解气体的深度并行故障诊断方法及系统 |
Non-Patent Citations (5)
Title |
---|
YONGCHENG LIU 等: "Multi-Label Image Classification via Knowledge Distillation from Weakly-Supervised Detection", 《ARXIV:1809.05884》, pages 1 - 9 * |
候卫东: "面向移动应用的人体图像多属性分类算法研究", 《中国优秀硕士学位论文全文数据库 信息科技辑》, pages 138 - 1256 * |
方东祥: "视觉群体感知应用中的视觉隐私保护方法研究", 《中国优秀硕士学位论文全文数据库 信息科技辑》, pages 138 - 233 * |
王世豪: "基于卷积神经网络的目标检测算法其增量学习研究", 《中国优秀硕士学位论文全文数据库 信息科技辑》, pages 138 - 653 * |
蒋树林: "基于分类器逆向学习的最小代价检测规避方法研究", 《中国优秀硕士学位论文全文数据库 信息科技辑》, pages 140 - 367 * |
Cited By (17)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112184508A (zh) * | 2020-10-13 | 2021-01-05 | 上海依图网络科技有限公司 | 一种用于图像处理的学生模型的训练方法及装置 |
CN112184508B (zh) * | 2020-10-13 | 2021-04-27 | 上海依图网络科技有限公司 | 一种用于图像处理的学生模型的训练方法及装置 |
CN112329467A (zh) * | 2020-11-03 | 2021-02-05 | 腾讯科技(深圳)有限公司 | 地址识别方法、装置、电子设备以及存储介质 |
CN112101484B (zh) * | 2020-11-10 | 2021-02-12 | 中国科学院自动化研究所 | 基于知识巩固的增量事件识别方法、系统、装置 |
CN112101484A (zh) * | 2020-11-10 | 2020-12-18 | 中国科学院自动化研究所 | 基于知识巩固的增量事件识别方法、系统、装置 |
CN112613312B (zh) * | 2020-12-18 | 2022-03-18 | 平安科技(深圳)有限公司 | 实体命名识别模型的训练方法、装置、设备及存储介质 |
CN112613312A (zh) * | 2020-12-18 | 2021-04-06 | 平安科技(深圳)有限公司 | 实体命名识别模型的训练方法、装置、设备及存储介质 |
CN112766463A (zh) * | 2021-01-25 | 2021-05-07 | 上海有个机器人有限公司 | 基于知识蒸馏技术优化神经网络模型的方法 |
CN112949433A (zh) * | 2021-02-18 | 2021-06-11 | 北京百度网讯科技有限公司 | 视频分类模型的生成方法、装置、设备和存储介质 |
US20220366226A1 (en) * | 2021-05-17 | 2022-11-17 | Marziehsadat TAHAEI | Methods and systems for compressing a trained neural network and for improving efficiently performing computations of a compressed neural network |
CN113762368A (zh) * | 2021-08-27 | 2021-12-07 | 北京市商汤科技开发有限公司 | 数据蒸馏的方法、装置、电子设备和存储介质 |
CN113919444A (zh) * | 2021-11-10 | 2022-01-11 | 北京市商汤科技开发有限公司 | 目标检测网络的训练方法、目标检测方法及装置 |
CN114359649A (zh) * | 2021-11-22 | 2022-04-15 | 腾讯科技(深圳)有限公司 | 图像处理方法、装置、设备、存储介质及程序产品 |
CN114359649B (zh) * | 2021-11-22 | 2024-03-22 | 腾讯科技(深圳)有限公司 | 图像处理方法、装置、设备、存储介质及程序产品 |
CN114581946A (zh) * | 2022-02-25 | 2022-06-03 | 江西师范大学 | 人群计数方法、装置、存储介质及电子设备 |
CN114973156A (zh) * | 2022-08-02 | 2022-08-30 | 松立控股集团股份有限公司 | 一种基于知识蒸馏的夜间渣土车检测方法 |
CN114973156B (zh) * | 2022-08-02 | 2022-10-25 | 松立控股集团股份有限公司 | 一种基于知识蒸馏的夜间渣土车检测方法 |
Also Published As
Publication number | Publication date |
---|---|
CN111738436B (zh) | 2023-07-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111738436B (zh) | 一种模型蒸馏方法、装置、电子设备及存储介质 | |
CN110866140B (zh) | 图像特征提取模型训练方法、图像搜索方法及计算机设备 | |
CN108334605B (zh) | 文本分类方法、装置、计算机设备及存储介质 | |
CN111523621B (zh) | 图像识别方法、装置、计算机设备和存储介质 | |
CN111191791B (zh) | 基于机器学习模型的图片分类方法、装置及设备 | |
CN111507378A (zh) | 训练图像处理模型的方法和装置 | |
CN110866530A (zh) | 一种字符图像识别方法、装置及电子设备 | |
CN111639755B (zh) | 一种网络模型训练方法、装置、电子设备及存储介质 | |
AU2017101803A4 (en) | Deep learning based image classification of dangerous goods of gun type | |
CN111275046A (zh) | 一种字符图像识别方法、装置、电子设备及存储介质 | |
CN112232355B (zh) | 图像分割网络处理、图像分割方法、装置和计算机设备 | |
CN110837570B (zh) | 对图像数据进行无偏见分类的方法 | |
CN111476806A (zh) | 图像处理方法、装置、计算机设备和存储介质 | |
CN113569895A (zh) | 图像处理模型训练方法、处理方法、装置、设备及介质 | |
CN113761259A (zh) | 一种图像处理方法、装置以及计算机设备 | |
CN112488237A (zh) | 一种分类模型的训练方法及装置 | |
CN113283368B (zh) | 一种模型训练方法、人脸属性分析方法、装置及介质 | |
CN111652320B (zh) | 一种样本分类方法、装置、电子设备及存储介质 | |
CN112749737A (zh) | 图像分类方法及装置、电子设备、存储介质 | |
CN114511733A (zh) | 基于弱监督学习的细粒度图像识别方法、装置及可读介质 | |
CN114299304A (zh) | 一种图像处理方法及相关设备 | |
CN113869234A (zh) | 人脸表情识别方法、装置、设备及存储介质 | |
US11587323B2 (en) | Target model broker | |
Cira et al. | Evaluation of transfer learning techniques with convolutional neural networks (cnns) to detect the existence of roads in high-resolution aerial imagery | |
Kondaveeti et al. | A Transfer Learning Approach to Bird Species Recognition using MobileNetV2 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant | ||
TR01 | Transfer of patent right | ||
TR01 | Transfer of patent right |
Effective date of registration: 20230914 Address after: A515, 5th Floor, Taifeng Commercial Logistics Center, No. 33 Huawei Road, Xiangzhou District, Zhuhai City, Guangdong Province, 519075 Patentee after: Zhuhai Sule Technology Co.,Ltd. Address before: 528400, Xueyuan Road, 1, Shiqi District, Guangdong, Zhongshan Patentee before: University OF ELECTRONIC SCIENCE AND TECHNOLOGY OF CHINA, ZHONGSHAN INSTITUTE |