CN114266897A - 痘痘类别的预测方法、装置、电子设备及存储介质 - Google Patents
痘痘类别的预测方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN114266897A CN114266897A CN202111609463.4A CN202111609463A CN114266897A CN 114266897 A CN114266897 A CN 114266897A CN 202111609463 A CN202111609463 A CN 202111609463A CN 114266897 A CN114266897 A CN 114266897A
- Authority
- CN
- China
- Prior art keywords
- model
- teacher
- student
- models
- category
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Image Analysis (AREA)
Abstract
本申请实施例涉及图像处理技术领域,公开了一种痘痘类别的预测方法、装置、电子设备及存储介质,该痘痘类别的预测方法,一方面,通过包括多种类别的痘痘的图像的数据集对多个教师模型进行训练,使得多个教师模型能够学习到多种类别的痘痘的特征;另一方面,通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型,使得学生模型能够更好地提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度,进而使本申请能够提高预测痘痘类别的准确率。
Description
技术领域
本申请实施例涉及图像处理技术领域,尤其涉及一种痘痘类别的预测方法、装置、电子设备及存储介质。
背景技术
随着移动通信技术的快速发展以及人民生活水平的提升,各种智能终端已广泛应用于人民的日常工作和生活,使得人们越来越习惯于使用APP等软件,使得美颜自拍、拍照测肤此类功能的APP需求也变得越来越多,因此不少的用户希望此类APP能够自动分析出脸部的痘痘情况,根据痘痘类别情况,有针对性提出皮肤改善方案。
目前,分类算法常常采用集成分类算法,其是神经网络的集合,它的输出是通过加权平均或投票组合而成,但是集成分类算法识别的准确度偏低。
发明内容
本申请实施例主要解决的技术问题是提供一种痘痘类别的预测方法、装置、电子设备及存储介质,以提高预测痘痘类别的准确率。
第一方面,本申请实施例中提供一种痘痘类别的预测方法,包括:
获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;
基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;
通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;
根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
在一些实施例中,训练学生模型,包括:
构建多层损失函数,并基于多层损失函数对学生模型进行训练。
在一些实施例中,多层损失函数包括:相似度损失函数、类别损失函数以及交叉熵损失函数中的至少一个。
在一些实施例中,多层损失函数为:
其中,Loss为多层损失函数,Ll1-sim为相似度损失函数,LKD为类别损失函数,Ls为交叉熵损失函数,i为痘痘的类别,c为特征图的大小,为教师模型的特征图,为学生模型的特征图,n为痘痘类别的数量,为教师模型预测的第i类别痘痘的概率值,为学生模型预测的第i类别痘痘的概率值,yi为真实痘痘类别。
在一些实施例中,通过多个教师模型对预设的学生模型进行知识蒸馏,包括:
根据训练后的多个教师模型,对图像数据集中的图像进行特征提取,以确定多个第一特征图,其中,每一教师模型对应一个第一特征图;
在每次迭代中,确定第二特征图,并随机选择一个教师模型对学生模型进行知识蒸馏,其中,第二特征图与第一特征图的大小相同。
在一些实施例中,基于多层损失函数对学生模型进行训练,包括:
基于多层损失函数对学生模型进行迭代训练;
若迭代次数大于第一次数阈值,或者,学生模型的损失小于第一损失阈值,则停止迭代训练。
在一些实施例中,痘痘类别包括粉刺、逗后红斑、炎症性丘疹、脓包、结节和囊肿中的至少一种。
第二方面,本申请实施例提供一种痘痘类别的预测装置,包括:
数据集获取单元,用于获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;
教师模型训练单元,用于基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;
学生模型训练单元,用于通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;
痘痘类别预测单元,用于根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
第三方面,本申请实施例提供一种电子设备,包括:
存储器以及一个或多个处理器,一个或多个处理器用于执行存储在存储器中的一个或多个计算机程序,一个或多个处理器在执行一个或多个计算机程序时,使得电子设备实现如第一方面的方法。
第四方面,本申请实施例提供一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,计算机程序包括程序指令,程序指令当被处理器执行时使处理器执行如第一方面的方法。
本申请实施例的有益效果:区别于现有技术的情况,本申请实施例提供的一种痘痘类别的预测方法、装置、电子设备及存储介质,该方法包括:获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
一方面,通过包括多种类别的痘痘的图像的数据集对多个教师模型进行训练,使得多个教师模型能够学习到多种类别的痘痘的特征;另一方面,通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型,使得学生模型能够更好地提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度,进而使本申请能够提高预测痘痘类别的准确率。
附图说明
一个或多个实施例通过与之对应的附图中的图片进行示例性说明,这些示例性说明并不构成对实施例的限定,附图中具有相同参考数字标号的元件表示为类似的元件,除非有特别申明,附图中的图不构成比例限制。
图1是本申请实施例提供的一种痘痘类别的预测方法的应用环境示意图;
图2是本申请实施例提供的一种痘痘类别的预测方法的流程示意图;
图3是本申请实施例提供的一种教师模型训练学生模型的示意图;
图4是本申请实施例提供的一种学生模型的迭代训练的流程示意图;
图5是本申请实施例提供的一种痘痘类别的预测装置的结构示意图;
图6是本申请实施例提供的一种电子设备的硬件结构示意图。
具体实施方式
下面结合具体实施例对本申请进行详细说明。以下实施例将有助于本领域的技术人员进一步理解本申请,但不以任何形式限制本申请。应当指出的是,对本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进。这些都属于本申请的保护范围。
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本申请,并不用于限定本申请。
需要说明的是,如果不冲突,本申请实施例中的各个特征可以相互结合,均在本申请的保护范围之内。另外,虽然在装置示意图中进行了功能模块划分,在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于装置中的模块划分,或流程图中的顺序执行所示出或描述的步骤。此外,本文所采用的“第一”、“第二”、“第三”等字样并不对数据和执行次序进行限定,仅是对功能和作用基本相同的相同项或相似项进行区分。
除非另有定义,本说明书所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本说明书中在本申请的说明书中所使用的术语只是为了描述具体的实施方式的目的,不是用于限制本申请。本说明书所使用的术语“和/或”包括一个或多个相关的所列项目的任意的和所有的组合。
此外,下面所描述的本申请各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
在对本申请进行详细说明之前,对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释:
(1)神经网络,也简称为神经网络(NNs)或称作连接模型(Connection Model),它是一种模仿动物神经网络行为特征,进行分布式并行信息处理的算法数学模型。神经网络依靠系统的复杂程度,通过调整内部大量节点之间相互连接的关系,从而达到处理信息的目的。具体的,神经网络可以是由神经单元组成的,具体可以理解为具有输入层、隐含层、输出层的神经网络,一般来说第一层是输入层,最后一层是输出层,中间的层数都是隐含层。其中,具有很多层隐含层的神经网络则称为深度神经网络(deep neural network,DNN)。神经网络中的每一层的工作可以用数学表达式y=a(W·x+b)来描述,从物理层面,神经网络中的每一层的工作可以理解为通过五种对输入空间(输入向量的集合)的操作,完成输入空间到输出空间的变换(即矩阵的行空间到列空间),这五种操作包括:1、升维/降维;2、放大/缩小;3、旋转;4、平移;5、“弯曲”。其中1、2、3的操作由“W·x”完成,4的操作由“+b”完成,5的操作则由“a()”来实现这里之所以用“空间”二字来表述是因为被分类的对象并不是单个事物,而是一类事物,空间是指这类事物所有个体的集合,其中,W是神经网络各层的权重矩阵,该矩阵中的每一个值表示该层的一个神经元的权重值。该矩阵W决定着上文所述的输入空间到输出空间的空间变换,即神经网络每一层的W控制着如何变换空间。训练神经网络的目的,也就是最终得到训练好的神经网络的所有层的权重矩阵。因此,神经网络的训练过程本质上就是学习控制空间变换的方式,更具体的就是学习权重矩阵。
需要注意的是,在本申请实施例中,基于机器学习任务所采用的模型,本质都是神经网络。神经网络中的常用组件有卷积层、池化层、归一化层和反向卷积层等,通过组装神经网络中的这些常用组件,设计得到模型,当确定模型参数(各层的权重矩阵)使得模型误差满足预设条件或调整模型参数的数量达到预设阈值时,模型收敛。
其中,卷积层配置有多个卷积核、每个卷积核设置有对应的步长,以对图像进行卷积运算。卷积运算的目的是提取输入图像的不同特征,第一层卷积层可能只能提取一些低级的特征如边缘、线条和角等层级,更深的卷积层能从低级特征中迭代提取更复杂的特征。
反向卷积层用于将一个低维度的空间映射到高维度,同时保持他们之间的连接关系/模式(这里的连接关系即是指卷积时候的连接关系)。反向卷积层配置有多个卷积核、每个卷积核设置有对应的步长,以对图像进行反卷积运算。一般,用于设计神经网络的框架库(例如PyTorch库)中内置有upsumple()函数,通过调用该upsumple()函数可以实现低维度到高维度的空间映射。
池化层(pooling)是模仿人的视觉系统可以对数据进行降维或用更高层次的特征表示图像。池化层的常见操作包括最大值池化、均值池化、随机池化、中值池化和组合池化等。通常来说,神经网络的卷积层之间都会周期性插入池化层以实现降维。
归一化层用于对中间层的所有神经元进行归一化运算,以防止梯度爆炸和梯度消失。
(2)损失函数,指的是是将随机事件或其有关随机变量的取值映射为非负实数以表示该随机事件的“风险”或“损失”的函数。损失函数是一个非负实数函数,用来量化模型预测的预测标签和真实标签之间的差异。在应用中,损失函数通常作为学习准则与优化问题相联系,即通过最小化损失函数求解和评估模型。例如在统计学和机器学习中被用于模型的参数估计(parametric estimation)。在训练神经网络的过程中,因为希望神经网络的输出尽可能的接近真正想要预测的值,可以通过比较当前网络的预测值和真正想要的目标值,再根据两者之间的差异情况来更新每一层神经网络的权重矩阵(然,在第一次更新之前通常会有初始化的过程,即为神经网络中的各层预先配置参数),比如,如果网络的预测值高了,就调整权重矩阵让它预测低一些,不断的调整,直到神经网络能够预测出真正想要的目标值。因此,就需要预先定义“如何比较预测值和目标值之间的差异”,这便是损失函数(loss function)或目标函数(objective function),它们是用于衡量预测值和目标值的差异的重要方程。其中,以损失函数举例,损失函数的输出值(loss)越高表示差异越大,那么神经网络的训练就变成了尽可能缩小这个loss的过程。
(3)知识蒸馏,即(Knowledge Distillation,KD),指的是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,其通过将已经训练好的模型包含的知识(Knowledge),蒸馏(Distill)提取到另一个模型里面去。通过引入与教师网络(teachernetwork:复杂、但推理性能优越)相关的软目标(soft-target)作为total loss的一部分,以诱导学生网络(student network:精简、低复杂度)的训练,实现知识迁移(knowledgetransfer),即先训练另外一个更复杂(一般为多个网络的集成)的教师网络(TeacherNetwork),并使用大网络的输出作为软目标来训练学生网络(Student Network)。
下面结合说明书附图具体阐述本申请的技术方案。
请参阅图1,图1是本申请实施例提供的一种痘痘类别的预测方法的应用环境示意图;
如图1所示,该应用环境100包括:电子设备101和服务器102,该电子设备101和服务器102通过有线或无线通信方式进行通信。
其中,电子设备101可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。电子设备101中可以设有客户端,该客户端可以是视频客户端、浏览器客户端、线上购物客户端、即时通信客户端等,本申请对客户端的类型不加以限定。
电子设备101以及服务器102可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。电子设备101可以获取目标图像,并预测该目标图像的痘痘类别,在可视化界面上展示该目标图像及其痘痘类别,其中,该目标图像可以是电子设备101的存储器中存储的图像或者接受其他设备发送的图像。
或者,电子设备101可以接收服务器102发送的目标图像的痘痘类别,并在可视化界面上展示该目标图像及其痘痘类别。用户可以电子设备中存储的人脸图像进行浏览,通过触发任一个人脸图像对应的痘痘类别预测按钮,来触发对该人脸图像的痘痘类别预测指令,电子设备可以响应于该痘痘类别预测指令,通过图像采集器件来获取人脸图像,将该人脸图像作为目标图像,其中,该图像采集器件可以内置于电子设备101中,还可以外接于电子设备101,本申请对此不加以限定。
电子设备101可以将该目标图像均发送给服务器102,并接收服务器102返回的目标图像的痘痘类别预测值,进而将该目标图像及其痘痘类别预测值展示在可视化界面上,以便用户了解目标图像的预测结果。
可以理解的是,电子设备101可以泛指多个电子设备中的一个,本申请实施例仅以电子设备101来举例说明。本领域技术人员可以知晓,上述电子设备的数量可以更多或更少。比如上述电子设备可以仅为一个,或者上述电子设备为几十个或几百个,或者更多数量,本申请实施例对电子设备的数量和设备类型不加以限定。
其中,服务器102可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。
服务器102以及电子设备101可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。服务器102可以维护有一个人脸图像数据库,用于存储多个人脸图像。服务器102可以接收电子设备101发送的痘痘类别预测指令和目标图像,并根据该痘痘类别预测指令,基于目标图像,对目标图像进行痘痘类别预测,得到目标图像的痘痘类别预测值,进而将该目标图像的痘痘类别预测值发送给电子设备101。
可以理解的是,上述服务器102的数量可以更多或更少,本申请实施例对此不加以限定。当然,服务器102还可以包括其他功能服务器,以便提供更全面且多样化的服务。
请参阅图2,图2是本申请实施例提供的一种痘痘类别的预测方法的流程示意图;
其中,该痘痘类别的预测方法,应用于电子设备,具体的,该痘痘类别的预测方法的执行主体为电子设备的一个或多个处理器。
如图2所示,该痘痘类别的预测方法,包括:
步骤S201:获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;
具体的,该图像数据集由多种类别的痘痘的人脸图像组成,即该图像数据集包括痘痘类别数据集,例如:该图像数据集中的各图像包括人脸,并且各图像为三通道彩色图像,通过收集多个痘痘图像,其中,每一痘痘图像标注有类别标签,该类别标签用于表征痘痘图像的痘痘类别。在一些实施方式中,痘痘类别包括粉刺、逗后红斑、炎症性丘疹、脓包、结节和囊肿,一共有六种类别,通过one-hot类别标签算法对痘痘图像进行类别标签,比如:[0,1,0,0,0,0]表示痘痘类别为逗后红斑。
进一步地,为了提供额外的类内和类间关系,本申请实施例还对痘痘类别的类别标签进行处理,得到软标签。例如:[0,1,0,0,0,0]表示痘痘类别为逗后红斑,经过处理之后,得到的软标签为[0,1,0.8,0.02,0.04,0.03,0]。可以理解的是,软标签中的每一痘痘类别对应的位置表示该类别的概率值。
在本申请实施例中,由于痘痘图像较小,因此,需要对痘痘图像进行归一化操作,具体的,将痘痘图像的大小调整为预设分辨率,例如:调整痘痘图像的分辨率为40*40。本申请实施例中的痘痘图像的大小还可以是其他分辨率,在此不进行限定。
可以理解的是,该图像数据集可以为由图像获取装置采集到的彩色证件照或彩色自拍照等。可以理解的是,图像数据集也可以是现有的开源人脸库中的数据,其中,开源人脸库可以为FERET人脸数据库、CMU Multi-PIE人脸数据库或YALE人脸数据库等。在此,对图像样本的来源不做限制,只要图像为包括人脸和痘痘的彩色图像即可,例如RGB格式的人脸图像。
步骤S202:基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;
具体的,预设多个教师模型,其中,教师模型包括多种不同的网络结构,例如:Densenet网络结构、Googlenet网络结构、Resnet网络结构以及VGG网络结构中的至少一种,可以理解的是,每一种网络结构对应一种网络模型,即,Densenet网络结构对应Densenet模型,Googlenet网络结构对应Googlenet模型,Resnet网络结构对应Resnet模型,VGG网络结构对应VGG模型。
在本申请实施例中,为了实现学生模型能够学习到多种不同的教师模型学习到的多种类别的痘痘的特征,进一步地,设置多个教师模型中的每一教师模型包括Densenet模型,Googlenet模型,Resnet模型以及VGG模型中的至少一种,并且,每一个教师模型的网络结构不同,以避免重复的教师模型,从而加快学生模型的训练,以提高训练效率。
在本申请实施例中,训练多个教师模型,包括:
为每一教师模型构建多类别交叉损失函数,对每一教师模型进行模型训练,以使每一教师模型预测各个痘痘类别的概率值。其中,多类别交叉损失函数用于将真实的标签类别与预测的标签类别进行概率相乘,以得到相应的损失,比如:真实标签y=[0,1,0,0,0,0],预测的类别p=[0.1,0.8,0.1.0.0.0],最后计算得到的损失loss=-y*logp。
在本申请实施例中,教师模型的训练通过Adam算法进行参数优化。
步骤S203:通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;
具体的,通过多个教师模型对预设的学生模型进行知识蒸馏,包括:
根据训练后的多个教师模型,对图像数据集中的图像进行特征提取,以确定多个第一特征图,其中,每一教师模型对应一个第一特征图;
在每次迭代中,确定第二特征图,并随机选择一个教师模型对学生模型进行知识蒸馏,其中,第二特征图与第一特征图的大小相同。
例如:假设输入的是40*40大小的痘痘图像,在每次迭代中,随机选择一个教师模型,该教师模型为Densenet模型,Googlenet模型,Resnet模型以及VGG模型中的至少一种,该教师模型进行特征提取,从而获取第一特征图,例如:第一特征图的大小为20*20,10*10或5*5,而学生模型同样需要获取与第一特征图的大小相同的第二特征图,此时,无论教师模型还是学生模型,都是在对应大小的特征图上采用相同的数量的卷积核,则教师模型和学生模型均输出相同大小的对应尺度的特征图,例如:教师模型和学生模型输出的都是20*20*32,10*10*64,5*5*128大小的特征图。
具体的,训练学生模型,包括:
构建多层损失函数,并基于多层损失函数对学生模型进行训练。
其中,该多层损失函数包括相似度损失函数、类别损失函数以及交叉熵损失函数中的至少一个。
具体的,该多层损失函数为:
其中,Loss为多层损失函数,Ll1-sim为相似度损失函数,LKD为类别损失函数,Ls为交叉熵损失函数,i为痘痘的类别,c为特征图的大小,为教师模型的特征图,为学生模型的特征图,n为痘痘类别的数量,为教师模型预测的第i类别痘痘的概率值,为学生模型预测的第i类别痘痘的概率值,yi为真实痘痘类别。
例如:特征图的大小c={20,10,5},n为痘痘类别的数量,即类别数,和分别表示教师模型和学生模型中各个大小相同的特征图,和分别表示教师模型和学生模型预测的第i类别痘痘的概率值,yi为真实痘痘类别,其采用one-hot类别标签算法进行表示。
具体的,请再参阅图3,图3是本申请实施例提供的一种教师模型训练学生模型的示意图;
如图3所示,通过向多个教师模型和学生模型输入痘痘图像,例如:向教师模型A(Teacher_A)-教师模型N(Teacher_N)分别输入同一副痘痘图像,向学生模型(Student)输入另一幅痘痘图像,其中,两幅痘痘图像的大小相同。
通过每一教师模型与学生模型输出的特征图进行计算,得到相似度损失Ll1-sim,其中,相似度损失该相似度损失用于将教师模型与学生模型对应的特征图的位置大小进行相减,以使学生模型的对应大小的特征图学习到的特征与教师模型的对应大小的特征图学习到的特征相近。
同时,为了更好地利用教师模型输出的软标签,对于学生模型训练的类别损失函数采用KD损失函数,以计算得到KD损失LKD,使得模型训练时往KL损失值最小的方向优化,以更好地学习到痘痘类别之间的相似特征,其中,KD损失
进一步地,由于教师模型输出的软标签,存在误分类的情况,因此,对多层损失函数加入交叉熵损失函数,以计算交叉熵损失Ls,其中,该交叉熵损失为学生模型预测的结果与真实标签之间的损失。
可以理解的是,本申请实施例中的学生模型的模型结构可以为任意一种网络模型,例如:Densenet模型,Googlenet模型,Resnet模型、VGG模型以及Mobilenet模型中的至少一种,而不同的网络模型最终的目的都是为了预测出痘痘类别,但是由于学习到的特征有所不同,因此,为了满足模型大小、检测速度以及在公共图像分类数据ImageNet所表现的分类准确度,本申请实施例中优选该学生模型的模型结构为Mobilenet模型,使得学生模型的模型大小合适,并且检测速度快,分类准确度高,有利于更好地预测痘痘类别。
步骤S204:根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
具体的,在学生模型被训练完成之后,调用训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
在本申请实施例中,通过多个教师模型对学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型,使得学生模型能够更好地提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度,进而使本申请能够提高预测痘痘类别的准确率。
请再参阅图4,图4是本申请实施例提供的一种学生模型的迭代训练的流程示意图;
如图4所示,该学生模型的迭代训练的流程,包括:
步骤S401:构建多个教师模型和学生模型;
具体的,多个教师模型中的每一教师模型的网络结构不同,例如:每一教师模型为Densenet模型,Googlenet模型,Resnet模型以及VGG模型中的一种。该学生模型的网络结构为Mobilenet模型。
步骤S402:构建多层损失函数;
具体的,多层损失函数包括相似度损失函数、类别损失函数以及交叉熵损失函数,例如:该多层损失函数为:
其中,Loss为多层损失函数,Ll1-sim为相似度损失函数,LKD为类别损失函数,Ls为交叉熵损失函数,i为痘痘的类别,c为特征图的大小,为教师模型的特征图,为学生模型的特征图,n为痘痘类别的数量,为教师模型预测的第i类别痘痘的概率值,为学生模型预测的第i类别痘痘的概率值,yi为真实痘痘类别。
步骤S403:基于多层损失函数对学生模型进行迭代训练;
具体的,在每一次训练时,随机选择一个教师模型对学生模型进行指导训练,以使学生模型学习到该教师模型的特征提取方式和特征融合方式。例如:该教师模型为教师网络,该学生模型为学生网络,通过教师网络的损失或中间特征对学生网络的损失和中间特征进行约束,从而学习到教师网络的特征提取方式和特征融合方式。可以理解的是,教师模型和学生模型均为生成网络,例如:深度学习图像分割网络,比如:UNet网络,用于将不同输入的目标图像通过损失函数或者约束条件,生成痘痘类别的预测结果。
在本申请实施例中,通过利用不同的教师模型的神经网络结构的不同,通过不同输出的组合作为监督来指导学生模型的训练,能够保证模型大小和预测的速度的前提下,通过多个不同教师模型的指导,能够更好的提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度。
具体的,通过教师模型对学生模型进行知识蒸馏,包括:
通过教师模型中的特征层对学生模型中的特征层进行距离计算,将计算得到的距离作为损失函数,训练学生模型,使得学生模型中的特征层接近教师模型中的特征层。
通过教师模型对学生模型进行知识蒸馏,使得学生模型学习到教师模型的特征提取方式和特征融合方式,能够提高学生模型的网络表达能力,有利于提高预测痘痘类别的准确率。
步骤S404:迭代次数是否大于第一次数阈值;
具体的,本申请实施例采用Adam算法(Adaptive Moment Estimation Algorithm)来优化模型参数。例如:迭代次数设置为500次,初始化学习率设置为0.001,权重衰减设置为0.0005,每50次迭代,学习率衰减为原来的1/10。
可以理解的是,Adam算法(Adaptive Moment Estimation Algorithm),可以看作动量法和RMSprop算法的结合,不但使用动量作为参数更新方向,而且可以自适应调整学习率。
具体的,判断迭代次数是否大于第一次数阈值,该第一次数阈值被预先设置,例如:设置为500次。若迭代次数大于第一次数阈值,则进入步骤S406:训练完成;若迭代次数不大于第一次数阈值,则进入步骤S405:学生模型的损失是否小于第一损失阈值。
可以理解的是,该第一次数阈值根据具体需要进行具体设置,在此不进行限定。
步骤S405:学生模型的损失是否小于第一损失阈值;
具体的,判断学生模型的损失是否小于第一损失阈值,若是,则进入步骤S406:训练完成;若否,则返回步骤S403:基于多层损失函数对学生模型进行迭代训练;
在本申请实施例中,判断学生模型的损失是否小于第一损失阈值,即判断多层损失函数计算得出的损失是否小于第一损失阈值,从而确定是否在迭代次数小于第一次数阈值时提前结束迭代过程,即停止迭代训练,以快速得到训练后的学生模型。其中,该多层损失函数为:
在本申请实施例中,第一损失阈值可以设置为0.0005、0.001,可以理解的是,该第一损失阈值根据具体需要进行具体设置,在此不进行限定。
步骤S406:训练完成;
具体的,在训练完成之后,得到训练后的学生模型,此时可以通过调用该训练后的学生模型,预测目标图像的痘痘类别。
在本申请实施例中,通过提供一种痘痘类别的预测方法,该方法包括:获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
一方面,通过包括多种类别的痘痘的图像的数据集对多个教师模型进行训练,使得多个教师模型能够学习到多种类别的痘痘的特征;另一方面,通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型,使得学生模型能够更好地提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度,进而使本申请能够提高预测痘痘类别的概率。
请参阅图5,图5是本申请实施例提供的一种痘痘类别的预测装置的结构示意图;
其中,该痘痘类别的预测装置,应用于电子设备,具体的,该痘痘类别的预测装置应用于电子设备的一个或多个处理器。
如图5所示,该痘痘类别的预测装置50,包括:
数据集获取单元501,用于获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;
教师模型训练单元502,用于基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;
学生模型训练单元503,用于通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;
痘痘类别预测单元504,用于根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
在本申请实施例中,痘痘类别的预测装置亦可以由硬件器件搭建成的,例如,痘痘类别的预测装置可以由一个或两个以上的芯片搭建而成,各个芯片可以互相协调工作,以完成上述各个实施例所阐述的痘痘类别的预测方法。再例如,痘痘类别的预测装置还可以由各类逻辑器件搭建而成,诸如由通用处理器、数字信号处理器(Digital SignalProcess,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field Programmable Gate Array,FPGA)、单片机、ARM处理器(AdvancedRISC Machines,ARM)或其它可编程逻辑器件、分立门或晶体管逻辑、分立的硬件组件或者这些部件的任何组合而搭建成。
本申请实施例中的痘痘类别的预测装置可以是装置,也可以是终端中的部件、集成电路、或芯片。该装置可以是移动电子设备,也可以为非移动电子设备。示例性的,移动电子设备可以为手机、平板电脑、笔记本电脑、掌上电脑、车载电子设备、可穿戴设备、超级移动个人计算机(ultra-mobile personal computer,UMPC)、上网本或者个人数字助理(personal digital assistant,PDA)等,非移动电子设备可以为服务器、网络附属存储器(Network Attached Storage,NAS)、个人计算机(personal computer,PC)、电视机(television,TV)、柜员机或者自助机等,本申请实施例不作具体限定。
本申请实施例中的痘痘类别的预测装置可以为具有操作系统的装置。该操作系统可以为安卓(Android)操作系统,可以为ios操作系统,还可以为其他可能的操作系统,本申请实施例不作具体限定。
本申请实施例提供的痘痘类别的预测装置能够实现图2实现的各个过程,为避免重复,这里不再赘述。
需要说明的是,上述痘痘类别的预测装置可执行本申请实施例所提供的痘痘类别的预测方法,具备执行方法相应的功能模块和有益效果。未在痘痘类别的预测装置实施例中详尽描述的技术细节,可参见本申请实施例所提供的痘痘类别的预测方法。
在本申请实施例中,通过提供一种痘痘类别的预测装置,包括:数据集获取单元,用于获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;教师模型训练单元,用于基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;学生模型训练单元,用于通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;痘痘类别预测单元,用于根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
一方面,通过包括多种类别的痘痘的图像的数据集对多个教师模型进行训练,使得多个教师模型能够学习到多种类别的痘痘的特征;另一方面,通过多个教师模型对学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型,使得学生模型能够更好地提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度,进而使本申请能够提高预测痘痘类别的概率。
本申请实施例还提供了一种电子设备,请参阅图6,图6是本申请实施例提供的一种电子设备的硬件结构示意图;
如图6所示,该电子设备60包括通信连接的至少一个处理器601和存储器602(图6中以总线连接、一个处理器为例)。
其中,处理器601用于提供计算和控制能力,以控制电子设备60执行相应任务,例如,控制电子设备60执行上述任一方法实施例中的痘痘类别的预测方法,包括:获取图像数据集,其中,图像数据集包括多种类别的痘痘的图像;基于图像数据集,对预设的多个教师模型进行训练,其中,教师模型包括多种不同的网络结构;通过多个教师模型对预设的学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型;根据训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的目标图像的痘痘类别。
一方面,通过包括多种类别的痘痘的图像的数据集对多个教师模型进行训练,使得多个教师模型能够学习到多种类别的痘痘的特征;另一方面,通过多个教师模型对学生模型进行知识蒸馏,以训练学生模型,得到训练后的学生模型,使得学生模型能够更好地提炼出教师模型中学习到的知识,从而提升学生模型的鲁棒性和准确度,进而使本申请能够提高预测痘痘类别的概率。
处理器601可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)、硬件芯片或者其任意组合;还可以是数字信号处理器(Digital Signal Processing,DSP)、专用集成电路(Application SpecificIntegrated Circuit,ASIC)、可编程逻辑器件(programmable logic device,PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(complex programmable logic device,CPLD),现场可编程逻辑门阵列(field-programmable gate array,FPGA),通用阵列逻辑(genericarray logic,GAL)或其任意组合。
存储器602作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序、非暂态性计算机可执行程序以及模块,如本申请实施例中的痘痘类别的预测方法对应的程序指令/模块。处理器601通过运行存储在存储器602中的非暂态软件程序、指令以及模块,可以实现下述任一方法实施例中的痘痘类别的预测方法。具体地,存储器602可以包括易失性存储器(volatile memory,VM),例如随机存取存储器(random access memory,RAM);存储器602也可以包括非易失性存储器(non-volatile memory,NVM),例如只读存储器(read-only memory,ROM),快闪存储器(flash memory),硬盘(hard disk drive,HDD)或固态硬盘(solid-state drive,SSD)或其他非暂态固态存储器件;存储器502还可以包括上述种类的存储器的组合。
在本申请实施例中,存储器602还可以包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
在本申请实施例中,电子设备60还可以具有有线或无线网络接口、键盘以及输入输出接口等部件,以便进行输入输出,电子设备60还可以包括其他用于实现设备功能的部件,在此不做赘述。
本申请实施例还提供了一种计算机可读存储介质,例如包括程序代码的存储器,上述程序代码可由处理器执行以完成上述实施例中的痘痘类别的预测方法。例如,该计算机可读存储介质可以是只读存储器(Read-Only Memory,ROM)、随机存取存储器(RandomAccess Memory,RAM)、只读光盘(Compact Disc Read-Only Memory,CDROM)、磁带、软盘和光数据存储设备等。
本申请实施例还提供了一种计算机程序产品,该计算机程序产品包括一条或多条程序代码,该程序代码存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该程序代码,处理器执行该程序代码,以完成上述实施例中提供的痘痘类别的预测方法的方法步骤。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来程序代码相关的硬件完成,该程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
通过以上的实施方式的描述,本领域普通技术人员可以清楚地了解到各实施方式可借助软件加通用硬件平台的方式来实现,当然也可以通过硬件。本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程是可以通过计算机程序来指令相关的硬件来完成,程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存储记忆体(Random Access Memory,RAM)等。
最后应说明的是:以上实施例仅用以说明本申请的技术方案,而非对其限制;在本申请的思路下,以上实施例或者不同实施例中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上述的本申请的不同方面的许多其它变化,为了简明,它们没有在细节中提供;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的范围。
Claims (10)
1.一种痘痘类别的预测方法,其特征在于,包括:
获取图像数据集,其中,所述图像数据集包括多种类别的痘痘的图像;
基于所述图像数据集,对预设的多个教师模型进行训练,其中,所述教师模型包括多种不同的网络结构;
通过多个所述教师模型对预设的学生模型进行知识蒸馏,以训练所述学生模型,得到训练后的学生模型;
根据所述训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的所述目标图像的痘痘类别。
2.根据权利要求1所述的方法,其特征在于,所述训练所述学生模型,包括:
构建多层损失函数,并基于所述多层损失函数对所述学生模型进行训练。
3.根据权利要求2所述的方法,其特征在于,所述多层损失函数包括:相似度损失函数、类别损失函数以及交叉熵损失函数中的至少一个。
5.根据权利要求1所述的方法,其特征在于,所述通过多个所述教师模型对预设的学生模型进行知识蒸馏,包括:
根据训练后的多个教师模型,对所述图像数据集中的图像进行特征提取,以确定多个第一特征图,其中,每一教师模型对应一个第一特征图;
在每次迭代中,确定第二特征图,并随机选择一个教师模型对所述学生模型进行知识蒸馏,其中,所述第二特征图与所述第一特征图的大小相同。
6.根据权利要求2所述的方法,其特征在于,所述基于所述多层损失函数对所述学生模型进行训练,包括:
基于所述多层损失函数对所述学生模型进行迭代训练;
若迭代次数大于第一次数阈值,或者,所述学生模型的损失小于第一损失阈值,则停止迭代训练。
7.根据权利要求1-6任一项所述的方法,其特征在于,所述痘痘类别包括粉刺、逗后红斑、炎症性丘疹、脓包、结节和囊肿中的至少一种。
8.一种痘痘类别的预测装置,其特征在于,包括:
数据集获取单元,用于获取图像数据集,其中,所述图像数据集包括多种类别的痘痘的图像;
教师模型训练单元,用于基于所述图像数据集,对预设的多个教师模型进行训练,其中,所述教师模型包括多种不同的网络结构;
学生模型训练单元,用于通过多个所述教师模型对预设的学生模型进行知识蒸馏,以训练所述学生模型,得到训练后的学生模型;
痘痘类别预测单元,用于根据所述训练后的学生模型,对包含痘痘的目标图像进行预测,以得到预测的所述目标图像的痘痘类别。
9.一种电子设备,其特征在于,包括:
存储器以及一个或多个处理器,所述一个或多个处理器用于执行存储在所述存储器中的一个或多个计算机程序,所述一个或多个处理器在执行所述一个或多个计算机程序时,使得所述电子设备实现如权利要求1-7任一项所述的方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序包括程序指令,所述程序指令当被处理器执行时使所述处理器执行如权利要求1-7任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111609463.4A CN114266897A (zh) | 2021-12-24 | 2021-12-24 | 痘痘类别的预测方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111609463.4A CN114266897A (zh) | 2021-12-24 | 2021-12-24 | 痘痘类别的预测方法、装置、电子设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114266897A true CN114266897A (zh) | 2022-04-01 |
Family
ID=80830109
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111609463.4A Pending CN114266897A (zh) | 2021-12-24 | 2021-12-24 | 痘痘类别的预测方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114266897A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114926471A (zh) * | 2022-05-24 | 2022-08-19 | 北京医准智能科技有限公司 | 一种图像分割方法、装置、电子设备及存储介质 |
CN115965964A (zh) * | 2023-01-29 | 2023-04-14 | 中国农业大学 | 一种鸡蛋新鲜度识别方法、系统及设备 |
CN116091895A (zh) * | 2023-04-04 | 2023-05-09 | 之江实验室 | 一种面向多任务知识融合的模型训练方法及装置 |
CN116594349A (zh) * | 2023-07-18 | 2023-08-15 | 中科航迈数控软件(深圳)有限公司 | 机床预测方法、装置、终端设备以及计算机可读存储介质 |
-
2021
- 2021-12-24 CN CN202111609463.4A patent/CN114266897A/zh active Pending
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114926471A (zh) * | 2022-05-24 | 2022-08-19 | 北京医准智能科技有限公司 | 一种图像分割方法、装置、电子设备及存储介质 |
CN115965964A (zh) * | 2023-01-29 | 2023-04-14 | 中国农业大学 | 一种鸡蛋新鲜度识别方法、系统及设备 |
CN115965964B (zh) * | 2023-01-29 | 2024-01-23 | 中国农业大学 | 一种鸡蛋新鲜度识别方法、系统及设备 |
CN116091895A (zh) * | 2023-04-04 | 2023-05-09 | 之江实验室 | 一种面向多任务知识融合的模型训练方法及装置 |
CN116594349A (zh) * | 2023-07-18 | 2023-08-15 | 中科航迈数控软件(深圳)有限公司 | 机床预测方法、装置、终端设备以及计算机可读存储介质 |
CN116594349B (zh) * | 2023-07-18 | 2023-10-03 | 中科航迈数控软件(深圳)有限公司 | 机床预测方法、装置、终端设备以及计算机可读存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2020221200A1 (zh) | 神经网络的构建方法、图像处理方法及装置 | |
US20210012198A1 (en) | Method for training deep neural network and apparatus | |
WO2020238293A1 (zh) | 图像分类方法、神经网络的训练方法及装置 | |
WO2021043168A1 (zh) | 行人再识别网络的训练方法、行人再识别方法和装置 | |
WO2019228317A1 (zh) | 人脸识别方法、装置及计算机可读介质 | |
WO2022042713A1 (zh) | 一种用于计算设备的深度学习训练方法和装置 | |
WO2021159714A1 (zh) | 一种数据处理方法及相关设备 | |
WO2021022521A1 (zh) | 数据处理的方法、训练神经网络模型的方法及设备 | |
WO2022001805A1 (zh) | 一种神经网络蒸馏方法及装置 | |
CN114266897A (zh) | 痘痘类别的预测方法、装置、电子设备及存储介质 | |
CN117456297A (zh) | 图像生成方法、神经网络的压缩方法及相关装置、设备 | |
CN113705769A (zh) | 一种神经网络训练方法以及装置 | |
CN110222718B (zh) | 图像处理的方法及装置 | |
CN111898703B (zh) | 多标签视频分类方法、模型训练方法、装置及介质 | |
WO2021129668A1 (zh) | 训练神经网络的方法和装置 | |
WO2021018251A1 (zh) | 图像分类方法及装置 | |
WO2021175278A1 (zh) | 一种模型更新方法以及相关装置 | |
WO2022012668A1 (zh) | 一种训练集处理方法和装置 | |
WO2021184902A1 (zh) | 图像分类方法、装置、及其训练方法、装置、设备、介质 | |
WO2022156475A1 (zh) | 神经网络模型的训练方法、数据处理方法及装置 | |
CN115238909A (zh) | 一种基于联邦学习的数据价值评估方法及其相关设备 | |
CN114299304A (zh) | 一种图像处理方法及相关设备 | |
WO2021136058A1 (zh) | 一种处理视频的方法及装置 | |
CN114169393A (zh) | 一种图像分类方法及其相关设备 | |
WO2024046144A1 (zh) | 一种视频处理方法及其相关设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |