CN116205290A - 一种基于中间特征知识融合的知识蒸馏方法和装置 - Google Patents

一种基于中间特征知识融合的知识蒸馏方法和装置 Download PDF

Info

Publication number
CN116205290A
CN116205290A CN202310499470.6A CN202310499470A CN116205290A CN 116205290 A CN116205290 A CN 116205290A CN 202310499470 A CN202310499470 A CN 202310499470A CN 116205290 A CN116205290 A CN 116205290A
Authority
CN
China
Prior art keywords
feature
knowledge
student
model
fusion
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202310499470.6A
Other languages
English (en)
Other versions
CN116205290B (zh
Inventor
王玉柱
张艾嘉
裘云蕾
段曼妮
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202310499470.6A priority Critical patent/CN116205290B/zh
Publication of CN116205290A publication Critical patent/CN116205290A/zh
Application granted granted Critical
Publication of CN116205290B publication Critical patent/CN116205290B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • G06V10/761Proximity, similarity or dissimilarity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • G06V10/806Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

一种基于中间特征知识融合的知识蒸馏方法和装置,将图像数据喂入教师模型和学生模型,并提取教师模型和学生模型的各阶段中间特征;构建阶段级残差连接,将学生模型某一阶段中间特征与上一阶段特征实现特征知识融合;将教师模型与融合后的学生模型分别经过全局平均池化,构建出语义类别特征向量,对该特征向量计算交叉熵损失,以最大化特征相似性;将学生模型预测输出与类别标签的分类损失与特征向量相似性损失加权求和,训练学生模型。还包括一种基于中间特征知识融合的知识蒸馏系统。本发明相较于现有技术,本发明充分融合教师模型的中间特征知识,知识蒸馏性能更优。

Description

一种基于中间特征知识融合的知识蒸馏方法和装置
技术领域
本发明涉及深度神经网络模型压缩领域,尤其是涉及一种基于中间特征知识融合的知识蒸馏方法和装置。
背景技术
近十年来,受益于更大的深度模型、大规模高质量的标注数据以及强大的硬件算力,深度神经网络在多种计算机视觉任务已取得显著进展,比如图像分类、目标检测、语义分割等。然而,由于计算资源和内存资源的限制,大模型在实际应用中难以部署。利用深度模型压缩技术,能够建立一个与大模型具有性能竞争力且对硬件资源要求低的高效模型。
知识蒸馏的优势在于能够实现跨模型结构的深度压缩而广受业界关注。知识蒸馏旨在利用更大的模型(教师)知识指导轻量化的小模型(学生)的训练,使学生模型能够实现更高的性能,其核心问题是如何从教师模型中提取有效的知识,并高效的传递给学生模型。知识蒸馏方法大体可分为两类:基于预测概率方法(logits-based)和基于中间特征方法(feature-based)。其中,基于中间特征的方法,如FitNet、OFD、ReviewKD等,通过引入精心设计的特征变换模块,能够有效提取教师模型的中间特征,进而显著提升知识蒸馏效果,但会额外引入不可避免的显著计算成本。基于预测概率的方法,如KD、DKD、DIS等,通过分析知识的表示形式,实现了对教师知识的高效利用,有效提升了知识蒸馏的性能,且不需要额外的计算成本。如何在较少计算成本下,能够有效利用教师模型中间阶段的特征知识,设计简单有效的知识表示及融合方法,进一步提升知识蒸馏效果,在模型压缩知识蒸馏领域的仍是一个待解决的关键问题。
发明内容
本发明要克服现有技术的上述缺点,提供一种基于中间特征知识融合的知识蒸馏方法和装置。
为了实现上述目的,本发明所述的一种基于中间特征知识融合的知识蒸馏方法,包括如下步骤:
S1,构建数据集:构建图像分类数据集;所述图像分类数据集中包括训练集和测试集;所述训练集由图像和分类标签构成;
S2,数据预处理:调整图像分类数据集中的图像的宽和高;对训练集图像做随机水平翻转、随机裁剪、标准化操作;对测试集图像做中心裁剪、标准化操作;
S3,融合残差知识:将预处理后的训练集数据分批次喂入教师模型和学生模型;对于教师模型,提取各阶段中间特征,所述中间特征为教师模型内部各阶段对所述图像分类数据集中的图像的表征,再对所述中间特征做全局平均池化处理得到特征向量
Figure SMS_1
;对于学生模型,提取各阶段中间特征/>
Figure SMS_2
,所述中间特征为学生模型内部各阶段对所述图像分类数据集中的图像的表征,其中,l为教师模型和学生模型的阶段数量;/>
对于阶段i,融合特征
Figure SMS_3
与阶段i-1的特征/>
Figure SMS_4
获得特征/>
Figure SMS_5
S4,变换特征:在S3步骤,为了保证特征
Figure SMS_7
与特征/>
Figure SMS_9
具有相同的尺度,对特征/>
Figure SMS_12
做特征变换操作,使特征/>
Figure SMS_8
与特征/>
Figure SMS_11
具有相同的宽、高、通道数量;同样的,为了使学生模型的融合特征/>
Figure SMS_13
与教师特征/>
Figure SMS_15
具有相同的尺度,对特征/>
Figure SMS_6
做特征变换和全局平均池化得到特征向量/>
Figure SMS_10
,最终学生特征向量为/>
Figure SMS_14
S5,计算特征相似性:考虑教师模型和学生模型中的第i个中间特征向量,分别为
Figure SMS_16
和/>
Figure SMS_17
,计算特征向量的相似性损失/>
Figure SMS_18
S6,计算分类损失:所述步骤S3中,输入图像的标签为
Figure SMS_19
,C为类别数,学生模型对输入图像的预测为/>
Figure SMS_20
,计算分类损失/>
Figure SMS_21
S7,学生网络训练过程的总损失为特征相似性损失与分类损失的加权和,表示为
Figure SMS_22
,其中/>
Figure SMS_23
为损失权重平衡因子。
进一步地,所述步骤S3中,学生模型与教师模型可以是相似结构,也可以是不同结构。喂入同样的图像数据,可以提取相同数量的中间特征。
进一步地,所述步骤S4中,对中间特征
Figure SMS_24
的特征变换为步长为2的3x3卷积;对特征/>
Figure SMS_25
的特征变换依次为1x1卷积、步长为2的3x3卷积、1x1卷积。
进一步地,所述步骤S5中,使用交叉熵最大化特征向量
Figure SMS_26
和/>
Figure SMS_27
之间的相似性,其过程如下式所述:
Figure SMS_28
其中,n为向量维度。
进一步地,所述步骤S5中,只对教师模型预测正确的图像样本,计算特征相似性。
进一步地,所述步骤S6中,使用交叉熵计算分类损失,其过程如下式所述:
Figure SMS_29
其中,输入图像的标签为
Figure SMS_30
,C为类别总数,学生模型对输入图像的预测为
Figure SMS_31
进一步地,所述步骤S7中,在训练集中划分出一定比例的验证集,根据学生模型在验证集上的准确率调整
Figure SMS_32
本发明还包括一种基于中间特征知识融合的知识蒸馏系统,包括:
数据集构建模块,用于构建图像分类数据集;所述图像分类数据集中包括训练集和测试集;所述训练集由图像和分类标签构成;
数据预处理模块:用于调整图像分类数据集中的图像的宽和高;对训练集图像做随机水平翻转、随机裁剪、标准化操作;对测试集图像做中心裁剪、标准化操作;
融合残差知识模块:用于融合特征向量,将预处理后的训练集数据分批次喂入教师模型和学生模型;对于教师模型,提取各阶段中间特征,再做全局平均池化处理得到特征向量
Figure SMS_33
;对于学生模型,提取各阶段中间特征/>
Figure SMS_34
,其中,l为教师模型和学生模型的阶段数量;对于阶段i,融合特征/>
Figure SMS_35
与阶段i-1的特征/>
Figure SMS_36
获得特征/>
Figure SMS_37
;/>
特征变换模块:用于统一教师网络与学生网络的特征尺度,对特征
Figure SMS_38
做特征变换操作,使特征/>
Figure SMS_42
与特征/>
Figure SMS_44
具有相同的宽、高、通道数量;为了使学生模型的融合特征/>
Figure SMS_40
与教师特征/>
Figure SMS_41
具有相同的尺度,对特征/>
Figure SMS_43
做特征变换和全局平均池化得到特征向量/>
Figure SMS_45
,最终学生特征向量为/>
Figure SMS_39
特征相似性计算模块:用于计算特征向量的相似性损失
Figure SMS_46
分类损失计算模块:用于计算学生网络分类损失
Figure SMS_47
学生网络训练模块:用于特征相似性损失与分类损失的加权求和,训练学生网络。
本发明还包括一种基于中间特征知识融合的知识蒸馏装置,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于上述的一种基于中间特征知识融合的知识蒸馏方法。
本发明还包括一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述的一种基于中间特征知识融合的知识蒸馏方法。
本发明的有益效果在于:
在知识蒸馏中,本发明将知识建模为教师模型对输入图像样本的预测结果的相对顺序,而不是强制要求学生模型严格学习教师模型预测的绝对值,能够降低学生模型的学习要求,有益于学生模型的优化。另外,本发明通过对学生模型跨阶段的中间特征知识融合,摒弃了现有方法需要手工设计精巧的特征变换模块,避免额外引入显著的计算成本。相较于以往知识蒸馏工作,本发明充分利用教师模型的中间特征,并且额外引入的计算成本较少,能够有效地教师模型的中间知识,并高效的传递给学生模型,充分发挥知识蒸馏的效果。
附图说明
图1是本发明一种基于中间特征知识融合的知识蒸馏方法的流程图。
图2是本发明在CIFAR100数据集上知识蒸馏损失曲线图。
图3是本发明在CIFAR100数据集上知识蒸馏准确率曲线图。
图4是本发明在ImageNet-1K数据集上知识蒸馏损失曲线图。
图5是本发明在ImageNet-1K数据集上知识蒸馏准确率曲线图。
图6是本发明一种基于中间特征知识融合的知识蒸馏装置的结构示意图。
图7是本发明的系统结构图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚明了,以下结合附图及实施例,对本发明进行详细说明。但是应当理解的是,此处所描述的具体实施方式仅用于说明和解释本发明,并不用于限制本发明的范围。
实施例1
以户外自然场景目标识别任务为例,包括动物、鸟类、植物、人等目标类别,本发明一种基于中间特征知识融合的知识蒸馏方法,参阅图1,具体过程如下:
S1,构建数据集:构建动物、鸟类、植物、人等100个类别的自然图像分类数据集,共6万张,其中训练集5万张,测试集1万张,在训练集中划分1万张作为验证集,用于调整超参数(CIFAR100公开数据集);类似的,构建1000个类别的自然图像分类数据集,共120余万张,其中训练集120万张,测试集5万张,在训练集中划分5万张作为验证集,用于调整超参数(ImageNet-1K公开数据集);
S2,数据预处理:调整图像分类数据集中的图像的宽和高;对训练集图像做随机水平翻转、随机裁剪、标准化操作;对测试集图像做中心裁剪、标准化操作;对CIFAR100数据集,将图像宽高调为32,对ImageNet-1K数据集,将图像宽高调为224;
S3,残差知识融合:对CIFAR100数据集,选择教师模型为DenseNet250,学生模型为ResNet110;对ImageNet-1K数据集,选择教师模型为ResNet-34,学生模型为ResNet-18;将预处理后的训练集数据分批次喂入教师模型和学生模型;对于教师模型,提取各阶段中间特征,所述中间特征为教师模型内部各阶段对所述图像分类数据集中的图像的表征,再对所述中间特征做全局平均池化处理得到特征向量
Figure SMS_48
;对于学生模型,提取各阶段中间特征/>
Figure SMS_49
,所述中间特征为学生模型内部各阶段对所述图像分类数据集中的图像的表征,其中,l为教师模型和学生模型的阶段数量。对于阶段i,融合特征/>
Figure SMS_50
与阶段i-1的特征/>
Figure SMS_51
获得特征/>
Figure SMS_52
S4,特征变换:在S3步骤,为了保证特征
Figure SMS_53
与特征/>
Figure SMS_57
具有相同的尺度,对特征/>
Figure SMS_60
做特征变换操作,使特征/>
Figure SMS_54
与特征/>
Figure SMS_56
具有相同的宽、高、通道数量;同样的,为了使学生模型的融合特征/>
Figure SMS_59
与教师特征/>
Figure SMS_62
具有相同的尺度,对特征/>
Figure SMS_55
做特征变换和全局平均池化得到特征向量/>
Figure SMS_58
,最终学生特征向量为/>
Figure SMS_61
S5,计算特征相似性:考虑教师模型和学生模型中的第i个中间特征向量,分别为
Figure SMS_63
和/>
Figure SMS_64
,计算特征向量的相似性损失/>
Figure SMS_65
S6,计算分类损失:所述步骤S3中,输入图像的标签为
Figure SMS_66
,C为类别数,学生模型对输入图像的预测为/>
Figure SMS_67
,计算分类损失/>
Figure SMS_68
S7,学生网络训练过程的总损失为特征相似性损失与分类损失的加权和,表示为
Figure SMS_69
,其中/>
Figure SMS_70
为损失权重平衡因子。
所述步骤S3中,学生模型与教师模型可以是相似结构,也可以是不同结构。喂入同样的图像数据,可以提取相同数量的中间特征。
所述步骤S4中,对中间特征
Figure SMS_71
的特征变换为步长为2的3x3卷积;对特征/>
Figure SMS_72
的特征变换依次为1x1卷积、步长为2的3x3卷积、1x1卷积。
所述步骤S5中,使用交叉熵最大化特征向量
Figure SMS_73
和/>
Figure SMS_74
之间的相似性,其过程如下式所述:
Figure SMS_75
其中,n为向量维度。
所述步骤S5中,只对教师模型预测正确的图像样本,计算特征相似性。
所述步骤S6中,使用交叉熵计算分类损失,其过程如下式所述:
Figure SMS_76
其中,输入图像的标签为
Figure SMS_77
,C为类别总数,学生模型对输入图像的预测为
Figure SMS_78
所述步骤S7中,在训练集中划分出一定比例的验证集,根据学生模型在验证集上的准确率调整
Figure SMS_79
。如表1所示,本发明在CIFAR100和ImageNet-1K数据集上与KD和ReviewKD的比较。在CIFAR100数据集上,设置教师网络和学生网络分别是DenseNet250和ResNet110。可以看到,本发明相对KD方法,准确率提升了2.21%;相对ReviewKD方法,准确率提升了1.11%。本发明的训练曲线如图2至图5所示。
Figure SMS_80
实施例2
参照图7,本发明还包括用于实现实施例1的一种基于中间特征知识融合的知识蒸馏方法的一种基于中间特征知识融合的知识蒸馏系统,包括:
数据集构建模块,用于构建图像分类数据集;所述图像分类数据集中包括训练集和测试集;所述训练集由图像和分类标签构成;
数据预处理模块:用于调整图像分类数据集中的图像的宽和高;对训练集图像做随机水平翻转、随机裁剪、标准化操作;对测试集图像做中心裁剪、标准化操作;
融合残差知识模块:用于融合特征向量,将预处理后的训练集数据分批次喂入教师模型和学生模型;对于教师模型,提取各阶段中间特征,再做全局平均池化处理得到特征向量
Figure SMS_81
;对于学生模型,提取各阶段中间特征/>
Figure SMS_82
,其中,l为教师模型和学生模型的阶段数量;对于阶段i,融合特征/>
Figure SMS_83
与阶段i-1的特征/>
Figure SMS_84
获得特征/>
Figure SMS_85
特征变换模块:用于统一教师网络与学生网络的特征尺度,对特征
Figure SMS_88
做特征变换操作,使特征/>
Figure SMS_90
与特征/>
Figure SMS_92
具有相同的宽、高、通道数量;为了使学生模型的融合特征/>
Figure SMS_87
与教师特征/>
Figure SMS_89
具有相同的尺度,对特征/>
Figure SMS_91
做特征变换和全局平均池化得到特征向量/>
Figure SMS_93
,最终学生特征向量为/>
Figure SMS_86
特征相似性计算模块:用于计算特征向量的相似性损失
Figure SMS_94
分类损失计算模块:用于计算学生网络分类损失
Figure SMS_95
;/>
学生网络训练模块:用于特征相似性损失与分类损失的加权求和,训练学生网络。
实施例3
本实施例涉及一种基于中间特征知识融合的知识蒸馏装置,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于上述实施例1的一种基于中间特征知识融合的知识蒸馏方法;装置实施例可以应用在任意具备数据处理能力的设备上,该任意具备数据处理能力的设备可以为诸如计算机等设备或装置。
如图6,在硬件层面,该知识蒸馏装置包括处理器、内部总线、网络接口、内存以及非易失性存储器,当然还可能包括其他业务所需要的硬件。处理器从非易失性存储器中读取对应的计算机程序到内存中然后运行,以实现上述图1所示的方法。当然,除了软件实现方式之外,本发明并不排除其他实现方式,比如逻辑器件抑或软硬件结合的方式等等,也就是说以下处理流程的执行主体并不限定于各个逻辑单元,也可以是硬件或逻辑器件。
对于一个技术的改进可以很明显地区分是硬件上的改进(例如,对二极管、晶体管、开关等电路结构的改进)还是软件上的改进(对于方法流程的改进)。然而,随着技术的发展,当今的很多方法流程的改进已经可以视为硬件电路结构的直接改进。设计人员几乎都通过将改进的方法流程编程到硬件电路中来得到相应的硬件电路结构。因此,不能说一个方法流程的改进就不能用硬件实体模块来实现。例如,可编程逻辑器件(ProgrammableLogic Device, PLD)(例如现场可编程门阵列(Field Programmable Gate Array,FPGA))就是这样一种集成电路,其逻辑功能由用户对器件编程来确定。由设计人员自行编程来把一个数字系统“集成”在一片PLD上,而不需要请芯片制造厂商来设计和制作专用的集成电路芯片。而且,如今,取代手工地制作集成电路芯片,这种编程也多半改用“逻辑编译器(logic compiler)”软件来实现,它与程序开发撰写时所用的软件编译器相类似,而要编译之前的原始代码也得用特定的编程语言来撰写,此称之为硬件描述语言(HardwareDescription Language,HDL),而HDL也并非仅有一种,而是有许多种,如ABEL(AdvancedBoolean Expression Language)、AHDL(Altera Hardware Description Language)、Confluence、CUPL(Cornell University Programming Language)、HDCal、JHDL(JavaHardware Description Language)、Lava、Lola、MyHDL、PALASM、RHDL(Ruby HardwareDescription Language)等,目前最普遍使用的是VHDL(Very-High-Speed IntegratedCircuit Hardware Description Language)与Verilog。本领域技术人员也应该清楚,只需要将方法流程用上述几种硬件描述语言稍作逻辑编程并编程到集成电路中,就可以很容易得到实现该逻辑方法流程的硬件电路。
控制器可以按任何适当的方式实现,例如,控制器可以采取例如微处理器或处理器以及存储可由该(微)处理器执行的计算机可读程序代码(例如软件或固件)的计算机可读介质、逻辑门、开关、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程逻辑控制器和嵌入微控制器的形式,控制器的例子包括但不限于以下微控制器:ARC 625D、Atmel AT91SAM、Microchip PIC18F26K20 以及Silicone Labs C8051F320,存储器控制器还可以被实现为存储器的控制逻辑的一部分。本领域技术人员也知道,除了以纯计算机可读程序代码方式实现控制器以外,完全可以通过将方法步骤进行逻辑编程来使得控制器以逻辑门、开关、专用集成电路、可编程逻辑控制器和嵌入微控制器等的形式来实现相同功能。因此这种控制器可以被认为是一种硬件部件,而对其内包括的用于实现各种功能的装置也可以视为硬件部件内的结构。或者甚至,可以将用于实现各种功能的装置视为既可以是实现方法的软件模块又可以是硬件部件内的结构。
上述实施例阐明的系统、装置、模块或单元,具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现。一种典型的实现设备为计算机。具体的,计算机例如可以为个人计算机、膝上型计算机、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任何设备的组合。
还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、商品或者设备中还存在另外的相同要素。
本领域技术人员应明白,本发明的实施例可提供为方法、系统或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明可以在由计算机执行的计算机可执行指令的一般上下文中描述,例如程序模块。一般地,程序模块包括执行特定任务或实现特定抽象数据类型的例程、程序、对象、组件、数据结构等等。也可以在分布式计算环境中实践本发明,在这些分布式计算环境中,由通过通信网络而被连接的远程处理设备来执行任务。在分布式计算环境中,程序模块可以位于包括存储设备在内的本地和远程计算机存储介质中。
实施例4
本发明实施例还提供一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述实施例1的一种基于中间特征知识融合的知识蒸馏方法。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的范围。

Claims (10)

1.一种基于中间特征知识融合的知识蒸馏方法,其特征在于包括如下步骤:
S1,构建数据集:构建图像分类数据集;所述图像分类数据集中包括训练集和测试集;所述训练集由图像和分类标签构成;
S2,数据预处理:调整图像分类数据集中的图像的宽和高;对训练集图像做随机水平翻转、随机裁剪、标准化操作;对测试集图像做中心裁剪、标准化操作;
S3,融合残差知识:将预处理后的训练集数据分批次喂入教师模型和学生模型;对于教师模型,提取各阶段中间特征,所述中间特征为教师模型内部各阶段对所述图像分类数据集中的图像的表征,再对所述中间特征做全局平均池化处理得到特征向量
Figure QLYQS_1
;对于学生模型,提取各阶段中间特征/>
Figure QLYQS_2
,所述中间特征为学生模型内部各阶段对所述图像分类数据集中的图像的表征,其中,l为教师模型和学生模型的阶段数量;对于阶段i,融合特征/>
Figure QLYQS_3
与阶段i-1的特征/>
Figure QLYQS_4
获得特征/>
Figure QLYQS_5
S4,变换特征:在S3步骤,为了保证特征
Figure QLYQS_7
与特征/>
Figure QLYQS_11
具有相同的尺度,对特征/>
Figure QLYQS_13
做特征变换操作,使特征/>
Figure QLYQS_8
与特征/>
Figure QLYQS_9
具有相同的宽、高、通道数量;同样的,为了使学生模型的融合特征/>
Figure QLYQS_12
与教师特征/>
Figure QLYQS_15
具有相同的尺度,对特征/>
Figure QLYQS_6
做特征变换和全局平均池化得到特征向量/>
Figure QLYQS_10
,最终学生特征向量为/>
Figure QLYQS_14
S5,计算特征相似性:考虑教师模型和学生模型中的第i个中间特征向量,分别为
Figure QLYQS_16
和/>
Figure QLYQS_17
,计算特征向量的相似性损失/>
Figure QLYQS_18
S6,计算分类损失:所述步骤S3中,输入图像的标签为
Figure QLYQS_19
,C为类别数,学生模型对输入图像的预测为/>
Figure QLYQS_20
,计算分类损失/>
Figure QLYQS_21
S7,学生网络训练过程的总损失为特征相似性损失与分类损失的加权和,表示为
Figure QLYQS_22
,其中/>
Figure QLYQS_23
为损失权重平衡因子。
2.根据权利要求1所述的基于中间特征知识融合的知识蒸馏方法,其特征在于,所述步骤S3中,学生模型与教师模型是相似结构或不同结构;喂入同样的图像数据,能提取相同数量的中间特征。
3.根据权利要求1所述的基于中间特征知识融合的知识蒸馏方法,其特征在于,所述步骤S4中,对中间特征
Figure QLYQS_24
的特征变换为步长为2的3x3卷积;对特征/>
Figure QLYQS_25
的特征变换依次为1x1卷积、步长为2的3x3卷积、1x1卷积。
4.根据权利要求1所述的基于中间特征知识融合的知识蒸馏方法,其特征在于,所述步骤S5中,使用交叉熵最大化特征向量
Figure QLYQS_26
和/>
Figure QLYQS_27
之间的相似性,其过程如下式所述:
Figure QLYQS_28
其中,n为向量维度。
5.根据权利要求4所述的基于中间特征知识融合的知识蒸馏方法,其特征在于,所述步骤S5中,只对教师模型预测正确的图像样本,计算特征相似性。
6.根据权利要求1所述的基于中间特征知识融合的知识蒸馏方法,其特征在于,所述步骤S6中,使用交叉熵计算分类损失,其过程如下式所述:
Figure QLYQS_29
其中,输入图像的标签为
Figure QLYQS_30
,C为类别总数,学生模型对输入图像的预测为/>
Figure QLYQS_31
7.根据权利要求1所述的基于中间特征知识融合的知识蒸馏方法,其特征在于,所述步骤S7中,在训练集中划分出部分作为的验证集,根据学生模型在验证集上的准确率调整
Figure QLYQS_32
8.一种基于中间特征知识融合的知识蒸馏系统,其特征包括:
数据集构建模块,用于构建图像分类数据集;所述图像分类数据集中包括训练集和测试集;所述训练集由图像和分类标签构成;
数据预处理模块:用于调整图像分类数据集中的图像的宽和高;对训练集图像做随机水平翻转、随机裁剪、标准化操作;对测试集图像做中心裁剪、标准化操作;
融合残差知识模块:用于融合特征向量,将预处理后的训练集数据分批次喂入教师模型和学生模型;对于教师模型,提取各阶段中间特征,再做全局平均池化处理得到特征向量
Figure QLYQS_33
;对于学生模型,提取各阶段中间特征/>
Figure QLYQS_34
,其中,l为教师模型和学生模型的阶段数量;对于阶段i,融合特征/>
Figure QLYQS_35
与阶段i-1的特征/>
Figure QLYQS_36
获得特征/>
Figure QLYQS_37
特征变换模块:用于统一教师网络与学生网络的特征尺度,对特征
Figure QLYQS_40
做特征变换操作,使特征/>
Figure QLYQS_41
与特征/>
Figure QLYQS_43
具有相同的宽、高、通道数量;为了使学生模型的融合特征/>
Figure QLYQS_39
与教师特征/>
Figure QLYQS_42
具有相同的尺度,对特征/>
Figure QLYQS_44
做特征变换和全局平均池化得到特征向量/>
Figure QLYQS_45
,最终学生特征向量为/>
Figure QLYQS_38
特征相似性计算模块:用于计算特征向量的相似性损失
Figure QLYQS_46
分类损失计算模块:用于计算学生网络分类损失
Figure QLYQS_47
学生网络训练模块:用于特征相似性损失与分类损失的加权求和,训练学生网络。
9.一种基于中间特征知识融合的知识蒸馏装置,其特征在于,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现权利要求1-7任一项所述的一种基于中间特征知识融合的知识蒸馏方法。
10.一种计算机可读存储介质,其特征在于,其上存储有程序,该程序被处理器执行时,实现权利要求1-7任一项所述的一种基于中间特征知识融合的知识蒸馏方法。
CN202310499470.6A 2023-05-06 2023-05-06 一种基于中间特征知识融合的知识蒸馏方法和装置 Active CN116205290B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310499470.6A CN116205290B (zh) 2023-05-06 2023-05-06 一种基于中间特征知识融合的知识蒸馏方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310499470.6A CN116205290B (zh) 2023-05-06 2023-05-06 一种基于中间特征知识融合的知识蒸馏方法和装置

Publications (2)

Publication Number Publication Date
CN116205290A true CN116205290A (zh) 2023-06-02
CN116205290B CN116205290B (zh) 2023-09-15

Family

ID=86509847

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310499470.6A Active CN116205290B (zh) 2023-05-06 2023-05-06 一种基于中间特征知识融合的知识蒸馏方法和装置

Country Status (1)

Country Link
CN (1) CN116205290B (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117009830A (zh) * 2023-10-07 2023-11-07 之江实验室 一种基于嵌入特征正则化的知识蒸馏方法和系统
CN117115469A (zh) * 2023-10-23 2023-11-24 腾讯科技(深圳)有限公司 图像特征提取网络的训练方法、装置、存储介质及设备
CN117725960A (zh) * 2024-02-18 2024-03-19 智慧眼科技股份有限公司 基于知识蒸馏的语言模型训练方法、文本分类方法及设备
CN117831138A (zh) * 2024-03-05 2024-04-05 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法
CN117831138B (zh) * 2024-03-05 2024-05-24 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法

Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200302295A1 (en) * 2019-03-22 2020-09-24 Royal Bank Of Canada System and method for knowledge distillation between neural networks
CN112116030A (zh) * 2020-10-13 2020-12-22 浙江大学 一种基于向量标准化和知识蒸馏的图像分类方法
CN112199535A (zh) * 2020-09-30 2021-01-08 浙江大学 一种基于集成知识蒸馏的图像分类方法
CN112418343A (zh) * 2020-12-08 2021-02-26 中山大学 多教师自适应联合知识蒸馏
CN112990447A (zh) * 2021-05-20 2021-06-18 之江实验室 一种知识显著性与局部模式一致性的知识蒸馏方法与装置
CN113240120A (zh) * 2021-05-07 2021-08-10 深圳思谋信息科技有限公司 基于温习机制的知识蒸馏方法、装置、计算机设备和介质
CN113361396A (zh) * 2021-06-04 2021-09-07 思必驰科技股份有限公司 多模态的知识蒸馏方法及系统
CN114049513A (zh) * 2021-09-24 2022-02-15 中国科学院信息工程研究所 一种基于多学生讨论的知识蒸馏方法和系统
CN114120319A (zh) * 2021-10-09 2022-03-01 苏州大学 一种基于多层次知识蒸馏的连续图像语义分割方法
US20220076136A1 (en) * 2020-09-09 2022-03-10 Peyman PASSBAN Method and system for training a neural network model using knowledge distillation
CN115984111A (zh) * 2023-01-06 2023-04-18 浙江大学 一种基于知识蒸馏压缩模型的图像超分辨率方法及装置
CN115995018A (zh) * 2022-12-09 2023-04-21 厦门大学 基于样本感知蒸馏的长尾分布视觉分类方法

Patent Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200302295A1 (en) * 2019-03-22 2020-09-24 Royal Bank Of Canada System and method for knowledge distillation between neural networks
US20220076136A1 (en) * 2020-09-09 2022-03-10 Peyman PASSBAN Method and system for training a neural network model using knowledge distillation
CN112199535A (zh) * 2020-09-30 2021-01-08 浙江大学 一种基于集成知识蒸馏的图像分类方法
CN112116030A (zh) * 2020-10-13 2020-12-22 浙江大学 一种基于向量标准化和知识蒸馏的图像分类方法
CN112418343A (zh) * 2020-12-08 2021-02-26 中山大学 多教师自适应联合知识蒸馏
CN113240120A (zh) * 2021-05-07 2021-08-10 深圳思谋信息科技有限公司 基于温习机制的知识蒸馏方法、装置、计算机设备和介质
CN112990447A (zh) * 2021-05-20 2021-06-18 之江实验室 一种知识显著性与局部模式一致性的知识蒸馏方法与装置
CN113361396A (zh) * 2021-06-04 2021-09-07 思必驰科技股份有限公司 多模态的知识蒸馏方法及系统
CN114049513A (zh) * 2021-09-24 2022-02-15 中国科学院信息工程研究所 一种基于多学生讨论的知识蒸馏方法和系统
CN114120319A (zh) * 2021-10-09 2022-03-01 苏州大学 一种基于多层次知识蒸馏的连续图像语义分割方法
CN115995018A (zh) * 2022-12-09 2023-04-21 厦门大学 基于样本感知蒸馏的长尾分布视觉分类方法
CN115984111A (zh) * 2023-01-06 2023-04-18 浙江大学 一种基于知识蒸馏压缩模型的图像超分辨率方法及装置

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
YUNNAN WANG 等: "The Chain of Self-Taught Knowledge Distillation Combining Output and Features", 《 2021 33RD CHINESE CONTROL AND DECISION CONFERENCE (CCDC)》, pages 5115 - 5120 *
葛仕明;赵胜伟;刘文瑜;李晨钰;: "基于深度特征蒸馏的人脸识别", 北京交通大学学报, no. 06, pages 32 - 38 *

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117009830A (zh) * 2023-10-07 2023-11-07 之江实验室 一种基于嵌入特征正则化的知识蒸馏方法和系统
CN117009830B (zh) * 2023-10-07 2024-02-13 之江实验室 一种基于嵌入特征正则化的知识蒸馏方法和系统
CN117115469A (zh) * 2023-10-23 2023-11-24 腾讯科技(深圳)有限公司 图像特征提取网络的训练方法、装置、存储介质及设备
CN117115469B (zh) * 2023-10-23 2024-01-05 腾讯科技(深圳)有限公司 图像特征提取网络的训练方法、装置、存储介质及设备
CN117725960A (zh) * 2024-02-18 2024-03-19 智慧眼科技股份有限公司 基于知识蒸馏的语言模型训练方法、文本分类方法及设备
CN117831138A (zh) * 2024-03-05 2024-04-05 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法
CN117831138B (zh) * 2024-03-05 2024-05-24 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法

Also Published As

Publication number Publication date
CN116205290B (zh) 2023-09-15

Similar Documents

Publication Publication Date Title
CN116205290B (zh) 一种基于中间特征知识融合的知识蒸馏方法和装置
TWI685761B (zh) 詞向量處理方法及裝置
CN111461004B (zh) 基于图注意力神经网络的事件检测方法、装置和电子设备
CN110348462A (zh) 一种图像特征确定、视觉问答方法、装置、设备及介质
CN109934253B (zh) 一种对抗样本生成方法及装置
CN111753878A (zh) 一种网络模型部署方法、设备及介质
CN111985525A (zh) 基于多模态信息融合处理的文本识别方法
CN114358243A (zh) 通用特征提取网络训练方法、装置及通用特征提取网络
CN117036829A (zh) 一种基于原型学习实现标签增强的叶片细粒度识别方法和系统
CN116630480B (zh) 一种交互式文本驱动图像编辑的方法、装置和电子设备
CN110717013A (zh) 文档的矢量化
CN115830633B (zh) 基于多任务学习残差神经网络的行人重识别方法和系统
CN117113174A (zh) 一种模型训练的方法、装置、存储介质及电子设备
CN115499635B (zh) 数据压缩处理方法及装置
CN115130621B (zh) 一种模型训练方法、装置、存储介质及电子设备
CN112307371B (zh) 小程序子服务识别方法、装置、设备及存储介质
CN115294336A (zh) 一种数据标注方法、装置及存储介质
CN114254080A (zh) 一种文本匹配方法、装置及设备
CN111539520A (zh) 增强深度学习模型鲁棒性的方法及装置
CN113221871B (zh) 一种文字识别方法、装置、设备及介质
CN112115952B (zh) 一种基于全卷积神经网络的图像分类方法、设备及介质
CN117079646B (zh) 一种语音识别模型的训练方法、装置、设备及存储介质
CN115423485B (zh) 数据处理方法、装置及设备
CN115953706B (zh) 虚拟形象处理方法及装置
CN117034942B (zh) 一种命名实体识别方法、装置、设备及可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant