CN113408571B - 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端 - Google Patents

一种基于模型蒸馏的图像分类方法、装置、存储介质及终端 Download PDF

Info

Publication number
CN113408571B
CN113408571B CN202110499675.5A CN202110499675A CN113408571B CN 113408571 B CN113408571 B CN 113408571B CN 202110499675 A CN202110499675 A CN 202110499675A CN 113408571 B CN113408571 B CN 113408571B
Authority
CN
China
Prior art keywords
model
feature
image
student
trained
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110499675.5A
Other languages
English (en)
Other versions
CN113408571A (zh
Inventor
廖丹萍
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Smart Video Security Innovation Center Co Ltd
Original Assignee
Zhejiang Smart Video Security Innovation Center Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Smart Video Security Innovation Center Co Ltd filed Critical Zhejiang Smart Video Security Innovation Center Co Ltd
Priority to CN202110499675.5A priority Critical patent/CN113408571B/zh
Publication of CN113408571A publication Critical patent/CN113408571A/zh
Application granted granted Critical
Publication of CN113408571B publication Critical patent/CN113408571B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/213Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于模型蒸馏的图像分类方法,该方法包括:获取待分类目标图像并输入预先训练的学生模型中;预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数;输出目标图像对应的多个类别概率值;基于多个类别概率值判定待分类目标图像的目标类别。因此,本申请实施例通过采用主成分分析算法对教师模型输出的特征进行特征降维,并约束学生模型的输出特征与降维后的特征一致,使得学生模型也能学到和教师模型相似区分度的特征,由于学生模型结构简单以及参数少,从而提升了硬件平台的运行速度,提高了图像分类效率。

Description

一种基于模型蒸馏的图像分类方法、装置、存储介质及终端
技术领域
本发明涉及计算机视觉技术领域,特别涉及一种基于模型蒸馏的图像分类方法、装置、存储介质及终端。
背景技术
近年来,深度神经网络使得很多计算机视觉任务的性能达到了前所未有的高度。神经网络的模型结构越复杂,参数越多,网络能学习到的知识就越丰富,学习效果也越好。然而,高额的存储空间以及计算资源使得大网络模型难以应用在各类移动平台,因此,设计更加轻量化且兼顾性能的网络模型成为了计算机视觉算法落地应用的关键研究之一。
在现有技术中,模型轻量化通常采用模型压缩方法通过对大模型进行参数裁剪、权重分解或者采用模型蒸馏等方法,减小模型对于计算空间和时间的消耗。例如,在模型蒸馏的方法中,先训练一个精度较高的大模型,称为教师模型,然后利用教师模型中学到的知识去指导训练参数量较少的学生模型。然而,当教师模型的特征维度高于学生模型的特征维度时,由于不同维度的特征无法直接计算距离,因此无法直接利用教师模型的高维特征指导学生模型训练。为了绕开这个问题,现有技术一般采用特征维度和教师网络一致的学生模型,这就使得学生模型的网络大小不能进一步进行压缩,影响了运行的时效性。
发明内容
本申请实施例提供了一种基于模型蒸馏的图像分类方法、装置、存储介质及终端。为了对披露的实施例的一些方面有一个基本的理解,下面给出了简单的概括。该概括部分不是泛泛评述,也不是要确定关键/重要组成元素或描绘这些实施例的保护范围。其唯一目的是用简单的形式呈现一些概念,以此作为后面的详细说明的序言。
第一方面,本申请实施例提供了一种基于模型蒸馏的图像分类方法,该方法包括:
获取待分类目标图像;
将待分类目标图像输入预先训练的学生模型中;其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数;
输出目标图像对应的多个类别概率值;
基于多个类别概率值判定待分类目标图像的目标类别。
可选的,基于多个类别概率值判定待分类目标图像的目标类别,包括:
选择多个类别概率值中的最大概率值;
识别选择的最大概率值对应的目标类别;
将目标类别确定为待分类目标图像的所属类别。
可选的,按照下述步骤生成预先训练的教师模型,包括:
采集多种类型的图像集生成模型训练样本;
创建教师模型;其中,所述教师模型的模型特征通道数为Ct
将模型训练样本输入教师模型中进行训练后,生成训练后的教师模型;
将训练后的教师模型确定为预先训练的教师模型。
可选的,按照下述步骤生成预先训练的学生模型,包括:
创建学生模型,其中,所述学生模型的模型特征通道数为Cs
基于预先训练的教师模型与主成分分析算法构建降维矩阵
Figure BDA0003055889280000021
其中,所述降维矩阵将特征从Ct维降至Cs维度;
从模型训练样本中获取第n图像;
将第n图像输入预先训练的教师模型中进行特征提取,生成大小为H×W×Ct的第一网络特征,H和W分别为特征的高度和宽度,Ct为第一网络特征的通道数量;
根据降维矩阵将第一网络特征中每个空间位置的特征由Ct维降至Cs维,得到大小为(H×W)×Cs的目标特征Fref
将第n图像输入预先训练的学生模型中进行特征提取,生成第二网络特征
Figure BDA0003055889280000022
根据目标特征Fref和第二网络特征
Figure BDA0003055889280000031
构造学生模型的目标损失函数;
将目标损失函数关联至学生模型上,生成关联函数后的学生模型;
将第n图像输入关联函数后的学生模型中进行训练。
可选的,方法还包括:
当所述模型的迭代训练次数小于预设值时,继续执行从所述模型训练样本中获取第n+1图像的步骤,并当所述n+1大于所述模型训练样本时,对所述模型训练样本中图像的顺序进行随机排列,并重置n=1。
可选的,基于预先训练的教师模型与主成分分析算法构建降维矩阵,包括:
从模型训练样本中K类图像集中随机选择N张样本,生成N×K张样本;
将N×K张样本输入预先训练的教师模型中进行特征提取,生成多个图像特征;
将多个图像特征中每个图像特征进行重排后,生成N×K个大小为(H×W)×Ct的矩阵;
将N×K个大小为(H×W)×Ct的矩阵进行纵向拼接,生成拼接矩阵
Figure BDA0003055889280000032
根据拼接矩阵
Figure BDA0003055889280000033
与主成分分析算法构建降维矩阵
Figure BDA0003055889280000034
其中,所述降维矩阵将特征从Ct维降至Cs维度。
可选的,损失函数为:
Figure BDA0003055889280000035
其中,λ为特征约束的权重,LCE为交叉熵损失函数,则LCE的计算公式为:LCE=∑i-pi×logqi,其中,qi为学生模型中网络输出的第i类的概率值,pi为图片标签向量的第i维数值。
第二方面,本申请实施例提供了一种基于模型蒸馏的图像分类装置,该装置包括:
图像获取模块,用于获取待分类目标图像;
图像输入模块,用于将待分类目标图像输入预先训练的学生模型中;其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数;
概率值输出模块,用于输出目标图像对应的多个类别概率值;
类别判定模块,用于基于多个类别概率值判定待分类目标图像的目标类别。
第三方面,本申请实施例提供一种计算机存储介质,计算机存储介质存储有多条指令,指令适于由处理器加载并执行上述的方法步骤。
第四方面,本申请实施例提供一种终端,可包括:处理器和存储器;其中,存储器存储有计算机程序,计算机程序适于由处理器加载并执行上述的方法步骤。
本申请实施例提供的技术方案可以包括以下有益效果:
在本申请实施例中,首先获取待分类目标图像并输入预先训练的学生模型中,其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数,然后输出目标图像对应的多个类别概率值,最后基于多个类别概率值判定待分类目标图像的目标类别。因此,本申请实施例通过采用主成分分析算法对教师模型输出的特征进行特征降维,并约束学生模型的输出特征与降维后的特征一致,使得学生模型也能学到和教师模型相似区分度的特征,由于学生模型参数少,从而提升了硬件平台的运行速度,进一步提高了图像分类效率。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本发明。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本发明的实施例,并与说明书一起用于解释本发明的原理。
图1是本申请实施例提供的一种基于模型蒸馏的图像分类方法的流程示意图;
图2是本申请实施例提供的一种基于模型蒸馏的图像分类过程的过程示意框图;
图3是本申请实施例提供的一种基于模型蒸馏的图像分类装置的装置示意图;
图4是本申请实施例提供的一种终端的结构示意图。
具体实施方式
以下描述和附图充分地示出本发明的具体实施方案,以使本领域的技术人员能够实践它们。
应当明确,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本发明相一致的所有实施方式。相反,它们仅是如所附权利要求书中所详述的、本发明的一些方面相一致的装置和方法的例子。
在本发明的描述中,需要理解的是,术语“第一”、“第二”等仅用于描述目的,而不能理解为指示或暗示相对重要性。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明中的具体含义。此外,在本发明的描述中,除非另有说明,“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。
下面将结合附图1-附图2,对本申请实施例提供的基于模型蒸馏的图像分类方法进行详细介绍。该方法可依赖于计算机程序实现,可运行于基于冯诺依曼体系的基于模型蒸馏的图像分类装置上。该计算机程序可集成在应用中,也可作为独立的工具类应用运行。其中,本申请实施例中的基于模型蒸馏的图像分类装置可以为用户终端,包括但不限于:个人电脑、平板电脑、手持设备、车载设备、可穿戴设备、计算设备或连接到无线调制解调器的其它处理设备等。在不同的网络中用户终端可以叫做不同的名称,例如:用户设备、接入终端、用户单元、用户站、移动站、移动台、远方站、远程终端、移动设备、用户终端、终端、无线通信设备、用户代理或用户装置、蜂窝电话、无绳电话、个人数字处理(personal digitalassistant,PDA)、5G网络或未来演进网络中的终端设备等。
请参见图1,为本申请实施例提供了一种基于模型蒸馏的图像分类方法的流程示意图。如图1所示,本申请实施例的方法可以包括以下步骤:
S101,获取待分类目标图像;
其中,待分类的目标图像是用来测试学生模型性能的图像或者学生模型应用在分类应用场景时获取到的图像。
通常,当待分类的目标图像是用来测试学生模型性能的图像时,待分类的目标图像可以是从测试样本中获取的,也可以是从用户终端里获取到的图像,还可以是从云端下载到的图像。当待分类的目标图像是学生模型应用在分类应用场景时获取到的图像时,待分类的图像可以是通过图像采集设备实时采集的图像。
在一种可能的实现方式中,当基于教师模型训练的学生模型训练结束后,并将训练结束的学生模型部署在实际应用场景时,物体传感器或者物体监测算法当检测到有物体进入摄像头监视区域后,触发图像采集摄像的拍照功能采集进入监视区域的目标图像,最后将目标图像确定为待分类目标图像。
在另一种可能的实现方式中,当基于教师模型训练的学生模型训练结束后,需要检测训练完成的学生模型的图像分类性能,用户通过用户终端从样本测试集或者本地图库或者云端下载任何一个带物体图像,将该图像确定为待分类目标图像。
S102,将待分类目标图像输入预先训练的学生模型中;其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数;
其中,模型蒸馏的方法中,算法先训练一个精度较高的大模型,称为教师模型。然后利用教师模型中学到的知识,去指导训练一个参数量较少的学生模型。学生模型通过学习教师模型中对分类有益的信息,从而提高自己的性能。由于学生模型参数量少,运行速度快,因此可以方便地部署在各类硬件平台上。
通常,教师模型与学生模型都是通过神经网络创建的,该神经网络优选卷积神经网络。
在本申请实施例中,在训练教师模型时,首先采集多种类型的图像集生成模型训练样本,然后创建教师模型;其中,所述教师模型的模型特征通道数为Ct,再将模型训练样本输入教师模型中进行训练后,生成训练后的教师模型,最后将训练后的教师模型确定为预先训练的教师模型。
在本申请实施例中,在训练学生模型时,首先创建学生模型,其中,所述学生模型的模型特征通道数为Cs,再基于预先训练的教师模型与主成分分析算法构建降维矩阵
Figure BDA0003055889280000071
其中,所述降维矩阵将特征从Ct维降至Cs维度;然后从模型训练样本中获取第n图像,再将第n图像输入预先训练的教师模型中进行特征提取,生成大小为H×W×Ct的第一网络特征,H和W分别为特征的高度和宽度,Ct为第一网络特征的通道数量,其次根据降维矩阵将第一网络特征中每个空间位置的特征由Ct维降至Cs维,得到大小为(H×W)×Cs的目标特征Fref,再将第n图像输入预先训练的学生模型中进行特征提取,生成第二网络特征
Figure BDA0003055889280000072
再根据目标特征Fref和第二网络特征
Figure BDA0003055889280000073
构造学生模型的目标损失函数,再将目标损失函数关联至学生模型上,生成关联函数后的学生模型,并将第n图像输入关联函数后的学生模型中进行训练,最后当所述模型的迭代训练次数小于预设值时,继续执行从所述模型训练样本中获取第n+1图像的步骤,并当所述n+1大于所述模型训练样本时,对所述模型训练样本中图像的顺序进行随机排列,并重置n=1。
具体的,损失函数为:
Figure BDA0003055889280000074
其中,λ为特征约束的权重,LCE为交叉熵损失函数,则LCE的计算公式为:LCE=∑i-pi×logqi,其中,qi为学生模型中网络输出的第i类的概率值,pi为图片标签向量的第i维数值。
需要说明的是,在训练学生模型的过程中,在现有损失函数的基础上,加入了特征约束的权重,用来约束学生模型提取的特征与教师模型的一致。由于学生模型提取的特征维度通常小于教师模型的特征维度,因此,需要将教师模型的特征进行降维,变换到与学生模型的特征相同的维度,才能直接与学生模型的特征进行比较,从而指导学生模型进行训练。
进一步的,在基于预先训练的教师模型与主成分分析算法构建降维矩阵时,首先从模型训练样本中K类图像集中随机选择N张样本,生成N×K张样本,再将N×K张样本输入预先训练的教师模型中进行特征提取,生成多个图像特征,然后将多个图像特征中每个图像特征进行重排后,生成N×K个大小为(H×W)×Ct的矩阵,再将N×K个大小为(H×W)×Ct的矩阵进行纵向拼接,生成拼接矩阵
Figure BDA0003055889280000081
最后根据拼接矩阵
Figure BDA0003055889280000082
与主成分分析算法构建降维矩阵
Figure BDA0003055889280000083
其中,所述降维矩阵将特征从Ct维降至Cs维度。
在一种可能的实现方式中,在基于步骤S101获取到待分类目标图像后,将待分类目标图像输入预先训练的学生模型中进行处理。
S103,输出目标图像对应的多个类别概率值;
其中,概率值是代表该图像所属类别的类别概率分布。
在一种可能的实现方式中,当基于步骤S102使用预先训练的学生模型进行处理后,会得到多个类别的概率值,最后预先训练的学生模型将多个概率值进行输出,输出后得到目标图像对应的多个类别概率值。
例如,该图像通过预先训练的学生模型进行处理完成后输出的多个概率值为:动物类型概率23%、人体类型概率67%、其他类型概率为10%,从输出的概率可以知道,最大概率为人体类型概率67%,因此该图像中的物体为人体类型。
在本申请实施例中,采用PCA(主成分分析)算法对特征进行降维。得到教师网络降维后的特征后,约束学生网络的输出特征与该特征一致,使得学生网络也能学到和教师网络相似区分度的特征。
S104,基于多个类别概率值判定待分类目标图像的目标类别。
在一种可能的实现方式中,在得到待分类图像的多个类别概率值后,首先选择多个类别概率值中的最大概率值,然后识别选择的最大概率值对应的目标类别,最后将目标类别确定为待分类目标图像的所属类别。
例如图2所示,图2是本申请提供的基于模型蒸馏的图像分类过程的过程示意图,首先通过获取一个目标图像,然后将该目标图像输入预先训练的学生模型中,经过模型处理后输入概率值1、概率值2、概率值3以及概率值n,其次从输出的多个概率中选择概率值最大的概率值,并将概率值最大的概率值对应的类别确定为图像的最终所属类别。
在本申请实施例中,首先获取待分类目标图像并输入预先训练的学生模型中,其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数,然后输出目标图像对应的多个类别概率值,最后基于多个类别概率值判定待分类目标图像的目标类别。因此,本申请实施例通过采用主成分分析算法对教师模型输出的特征进行特征降维,并约束学生模型的输出特征与降维后的特征一致,使得学生模型也能学到和教师模型相似区分度的特征,由于学生参数少,从而提升了硬件平台的运行速度,进一步提高了图像分类效率。
下述为本发明装置实施例,可以用于执行本发明方法实施例。对于本发明装置实施例中未披露的细节,请参照本发明方法实施例。
请参见图3,其示出了本发明一个示例性实施例提供的基于模型蒸馏的图像分类装置的结构示意图。该基于模型蒸馏的图像分类装置可以通过软件、硬件或者两者的结合实现成为终端的全部或一部分。该装置1包括图像获取模块10、图像输入模块20、概率值输出模块30、类别判定模块40。
图像获取模块10,用于获取待分类目标图像;
图像输入模块20,用于将待分类目标图像输入预先训练的学生模型中;其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数;
概率值输出模块30,用于输出目标图像对应的多个类别概率值;
类别判定模块40,用于基于多个类别概率值判定待分类目标图像的目标类别。
需要说明的是,上述实施例提供的基于模型蒸馏的图像分类装置在执行基于模型蒸馏的图像分类方法时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的基于模型蒸馏的图像分类装置与基于模型蒸馏的图像分类方法实施例属于同一构思,其体现实现过程详见方法实施例,这里不再赘述。
上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。
在本申请实施例中,首先获取待分类目标图像并输入预先训练的学生模型中,其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数,然后输出目标图像对应的多个类别概率值,最后基于多个类别概率值判定待分类目标图像的目标类别。因此,本申请实施例通过采用主成分分析算法对教师模型输出的特征进行特征降维,并约束学生模型的输出特征与降维后的特征一致,使得学生模型也能学到和教师模型相似区分度的特征,由于学生模型参数少,从而提升了硬件平台的运行速度,进一步提高了图像分类效率。
本发明还提供一种计算机可读介质,其上存储有程序指令,该程序指令被处理器执行时实现上述各个方法实施例提供的基于模型蒸馏的图像分类方法。本发明还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述各个方法实施例的基于模型蒸馏的图像分类方法。
请参见图4,为本申请实施例提供了一种终端的结构示意图。如图4所示,终端1000可以包括:至少一个处理器1001,至少一个网络接口1004,用户接口1003,存储器1005,至少一个通信总线1002。
其中,通信总线1002用于实现这些组件之间的连接通信。
其中,用户接口1003可以包括显示屏(Display)、摄像头(Camera),可选用户接口1003还可以包括标准的有线接口、无线接口。
其中,网络接口1004可选的可以包括标准的有线接口、无线接口(如WI-FI接口)。
其中,处理器1001可以包括一个或者多个处理核心。处理器1001利用各种借口和线路连接整个电子设备1000内的各个部分,通过运行或执行存储在存储器1005内的指令、程序、代码集或指令集,以及调用存储在存储器1005内的数据,执行电子设备1000的各种功能和处理数据。可选的,处理器1001可以采用数字信号处理(Digital Signal Processing,DSP)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)、可编程逻辑阵列(Programmable Logic Array,PLA)中的至少一种硬件形式来实现。处理器1001可集成中央处理器(Central Processing Unit,CPU)、图像处理器(Graphics Processing Unit,GPU)和调制解调器等中的一种或几种的组合。其中,CPU主要处理操作系统、用户界面和应用程序等;GPU用于负责显示屏所需要显示的内容的渲染和绘制;调制解调器用于处理无线通信。可以理解的是,上述调制解调器也可以不集成到处理器1001中,单独通过一块芯片进行实现。
其中,存储器1005可以包括随机存储器(Random Access Memory,RAM),也可以包括只读存储器(Read-Only Memory)。可选的,该存储器1005包括非瞬时性计算机可读介质(non-transitory computer-readable storage medium)。存储器1005可用于存储指令、程序、代码、代码集或指令集。存储器1005可包括存储程序区和存储数据区,其中,存储程序区可存储用于实现操作系统的指令、用于至少一个功能的指令(比如触控功能、声音播放功能、图像播放功能等)、用于实现上述各个方法实施例的指令等;存储数据区可存储上面各个方法实施例中涉及到的数据等。存储器1005可选的还可以是至少一个位于远离前述处理器1001的存储装置。如图4所示,作为一种计算机存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及基于模型蒸馏的图像分类应用程序。
在图4所示的终端1000中,用户接口1003主要用于为用户提供输入的接口,获取用户输入的数据;而处理器1001可以用于调用存储器1005中存储的基于模型蒸馏的图像分类应用程序,并具体执行以下操作:
获取待分类目标图像;
将待分类目标图像输入预先训练的学生模型中;其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数;
输出目标图像对应的多个类别概率值;
基于多个类别概率值判定待分类目标图像的目标类别。
在一个实施例中,处理器1001在执行基于多个类别概率值判定待分类目标图像的目标类别时,具体执行以下操作:
选择多个类别概率值中的最大概率值;
识别选择的最大概率值对应的目标类别;
将目标类别确定为待分类目标图像的所属类别。
在一个实施例中,处理器1001在执行生成预先训练的教师模型时,具体执行以下操作:
采集多种类型的图像集生成模型训练样本;
创建教师模型;其中,所述教师模型的模型特征通道数为Ct
将模型训练样本输入教师模型中进行训练后,生成训练后的教师模型;
将训练后的教师模型确定为预先训练的教师模型。
在一个实施例中,处理器1001在执行生成预先训练的学生模型时,具体执行以下操作:
创建学生模型,其中,所述学生模型的模型特征通道数为Cs
基于预先训练的教师模型与主成分分析算法构建降维矩阵
Figure BDA0003055889280000121
其中,所述降维矩阵将特征从Ct维降至Cs维度;
从模型训练样本中获取第n图像;
将第n图像输入预先训练的教师模型中进行特征提取,生成大小为H×W×Ct的第一网络特征,H和W分别为特征的高度和宽度,Ct为第一网络特征的通道数量;
根据降维矩阵将第一网络特征中每个空间位置的特征由Ct维降至Cs维,得到大小为(H×W)×Cs的目标特征Fref
将第n图像输入预先训练的学生模型中进行特征提取,生成第二网络特征
Figure BDA0003055889280000122
根据目标特征Fref和第二网络特征
Figure BDA0003055889280000131
构造学生模型的目标损失函数;
将目标损失函数关联至学生模型上,生成关联函数后的学生模型;
将第n图像输入关联函数后的学生模型中进行训练。
在一个实施例中,处理器1001在执行基于预先训练的教师模型与主成分分析算法构建降维矩阵时,具体执行以下操作:
从模型训练样本中K类图像集中随机选择N张样本,生成N×K张样本;
将N×K张样本输入预先训练的教师模型中进行特征提取,生成多个图像特征;
将多个图像特征中每个图像特征进行重排后,生成N×K个大小为(H×W)×Ct的矩阵;
将N×K个大小为(H×W)×nt的矩阵进行纵向拼接,生成拼接矩阵
Figure BDA0003055889280000132
根据拼接矩阵
Figure BDA0003055889280000133
与主成分分析算法构建降维矩阵
Figure BDA0003055889280000134
其中,所述降维矩阵将特征从Ct维降至Cs维度。
在本申请实施例中,首先获取待分类目标图像并输入预先训练的学生模型中,其中,预先训练的学生模型基于模型蒸馏法训练生成,模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,学生模型的模型特征通道数小于教师模型的模型特征通道数,然后输出目标图像对应的多个类别概率值,最后基于多个类别概率值判定待分类目标图像的目标类别。因此,本申请实施例通过采用主成分分析算法对教师模型输出的特征进行特征降维,并约束学生模型的输出特征与降维后的特征一致,使得学生模型也能学到和教师模型相似区分度的特征,由于学生模型参数少,从而提升了硬件平台的运行速度,进一步提高了图像分类效率。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,基于模型蒸馏的图像分类的程序可存储于计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,的存储介质可为磁碟、光盘、只读存储记忆体或随机存储记忆体等。
以上所揭露的仅为本申请较佳实施例而已,当然不能以此来限定本申请之权利范围,因此依本申请权利要求所作的等同变化,仍属本申请所涵盖的范围。

Claims (9)

1.一种基于模型蒸馏的图像分类方法,其特征在于,所述方法包括:
获取待分类目标图像;
将所述待分类目标图像输入预先训练的学生模型中;其中,所述预先训练的学生模型基于模型蒸馏法训练生成,所述模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,所述学生模型的模型特征通道数小于所述教师模型的模型特征通道数;
输出所述目标图像对应的多个类别概率值;
基于所述多个类别概率值判定所述待分类目标图像的目标类别;其中,
按照下述步骤生成预先训练的学生模型,包括:
创建学生模型,其中,所述学生模型的模型特征通道数为Cs
基于所述预先训练的教师模型与所述主成分分析算法构建降维矩阵
Figure FDA0003603364030000011
其中,所述降维矩阵将特征从Ct维降至Cs维度;其中,所述Ct是教师模型的特征通道数;
从所述模型训练样本中获取第n图像;
将所述第n图像输入所述预先训练的教师模型中进行特征提取,生成大小为H×W×Ct的第一网络特征,H和W分别为特征的高度和宽度,Ct为所述第一网络特征的通道数量;
根据所述降维矩阵将所述第一网络特征中每个空间位置的特征由Ct维降至Cs维,得到大小为(H×W)×Cs的目标特征Fref
将所述第n图像输入所述学生模型中进行特征提取,生成第二网络特征
Figure FDA0003603364030000012
根据所述目标特征Fref和所述第二网络特征
Figure FDA0003603364030000013
构造所述学生模型的目标损失函数;
将所述目标损失函数关联至所述学生模型上,生成关联函数后的学生模型;
将所述第n图像输入所述关联函数后的学生模型中进行训练。
2.根据权利要求1所述的方法,其特征在于,所述基于所述多个类别概率值判定所述待分类目标图像的目标类别,包括:
选择所述多个类别概率值中的最大概率值;
识别所述选择的最大概率值对应的目标类别;
将所述目标类别确定为所述待分类目标图像的所属类别。
3.根据权利要求1所述的方法,其特征在于,按照下述步骤生成预先训练的教师模型,包括:
采集多种类型的图像集生成模型训练样本;
创建教师模型;其中,所述教师模型的模型特征通道数为Ct
将所述模型训练样本输入所述教师模型中进行训练后,生成训练后的教师模型;
将所述训练后的教师模型确定为预先训练的教师模型。
4.根据权利要求1所述的方法,其特征在于,所述方法还包括:
当所述模型的迭代训练次数小于预设值时,继续执行从所述模型训练样本中获取第n+1图像的步骤,并当所述n+1大于所述模型训练样本时,对所述模型训练样本中图像的顺序进行随机排列,并重置n=1。
5.根据权利要求1所述的方法,其特征在于,所述基于所述预先训练的教师模型与所述主成分分析算法构建降维矩阵,包括:
从所述模型训练样本中K类图像集中随机选择N张样本,生成N×K张样本;
将所述N×K张样本输入所述预先训练的教师模型中进行特征提取,生成多个图像特征;
将所述多个图像特征中每个图像特征进行重排后,生成N×K个大小为(H×W)×Ct的特征矩阵;
将所述N×K个大小为(H×W)×Ct的矩阵进行纵向拼接,生成拼接矩阵
Figure FDA0003603364030000021
根据所述拼接矩阵
Figure FDA0003603364030000022
与所述主成分分析算法构建降维矩阵
Figure FDA0003603364030000023
其中,所述降维矩阵将特征从Ct维降至Cs维度。
6.根据权利要求1所述的方法,其特征在于,所述损失函数为:
Figure FDA0003603364030000031
其中,λ为特征约束的权重,LCE为交叉熵损失函数,则LCE的计算公式为:LCE=∑i-pi×log qi,其中,qi为所述学生模型中网络输出的第i类的概率值,pi为图片标签向量的第i维数值。
7.一种基于模型蒸馏的图像分类装置,其特征在于,所述装置包括:
图像获取模块,用于获取待分类目标图像;
图像输入模块,用于将所述待分类目标图像输入预先训练的学生模型中;其中,所述预先训练的学生模型基于模型蒸馏法训练生成,所述模型蒸馏法采用主成分分析算法对预先训练的教师模型输出的特征进行特征降维,所述学生模型的模型特征通道数小于所述教师模型的模型特征通道数;
概率值输出模块,用于输出所述目标图像对应的多个类别概率值;
类别判定模块,用于基于所述多个类别概率值判定所述待分类目标图像的目标类别;其中,
按照下述步骤生成预先训练的学生模型,包括:
创建学生模型,其中,所述学生模型的模型特征通道数为Cs
基于所述预先训练的教师模型与所述主成分分析算法构建降维矩阵
Figure FDA0003603364030000032
其中,所述降维矩阵将特征从Ct维降至Cs维度;其中,所述Ct是教师模型的特征通道数;
从所述模型训练样本中获取第n图像;
将所述第n图像输入所述预先训练的教师模型中进行特征提取,生成大小为H×W×Ct的第一网络特征,H和W分别为特征的高度和宽度,Ct为所述第一网络特征的通道数量;
根据所述降维矩阵将所述第一网络特征中每个空间位置的特征由Ct维降至Cs维,得到大小为(H×W)×Cs的目标特征Fref
将所述第n图像输入所述学生模型中进行特征提取,生成第二网络特征
Figure FDA0003603364030000033
根据所述目标特征Fref和所述第二网络特征
Figure FDA0003603364030000034
构造所述学生模型的目标损失函数;
将所述目标损失函数关联至所述学生模型上,生成关联函数后的学生模型;
将所述第n图像输入所述关联函数后的学生模型中进行训练。
8.一种计算机存储介质,其特征在于,所述计算机存储介质存储有多条指令,所述指令适于由处理器加载并执行如权利要求1-6任意一项的方法步骤。
9.一种终端,其特征在于,包括:处理器和存储器;其中,所述存储器存储有计算机程序,所述计算机程序适于由所述处理器加载并执行如权利要求1-6任意一项的方法步骤。
CN202110499675.5A 2021-05-08 2021-05-08 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端 Active CN113408571B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110499675.5A CN113408571B (zh) 2021-05-08 2021-05-08 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110499675.5A CN113408571B (zh) 2021-05-08 2021-05-08 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端

Publications (2)

Publication Number Publication Date
CN113408571A CN113408571A (zh) 2021-09-17
CN113408571B true CN113408571B (zh) 2022-07-19

Family

ID=77678305

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110499675.5A Active CN113408571B (zh) 2021-05-08 2021-05-08 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端

Country Status (1)

Country Link
CN (1) CN113408571B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115147418B (zh) * 2022-09-05 2022-12-27 东声(苏州)智能科技有限公司 缺陷检测模型的压缩训练方法和装置

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111027635A (zh) * 2019-12-12 2020-04-17 深圳前海微众银行股份有限公司 图像处理模型的构建方法、装置、终端及可读存储介质
WO2020177651A1 (zh) * 2019-03-01 2020-09-10 华为技术有限公司 图像分割方法和图像处理装置

Family Cites Families (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP6584250B2 (ja) * 2015-09-10 2019-10-02 株式会社Screenホールディングス 画像分類方法、分類器の構成方法および画像分類装置
CN107247989B (zh) * 2017-06-15 2020-11-24 北京图森智途科技有限公司 一种实时的计算机视觉处理方法及装置
WO2019013711A1 (en) * 2017-07-12 2019-01-17 Mastercard Asia/Pacific Pte. Ltd. MOBILE DEVICE PLATFORM FOR AUTOMATED VISUAL RECOGNITION OF RETAIL PRODUCTS
CN109871909B (zh) * 2019-04-16 2021-10-01 京东方科技集团股份有限公司 图像识别方法及装置
CN111832787B (zh) * 2019-04-23 2022-12-09 北京新唐思创教育科技有限公司 教师风格预测模型的训练方法及计算机存储介质
CN110659665B (zh) * 2019-08-02 2023-09-29 深圳力维智联技术有限公司 一种异维特征的模型构建方法及图像识别方法、装置
CN111950638B (zh) * 2020-08-14 2024-02-06 厦门美图之家科技有限公司 基于模型蒸馏的图像分类方法、装置和电子设备
CN112215334A (zh) * 2020-09-24 2021-01-12 北京航空航天大学 一种面向事件相机的神经网络模型压缩方法
CN112184508B (zh) * 2020-10-13 2021-04-27 上海依图网络科技有限公司 一种用于图像处理的学生模型的训练方法及装置
CN112116030B (zh) * 2020-10-13 2022-08-30 浙江大学 一种基于向量标准化和知识蒸馏的图像分类方法

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020177651A1 (zh) * 2019-03-01 2020-09-10 华为技术有限公司 图像分割方法和图像处理装置
CN111027635A (zh) * 2019-12-12 2020-04-17 深圳前海微众银行股份有限公司 图像处理模型的构建方法、装置、终端及可读存储介质

Also Published As

Publication number Publication date
CN113408571A (zh) 2021-09-17

Similar Documents

Publication Publication Date Title
CN112434721B (zh) 一种基于小样本学习的图像分类方法、系统、存储介质及终端
CN111104962B (zh) 图像的语义分割方法、装置、电子设备及可读存储介质
WO2022083536A1 (zh) 一种神经网络构建方法以及装置
CN109471945B (zh) 基于深度学习的医疗文本分类方法、装置及存储介质
CN113408570A (zh) 一种基于模型蒸馏的图像类别识别方法、装置、存储介质及终端
CN108197652B (zh) 用于生成信息的方法和装置
CN111126258A (zh) 图像识别方法及相关装置
CN111401516A (zh) 一种神经网络通道参数的搜索方法及相关设备
CN112784778B (zh) 生成模型并识别年龄和性别的方法、装置、设备和介质
CN107909147A (zh) 一种数据处理方法及装置
CN110738102A (zh) 一种人脸识别方法及系统
CN111488985A (zh) 深度神经网络模型压缩训练方法、装置、设备、介质
CN113705775A (zh) 一种神经网络的剪枝方法、装置、设备及存储介质
CN111738403A (zh) 一种神经网络的优化方法及相关设备
CN112418320A (zh) 一种企业关联关系识别方法、装置及存储介质
CN115713715A (zh) 一种基于深度学习的人体行为识别方法及识别系统
CN113408571B (zh) 一种基于模型蒸馏的图像分类方法、装置、存储介质及终端
CN109978058B (zh) 确定图像分类的方法、装置、终端及存储介质
CN110069997B (zh) 场景分类方法、装置及电子设备
CN110717407A (zh) 基于唇语密码的人脸识别方法、装置及存储介质
CN111967478A (zh) 一种基于权重翻转的特征图重构方法、系统、存储介质及终端
CN112257840A (zh) 一种神经网络处理方法以及相关设备
CN115982965A (zh) 去噪扩散样本增量学习的碳纤维材料损伤检测方法及装置
CN112862073B (zh) 一种压缩数据分析方法、装置、存储介质及终端
CN112307243A (zh) 用于检索图像的方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant