CN114758180B - 一种基于知识蒸馏的轻量化花卉识别方法 - Google Patents

一种基于知识蒸馏的轻量化花卉识别方法 Download PDF

Info

Publication number
CN114758180B
CN114758180B CN202210412189.XA CN202210412189A CN114758180B CN 114758180 B CN114758180 B CN 114758180B CN 202210412189 A CN202210412189 A CN 202210412189A CN 114758180 B CN114758180 B CN 114758180B
Authority
CN
China
Prior art keywords
network
flower
student
teacher
picture
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210412189.XA
Other languages
English (en)
Other versions
CN114758180A (zh
Inventor
韦旭东
张红雨
李博
史长凯
韩欢
钟山
王曦
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Electronic Science and Technology of China
Original Assignee
University of Electronic Science and Technology of China
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Electronic Science and Technology of China filed Critical University of Electronic Science and Technology of China
Priority to CN202210412189.XA priority Critical patent/CN114758180B/zh
Publication of CN114758180A publication Critical patent/CN114758180A/zh
Application granted granted Critical
Publication of CN114758180B publication Critical patent/CN114758180B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Abstract

本发明公开了一种基于知识蒸馏的轻量化花卉识别方法,包括以下步骤:S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;S2.选定教师网络和学生网络;S3.对教师网络初始化和训练,得到成熟的教师网络;S4.对学生网络进行初始化;S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;S6.将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。本发明使得轻量级花卉识别模型在模型大幅压缩的同时还能保持较高的准确率。

Description

一种基于知识蒸馏的轻量化花卉识别方法
技术领域
本发明涉及花卉识别,特别是涉及一种基于知识蒸馏的轻量化花卉识别方法。
背景技术
在农、林业发展中,花卉种类的快速准确鉴别具有重要的意义。传统的花卉识别方法易受到花卉形态多样性、背景环境复杂性及光照条件多变性的影响,其准确率与泛化性能有待提升。而深层卷积神经网络(Deep convolutional neural network,DCNN)在高速计算设备的辅助下可以自动学习视觉目标语义特征的特点,解决了复杂环境下的视觉目标的鲁棒性识别问题,在花卉识别应用中具有较大潜力。但在实际应用中,人们更希望能够利用便携式设备及时获得花卉的种类信息,从而在数据产生地点实时进行分析,以便于最有效地对花卉资源进行开发利用。因此在算力弱、存储成本高但是便于携带的AI边缘计算设备上高效运行DCNN花卉分类模型对于户外实时花卉识别具有重大的研究价值与意义。目前,相关研究人员已构建出多种CNN模型来进行花卉的识别;
为了追求更好的分类效果,大多数的网络模型结构变得愈发庞杂。虽然相关任务准确率得到了提升,但通过加深网络来提高准确率会增加较大的参数量,导致网络的运算量增加,需要花费极大的运算资源,使得其难以应用到AI边缘计算设备上。轻量级DCNN模型的优势主要在于构建出更加高效的卷积网络计算方式,在模型大幅压缩的同时兼顾良好的网络性能。
相较于重量级网络而言,轻量级网络的预测时间、运算力需求以及模型储存占用量都得到了极大减少,使得该类网络更加适合于移动平台的应用。但是经过实验对比发现,轻量级网络在识别的准确率上和重量级网络还有明显的差距。
发明内容
本发明的目的在于克服现有技术的不足,提供一种基于知识蒸馏的轻量化花卉识别方法,使用知识蒸馏的算法,利用重量级网络辅助训练轻量级网络,在模型大幅压缩的同时尽量减低准确率方面的损失,以此得到一个模型大幅压缩而且保持较高准确率的轻量级花卉识别模型。
本发明的目的是通过以下技术方案来实现的:一种基于知识蒸馏的轻量化花卉识别方法,包括以下步骤:
S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;
所述花卉数据集中包含m张花卉图片,根据每一张花卉图片的花类别,构建该图片的真实标签;所述真实标签由N个数字构成数组:若花卉图片属于第n个花类别,则真实标签的第n个数字为1,其余数字为0;花卉数据集中共有N个花类别,即花卉数据集中共有N个不同的真实标签;并且在所述花卉数据集中,每个花类别具有至少两张花卉图片;
在本申请的实施例中,所使用的花卉数据集为牛津大学制作并提供公开下载的Oxford-Flower102数据集或Oxford-Flower17数据集。其中Oxford-Flower102数据集包含102个花类别,每个类包含40到258个图片,共8189张图片;Oxford-Flower17数据集,包含17个花类别,每个类别80张图片,共1360张图片。
将花卉数据集划分为训练集和测试集,并使得训练集和测试集均包含N个花类别的花卉图片;
S2.选定教师网络和学生网络;
S3.对教师网络初始化和训练,得到成熟的教师网络;
S4.对学生网络进行初始化;
S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;
S6.将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。
其中,所述步骤S2中,选定一个模型较大准确率较高的神经网络作为教师网络,模型较小准确率较低的神经网络作为学生网络;
所述模型较大准确率较高的神经网络包括SeNet152网络或MobilNetV3-Large网络;
所述模型较小准确率较低的神经网络包括MobilNetV3-Small网络。
其中,所述步骤S3包括:
S301.教师网络加载预先设定的ImageNet预训练权重(ImageNet预训练权重为由Pytorch官方提供的ImageNet预训练权重),并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
利用新建的全连接层替换教师网络原有的最后一个连接层,完成教师网络的初始化;当图片输入教师网络时,教师网络的全连接层输出的是:该图片为各个花卉类别的概率;
S302.对于训练集中的任一张图片,将该图片输入教师网络做前向运算得到教师网络的输出y:
设教师网络共有K层,其中第i层的输入输出表示为
yi=σi(xi*wi+bi)
其中i=1,2,…K;yi表示教师网络的第i层输出,xi表示教师网络的第i层的输入,σi表示教师网络第i层所用的激活函数;设教师网络最后一层的输出为y,教师网络最后一层的输出也叫作教师网络的输出,其中包含了输入图片为各个花类别的概率;
通过CrossEntropyLoss函数计算y和真实标签label之间的硬损失Lhard_t,
Lhard_t=CrossEntroyLoss(y,lable)
其中,label表示当前输入图片的真实标签,
使用Lhard_t对教师网络进行反向传播并结合Adam优化器,更新教师网络的参数:
Wi,Bi=Adam(Lhard_t,wi,bi,lr)
其中,Adam优化器表示为Adam函数,wi,bi表示教师网络第i层更新前的参数,Wi,Bi表示教师网络第i层更新后的参数,lr为学习率;
S303.对于训练集的每一张图片,重复执行步骤S302,对教师网络参数进行更新,所有图像下的更新完成时,得到训练后的教师网络;
S304.对于测试集每一张图片,将该图片输入S303训练后的教师网络做前向运算得到教师网络的预测输出y,将y和真实标签对比以判断当前图片是否预测正确,直到测试集的所有图像预测完成,统计得到教师网络的准确率;
S305.重复执行步骤S303~S304共200次,得到200个训练后的教师网络,选择其中在测试集具有最高准确率的一个训练后的神经网络,将其作为成熟的教师神经网络。
其中,所述步骤S4包括:
学生网络加载预先设定的ImageNet预训练权重(所述ImageNet预训练权重为由Pytorch官方提供的ImageNet预训练权重),并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
利用新建的全连接层替换学生网络原有的最后一个连接层,完成学生网络的初始化;当图片输入学生网络时,学生网络的全连接层输出的是:该图片为各个花卉类别的概率。
其中,所述步骤S5包括以下子步骤:
S501.采用步骤S3中得到的成熟的教师网络,设置为eval模式,eval模式即评估模式,不参与反向传播;
S502.对于训练集中的任一张图片,将其同时输入成熟的教师网络和初始化后的学生网络,做前向运算;
计算学生网络硬输出和真实标签label之间的硬损失Lhard_s,学生网络软输出和教师网络软输出之间的蒸馏损失Lsoft,最终得到总损失L=(1-α)*Lhard_s+α*Lsoft
其中,α表示Lsoft在总损失中的比重;T表示蒸馏所用的温度;vj表示教师网络的硬预测输出在第j类花卉类别上概率值;zj表示学生网络的硬预测输出在第j类花卉类别上的概率值;表示教师网络在温度T下的软预测输出在第j类花卉类别上概率值;/>表示学生网络在温度T下的软预测输出在第j类花卉类别上概率值;cj表示真实标签在第j类上的值;N表示总类别数量;
S503.使用总损失L对学生网络进行反向传播并结合Adam优化器更新学生网络的参数;
Wi,Bi=A d a(m,Li,wi,b
其中wi,bi表示学生网络第i层更新前的参数,Wi,Bi表示学生网络第i层更新后的参数,lr为学习率;
S504.对于训练集的每一张图片,重复执行步骤S502-503,对学生网络参数进行更新,所有图像下的更新完成时,得到训练后的学生网络;
S505.对于测试集每一张图片,将该将图片输入S503训练后的学生网络做前向运算得到学生网络的预测输出y,将y和真实标签对比以判断当前图片是否预测正确,直到测试集的所有图像预测完成,统计得到学生网络的准确率;
S506.重复执行步骤S504~S505共200次,得到200个训练后的学生网络,选择其中在测试集具有最高准确率的一个训练后的学生网络,将其作为成熟的学生神经网络。
本发明的有益效果是:本发明利用重量级网络辅助训练轻量级网络,在模型大幅压缩的同时尽量减低准确率方面的损失,使得轻量级网络在模型大幅压缩的同时还能保持较高的准确率。
附图说明
图1为本发明训练过程的流程图;
图2为本发明识别过程的流程图。
具体实施方式
下面结合附图进一步详细描述本发明的技术方案,但本发明的保护范围不局限于以下所述。
知识蒸馏是用大神经网络去指导小神经网络的训练(就是把大网络的输出当做小网络训练时的学习目标),以此把大网络学到的知识迁移到小网络中,从而达到提高小网络性能或是压缩大网络模型的目的。之所以叫蒸馏是因为大网络的输出概率分布比较极端不均匀,使用的温度T使得大网络的输出软化,即变的更均匀,这样小网络比较容易学习大网络软化后的输出。本专利相当于把知识蒸馏算法应用到花卉分类领域中,进行花卉识别,具体地:
如图1所示,一种基于知识蒸馏的轻量化花卉识别方法,包括以下步骤:
S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;
所述花卉数据集中包含m张花卉图片,根据每一张花卉图片的花类别,构建该图片的真实标签;所述真实标签由N个数字构成数组:若花卉图片属于第n个花类别,则真实标签的第n个数字为1,其余数字为0;
例如,某一张花卉图片属于第一个花类别,则其真实标签[x1,x2,...xN]中,x1=1,x2~xN均为0,同理,若某一张花卉图片属于第二个花类别,则其真实标签[x1,x2,...xN]中,x2=1,x1以及x3~xN均为0。
花卉数据集中共有N个花类别,即花卉数据集中共有N个不同的真实标签;并且在所述花卉数据集中,每个花类别具有至少两张花卉图片;
将花卉数据集划分为训练集和测试集,并使得训练集和测试集均包含N个花类别的花卉图片;
S2.选定教师网络和学生网络;
S3.对教师网络初始化和训练,得到成熟的教师网络;
S4.对学生网络进行初始化;
S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;
S6.在训练完成后,如图2所示,将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。
其中,所述步骤S2中,选定一个模型较大准确率较高的神经网络作为教师网络,模型较小准确率较低的神经网络作为学生网络;
所述模型较大准确率较高的神经网络包括SeNet152网络或MobilNetV3-Large网络;
所述模型较小准确率较低的神经网络包括MobilNetV3-Small网络。
其中,所述步骤S3包括:
S301.教师网络加载预先设定的ImageNet预训练权重(ImageNet预训练权重为由Pytorch官方提供的ImageNet预训练权重),并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
利用新建的全连接层替换教师网络原有的最后一个连接层,完成教师网络的初始化;当图片输入教师网络时,教师网络的全连接层输出的是:该图片为各个花卉类别的概率;
S302.对于训练集中的任一张图片,将该图片输入教师网络做前向运算得到教师网络的输出y:
设教师网络共有K层,其中第i层的输入输出表示为
yi=σi(xi*wi+bi)
其中i=1,2,…K;yi表示教师网络的第i层输出,xi表示教师网络的第i层的输入,σi表示教师网络第i层所用的激活函数;设教师网络最后一层的输出为y,教师网络最后一层的输出也叫作教师网络的输出,其中包含了输入图片为各个花类别的概率;
通过CrossEntropyLoss函数计算y和真实标签label之间的硬损失Lhard_t,
Lhard_t=CrossEntroyLoss(y,lable)
其中,label表示当前输入图片的真实标签,
使用Lhard_t对教师网络进行反向传播并结合Adam优化器,更新教师网络的参数:
Wi,Bi=Adam(Lhard_t,wi,bi,lr)
其中,Adam优化器表示为Adam函数,wi,bi表示教师网络第i层更新前的参数,Wi,Bi表示教师网络第i层更新后的参数,lr为学习率;
S303.对于训练集的每一张图片,重复执行步骤S302,对教师网络参数进行更新,所有图像下的更新完成时,得到训练后的教师网络;
S304.对于测试集每一张图片,将该图片输入S303训练后的教师网络做前向运算得到教师网络的预测输出y,将y和真实标签对比以判断当前图片是否预测正确,直到测试集的所有图像预测完成,统计得到教师网络的准确率;
S305.重复执行步骤S303~S304共200次,得到200个训练后的教师网络,选择其中在测试集具有最高准确率的一个训练后的神经网络,将其作为成熟的教师神经网络。
其中,所述步骤S4包括:
学生网络加载预先设定的ImageNet预训练权重(所述ImageNet预训练权重为由Pytorch官方提供的ImageNet预训练权重),并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
利用新建的全连接层替换学生网络原有的最后一个连接层,完成学生网络的初始化;当图片输入学生网络时,学生网络的全连接层输出的是:该图片为各个花卉类别的概率。
其中,所述步骤S5包括以下子步骤:
S501.采用步骤S3中得到的成熟的教师网络,设置为eval模式,eval模式即评估模式,不参与反向传播;
S502.对于训练集中的任一张图片,将其同时输入成熟的教师网络和初始化后的学生网络,做前向运算;
计算学生网络硬输出和真实标签label之间的硬损失Lhard_s,学生网络软输出和教师网络软输出之间的蒸馏损失Lsoft,最终得到总损失L=(1-α)*Lhard_s+α*Lsoft
其中,α表示Lsoft在总损失中的比重;T表示蒸馏所用的温度;vj表示教师网络的硬预测输出在第j类花卉类别上概率值;zj表示学生网络的硬预测输出在第j类花卉类别上的概率值;表示教师网络在温度T下的软预测输出在第j类花卉类别上概率值;/>表示学生网络在温度T下的软预测输出在第j类花卉类别上概率值;cj表示真实标签在第j类上的值;N表示总类别数量;硬预测输出是指将图片输入教师网络或学生网络后,由教师网络或学生网络直接输出的数据;软预测输出是指与温度T相关的预测值,其计算方式在上述公式中已经给出。
S503.使用总损失L对学生网络进行反向传播并结合Adam优化器更新学生网络的参数;
Wi,Bi=A d a(m,Li,wi,b
其中wi,bi表示学生网络第i层更新前的参数,Wi,Bi表示学生网络第i层更新后的参数,lr为学习率;
S504.对于训练集的每一张图片,重复执行步骤S502-503,对学生网络参数进行更新,所有图像下的更新完成时,得到训练后的学生网络;
S505.对于测试集每一张图片,将该将图片输入S503训练后的学生网络做前向运算得到学生网络的预测输出y,将y和真实标签对比以判断当前图片是否预测正确,直到测试集的所有图像预测完成,统计得到学生网络的准确率;
S506.重复执行步骤S504~S505共200次,得到200个训练后的学生网络,选择其中在测试集具有最高准确率的一个训练后的学生网络,将其作为成熟的学生神经网络。
在本申请的实施例中,采用的数据集为牛津大学所制作的Oxford-Flower102数据集和Oxford-Flower17数据集.平台是联想Legion R700,处理器AMD Ryzen7 4800H,显卡NVIDIA GeForce GTX 1650,内存16.0GB,windows10操作系统.仿真软件PyCharm2021.1.3,运行环境python3.7,pytorch1.9。表1为不同模型在Oxford-Flower102上的大小和准确率对比,学生网络是MobileNetV3-small,教师网络是SeNet152,均使用迁移学习,epoch=200
表1
表2为不同模型在Oxford-Flower17上的大小和准确率对比,其中学生网络是MobileNetV3-small,教师网络是MobileNetV3-large,均使用迁移学习,epoch=200
表2
从表1和表2可以看出轻量级花卉识别模型在使用知识蒸馏算法进行训练后,模型大小不变而识别准确率明显提升,在Oxford-Flower102和Oxford-Flower17上分别提高了0.5%和0.6%。如表1所示,在Oxford-Flower102上,MobileNetV3-small+知识蒸馏的准确率相比SENet152低0.7%,而模型大小仅是SENet152的1/40;相比ResNet18,模型大小是其1/7,准确率反而高0.2%。以上证明了本发明所用算法的有效性。Oxford-Flower17是只有1360张图片的小数据集,表2的结果表明本发明所用算法在数据较少时同样有效。
上述说明示出并描述了本发明的一个优选实施例,但如前所述,应当理解本发明并非局限于本文所披露的形式,不应看作是对其他实施例的排除,而可用于各种其他组合、修改和环境,并能够在本文所述发明构想范围内,通过上述教导或相关领域的技术或知识进行改动。而本领域人员所进行的改动和变化不脱离本发明的精神和范围,则都应在本发明所附权利要求的保护范围内。

Claims (4)

1.一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:包括以下步骤:
S1.构建花卉数据集,并将花卉数据集划分为训练集和测试集;
所述花卉数据集中包含m张花卉图片,根据每一张花卉图片的花类别,构建该图片的真实标签;所述真实标签由N个数字构成数组:若花卉图片属于第n个花类别,则真实标签的第n个数字为1,其余数字为0;花卉数据集中共有N个花类别,即花卉数据集中共有N个不同的真实标签;并且在所述花卉数据集中,每个花类别具有至少两张花卉图片;
将花卉数据集划分为训练集和测试集,并使得训练集和测试集均包含N个花类别的花卉图片;
S2.选定教师网络和学生网络;
S3.对教师网络初始化和训练,得到成熟的教师网络;
S4.对学生网络进行初始化;
S5.在教师网络的辅助下,使用花卉数据集训练初始化后的学生网络,得到成熟的学生神经网络;
所述步骤S5包括以下子步骤:
S501.采用步骤S3中得到的成熟的教师网络,设置为eval模式,eval模式即评估模式,不参与反向传播;
S502.对于训练集中的任一张图片,将其同时输入成熟的教师网络和初始化后的学生网络,做前向运算;
计算学生网络硬输出和真实标签label之间的硬损失Lhard_s,学生网络软输出和教师网络软输出之间的蒸馏损失Lsoft,最终得到总损失L=(1-α)*Lhard_s+α*Lsoft
其中,α表示Lsoft在总损失中的比重;T表示蒸馏所用的温度;vj表示教师网络的硬预测输出在第j类花卉类别上概率值;zj表示学生网络的硬预测输出在第j类花卉类别上的概率值;表示教师网络在温度T下的软预测输出在第j类花卉类别上概率值;/>表示学生网络在温度T下的软预测输出在第j类花卉类别上概率值;cj表示真实标签在第j类花卉类别上概率值;N表示总类别数量;
S503.使用总损失L对学生网络进行反向传播并结合Adam优化器更新学生网络的参数;
Wi,Bi=Adam(L,wi,bi,lr)
其中wi,bi表示学生网络第i层更新前的参数,Wi,Bi表示学生网络第i层更新后的参数,lr为学习率;
S504.对于训练集的每一张图片,重复执行步骤S502-503,对学生网络参数进行更新,所有图像下的更新完成时,得到训练后的学生网络;
S505.对于测试集每一张图片,将该图片输入S503训练后的学生网络做前向运算得到学生网络的预测输出y,将y和真实标签对比以判断当前图片是否预测正确,直到测试集的所有图像预测完成,统计得到学生网络的准确率;
S506.重复执行步骤S504~S505共200次,得到200个训练后的学生网络,选择其中在测试集具有最高准确率的一个训练后的学生网络,将其作为成熟的学生神经网络;
S6.将成熟的学生神经网络设置为eval模式,不进行反向传播;将待识别花卉图片输入成熟的学生神经网络,通过前向传播计算并输出识别结果,至此花卉识别结束。
2.根据权利要求1所述的一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:所述步骤S2中,选定一个模型较大准确率较高的神经网络作为教师网络,模型较小准确率较低的神经网络作为学生网络;
所述模型较大准确率较高的神经网络包括SeNet152网络或MobilNetV3-Large网络;
所述模型较小准确率较低的神经网络包括MobilNetV3-Small网络。
3.根据权利要求1所述的一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:所述步骤S3包括:
S301.教师网络加载预先设定的ImageNet预训练权重,并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
利用新建的全连接层替换教师网络原有的最后一个连接层,完成教师网络的初始化;当图片输入教师网络时,教师网络的全连接层输出的是:该图片为各个花卉类别的概率;
S302.对于训练集中的任一张图片,将该图片输入教师网络做前向运算得到教师网络的输出y:
设教师网络共有K层,其中第i层的输入输出表示为
yi=σi(xi*wi+bi)
其中i=1,2,…K;yi表示教师网络的第i层输出,xi表示教师网络的第i层的输入,σi表示教师网络第i层所用的激活函数;设教师网络最后一层的输出为y,教师网络最后一层的输出也叫作教师网络的输出,其中包含了输入图片为各个花类别的概率;
通过CrossEntropyLoss函数计算y和真实标签label之间的硬损失Lhard_t,
Lhard_t=CrossEntroyLoss(y,lable)
其中,label表示当前输入图片的真实标签,
使用Lhard_t对教师网络进行反向传播并结合Adam优化器,更新教师网络的参数:
Wi,Bi=Adam(Lhard_t,wi,bi,lr)
其中,Adam优化器表示为Adam函数,wi,bi表示教师网络第i层更新前的参数,Wi,Bi表示教师网络第i层更新后的参数,lr为学习率;
S303.对于训练集的每一张图片,重复执行步骤S302,对教师网络参数进行更新,所有图像下的更新完成时,得到训练后的教师网络;
S304.对于测试集每一张图片,将该图片输入S303训练后的教师网络做前向运算得到教师网络的预测输出y,将y和真实标签对比以判断当前图片是否预测正确,直到测试集的所有图像预测完成,统计得到教师网络的准确率;
S305.重复执行步骤S303~S304共200次,得到200个训练后的教师网络,选择其中在测试集具有最高准确率的一个训练后的神经网络,将其作为成熟的教师神经网络。
4.根据权利要求1所述的一种基于知识蒸馏的轻量化花卉识别方法,其特征在于:所述步骤S4包括:
学生网络加载预先设定的ImageNet预训练权重,并根据花卉总类别的数目N,构建一个新的全连接层:新的全连接层的输出类别与花卉训练数据集的总类别数目相同且一一对应;
利用新建的全连接层替换学生网络原有的最后一个连接层,完成学生网络的初始化当图片输入学生网络时,学生网络的全连接层输出的是:该图片为各个花卉类别的概率。
CN202210412189.XA 2022-04-19 2022-04-19 一种基于知识蒸馏的轻量化花卉识别方法 Active CN114758180B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210412189.XA CN114758180B (zh) 2022-04-19 2022-04-19 一种基于知识蒸馏的轻量化花卉识别方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210412189.XA CN114758180B (zh) 2022-04-19 2022-04-19 一种基于知识蒸馏的轻量化花卉识别方法

Publications (2)

Publication Number Publication Date
CN114758180A CN114758180A (zh) 2022-07-15
CN114758180B true CN114758180B (zh) 2023-10-10

Family

ID=82331990

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210412189.XA Active CN114758180B (zh) 2022-04-19 2022-04-19 一种基于知识蒸馏的轻量化花卉识别方法

Country Status (1)

Country Link
CN (1) CN114758180B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116402116B (zh) * 2023-06-05 2023-09-05 山东云海国创云计算装备产业创新中心有限公司 神经网络的剪枝方法、系统、设备、介质及图像处理方法
CN117058437B (zh) * 2023-06-16 2024-03-08 江苏大学 一种基于知识蒸馏的花卉分类方法、系统、设备及介质

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN114049513A (zh) * 2021-09-24 2022-02-15 中国科学院信息工程研究所 一种基于多学生讨论的知识蒸馏方法和系统
CN114241282A (zh) * 2021-11-04 2022-03-25 河南工业大学 一种基于知识蒸馏的边缘设备场景识别方法及装置

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20030177675A1 (en) * 2002-03-19 2003-09-25 Faulkner Willard M. Flexible plant identification display cards

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN114049513A (zh) * 2021-09-24 2022-02-15 中国科学院信息工程研究所 一种基于多学生讨论的知识蒸馏方法和系统
CN114241282A (zh) * 2021-11-04 2022-03-25 河南工业大学 一种基于知识蒸馏的边缘设备场景识别方法及装置

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
Xudong Wei 等.A Lightweight Flower Classification Model Based on Improved Knowledge Distillation.2022 IEEE 10th Joint International Information Technology and Artificial Intelligence Conference.2022,全文. *
刘丰.基于注意力机制的低分辨率图像目标检测技术研究.中国优秀硕士学位论文全文数据库信息科技辑.2023,全文. *
李延超 等.自适应主动半监督学习方法.软件学报.2020,全文. *

Also Published As

Publication number Publication date
CN114758180A (zh) 2022-07-15

Similar Documents

Publication Publication Date Title
CN110598029B (zh) 基于注意力转移机制的细粒度图像分类方法
CN114758180B (zh) 一种基于知识蒸馏的轻量化花卉识别方法
CN109544524A (zh) 一种基于注意力机制的多属性图像美学评价系统
CN109829541A (zh) 基于学习自动机的深度神经网络增量式训练方法及系统
CN109657780A (zh) 一种基于剪枝顺序主动学习的模型压缩方法
CN111160474A (zh) 一种基于深度课程学习的图像识别方法
CN110134964B (zh) 一种基于层次化卷积神经网络和注意力机制的文本匹配方法
CN108334499A (zh) 一种文本标签标注设备、方法和计算设备
CN107358293A (zh) 一种神经网络训练方法及装置
CN114049513A (zh) 一种基于多学生讨论的知识蒸馏方法和系统
CN114332545B (zh) 一种基于低比特脉冲神经网络的图像数据分类方法和装置
Islam et al. InceptB: a CNN based classification approach for recognizing traditional bengali games
CN109740012B (zh) 基于深度神经网络对图像语义进行理解和问答的方法
CN116797423B (zh) 一种基于全局优化的高校自动快速排课方法与系统
CN114943345A (zh) 基于主动学习和模型压缩的联邦学习全局模型训练方法
US20230222768A1 (en) Multiscale point cloud classification method and system
CN117236421B (zh) 一种基于联邦知识蒸馏的大模型训练方法
CN112667797B (zh) 自适应迁移学习的问答匹配方法、系统及存储介质
CN110047088B (zh) 一种基于改进教与学优化算法的ht-29图像分割方法
Zhang Modern art design system based on the deep learning algorithm
CN113032613A (zh) 一种基于交互注意力卷积神经网络的三维模型检索方法
WO2023134142A1 (zh) 一种多尺度点云分类方法及系统
CN114444654A (zh) 一种面向nas的免训练神经网络性能评估方法、装置和设备
CN109726690A (zh) 基于DenseCap网络的学习者行为图像多区域描述方法
CN113128661A (zh) 信息处理装置和信息处理方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant