CN111950638A - 基于模型蒸馏的图像分类方法、装置和电子设备 - Google Patents

基于模型蒸馏的图像分类方法、装置和电子设备 Download PDF

Info

Publication number
CN111950638A
CN111950638A CN202010817719.XA CN202010817719A CN111950638A CN 111950638 A CN111950638 A CN 111950638A CN 202010817719 A CN202010817719 A CN 202010817719A CN 111950638 A CN111950638 A CN 111950638A
Authority
CN
China
Prior art keywords
output result
model
loss function
calculating
label
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010817719.XA
Other languages
English (en)
Other versions
CN111950638B (zh
Inventor
陈宝林
黄炜
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Xiamen Meitu Technology Co Ltd
Original Assignee
Xiamen Meitu Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Xiamen Meitu Technology Co Ltd filed Critical Xiamen Meitu Technology Co Ltd
Priority to CN202010817719.XA priority Critical patent/CN111950638B/zh
Publication of CN111950638A publication Critical patent/CN111950638A/zh
Application granted granted Critical
Publication of CN111950638B publication Critical patent/CN111950638B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例提供了一种基于模型蒸馏的图像分类方法、装置和电子设备,涉及图像处理技术领域。该方法首先获取待处理图像。接着将待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。如此,将复杂的教师模型提炼成低复杂度、精度损失较小的学生模型,并且仅需将复杂度更小的学生模型应用到移动端即可,减少了移动端存储模型所需要的空间,减小了计算量,从而在保证处理结果的准确性的前提下,提高了图像处理在移动端的实现速率。

Description

基于模型蒸馏的图像分类方法、装置和电子设备
技术领域
本申请涉及图像处理技术领域,具体而言,涉及一种基于模型蒸馏的图像分类方法、装置和电子设备。
背景技术
目前,常常基于深度神经网络对图像进行处理,例如,对图像进行分类、分割等,但是由于常用的深度神经网络模型结构复杂,训练过程往往复杂耗时,且需要的存储空间较大且计算复杂,在移动端上使用时,处理过程迟缓。
如何在保证处理结果的准确性的前提下,提高图像处理在移动终端的实现速率是值得研究的问题。
发明内容
本申请提供了一种基于模型蒸馏的图像分类方法、装置和电子设备,以解决上述问题。
本申请的实施例可以这样实现:
第一方面,本申请实施例提供一种基于模型蒸馏的图像分类方法,所述方法包括:
获取待处理图像;
将所述待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,所述学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。
在可选的实施方式中,所述学生模型通过以下步骤训练得到:
获取原始图像和所述原始图像对应的标签,其中,所述标签通过对所述原始图像进行预先设定得到;
将所述原始图像作为训练样本,输入预先训练好的教师模型和预先构建好的学生网络中,得到所述教师模型输出的第一输出结果和所述学生网络输出的第二输出结果;
依据所述第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值;
依据所述损失值,采用反向传播算法迭代更新所述学生网络的参数,直至迭代更新次数达到预设阈值,得到训练后的所述学生模型。
在可选的实施方式中,所述第一输出结果包括第一中间层输出结果及第一最终层输出结果,所述第二输出结果包括第二中间层输出结果及第二最终层输出结果,所述损失函数包括第一局部损失函数、第二局部损失函数及全局损失函数;
所述依据所述第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值的步骤包括:
依据所述第一中间层输出结果及所述标签,计算所述第一局部损失函数的第一输出值;
依据所述第二中间层输出结果及所述标签,计算所述第二局部损失函数的第二输出值;
依据所述第一最终层输出结果及所述第二最终层输出结果,计算所述全局损失函数的第三输出值;
计算所述第一输出值、所述第二输出值及所述第三输出值的和,得到所述损失值。
在可选的实施方式中,所述依据所述第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值的步骤包括:
依据所述标签对所述第一输出结果进行修正,得到修正后的第一输出结果;
依据所述修正后的第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值。
在可选的实施方式中,所述依据所述第一中间层输出结果及所述标签,计算所述第一局部损失函数的第一输出值的步骤包括:
获取预先构建的负样本集及预先构建的全零标签,其中,所述负样本集包括多个负样本;
获取所有所述负样本的权重,得到权重矩阵,其中,所述权重矩阵为根据所述标签对所有负样本的权重进行初始化后得到,或者,所述权重矩阵为根据所述损失值,采用反向传播算法对所有负样本的权重进行迭代更新后得到;
将所述第一中间层输出结果输入全连接层,利用所述全连接层对所述第一中间层输出结果进行特征提取,并计算特征提取后的第一中间层输出结果的L2范数,得到第一特征向量;
计算所述第一特征向量与所述权重矩阵的乘积,得到第一初步结果;
计算所述第一初步结果与所述全零标签的交叉熵损失函数,得到所述第一局部损失函数的第一输出值。
在可选的实施方式中,所述依据所述第二中间层输出结果及所述标签,计算所述第二局部损失函数的第二输出值的步骤包括:
获取预先构建的负样本集及预先构建的全零标签,其中,所述负样本集包括多个负样本;
获取所有所述负样本的权重,得到权重矩阵,其中,所述权重矩阵为根据所述标签对所有负样本的权重进行初始化后得到,或者,所述权重矩阵为根据所述损失值,采用反向传播算法对所有负样本的权重进行迭代更新后得到;
将所述第二中间层输出结果输入全连接层,利用所述全连接层对所述第二中间层输出结果进行特征提取,并计算特征提取后的第二中间层输出结果的L2范数,得到第二特征向量;
计算所述第二特征向量与所述权重矩阵的乘积,得到第二初步结果;
计算所述第二初步结果与所述全零标签的交叉熵损失函数,得到所述第二局部损失函数的第二输出值。
在可选的实施方式中,所述依据所述第一最终层输出结果及所述第二最终层输出结果,计算所述全局损失函数的第三输出值的步骤包括:
依据所述第一最终层输出结果及所述第二最终层输出结果,按照以下公式计算所述全局损失函数的第三输出值:
Figure BDA0002633326750000041
其中,ai为第i个所述第二最终层输出结果,yi为第i个所述第一最终层输出结果,m为所述第二最终层输出结果或所述第一最终层输出结果的个数。
第二方面,本申请实施例提供一种基于模型蒸馏的图像分类装置,所述装置包括:
获取模块,用于获取待处理图像;
分类模块,用于将所述待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,所述学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。
第三方面,本申请实施例提供一种电子设备,所述电子设备包括处理器、存储器及总线,所述存储器存储有所述处理器可执行的机器可读指令,当电子设备运行时,所述处理器及所述存储器之间通过总线通信,所述处理器执行所述机器可读指令,以执行前述实施方式任意一项所述的基于模型蒸馏的图像分类方法的步骤。
第四方面,本申请实施例提供一种可读存储介质,所述可读存储介质中存储有计算机程序,所述计算机程序被执行时实现前述实施方式任意一项所述的基于模型蒸馏的图像分类方法。
本申请实施例提供了一种基于模型蒸馏的图像分类方法、装置和电子设备。该方法首先获取待处理图像。接着将待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。如此,将复杂的教师模型提炼成低复杂度、精度损失较小的学生模型,并且仅需将复杂度更小的学生模型应用到移动端即可,减少了移动端存储模型所需要的空间,减小了计算量,从而在保证处理结果的准确性的前提下,提高了图像处理在移动端的实现速率。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1为本申请实施例提供的电子设备的结构框图。
图2为本申请实施例提供的基于模型蒸馏的图像分类方法的流程图。
图3为本申请实施例提供的学生模型的训练方法流程图。
图4为本申请实施例提供的基于模型蒸馏的图像分类装置的功能模块框图。
图标:100-电子设备;110-存储器;120-处理器;130-基于模型蒸馏的图像分类装置;131-获取模块;132-分类模块;140-通信单元。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本申请实施例的组件可以以各种不同的配置来布置和设计。
因此,以下对在附图中提供的本申请的实施例的详细描述并非旨在限制要求保护的本申请的范围,而是仅仅表示本申请的选定实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。
此外,若出现术语“第一”、“第二”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
需要说明的是,在不冲突的情况下,本申请的实施例中的特征可以相互结合。
如背景技术所介绍,目前,常常基于深度神经网络对图像进行处理,例如,对图像进行分类、分割等,但是由于常用的深度神经网络模型结构复杂,训练过程往往复杂耗时,且需要的存储空间较大且计算复杂,在移动终端上使用时,处理过程迟缓。
如何在保证处理结果的准确性的前提下,提高图像处理在移动终端的实现速率是值得研究的问题。
有鉴于此,本申请实施例提供了一种基于模型蒸馏的图像分类方法、装置和电子设备,该方法通过采用预先训练好的复杂模型(称为教师模型)的输出作为监督信号去训练另一个简单的网络。将复杂模型提炼成低复杂度、精度损失较小的深度网络小模型(称为学生模型),以最大程度的减小模型复杂度,减少模型存储需要的空间,加速模型的训练过程和应用过程。下面对上述方案进行详细阐述。
请参阅图1,图1为本申请实施例提供的一种电子设备100的结构框图。设备可以包括处理器120、存储器110、基于模型蒸馏的图像分类装置130及通信单元140,存储器110存储有处理器120可执行的机器可读指令,当电子设备100运行时,处理器120及存储器110之间通过总线通信,处理器120执行机器可读指令,并执行基于模型蒸馏的图像分类方法的步骤。
存储器110、处理器120以及通信单元140各元件相互之间直接或间接地电性连接,以实现信号的传输或交互。
例如,这些元件相互之间可通过一条或多条通讯总线或信号线实现电性连接。基于模型蒸馏的图像分类装置130包括至少一个可以软件或固件(firmware)的形式存储于存储器110中的软件功能模块。处理器120用于执行存储器110中存储的可执行模块,例如基于模型蒸馏的图像分类装置130所包括的软件功能模块或计算机程序。
其中,存储器110可以是,但不限于,随机读取存储器(Random ACCess memory,RAM),只读存储器(Read Only Memory,ROM),可编程只读存储器(Programmable Read-OnlyMemory,PROM),可擦除只读存储器(Erasable Programmable Read-Only Memory,EPROM),电可擦除只读存储器(Electric Erasable Programmable Read-Only Memory,EEPROM)等。
处理器120可以是一种集成电路芯片,具有信号处理能力。上述处理器120可以是通用处理器,包括中央处理器(Central Processing Unit,简称CPU)、网络处理器(NetworkProcessor,简称NP)等。
还可以是数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
本申请实施例中,存储器110用于存储程序,处理器120用于在接收到执行指令后,执行程序。本申请实施例任一实施方式所揭示的流程定义的方法可以应用于处理器120中,或者由处理器120实现。
通信单元140用于通过网络建立电子设备100与其他电子设备之间的通信连接,并用于通过网络收发数据。
在一些实施例中,网络可以是任何类型的有线或者无线网络,或者是他们的结合。仅作为示例,网络可以包括有线网络、无线网络、光纤网络、远程通信网络、内联网、因特网、局域网(Local Area Network,LAN)、广域网(Wide Area Network,WAN)、无线局域网(Wireless Local Area Networks,WLAN)、城域网(Metropolitan Area Network,MAN)、广域网(Wide Area Network,WAN)、公共电话交换网(Public Switched Telephone Network,PSTN)、蓝牙网络、ZigBEE网络、或近场通信(Near Field Communication,NFC)网络等,或其任意组合。
在本申请实施例中,电子设备100可以是但不限于智能手机、个人电脑、平板电脑等具有处理功能的设备。
可以理解,图1所示的结构仅为示意。电子设备100还可以具有比图1所示更多或者更少的组件,或者具有与图1所示不同的配置。图1所示的各组件可以采用硬件、软件或其组合实现。
基于上述电子设备100的实现架构,本申请实施例提供了一种基于模型蒸馏的图像分类方法,请结合参阅2,图2为本申请实施例提供的基于模型蒸馏的图像分类方法的流程图。下面结合图2所示的具体流程进行详细描述。
步骤S1,获取待处理图像。
步骤S2,将待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。
其中,待处理图像可以由当前电子设备100预先存储在存储器110中,当需要时从存储器110中获取即可,待处理图像也可以由当前电子设备100实时拍摄得到。
图像分类的目标是根据输入图片且根据预定义类别分配标签,因此,训练时,可以人工对每张原始图像进行标注,分配标签,例如,一张包含“狗”的图像,其标签可以是狗:95%,猫:4%,熊猫:1%。
当训练完成后,作为一种可能的实施场景,将一张待分类图片输入该训练好的学生模型中,则学生模型可为该待分类图片分配一个标签,例如,狗。如此,完成图像的分类。
需要说明的是,上述方案也可应用于其他实施场景,例如,应用于图像分割、图像美白等等。在其他实施场景中,将待处理图像输入学生模型中进行图像分割或图像美白等,即可得到图像分割或图像美白后的结果图像。
可以理解的是,上述学生模型可以是其他电子设备中预先训练得到,之后迁移至当前电子设备100的,也可以是在当前电子设备100中预先训练,并存储得到的。
请结合参阅图3,作为一种可能的实施方式,学生模型可以通过以下步骤训练得到:
步骤S100,获取原始图像和原始图像对应的标签,其中,标签通过对原始图像进行预先设定得到。
步骤S200,将原始图像作为训练样本,输入预先训练好的教师模型和预先构建好的学生网络中,得到教师模型输出的第一输出结果和学生网络输出的第二输出结果。
步骤S300,依据第一输出结果、第二输出结果及标签计算损失函数的损失值。
步骤S400,依据损失值,采用反向传播算法迭代更新学生网络的参数,直至迭代更新次数达到预设阈值,得到训练后的学生模型。
其中,原始图像可以是CiFar100数据集、MNIST数据集或MPEG数据集等其他数据集。可选地,本申请实施例以CiFar100数据集为训练样本,对学生模型进行训练。
进一步地,上述教师模型为通过上述任意一种或多种数据集预先训练好的模型,该教师模型可以是ResNet50,即具有50层深度的残差网络。该教师模型也可以是MobileNetV3_L,该MobileNetV3_L为标准的MobileNetV3的大模型。该教师模型还可以是MobileNetV2+,该MobileNetV2+网络结构将标准的MobileNetV2的宽度系数乘上1.5所得到的。
作为一种可选的实施方式,当教师模型为ResNet50时,对应的学生网络可以是resnet14x4。该resnet14x4学生网络是本申请实施例针对深度残差网络(Deep residualnetwork,ResNet)结构进行改造使其适应于CiFar100的分类数据集,网络的输入维度为32x32,其中resnet14x4表示具有14层深度的残差网络,且网络中三组基本块的输入通道数分别为64,128,256。
当教师模型为MobileNetV3_L时,对应的学生网络可以是MobileNetV3_S。
当教师模型为MobileNetV2+时,对应的学生网络可以是resnet14x4。
作为另一种可选的实施方式,当教师模型为ResNet50时,对应的学生网络还可以是MobileNetV2+。
如此,将训练好的结构复杂、计算量大但是性能优秀的教师模型,对结构相对简单、计算量较小的学生网络进行指导,以提升学生网络的性能。
现有的蒸馏方案大多通过拟合教师模型和学生网络的中间层输出或最后几层输出的相似程度来提升学生网络的精度,学生网络的精度效果很大程度上取决于使用者对其的压缩程度以及与教师模型的网络结构的相似性,高压缩的学生网络与教师模型还是存在较大精度损失,如何确保训练好的学生模型精度良好的同时,压缩学生模型对应网络的大小、减少学生模型前向推理的时延。
因此,在上述基础上,本申请实施例重新构建了损失函数,以确保训练好的学生模型精度良好的同时,压缩学生模型对应网络的大小、减少学生模型前向推理的时延。
作为一种可选的实施方式,第一输出结果包括第一中间层输出结果及第一最终层输出结果,第二输出结果包括第二中间层输出结果及第二最终层输出结果,损失函数包括第一局部损失函数、第二局部损失函数及全局损失函数。
在上述实施方式的基础上,可通过以下步骤实现步骤S300,以计算损失函数的损失值:依据第一中间层输出结果及标签,计算第一局部损失函数的第一输出值。依据第二中间层输出结果及标签,计算第二局部损失函数的第二输出值。依据第一最终层输出结果及第二最终层输出结果,计算全局损失函数的第三输出值。计算第一输出值、第二输出值及第三输出值的和,得到损失值。
应当理解,在其它实施例中,本申请实施例实现步骤S300中的部分步骤的顺序可以根据实际需要相互交换。例如,可以先计算第一输出值、再计算第二输出值、接着再计算第三输出值。还可以先计算第二输出值、再计算第三输出值、接着再计算第一输出值。还可以同时计算第一输出值、第二输出值及第三输出值。
可选地,上述第一中间层可以是教师模型包括的倒数第二层特征层,即第一最终层的前一层特征层。上述第二中间层可以是学生网络包括的倒数第二层特征层,即第二最终层的前一层特征层。
如此,可通过教师模型与学生模型中的不同层输出的结果共同计算损失值,帮助学生模型更好地学习到教师模型中的参数,以提高训练好的学生模型的精度。
由于教师模型输出的结果可能与训练样本的真实的标签不对应,可能导致学生模型学到错的数据,因此,作为另一种可选的实施方式,在计算损失函数的损失值,实现步骤S300时,还可以先对第一输出结果和第二输出结果进行修正,使得学生模型学到正确的数据,进一步提高学生模型的精度。
例如,依据标签对第一输出结果进行修正,得到修正后的第一输出结果。依据修正后的第一输出结果、第二输出结果及标签计算损失函数的损失值。
作为一种可能的实施方式,可通过以下方式对第一输出结果和第二输出结果进行修正:获取第一输出结果中第一最终层输出结果包括的至少一个输出结果,按照大小对至少一个输出结果进行排序。比较第一最终层输出结果中的最大值对应的标签与标签是否一致,若不一致,则将第一最终层输出结果中的次大值对应的标签与最大值对应的标签互换,将次大值对应的标签作为第一最终层输出结果。若一致,则不做更改。
例如,第一最终层输出结果包括:猫:50%,狗:49%,熊猫1%。同时,若训练样本对应的真实的标签为狗,由于最大值对应的标签与真实标签不一致,此时将次大值对应的标签:“狗”代替最大值对应的标签。即,最后的第一最终层输出结果为“狗:49%”,而非“猫:50%”。
如此,当教师模型输出的标签与真实的标签不一致时,则对教师模型输出的第一最终层输出结果进行修正,使得训练得到的学生模型的准确性更高。
需要说明的是,若对教师模型输出的第一最终层输出结果进行了修正,则利用修正后的第一输出结果、第二输出结果及标签计算损失函数的损失值。其原理及计算过程与利用不进行修正的得到的第一输出结果、第二输出结果及标签计算损失函数的损失值的原理及计算过程一致,在此不做赘述。
进一步地,作为一种可选的实施方式,可通过以下步骤依据第一中间层输出结果及标签,计算第一局部损失函数的第一输出值:
获取预先构建的负样本集及预先构建的全零标签,其中,负样本集包括多个负样本。获取所有负样本的权重,得到权重矩阵,其中,权重矩阵为根据标签对所有负样本的权重进行初始化后得到,或者,权重矩阵为根据损失值,采用反向传播算法对所有负样本的权重进行迭代更新后得到。
将第一中间层输出结果输入全连接层,利用全连接层对第一中间层输出结果进行特征提取,并计算特征提取后的第一中间层输出结果的L2范数,得到第一特征向量。
计算第一特征向量与权重矩阵的乘积,得到第一初步结果。计算第一初步结果与全零标签的交叉熵损失函数,得到第一局部损失函数的第一输出值。
其中,当分类数为多类时,可以转为二分类的情况,即把除自身以外的所有类别当成负样本。预先构建的负样本集中含有所有类别的大量负样本,假设原始图像所属的数据集中共有D张图片、N个类别,则负样本集为DxN’的均匀分布矩阵,其中N’可以取大于或等于N的值。
同样以原始图像所属的数据集中共有D张图片、N个类别为例,可选地,可采用Xavier初始化方法,利用标签对所有负样本的权重进行初始化,得到的DxN’的随机初始化均匀分布矩阵即为上述权重矩阵。全零标签可以是:建筑:0,车:0,树木:0。可根据N’的数量确定全零标签的个数,并根据N’的类别确定全零标签包括的类别,在此不做限定。
如此,通过上述步骤计算第一局部损失函数的第一输出值,可在使用反向传播算法更新学生模型的参数时,指导权重矩阵的更新,以起到指导学生模型的作用,使得学生模型尽可能的区分不同类别。
进一步地,作为一种可选的实施方式,依据第二中间层输出结果及标签,计算第二局部损失函数的第二输出值的步骤包括:
获取预先构建的负样本集及预先构建的全零标签,其中,负样本集包括多个负样本。获取所有负样本的权重,得到权重矩阵,其中,权重矩阵为根据标签对所有负样本的权重进行初始化后得到,或者,权重矩阵为根据损失值,采用反向传播算法对所有负样本的权重进行迭代更新后得到。
将第二中间层输出结果输入全连接层,利用全连接层对第二中间层输出结果进行特征提取,并计算特征提取后的第二中间层输出结果的L2范数,得到第二特征向量。
计算第二特征向量与权重矩阵的乘积,得到第二初步结果。计算第二初步结果与全零标签的交叉熵损失函数,得到第二局部损失函数的第二输出值。
上述依据第二中间层输出结果及标签,计算第二局部损失函数的第二输出值的原理及过程与计算第一局部损失函数的第一输出值的原理及过程相同,在此不做赘述。
同样地,通过上述步骤计算第二局部损失函数的第二输出值,可在使用反向传播算法更新学生模型的参数时,指导权重矩阵的更新,以起到指导学生模型的作用,使得学生模型尽可能的区分不同类别。
作为一种可选的实施方式,可按照以下步骤依据第一最终层输出结果及第二最终层输出结果,按照以下公式计算全局损失函数的第三输出值:
Figure BDA0002633326750000161
其中,ai为第i个第二最终层输出结果,yi为第i个第一最终层输出结果,m为第二最终层输出结果或第一最终层输出结果的个数。
如此,通过上述全局损失函数,可让学生模型的输出结果与教师模型的输出结果尽量的接近,以提高学生模型的准确性。同时,通过上述三种损失函数共同监督训练学生网络,使得高压缩的学生模型与教师模型的网络结构的相似性更高,从而在确保学生模型精度的同时,进一步压缩学生模型的大小并减少训练时前向推理的时延。
基于同一发明构思,请结合参阅图4,本申请实施例中还提供了与上述基于模型蒸馏的图像分类方法对应的基于模型蒸馏的图像分类装置130,装置包括:
获取模块131,用于获取待处理图像。
分类模块132,用于将待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。
由于本申请实施例中的装置解决问题的原理与本申请实施例上述基于模型蒸馏的图像分类方法相似,因此装置的实施原理可以参见方法的实施原理,重复之处不再赘述。
本申请实施例也提供了一种可读存储介质,可读存储介质中存储有计算机程序,计算机程序被执行时实现上述的基于模型蒸馏的图像分类方法。
综上所述,本申请实施例提供了一种基于模型蒸馏的图像分类方法、装置、电子设备和可读存储介质。该方法首先获取待处理图像。接着将待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。如此,将复杂的教师模型提炼成低复杂度、精度损失较小的学生模型,以减小模型的复杂度,减少了模型存储需要的空间,同时加速了模型的训练过程,从而在保证处理结果的准确性的前提下,提高了图像处理在移动端的实现速率。同时,本申请还通过重新构建损失函数,确保训练好的学生模型精度良好的同时,压缩学生模型对应网络的大小、减少学生模型前向推理的时延。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以所述权利要求的保护范围为准。

Claims (10)

1.一种基于模型蒸馏的图像分类方法,其特征在于,所述方法包括:
获取待处理图像;
将所述待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,所述学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。
2.根据权利要求1所述的基于模型蒸馏的图像分类方法,其特征在于,所述学生模型通过以下步骤训练得到:
获取原始图像和所述原始图像对应的标签,其中,所述标签通过对所述原始图像进行预先设定得到;
将所述原始图像作为训练样本,输入预先训练好的教师模型和预先构建好的学生网络中,得到所述教师模型输出的第一输出结果和所述学生网络输出的第二输出结果;
依据所述第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值;
依据所述损失值,采用反向传播算法迭代更新所述学生网络的参数,直至迭代更新次数达到预设阈值,得到训练后的所述学生模型。
3.根据权利要求2所述的基于模型蒸馏的图像分类方法,其特征在于,所述第一输出结果包括第一中间层输出结果及第一最终层输出结果,所述第二输出结果包括第二中间层输出结果及第二最终层输出结果,所述损失函数包括第一局部损失函数、第二局部损失函数及全局损失函数;
所述依据所述第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值的步骤包括:
依据所述第一中间层输出结果及所述标签,计算所述第一局部损失函数的第一输出值;
依据所述第二中间层输出结果及所述标签,计算所述第二局部损失函数的第二输出值;
依据所述第一最终层输出结果及所述第二最终层输出结果,计算所述全局损失函数的第三输出值;
计算所述第一输出值、所述第二输出值及所述第三输出值的和,得到所述损失值。
4.根据权利要求2所述的基于模型蒸馏的图像分类方法,其特征在于,所述依据所述第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值的步骤包括:
依据所述标签对所述第一输出结果进行修正,得到修正后的第一输出结果;
依据所述修正后的第一输出结果、所述第二输出结果及所述标签计算所述损失函数的损失值。
5.根据权利要求3所述的基于模型蒸馏的图像分类方法,其特征在于,所述依据所述第一中间层输出结果及所述标签,计算所述第一局部损失函数的第一输出值的步骤包括:
获取预先构建的负样本集及预先构建的全零标签,其中,所述负样本集包括多个负样本;
获取所有所述负样本的权重,得到权重矩阵,其中,所述权重矩阵为根据所述标签对所有负样本的权重进行初始化后得到,或者,所述权重矩阵为根据所述损失值,采用反向传播算法对所有负样本的权重进行迭代更新后得到;
将所述第一中间层输出结果输入全连接层,利用所述全连接层对所述第一中间层输出结果进行特征提取,并计算特征提取后的第一中间层输出结果的L2范数,得到第一特征向量;
计算所述第一特征向量与所述权重矩阵的乘积,得到第一初步结果;
计算所述第一初步结果与所述全零标签的交叉熵损失函数,得到所述第一局部损失函数的第一输出值。
6.根据权利要求3所述的基于模型蒸馏的图像分类方法,其特征在于,所述依据所述第二中间层输出结果及所述标签,计算所述第二局部损失函数的第二输出值的步骤包括:
获取预先构建的负样本集及预先构建的全零标签,其中,所述负样本集包括多个负样本;
获取所有所述负样本的权重,得到权重矩阵,其中,所述权重矩阵为根据所述标签对所有负样本的权重进行初始化后得到,或者,所述权重矩阵为根据所述损失值,采用反向传播算法对所有负样本的权重进行迭代更新后得到;
将所述第二中间层输出结果输入全连接层,利用所述全连接层对所述第二中间层输出结果进行特征提取,并计算特征提取后的第二中间层输出结果的L2范数,得到第二特征向量;
计算所述第二特征向量与所述权重矩阵的乘积,得到第二初步结果;
计算所述第二初步结果与所述全零标签的交叉熵损失函数,得到所述第二局部损失函数的第二输出值。
7.根据权利要求3所述的基于模型蒸馏的图像分类方法,其特征在于,所述依据所述第一最终层输出结果及所述第二最终层输出结果,计算所述全局损失函数的第三输出值的步骤包括:
依据所述第一最终层输出结果及所述第二最终层输出结果,按照以下公式计算所述全局损失函数的第三输出值:
Figure FDA0002633326740000041
其中,ai为第i个所述第二最终层输出结果,yi为第i个所述第一最终层输出结果,m为所述第二最终层输出结果或所述第一最终层输出结果的个数。
8.一种基于模型蒸馏的图像分类装置,其特征在于,所述装置包括:
获取模块,用于获取待处理图像;
分类模块,用于将所述待处理图像输入学生模型中进行分类,得到分类后的分类结果,其中,所述学生模型是利用预先训练的教师模型及预设的损失函数对预先构建的学生网络进行训练得到。
9.一种电子设备,其特征在于,所述电子设备包括处理器、存储器及总线,所述存储器存储有所述处理器可执行的机器可读指令,当电子设备运行时,所述处理器及所述存储器之间通过总线通信,所述处理器执行所述机器可读指令,以执行权利要求1-7任意一项所述的基于模型蒸馏的图像分类方法的步骤。
10.一种可读存储介质,其特征在于,所述可读存储介质中存储有计算机程序,所述计算机程序被执行时实现权利要求1-7任意一项所述的基于模型蒸馏的图像分类方法。
CN202010817719.XA 2020-08-14 2020-08-14 基于模型蒸馏的图像分类方法、装置和电子设备 Active CN111950638B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010817719.XA CN111950638B (zh) 2020-08-14 2020-08-14 基于模型蒸馏的图像分类方法、装置和电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010817719.XA CN111950638B (zh) 2020-08-14 2020-08-14 基于模型蒸馏的图像分类方法、装置和电子设备

Publications (2)

Publication Number Publication Date
CN111950638A true CN111950638A (zh) 2020-11-17
CN111950638B CN111950638B (zh) 2024-02-06

Family

ID=73343784

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010817719.XA Active CN111950638B (zh) 2020-08-14 2020-08-14 基于模型蒸馏的图像分类方法、装置和电子设备

Country Status (1)

Country Link
CN (1) CN111950638B (zh)

Cited By (15)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112508120A (zh) * 2020-12-18 2021-03-16 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品
CN112528109A (zh) * 2020-12-01 2021-03-19 中科讯飞互联(北京)信息科技有限公司 一种数据分类方法、装置、设备及存储介质
CN112668716A (zh) * 2020-12-29 2021-04-16 奥比中光科技集团股份有限公司 一种神经网络模型的训练方法及设备
CN112949786A (zh) * 2021-05-17 2021-06-11 腾讯科技(深圳)有限公司 数据分类识别方法、装置、设备及可读存储介质
CN113159085A (zh) * 2020-12-30 2021-07-23 北京爱笔科技有限公司 分类模型的训练及基于图像的分类方法、相关装置
CN113392938A (zh) * 2021-07-30 2021-09-14 广东工业大学 一种分类模型训练方法、阿尔茨海默病分类方法及装置
CN113411425A (zh) * 2021-06-21 2021-09-17 深圳思谋信息科技有限公司 视频超分模型构建处理方法、装置、计算机设备和介质
CN113408570A (zh) * 2021-05-08 2021-09-17 浙江智慧视频安防创新中心有限公司 一种基于模型蒸馏的图像类别识别方法、装置、存储介质及终端
CN113408571A (zh) * 2021-05-08 2021-09-17 浙江智慧视频安防创新中心有限公司 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端
CN113657523A (zh) * 2021-08-23 2021-11-16 科大讯飞股份有限公司 一种图像目标分类方法、装置、设备及存储介质
CN113762368A (zh) * 2021-08-27 2021-12-07 北京市商汤科技开发有限公司 数据蒸馏的方法、装置、电子设备和存储介质
CN114693995A (zh) * 2022-04-14 2022-07-01 北京百度网讯科技有限公司 应用于图像处理的模型训练方法、图像处理方法和设备
WO2022141859A1 (zh) * 2020-12-31 2022-07-07 平安科技(深圳)有限公司 图像检测方法、装置、电子设备及存储介质
CN115294407A (zh) * 2022-09-30 2022-11-04 山东大学 基于预习机制知识蒸馏的模型压缩方法及系统
WO2023169334A1 (zh) * 2022-03-09 2023-09-14 北京字跳网络技术有限公司 图像的语义分割方法、装置、电子设备及存储介质

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2017158058A1 (en) * 2016-03-15 2017-09-21 Imra Europe Sas Method for classification of unique/rare cases by reinforcement learning in neural networks
CN107247989A (zh) * 2017-06-15 2017-10-13 北京图森未来科技有限公司 一种神经网络训练方法及装置
US20180174047A1 (en) * 2016-12-15 2018-06-21 WaveOne Inc. Data compression for machine learning tasks
CN110070183A (zh) * 2019-03-11 2019-07-30 中国科学院信息工程研究所 一种弱标注数据的神经网络模型训练方法及装置
CN110413993A (zh) * 2019-06-26 2019-11-05 重庆兆光科技股份有限公司 一种基于稀疏权值神经网络的语义分类方法、系统和介质
CN110598603A (zh) * 2019-09-02 2019-12-20 深圳力维智联技术有限公司 人脸识别模型获取方法、装置、设备和介质
CN110689043A (zh) * 2019-08-22 2020-01-14 长沙千视通智能科技有限公司 一种基于多重注意力机制的车辆细粒度识别方法及装置
US20200151497A1 (en) * 2018-11-12 2020-05-14 Sony Corporation Semantic segmentation with soft cross-entropy loss
CN111242297A (zh) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 基于知识蒸馏的模型训练方法、图像处理方法及装置
EP3680823A1 (en) * 2019-01-10 2020-07-15 Visa International Service Association System, method, and computer program product for incorporating knowledge from more complex models in simpler models

Patent Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2017158058A1 (en) * 2016-03-15 2017-09-21 Imra Europe Sas Method for classification of unique/rare cases by reinforcement learning in neural networks
US20180174047A1 (en) * 2016-12-15 2018-06-21 WaveOne Inc. Data compression for machine learning tasks
CN107247989A (zh) * 2017-06-15 2017-10-13 北京图森未来科技有限公司 一种神经网络训练方法及装置
US20180365564A1 (en) * 2017-06-15 2018-12-20 TuSimple Method and device for training neural network
US20200151497A1 (en) * 2018-11-12 2020-05-14 Sony Corporation Semantic segmentation with soft cross-entropy loss
EP3680823A1 (en) * 2019-01-10 2020-07-15 Visa International Service Association System, method, and computer program product for incorporating knowledge from more complex models in simpler models
CN110070183A (zh) * 2019-03-11 2019-07-30 中国科学院信息工程研究所 一种弱标注数据的神经网络模型训练方法及装置
CN110413993A (zh) * 2019-06-26 2019-11-05 重庆兆光科技股份有限公司 一种基于稀疏权值神经网络的语义分类方法、系统和介质
CN110689043A (zh) * 2019-08-22 2020-01-14 长沙千视通智能科技有限公司 一种基于多重注意力机制的车辆细粒度识别方法及装置
CN110598603A (zh) * 2019-09-02 2019-12-20 深圳力维智联技术有限公司 人脸识别模型获取方法、装置、设备和介质
CN111242297A (zh) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 基于知识蒸馏的模型训练方法、图像处理方法及装置

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
QUANDE LIU 等: "Semi-Supervised Medical Image Classification With Relation-Driven Self-Ensembling Model", 《IEEE TRANSACTIONS ON MEDICAL IMAGING》, vol. 39, no. 11, pages 3429, XP011816696, DOI: 10.1109/TMI.2020.2995518 *
侯卫东: "面向移动应用的人体图像多属性分类算法研究", 《中国优秀硕士学位论文全文数据库 (信息科技辑)》, no. 07, pages 138 - 1256 *
王峰: "基于深度学习的人脸认证方法研究", 《中国博士学位论文全文数据库 (信息科技辑)》, no. 04, pages 138 - 11 *

Cited By (20)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112528109A (zh) * 2020-12-01 2021-03-19 中科讯飞互联(北京)信息科技有限公司 一种数据分类方法、装置、设备及存储介质
CN112528109B (zh) * 2020-12-01 2023-10-27 科大讯飞(北京)有限公司 一种数据分类方法、装置、设备及存储介质
CN112508120A (zh) * 2020-12-18 2021-03-16 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品
CN112508120B (zh) * 2020-12-18 2023-10-10 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品
CN112668716A (zh) * 2020-12-29 2021-04-16 奥比中光科技集团股份有限公司 一种神经网络模型的训练方法及设备
CN113159085A (zh) * 2020-12-30 2021-07-23 北京爱笔科技有限公司 分类模型的训练及基于图像的分类方法、相关装置
CN113159085B (zh) * 2020-12-30 2024-05-28 北京爱笔科技有限公司 分类模型的训练及基于图像的分类方法、相关装置
WO2022141859A1 (zh) * 2020-12-31 2022-07-07 平安科技(深圳)有限公司 图像检测方法、装置、电子设备及存储介质
CN113408570A (zh) * 2021-05-08 2021-09-17 浙江智慧视频安防创新中心有限公司 一种基于模型蒸馏的图像类别识别方法、装置、存储介质及终端
CN113408571A (zh) * 2021-05-08 2021-09-17 浙江智慧视频安防创新中心有限公司 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端
CN112949786A (zh) * 2021-05-17 2021-06-11 腾讯科技(深圳)有限公司 数据分类识别方法、装置、设备及可读存储介质
CN112949786B (zh) * 2021-05-17 2021-08-06 腾讯科技(深圳)有限公司 数据分类识别方法、装置、设备及可读存储介质
CN113411425A (zh) * 2021-06-21 2021-09-17 深圳思谋信息科技有限公司 视频超分模型构建处理方法、装置、计算机设备和介质
CN113411425B (zh) * 2021-06-21 2023-11-07 深圳思谋信息科技有限公司 视频超分模型构建处理方法、装置、计算机设备和介质
CN113392938A (zh) * 2021-07-30 2021-09-14 广东工业大学 一种分类模型训练方法、阿尔茨海默病分类方法及装置
CN113657523A (zh) * 2021-08-23 2021-11-16 科大讯飞股份有限公司 一种图像目标分类方法、装置、设备及存储介质
CN113762368A (zh) * 2021-08-27 2021-12-07 北京市商汤科技开发有限公司 数据蒸馏的方法、装置、电子设备和存储介质
WO2023169334A1 (zh) * 2022-03-09 2023-09-14 北京字跳网络技术有限公司 图像的语义分割方法、装置、电子设备及存储介质
CN114693995A (zh) * 2022-04-14 2022-07-01 北京百度网讯科技有限公司 应用于图像处理的模型训练方法、图像处理方法和设备
CN115294407A (zh) * 2022-09-30 2022-11-04 山东大学 基于预习机制知识蒸馏的模型压缩方法及系统

Also Published As

Publication number Publication date
CN111950638B (zh) 2024-02-06

Similar Documents

Publication Publication Date Title
CN111950638A (zh) 基于模型蒸馏的图像分类方法、装置和电子设备
CN107730474B (zh) 图像处理方法、处理装置和处理设备
WO2019100724A1 (zh) 训练多标签分类模型的方法和装置
US10719693B2 (en) Method and apparatus for outputting information of object relationship
CN109492627B (zh) 一种基于全卷积网络的深度模型的场景文本擦除方法
CN109919183B (zh) 一种基于小样本的图像识别方法、装置、设备及存储介质
CN110738102A (zh) 一种人脸识别方法及系统
WO2023174036A1 (zh) 联邦学习模型训练方法、电子设备及存储介质
CN113408570A (zh) 一种基于模型蒸馏的图像类别识别方法、装置、存储介质及终端
CN114676704A (zh) 句子情感分析方法、装置、设备以及存储介质
CN111898735A (zh) 蒸馏学习方法、装置、计算机设备和存储介质
CN111814804B (zh) 基于ga-bp-mc神经网络的人体三维尺寸信息预测方法及装置
CN110659398A (zh) 一种基于数学图表类数据集的视觉问答方法
CN110879993A (zh) 神经网络训练方法、人脸识别任务的执行方法及装置
CN114399808A (zh) 一种人脸年龄估计方法、系统、电子设备及存储介质
CN115984930A (zh) 微表情识别方法、装置、微表情识别模型的训练方法
CN109241930B (zh) 用于处理眉部图像的方法和装置
CN110659641B (zh) 一种文字识别的方法、装置及电子设备
CN117315758A (zh) 面部表情的检测方法、装置、电子设备及存储介质
CN112532251A (zh) 一种数据处理的方法及设备
CN112257840A (zh) 一种神经网络处理方法以及相关设备
CN113408571B (zh) 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端
CN115795025A (zh) 一种摘要生成方法及其相关设备
CN115906861A (zh) 基于交互方面信息融合的语句情感分析方法以及装置
CN114998643A (zh) 类别描述的特征信息的获取方法、图像的处理方法及设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant