CN110991556B - 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质 - Google Patents

一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质 Download PDF

Info

Publication number
CN110991556B
CN110991556B CN201911300279.4A CN201911300279A CN110991556B CN 110991556 B CN110991556 B CN 110991556B CN 201911300279 A CN201911300279 A CN 201911300279A CN 110991556 B CN110991556 B CN 110991556B
Authority
CN
China
Prior art keywords
student
model
training
distillation
images
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN201911300279.4A
Other languages
English (en)
Other versions
CN110991556A (zh
Inventor
冯于树
胡浩基
李卓远
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN201911300279.4A priority Critical patent/CN110991556B/zh
Publication of CN110991556A publication Critical patent/CN110991556A/zh
Application granted granted Critical
Publication of CN110991556B publication Critical patent/CN110991556B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质。该方法包括:获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;对所有图像进行预处理操作;将预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;将测试集输入合作蒸馏模型对图片进行分类。本发明提出的方法相比原始的方法在图像分类算法中的分类效果提升了3.6%。

Description

一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及 介质
技术领域
本发明实施例涉及计算机视觉领域,特别涉及一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质。
背景技术
随着信息技术的高速发展,深度学习技术在图像分类任务上的性能已经远远超越了传统的图像识别方法。深度学习将输入的图像通过一些简单的非线性的模型转变成为更加抽象的表达,所提取到的特征更加接近图像的高级语义信息。
深度卷积神经网络(Convolutional Neural Network,CNN)是特别设计用于识别图像的多层感知器。CNN的权重共享网络结构与生物神经网络类似,通过对图像进行多次的卷积核池化操作,逐渐提取到图像的高层表达,再使用神经网络对特征进行分类,以此来实现对图像分类的功能。因此CNN在图像分类领域表现出极大的优势。
然而,CNN的强大的表达能力是以内存和其他资源的消耗为代价的。大量的神经网络权重会消耗大量的内存和存储器带宽,阻碍其在图像分类任务中的应用。在资源受限的场合下,CNN的模型大小受到限制,对应地,CNN对图像分类的性能将会下降。
发明内容
为了解决上述问题,本发明实施例提供了一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质。在同一个教师的监督下,通过同时训练多个具有相同结构且模型内存占据较小的学生网络,使学生之间能够实现信息的交流,最大程度地增大每个学生所获取到的信息量,以此提高学生网络的性能,从而使得CNN模型在大小受到限制的情况下,仍能有高效的图像表达能力。
本发明的目的可以通过以下的技术方法实现:
第一方面,本发明实施例提供一种基于多学生合作蒸馏的高效图像分类方法,包括以下步骤:
获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;
对所有图像进行预处理操作;
将预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;
将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;
将测试集输入合作蒸馏模型对图片进行分类。
进一步的,所述对所有图像进行预处理操作中,对训练集图片的预处理操作为,首先以50%的概率将该图片进行水平翻转,然后以50%的概率将该图片顺时针旋转,最后进行训练图片的归一化,将每一张图片的像素都减去全部训练集图像的像素均值,然后将每一张图片的像素都除以全部训练集图像的像素的标准差。
进一步的,所述对所有图像进行预处理操作中,对测试集图片的预处理操作为,进行训练图片的归一化,将每一张图片的像素都减去全部训练集图像的像素均值,然后将每一张图片的像素都除以全部训练集图像的像素的标准差。
进一步的,所述合作蒸馏训练,包括:
(4.1)每次迭代训练中,首先将所有学生模型加入到学生模型集合{i,=1,2,…,}中;计算教师模型T的输出概率与每个学生模型的输出概率之间的KL散度,并按照KL散度的大小,对学生模型集合{i}进行降序排序;
(4.2)按顺序从学生模型集合中取出一个学生Sk,并计算以下值:
(4.2.1)计算该学生的输出概率与训练图片的标签之间的交叉熵
(4.2.2)计算该学生的输出概率与教师的输出概率之间的KL散度
(4.2.3)若此时学生模型为空,则跳过该步骤;否则在剩下的学生模型集合中,针对每个学生Si,计算Si给予学生Sk的知识N(i,k),并计算Si和学生Sk的差异M(i,k),将N(i,k)与M(i,k)相乘,并进行累加,得到值
(4.2.4)若此时学生模型为空,则跳过该步骤;否则在剩下的模型集合中,计算所有学生的输出向量的平均值,计算学生Sk的输出向量与该平均值之间的绝对值距离D(k),计算多样性损失e-D(k),得到值
(4.2.5)将和/>进行累加,得到值Lk,作为学生Sk在本次训练的损失值,进行学生Sk的梯度更新;
(4.3)若学生模型集合为空,则本次迭代训练结束;否则重复步骤(4.2);
(4.4)当迭代次数达到预设值后,结束所有学生的训练。
进一步的,所述步骤(4.1)、(4.2.1)、(4.2.2)中,模型的输出概率具体为,图片经过神经网络模型后,得到最后一层的输出,然后经过softmax层,得到输出概率。
进一步的,所述步骤(4.2.4)中,模型的输出向量具体为,图片经过神经网络模型后,最后一层的输出。
进一步的,所述步骤(4.2.3)中,学生Si给予学生Sk的知识N(i,k)具体为,学生Sk的输出概率与学生Si的输出概率之间的KL散度。
进一步的,所述步骤(4.2.3)中,学生Si和学生Sk的差异M(i,k)具体为,学生Si的输出向量与学生Sk的输出向量之间的欧氏距离。
进一步的,所述步骤(4.2.5)中,Lk的计算具体为,其中,α、β和γ是需要手动设置的超参数,取值范围为0~1。
进一步的,所述步骤(5)中,选取一个学生模型作为最终的图像分类器,具体为,选择第一个学生模型,删除剩下的所有学生。
第二方面,本发明实施例提供一种基于多学生合作蒸馏的高效图像分类装置,包括:
获取标注模块,用于获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;
预处理模块,用于对所有图像进行预处理操作;
教师模型建立模块,用于将预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;
合作蒸馏模型建立模块,用于将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;
分类模块,用于将测试集输入合作蒸馏模型对图片进行分类。
第三方面,本发明实施例提供一种设备,包括:
一个或多个处理器;
存储器,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如第一方面所述的一种基于多学生合作蒸馏的高效图像分类方法。
第四方面,本发明实施例提供一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如第一方面所述的一种基于多学生合作蒸馏的高效图像分类方法。
采用上述技术方案,本发明实施例具有如下优点:
(1)本发明属于利用深度卷积神经网络进行图像分类的方法,相比于传统的图像分类方法,本发明能够实现更好的分类结果。本发明提出的方法相比原始的方法在图像分类算法中的分类效果提升了3.6%。
(2)本发明在蒸馏过程中让多个学生模型之间进行信息的交互,使得每个学生模型都能获取到其他学生模型从输入图像中提取到的高级语义特征,进而更加显著地提高图像分类的效率。
(3)本发明在蒸馏过程中通过设置学生模型的多样性损失,让每个学生模型适当地提取到不同于其他学生模型所提取到的图像特征,从而最大限度地提高学生模型交互时的信息量,进而更加显著地提高图像分类的效率。
(4)本发明提供的方法适用于绝大多数卷积神经网络的性能提高。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本发明的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1为本发明实施例中整体方法流程图;
图2为本发明实施例中多学生合作蒸馏的流程图。
图3为本实施例2中提供的一种基于多学生合作蒸馏的高效图像分类装置的结构示意图;
图4为本发明实施例3提供的一种设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合本申请具体实施例及相应的附图对本申请技术方案进行清楚、完整地描述。显然,所描述的实施例仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
实施例1:
图1为本发明实施例中整体方法流程图;本发明实施例提供了一种基于多学生合作蒸馏的高效图像分类方法,该方法包括如下步骤:
S100、获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;
具体的,准备数据集、网络训练框架等,本实施例采用开源的、已提供类别标注的CIFAR-100数据集,其中训练集为50000张、100类图像,测试集为10000张、100类图像。(CIFAR100下载链接:http://www.cs.toronto.edu/~kriz/cifar.html)。所使用的网络训练框架为PyTorch框架。
S200、对所有图像进行预处理。
具体为:首先将属于同一类的图片放在同一个文件夹下。对每一张训练图像,首先以50%的概率将该图片进行水平翻转,然后以50%的概率将该图片顺时针旋转15度,最后进行训练图片的归一化,将每一张图片的像素都减去全部训练集图像的像素均值,然后将每一张图片的像素都除以全部训练集图像的像素的标准差。对每一张测试图像,进行训练图片的归一化,将每一张图片的像素都减去全部训练集图像的像素均值,然后将每一张图片的像素都除以全部训练集图像的像素的标准差。
S300、将步骤S200预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;
S400、将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;
具体的,设置总的训练次数为200个周期,迭代次数初始化为0。随机初始化三个学生模型,其网络结构均为ResNet-18。然后将步骤S200预处理后的图像训练集分别输送到每个学生模型和预训练好的教师模型T中去,通过迭代训练的方式使每个学生模型都能学习到每个类别的特征,并在迭代训练的过程中进行教师模型T指导下的三个学生模型之间的合作蒸馏,使得每个学生模型都能提高分类性能。对于三个学生的合作蒸馏部分,其流程图如图2所示,主要包含以下步骤:
S410、在每次迭代训练中,将三个学生模型加入到学生模型集合S中,计算每个学生的输出概率与教师模型T的输出概率之间的KL散度,并按照KL散度的大小,对模型集合进行降序排序,得到S={S1,S2,S3}。迭代次数加1。
S420、按顺序取出一个学生模型,为S1,此时学生模型集合S={S2,S3},初始化四个值为0,并计算:
S4201、计算该学生的输出概率与训练图片的标签之间的交叉熵
S4202、计算该学生的输出概率与教师的输出概率之间的KL散度
S4203、判断学生集合是否为空,是则跳到步骤S4206,否则进入步骤S4204。
S4204、剩下的模型集合S不为空,因此在剩下的学生模型集合S={S2,S3}中,针对每个学生Si,计算学生模型Si的输出概率与学生模型S1的输出概率之间的KL散度,记作N(i,1),并计算学生模型Si的输出向量和学生模型S1的输出向量的欧氏距离,记作M(i,1),将N(i,1)与M(i,1)相乘,并进行累加,得到值即此时/>
S4205、剩下的模型集合S不为空,因此在剩下的模型集合S={S2,S3}中,取所有学生模型的输出向量的平均值,计算学生S1的输出向量与该平均值之间的绝对值距离D(1),计算多样性损失e-D(1),得到值即/>
S4206、计算在本实验中,α取0.1,β取0.9,γ取0.7。将L1作为学生模型S1本次迭代过程中的损失值,并以此更新学生S1的权重。
S430、判断学生模型是否为空,是则结束本次迭代训练,否则重复步骤S420。
S440、判断迭代次数是否达到200个周期,若是则结束训练过程,否则重复步骤S410。
S500、保留第一个学生模型,作为新的图像分类器,并删除剩下两个模型。进行测试集在该学生网络的分类。本实验中,选用CIFAR-100测试集的10000张、共100类图片作为新的图像数据,得到分类结果。
实验结果表明,原始的ResNet-18在CIFAR-100测试集上的分类误差为24.39%(数据来源:https://github.com/weiaicunzai/pytorch-cifar100),经本发明方法训练得到的ResNet-18在CIFAR-100测试集上的分类误差为20.79%,与原来相比,模型的分类性能提升了3.6%。因此,本发明能够显著提升图像分类效率,使得在模型大小受限的情况下,仍能有优异的分类效果。
本发明在使用一个教师模型的监督下同时训练多个具有相同结构的学生网络,使学生模型之间实现信息的交流,以此提高学生模型的分类性能,从而使得神经网络在大小受到限制的情况下仍有高效的图像分类能力。本发明提出的方法相比原始的方法在图像分类算法中的分类效果提升了3.6%。
实施例2:
图3为本实施例2中提供的一种基于多学生合作蒸馏的高效图像分类装置的结构示意图,该装置底层基于互联网网络,该装置可以执行任意本发明任意实施例所提供的一种基于多学生合作蒸馏的高效图像分类方法,具备执行该方法相应的功能模块和有益效果。如图3所示,该装置包括:
获取标注模块,用于获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;
预处理模块,用于对所有图像进行预处理操作;
教师模型建立模块,用于将预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;
合作蒸馏模型建立模块,用于将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;
分类模块,用于将测试集输入合作蒸馏模型对图片进行分类。
实施例3:
图4为本发明实施例3提供的一种设备的结构示意图。图4示出了适于用来实现本发明实施方式的示例性设备1的框图。图4显示的设备仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。
如图4所示,设备1以通用计算设备的形式表现。设备1的组件可以包括但不限于:一个或者多个处理器或者处理单元2,存储器3,连接不同系统组件(包括存储器3和处理单元2)的总线4。
总线4表示几类总线结构中的一种或多种,包括存储器总线或者存储器控制器,外围总线,图形加速端口,处理器或者使用多种总线结构中的任意总线结构的局域总线。举例来说,这些体系结构包括但不限于工业标准体系结构(ISA)总线,微通道体系结构(MAC)总线,增强型ISA总线、视频电子标准协会(VESA)局域总线以及外围组件互连(PCI)总线。
设备1典型地包括多种计算机系统可读介质。这些介质可以是任何能够被设备1访问的可用介质,包括易失性和非易失性介质,可移动的和不可移动的介质。
存储器3可以包括易失性存储器形式的计算机系统可读介质,例如随机存取存储器(RAM)5和/或高速缓存存储器6。设备1可以进一步包括其它可移动/不可移动的、易失性/非易失性计算机系统存储介质。仅作为举例,存储系统8可以用于读写不可移动的、非易失性磁介质(图3未显示,通常称为“硬盘驱动器”)。尽管图4中未示出,可以提供用于对可移动非易失性磁盘(例如“软盘”)读写的磁盘驱动器,以及对可移动非易失性光盘(例如CDROM,DVD-ROM或者其它光介质)读写的光盘驱动器。在这些情况下,每个驱动器可以通过一个或者多个数据介质接口与总线4相连。存储器3可以包括至少一个程序产品,该程序产品具有一组(例如至少一个)程序模块,这些程序模块被配置以执行本发明各实施例的功能。
具有一组(至少一个)程序模块8,可以存储在例如存储器3中,这样的程序模块8包括但不限于操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。程序模块8通常执行本发明所描述的实施例中的功能和/或方法。
设备1也可以与一个或多个外部设备10(例如键盘、指向设备、显示设备9等)通信,还可与一个或者多个使得用户能与该设备1交互的设备通信,和/或与使得该设备1能与一个或多个其它计算设备进行通信的任何设备(例如网卡,调制解调器等等)通信。这种通信可以通过输入/输出(I/O)接口11进行。并且,设备1还可以通过网络适配器12与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图4所示,网络适配器12通过总线4与设备1的其它模块通信。应当明白,尽管图4中未示出,可以结合设备1使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID系统、磁带驱动器以及数据备份存储系统等。
处理单元2通过运行存储在存储器3中的程序,从而执行各种功能应用以及数据处理,例如实现本发明实施例所提供的基于互联网的多学生合作蒸馏的高效图像分类方法。
实施例4
本发明实施例4还提供了一种计算机可读存储介质,其上存储有计算机程序(或称为计算机可执行指令),该程序被处理器执行时用于执行一种基于多学生合作蒸馏的高效图像分类方法,该方法为实施例1所述的方法。
本发明实施例的计算机存储介质,可以采用一个或多个计算机可读的介质的任意组合。计算机可读介质可以是计算机可读信号介质或者计算机可读存储介质。计算机可读存储介质例如可以是——但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CDROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本文件中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
计算机可读的信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。计算机可读的信号介质还可以是计算机可读存储介质以外的任何计算机可读介质,该计算机可读介质可以发送、传播或者传输用于由指令执行系统、装置或者器件使用或者与其结合使用的程序。
计算机可读介质上包含的程序代码可以用任何适当的介质传输,包括——但不限于无线、电线、光缆、RF等等,或者上述的任意合适的组合。
可以以一种或多种程序设计语言或其组合来编写用于执行本发明操作的计算机程序代码,所述程序设计语言包括面向对象的程序设计语言—诸如Java、Smalltalk、C++,还包括常规的过程式程序设计语言—诸如”C”语言或类似的程序设计语言。程序代码可以完全地在用户计算机上执行、部分地在用户计算机上执行、作为一个独立的软件包执行、部分在用户计算机上部分在远程计算机上执行、或者完全在远程计算机或服务器上执行。在涉及远程计算机的情形中,远程计算机可以通过任意种类的网络——包括局域网(LAN)或广域网(WAN)—连接到用户计算机,或者,可以连接到外部计算机(例如利用因特网服务提供商来通过因特网连接)。
本发明通过上述实施例来说明本发明的详细方法,但本发明并不局限于上述详细方法,即不意味着本发明必须依赖上述详细方法才能实施。所属技术领域的技术人员应该明了,对本发明的任何改进,对本发明产品各原料的等效替换及辅助成分的添加、具体方式的选择等,均落在本发明的保护范围和公开范围之内。

Claims (9)

1.一种基于多学生合作蒸馏的高效图像分类方法,其特征在于,该方法包括:
获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;
对所有图像进行预处理操作;
将预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;
将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;
所述合作蒸馏训练,包括:
(1)每次迭代训练中,先将所有学生模型加入到学生模型集合{Si,i=1,2,…,N}中;计算教师模型T的输出概率与每个学生模型的输出概率之间的KL散度,并按照KL散度的大小,对学生模型集合{Si}进行降序排序;
(2)按顺序从学生模型集合中取出一个学生Sk,并计算以下值:
(2.1)计算该学生的输出概率与训练图片的标签之间的交叉熵
(2.2)计算该学生的输出概率与教师的输出概率之间的KL散度
(2.3)若此时学生模型为空,则跳过该步骤;否则在剩下的学生模型集合中,针对每个学生Si,计算Si给予学生Sk的知识N(i,k),并计算Si和学生Sk的差异M(i,k),将N(i,k)与M(i,k)相乘,并进行累加,得到值
(2.4)若此时学生模型为空,则跳过该步骤;否则在剩下的模型集合中,计算所有学生的输出向量的平均值,计算学生Sk的输出向量与该平均值之间的绝对值距离D(k),计算多样性损失e-D(k),得到值
(2.5)将和/>进行累加,得到值Lk,作为学生Sk在本次训练的损失值,进行学生Sk的梯度更新;
(3)若学生模型集合为空,则本次迭代训练结束;否则重复步骤(4.2);
(4)当迭代次数达到预设值后,结束所有学生的训练;
将测试集输入合作蒸馏模型对图片进行分类。
2.根据权利要求1所述的一种基于多学生合作蒸馏的高效图像分类方法,其特征在于,所述对所有图像进行预处理操作中,对训练集图片的预处理操作为,首先以50%的概率将该图片进行水平翻转,然后以50%的概率将该图片顺时针旋转,最后进行训练图片的归一化,将每一张图片的像素都减去全部训练集图像的像素均值,然后将每一张图片的像素都除以全部训练集图像的像素的标准差。
3.根据权利要求1所述的一种基于多学生合作蒸馏的高效图像分类方法,其特征在于,所述对所有图像进行预处理操作中,对测试集图片的预处理操作为,进行训练图片的归一化,将每一张图片的像素都减去全部训练集图像的像素均值,然后将每一张图片的像素都除以全部训练集图像的像素的标准差。
4.根据权利要求1所述的一种基于多学生合作蒸馏的高效图像分类方法,其特征在于,所述步骤(2.3)中,学生Si给予学生Sk的知识N(i,k)具体为,学生Sk的输出概率与学生Si的输出概率之间的KL散度。
5.根据权利要求1所述的一种基于多学生合作蒸馏的高效图像分类方法,其特征在于,所述步骤(2.3)中,学生Si和学生Sk的差异M(i,k)具体为,学生Si的输出向量与学生Sk的输出向量之间的欧氏距离。
6.根据权利要求1所述的一种基于多学生合作蒸馏的高效图像分类方法,其特征在于,所述步骤(2.5)中,Lk的计算具体为,其中,α、β和γ是需要手动设置的超参数。
7.一种基于多学生合作蒸馏的高效图像分类装置,其特征在于,包括:
获取标注模块,用于获取图像的训练集和测试集,并对训练集和测试集的图像分别进行类别标注;
预处理模块,用于对所有图像进行预处理操作;
教师模型建立模块,用于将预处理后的图片批量送入一个神经网络,进行迭代训练,得到训练好的教师模型T;
合作蒸馏模型建立模块,用于将预处理后的图片同时批量送入每个学生模型和教师模型T中,进行学生的合作蒸馏训练,得到合作蒸馏模型,其中每个学生模型为具有相同网络结构的神经网络,模型参数量小于教师模型T;
所述合作蒸馏训练,包括:
(1)每次迭代训练中,先将所有学生模型加入到学生模型集合{Si,i=1,2,…,N}中;计算教师模型T的输出概率与每个学生模型的输出概率之间的KL散度,并按照KL散度的大小,对学生模型集合{Si}进行降序排序;
(2)按顺序从学生模型集合中取出一个学生Sk,并计算以下值:
(2.1)计算该学生的输出概率与训练图片的标签之间的交叉熵
(2.2)计算该学生的输出概率与教师的输出概率之间的KL散度
(2.3)若此时学生模型为空,则跳过该步骤;否则在剩下的学生模型集合中,针对每个学生Si,计算Si给予学生Sk的知识N(i,k),并计算Si和学生Sk的差异M(i,k),将N(i,k)与M(i,k)相乘,并进行累加,得到值
(2.4)若此时学生模型为空,则跳过该步骤;否则在剩下的模型集合中,计算所有学生的输出向量的平均值,计算学生Sk的输出向量与该平均值之间的绝对值距离D(k),计算多样性损失e-D(k),得到值
(2.5)将和/>进行累加,得到值Lk,作为学生Sk在本次训练的损失值,进行学生Sk的梯度更新;
(3)若学生模型集合为空,则本次迭代训练结束;否则重复步骤(4.2);
(4)当迭代次数达到预设值后,结束所有学生的训练;
分类模块,用于将测试集输入合作蒸馏模型对图片进行分类。
8.一种电子设备,其特征在于,包括:
一个或多个处理器;
存储器,用于存储一个或多个程序;
当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如权利要求1-6任一项所述的一种基于多学生合作蒸馏的高效图像分类方法。
9.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1-6中任一项所述的一种基于多学生合作蒸馏的高效图像分类方法。
CN201911300279.4A 2019-12-16 2019-12-16 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质 Active CN110991556B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201911300279.4A CN110991556B (zh) 2019-12-16 2019-12-16 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201911300279.4A CN110991556B (zh) 2019-12-16 2019-12-16 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质

Publications (2)

Publication Number Publication Date
CN110991556A CN110991556A (zh) 2020-04-10
CN110991556B true CN110991556B (zh) 2023-08-15

Family

ID=70094588

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201911300279.4A Active CN110991556B (zh) 2019-12-16 2019-12-16 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质

Country Status (1)

Country Link
CN (1) CN110991556B (zh)

Families Citing this family (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111553298B (zh) * 2020-05-07 2021-02-05 卓源信息科技股份有限公司 一种基于区块链的火灾识别方法及系统
CN112396923B (zh) * 2020-11-25 2023-09-19 贵州轻工职业技术学院 一种市场营销的教学模拟系统
CN112528109B (zh) * 2020-12-01 2023-10-27 科大讯飞(北京)有限公司 一种数据分类方法、装置、设备及存储介质
CN113326768B (zh) * 2021-05-28 2023-12-22 浙江商汤科技开发有限公司 训练方法、图像特征提取方法、图像识别方法及装置
CN113610069B (zh) * 2021-10-11 2022-02-08 北京文安智能技术股份有限公司 基于知识蒸馏的目标检测模型训练方法
CN113888538B (zh) * 2021-12-06 2022-02-18 成都考拉悠然科技有限公司 一种基于内存分块模型的工业异常检测方法
CN115203419A (zh) * 2022-07-21 2022-10-18 北京百度网讯科技有限公司 语言模型的训练方法、装置及电子设备

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018126213A1 (en) * 2016-12-30 2018-07-05 Google Llc Multi-task learning using knowledge distillation
CN110232411A (zh) * 2019-05-30 2019-09-13 北京百度网讯科技有限公司 模型蒸馏实现方法、装置、系统、计算机设备及存储介质

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US11195093B2 (en) * 2017-05-18 2021-12-07 Samsung Electronics Co., Ltd Apparatus and method for student-teacher transfer learning network using knowledge bridge

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018126213A1 (en) * 2016-12-30 2018-07-05 Google Llc Multi-task learning using knowledge distillation
CN110232411A (zh) * 2019-05-30 2019-09-13 北京百度网讯科技有限公司 模型蒸馏实现方法、装置、系统、计算机设备及存储介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
自适应性多教师多学生知识蒸馏学习;宋迦陵;《CNKI优秀硕士学位论文全文库(社会科学II辑)》;20190915;H127-50,正文第25、37-39、41页 *

Also Published As

Publication number Publication date
CN110991556A (zh) 2020-04-10

Similar Documents

Publication Publication Date Title
CN110991556B (zh) 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质
US9990558B2 (en) Generating image features based on robust feature-learning
CN110674880A (zh) 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN111767366B (zh) 问答资源挖掘方法、装置、计算机设备及存储介质
US11640551B2 (en) Method and apparatus for recommending sample data
CN110598620B (zh) 基于深度神经网络模型的推荐方法和装置
CN113128478B (zh) 模型训练方法、行人分析方法、装置、设备及存储介质
US11030750B2 (en) Multi-level convolutional LSTM model for the segmentation of MR images
US20220092407A1 (en) Transfer learning with machine learning systems
CN111753863A (zh) 一种图像分类方法、装置、电子设备及存储介质
CN112052840B (zh) 图片筛选方法、系统、设备及存储介质
CN112381079A (zh) 图像处理方法和信息处理设备
CN111368878A (zh) 一种基于ssd目标检测的优化方法、计算机设备和介质
CN111667027A (zh) 多模态图像的分割模型训练方法、图像处理方法及装置
CN115238909A (zh) 一种基于联邦学习的数据价值评估方法及其相关设备
CN110781849A (zh) 一种图像处理方法、装置、设备及存储介质
CN113705215A (zh) 一种基于元学习的大规模多标签文本分类方法
CN112966743A (zh) 基于多维度注意力的图片分类方法、系统、设备及介质
CN113435499A (zh) 标签分类方法、装置、电子设备和存储介质
CN112966754A (zh) 样本筛选方法、样本筛选装置及终端设备
CN111259932A (zh) 分类方法、介质、装置和计算设备
CN111445545A (zh) 一种文本转贴图方法、装置、存储介质及电子设备
CN115544210A (zh) 基于持续学习的事件抽取的模型训练、事件抽取的方法
CN111062477B (zh) 一种数据处理方法、装置及存储介质
CN114185657A (zh) 一种云平台的任务调度方法、装置、存储介质及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant