CN114049513A - 一种基于多学生讨论的知识蒸馏方法和系统 - Google Patents

一种基于多学生讨论的知识蒸馏方法和系统 Download PDF

Info

Publication number
CN114049513A
CN114049513A CN202111120541.4A CN202111120541A CN114049513A CN 114049513 A CN114049513 A CN 114049513A CN 202111120541 A CN202111120541 A CN 202111120541A CN 114049513 A CN114049513 A CN 114049513A
Authority
CN
China
Prior art keywords
student
network
teacher
distillation
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111120541.4A
Other languages
English (en)
Inventor
王蕊
刘俐君
吕飞霄
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Institute of Information Engineering of CAS
Original Assignee
Institute of Information Engineering of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Institute of Information Engineering of CAS filed Critical Institute of Information Engineering of CAS
Priority to CN202111120541.4A priority Critical patent/CN114049513A/zh
Publication of CN114049513A publication Critical patent/CN114049513A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于多学生讨论的知识蒸馏方法和系统。该方法的步骤包括:1)选取复杂网络ResNet32x4作为知识蒸馏的教师模型,对教师模型进行预训练;2)进行知识蒸馏,采用单老师多学生的蒸馏模式,多个小型学生网络的参数分别初始化并独立训练,分别学习来自教师网络的知识;3)借助讨论模块使得学生网络相互讨论,以各个学生模型的logits输出作为其输入,采用多层卷积神经网络,将各个学生网络的输出耦合在一起,输出最终的类别预测;4)将待分类的图像输入学生网络,再经过学生之间的讨论得到最终的图像分类结果。本发明大大提高了图像分类的准确率,并改善了知识蒸馏领域中师生模型表达能力差异较大的情况。

Description

一种基于多学生讨论的知识蒸馏方法和系统
技术领域
本发明属于计算机视觉技术领域,具体涉及一种基于多学生讨论的知识蒸馏方法和系统。
背景技术
随着算力的提升和大规模数据集的广泛出现,深度模型取得了很大的成功,尤其在图像以及语音识别等的任务上。但是大多数的深度学习模型都包含大量的参数,又深又广的模型在训练时需要消耗大量的计算资源,而且部署模型时依然存在很高的存储以及计算要求。因此,为了获得更快的计算速度,深度模型的压缩成为了近来的研究热点。其中知识蒸馏是一种模型压缩的有效方法,它旨在将复杂的模型或模型集合压缩为较小模型以进行部署。当复杂的模型被训练好之后,可以用来指导轻量级模型的学习,从而应用于实时场景,近年来知识蒸馏被广泛应用于自然语言处理领域和计算机视觉领域。
具体而言,知识蒸馏(Knowledge Distillation)旨在将知识从复杂的深度模型(教师模型)转移到轻量级的模型(学生模型),一般而言前者具有更强的学习和表示能力,同时性能也更高,而后者的计算复杂度低,便于在边缘设备上部署。Hinton等人在2015年首次提出了知识蒸馏的概念,为了进行模型之间暗知识的传递,将学生模型的目标设定为最小化教师输出和学生输出的Kullback-Leibler(KL)散度,学生模型通过模仿教师模型的软目标来提升自身性能。这种软目标的蒸馏方式之所以有效,是因为教师模型赋予不同类别的相对概率为学生模型的训练提供了丰富的信息。
上述基于软目标的蒸馏属于传统知识蒸馏的一种重要方法,此外近年来还出现了很多不同的方法尝试各种知识的迁移形式。例如通过学习教师模型的中间表示、解决问题的流程、注意力图、结构关系、激活图的相似度等以促进学生网络的优化过程。然而,所有的这些方法都是从单个教师模型中提取知识到学生网络,这会导致所学知识的单一甚至产生偏见。为了进一步提高学生网络在部署时的性能,最近的一些研究提出在蒸馏时使用多个教师模型。最简单的思路就是直接将多个教师软目标的均值作为学生网络的学习的指导,每个教师网络被分配了相同的权重。更进一步地,可以利用多个教师模型的加权平均值来指导学生网络的训练,其中权重作为超参数在训练过程中始终不变,并且不同教师网络被赋予不同的权重。为了进一步优化权重的赋值,又有学者提出采用强化学习的方法动态地为教师模型分配权重,以优化学生模型的性能。除了基于权重角度的研究,有学者提出了一种基于噪声的正则化方法,来模拟多位教师的学习。此外,多教师教学的蒸馏架构还可以跟不同的任务相结合来解决不同的问题。例如,可以通过将每个源域与教师相关联来研究域适应任务,以便以后进行蒸馏学习。在多任务学习领域,可以将多教师学习应用于多任务学习,其中每个教师对应一个任务。
总之,知识蒸馏作为一种有效的深度神经网络压缩和加速技术,已广泛应用于人工智能的不同领域,包括计算机视觉、语音识别、自然语言处理以及推荐系统等。具体到计算机视觉领域,知识蒸馏旨在为各种不同的视觉识别任务提供高效和有效的师生学习,使得轻量级学生网络得以在边缘设备部署。然而由于教师模型和学生模型的规模不同,其表示能力亦有较大差距,导致学生模型的性能依然不高。
发明内容
为了克服现有的知识蒸馏方法存在的教师和学生间知识表示的差距问题,本发明的目的在于提供一种基于多学生讨论的知识蒸馏方法,进一步提升学生网络的性能。
本发明首先利用不同的学生网络产生对同一张图像的不同预测,然后将这些结果送入讨论器网络,经过讨论不同的网络协商一致,最终得到对于该图像所属类别的准确预测。具体地,为产生多样化且性能优异的学生网络,首先借助于强大的教师网络对多个学生网络进行知识蒸馏的训练,然后将多个学生网络产生的预测结果输入讨论器(讨论模块),接着对讨论器进行图像分类的训练,使得讨论器同时也作为图像分类器输出最终的预测结果。由于讨论阶段的加入使得学生间能够优势互补,较好地解决了单一的教师学生网络之间表示能力的差距问题,最终得到更准确的分类结果。
本发明采用的技术方案如下:
一种基于多学生讨论的知识蒸馏方法,其步骤包括:
1)教师网络的预训练,知识蒸馏的起点是良好的教师网络,一般而言,教师网络需要选择规模较大的复杂网络,以便于更好地拟合复杂数据;
2)多学生网络的训练,采用单老师多学生的蒸馏模式,多个小型学生网络的参数分别初始化,独立训练分别学习来自教师网络的知识;
3)学生网络相互讨论,讨论模块完成讨论功能,以各个学生网络的logits输出作为其输入,采用多层卷积神经网络,将各个学生网络的输出耦合在一起,输出最终的类别预测。
4)将待分类的图像输入学生网络,再经过学生之间的讨论得到最终的准确率更高的分类结果。
进一步的,所述阶段1)教师网络的预训练,一个普通的知识蒸馏框架通常包含一个或多个大型预训练教师网络和一个小型学生网络。教师网络通常比学生网络大得多,需要提前在大型数据集上进行预训练,本发明优选使用ResNet32x4网络作为教师网络在CIFAR100数据集上进行预训练。
进一步的,所述阶段2)多学生网络的训练,在一般的知识蒸馏框架下,通常包含一个强大的教师网络和一个较小的学生网络,主要思想是在教师网络的指导下训练一个有效的学生网络以获得可比较的准确性。来自教师网络的监督信息,通常称为教师网络学到的“知识”,它可以帮助学生网络模仿教师网络的行为。但是传统的蒸馏模式只包含一个单一的学生网络,通常会由于模型过拟合以及模型表达能力不足的原因导致部署效果较差。因此,本发明采用单老师多学生的蒸馏模式,解决了单学生的蒸馏架构引起的表达能力不足的问题,缩小了教师网络和学生网络之间的性能差距。
进一步的,所述阶段3)学生网络相互讨论,基于上述步骤,得到了多个训练良好的学生网络,而这些网络通常是多样化的,为了充分发挥每个学生网络的优势,在部署阶段做到取长补短,本发明进一步设计讨论模块。讨论模块以各个学生网络的logits输出作为其输入,采用卷积神经网络(CNN)的架构,将各个学生网络的输出耦合在一起,利用CNN网络包含的卷积层、池化层以及其非线性激活的特性,捕获各个学生网络的优势类别,最终输出更加准确的类别预测。
进一步的,所述阶段4)获得待分类图像的预测结果,经过上述阶段2)和阶段3)的训练,已经得到了完整的网络架构以及与之相匹配的网络参数,为了在轻量级的学生网络上获得较好的图像分类效果,本阶段采用经过蒸馏的多学生网络以及讨论模块相结合的网络结构,以待分类图像作为输入,图像经多学生网络,获得多样化的中间结果,之后将中间结果送入讨论模块,讨论模块输出的预测结果即为最终的分类结果,经实验验证表明,上述过程所得分类结果较其他方法性能更优。
进一步的,对于所述阶段2)中每个学生网络的蒸馏过程,本发明中转移暗知识的方法表述如下:定义教师深度网络最后一个全连接层的输出的logits向量为z,令zi表示第i个类别的logit,那么输入属于第i个类别的概率pi可以是由softmax函数估计:
Figure BDA0003276950060000031
因此,教师网络获得的软目标的预测包含暗知识,可以作为监督信息将知识从教师网络转移到学生网络。
进一步的,引入温度因子T来控制每个软目标的重要性:
Figure BDA0003276950060000041
其中较高的温度会产生更柔和的概率分布。具体来说,当T→∞时,所有类共享相同的概率。当T→0时,软目标成为one-hot的标签,即硬目标。来自教师网络的软目标和真实标签对于提高学生网络的性能都非常重要,分别作为蒸馏损失和学生损失共同指导学生网络的优化。
优选地,本发明基于多学生讨论的知识蒸馏方法主要包括以下步骤:
1)使用常规的图像分类数据集CIFAR100作为训练数据集,选取复杂网络作为知识蒸馏的教师网络,对教师网络进行预训练;
2)多学生网络的训练,采用不同的学生网络,包括:ResNet20、ResNet32、ResNet8x4,在每个学生单独蒸馏的过程中采用logits蒸馏形式;
3)学生网络相互讨论,讨论模块采用多层卷积神经网络,具体包含了卷积层、池化层、全连接层以及非线性激活函数,将多个学生的logits输入讨论模块,产生图片的分类结果;
4)将待分类的图像输入经过训练的学生网络,再经过已经训练好的讨论模块输出最终的准确率更高的分类结果。
进一步的,步骤1)教师网络的预训练阶段,采用典型图像分类损失,即交叉熵(Cross Entropy)损失对ResNet32x4的教师网络进行训练,所述交叉熵损失的形式化表示为:
Figure BDA0003276950060000042
其中p(x)表示图像one-hot形式的真实标签分布,q(x)表示教师网络预测的分布。
进一步的,步骤2)多学生的蒸馏,在学生网络的选择上,为了增加学生网络的多样性,本发明选取架构不同的小型学生网络,并且各个网络的参数分别初始化。在学生网络的蒸馏过程中,各个网络分别学习来自教师网络的知识,因此训练过程相互独立。一方面,独立的训练使得学生之间的多样性进一步增强,另一方面,独立训练语允许多个学生网络并行训练,大大缩短了训练时间。进一步的,上述教师网络只有一个,其知识由学生网络共享,这大大简化了步骤1)中教师网络的预训练过程,减少了计算量以及存储空间的占用。进一步的,每个学生网络的优化过程采用了典型的知识蒸馏方法的知识表示,即学生基于教师logits的模仿,进一步的,所述logits指的是深度神经网络中最后一层的输出,将它视为来自教师网络的知识的载体。因此,每个学生网络的蒸馏损失被定义为匹配教师网络和学生网络之间的交叉熵,即:
Figure BDA0003276950060000051
其中zt和zs分别是教师网络和学生网络的logits。公式(4)使得教师网络的logits与学生网络的logits相匹配。
学生网络自身的分类损失是真实标签和学生网络的输出结果之间的交叉熵,其公式表示为:
Figure BDA0003276950060000052
其中,LS表示学生网络自身的分类损失,LCE表示学生网络输出与真实标签的交叉熵损失,y是ground truth(真实标签)的向量,其中只有一个元素为1,代表迁移训练样本的ground truth标签,其他元素为0。在蒸馏损失和学生自身的分类损失中,都使用学生网络的相同logits,但温度系数不同。学生分类损失中的温度系数为T=1,蒸馏损失中的温度系数为T=t,在本发明中t统一取4。因此,多学生知识蒸馏中每个学生最终的损失是蒸馏损失和分类损失的联合:
L(x,W)=α*LD(p(zt,T),p(zs,T))+(1-α)*LS(y,p(zs,T)) (6)
其中x是CIFAR100数据集中训练集上的训练输入,W是学生网络的参数,α是调节参数用于平衡二者的权重。
进一步的,所述步骤4),最终输入到网络中图片是来自CIFAR100的测试数据集,图片经过多个学生网络以及讨论模块的映射,最终得到其分类结果。
本发明还提供一种采用上述方法的基于多学生讨论的知识蒸馏系统,其包括:
教师网络预训练单元,用于对教师网络进行预训练;
多学生网络训练单元,用于采用单老师多学生的蒸馏模式,利用预训练的教师网络对多个学生网络进行知识蒸馏的训练;
讨论模块训练单元,用于将多个学生网络产生的图像分类预测结果输入讨论模块,对讨论模块进行图像分类的训练;
图像分类单元,用于将待分类的图像输入经过训练完成的学生网络,再经过训练完成的讨论模块,输出最终的图像分类结果。
综上所述,本发明设计了一种基于多学生讨论的知识蒸馏方法,使得蒸馏得到的学生网络能够产生更加准确的分类效果。与现有技术相比,本发明的优点在于:
1、采用单教师多学生的知识蒸馏模式,增加了被蒸馏学生数量,且所选取的学生网络架构各异,这使得蒸馏之后的中间结果更具多样性,为后面讨论模块取长补短奠定了基础;
2、本发明在蒸馏过程中采用各个学生网络并行训练的模式,各个学生网络之间相互独立没有参数共享,因而总的蒸馏时间即为最大的学生网络的蒸馏时间,这在一定程度上节约了模型训练的时间成本;
3、讨论模块的设计明显提升了预测性能,实验结果表明,本发明的多学生讨论式的蒸馏方法始终比经典的教师指导的一对一的蒸馏方法具有更好的泛化能力。与其他方法相比,在本发明的方法中观察到更大的学生多样性和更优的讨论结果。
附图说明
图1为基于多学生讨论的知识蒸馏方法的流程图;
图2为单教师多学生蒸馏训练架构图;
图3为多学生讨论架构图。
具体实施方式
下面通过具体实例和附图,对本发明做进一步的详细说明。本发明的基于多学生讨论的知识蒸馏方法的流程图如图1所示,主要分为训练阶段和测试阶段两个阶段。
训练阶段分为三个阶段,其步骤如下:
1)第一阶段预训练教师网络,选取ResNet32*4网络作为教师网络,训练数据集使用CIFAR100。
该步骤1)的处理过程为:使用CIFAR100数据集中训练集包含的全部数据(100个图像类别,每类包含500张32×32大小的彩图)对ResNet32*4网络进行训练以获得训练良好的教师网络。上述教师网络的初始化采用了随机初始化的策略,训练共经过了240个epoch的迭代,初始学习率为0.05,随后以0.1的倍数进行学习率的衰减。训练过程采用随机梯度下降法进行参数优化,并将其动量设置为0.9。
2)知识蒸馏阶段,采用单老师多学生的蒸馏方式,旨在通过蒸馏产生多样化的中间结果,训练过程示意图如图2所示。
该步骤2)的学生网络选取上,一方面要求学生网络的规模足够小,以便于部署在边缘设备上,另一方面,要求学生网络尽可能产生多样化的预测结果,使得网络之间可以相互学习。因此,本发明使用的学生网络分别是:ResNet20、ResNet32、ResNet8x4,它们的规模都比教师网络小很多。在学生网络的结构上,都是由一系列的不同尺度的卷积层、BN层、非线性激活层堆叠而成,在模型的最后由均值池化层连接一个全连接层作为输出层。为进行知识蒸馏,每个学生网络分别计算其最后一层输出与教师网络最后一层输出的交叉熵损失作为蒸馏损失,通过最小化蒸馏损失,使得学生网络学习来自于教师网络输出的暗知识。同步骤1)类似,学生网络的训练数据依然使用CIFAR100数据集的训练集,训练过程中的参数设置与步骤1)相同。此外,由于学生网络的目标函数中包含了最小化蒸馏损失以及分类的交叉熵损失,二者的权重都设为1,为了便于比较,所述蒸馏损失的温度系数本发明全部设置为4。
3)讨论模块训练阶段,由于讨论模块也是由一系列卷积层、池化层、全连接层以及非线性激活函数构成,其参数也需要根据目标函数进行迭代优化,讨论阶段的架构如图3所示。
该步骤3)的讨论模块由两个卷积层、1个最大池化层、全连接层以及最后全连接的输出层构成。其中第一个卷积层conv_1的输入是三个学生网络的logits连接(concat)到一起的结果,其尺寸为3x100,conv_1的卷积核大小为3x3。conv_2的卷积核大小也设置为3x3,随后两个全连接层的输出分别设置为1024、100。由于讨论模块的输出即为最终的类别预测,因此,最后采用softmax函数计算类别的概率分布。所述讨论模块的目标函数定义为模型输出与ground truth的标签的交叉熵损失,优化的参数仅仅涉及讨论模块,而不会回传到学生网络。训练过程中,讨论模块由于其结构简单,仅需要90个epoch即可训练至收敛。
测试阶段的步骤如下:
1)将测试图像并行输入训练好的多学生网络中,得到多个学生网络最后一层输出的logits,然后将所有的logits连接在一起作为讨论模块的输入。训练好的讨论模块经过计算最终输出对应测试图像的分类结果。所述测试图像本发明采用CIFAR100的测试数据集,共100类图像每类包含100张图像。
2)将该模型的最终预测结果与测试集本身的标注进行对比,计算预测正确的top1的准确率,作为对本模型分类效果的评价。
本发明提出的基于多学生讨论的知识蒸馏方法,其测试环境及实验结果为:
(1)测试环境:
系统环境:Ubuntu 16.04.5;
硬件环境:内存:24GB,GPU:NVIDIA Quadro P6000,硬盘:1TB。
(2)实验数据:
训练数据:使用CIFAR100数据集分别进行训练和测试,训练到模型稳定,效果不再提升。
测试数据:CIFAR100的测试数据集。
评估方法:在线评估。
(3)实验结果:
本发明实验结果与传统知识蒸馏方法对比,传统知识蒸馏方法的架构为单老师单学生的蒸馏模式,模型的选取上教师网络为ResNet32x4,学生网络为Resnet3,具体的知识表示形式也是最小化二者logits的差异。测试对比结果如表1所示:
表1.本发明的测试结果对比
方法对比 Accuracy
传统知识蒸馏方法 72.22
本方法 76.77
基于同一发明构思,本发明的另一个实施例提供一种基于多学生讨论的知识蒸馏系统,其包括:
教师网络预训练单元,用于对教师网络进行预训练;
多学生网络训练单元,用于采用单老师多学生的蒸馏模式,利用预训练的教师网络对多个学生网络进行知识蒸馏的训练;
讨论模块训练单元,用于将多个学生网络产生的图像分类预测结果输入讨论模块,对讨论模块进行图像分类的训练;
图像分类单元,用于将待分类的图像输入经过训练完成的学生网络,再经过训练完成的讨论模块,输出最终的图像分类结果。
其中各模块的具体实施过程参见前文对本发明方法的描述。
基于同一发明构思,本发明的另一实施例提供一种电子装置(计算机、服务器、智能手机等),其包括存储器和处理器,所述存储器存储计算机程序,所述计算机程序被配置为由所述处理器执行,所述计算机程序包括用于执行本发明方法中各步骤的指令。
基于同一发明构思,本发明的另一实施例提供一种计算机可读存储介质(如ROM/RAM、磁盘、光盘),所述计算机可读存储介质存储计算机程序,所述计算机程序被计算机执行时,实现本发明方法的各个步骤。
以上实施例仅用以说明本发明的技术方案而非对其进行限制,本领域的普通技术人员可以对本发明的技术方案进行修改或者等同替换,而不脱离本发明的精神和范围,本发明的保护范围应以权利要求书所述为准。

Claims (10)

1.一种基于多学生讨论的知识蒸馏方法,其特征在于,包括以下步骤:
对教师网络进行预训练;
采用单老师多学生的蒸馏模式,利用预训练的教师网络对多个学生网络进行知识蒸馏的训练;
将多个学生网络产生的图像分类预测结果输入讨论模块,对讨论模块进行图像分类的训练;
将待分类的图像输入经过训练完成的学生网络,再经过训练完成的讨论模块,输出最终的图像分类结果。
2.根据权利要求1所述的方法,其特征在于,所述对教师网络进行预训练,是使用ResNet32x4网络作为教师网络在CIFAR100数据集上进行预训练。
3.根据权利要求1所述的方法,其特征在于,所述多个学生网络包括ResNet20、ResNet32、ResNet8x4,在每个学生网络单独蒸馏的过程中采用logits蒸馏形式。
4.根据权利要求1所述的方法,其特征在于,每个学生网络的蒸馏过程包括:定义教师网络最后一个全连接层的输出的logits向量为z,令zi表示第i个类别的logit,那么输入属于第i个类别的概率pi是由softmax函数估计:
Figure FDA0003276950050000011
据此,教师网络获得的软目标的预测包含暗知识,能够作为监督信息将知识从教师网络转移到学生网络。
5.根据权利要求4所述的方法,其特征在于,引入温度因子T来控制每个软目标的重要性:
Figure FDA0003276950050000012
其中,较高的温度会产生更柔和的概率分布;当T→∞时所有类共享相同的概率,当T→0时软目标成为one-hot的标签,即硬目标;来自教师网络的软目标和真实标签分别作为蒸馏损失和学生损失共同指导学生网络的优化。
6.根据权利要求5所述的方法,其特征在于,每个学生网络的蒸馏损失被定义为匹配教师网络和学生网络之间的交叉熵,即:
Figure FDA0003276950050000013
其中zt和zs分别是教师网络和学生网络的logits,该公式使得教师网络的logits与学生网络的logits相匹配;
学生网络自身的分类损失是真实标签和学生网络的输出结果之间的交叉熵,其公式表示为:
Figure FDA0003276950050000021
其中y是真实标签的向量,其中只有一个元素为1,代表迁移训练样本的ground truth标签,其他元素为0;在蒸馏损失和学生网络自身的分类损失中,都使用学生网络的相同logits,但温度系数不同,学生网络分类损失中的温度系数为T=1,蒸馏损失中的温度系数为T=t,多学生知识蒸馏中每个学生网络最终的损失是蒸馏损失和分类损失的联合:
L(x,W)=α*LD(p(zt,T),p(zs,T))+(1-α)*Ls(y,p(zs,T))
其中x是CIFAR100数据集中训练集上的训练输入,W是学生网络的参数,α是调节参数用于平衡二者的权重。
7.根据权利要求1所述的方法,其特征在于,所述讨论模块以各个学生网络的logits输出作为其输入,采用CNN网络的架构将各个学生网络的输出耦合在一起,利用CNN网络包含的卷积层、池化层以及其非线性激活的特性,捕获各个学生网络的优势类别,最终输出准确的类别预测。
8.一种采用权利要求1~7中任一权利要求所述方法的基于多学生讨论的知识蒸馏系统,其特征在于,包括:
教师网络预训练单元,用于对教师网络进行预训练;
多学生网络训练单元,用于采用单老师多学生的蒸馏模式,利用预训练的教师网络对多个学生网络进行知识蒸馏的训练;
讨论模块训练单元,用于将多个学生网络产生的图像分类预测结果输入讨论模块,对讨论模块进行图像分类的训练;
图像分类单元,用于将待分类的图像输入经过训练完成的学生网络,再经过训练完成的讨论模块,输出最终的图像分类结果。
9.一种电子装置,其特征在于,包括存储器和处理器,所述存储器存储计算机程序,所述计算机程序被配置为由所述处理器执行,所述计算机程序包括用于执行权利要求1~7中任一权利要求所述方法的指令。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储计算机程序,所述计算机程序被计算机执行时,实现权利要求1~7中任一权利要求所述的方法。
CN202111120541.4A 2021-09-24 2021-09-24 一种基于多学生讨论的知识蒸馏方法和系统 Pending CN114049513A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111120541.4A CN114049513A (zh) 2021-09-24 2021-09-24 一种基于多学生讨论的知识蒸馏方法和系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111120541.4A CN114049513A (zh) 2021-09-24 2021-09-24 一种基于多学生讨论的知识蒸馏方法和系统

Publications (1)

Publication Number Publication Date
CN114049513A true CN114049513A (zh) 2022-02-15

Family

ID=80204643

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111120541.4A Pending CN114049513A (zh) 2021-09-24 2021-09-24 一种基于多学生讨论的知识蒸馏方法和系统

Country Status (1)

Country Link
CN (1) CN114049513A (zh)

Cited By (13)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114758180A (zh) * 2022-04-19 2022-07-15 电子科技大学 一种基于知识蒸馏的轻量化花卉识别方法
CN115019183A (zh) * 2022-07-28 2022-09-06 北京卫星信息工程研究所 基于知识蒸馏和图像重构的遥感影像模型迁移方法
CN115115879A (zh) * 2022-06-29 2022-09-27 合肥工业大学 可切换在线知识蒸馏的图像分类方法、装置及可存储介质
CN115131599A (zh) * 2022-04-19 2022-09-30 浙江大学 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法
CN115223049A (zh) * 2022-09-20 2022-10-21 山东大学 面向电力场景边缘计算大模型压缩的知识蒸馏与量化技术
CN115511059A (zh) * 2022-10-12 2022-12-23 北华航天工业学院 一种基于卷积神经网络通道解耦的网络轻量化方法
CN115965964A (zh) * 2023-01-29 2023-04-14 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备
CN116205290A (zh) * 2023-05-06 2023-06-02 之江实验室 一种基于中间特征知识融合的知识蒸馏方法和装置
CN116453105A (zh) * 2023-06-20 2023-07-18 青岛国实科技集团有限公司 基于知识蒸馏深度神经网络的船牌号识别方法及系统
CN116486285A (zh) * 2023-03-15 2023-07-25 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116719945A (zh) * 2023-08-08 2023-09-08 北京惠每云科技有限公司 一种医学短文本的分类方法、装置、电子设备及存储介质
WO2024000344A1 (zh) * 2022-06-30 2024-01-04 华为技术有限公司 一种模型训练方法及相关装置
CN117892841B (zh) * 2024-03-14 2024-05-31 青岛理工大学 基于渐进式联想学习的自蒸馏方法及系统

Cited By (23)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115131599B (zh) * 2022-04-19 2023-04-18 浙江大学 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法
CN115131599A (zh) * 2022-04-19 2022-09-30 浙江大学 一种基于对抗偏差与鲁棒性知识蒸馏的图像分类方法
CN114758180A (zh) * 2022-04-19 2022-07-15 电子科技大学 一种基于知识蒸馏的轻量化花卉识别方法
CN114758180B (zh) * 2022-04-19 2023-10-10 电子科技大学 一种基于知识蒸馏的轻量化花卉识别方法
CN115115879A (zh) * 2022-06-29 2022-09-27 合肥工业大学 可切换在线知识蒸馏的图像分类方法、装置及可存储介质
CN115115879B (zh) * 2022-06-29 2024-03-19 合肥工业大学 可切换在线知识蒸馏的图像分类方法、装置及可存储介质
WO2024000344A1 (zh) * 2022-06-30 2024-01-04 华为技术有限公司 一种模型训练方法及相关装置
CN115019183A (zh) * 2022-07-28 2022-09-06 北京卫星信息工程研究所 基于知识蒸馏和图像重构的遥感影像模型迁移方法
CN115223049B (zh) * 2022-09-20 2022-12-13 山东大学 面向电力场景边缘计算大模型压缩的知识蒸馏与量化方法
CN115223049A (zh) * 2022-09-20 2022-10-21 山东大学 面向电力场景边缘计算大模型压缩的知识蒸馏与量化技术
CN115511059A (zh) * 2022-10-12 2022-12-23 北华航天工业学院 一种基于卷积神经网络通道解耦的网络轻量化方法
CN115511059B (zh) * 2022-10-12 2024-02-09 北华航天工业学院 一种基于卷积神经网络通道解耦的网络轻量化方法
CN115965964B (zh) * 2023-01-29 2024-01-23 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备
CN115965964A (zh) * 2023-01-29 2023-04-14 中国农业大学 一种鸡蛋新鲜度识别方法、系统及设备
CN116486285A (zh) * 2023-03-15 2023-07-25 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116486285B (zh) * 2023-03-15 2024-03-19 中国矿业大学 一种基于类别掩码蒸馏的航拍图像目标检测方法
CN116205290B (zh) * 2023-05-06 2023-09-15 之江实验室 一种基于中间特征知识融合的知识蒸馏方法和装置
CN116205290A (zh) * 2023-05-06 2023-06-02 之江实验室 一种基于中间特征知识融合的知识蒸馏方法和装置
CN116453105B (zh) * 2023-06-20 2023-08-18 青岛国实科技集团有限公司 基于知识蒸馏深度神经网络的船牌号识别方法及系统
CN116453105A (zh) * 2023-06-20 2023-07-18 青岛国实科技集团有限公司 基于知识蒸馏深度神经网络的船牌号识别方法及系统
CN116719945A (zh) * 2023-08-08 2023-09-08 北京惠每云科技有限公司 一种医学短文本的分类方法、装置、电子设备及存储介质
CN116719945B (zh) * 2023-08-08 2023-10-24 北京惠每云科技有限公司 一种医学短文本的分类方法、装置、电子设备及存储介质
CN117892841B (zh) * 2024-03-14 2024-05-31 青岛理工大学 基于渐进式联想学习的自蒸馏方法及系统

Similar Documents

Publication Publication Date Title
CN114049513A (zh) 一种基于多学生讨论的知识蒸馏方法和系统
Jaafra et al. Reinforcement learning for neural architecture search: A review
CN108805270B (zh) 一种基于存储器的卷积神经网络系统
CN109325443B (zh) 一种基于多实例多标签深度迁移学习的人脸属性识别方法
WO2022252272A1 (zh) 一种基于迁移学习的改进vgg16网络猪的身份识别方法
CN109829541A (zh) 基于学习自动机的深度神经网络增量式训练方法及系统
CN110472730A (zh) 一种卷积神经网络的自蒸馏训练方法和可伸缩动态预测方法
CN108846384A (zh) 融合视频感知的多任务协同识别方法及系统
CN110070107A (zh) 物体识别方法及装置
CN116134454A (zh) 用于使用知识蒸馏训练神经网络模型的方法和系统
CN110490136A (zh) 一种基于知识蒸馏的人体行为预测方法
CN109102000A (zh) 一种基于分层特征提取与多层脉冲神经网络的图像识别方法
CN109829049A (zh) 利用知识库渐进时空注意力网络解决视频问答任务的方法
Lyu et al. Neural architecture search for portrait parsing
CN114332545B (zh) 一种基于低比特脉冲神经网络的图像数据分类方法和装置
Sang et al. Discriminative deep feature learning for facial emotion recognition
CN108171328A (zh) 一种卷积运算方法和基于该方法的神经网络处理器
CN114819091B (zh) 基于自适应任务权重的多任务网络模型训练方法及系统
CN114758180B (zh) 一种基于知识蒸馏的轻量化花卉识别方法
CN115966010A (zh) 一种基于注意力和多尺度特征融合的表情识别方法
CN113095251B (zh) 一种人体姿态估计方法及系统
CN110188621A (zh) 一种基于ssf-il-cnn的三维人脸表情识别方法
CN114202021A (zh) 一种基于知识蒸馏的高效图像分类方法及系统
CN114170659A (zh) 一种基于注意力机制的面部情感识别方法
Zhong et al. Face expression recognition based on NGO-BILSTM model

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination