CN109711544A - 模型压缩的方法、装置、电子设备及计算机存储介质 - Google Patents
模型压缩的方法、装置、电子设备及计算机存储介质 Download PDFInfo
- Publication number
- CN109711544A CN109711544A CN201811476137.9A CN201811476137A CN109711544A CN 109711544 A CN109711544 A CN 109711544A CN 201811476137 A CN201811476137 A CN 201811476137A CN 109711544 A CN109711544 A CN 109711544A
- Authority
- CN
- China
- Prior art keywords
- network
- teacher
- loss function
- data
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Landscapes
- Image Analysis (AREA)
Abstract
本申请提供了模型压缩的方法、装置、电子设备及计算机存储介质。所述方法包括:获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
Description
技术领域
本申请涉及人工智能领域,尤其涉及模型压缩的方法、装置、电子设备及计算机存储介质。
背景技术
近年来,深度学习网络在计算机视觉领域的目标检测应用中取得了巨大的成功。但由于深度学习网络模型往往包含大量的模型参数,计算量大、处理速度慢,其应用也多在云端,在终端落地仍面临巨大的挑战。
为了减少网络模型的冗余,国内外研究人员提出了蒸馏学习算法,在蒸馏学习中,通过将结构复杂的老师网络的知识提炼或者蒸馏到结构简单的学生网络模型,指导学生网络模型的训练,从而实现了对老师网络的压缩。但蒸馏后的学生网络性能不够理想,与老师网络的各方面检测性能仍存在一定差距。并且,当前的蒸馏学习都是基于两阶段(Two-stage)目标检测的网络,对单阶段(One-stage)目标检测中的应用也尚未得到探索。
发明内容
本申请提供了模型压缩的方法、装置、电子设备及计算机存储介质,能够使得模型压缩后得到的学生网络检测性能超越老师网络。
第一方面,提供了一种模型压缩的方法,所述方法包括以下步骤:
获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;
利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;
根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
可选地,在所述获取训练样本数据之前,所述方法还包括:
利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。
可选地,在所述获取训练样本数据之前,所述方法还包括:
获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;
将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。
可选地,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异从而确定的损失函数。
可选地,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中,所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。
可选地,所述自适应蒸馏损失函数的公式为,
ADL=ADW·KL
ADW=(1-e-KL+βT(q))γ
其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。
可选地,所述方法还包括:所述训练后的学生网络进行自学习的过程。
第二方面,提供了一种模型压缩的装置,包括获取单元、训练单元以及反向传播单元,其中,
所述获取单元用于获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;
所述训练单元用于利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;
所述反向传播单元用于根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
可选地,所述训练单元还用于在所述获取单元获取训练样本数据之前,利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。
可选地,所述装置还包括标注单元,
所述标注单元用于在所述获取训练样本数据之前,获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;
所述标注单元还用于将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。
可选地,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异所确定的损失函数。
可选地,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。
可选地,所述自适应蒸馏损失函数的公式为,
ADL=ADW·KL
ADW=(1-e-KL+βT(q))γ
其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。
可选地,所述方法还包括:所述训练后的学生网络进行自学习的过程。
第三方面,提供了一种电子设备,包括处理器、输入设备、输出设备和存储器,所述处理器、输入设备、输出设备和存储器相互连接,其中,所述存储器用于存储计算机程序,所述计算机程序包括程序指令,所述处理器被配置用于调用所述程序指令,执行上述第一方面所述的方法
第四方面,提供了一种计算机可读存储介质,所述计算机存储介质存储有计算机程序,所述计算机程序包括程序指令,所述程序指令当被处理器执行时使所述处理器执行上述第一方面的方法。
基于本申请提供的模型压缩的方法、装置、电子设备以及计算机可读存储介质,通过获取训练样本数据,利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数,根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,从而获得训练后的学生网络。由于自适应蒸馏损失函数中包括控制老师网络难学习的样本和学生网络模型难模仿所述老师网络的样本权重的系数,使得老师网络从训练样本数据中提取的数据结构特征能有针对性的传递到学生网络中,从而使得学生网络的目标检测性能得到大大提升。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请提供的一种模型压缩的方法的流程示意图;
图2是本申请提供的一种模型压缩的方法中第一预测函数输出的概率分布与蒸馏温度参数T之间关系的示意图;
图3是本申请提供的一种模型压缩的方法中正响应样本数量a和负响应样本数量b与学生网络训练结果之间的关系示意图;
图4是本申请提供的一种模型压缩的方法中获得学生网络模型的自适应蒸馏损失函数ADL的流程示意图;
图5是本申请提供的一种模型压缩的装置结构示意图;
图6是本申请提供的一种电子设备结构示意框图。
具体实施方式
下面通过具体实施方式结合附图对本申请作进一步详细说明。在以下的实施方式中,很多细节描述是为了使得本申请能被更好的理解。然而,本领域技术人员可以毫不费力的认识到,其中部分特征在不同情况下是可以省略的,或者可以由其他方法所替代。在某些情况下,本申请相关的一些操作并没有在说明书中显示或描述,这是为了避免本申请的核心部分被过多的描述所淹没。对于本领域技术人员而言,详细描述这些相关操作并不是必要的,他们根据说明书中的描述以及本领域的一般技术知识即可完整了解相关操作。
应当理解,当在本说明书和所附权利要求书中使用术语时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
需要说明的是,在本申请实施例中使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本申请。在本申请实施例和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。
为了使本申请能够被更好的理解,下面对现有的蒸馏学习进行简要介绍。
蒸馏学习(Knowledge Distillation,KD)指的是将训练好的复杂模型中的“知识”迁移到一个结构更为简单的网络中,从而达到模型压缩的目的。蒸馏神经网络取名为蒸馏,其实是一个非常形象的过程。由于水蒸馏的过程是将水沸腾产生的蒸汽导入冷凝管,使之冷却凝结成纯净水,因此如果把样本中的数据结构特征信息和数据本身当作一个混合物,分布信息通过概率分布被分离出来,当温度参数T较大时,相当于用很高的温度将关键的分布信息从原有的数据中分离,接着恢复低温,得到正常的分布信息,最后让二者充分融合,最终可以得到最“纯净的”网络模型。一个简单的蒸馏学习方法的步骤可以是:提升老师网络输出函数softmax表达式中的温度参数T,使得老师网络产生一个合适的“软目标”;采用同样的温度参数T来训练小模型,使得它产生与老师网络相匹配的“软目标”,其中,老师网络和学生网络模型在训练过程中使用的样本数据集均为有标签样本数据集。
相关术语解释:
温度参数(Temperature):蒸馏学习中的温度参数T可以用跑步的例子来解释,例如:某运动员每次跑步均为负重跑步,那么在取下负重正常跑步的时候,就会非常轻松,也可以比其他运动员跑步速度更快。同理,温度参数T就是这个负重包,对于一个结构复杂的老师网络来说,训练结束后往往能够得到很好的学习效果,但是对于一个结构简单的学生网络来说,无法得到很好的学习效果,因此,为了帮助学生网络进行学习,在学生网络的输出函数softmax中增加一个温度参数T,加上这个温度参数以后,错误分类再经过softmax以后错误输出会被“放大”,正确分类会被“缩小”,也就是说,人为增大了训练的难度,一旦将T重新设置为1,分类结果会非常接近于老师网络的分类效果。
软目标(Soft target):软目标指的是老师网络使用带有温度参数T的输出函数softmax产生的输出结果。
硬目标(Hard target):硬目标指的是正常网络训练的目标,也就是有标签样本的真实标签,但是在本申请实施例中,由于使用的是训练样本数据包括有标签样本数据和无标签样本数据,其中无标签样本数据没有真实标签,因此本申请的硬目标指的是老师网络使用没有温度参数T的输出函数softmax产生的输出结果。
图1是本申请提供的一种模型压缩的方法。由图1可知,本申请提供的模型压缩的方法包括以下步骤:
S101:获取训练样本数据,其中,所述训练样本数据包括有标签样本数据。
在本申请具体的实施方式中,所述有标签样本数据(Labeled Data)可以是标注有真实分类结果的样本集。在所述获取训练样本数据之前,所述方法还包括:利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。可以理解的是,利用所述有标签样本数据对老师网络模型进行训练的具体步骤可以是:根据老师网络模型的预测结果与真实标签之间差距计算损失(LOSS),根据LOSS值调节老师网络模型的权重,直到老师网络模型的LOSS值达到某一阈值时,从而获得老师网络。例如,一个5分类问题,输入的一张图片的真实分类结果为第4类,那么这张图片的真实标签可以是y=[0,0,0,1,0],当老师网络模型的预测结果为p=[0.1,0.15,0.05,0.6,0.1]时,预测的分类结果虽然是正确的,但是与真实标签仍存在差距,此时的LOSS=-log(0.6),假设LOSS的阈值为-log(0.95),那么此时的老师网络模型仍需要进一步的训练。因此,通过老师网络模型的LOSS函数来调节网络学习方向,能够获得最终的性能良好的老师网络。应理解,上述举例仅用于说明,并不能构成具体限定。
在本申请具体的实施方式中,在所述获取训练样本数据之前,所述方法还包括:获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。可以理解的是,所述无标签样本数据是没有标注过真实分类结果的数据。有标签样本数据数量少、获取困难,无标签样本数据与有标签样本数据相比,获取方式更加多元、便捷、成本低,只需要使用网络爬虫即可从网络爬取大量的无标签样本数据。因而将老师网络标注后的无标签样本数据作为训练样本数据,可以在使用较少的有标签样本数据的同时,获得性能更加优越的学生网络。
可选地,由于老师网络是用于蒸馏学习的网络,因此老师网络在对无标签样本数据进行标注时,标注结果可以包括软目标和硬目标,其中,软目标是老师网络对无标签样本数据使用第一预测函数进行预测获得的预测结果,硬目标是老师网络对无标签样本数据使用第二预测函数进行预测得到的预测结果,其中,第一预测函数是包含蒸馏学习的温度参数的函数,所述第二预测函数是不包含蒸馏学习的温度参数的函数。其中,所述第一预测函数的公式为
其中,q为第一预测函数输出的预测结果,z为第二预测函数输出的预测结果,T为预设的蒸馏学习温度参数。应理解,第二预测函数指的是正常情况下,神经网络进行预测时,输出的softmax函数,其中,softmax函数的输出结果为概率分布。因此,在第一预测函数公式中增加温度参数T后,第一预测函数输出的预测结果(即为蒸馏学习中的软目标)相比第二预测函数输出的预测结果(即为蒸馏学习中的硬目标)的概率分布更缓和、均匀,数值介于0-1之间。例如,图2是本申请提供的一种模型压缩的方法中第一预测函数输出的概率分布与蒸馏温度参数T之间关系的示意图,其中,横轴代表概率分布中的每个类别依次排列的编号,比如,1代表第1类,2代表第2类等等,纵轴代表输入图片属于每个对应的分类编号的概率值,比如,输入图片属于第1类的概率为0.1,属于第2类的概率为0.2等等。由图2可知,温度参数T的数值越大,软目标的分布越平缓(Soft),或者说,软目标的概率分布数值比硬目标越小。可以理解的是,分布平缓的软目标使得同一张输入图片,学生网络模型经过第一预测函数公式输出的错误分类结果,相比于经过第二预测公式输出的错误分类结果,由于指数函数的单调递增特性,学生网络模型的LOSS计算的值会更大,从而人为增大了训练的难度。并且,同一个样本,用在大规模神经网络(老师网络)上产生的软目标来训练一个小的网络(学生网络模型)时,因为并不是直接标注的一个目标,学生网络模型学习起来会更快收敛。更巧妙的是,本申请使用无标签样本数据产生的硬目标和软目标来训练学生网络,因为老师网络将无标签样本数据结构信息学习结果保存在自己产生的硬目标和软目标中,学生网络模型可以直接从软目标和硬目标中来获得知识,从而极大地提升了学生网络的目标检测性能。
可选地,在本申请提供的模型压缩的方法中,应该使用更多的老师网络标注过的数据进行学生网络的训练,理论上来说,当无标签样本数据全部为老师网络标注过的数据时训练的效果最好。但是,由于老师网络的预测结果是伪标签而不是真实标签,伪标签与真实标签仍然存在一定的误差,因此本申请使用的无标签样本数据中可以包含没有被老师网络标注过的无标签样本数据,其中,老师网络标注过的无标签样本数据可以称为正响应样本数据,老师网络没有标注过的无标签样本数据可以称为负响应样本数据,正响应样本数据数量与负响应样本数据数量的比例可以通过进一步的实验而确定。例如,图3是本申请提供的正响应样本数量a和负响应样本数量b与学生网络模型训练结果之间的关系示意图,其中,平均准确度(Mean Average Precision,mAP)指的是使用coco数据集训练的目标检测模型的测评指标,mAP是目标检测网络模型的检测精度、速度等等多个性能指标的综合指标,也是判别目标检测模型的检测性能最重要的一个指标,mAP的值越大,意味着目标检测模型各方面的综合性能越好,由图2可知,当的值越大,mAP的值越大,学生网络模型训练效果越好,但是当时,mAP的值不再产生变化,也就是说,的大小对学生网络模型的训练效果不再有影响。因此,所述正响应数据数量a与负响应数据数量b之间的关系用公式可以表示为:
应理解,图3以及公式(2)仅仅用于举例说明,并不能构成具体限定,本申请使用的无标签样本数据还可以全部是老师网络标注过的数据。
S102:利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数。
在本申请具体的实施方式中,老师网络是在蒸馏学习中用为学生网络模型提供更加准确数据结构特征的高性能神经网络。其中,学生网络模型是计算速度较快、适合部署到对实时性要求较高的单个神经网络,学生网络模型相比于老师网络,具有更大的运算吞吐量、更简单的网络结构和更少的模型参数。老师网络性能优良、准确率高,但是相对于学生网络模型,老师网络的结构复杂、参数权重多、计算速度较慢。例如:老师网络可以是用于人脸检测的残差神经网络Resnet101,学生网络模型可以是用于人脸检测的Resnet50,其中,老师网络的网络层数为101,学生网络模型的网络层数为50。应理解,上述举例仅用于说明,并不能构成具体限定。
在本申请具体的实施方式中,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异从而确定的损失函数。其中,同一样本数据的学习结果的差异指的是,学生网络使用第一预测函数得到的第一预测结果,与老师网络使用第一预测函数得到的软目标之间的差异,或者是,学生网络使用第二预测函数得到的第二预测结果,与老师网络使用第二预测函数得到的硬目标之间的差异。
在本申请具体的实施方式中,所述训练样本数据包括有标签样本数据和无标签样本数据,因此,利用所述训练样本数据对所述老师网络和学生网络分别进行训练,得到自适应蒸馏损失函数和焦点损失函数的具体步骤可以是:使用所述有标签样本数据对学生网络模型进行训练获得学生网络;使用所述学生网络对所述无标签样本数据进行预测,使用所述第一预测函数获得第一预测结果,使用所述第二预测函数获得第二预测结果;根据所述第一预测结果与所述软目标之间的差异获得软目标的自适应蒸馏损失函数和焦点损失函数,根据所述第二预测结果与所述硬目标之间的差异获得硬目标的自适应蒸馏损失函数和焦点损失函数。
举例来讲,利用所述训练样本数据对所述老师网络和学生网络分别进行训练,得到自适应蒸馏损失函数ADL的具体步骤流程可以如图4所示,首先,使用有标签样本数据训练老师网络模型获得老师网络,使用有标签样本数据训练初级学生网络模型获得学生网络模型;其次,使用老师网络对无标签样本数据集进行预测,使用第一预测函数获得软目标,使用第二预测函数获得硬目标,使用学生网络模型对无标签样本数据集进行预测,使用第一预测函数获得第一预测结果,使用第二预测函数获得第二预测结果;第三,计算学生网络的ADL,根据第一预测结果与软目标之间的差距获得软目标的ADL,根据第二预测结果与硬目标之间的差距获得硬目标的ADL;最后,根据软目标的ADL与硬目标的ADL获得学生网络最终的自适应蒸馏损失函数ADL。同理,焦点损失函数也包括软目标的焦点损失函数和硬目标的焦点损失函数。应理解,蒸馏学习虽然将学生网络的学习目标分成了软目标和硬目标,可以得到软目标损失函数和硬目标损失函数,但是软目标损失函数和硬目标损失函数是使用相同的损失函数公式进行计算的,因此,为了便于理解,本申请下文在对损失函数的具体公式进行阐述时,不再分别针对软目标和硬目标进行分析。并且,图4所示的具体的训练流程仅仅是用于举例,并不能构成具体限定。
在本申请具体的实施方式中,所述学生网络模型的总损失函数包括自蒸馏损失函数和焦点损失函数,其中,所述焦点损失函数包括确定目标类别的焦点损失函数以及确定目标位置的焦点损失函数,所述总损失函数的具体公式为:
L=FL+Lloc+ADL (3)
其中,FL是确定目标类别的焦点损失函数,Lloc是确定目标位置的焦点损失函数,ADL是本申请提供自适应蒸馏损失函数。可以理解的是,公式(3)中的ADL指的是软目标的自适应蒸馏损失函数与硬目标的自适应蒸馏损失函数之和,FL以及Lloc是现有技术中的损失函数,因此本申请不再作赘述。但是,应理解,当前的蒸馏学习都是基于两阶段目标检测的网络进行训练的,而本申请的蒸馏学习是基于单阶段目标检测的网络,由于两阶段目标检测网络的损失函数无法在单阶段目标检测网络中使用,因此本申请提出了与当前蒸馏学习的损失函数不同的公式(3)作为本申请提供的模型压缩的方法的损失函数。
在本申请具体的实施方式中,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中,所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。所述自适应蒸馏损失函数的公式为,
ADL=ADW·KL (4)
ADW=(1-e-KL+βT(q))γ (5)
其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。
在本申请具体的实施方式中,KL是用来描述所述老师网络与学生网络模型预测结果差异的相对熵,也就是说,KL体现了学生网络模型的学习结果与老师网络的预测结果之间的差异,因此KL可以用来控制所述学生网络模型确定的难模仿样本的权重,其中,KL的具体公式为,
其中,q为老师网络预测的软目标或硬目标,p为学生网络模型预测的结果,当KL越大,意味着学生网络模型对这个样本的学习结果与老师网络的预测结果差异越大,也就是说,这个样本是学生网络模型越难模仿的样本。由于学生网络模型的学习是一个动态的过程,因此KL也是随着学生网络模型的学习结果不断自适应调整的值,当学生网络模型对某一样本预测结果与老师网络预测结果之间的差异越来越小时,KL也越来越小,该样本对应的损失函数也越来越小,因此本申请的自适应蒸馏损失函数可以根据学生网络模型的学习结果不断自主调整,从而更有目的性、针对性的训练学生网络模型,使得学生网络的性能得到更大提升,从而超越老师网络。
在本申请具体的实施方式中,T(q)是用来描述所述老师网络预测结果不确定性的老师网络的熵,也就是说,T(q)体现了老师网络对于该样本是否是学生网络难以模仿的样本的判断,因此T(q)可以用于控制所述老师网络确定的难模仿样本的权重,其中,T(q)的具体公式为,
T(q)=-(qlog(q)+(1-q)log(1-q) (7)
其中,老师网络的熵T(q)在q=0.5时达到最大值,当q接近0或1时达到最小值。由于老师网络的熵T(q)反应了老师网络对某一样本的不确定性,因此当q越接近0.5时,老师网络认为这个样本是学生网络模型越难模仿的样本。可以理解的是,由于学生网络使用的损失函数中包括控制老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本的权重,使得学生网络模型难学习的样本和难模仿的样本在损失函数中的影响得到最大化。也就是说,学生网络在训练过程中学习的特征都是难学习的,难模仿的,从而人为加大了学生网络的学习难度,也使得学生网络能获得更多、更“纯粹”的知识特征。因此,本申请提供的模型压缩的方法可以使得学生网络在蒸馏学习的过程中,有目的性的将学习重点放在难学习、难模仿的困难样本中,从而获得更好的学习效果,得到目标检测性能更加优良的学生网络。
S103:根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
在本申请具体的实施方式中,根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,根据所述自适应蒸馏损失函数和所述焦点损失函数调整所述学生网络模型的模型参数,直到所述总损失函数达到预设的阈值,从而获得训练后的学生网络。下表1是使用本申请提供的模型压缩的方法训练的学生网络的目标检测性能指标与老师网络性能指标的对比列表,由表1可知,使用本申请提供的模型压缩的方法,对学生网络模型进行半监督环境下的知识蒸馏,可以使得学生网络的目标检测性能超越老师网络。
表1:学生网络目标检测性能指标与老师网络目标检测性能指标的对比列表
在本申请具体的实施方式中,所述方法还包括:所述训练后的学生网络进行自学习的过程。也就是说,所述学生网络还可以进行自蒸馏(self d distillation)。自蒸馏的具体步骤包括:将训练后的学生网络作为新的老师网络,重新执行步骤S101至步骤S103,得到新的训练后的学生网络。新的训练后的学生网络的框架结构和规模较训练后的学生网络更简单、更小型化。下表2是学生网络模型进行自蒸馏的情况下,使用本申请提供的模型压缩的方法训练的学生网络自蒸馏前后的目标检测性能指标对比列表。由表2可知,使用本申请提供的模型压缩的方法,对学生网络进行半监督环境下的知识自蒸馏,可以使得学生网络的目标检测性能得到极大的提升。
表2:学生网络自蒸馏前后的目标检测性能指标对比列表
上述方法中,通过获取训练样本数据,利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数,根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,从而获得训练后的学生网络。由于自适应蒸馏损失函数中包括控制老师网络难学习的样本和学生网络模型难模仿所述老师网络的样本权重的系数,因此,通过上述方法,使得老师网络从训练样本数据中提取的数据结构特征能有针对性的传递到学生网络中,从而使得学生网络的目标检测性能得到大大提升。
图5是本申请提供的一种模型压缩的装置的结构示意图。由图5可知,本申请提供的模型压缩的装置包括获取单元510、训练单元520、反向传播单元530以及标注单元540,其中,
所述获取单元510用于获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;
所述训练单元520用于利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;
所述反向传播单元530用于根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
在本申请具体的实施方式中,所述有标签样本数据(Labeled Data)可以是标注有真实分类结果的样本集。所述训练单元520还用于在所述获取单元510获取训练样本数据之前,利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。可以理解的是,利用所述有标签样本数据对老师网络模型进行训练的具体步骤可以是:根据老师网络模型的预测结果与真实标签之间差距计算损失(LOSS),根据损失值调节老师网络模型的权重,直到老师网络模型的LOSS值达到某一阈值时,从而获得老师网络。例如,一个5分类问题,输入的一张图片的真实分类结果为第4类,那么这张图片的真实标签可以是y=[0,0,0,1,0],当老师网络模型的预测结果为p=[0.1,0.15,0.05,0.6,0.1]时,预测的分类结果虽然是正确的,但是与真实标签仍存在差距,此时的LOSS=-log(0.6),假设LOSS的阈值为-log(0.95),那么此时的老师网络模型仍需要进一步的训练。因此,通过老师网络模型的LOSS函数来调节网络学习方向,能够获得最终的性能良好的老师网络。应理解,上述举例仅用于说明,并不能构成具体限定。
在本申请具体的实施方式中,所述装置还包括标注单元540,所述标注单元540用于在所述获取单元510获取训练样本数据之前,获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;所述标注单元540还用于将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。可以理解的是,所述无标签样本数据是没有标注过真实分类结果的数据。有标签样本数据数量少、获取困难,无标签样本数据与有标签样本数据相比,获取方式更加多元、便捷、成本低,只需要使用网络爬虫即可从网络爬取大量的无标签样本数据。因而将老师网络标注后的无标签样本数据作为训练样本数据,可以在使用较少的有标签样本数据的同时,获得性能更加优越的学生网络模型。
可选地,由于老师网络是用于蒸馏学习的网络,因此标注单元540使用老师网络在对无标签样本数据进行标注时,标注结果可以包括软目标和硬目标,其中,软目标是老师网络对无标签样本数据使用第一预测函数进行预测获得的预测结果,硬目标是老师网络对无标签样本数据使用第二预测函数进行预测得到的预测结果,其中,第一预测函数是包含蒸馏学习的温度参数的函数,所述第二预测函数是不包含蒸馏学习的温度参数的函数。其中,所述第一预测函数的公式为公式(1),其中,q为第一预测函数输出的预测结果,z为第二预测函数输出的预测结果,T为预设的蒸馏学习温度参数。应理解,第二预测函数指的是正常情况下,神经网络进行预测时,输出的softmax函数,其中,softmax函数的输出结果为概率分布。因此,在第一预测函数公式中增加温度参数T后,第一预测函数输出的预测结果(即为蒸馏学习中的软目标)相比第二预测函数输出的预测结果(即为蒸馏学习中的硬目标)的概率分布更缓和、均匀,数值介于0-1之间。例如,图2是本申请提供的一种模型压缩的方法中第一预测函数输出的概率分布与蒸馏温度参数T之间关系的示意图,其中,横轴代表概率分布中的每个类别依次排列的编号,比如,1代表第1类,2代表第2类等等,纵轴代表输入图片属于每个对应的分类编号的概率值,比如,输入图片属于第1类的概率为0.1,属于第2类的概率为0.2等等。由图2可知,温度参数T的数值越大,软目标的分布越平缓(Soft),或者说,软目标的概率分布数值比硬目标越小。可以理解的是,分布平缓的软目标使得同一张输入图片,学生网络模型经过第一预测函数公式输出的错误分类结果,相比于经过第二预测公式输出的错误分类结果,由于指数函数的单调递增特性,学生网络模型的LOSS计算的值会更大,从而人为增大了训练的难度。并且,同一个样本,用在大规模神经网络(老师网络)上产生的软目标来训练一个小的网络(学生网络模型)时,因为并不是直接标注的一个目标,学生网络模型学习起来会更快收敛。更巧妙的是,本申请使用无标签样本数据产生的硬目标和软目标来训练学生网络,因为老师网络将无标签样本数据结构信息学习结果保存在自己产生的硬目标和软目标中,学生网络可以直接从软目标和硬目标中来获得知识,从而极大地提升了学生网络的训练速度。
可选地,在本申请提供的模型压缩的装置中,应该使用更多的老师网络标注过的数据进行学生网络的训练,理论上来说,当无标签样本数据全部为老师网络标注过的数据时训练的效果最好。但是,由于老师网络的预测结果是伪标签而不是真实标签,伪标签与真实标签仍然存在一定的误差,因此本申请使用的无标签样本数据中可以包含没有被老师网络标注过的样本数据,其中,老师网络标注过的数据可以称为正响应样本数据,老师网络没有标注过的数据可以称为负响应样本数据,正响应样本数据数量与负响应样本数据数量的比例可以通过进一步的实验而确定。例如,图3是本申请提供的正响应样本数量a和负响应样本数量b与学生网络模型训练结果之间的关系示意图,其中,平均准确度(Mean AveragePrecision,mAP)指的是使用coco数据集训练的目标检测模型的测评指标,mAP是目标检测网络模型的检测精度、速度等等多个性能指标的综合指标,也是判别目标检测模型的检测性能最重要的一个指标,mAP的值越大,意味着目标检测模型各方面的综合性能越好,由图2可知,当的值越大,mAP的值越大,学生网络模型训练效果越好,但是当时,mAP的值不再产生变化,也就是说,的大小对学生网络模型的训练效果不再有影响。因此,所述正响应数据数量a与负响应数据数量b之间的关系用公式可以表示为:
应理解,图3以及公式(2)仅仅用于举例说明,并不能构成具体限定,本申请使用的无标签样本数据还可以全部是老师网络标注过的数据。
在本申请具体的实施方式中,老师网络是在蒸馏学习中用为学生网络模型提供更加准确数据结构特征的高性能神经网络。其中,学生网络模型是计算速度较快但性能较差的、适合部署到对实时性要求较高的单个神经网络,学生网络模型相比于老师网络,具有更大的运算吞吐量、更简单的网络结构和更少的模型参数。老师网络性能优良、准确率高,但是相对于学生网络模型,老师网络的结构复杂、参数权重多、计算速度较慢。例如:老师网络可以是用于人脸检测的残差神经网络Resnet101,学生网络模型可以是用于人脸检测的Resnet50,其中,老师网络的网络层数为101,学生网络模型的网络层数为50。应理解,上述举例仅用于说明,并不能构成具体限定。
在本申请具体的实施方式中,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异从而确定的损失函数。其中,同一样本数据的学习结果的差异指的是,学生网络使用第一预测函数得到的第一预测结果,与老师网络使用第一预测函数得到的软目标之间的差异,或者是,学生网络使用第二预测函数得到的第二预测结果,与老师网络使用第二预测函数得到的硬目标之间的差异。
在本申请具体的实施方式中,所述训练样本数据包括有标签样本数据和无标签样本数据,因此,利用所述训练样本数据对所述老师网络和学生网络分别进行训练,得到自适应蒸馏损失函数和焦点损失函数的具体步骤可以是:使用所述有标签样本数据对学生网络模型进行训练获得学生网络;使用所述学生网络对所述无标签样本数据进行预测,使用所述第一预测函数获得第一预测结果,使用所述第二预测函数获得第二预测结果;根据所述第一预测结果与所述软目标之间的差异获得软目标的自适应蒸馏损失函数和焦点损失函数,根据所述第二预测结果与所述硬目标之间的差异获得硬目标的自适应蒸馏损失函数和焦点损失函数。
举例来讲,利用所述训练样本数据对所述老师网络和学生网络分别进行训练,得到自适应蒸馏损失函数ADL的具体步骤流程可以如图4所示,首先,使用有标签样本数据训练老师网络模型获得老师网络,使用有标签样本数据训练初级学生网络模型获得学生网络模型;其次,使用老师网络对无标签样本数据集进行预测,使用第一预测函数获得软目标,使用第二预测函数获得硬目标,使用学生网络模型对无标签样本数据集进行预测,使用第一预测函数获得第一预测结果,使用第二预测函数获得第二预测结果;第三,计算学生网络的ADL,根据第一预测结果与软目标之间的差距获得软目标的ADL,根据第二预测结果与硬目标之间的差距获得硬目标的ADL;最后,根据软目标的ADL与硬目标的ADL获得学生网络最终的自适应蒸馏损失函数ADL。同理,焦点损失函数也包括软目标的焦点损失函数和硬目标的焦点损失函数。应理解,蒸馏学习虽然将学生网络的学习目标分成了软目标和硬目标,可以得到软目标损失函数和硬目标损失函数,但是软目标损失函数和硬目标损失函数是使用相同的损失函数公式进行计算的,因此,为了便于理解,本申请下文在对损失函数的具体公式进行阐述时,不再分别针对软目标和硬目标进行分析。并且,图4所示的具体的训练流程仅仅是用于举例,并不能构成具体限定。
在本申请具体的实施方式中,所述学生网络模型的总损失函数包括自蒸馏损失函数和焦点损失函数,其中,所述焦点损失函数包括确定目标类别的焦点损失函数以及确定目标位置的焦点损失函数,所述总损失函数的具体公式如公式(3)所示,其中,FL是确定目标类别的焦点损失函数,Lloc是确定目标位置的焦点损失函数,ADL是本申请提供自适应蒸馏损失函数。可以理解的是,公式(3)中的ADL指的是软目标的自适应蒸馏损失函数与硬目标的自适应蒸馏损失函数之和,FL以及Lloc是现有技术中的损失函数,因此本申请不再作赘述,但是,应理解,当前的蒸馏学习都是基于两阶段目标检测的网络进行训练的,而本申请的蒸馏学习是基于单阶段目标检测的网络,由于两阶段目标检测网络的损失函数LOSS无法在单阶段目标检测网络中使用,因此本申请提出了与当前蒸馏学习损失函数不同的公式(3)作为本申请提供的模型压缩的装置的损失函数。
在本申请具体的实施方式中,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中,所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。所述自适应蒸馏损失函数的公式为公式(4)以及公式(5),其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。
在本申请具体的实施方式中,KL是用来描述所述老师网络与学生网络预测结果差异的相对熵,也就是说,KL体现了学生网络的学习结果与老师网络的预测结果之间的差异,因此KL可以用来控制所述学生网络确定的难模仿样本的权重,其中,KL的具体公式为公式(6),其中,q为老师网络预测的软目标或硬目标,p为学生网络预测的结果,当KL越大,意味着学生网络对这个样本的学习结果与老师网络的预测结果差异越大,也就是说,这个样本是学生网络模型越难模仿的样本。由于学生网络模型的学习是一个动态的过程,因此KL也是随着学生网络模型的学习结果不断自适应调整的值,当学生网络模型对某一样本预测结果与老师网络预测结果之间的差异越来越小时,KL也越来越小,该样本对应的损失函数也越来越小,因此本申请的自适应蒸馏损失函数可以根据学生网络模型的学习结果不断自主调整,从而更有目的性、针对性的训练学生网络模型,使得学生网络的性能得到更大提升,从而超越老师网络。
在本申请具体的实施方式中,T(q)是用来描述所述老师网络预测结果不确定性的老师网络的熵,也就是说,T(q)体现了老师网络对于该样本是否是学生网络难以模仿的样本的判断,因此T(q)可以用于控制所述老师网络确定的难模仿样本的权重,T(q)的具体公式为公式(7),其中,老师网络的熵T(q)在q=0.5时达到最大值,当q接近0或1时达到最小值。由于老师网络的熵T(q)反应了老师网络对某一样本的不确定性,因此当q越接近0.5时,老师网络认为这个样本是学生网络模型越难模仿的样本。可以理解的是,由于学生网络使用的损失函数中包括控制老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本的权重,使得学生网络模型难学习的样本和难模仿的样本在损失函数中的影响得到最大化。也就是说,学生网络在训练过程中学习的特征都是难学习的,难模仿的,从而人为加大了学生网络的学习难度,也使得学生网络能获得更多、更“纯粹”的知识特征。因此,本申请提供的模型压缩的装置可以使得学生网络在蒸馏学习的过程中,有目的性的将学习重点放在难学习、难模仿的困难样本中,从而获得更好的学习效果,得到目标检测性能更加优良的学生网络。
在本申请具体的实施方式中,根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播可以是,根据所述自适应蒸馏损失函数和所述焦点损失函数调整所述学生网络模型的模型参数,直到所述总损失函数达到预设的阈值,从而获得训练后的学生网络模型。下表1是使用本申请提供的模型压缩的装置训练的学生网络的目标检测性能指标与老师网络性能指标的对比列表,由表1可知,使用本申请提供的模型压缩的装置,对学生网络模型进行半监督环境下的知识蒸馏,可以使得学生网络的目标检测性能超越老师网络。
在本申请具体的实施方式中,所述装置还包括:所述训练后的学生网络进行自学习的过程。也就是说,所述学生网络还可以进行自蒸馏(self d distillation)。自蒸馏的具体步骤包括:将训练后的学生网络作为新的老师网络,重新执行步骤S101至步骤S103,得到新的训练后的学生网络。新的训练后的学生网络的框架结构和规模较训练后的学生网络更简单、更小型化。表2是学生网络模型进行自蒸馏的情况下,使用本申请提供的模型压缩的装置训练的学生网络模型自蒸馏前后的目标检测性能指标对比列表。由表2可知,使用本申请提供的模型压缩的装置,对学生网络进行半监督环境下的知识自蒸馏,可以使得学生网络的目标检测性能得到极大的提升。
上述装置中,通过获取训练样本数据,利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数,根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,从而获得训练后的学生网络。由于自适应蒸馏损失函数中包括控制老师网络难学习的样本和学生网络模型难模仿所述老师网络的样本权重的系数,使得老师网络从训练样本数据中提取的数据结构特征能有针对性的传递到学生网络中,从而使得学生网络的目标检测性能得到大大提升。
图6是本申请提供的一种电子设备结构示意框图。如图6所示,本实施例中的电子设备可以包括:一个或多个处理器601;一个或多个输入设备602,一个或多个输出设备603和存储器604。上述处理器601、输入设备602、输出设备603和存储器604通过总线605连接。存储器602用于存储计算机程序,所述计算机程序包括程序指令,处理器601用于执行存储器602存储的程序指令。
在本申请实施例中,所称处理器601可以是中央处理单元(Central ProcessingUnit,CPU),该处理器还可以是其他通用处理器、数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
输入设备602可以包括触控板、指纹采传感器(用于采集用户的指纹信息和指纹的方向信息)、麦克风等,输出设备603可以包括显示器(LCD等)、扬声器等。
存储器604可以包括易失性存储器,例如随机存取存储器(Random AccessMmemory,RAM);存储器也可以包括非易失性存储器,例如只读存储器(Read-Only Memory,ROM)、快闪存储器(Flash Memory)、硬盘(Hard Disk Drive,HDD)或固态硬盘(Solid-StateDrive,SSD),存储器还可以包括上述种类的存储器的组合。存储器604可以采用集中式存储,也可以采用分布式存储,此处不作具体限定。可以理解的是,存储器604用于存储计算机程序,例如:计算机程序指令等。在本申请实施例中,存储器604可以向处理器601提供指令和数据。
具体实现中,本申请实施例中所描述的处理器601、输入设备602、输出设备603、存储器604、总线605可执行本申请提供的模型压缩的方法的任一实施例中所描述的实现方式,在此不再赘述。
本申请提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序包括程序指令,所述程序指令被处理器执行时实现本申请提供的模型压缩的方法的任一实施例中所描述的实现方式,在此不再赘述。
所述计算机可读存储介质可以是前述任一实施例所述的终端的内部存储单元,例如终端的硬盘或内存。所述计算机可读存储介质也可以是所述终端的外部存储设备,例如所述终端上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(SecureDigital,SD)卡,闪存卡(Flash Card)等。进一步地,所述计算机可读存储介质还可以既包括所述终端的内部存储单元也包括外部存储设备。所述计算机可读存储介质用于存储所述计算机程序以及所述终端所需的其他程序和数据。所述计算机可读存储介质还可以用于暂时地存储已经输出或者将要输出的数据。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,上述描述的设备和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的电子设备、装置和方法,可以通过其它的方式实现。例如,以上所描述的电子设备实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口、装置或单元的间接耦合或通信连接,也可以是电的,机械的或其它的形式连接。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本申请实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分,或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器、随机存取存储器、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围为准。
Claims (10)
1.一种模型压缩的方法,其特征在于,包括:
获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;
利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;
根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
2.根据权利要求1所述的方法,其特征在于,在所述获取训练样本数据之前,所述方法还包括:
利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。
3.根据权利要求1或2所述的方法,其特征在于,在所述获取训练样本数据之前,所述方法还包括:
获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;
将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。
4.根据权利要求1所述的方法,其特征在于,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异从而确定的损失函数。
5.根据权利要求1或4所述的方法,其特征在于,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中,所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。
6.根据权利要求5所述的方法,其特征在于,所述自适应蒸馏损失函数的公式为,
ADL=ADW·KL
ADW=(1-e-KL+βT(q))γ
其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。
7.根据权利要求1所述的方法,其特征在于,所述方法还包括:所述训练后的学生网络进行自学习的过程。
8.一种模型压缩的装置,其特征在于,包括获取单元、训练单元以及反向传播单元,其中,
所述获取单元用于获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;
所述训练单元用于利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;
所述反向传播单元用于根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。
9.一种电子设备,其特征在于,包括处理器、输入设备、输出设备和存储器,所述处理器、输入设备、输出设备和存储器相互连接,其中,所述存储器用于存储计算机程序,所述计算机程序包括程序指令,所述处理器被配置用于调用所述程序指令,执行如权利要求1至7任一所述模型压缩的方法的操作。
10.一种计算机存储介质,用于存储计算机可读取的指令,其特征在于,所述指令被执行时执行如权利要求1至7任一所述模型压缩的方法的操作。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811476137.9A CN109711544A (zh) | 2018-12-04 | 2018-12-04 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811476137.9A CN109711544A (zh) | 2018-12-04 | 2018-12-04 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN109711544A true CN109711544A (zh) | 2019-05-03 |
Family
ID=66254611
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201811476137.9A Pending CN109711544A (zh) | 2018-12-04 | 2018-12-04 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109711544A (zh) |
Cited By (37)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110232411A (zh) * | 2019-05-30 | 2019-09-13 | 北京百度网讯科技有限公司 | 模型蒸馏实现方法、装置、系统、计算机设备及存储介质 |
CN110246487A (zh) * | 2019-06-13 | 2019-09-17 | 苏州思必驰信息科技有限公司 | 用于单通道的语音识别模型的优化方法及系统 |
CN110276413A (zh) * | 2019-06-28 | 2019-09-24 | 深圳前海微众银行股份有限公司 | 一种模型压缩方法及装置 |
CN110348572A (zh) * | 2019-07-09 | 2019-10-18 | 上海商汤智能科技有限公司 | 神经网络模型的处理方法及装置、电子设备、存储介质 |
CN110472494A (zh) * | 2019-06-21 | 2019-11-19 | 深圳壹账通智能科技有限公司 | 脸部特征提取模型训练方法、脸部特征提取方法、装置、设备及存储介质 |
CN110648048A (zh) * | 2019-08-21 | 2020-01-03 | 阿里巴巴集团控股有限公司 | 小程序签约事件处理方法、装置、服务器及可读存储介质 |
CN110825970A (zh) * | 2019-11-07 | 2020-02-21 | 浙江同花顺智能科技有限公司 | 一种信息推荐方法、装置、设备及计算机可读存储介质 |
CN110837846A (zh) * | 2019-10-12 | 2020-02-25 | 深圳力维智联技术有限公司 | 一种图像识别模型的构建方法、图像识别方法及装置 |
CN111091177A (zh) * | 2019-11-12 | 2020-05-01 | 腾讯科技(深圳)有限公司 | 一种模型压缩方法、装置、电子设备和存储介质 |
CN111145026A (zh) * | 2019-12-30 | 2020-05-12 | 第四范式(北京)技术有限公司 | 一种反洗钱模型的训练方法及装置 |
CN111312271A (zh) * | 2020-02-28 | 2020-06-19 | 云知声智能科技股份有限公司 | 一种提高收敛速度和处理性能的模型压缩方法和系统 |
CN111461212A (zh) * | 2020-03-31 | 2020-07-28 | 中国科学院计算技术研究所 | 一种用于点云目标检测模型的压缩方法 |
CN111553479A (zh) * | 2020-05-13 | 2020-08-18 | 鼎富智能科技有限公司 | 一种模型蒸馏方法、文本检索方法及装置 |
CN111724867A (zh) * | 2020-06-24 | 2020-09-29 | 中国科学技术大学 | 分子属性测定方法、装置、电子设备及存储介质 |
CN111753878A (zh) * | 2020-05-20 | 2020-10-09 | 济南浪潮高新科技投资发展有限公司 | 一种网络模型部署方法、设备及介质 |
CN111783962A (zh) * | 2020-07-24 | 2020-10-16 | Oppo广东移动通信有限公司 | 数据处理方法、数据处理装置、存储介质与电子设备 |
CN111783606A (zh) * | 2020-06-24 | 2020-10-16 | 北京百度网讯科技有限公司 | 一种人脸识别网络的训练方法、装置、设备及存储介质 |
CN111898707A (zh) * | 2020-08-24 | 2020-11-06 | 鼎富智能科技有限公司 | 模型训练方法、文本分类方法、电子设备及存储介质 |
CN111967573A (zh) * | 2020-07-15 | 2020-11-20 | 中国科学院深圳先进技术研究院 | 数据处理方法、装置、设备及计算机可读存储介质 |
CN112163450A (zh) * | 2020-08-24 | 2021-01-01 | 中国海洋大学 | 基于s3d学习算法的高频地波雷达船只目标检测方法 |
CN112651975A (zh) * | 2020-12-29 | 2021-04-13 | 奥比中光科技集团股份有限公司 | 一种轻量化网络模型的训练方法、装置及设备 |
CN112712052A (zh) * | 2021-01-13 | 2021-04-27 | 安徽水天信息科技有限公司 | 一种机场全景视频中微弱目标的检测识别方法 |
CN112784677A (zh) * | 2020-12-04 | 2021-05-11 | 上海芯翌智能科技有限公司 | 模型训练方法及装置、存储介质、计算设备 |
RU2749970C1 (ru) * | 2019-10-24 | 2021-06-21 | Бейдзин Сяоми Интиллиджент Текнолоджи Ко., ЛТД. | Способ сжатия модели нейронной сети, а также способ и устройство для перевода языкового корпуса |
CN113191479A (zh) * | 2020-01-14 | 2021-07-30 | 华为技术有限公司 | 联合学习的方法、系统、节点及存储介质 |
CN113219357A (zh) * | 2021-04-28 | 2021-08-06 | 东软睿驰汽车技术(沈阳)有限公司 | 电池包健康状态计算方法、系统及电子设备 |
CN113469977A (zh) * | 2021-07-06 | 2021-10-01 | 浙江霖研精密科技有限公司 | 一种基于蒸馏学习机制的瑕疵检测装置、方法、存储介质 |
WO2021197223A1 (zh) * | 2020-11-13 | 2021-10-07 | 平安科技(深圳)有限公司 | 一种模型压缩方法、系统、终端及存储介质 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN113554059A (zh) * | 2021-06-23 | 2021-10-26 | 北京达佳互联信息技术有限公司 | 图片处理方法、装置、电子设备及存储介质 |
WO2022001232A1 (zh) * | 2020-10-30 | 2022-01-06 | 平安科技(深圳)有限公司 | 一种问答数据增强方法、装置、计算机设备及存储介质 |
WO2022083157A1 (zh) * | 2020-10-22 | 2022-04-28 | 北京迈格威科技有限公司 | 目标检测方法、装置及电子设备 |
CN114492793A (zh) * | 2022-01-27 | 2022-05-13 | 北京百度网讯科技有限公司 | 一种模型训练和样本生成方法、装置、设备及存储介质 |
WO2022104550A1 (zh) * | 2020-11-17 | 2022-05-27 | 华为技术有限公司 | 模型蒸馏训练的方法及相关装置和设备、可读存储介质 |
US20220199258A1 (en) * | 2019-09-26 | 2022-06-23 | Lunit Inc. | Training method for specializing artificial interlligence model in institution for deployment, and apparatus for training artificial intelligence model |
WO2022242076A1 (en) * | 2021-05-17 | 2022-11-24 | Huawei Technologies Co., Ltd. | Methods and systems for compressing trained neural network and for improving efficiently performing computations of compressed neural network |
CN116863278A (zh) * | 2023-08-25 | 2023-10-10 | 摩尔线程智能科技(北京)有限责任公司 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108030488A (zh) * | 2017-11-30 | 2018-05-15 | 北京医拍智能科技有限公司 | 基于卷积神经网络的心律失常的检测系统 |
US20180136633A1 (en) * | 2016-05-20 | 2018-05-17 | Moog Inc. | Outer space digital logistics system |
CN108875693A (zh) * | 2018-07-03 | 2018-11-23 | 北京旷视科技有限公司 | 一种图像处理方法、装置、电子设备及其存储介质 |
-
2018
- 2018-12-04 CN CN201811476137.9A patent/CN109711544A/zh active Pending
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180136633A1 (en) * | 2016-05-20 | 2018-05-17 | Moog Inc. | Outer space digital logistics system |
CN108030488A (zh) * | 2017-11-30 | 2018-05-15 | 北京医拍智能科技有限公司 | 基于卷积神经网络的心律失常的检测系统 |
CN108875693A (zh) * | 2018-07-03 | 2018-11-23 | 北京旷视科技有限公司 | 一种图像处理方法、装置、电子设备及其存储介质 |
Cited By (58)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110232411B (zh) * | 2019-05-30 | 2022-08-23 | 北京百度网讯科技有限公司 | 模型蒸馏实现方法、装置、系统、计算机设备及存储介质 |
CN110232411A (zh) * | 2019-05-30 | 2019-09-13 | 北京百度网讯科技有限公司 | 模型蒸馏实现方法、装置、系统、计算机设备及存储介质 |
CN110246487B (zh) * | 2019-06-13 | 2021-06-22 | 思必驰科技股份有限公司 | 用于单通道的语音识别模型的优化方法及系统 |
CN110246487A (zh) * | 2019-06-13 | 2019-09-17 | 苏州思必驰信息科技有限公司 | 用于单通道的语音识别模型的优化方法及系统 |
KR102385463B1 (ko) | 2019-06-21 | 2022-04-12 | 원 커넥트 스마트 테크놀로지 컴퍼니 리미티드 (썬전) | 얼굴 특징 추출 모델 학습 방법, 얼굴 특징 추출 방법, 장치, 디바이스 및 저장 매체 |
WO2020253127A1 (zh) * | 2019-06-21 | 2020-12-24 | 深圳壹账通智能科技有限公司 | 脸部特征提取模型训练方法、脸部特征提取方法、装置、设备及存储介质 |
JP2021532434A (ja) * | 2019-06-21 | 2021-11-25 | ワン・コネクト・スマート・テクノロジー・カンパニー・リミテッド・(シェンチェン) | 顔特徴抽出モデル訓練方法、顔特徴抽出方法、装置、機器および記憶媒体 |
JP6994588B2 (ja) | 2019-06-21 | 2022-01-14 | ワン・コネクト・スマート・テクノロジー・カンパニー・リミテッド・(シェンチェン) | 顔特徴抽出モデル訓練方法、顔特徴抽出方法、装置、機器および記憶媒体 |
KR20200145827A (ko) * | 2019-06-21 | 2020-12-30 | 원 커넥트 스마트 테크놀로지 컴퍼니 리미티드 (썬전) | 얼굴 특징 추출 모델 학습 방법, 얼굴 특징 추출 방법, 장치, 디바이스 및 저장 매체 |
CN110472494A (zh) * | 2019-06-21 | 2019-11-19 | 深圳壹账通智能科技有限公司 | 脸部特征提取模型训练方法、脸部特征提取方法、装置、设备及存储介质 |
CN110276413A (zh) * | 2019-06-28 | 2019-09-24 | 深圳前海微众银行股份有限公司 | 一种模型压缩方法及装置 |
CN110276413B (zh) * | 2019-06-28 | 2023-10-31 | 深圳前海微众银行股份有限公司 | 一种模型压缩方法及装置 |
CN110348572A (zh) * | 2019-07-09 | 2019-10-18 | 上海商汤智能科技有限公司 | 神经网络模型的处理方法及装置、电子设备、存储介质 |
CN110648048A (zh) * | 2019-08-21 | 2020-01-03 | 阿里巴巴集团控股有限公司 | 小程序签约事件处理方法、装置、服务器及可读存储介质 |
US20220199258A1 (en) * | 2019-09-26 | 2022-06-23 | Lunit Inc. | Training method for specializing artificial interlligence model in institution for deployment, and apparatus for training artificial intelligence model |
CN110837846A (zh) * | 2019-10-12 | 2020-02-25 | 深圳力维智联技术有限公司 | 一种图像识别模型的构建方法、图像识别方法及装置 |
CN110837846B (zh) * | 2019-10-12 | 2023-10-31 | 深圳力维智联技术有限公司 | 一种图像识别模型的构建方法、图像识别方法及装置 |
RU2749970C1 (ru) * | 2019-10-24 | 2021-06-21 | Бейдзин Сяоми Интиллиджент Текнолоджи Ко., ЛТД. | Способ сжатия модели нейронной сети, а также способ и устройство для перевода языкового корпуса |
US11556723B2 (en) | 2019-10-24 | 2023-01-17 | Beijing Xiaomi Intelligent Technology Co., Ltd. | Neural network model compression method, corpus translation method and device |
CN110825970A (zh) * | 2019-11-07 | 2020-02-21 | 浙江同花顺智能科技有限公司 | 一种信息推荐方法、装置、设备及计算机可读存储介质 |
CN111091177A (zh) * | 2019-11-12 | 2020-05-01 | 腾讯科技(深圳)有限公司 | 一种模型压缩方法、装置、电子设备和存储介质 |
CN111091177B (zh) * | 2019-11-12 | 2022-03-08 | 腾讯科技(深圳)有限公司 | 一种模型压缩方法、装置、电子设备和存储介质 |
CN111145026B (zh) * | 2019-12-30 | 2023-05-09 | 第四范式(北京)技术有限公司 | 一种反洗钱模型的训练方法及装置 |
CN111145026A (zh) * | 2019-12-30 | 2020-05-12 | 第四范式(北京)技术有限公司 | 一种反洗钱模型的训练方法及装置 |
CN113191479B (zh) * | 2020-01-14 | 2024-09-24 | 华为技术有限公司 | 联合学习的方法、系统、节点及存储介质 |
CN113191479A (zh) * | 2020-01-14 | 2021-07-30 | 华为技术有限公司 | 联合学习的方法、系统、节点及存储介质 |
CN111312271A (zh) * | 2020-02-28 | 2020-06-19 | 云知声智能科技股份有限公司 | 一种提高收敛速度和处理性能的模型压缩方法和系统 |
CN111461212B (zh) * | 2020-03-31 | 2023-04-07 | 中国科学院计算技术研究所 | 一种用于点云目标检测模型的压缩方法 |
CN111461212A (zh) * | 2020-03-31 | 2020-07-28 | 中国科学院计算技术研究所 | 一种用于点云目标检测模型的压缩方法 |
CN111553479B (zh) * | 2020-05-13 | 2023-11-03 | 鼎富智能科技有限公司 | 一种模型蒸馏方法、文本检索方法及装置 |
CN111553479A (zh) * | 2020-05-13 | 2020-08-18 | 鼎富智能科技有限公司 | 一种模型蒸馏方法、文本检索方法及装置 |
CN111753878A (zh) * | 2020-05-20 | 2020-10-09 | 济南浪潮高新科技投资发展有限公司 | 一种网络模型部署方法、设备及介质 |
CN111724867A (zh) * | 2020-06-24 | 2020-09-29 | 中国科学技术大学 | 分子属性测定方法、装置、电子设备及存储介质 |
CN111724867B (zh) * | 2020-06-24 | 2022-09-09 | 中国科学技术大学 | 分子属性测定方法、装置、电子设备及存储介质 |
CN111783606B (zh) * | 2020-06-24 | 2024-02-20 | 北京百度网讯科技有限公司 | 一种人脸识别网络的训练方法、装置、设备及存储介质 |
CN111783606A (zh) * | 2020-06-24 | 2020-10-16 | 北京百度网讯科技有限公司 | 一种人脸识别网络的训练方法、装置、设备及存储介质 |
CN111967573A (zh) * | 2020-07-15 | 2020-11-20 | 中国科学院深圳先进技术研究院 | 数据处理方法、装置、设备及计算机可读存储介质 |
CN111783962A (zh) * | 2020-07-24 | 2020-10-16 | Oppo广东移动通信有限公司 | 数据处理方法、数据处理装置、存储介质与电子设备 |
CN112163450A (zh) * | 2020-08-24 | 2021-01-01 | 中国海洋大学 | 基于s3d学习算法的高频地波雷达船只目标检测方法 |
CN111898707A (zh) * | 2020-08-24 | 2020-11-06 | 鼎富智能科技有限公司 | 模型训练方法、文本分类方法、电子设备及存储介质 |
WO2022083157A1 (zh) * | 2020-10-22 | 2022-04-28 | 北京迈格威科技有限公司 | 目标检测方法、装置及电子设备 |
WO2022001232A1 (zh) * | 2020-10-30 | 2022-01-06 | 平安科技(深圳)有限公司 | 一种问答数据增强方法、装置、计算机设备及存储介质 |
WO2021197223A1 (zh) * | 2020-11-13 | 2021-10-07 | 平安科技(深圳)有限公司 | 一种模型压缩方法、系统、终端及存储介质 |
WO2022104550A1 (zh) * | 2020-11-17 | 2022-05-27 | 华为技术有限公司 | 模型蒸馏训练的方法及相关装置和设备、可读存储介质 |
CN116438546A (zh) * | 2020-11-17 | 2023-07-14 | 华为技术有限公司 | 模型蒸馏训练的方法及相关装置和设备、可读存储介质 |
CN112784677A (zh) * | 2020-12-04 | 2021-05-11 | 上海芯翌智能科技有限公司 | 模型训练方法及装置、存储介质、计算设备 |
CN112651975A (zh) * | 2020-12-29 | 2021-04-13 | 奥比中光科技集团股份有限公司 | 一种轻量化网络模型的训练方法、装置及设备 |
CN112712052A (zh) * | 2021-01-13 | 2021-04-27 | 安徽水天信息科技有限公司 | 一种机场全景视频中微弱目标的检测识别方法 |
CN113219357A (zh) * | 2021-04-28 | 2021-08-06 | 东软睿驰汽车技术(沈阳)有限公司 | 电池包健康状态计算方法、系统及电子设备 |
CN113219357B (zh) * | 2021-04-28 | 2024-07-16 | 东软睿驰汽车技术(沈阳)有限公司 | 电池包健康状态计算方法、系统及电子设备 |
WO2022242076A1 (en) * | 2021-05-17 | 2022-11-24 | Huawei Technologies Co., Ltd. | Methods and systems for compressing trained neural network and for improving efficiently performing computations of compressed neural network |
CN113554059A (zh) * | 2021-06-23 | 2021-10-26 | 北京达佳互联信息技术有限公司 | 图片处理方法、装置、电子设备及存储介质 |
CN113469977B (zh) * | 2021-07-06 | 2024-01-12 | 浙江霖研精密科技有限公司 | 一种基于蒸馏学习机制的瑕疵检测装置、方法、存储介质 |
CN113469977A (zh) * | 2021-07-06 | 2021-10-01 | 浙江霖研精密科技有限公司 | 一种基于蒸馏学习机制的瑕疵检测装置、方法、存储介质 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN114492793A (zh) * | 2022-01-27 | 2022-05-13 | 北京百度网讯科技有限公司 | 一种模型训练和样本生成方法、装置、设备及存储介质 |
CN116863278A (zh) * | 2023-08-25 | 2023-10-10 | 摩尔线程智能科技(北京)有限责任公司 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
CN116863278B (zh) * | 2023-08-25 | 2024-01-26 | 摩尔线程智能科技(北京)有限责任公司 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109711544A (zh) | 模型压缩的方法、装置、电子设备及计算机存储介质 | |
CN109902546B (zh) | 人脸识别方法、装置及计算机可读介质 | |
WO2023280065A1 (zh) | 一种面向跨模态通信系统的图像重建方法及装置 | |
Zhang et al. | A return-cost-based binary firefly algorithm for feature selection | |
CN110633745B (zh) | 一种基于人工智能的图像分类训练方法、装置及存储介质 | |
CN105513591B (zh) | 用lstm循环神经网络模型进行语音识别的方法和装置 | |
WO2020107806A1 (zh) | 一种推荐方法及装置 | |
CN109496322B (zh) | 信用评价方法和装置以及梯度渐进决策树参数调整方法和装置 | |
EP3144859A2 (en) | Model training method and apparatus, and data recognizing method | |
CN107358293A (zh) | 一种神经网络训练方法及装置 | |
CN111259738B (zh) | 人脸识别模型构建方法、人脸识别方法及相关装置 | |
CN108388876A (zh) | 一种图像识别方法、装置以及相关设备 | |
CN108062572A (zh) | 一种基于DdAE深度学习模型的水电机组故障诊断方法与系统 | |
WO2019223250A1 (zh) | 一种确定剪枝阈值的方法及装置、模型剪枝方法及装置 | |
US20230023271A1 (en) | Method and apparatus for detecting face, computer device and computer-readable storage medium | |
CN112631560B (zh) | 一种推荐模型的目标函数的构建方法及终端 | |
CN110244689A (zh) | 一种基于判别性特征学习方法的auv自适应故障诊断方法 | |
CN106628097A (zh) | 一种基于改进径向基神经网络的船舶设备故障诊断方法 | |
CN108875836B (zh) | 一种基于深度多任务学习的简单-复杂活动协同识别方法 | |
CN112418302A (zh) | 一种任务预测方法及装置 | |
CN114037945A (zh) | 一种基于多粒度特征交互的跨模态检索方法 | |
CN110490028A (zh) | 基于深度学习的人脸识别网络训练方法、设备及存储介质 | |
CN116343080A (zh) | 一种动态稀疏关键帧视频目标检测方法、装置及存储介质 | |
CN109242089B (zh) | 递进监督深度学习神经网络训练方法、系统、介质和设备 | |
CN109190471B (zh) | 基于自然语言描述的视频监控行人搜索的注意力模型方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20190503 |