CN114611670A - 一种基于师生协同的知识蒸馏方法 - Google Patents

一种基于师生协同的知识蒸馏方法 Download PDF

Info

Publication number
CN114611670A
CN114611670A CN202210254811.9A CN202210254811A CN114611670A CN 114611670 A CN114611670 A CN 114611670A CN 202210254811 A CN202210254811 A CN 202210254811A CN 114611670 A CN114611670 A CN 114611670A
Authority
CN
China
Prior art keywords
network
output
teacher
student
branch
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210254811.9A
Other languages
English (en)
Inventor
李刚
高文建
徐传运
张杨
刘欢
王影
郑宇�
徐昊
张晴
宋志瑶
马莹丽
曹铠洪
陈志远
朱鑫
李梦伟
白南兰
陈鹏
孙成杰
王克亚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Chongqing University of Technology
Original Assignee
Chongqing University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Chongqing University of Technology filed Critical Chongqing University of Technology
Priority to CN202210254811.9A priority Critical patent/CN114611670A/zh
Publication of CN114611670A publication Critical patent/CN114611670A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q50/00Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
    • G06Q50/10Services
    • G06Q50/20Education
    • G06Q50/205Education administration or guidance

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Business, Economics & Management (AREA)
  • Theoretical Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Data Mining & Analysis (AREA)
  • Educational Administration (AREA)
  • Educational Technology (AREA)
  • Strategic Management (AREA)
  • Tourism & Hospitality (AREA)
  • Economics (AREA)
  • Human Resources & Organizations (AREA)
  • Marketing (AREA)
  • Primary Health Care (AREA)
  • General Business, Economics & Management (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明涉及知识蒸馏技术领域,具体涉及一种基于师生协同的知识蒸馏方法,包括:构建经过预先训练的教师网络以及具有多层级的分支输出的学生网络;将训练数据分别输入教师网络和学生网络,得到教师网络输出的概率分布以及各个分支输出的概率分布和特征;计算各个分支的知识蒸馏损失和自蒸馏损失;然后通过各个分支的知识蒸馏损失和自蒸馏损失计算对应的整体损失函数,并更新学生网络的参数;对学生网络各个分支输出的概率分布进行融合,得到对应的最终概率分布;重复上述步骤,直至学生网络训练至收敛。本发明能够通过教师网络和学生网络自身来协同优化和训练学生网络,使得不增加教师网络的复杂度并能够基于学生网络的输出进行自监督和自学习。

Description

一种基于师生协同的知识蒸馏方法
技术领域
本发明涉及知识蒸馏技术领域,具体涉及一种基于师生协同的知识蒸馏方法。
背景技术
随着深度学习的快速发展,深度卷积网络在计算机视觉的各项任务中表现出了出色的性能。然而卷积神经网络越来越深,导致模型参数巨大、计算复杂、延时高,使得在有限硬件资源条件下难以部署到终端。其中,知识蒸馏(knowledge distillation)作为神经网络模型压缩的一种重要方法,其目的是使用一种轻量化的模型从过度参数化的模型中学习有效的知识,获得近似于复杂模型的性能,从而达到模型压缩的目的。
现有技术通常将知识蒸馏模型结构称为师生网络,在具有一位经验丰富的教师网络的条件下,学生网络通过知识蒸馏学习教师网络丰富信息,提升自身网络的性能。例如,公开号为CN112418343A的中国专利就公开了《多教师自适应联合知识蒸馏》,其包括:将训练好的多个教师网络的特征输入到一个深度神经网络进行二次分类,将深度神经网络的中间层作为教师网络的特征融合模型;将同一批训练数据输入教师网络和学生网络,得到各个教师网络的特征和概率分布;用训练好的深度神经网络融合特征,用加权预测融合各个教师网络的预测结果;构造损失函数,并基于损失函数更新学生网络的参数,固定其他模型的参数;重复上述步骤,直到学生网络收敛。
上述现有方案中的多教师自适应联合知识蒸馏方法,通过将不同教师网络传递的知识有差异的结合,形成软标签引导学生网络的学习,使得学生网络的学习更加有效。但申请人发现,上述现有方案在训练学生网络时,需要构建并融合多个教师网络的特征和概率分布,这大大增加了整个知识蒸馏模型的复杂度和训练成本。同时,现有方案仅专注于提升教师网络的性能以及如何传递有效的信息,而忽略了挖掘学生网络的潜在价值,导致学生网络的性能有待进一步提高。因此,如何设计一种能够降低知识蒸馏模型复杂度并提高学生网络性能的知识蒸馏方法是亟需解决的技术问题。
发明内容
针对上述现有技术的不足,本发明所要解决的技术问题是:如何提供一种基于师生协同的知识蒸馏方法,以能够通过教师网络和学生网络自身来协同优化和训练学生网络,使得不增加教师网络的复杂度并能够基于学生网络的输出进行自监督和自学习,从而能够降低知识蒸馏模型的复杂度并提高学生网络的性能。
为了解决上述技术问题,本发明采用了如下的技术方案:
基于师生协同的知识蒸馏方法,包括以下步骤:
S1:构建经过预先训练的教师网络,以及具有多层级的分支输出的学生网络;
S2:将训练数据分别输入教师网络和学生网络,得到教师网络输出的概率分布以及各个分支输出的概率分布和特征;
S3:通过教师网络输出的概率分布和各个分支输出的概率分布计算各个分支的知识蒸馏损失;然后通过各个分支输出的概率分布和特征计算各个分支的自蒸馏损失;最后通过各个分支的知识蒸馏损失和自蒸馏损失计算对应的整体损失函数,并更新学生网络的参数;
S4:对学生网络各个分支输出的概率分布进行融合,得到对应的最终概率分布;
S5:重复步骤S1至S4,直至学生网络训练至收敛。
优选的,步骤S1中,使用过参数化的ResNet模型或VGG模型作为教师网络,并对教师网络进行训练。
优选的,步骤S1中,在学生网络的不同阶段添加自适应瓶颈层和全连接层,使得学生网络能够形成由浅到深的多个层级的分支输出。
优选的,步骤S2中,步骤S2中,自适应瓶颈层的结构由1x1、3x3、1x1的三层卷积模块组成,其自适应体现在根据不同特征图的大小使用不同数量的瓶颈模块。
优选的,步骤S3中,分支的知识蒸馏损失包括教师网络输出的概率分布和对应分支输出的概率分布之间的KL散度,以及对应分支输出的概率分布与训练数据的真实标签之间的交叉熵损失。
优选的,步骤S3中,知识蒸馏损失通过如下公式计算:
Figure BDA0003548131370000021
其中,yt=ft(x,wt);
yi=fs(x,ws);
式中:
Figure BDA0003548131370000022
表示第i个分支的知识蒸馏损失;i∈[1,n];T2LKL(yi,yt)表示教师网络输出的概率分布yt和第i个分支输出的概率分布yi之间的KL散度;LCE(yi,y)表示第i个分支输出的概率分布yi与训练数据的真实标签y之间的交叉熵损失;wt、ws表示教师网络和学生网络的权重参数;x表示教师网络和学生网络的输入;ft和fs表示教师网络和学生网络的特征。
优选的,步骤S3中,分支的自蒸馏损失包括对应分支输出的概率分布与主干网络输出的概率分布之间的KL散度,以及对应分支输出的特征与主干网络输出的特征之间的L2损失;其中,将最深层级分支的输出作为主干网络的输出。
优选的,步骤S3中,自蒸馏损失通过如下公式计算:
Figure BDA0003548131370000023
其中,yi,fi=fs(x,ws);
式中:
Figure BDA0003548131370000031
表示第i个分支的自蒸馏损失;i∈[1,n];T2LKL(yi,yn)表示第i个分支输出的概率分布yi与主干网络输出的概率分布yn之间的KL散度;||ui(fi)-fn||2表示第i个分支输出的特征fi与主干网络输出的特征fn之间的L2损失。
优选的,步骤S3中,整体损失函数表示为:
Figure BDA0003548131370000032
式中:Loss表示整体损失;
Figure BDA0003548131370000033
表示第i个分支的知识蒸馏损失;
Figure BDA0003548131370000034
表示第i个分支的自蒸馏损失;i∈[1,n];a、β表示设置的超参数。
优选的,步骤S4中,最终概率分布通过如下公式计算:
Figure BDA0003548131370000035
式中:ys表示学生网络输出的最终概率分布;yi表示第i个分支输出的概率分布;i∈[1,n]。
本发明中的基于师生协同的知识蒸馏方法与现有技术相比,具有如下有益效果:
本发明通过构建教师网络以及具有多层级分支输出的学生网络,进而通过教师网络输出的概率分布以及各个分支输出的概率分布和特征分别构建基于知识蒸馏和自蒸馏结合的整体损失函数,使得能够在教师网络指导的基础上,通过学生网络多层级的分支输出的概率分布和特征进行自我监督,即在教师网络蒸馏结构的基础上将学生网络作为第二个老师,能够通过教师网络和学生网络自身来协同优化和训练学生网络,本发明仅需在学生网络的主干网络中添加少许层,而无需提高教师网络的复杂度,使得不增加教师网络的复杂度并能够基于学生网络的输出进行自监督和自学习,从而能够降低知识蒸馏模型的复杂度并提高学生网络的性能,并兼顾知识蒸馏模型的训练成本和训练效果。
附图说明
为了使发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明作进一步的详细描述,其中:
图1为基于师生协同的知识蒸馏方法的逻辑框图;
图2为教师网络和学生网络的网络结构示意图;
图3为四个分支输出特征的示意图。
具体实施方式
下面通过具体实施方式进一步详细的说明:
实施例:
本实施例中公开了一种基于师生协同的知识蒸馏方法。
如图1和图2所示,基于师生协同的知识蒸馏方法,包括以下步骤:
S1:构建经过预先训练的教师网络,以及具有多层级的分支输出的学生网络;
S2:将训练数据分别输入教师网络和学生网络,得到教师网络输出的概率分布以及各个分支输出的概率分布和特征;本实施例中,需固定教师网络的参数。
S3:通过教师网络输出的概率分布和各个分支输出的概率分布计算各个分支的知识蒸馏损失;然后通过各个分支输出的概率分布和特征计算各个分支的自蒸馏损失;最后通过各个分支的知识蒸馏损失和自蒸馏损失计算对应的整体损失函数,并更新学生网络的参数;
S4:对学生网络各个分支输出的概率分布进行融合,得到对应的最终概率分布;
S5:重复步骤S1至S4,直至学生网络训练至收敛。
本发明通过构建教师网络以及具有多层级分支输出的学生网络,进而通过教师网络输出的概率分布以及各个分支输出的概率分布和特征分别构建基于知识蒸馏和自蒸馏结合的整体损失函数,使得能够在教师网络指导的基础上,通过学生网络多层级的分支输出的概率分布和特征进行自我监督,即在教师网络蒸馏结构的基础上将学生网络作为第二个老师,能够通过教师网络和学生网络自身来协同优化和训练学生网络,本发明仅需在学生网络的主干网络中添加少许层,而无需提高教师网络的复杂度,使得不增加教师网络的复杂度并能够基于学生网络的输出进行自监督和自学习,从而能够降低知识蒸馏模型的复杂度并提高学生网络的性能,并兼顾知识蒸馏模型的训练成本和训练效果。
具体实施过程中,使用过参数化的ResNet模型或VGG模型作为教师网络,并对教师网络进行训练。使用相比教师网络较小的模型作为学生网络,在学生网络的不同阶段添加自适应瓶颈层和全连接层,使得学生网络能够形成由浅到深的多个层级的(分类器)分支输出。
本实施例中,自适应瓶颈层的具体结构由1x1、3x3、1x1的三层卷积模块组成,自适应体现在根据不同特征图的大小使用不同数量的瓶颈模块。引入自适应瓶颈层一方面是为了保证学生网络不同阶段的输出特征是相同尺度的,另一方面是减小卷积的计算量。全连接层用于输出类别的概率分布。由于不同阶段的分支(分类器)网络结构不同,对样本有不同的拟合性能,这提供了丰富的类别信息。
具体实施过程中,分支的知识蒸馏损失包括教师网络输出的概率分布和对应分支输出的概率分布之间的KL散度,以及对应分支输出的概率分布与训练数据的真实标签之间的交叉熵损失。知识蒸馏损失通过如下公式计算:
Figure BDA0003548131370000041
其中,yt=ft(x,wt);
yi=fs(x,ws);
式中:
Figure BDA0003548131370000051
表示第i个分支的知识蒸馏损失;i∈[1,n];T2LKL(yi,yt)表示教师网络输出的概率分布yt和第i个分支输出的概率分布yi之间的KL散度;LCE(yi,y)表示第i个分支输出的概率分布yi与训练数据的真实标签y之间的交叉熵损失;wt、ws表示教师网络和学生网络的权重参数;x表示教师网络和学生网络的输入;ft和fs表示教师网络和学生网络的特征。
具体实施过程中,分支的自蒸馏损失包括对应分支输出的概率分布与主干网络输出的概率分布之间的KL散度,以及对应分支输出的特征与主干网络输出的特征之间的L2损失;其中,将最深层级分支的输出作为主干网络的输出。自蒸馏损失通过如下公式计算:
Figure BDA0003548131370000052
其中,yi,fi=fs(x,ws);
式中:
Figure BDA0003548131370000053
表示第i个分支的自蒸馏损失;i∈[1,n-1];T2LKL(yi,yn)表示第i个分支输出的概率分布yi与主干网络输出的概率分布yn之间的KL散度;||ui(fi)-fn||2表示第i个分支输出的特征fi与主干网络输出的特征fn之间的L2损失。
整体损失函数表示为:
Figure BDA0003548131370000054
式中:Loss表示整体损失;
Figure BDA0003548131370000055
表示第i个分支的知识蒸馏损失;
Figure BDA0003548131370000056
表示第i个分支的自蒸馏损失;i∈[1,n];a、β表示设置的超参数。
本实施例中,KL散度、交叉熵损失和L2损失的计算均采用现有手段,这里不再赘述。
其中,KL散度(Kullback-Leibler divergence,Kullback-Leibler散度)又称为相对熵(relative entropy)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量。
在信息理论中,相对熵等价于两个概率分布的信息熵(Shannon entropy)的差值。相对熵是一些优化算法,例如最大期望算法(Expectation-Maximization algorithm,EM)的损失函数。此时参与计算的一个概率分布为真实分布,另一个为理论(拟合)分布,相对熵表示使用理论分布拟合真实分布时产生的信息损耗。
交叉熵(Cross Entropy)是Shannon信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。
语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。复杂度的意义是用该模型表示这一文本平均的分支数,其倒数可视为每个词的平均概率。平滑是指对没观察到的N元组合赋予一个概率值,以保证词序列总能通过语言模型得到一个概率值。通常使用的平滑技术有图灵估计、删除插值平滑、Katz平滑和Kneser-Ney平滑。
相对熵(relative entropy),又被称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量。在信息理论中,相对熵等价于两个概率分布的信息熵(Shannon entropy)的差值。
相对熵是一些优化算法,例如最大期望算法(Expectation-Maximizationalgorithm,EM)的损失函数。此时参与计算的一个概率分布为真实分布,另一个为理论(拟合)分布,相对熵表示使用理论分布拟合真实分布时产生的信息损耗
L2损失(L2范数损失函数),也被称为最小平方误差(LSE)。它是把目标值与估计值的差值的平方和最小化。一般回归问题会使用此损失,离群点对次损失影响较大。
本发明通过教师网络输出的概率分布和分支输出的概率分布之间的KL散度以及分支输出的概率分布与训练数据的真实标签之间的交叉熵损失计算分支的知识蒸馏损失,通过分支输出的概率分布与主干网络输出的概率分布之间的KL散度,以及分支输出的特征与主干网络输出的特征之间的L2损失计算分支的自蒸馏损失,进而能够基于各个分支的知识蒸馏损失和自蒸馏损失计算整体损失函数,使得能够在教师网络指导的基础上,通过学生网络多层级的分支输出的概率分布和特征进行自我监督,即在教师网络蒸馏结构的基础上将学生网络作为第二个老师,从而能够通过教师网络和学生网络自身来协同优化和训练学生网络。
具体实施过程中,最终概率分布通过如下公式计算:
Figure BDA0003548131370000061
式中:ys表示学生网络输出的最终概率分布;yi表示第i个分支输出的概率分布;i∈[1,n]。
本发明通过平均集成的方式计算学生网络输出的最终概率分布,使得能够综合各个分支输出的概率分布来分析最终概率分布,从而能够进一步提高学生网络的性能。
为了更好的说明本发明的优势,本实施例中还公开了如下实验。
1、数据集和实验设置
1)CIFAR-100(来自A.Krizhevsky,Learning multiple layers of featuresfrom tinyimages):该数据集由AlexKrizhevsky,VinodNair和GeoffreyHinton收集,共有60K张大小为32x32的彩色图像,分成100个类别,其中训练样本50k,测试样本10k。数据预处理遵循CRD(来自Y.Tian,D.Krishnan,P.Isola,Contrastive representationdistillation)的数据集处理方法,将训练集图像各边填充4个像素,再随机裁剪为32x32,同时以0.5的概率进行随机水平翻转。测试时,采用原始图像进行评估。实验使用SGD优化,将权重衰减和动量分别设置为0.0001和0.9。batchsize设置为128,初始学习率为0.1,在epoch为150、180、210分别降低为原来的0.1倍,在240轮结束训练。
2)Tiny-ImageNet:作为大规模图像分类数据集ImageNet(来自J.Deng,W.Dong,R.Socher,L.-J.Li,K.Li,L.Fei-Fei,Imagenet:A large-scale hierarchical imagedatabase)的一个子集,由斯坦福大学2016年发布。共有120k张大小为64x64的彩色图像,分成200个类别,其中训练样本100k张,验证集、测试集各10k。实验仅采用简单的随机水平翻转进行预处理,以原图大小进行训练和测试。优化方式及超参数设置遵循CIFAR数据集。
2、对比基准方法
实验分别选用经典的ResNet(来自K.He,X.Zhang,S.Ren,J.Sun,Deep residuallearning forimage recognition)和VGG(来自J.Kim,S.Park,N.Kwak,Paraphrasingcomplex network:network compression via factor transfer)作为主干网络。为了融合教师网络和学生网络自身不同层次知识,我们在常规的教师指导下构造了多级输出的学生网络。方便起见,在特征空间分辨率下降的块间插入三个独立分类器分支,每一分支包含了瓶颈层和全连接层,其中瓶颈层保证了输出特征图大小保持一致,同时减轻浅层分类器之间影响。
与Zhang el al.(来自L.Zhang,J.Song,A.Gao,J.Chen,C.Bao,K.Ma,Be yourownteacher:Improve the performance of convolutional neuralnetworks via selfdistillation)不同,我们多个分支网络的全连接层采用共享权重,降低模型参数量。
表1显示了学生网络每个分支在CIFAR100上的表现,我们发现,由于网络深度不同,捕获的语义特征也不同,深层分类器较浅层拥有更高的分类精度。测试时,我们使用了一种平均集成方法,平衡多出口的分类差异,实验结果表明,我们最终的测试准确率对比基准值均有4%-7%的提升。另外,我们发现,基于师生协同的知识蒸馏方法让模型浅层出口的分类精度就已经能接近或超越整个模型的最终精度。
表1师生协同知识蒸馏方法与基准方法的分类准确率对比(%)
Figure BDA0003548131370000071
3、对比知识蒸馏方法
为了表明本发明提出的师生联合蒸馏方法的有效性和鲁棒性,我们选用了五种不同的师生架构,其中包含了同构和异构模型,并分别对比了一些主流的知识蒸馏方法。大多实验方法遵循原作者开源代码实现,少数方法按照Tian et al.(来自Y.Tian,D.Krishnan,P.Isola,Contrastive representation distillation)的复现,在CIFAR-100和Tiny-ImageNet两个数据集进行了实验。以分类准确率和参数量为评价指标,分类结果如表2,3所示。模型参数量如表4所示。由于我们在学生网络上构建了多出口网络,导致参数量略高于传统KD算法,但与教师网络仍有较大差距,也能达到较好的模型压缩效果。而且,从分类精度上看,我们方法对比一些优秀的蒸馏方法,学生网络均有1%-3%的提升。
表2 CIFAR100上师生协同知识蒸馏方法与知识蒸馏方法的分类准确率对比(%)
Figure BDA0003548131370000081
其中,KD来自G.Hinton,O.Vinyals,J.Dean,Distilling the knowledge inaneural network;
FIT来自A.Romero,N.Ballas,S.E.Kahou,A.Chassang,C.Gatta,Y.Bengio,Fitnets:Hints for thin deep nets;
AT来自S.Zagoruyko,N.Komodakis,Paying more attention to attention:Improving the performance of convolutional neural networks via attentiontransfer;
SP来自F.Tung,G.Mori,Similarity-preserving knowledge distillation;
CC来自B.Peng,X.Jin,J.Liu,D.Li,Y.Wu,Y.Liu,S.Zhou,Z.Zhang,Correlationcongruence for knowledge distillation;
VID来自S.Ahn,S.X.Hu,A.Damianou,N.D.Lawrence,Z.Dai,Variationalinformation distillation for knowledge transfer;
RKD来自W.Park,D.Kim,Y.Lu,M.Cho,Relational knowledge distillation;
PKT来自N.Passalis,A.Tefas,Learning deep representations withprobabilistic knowledge transfer;
AB来自B.Heo,M.Lee,S.Yun,J.Y.Choi,Knowledge transfer via distillationof activation boundaries formed by hidden neurons;
FT来自J.Kim,S.Park,N.Kwak,Paraphrasing complex network:networkcompression via factor transfer;
NST来自Z.Huang,N.Wang,Like what you like:Knowledge distill vianeuronselectivity transfer;
CRD来自Y.Tian,D.Krishnan,P.Isola,Contrastive representationdistillation。
表3 Tiny-ImaNet上师生协同知识蒸馏方法与知识蒸馏方法的分类准确率对比(%)
Figure BDA0003548131370000091
表4师生模型参数量对比(M)
Model Parameters
ResNet152 58.348
ResNet50 37.812
ResNet34 21.798
ResNet18 12.334
ResNet10 5.859
VGG13 9.923
VGG8 5.383
4、对比多出口网络(multi-exit net)
无教师自蒸馏模型通常是多出口结构,本发明的学生网络也可以看作一种基于知识蒸馏的多出口结构,与过去Zhang et al.提出的多分类器网络主要区别在于我们的每个分类器都接受来自教师网络的监督,而不是仅是深层分类器的监督,deeply supervisednet(DSN)(来自C.-Y.Lee,S.Xie,P.Gallagher,Z.Zhang,Z.Tu,Deeply supervised nets)则是用真实标签对中间层加以约束,通过减轻梯度爆炸或消失来提高分类精度。为了验证提出的方法的有效性,实验对比了这两种方法,选用ResNet152作为教师网络,分别用ResNet18和ResNet50作为多出口学生主干网络。实验结果如表5所示,无论是浅层分类器还是模型最终的输出,本发明基于师生协同蒸馏的多出口学生网络都表现出优越的性能。可以发现,知识蒸馏使多出口网络去匹配一个额外的教师网络知识是有效的,每个分类器捕获了更多的视图特征。
表5提出的方法对比其他多出口网络优化方法(%)
Figure BDA0003548131370000101
其中,DSN(深度监督网络)来自C.-Y.Lee,S.Xie,P.Gallagher,Z.Zhang,Z.Tu,Deeply supervised nets;
SD(自蒸馏)来自L.Zhang,J.Song,A.Gao,J.Chen,C.Bao,K.Ma,Be your ownteacher:Improve the performance of convolutional neural networks via selfdistillation。
5、实验分析
我们对实验观察展开进一步地分析:首先通过消融实验对每一部分策略进行讨论,然后分别分析了多出口蒸馏和集成模块的有效性,最后从信息论的和特征学习的角度对我们的整体方法提供解释。
5.1消融实验(Ablation Study)
由于我们的方法是基于师生之间知识蒸馏和学生网络自蒸馏实现的,实验效果来自于知识蒸馏还是自蒸馏是有争议的。为了进一步验证我们方法的有效性,选用不同的学生网络,分别实施随机梯度下降、知识蒸馏和学生自我蒸馏三种方法进行对比,以分类准确率为评价指标。
进一步,本发明提出的师生协同蒸馏方法融合了三个部分的监督:(i)教师网络输出logits对学生网络监督Logits(T),(ii)学生网络深层的soft logits对浅层指导Logits(S),(iii)学生网络浅层的特征匹配深层特征Feature(S),以及最后使用了平均集成策略。为了评估每一部分的有效性,我们选用ResNet152和ResNet18分别作为教师和学生网络,在CIFAR-100上进行了消融实验。
实验结果如表6所示。可以看出,每个策略对于分类精度都有不同程度的提升,且对比传统只用教师网络logits的知识蒸馏方法,有较大的提升,甚至优于教师网络。
表6 CIFAR100上对不同策略的消融实验结果
Figure BDA0003548131370000111
5.2 Multi-exits学生网络特征降维可视化
本发明构建了基于自蒸馏的多出口学生网络,其中主干网络最深层输出可以被认为是第二老师,类似于多教师蒸馏,不同网络学到了不同的视图特征,通过知识蒸馏和自蒸馏,使学生网络匹配多个模型的特征表示知识。我们对三个分支网络和主干网络中全连接层前的高维特征进行降维可视化。如图3所示,学生网络每个出口的分类效果显著,浅层的分类精度甚至接近深层的分类性能。
5.3平均集成及敏感度分析
这一部分我们讨论了多出口集成的有效性以及集成的出口数量对实验的影响。在学生网络中,我们构造了多个输出通道,每个输出通道都是一个独立的分类网络。在“多视图”数据中,每个数据类包含了多种视图特征,不同的网络往往学到了不同的视图特征,通过集成能有效融合多个模型学到的不同特征信息。同时,由于网络之间的差异可能较浅的网络容易过拟合或陷入局部最优,通过集成能有效减小类别概率间的方差,形成一个强分类器。我们的实验分别在CIFAR100和Tiny-ImageNet数据集进行,以及使用不同师生架构下验证了集成策略的有效性,进一步我们还探索了集成出口的数量对分类的精度的影响,结果表明,在一定范围内,集成网络出口数量越多能提升网络最终的性能。
5.4师生联合蒸馏
最后,从信息论和特征学习的角度分析基于师生协同的知识蒸馏方法性。回顾知识蒸馏,它的有效性很大程度来自于教师网络的软标签信息。信息量大小仅与概率有关,软标签比起one-hot标签信息熵更大,隐式地包含了类别之间的信息,丰富的信息让学生网络获得收益。这也为我们的方法提供了一种解释,我们在常规的师生架构中,让学生网络不仅匹配教师网络的软标签信息,而且学习学生网络的自身输出的软标签信息,我们知道学生网络对于教师网络而言,通常结构简单、性能较差,这就导致学生网络输出的类别概率不确定性更大,信息量也更大。因此我们联合学生和教师的指导,让类别信息得到进一步丰富,学生网络也从中受益。
从特征学习的角度看,由于学习的随机性,初始化不同的模型学到的视图特征不同。另一方面,输入数据经过同一网络的不同层次卷积核,也依次抽取了低维和高维特征视图。我们的协同蒸馏方法,将这二者结合起来,学生网络试图学习教师网络的学习的视图,同时学生网络自己从训练数据中学习新的视图。教师的指导起到了正则化的作用,限制了学生网络在教师网络学习的视图附近搜索新的视图,所以学生的效果更好。多出口网络也可以用多视图的角度来解释,每个出口的所构建的模型学到的就是训练数据的一个新的视图特征,所以多个出口就学到了多个视图,并且这些视图还是相关的,最后通过集成形成了一个学到了多种视图特征的强分类器。
6、结论
在本发明中,我们提出了一种的师生协同蒸馏方法。与传统知识蒸馏方法不同,我们引入知识蒸馏和自蒸馏相融合的思想,让模型从教师网络和自身学习新的视图特征知识。通过大量的实验和可视化分析,验证了我们提出的方法及每个组件的有效性,并表明该方法对知识蒸馏和多出口网络均有重要的指导意义。
最后需要说明的是,以上实施例仅用以说明本发明的技术方案而非限制技术方案,本领域的普通技术人员应当理解,那些对本发明的技术方案进行修改或者等同替换,而不脱离本技术方案的宗旨和范围,均应涵盖在本发明的权利要求范围当中。

Claims (10)

1.一种基于师生协同的知识蒸馏方法,其特征在于,包括以下步骤:
S1:构建经过预先训练的教师网络,以及具有多层级的分支输出的学生网络;
S2:将训练数据分别输入教师网络和学生网络,得到教师网络输出的概率分布以及各个分支输出的概率分布和特征;
S3:通过教师网络输出的概率分布和各个分支输出的概率分布计算各个分支的知识蒸馏损失;然后通过各个分支输出的概率分布和特征计算各个分支的自蒸馏损失;最后通过各个分支的知识蒸馏损失和自蒸馏损失计算对应的整体损失函数,并更新学生网络的参数;
S4:对学生网络各个分支输出的概率分布进行融合,得到对应的最终概率分布;
S5:重复步骤S1至S4,直至学生网络训练至收敛。
2.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S1中,使用过参数化的ResNet模型或VGG模型作为教师网络,并对教师网络进行训练。
3.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S1中,在学生网络的不同阶段添加自适应瓶颈层和全连接层,使得学生网络能够形成由浅到深的多个层级的分支输出。
4.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S2中,自适应瓶颈层的结构由1x1、3x3、1x1的三层卷积模块组成,其自适应体现在根据不同特征图的大小使用不同数量的瓶颈模块。
5.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S3中,分支的知识蒸馏损失包括教师网络输出的概率分布和对应分支输出的概率分布之间的KL散度,以及对应分支输出的概率分布与训练数据的真实标签之间的交叉熵损失。
6.如权利要求5所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S3中,知识蒸馏损失通过如下公式计算:
Figure FDA0003548131360000011
其中,yt=ft(x,wt);
yi=fs(x,ws);
式中:
Figure FDA0003548131360000012
表示第i个分支的知识蒸馏损失;i∈[1,n];T2LKL(yi,yt)表示教师网络输出的概率分布yt和第i个分支输出的概率分布yi之间的KL散度;LCE(yi,y)表示第i个分支输出的概率分布yi与训练数据的真实标签y之间的交叉熵损失;wt、ws表示教师网络和学生网络的权重参数;x表示教师网络和学生网络的输入;ft和fs表示教师网络和学生网络的特征。
7.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S3中,分支的自蒸馏损失包括对应分支输出的概率分布与主干网络输出的概率分布之间的KL散度,以及对应分支输出的特征与主干网络输出的特征之间的L2损失;其中,将最深层级分支的输出作为主干网络的输出。
8.如权利要求7所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S3中,自蒸馏损失通过如下公式计算:
Figure FDA0003548131360000021
其中,yi,fi=fs(x,ws);
式中:
Figure FDA0003548131360000022
表示第i个分支的自蒸馏损失;i∈[1,n];T2LKL(yi,yn)表示第i个分支输出的概率分布yi与主干网络输出的概率分布yn之间的KL散度;‖ui(fi)-fn2表示第i个分支输出的特征fi与主干网络输出的特征fn之间的L2损失。
9.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S3中,整体损失函数表示为:
Figure FDA0003548131360000023
式中:Loss表示整体损失;
Figure FDA0003548131360000024
表示第i个分支的知识蒸馏损失;
Figure FDA0003548131360000025
表示第i个分支的自蒸馏损失;i∈[1,n];a、β表示设置的超参数。
10.如权利要求1所述的基于师生协同的知识蒸馏方法,其特征在于:步骤S4中,最终概率分布通过如下公式计算:
Figure FDA0003548131360000026
式中:ys表示学生网络输出的最终概率分布;yi表示第i个分支输出的概率分布;i∈[1,n]。
CN202210254811.9A 2022-03-15 2022-03-15 一种基于师生协同的知识蒸馏方法 Pending CN114611670A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210254811.9A CN114611670A (zh) 2022-03-15 2022-03-15 一种基于师生协同的知识蒸馏方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210254811.9A CN114611670A (zh) 2022-03-15 2022-03-15 一种基于师生协同的知识蒸馏方法

Publications (1)

Publication Number Publication Date
CN114611670A true CN114611670A (zh) 2022-06-10

Family

ID=81862205

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210254811.9A Pending CN114611670A (zh) 2022-03-15 2022-03-15 一种基于师生协同的知识蒸馏方法

Country Status (1)

Country Link
CN (1) CN114611670A (zh)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115661597A (zh) * 2022-10-28 2023-01-31 电子科技大学 一种基于动态权重定位蒸馏的可见光和红外融合目标检测方法
CN115774851A (zh) * 2023-02-10 2023-03-10 四川大学 基于分级知识蒸馏的曲轴内部缺陷检测方法及其检测系统
CN117057414A (zh) * 2023-08-11 2023-11-14 佛山科学技术学院 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统
CN117253123A (zh) * 2023-08-11 2023-12-19 中国矿业大学 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法
WO2024000344A1 (zh) * 2022-06-30 2024-01-04 华为技术有限公司 一种模型训练方法及相关装置
CN117057414B (zh) * 2023-08-11 2024-06-07 佛山科学技术学院 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统

Cited By (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2024000344A1 (zh) * 2022-06-30 2024-01-04 华为技术有限公司 一种模型训练方法及相关装置
CN115661597A (zh) * 2022-10-28 2023-01-31 电子科技大学 一种基于动态权重定位蒸馏的可见光和红外融合目标检测方法
CN115661597B (zh) * 2022-10-28 2023-08-15 电子科技大学 一种基于动态权重定位蒸馏的可见光和红外融合目标检测方法
CN115774851A (zh) * 2023-02-10 2023-03-10 四川大学 基于分级知识蒸馏的曲轴内部缺陷检测方法及其检测系统
CN115774851B (zh) * 2023-02-10 2023-04-25 四川大学 基于分级知识蒸馏的曲轴内部缺陷检测方法及其检测系统
CN117057414A (zh) * 2023-08-11 2023-11-14 佛山科学技术学院 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统
CN117253123A (zh) * 2023-08-11 2023-12-19 中国矿业大学 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法
CN117253123B (zh) * 2023-08-11 2024-05-17 中国矿业大学 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法
CN117057414B (zh) * 2023-08-11 2024-06-07 佛山科学技术学院 一种面向文本生成的多步协作式提示学习的黑盒知识蒸馏方法及系统

Similar Documents

Publication Publication Date Title
CN114611670A (zh) 一种基于师生协同的知识蒸馏方法
CN108319686B (zh) 基于受限文本空间的对抗性跨媒体检索方法
Gu et al. Stack-captioning: Coarse-to-fine learning for image captioning
CN108549658B (zh) 一种基于语法分析树上注意力机制的深度学习视频问答方法及系统
CN106202068B (zh) 基于多语平行语料的语义向量的机器翻译方法
CN114398961B (zh) 一种基于多模态深度特征融合的视觉问答方法及其模型
CN109753571B (zh) 一种基于二次主题空间投影的场景图谱低维空间嵌入方法
CN112685597B (zh) 一种基于擦除机制的弱监督视频片段检索方法和系统
CN110751698A (zh) 一种基于混和网络模型的文本到图像的生成方法
CN112464004A (zh) 一种多视角深度生成图像聚类方法
CN112527993B (zh) 一种跨媒体层次化深度视频问答推理框架
CN110851575B (zh) 一种对话生成系统及对话实现方法
CN113673535B (zh) 一种多模态特征融合网络的图像描述生成方法
CN116110022B (zh) 基于响应知识蒸馏的轻量化交通标志检测方法及系统
CN114254093A (zh) 多空间知识增强的知识图谱问答方法及系统
CN113239211A (zh) 一种基于课程学习的强化学习知识图谱推理方法
Hu et al. One-bit supervision for image classification
WO2023108873A1 (zh) 一种脑网络和脑成瘾连接计算方法及装置
CN109948589B (zh) 基于量子深度信念网络的人脸表情识别方法
Yang et al. Att-bm-som: A framework of effectively choosing image information and optimizing syntax for image captioning
CN113887471A (zh) 基于特征解耦和交叉对比的视频时序定位方法
KR20220066554A (ko) Qa 모델을 이용하여 지식 그래프를 구축하는 방법, 장치 및 컴퓨터 프로그램
CN111507472A (zh) 一种基于重要性剪枝的精度估计参数搜索方法
CN113554040B (zh) 一种基于条件生成对抗网络的图像描述方法、装置设备
CN115527052A (zh) 一种基于对比预测的多视图聚类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination