CN114549905A - 一种基于改进在线知识蒸馏算法的图像分类方法 - Google Patents

一种基于改进在线知识蒸馏算法的图像分类方法 Download PDF

Info

Publication number
CN114549905A
CN114549905A CN202210183421.7A CN202210183421A CN114549905A CN 114549905 A CN114549905 A CN 114549905A CN 202210183421 A CN202210183421 A CN 202210183421A CN 114549905 A CN114549905 A CN 114549905A
Authority
CN
China
Prior art keywords
network
student
training
student network
error
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210183421.7A
Other languages
English (en)
Inventor
陈莹
邵柏潭
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jiangnan University
Original Assignee
Jiangnan University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jiangnan University filed Critical Jiangnan University
Priority to CN202210183421.7A priority Critical patent/CN114549905A/zh
Publication of CN114549905A publication Critical patent/CN114549905A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

本发明公开了一种基于改进在线知识蒸馏算法的图像分类方法,属于深度学习中的模型压缩技术领域。该方法设置错题集保存各学生网络训练过程中集成输出特征,通过训练图片的真实标签判断错题集中记录的集成输出特征是否正确,若记录的特征正确,则更新网络参数以使得网络集合中的每一个网络的输出特征zi的分布靠近错题集中记录的特征分布,若记录的特征错误,则让zi的分布远离错题集中记录的特征分布。即利用错题集保存网络集合集成输出的历史信息,将其作为学生网络优化训练的监督信息,并提出一种反思机制使学生网络的输出分布远离错误的历史信息,靠近正确的历史信息,迫使学生网络学习到更优质的特征,从而在对图像进行分类时得到更准确的分类结果。

Description

一种基于改进在线知识蒸馏算法的图像分类方法
技术领域
本发明涉及一种基于改进在线知识蒸馏算法的图像分类方法,属于深度学习中的模型压缩技术领域。
背景技术
深度学习方法在计算机视觉领域大放异彩,研究者们通常利用庞大的神经网络来提高识别精度,但是大网络存在多方面的问题:大网络参数量巨大,导致一些移动设备或嵌入式设备无法存储这样的网络模型;计算量巨大,导致执行一次推理需要花费许多时间,无法满足实时性的需求;而且能量消耗巨大这个问题在嵌入式设备上最为关键,嵌入式设备要考虑续航时间的需求。
因此模型压缩技术应运而生,如今模型压缩技术主要包括模型剪枝、模型量化、知识蒸馏、轻量化网络设计以及矢量分解,其中知识蒸馏本质上是一种基于迁移学习的方法,旨在将教师网络的知识蒸馏给学生网络,在部署时使用学生网络代替教师网络执行推理,以此实现模型压缩的目的。
近年来,知识蒸馏方法的研究层出不穷,按照蒸馏的方法可以分为离线知识蒸馏和在线知识蒸馏,其中离线知识蒸馏是教师-学生结构的二阶段的蒸馏方法,二阶段表示先训练一个大的教师网络然后将教师网络的知识迁移给学生网络,教师网络和学生网络的训练过程是分开进行的,因此需要两个阶段的训练过程;而在线知识蒸馏是学生-学生结构的一阶段的蒸馏方法,表示多个学生网络互相学习,在一个训练阶段同时优化,在线知识蒸馏不像离线知识蒸馏那样教师网络和学生网络的训练是分开进行的。因此从实现的角度来说,在线知识蒸馏相比离线知识蒸馏更加方便,从训练成本的角度来说在线知识蒸馏节省了单独训练一个教师网络的时间和计算资源成本,因此在线知识蒸馏的研究具有一定的意义。
但是现有的在线知识蒸馏方法缺少离线知识蒸馏的教师网络所能提供的稳定且合理的监督信息,因此在线知识蒸馏方法的几个学生网络在前期训练时产生的监督信息不够准确的问题,会导致蒸馏得到的网络性能受损,从而导致在无法很好的完成计算机视觉任务,比如在进行图像分类时,出现分类精度低的问题。
发明内容
为了提高使用在线知识蒸馏方法得到的网络进行图像分类的精度,本发明提供了一种基于改进在线知识蒸馏算法的图像分类方法,所述方法包括:
Step1:确定在线知识蒸馏算法中的学生网络集合,并设置错题集以保存各学生网络训练过程中的集成输出特征,通过公开的图像分类数据集中的训练图片训练各学生网络以更新网络参数,得到训练好的学生网络集合;更新各学生网络参数时,通过训练图片的真实标签判断错题集中记录的集成输出特征是否正确,如果记录的是正确的特征,则更新网络参数以使得网络集合中的每一个网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,就让zi的分布远离错题集中记录的特征分布;
Step2:将待分类图像采用不同的随机变换得到变换后的图像以输入训练好的各学生网络中,得到各学生网络的输出特征,将所有学生网络的输出特征取平均得到对应于待分类图像的集成输出特征,根据待分类图像的集成输出特征对待分类图像进行分类。
可选的,设公开的图像分类数据集包含N个训练图片,共C个类别,则初始化错题集为
Figure BDA0003501947440000021
的零矩阵;
Figure BDA0003501947440000022
为N×C维的实数集;
所述Step1中训练各学生网络以更新网络参数,包括:
对公开的图像分类数据集中的训练图片进行预处理;
使用反思机制训练学生网络集合中的各学生网络,设置网络训练批次为128,学习率为0.1,动量为0.9,权重衰减正则系数为0.0005,训练总轮数为300,在训练第150次和第225次时学习率衰减10倍,蒸馏温度系数设置为3.0,错题集的历史信息保留率γ为0.2;
各学生网络前向传播更新网络参数时,得到第i个学生网络的输出特征zi,将所有学生网络的输出平均得到集成输出特征ze,前向时将集成输出更新到错题集中,表达式表示为:
NB←γNB+(1-γ)ze
各学生网络反向传播更新网络参数时,通过训练图片的真实标签可以判断出错题集中记录的特征是否正确,如果记录的是正确的特征,则让学生网络集合中的每一个学生网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,则让zi的分布远离错题集中记录的特征分布;即各学生网络的损失函数设置为:
Figure BDA0003501947440000023
其中,LHKD(NB,zi)为蒸馏损失函数,LCE为交叉熵损失函数。
可选的,所述使用反思机制训练学生网络集合中的各学生网络时,设置初始反思系数μ=μ0为0.005,设验证监视窗口大小w为10。
可选的,当各学生网络完成一轮次的训练和验证后,记录各学生网络当前轮次的验证准确率,当各学生网络的验证准确率连续w轮超过第一阈值时,设置反思系数为第一预定值;当各学生网络的验证准确率连续ω轮超过第二阈值时,设置反思系数为第二预定值。
可选的,所述第一阈值为75%,对应的第一预定值为0.01,即令μ=2μ0为0.01。
可选的,所述第二阈值为93.5%,对应的第二预定值为0.02,即令μ=4μ0为0.02。
可选的,所述通过公开的图像分类数据集中的训练图片训练各学生网络以更新网络参数时,对不同的学生网络,采用不同的预处理对同一训练图片做变换后得到各学生网络的输入。
可选的,所述预处理包括随机水平翻转、随机裁剪、边界部分用零填充后将分辨率调整到32*32并对其进行归一化处理。
本发明有益效果是:
通过利用错题集模块保存网络集合集成输出的历史信息,并将其作为学生网络优化训练的监督信息,并提出一种反思机制使学生网络的输出分布远离错误的历史信息,靠近正确的历史信息,迫使学生网络学习到更优质的特征,从而在对图像进行分类时得到更准确的分类结果。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为反思协作学习在线知识蒸馏的方法结构图。
图2为CIFAR100数据集上resnet32网络DML、ONE、OKDDip、KDCL与本发明方法的准确率结果对比图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明实施方式作进一步地详细描述。
实施例一:
本实施例提供一种基于改进在线知识蒸馏算法的图像分类方法,参见图1,所述方法包括:
Step1:确定在线知识蒸馏算法中的学生网络集合,并设置错题集以保存各学生网络训练过程中的集成输出特征,通过公开的图像分类数据集中的训练图片训练各学生网络以更新网络参数,得到训练好的学生网络集合;更新各学生网络参数时,通过训练图片的真实标签判断错题集中记录的集成输出特征是否正确,如果记录的是正确的特征,则更新网络参数以使得网络集合中的每一个网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,就让zi的分布远离错题集中记录的特征分布;
Step2:将待分类图像采用不同的随机变换得到变换后的图像以输入训练好的各学生网络中,得到各学生网络的输出特征,将所有学生网络的输出特征取平均得到对应于待分类图像的集成输出特征,根据待分类图像的集成输出特征对待分类图像进行分类。
学生网络集合中包含若干个学生网络,学生网络可以是任意的神经网络,比如ResNet、VGG、MobileNet、DenseNet、ShuffleNet等公开的网络。
实施例二
本实施例提供一种基于改进在线知识蒸馏算法的图像分类方法,所述方法针对图像分类任务,在现有的在线知识蒸馏算法基础上作出了改进,提出一种基于错题集机制的反思协作学习在线知识蒸馏算法,进而应用到图像分类任务上,提高了图像分类精度。
所述方法包括:
步骤A.1、设置网络集合中的学生网络的模型结构。
所述步骤A.1包括:
(1)学生网络集合中的模型结构的选择应该按照部署的场景需求进行选择,若追求速度应该选择容量小的轻量网络,比如MobileNet、ShuffleNet等;若追求精度可以选择容量适中或较大的网络,比如DenseNet、VGG;
学生网络集合中的神经网络的结构均采用已知的公开网络结构,比如VGG(Simonyan K,Zisserman A.Very deep convolutional networks for large-scaleimage recognition[J].arXiv preprint arXiv:1409.1556,2014.)网络是2014年提出的,证明了网络的深度能够在一定程度上提升网络的性能,并采用连续堆积的小卷积核代替一个较大的卷积核,保证在具有相同感知野的条件下增加网络的非线性程度。ResNet(He K,Zhang X,Ren S,et al.Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and patternrecognition.2016:770-778.)是2015年提出来的,主要参考VGG网络结构,并在其基础上进行修改,通过加入残差单元,缓解深层网络的退化问题。在一定程度上,也通过残差单元解决了深层网络在训练时梯度消失的问题。MobileNet(Howard A G,Zhu M,Chen B,etal.Mobilenets:Efficient convolutional neural networks for mobile visionapplications[J].arXiv preprint arXiv:1704.04861,2017.)是2017年提出来的一种轻量级深度网络。主要提出使用深度可分离卷积在不过多降低识别准确率的情况下降低标准卷积的参数量和计算量,缓解移动设备算力低的压力。并且使用ReLU6激活函数代替ReLU,使网络对低比特量化更加友好。
(2)集合中网络的数量应该根据模型训练机器的显卡存储空间进行选择,若显卡存储空间充裕,可以选择较多的网络进行训练,但应满足最少两个网络参与训练的要求。
步骤B.1、数据集若包含N个训练图片,共C个类别,则初始化错题集为
Figure BDA0003501947440000051
的零矩阵:
所述步骤B.1包括:
(1)数据集若包含N个训练图片,共C个类别,则初始化错题集为
Figure BDA0003501947440000052
的零矩阵。
步骤C.1、定义反思系数调整策略。
所述步骤C.1包括:
(1)定义反思系数调整策略,设初始反思系数μ=μ0为0.005,设验证监视窗口大小w为10,当模型完成一轮次的训练和验证后,记录模型当前轮次的验证准确率,当模型的验证准确率连续w轮超过75%时,令μ=2μ0为0.01,随着模型进一步的训练,当模型的验证准确率连续w轮超过93.5%时,令μ=4μ0为0.02。
步骤D.1、读入数据集的图像数据,对图片进行数据增强后将分辨率调整到32*32并对其归一化。
所述步骤D.1包括:
(1)读入数据集的图像数据,对图片进行随机水平翻转、随机裁剪、边界部分用零填充后将分辨率调整到32*32并对其归一化;
(2)网络集合中的每一个网络的输入应当采用不同的随机变换,即若网络集合中共设置了3个模型,对相同的输入图片做多种变换得到3个模型各自的输入,且3个输入是各不相同的,最后分辨率调整到32*32,使其尺寸能够输入给模型。
步骤E.1、使用反思机制训练网络集合中的网络,设置网络训练的批量数为128,网络前向传播时需要将网络集成输出按照指定的更新策略更新到错题集中。
所述步骤E.1包括:
(1)设置网络训练批次为128,学习率为0.1,动量为0.9,权重衰减正则系数为0.0005,训练总轮数为300,在训练第150次和第225次时学习率衰减10倍,蒸馏温度系数设置为3.0,错题集的历史信息保留率γ为0.2;
(2)网络前向传播时,得到第i个网络的输出特征zi,将所有网络的输出平均得到集成输出特征ze,前向时需要将集成输出更新到错题集中,表达式表示为:
NB←γNB+(1-γ)ze
(3)反向传播更新网络参数时,通过训练图片的真实标签可以判断出错题集中记录的特征是否正确,如果记录的是正确的特征,就让网络集合中的每一个网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,就让zi的分布远离错题集中记录的特征分布,从而实现反思的目的。损失函数表示为:
Figure BDA0003501947440000061
其中,LHKD(NB,zi)为(Hinton G,Vinyals O,Dean J.Distilling the knowledgein a neural network[EB/OL].https://arxiv.org/abs/1503.02531)提出的蒸馏损失函数,LCE为交叉熵损失函数。
步骤F.1、如果网络的验证精度达到某一设定好的条件时,调整训练时的反思系数。
所述步骤F.1包括:
(1)如果网络集合在验证集上的集成准确率满足设定好的条件时,按照设定好的策略调整反思系数μ,并使用新的反思系数进行新一轮的训练。
步骤G.1、完成网络集合的训练,并将网络集合中性能最佳的模型进行测试部署。
所述步骤G.1包括:
(1)模型部署时,如果优先考虑准确率,则可以将网络集合的所有网络进行部署,取网络集合中所有网络输出的平均值作为最终的集成结果可以获得更加准确的识别效果,如果优先考虑推理速度,可以选择网络集合中验证准确率最高的模型进行部署,可以兼具推理效率和准确率。
实施例三
本实施例提供一种基于改进在线知识蒸馏算法的图像分类方法,以应用于CIFAR100数据集上使用ResNet110网络进行图像分类为例进行说明,CIFAR100数据集为图像分类领域公开的数据集;所述方法包括:
A.1、设置网络集合中的每一个网络的模型结构。
所述步骤A.1包括:
(1)网络集合中的模型结构的选择应该按照部署的场景需求进行选择,若追求速度应该选择容量小的轻量网络,若追求精度可以选择容量适中或较大的网络;本实施例中选择ResNet110作为学生网络集合中的学生网络。
(2)集合中网络的数量应该根据模型训练机器的显卡存储空间进行选择,若显卡存储空间充裕,可以选择较多的网络进行训练,但应满足最少两个网络参与训练的要求。本实施例中选择3个网络,即3个网络均为ResNet110。
需要进行说明的是,在选择具体的学生网络的类型时,学生网络可以是相同的网络,也可以是不同的网络,比如3个网络可以分别选择VGG、ResNet和MobileNet,也可以如本实施例一样,均选择ResNet110。
B.1、数据集若包含N个训练图片,共C个类别,则初始化错题集为
Figure BDA0003501947440000071
的零矩阵,即初始化一个N行C列的零矩阵。
C.1、定义反思系数调整策略。
所述步骤C.1包括:
(1)定义反思系数调整策略,设初始反思系数μ=μ0为0.005,设验证监视窗口大小w为10,当模型完成一轮次的训练和验证后,记录模型当前轮次的验证准确率,当模型的验证准确率连续w轮超过75%时,令μ=2μ0为0.01,随着模型进一步的训练,当模型的验证准确率连续w轮超过93.5%时,令μ=4μ0为0.02。
D.1、读入数据集的图像数据,对图片进行数据增强后将分辨率调整到32*32并对其归一化。
所述步骤D.1包括:
(1)读入数据集的图像数据,对图片进行预处理,即随机水平翻转、随机裁剪、边界部分用零填充后将分辨率调整到32*32并对其归一化;
(2)网络集合中的每一个网络的输入应当采用不同的随机预处理,即若网络集合中共设置了3个学生网络,对相同的输入图片做多种变换得到3个模型各自的输入,且3个输入是各不相同的,最后分辨率调整到32*32,使其尺寸能够输入给模型。
E.1、使用反思机制训练网络集合中的网络,设置网络训练的批量数为128,网络前向传播时需要将网络集成输出按照指定的更新策略更新到错题集中。
所述步骤E.1包括:
(1)设置网络训练批次为128,学习率为0.1,动量为0.9,权重衰减正则系数为0.0005,训练总轮数为300,在训练第150次和第225次时学习率衰减10倍,蒸馏温度系数设置为3.0,错题集的历史信息保留率γ为0.2;
(2)网络前向传播时,得到第i个网络的输出特征zi,将所有网络的输出平均得到集成输出特征ze,前向时需要将集成输出更新到错题集中,表达式表示为:
NB←γNB+(1-γ)ze
(3)反向传播更新网络参数时,通过训练图片的真实标签可以判断出错题集中记录的特征是否正确,如果记录的是正确的特征,就让网络集合中的每一个网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,就让zi的分布远离错题集中记录的特征分布,从而实现反思的目的。损失函数表示为:
Figure BDA0003501947440000081
其中,LHKD(NB,zi)为(Hinton G,Vinyals O,Dean J.Distilling the knowledgein a neural network[EB/OL].https://arxiv.org/abs/1503.02531)提出的蒸馏损失函数,LCE为交叉熵损失函数。
F.1、如果网络的验证精度达到某一设定好的条件时,调整训练时的反思系数。
所述步骤F.1包括:
(1)如果网络集合在验证集上的集成准确率满足设定好的条件时,按照设定好的策略调整反思系数μ,并使用新的反思系数进行新一轮的训练。
G.1、完成网络集合的训练,并将网络集合中性能最佳的模型进行测试部署。
所述步骤G.1包括:
(1)模型部署时,如果优先考虑准确率,则可以将网络集合的所有网络进行部署,取网络集合中所有网络输出的平均值作为最终的集成结果可以获得更加准确的识别效果,如果优先考虑推理速度,可以选择网络集合中验证准确率最高的模型进行部署,可以兼具推理效率和准确率。
如图2所述,本实施例对比了采用不同的知识蒸馏方法对图像分类模型进行压缩后得到的模型对CIFAR100数据集的所有图片进行分类的图像分类准确度,其中:
DML方法可参考“Zhang Y,Xiang T,Hospedales T M,et al.Deep mutuallearning[C]//Proceedings of the IEEE Conference on Computer Vision andPattern Recognition.2018:4320-4328.”;
ONE方法可参考“Lan X,Zhu X,Gong S.Knowledge distillation by on-the-flynative ensemble[C]//Proceedings of the 32nd International Conference onNeural Information Processing Systems.2018:7528-7538.”;
OKDDip方法可参考“Chen D,Mei J P,Wang C,et al.Online knowledgedistillation with diverse peers[C]//Proceedings of the AAAI Conference onArtificial Intelligence.2020,34(04):3430-3437.”;
KDCL方法可参考“Guo Q,Wang X,Wu Y,et al.Online knowledge distillationvia collaborative learning[C]//Proceedings of the IEEE/CVF Conference onComputer Vision and Pattern Recognition.2020:11020-11029.”。
上述四种方法和本申请方法均采用3个网络模型进行在线知识蒸馏以进行图像分类,上述四种方法中,DML方法采用了深度互学习的想法,在每一个训练批次中要轮流对每一个学生网络进行优化。在训练其中一个网络时,将其他网络的输出作为监督信息,指导当前待优化的网络进行训练。ONE方法使用一个门模块,对多个网络分支的输出进行加权平均构建集成的监督信息,然后利用集成的监督信息轮流对每一个网络分支进行优化。KDDip方法提出二级蒸馏的方法,第一级蒸馏是使用注意力机制对几个网络分支的输出进行加权构建监督信息,监督几个网络分支的优化,第二级蒸馏是将几个网络分支的平均输出作为监督信息,指导一个学生领导网络进行优化。KDCL方法集成了所有学生网络的输出结果,将此结果作为监督信息,轮流指导每一个学生网络进行优化。由此可知,这四种现有方法都只关注当前的监督信息,而忽略了历史记录所能提供的监督信息,因而其分类准确度无法得到进一步的提高;图2给出了这四种方法和本申请方法最终得到的图像分类准确度,由图2可知,本申请方法得出的图像分类准确度为74.21%,高于所列出的现有四种方法,而现有对于采用不同的知识蒸馏方法蒸馏得到的网络进行图像分类的精度的提高非常困难,可以看到在之前的研究中,最高值为73.79%,而本申请通过利用错题集模块保存网络集合集成输出的历史信息,并将其作为学生网络优化训练的监督信息,并提出一种反思机制使学生网络的输出分布远离错误的历史信息,靠近正确的历史信息,迫使学生网络学习到更优质的特征,从而在对图像进行分类时得到更准确的分类结果,使得图像分类准确度提高至74.21%。
本发明实施例中的部分步骤,可以利用软件实现,相应的软件程序可以存储在可读取的存储介质中,如光盘或硬盘等。
以上所述仅为本发明的较佳实施例,并不用以限制本发明,凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (8)

1.一种基于改进在线知识蒸馏算法的图像分类方法,其特征在于,所述方法包括:
Step1:确定在线知识蒸馏算法中的学生网络集合,并设置错题集以保存各学生网络训练过程中的集成输出特征,通过公开的图像分类数据集中的训练图片训练各学生网络以更新网络参数,得到训练好的学生网络集合;更新各学生网络参数时,通过训练图片的真实标签判断错题集中记录的集成输出特征是否正确,如果记录的是正确的特征,则更新网络参数以使得网络集合中的每一个网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,就让zi的分布远离错题集中记录的特征分布;
Step2:将待分类图像采用不同的随机变换得到变换后的图像以输入训练好的各学生网络中,得到各学生网络的输出特征,将所有学生网络的输出特征取平均得到对应于待分类图像的集成输出特征,根据待分类图像的集成输出特征对待分类图像进行分类。
2.根据权利要求1所述的方法,其特征在于,设公开的图像分类数据集包含N个训练图片,共C个类别,则初始化错题集为
Figure FDA0003501947430000011
的零矩阵;
Figure FDA0003501947430000012
为N×C维的实数集;
所述Step1中训练各学生网络以更新网络参数,包括:
对公开的图像分类数据集中的训练图片进行预处理;
使用反思机制训练学生网络集合中的各学生网络,设置网络训练批次为128,学习率为0.1,动量为0.9,权重衰减正则系数为0.0005,训练总轮数为300,在训练第150次和第225次时学习率衰减10倍,蒸馏温度系数设置为3.0,错题集的历史信息保留率γ为0.2;
各学生网络前向传播更新网络参数时,得到第i个学生网络的输出特征zi,将所有学生网络的输出平均得到集成输出特征ze,前向时将集成输出更新到错题集中,表达式表示为:
NB←γNB+(1-γ)ze
各学生网络反向传播更新网络参数时,通过训练图片的真实标签可以判断出错题集中记录的特征是否正确,如果记录的是正确的特征,则让学生网络集合中的每一个学生网络的输出特征zi的分布靠近错题集中记录的特征分布,如果记录的是错误的特征,则让zi的分布远离错题集中记录的特征分布;即各学生网络的损失函数设置为:
Figure FDA0003501947430000013
其中,LHKD(NB,zi)为蒸馏损失函数,LCE为交叉熵损失函数。
3.根据权利要求2所述的方法,其特征在于,所述使用反思机制训练学生网络集合中的各学生网络时,设置初始反思系数μ=μ0为0.005,设验证监视窗口大小w为10。
4.根据权利要求3所述的方法,其特征在于,当各学生网络完成一轮次的训练和验证后,记录各学生网络当前轮次的验证准确率,当各学生网络的验证准确率连续w轮超过第一阈值时,设置反思系数为第一预定值;当各学生网络的验证准确率连续w轮超过第二阈值时,设置反思系数为第二预定值。
5.根据权利要求4所述的方法,其特征在于,所述第一阈值为75%,对应的第一预定值为0.01,即令μ=2μ0为0.01。
6.根据权利要求5所述的方法,其特征在于,所述第二阈值为93.5%,对应的第二预定值为0.02,即令μ=5μ0为0.02。
7.根据权利要求6所述的方法,其特征在于,所述通过公开的图像分类数据集中的训练图片训练各学生网络以更新网络参数时,对不同的学生网络,采用不同的预处理对同一训练图片做变换后得到各学生网络的输入。
8.根据权利要求7所述的方法,其特征在于,所述预处理包括随机水平翻转、随机裁剪、边界部分用零填充后将分辨率调整到32*32并对其进行归一化处理。
CN202210183421.7A 2022-02-11 2022-02-11 一种基于改进在线知识蒸馏算法的图像分类方法 Pending CN114549905A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210183421.7A CN114549905A (zh) 2022-02-11 2022-02-11 一种基于改进在线知识蒸馏算法的图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210183421.7A CN114549905A (zh) 2022-02-11 2022-02-11 一种基于改进在线知识蒸馏算法的图像分类方法

Publications (1)

Publication Number Publication Date
CN114549905A true CN114549905A (zh) 2022-05-27

Family

ID=81679240

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210183421.7A Pending CN114549905A (zh) 2022-02-11 2022-02-11 一种基于改进在线知识蒸馏算法的图像分类方法

Country Status (1)

Country Link
CN (1) CN114549905A (zh)

Similar Documents

Publication Publication Date Title
CN110070183B (zh) 一种弱标注数据的神经网络模型训练方法及装置
US20240013856A1 (en) Splicing Site Classification Using Neural Networks
CN107480261B (zh) 一种基于深度学习细粒度人脸图像快速检索方法
Tjandra et al. Compressing recurrent neural network with tensor train
CN113420651B (zh) 深度卷积神经网络的轻量化方法、系统及目标检测方法
CN111898689A (zh) 一种基于神经网络架构搜索的图像分类方法
CN109255381B (zh) 一种基于二阶vlad稀疏自适应深度网络的图像分类方法
CN112766062B (zh) 一种基于双流深度神经网络的人体行为识别方法
CN111062410B (zh) 基于深度学习的星型信息桥气象预测方法
CN114898151A (zh) 一种基于深度学习与支持向量机融合的图像分类方法
CN114998659A (zh) 随时间在线训练脉冲神经网络模型的图像数据分类方法
CN116844041A (zh) 一种基于双向卷积时间自注意力机制的耕地提取方法
CN116229323A (zh) 一种基于改进的深度残差网络的人体行为识别方法
CN114638408A (zh) 一种基于时空信息的行人轨迹预测方法
CN117033985A (zh) 一种基于ResCNN-BiGRU的运动想象脑电分类方法
CN113590748B (zh) 基于迭代网络组合的情感分类持续学习方法及存储介质
CN114298224A (zh) 图像分类方法、装置以及计算机可读存储介质
CN114169385A (zh) 基于混合数据增强的mswi过程燃烧状态识别方法
CN113989566A (zh) 一种图像分类方法、装置、计算机设备和存储介质
Nandan et al. Handwritten digit recognition using ensemble learning
CN117272040A (zh) 一种基于元学习框架的小样本时间序列预测方法
CN116543289A (zh) 一种基于编码器-解码器及Bi-LSTM注意力模型的图像描述方法
CN114549905A (zh) 一种基于改进在线知识蒸馏算法的图像分类方法
CN113435588B (zh) 基于深度卷积神经网络bn层尺度系数的卷积核嫁接方法
CN115063374A (zh) 模型训练、人脸图像质量评分方法、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination