CN117610608A - 基于多阶段特征融合的知识蒸馏方法、设备及介质 - Google Patents

基于多阶段特征融合的知识蒸馏方法、设备及介质 Download PDF

Info

Publication number
CN117610608A
CN117610608A CN202311370731.0A CN202311370731A CN117610608A CN 117610608 A CN117610608 A CN 117610608A CN 202311370731 A CN202311370731 A CN 202311370731A CN 117610608 A CN117610608 A CN 117610608A
Authority
CN
China
Prior art keywords
fusion
stage
network model
student network
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311370731.0A
Other languages
English (en)
Inventor
李刚
王坤
徐传运
何攀
阮子涵
吕鹏飞
蒋建忠
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Chongqing Linlue Technology Co ltd
Chongqing Normal University
Chongqing University of Technology
Original Assignee
Chongqing Linlue Technology Co ltd
Chongqing Normal University
Chongqing University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Chongqing Linlue Technology Co ltd, Chongqing Normal University, Chongqing University of Technology filed Critical Chongqing Linlue Technology Co ltd
Priority to CN202311370731.0A priority Critical patent/CN117610608A/zh
Publication of CN117610608A publication Critical patent/CN117610608A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/042Knowledge-based neural networks; Logical representations of neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • G06V10/806Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Software Systems (AREA)
  • General Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明属于计算机技术领域,尤其涉及基于多阶段特征融合的知识蒸馏方法、设备及介质,首先获取原始数据集,并对原始数据集进行预处理;然后采用原始数据集训练教师网络模型,获取训练完成的教师网络模型;接着冻结训练完成的教师网络模型的预训练权重,使用多阶段特征融合框架、跨阶段特征融合注意力模块和相同阶段融合特征对比损失函数训练学生网络模型,生成训练完成的学生网络模型;最后运行训练完成的学生网络模型,并在推理阶段只保留学生网络架构。本发明能够解决由于教师网络与学生网络之间存在特征分布差异性,导致学生网络难以充分学习教师网络中间层特征隐含知识的问题。

Description

基于多阶段特征融合的知识蒸馏方法、设备及介质
技术领域
本发明属于计算机技术领域,尤其涉及基于多阶段特征融合的知识蒸馏方法、设备及介质。
背景技术
近年来,深度学习领域中的卷积神经网络极大地促进了计算机视觉的发展,并在分类识别、目标检测和目标分割等方面得到了广泛的应用。然而,受限于边缘计算设备的算力约束和显存限制,卷积网络大模型难以部署应用。如何平衡计算开销和模型性能仍是一个十分具有挑战性的问题,而知识蒸馏是一种有效解决方法。知识蒸馏通过教师教授学生的形式将隐含知识从教师网络大模型传递到学生网络小模型上,进而大幅提升学生模型性能。该方法简单有效,被广泛应用于卷积网络和视觉任务上。
常见的知识蒸馏方法一般分为两类,一类是基于软标签分类知识,该类方法使用不同的温度软化教师网络和学生网络输出的分类标签,由此通过缩减两者软化标签间的最终分类知识差异性,来提升学生网络的识别精度。另一类是基于中间层特征,一般而言,教师网络与学生网络的结构和学习过程有一定的相似性,学生网络可以学习教师网络在中间特征层中隐含的知识,取得更优的学习过程,实现自身的精度提升效果。
教师网络与学生网络在相同阶段之间的特征分布往往存在较大的区别,而且同一网络在不同阶段之间的特征分布也各有侧重点,深层特征注重概念信息,浅层特征注重纹理信息,这带来了特征知识分布差异性问题,导致学生网络难以直接学习教师网络的特征隐含知识。
发明内容
本发明所解决的技术问题在于提供一种基于多阶段特征融合的知识蒸馏方法、设备及介质,以解决由于教师网络与学生网络之间存在特征分布差异性,导致学生网络难以充分学习教师网络中间层特征隐含知识的问题。
本发明提供的基础方案:基于多阶段特征融合的知识蒸馏方法,包括:
S1:获取原始数据集,并对原始数据集进行预处理;
S2:采用原始数据集训练教师网络模型,获取训练完成的教师网络模型;
S3:冻结训练完成的教师网络模型的预训练权重,使用多阶段特征融合框架、跨阶段特征融合注意力模块和相同阶段融合特征对比损失函数训练学生网络模型,生成训练完成的学生网络模型;
S4:运行训练完成的学生网络模型,并在推理阶段只保留学生网络架构。
进一步,所述S3包括:
S3-1:冻结训练完成的教师网络模型的预训练权重;
S3-2:将教师网络模型和学生网络模型均使用多阶段特征融合框架,并构建学生网络特征融合和教师网络特征融合;
S3-3:通过跨阶段特征融合注意力模块对学生网络模型进行训练;
S3-4:构建相同阶段融合特征对比损失函数对S3-3训练完成的学生网络模型进行损失验证。
进一步,所述S3-2具体为:
定义教师网络模型为T,学生网络模型为S,教师网络模型T和学生网络模型S均包括n个特征输出阶段和n个对应的融合模块FFAi,i∈n,其对应的第i层特征分别为Ti和Si
设置首个融合模块具有一个输入入口,其余均有两个输入入口,且最后一个融合模块有一个输出出口,其余均有两个输出出口,记第i层融合输出特征为和/>
构建学生网络特征融合和教师网络特征融合,所述学生网络特征融合计算公式为:
所述教师网络特征融合的计算公式为:
进一步,所述S3-3具体为:
跨阶段特征融合注意力模块中,包括两个不同阶段特征I1和I2,且I1和I2的尺寸和通道数均不同;
通过卷积和归一化处理将输入特征I1的尺寸和通道数调整为与输入特征I2相一致,并相加得到初步融合特征I;
通过并联的通道注意力机制Ac和空间注意力机制As进行处理,将并联结果相加得到融合特征F;
再通过卷积和归一化处理,分别生成尺寸和通道数一般不同的两个输出特征F1和F2,并且融合特征模块FFA1输入特征只有I1,融合特征模块FFAn输出特征只有F1
进一步,所述跨阶段特征融合注意力模块的表达公式为:
(F1,F2)=As(S(I1)+I2)+Ac(S(I1)+I2)。
进一步,所述S3-4具体为:
构建相同阶段融合特征对比损失函数;
通过相同阶段融合特征对比损失函数分别将教师网络模型和学生网络模型的第i阶段融合特征对应的TFi和SFi按照预设的处理做Lmse相似度匹配;
结合真实标签与学生分类结果的交叉熵损失函数,以及权重调节超参数,构建完整损失函数,对学生网络模型进行损失验证;
所述按照预设的处理具体为:
不做处理,保留TFi和SFi
不改变特征空间尺寸,在通道上进行压缩处理,得到TFi 1和SFi 1
不改变通道数、在空间上进行压缩,得到TFi 2和SFi 2,结合权重调节超参数λ,构成n个阶段融合特征对比函数。
进一步,所述相同阶段融合特征对比损失函数的计算公式为:
Lscm=Lmse(TFi,SFi)+λLmse(TFi 1,SFi 1)+λLmse(TFi 2,SFi 2)
其中,Lscm表示相同阶段融合特征对比损失函数,λ表示权重调节超参数;
所述完整损失函数的计算公式为:
Ltotal=Lce+αLscm
其中,Ltotal表示完整损失函数,Lce表示交叉熵损失函数,α表示完整损失函数对应的权重调节超参数。
进一步,所述S4中在推理阶段只保留学生网络架构具体为:在学生网络模型推理阶段,剪去教师网络模型和多阶段特征融合框架,只保留学生网络架构部分。
一种电子设备,包括处理器和存储器,所述存储器中存储程序或指令,所述处理器通过调用所述存储器存储的程序或指令,执行如上所述的基于多阶段特征融合的知识蒸馏方法。
一种计算机可读存储介质,所述计算机可读存储介质存储程序或指令,所述程序或指令使计算机执行如上所述的基于多阶段特征融合的知识蒸馏方法。
本发明的原理及优点在于:在本申请中,通过多阶段特征融合框架,分别在教师网络和学生网络实现特征知识从浅层到深层的跨阶段知识传递,进而可以让学生网络的单一阶段从教师的不同阶段学习特征隐含知识,增强学生模型的泛化性和学习能力。通过跨阶段特征融合注意力模块可以实现相邻阶段特征之间的有机融合和有益知识增强,再搭配上相同阶段融合特征之间的空间和通道对比损失函数,可以让学生网络从通道和空间两个角度来学习教师网络的特征和对比两者之间的特征差异性,实现学生模型的进一步效果提升,并增强其模型泛化性。
附图说明
图1为本发明实施例的流程框图;
图2为本发明实施例的整体知识蒸馏方法结构图;
图3为本发明实施例的跨阶段特征融合注意力模块结构图;
图4为本发明实施例的相同阶段融合特征对比损失函数结构图;
图5为本发明实施例的框架的不同阶段组合对比图;
图6为本发明实施例的框架和模块消融实验对比图;
图7为本发明实施例的电子设备结构示意图。
具体实施方式
下面通过具体实施方式进一步详细说明:
说明书附图中的标记包括:电子设备400、处理器401、存储器402、输入装置403、输出装置404。
实施例基本如附图1所示:基于多阶段特征融合的知识蒸馏方法,包括:
S1:获取原始数据集,并对原始数据集进行预处理;
在本实施例中,获取的原始数据集为CIFAR100数据集,该数据集的图像初始尺寸为32,在进行预处理包括进行图像尺寸参数为32、填充参数为4的随机剪切,以及随机水平翻转和图片归一化的预处理。
S2:采用原始数据集训练教师网络模型,获取训练完成的教师网络模型;在本实施例中,教师网络模型包括resnet56、resnet110、resnet32×4、vgg13、WRN-40-2、ResNet50。训练教师网络模型用到的策略中,BatchSize为64,Epoch为240。训练教师网络模型用到的初始学习率为0.05,均采用梯度学习率方法,在150、180和210轮各乘以0.1。优化器选用随机梯度下降法,权重衰减为5e-4,动量为0.9。假设训练好的教师网络模型为T。
S3:冻结训练完成的教师网络模型的预训练权重,使用多阶段特征融合框架、跨阶段特征融合注意力模块和相同阶段融合特征对比损失函数训练学生网络模型,生成训练完成的学生网络模型;其中,学生网络模型包括resnet、VGG、WRN、ShuffleNet和MobileNet。训练学生网络模型用到的参数设置中,ShuffleNet和MobileNet的初始学习率为0.01,其他模型用到的初始学习率为0.05,其他参数设置与训练教师网络模型相同。假设训练好的学生网络模型为S。
S3包括:
S3-1:冻结训练完成的教师网络模型的预训练权重;
S3-2:将教师网络模型和学生网络模型均使用多阶段特征融合框架,并构建学生网络特征融合和教师网络特征融合;
S3-3:通过跨阶段特征融合注意力模块对学生网络模型进行训练;
S3-4:构建相同阶段融合特征对比损失函数对S3-3训练完成的学生网络模型进行损失验证。
具体的,首先在训练学生网络模型的过程中,冻结S2所得教师网络模型T的预训练权重,保证在训练过程中,教师网络模型的预训练权重不发生改变。
其次,如图2所示,在训练学生网络模型的过程中,教师网络模型和学生网络模型均使用多阶段特征融合框架,达到一种对称的网络架构,教师网络模型T和学生网络模型S均包括n个特征输出阶段和n个对应的融合模块FFAi,i∈n,其对应的第i层特征分别为Ti和Si
设置首个融合模块具有一个输入入口,其余均有两个输入入口,且最后一个融合模块有一个输出出口,其余均有两个输出出口,记第i层融合输出特征为Fi 1
构建学生网络特征融合和教师网络特征融合,所述学生网络特征融合计算公式为:
所述教师网络特征融合的计算公式为:
因此,本申请中通过多阶段特征融合框架实现了深层特征和浅层特征的有机融合,并且采用了对称的知识传递结构实现了教师网络模型和学生网络模型相同阶段融合特征之间的有效知识传递,可以有效提高学生网络模型对于教师网络模型有益特征知识的学习能力和对于原始图片的识别能力。
再如图3所示,在跨阶段特征融合注意力模块中,包括两个不同阶段特征I1和I2,且I1和I2的尺寸和通道数均不同;
通过卷积和归一化处理将输入特征I1的尺寸和通道数调整为与输入特征I2相一致,并相加得到初步融合特征I;
通过并联的通道注意力机制Ac和空间注意力机制As进行处理,将并联结果相加得到融合特征F;
再通过卷积和归一化处理,分别生成尺寸和通道数一般不同的两个输出特征F1和F2,并且融合特征模块FFA1输入特征只有I1,融合特征模块FFAn输出特征只有F1。所述跨阶段特征融合注意力模块的表达公式为:
(F1,F2)=As(S(I1)+I2)+Ac(S(I1)+I2)。
因此,跨阶段特征融合注意力模块实现了具有不同尺寸和通道数的相邻特征之间的有机融合和融合后特征知识的有效增强。通过卷积网络实现输入特征的尺度缩放,借由残差思想将两个不同尺寸的特征融合起来,实现相邻特征的融合,并且结合多阶段特征融合框架之后,实现了特征知识从浅层到深层的逐层传递,获得了学生网络模型进行有效知识学习的体系和方法。此外,对于融合特征,还用了并联的空间注意力机制和通道注意力机制来增强融合后特征,可以进一步加强输入特征的有机融合,并且让融合后的特征加强空间信息和通道信息,进一步提高学生网络模型的学习能力。
并且再如图4所示,构建相同阶段融合特征对比损失函数对S3-3训练完成的学生网络模型进行损失验证,具体为:
构建相同阶段融合特征对比损失函数;
通过相同阶段融合特征对比损失函数分别将教师网络模型和学生网络模型的第i阶段融合特征对应的TFi和SFi按照预设的处理做Lmse相似度匹配;所述按照预设的处理具体为:一是不做处理,保留TFi和SFi;二是不改变特征空间尺寸,在通道上进行压缩处理,得到/>和/>三是不改变通道数、在空间上进行压缩,得到TFi 2和SFi 2,结合权重调节超参数λ,构成n个阶段融合特征对比函数。相同阶段融合特征对比损失函数的计算公式为:
Lscm=Lmse(TFi,SFi)+λLmse(TFi 1,SFi 1)+λLmse(TFi 2,SFi 2)
其中,Lscm表示相同阶段融合特征对比损失函数,λ表示权重调节超参数;
此外,结合真实标签与学生分类结果的交叉熵损失函数,以及权重调节超参数,构建完整损失函数,对学生网络模型进行损失验证;完整损失函数的计算公式为:
Ltotal=Lce+αLscm
其中,Ltotal表示完整损失函数,Lce表示交叉熵损失函数,α表示完整损失函数对应的权重调节超参数。
因此,空间和通道对比损失函数的结构与特征融合注意力模块的空间和通道注意力机制相对应,用相近的学习方法和效果验证方式,可以有效检验所学知识的可行性,也可以进一步加强学生网络模型的学习能力和进一步缩减学生网络模型的学习方位,综合实现一种有效的学生网络模型的知识蒸馏学习方法。
对此,本申请通过表1记录了多阶段特征融合框架在使用不同阶段组合下的实验结果,这里基本保留了学生网络的融合模块,只对教师网络的学习阶段做调整。组合变化如图5所示,可以看出,与其他阶段性组合相比,本实验提供的特征融合学习框架是显著有效的。
表1不同阶段组合下的实验数据表
接着本申请通过表2对比了多阶段特征融合知识蒸馏方法的框架和模块消融实验,结构变化如图6所示,MS代表不使用融合框架下的多阶段直接对比,MSF表示使用融合框架下的多阶段直接对比,FFA表示使用跨阶段融合特征注意力模块,SCM表示使用相同阶段的空间和通道对比损失函数,实验证明了框架本身和模块组合的有效性,均可以有效提升学生模型识别效果,并且本发明的消融实验均取得了精度提升的实验结果,证明了本发明的可行性。
表2框架和模块消融实验数据表
S4:运行训练完成的学生网络模型,并在推理阶段只保留学生网络架构,本申请中,在推理阶段,剪去教师网络和多阶段特征融合框架,只保留学生网络架构部分,保证在不添加额外参数和不更改自身结构的基础上,提高学生网络的识别精度。表3和表4中实验结果表明,本发明的多阶段特征融合知识蒸馏方法MSFF适用范围广,在多种轻量化网络模型上取得了有竞争力的精度提升效果,让学生网络从教师网络中学习到了许多有效知识,实验结果显示,本发明与CRD和OFD基本在同一水平,略低于ReviewKD的SOTA结果。
表3在CIFAR100数据集上,同类型模型蒸馏效果
表4在CIFAR100数据集上,不同类型模型蒸馏效果
下表5则是进一步验证了本发明的实用性和泛用性,本表记录的实验是在CIFAR100上,教师网络均为WRN-40-2时,使用不同的知识蒸馏方法训练得到学生网络WRN-40-1,并将得到的不同WRN-40-1模型迁移到STL-10和TinyImageNet数据集检查精度。从表格中的数据可以观察到,与基准精度和其他知识蒸馏方法相比,本发明提供的知识蒸馏方法取得了有价值的精度提升效果,进一步验证了本方法的有效性和泛化性。
表5教师和学生组合为WRN-40-2和WRN-40-1的迁移实验效果
此外,还包括一种电子设备,如图7所示,电子设备400包括一个或多个处理器401和存储器402。
处理器401可以是中央处理单元(CPU)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元,并且可以控制电子设备400中的其他组件以执行期望的功能。
存储器402可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器401可以运行所述程序指令,以实现上文所说明的本发明任意实施例的基于多阶段特征融合的知识蒸馏方法以及/或者其他期望的功能。在所述计算机可读存储介质中还可以存储诸如初始外参、阈值等各种内容。
在一个示例中,电子设备400还可以包括:输入装置403和输出装置404,这些组件通过总线系统和/或其他形式的连接机构(未示出)互连。该输入装置403可以包括例如键盘、鼠标等等。该输出装置404可以向外部输出各种信息,包括预警提示信息、制动力度等。该输出装置404可以包括例如显示器、扬声器、打印机、以及通信网络及其所连接的远程输出设备等等。
当然,为了简化,图7中仅示出了该电子设备400中与本发明有关的组件中的一些,省略了诸如总线、输入/输出接口等等的组件。除此之外,根据具体应用情况,电子设备400还可以包括任何其他适当的组件。
除了上述方法和设备以外,本发明的实施例还可以是计算机程序产品,其包括计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本发明任意实施例所提供的基于多阶段特征融合的知识蒸馏方法的步骤。
所述计算机程序产品可以以一种或多种程序设计语言的任意组合来编写用于执行本发明实施例操作的程序代码,所述程序设计语言包括面向对象的程序设计语言,诸如Java、C++等,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。
此外,本发明的实施例还可以是计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本发明任意实施例所提供的基于多阶段特征融合的知识蒸馏方法的步骤。
所述计算机可读存储介质可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以包括但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
以上的仅是本发明的实施例,方案中公知的具体结构及特性等常识在此未作过多描述,所属领域普通技术人员知晓申请日或者优先权日之前发明所属技术领域所有的普通技术知识,能够获知该领域中所有的现有技术,并且具有应用该日期之前常规实验手段的能力,所属领域普通技术人员可以在本申请给出的启示下,结合自身能力完善并实施本方案,一些典型的公知结构或者公知方法不应当成为所属领域普通技术人员实施本申请的障碍。应当指出,对于本领域的技术人员来说,在不脱离本发明结构的前提下,还可以作出若干变形和改进,这些也应该视为本发明的保护范围,这些都不会影响本发明实施的效果和专利的实用性。本申请要求的保护范围应当以其权利要求的内容为准,说明书中的具体实施方式等记载可以用于解释权利要求的内容。

Claims (10)

1.基于多阶段特征融合的知识蒸馏方法,其特征在于:包括:
S1:获取原始数据集,并对原始数据集进行预处理;
S2:采用原始数据集训练教师网络模型,获取训练完成的教师网络模型;
S3:冻结训练完成的教师网络模型的预训练权重,使用多阶段特征融合框架、跨阶段特征融合注意力模块和相同阶段融合特征对比损失函数训练学生网络模型,生成训练完成的学生网络模型;
S4:运行训练完成的学生网络模型,并在推理阶段只保留学生网络架构。
2.根据权利要求1所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述S3包括:
S3-1:冻结训练完成的教师网络模型的预训练权重;
S3-2:将教师网络模型和学生网络模型均使用多阶段特征融合框架,并构建学生网络特征融合和教师网络特征融合;
S3-3:通过跨阶段特征融合注意力模块对学生网络模型进行训练;
S3-4:构建相同阶段融合特征对比损失函数对S3-3训练完成的学生网络模型进行损失验证。
3.根据权利要求2所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述S3-2具体为:
定义教师网络模型为T,学生网络模型为S,教师网络模型T和学生网络模型S均包括n个特征输出阶段和n个对应的融合模块FFAi,i∈n,其对应的第i层特征分别为Ti和Si
设置首个融合模块具有一个输入入口,其余均有两个输入入口,且最后一个融合模块有一个输出出口,其余均有两个输出出口,记第i层融合输出特征为和/>
构建学生网络特征融合和教师网络特征融合,所述学生网络特征融合计算公式为:
所述教师网络特征融合的计算公式为:
4.根据权利要求3所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述S3-3具体为:
跨阶段特征融合注意力模块中,包括两个不同阶段特征I1和I2,且I1和I2的尺寸和通道数均不同;
通过卷积和归一化处理将输入特征I1的尺寸和通道数调整为与输入特征I2相一致,并相加得到初步融合特征I;
通过并联的通道注意力机制Ac和空间注意力机制As进行处理,将并联结果相加得到融合特征F;
再通过卷积和归一化处理,分别生成尺寸和通道数一般不同的两个输出特征F1和F2,并且融合特征模块FFA1输入特征只有I1,融合特征模块FFAn输出特征只有F1
5.根据权利要求4所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述跨阶段特征融合注意力模块的表达公式为:
(F1,F2)=As(S(I1)+I2)+Ac(S(I1)+I2)。
6.根据权利要求5所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述S3-4具体为:
构建相同阶段融合特征对比损失函数;
通过相同阶段融合特征对比损失函数分别将教师网络模型和学生网络模型的第i阶段融合特征对应的TFi和SFi按照预设的处理做Lmse相似度匹配;
结合真实标签与学生分类结果的交叉熵损失函数,以及权重调节超参数,构建完整损失函数,对学生网络模型进行损失验证;
所述按照预设的处理具体为:
不做处理,保留TFi和SFi
不改变特征空间尺寸,在通道上进行压缩处理,得到和/>
不改变通道数、在空间上进行压缩,得到TFi 2和SFi 2,结合权重调节超参数λ,构成n个阶段融合特征对比函数。
7.根据权利要求6所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述相同阶段融合特征对比损失函数的计算公式为:
Lscm=Lmse(TFi,SFi)+λLmse(TFi 1,SFi 1)+λLmse(TFi 2,SFi 2)
其中,Lscm表示相同阶段融合特征对比损失函数,λ表示权重调节超参数;
所述完整损失函数的计算公式为:
Ltotal=Lce+αLscm
其中,Ltotal表示完整损失函数,Lce表示交叉熵损失函数,α表示完整损失函数对应的权重调节超参数。
8.根据权利要求7所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述S4中在推理阶段只保留学生网络架构具体为:在学生网络模型推理阶段,剪去教师网络模型和多阶段特征融合框架,只保留学生网络架构部分。
9.一种电子设备,其特征在于:包括处理器和存储器,所述存储器中存储程序或指令,所述处理器通过调用所述存储器存储的程序或指令,执行如上权利要求1-7任一项所述的基于多阶段特征融合的知识蒸馏方法。
10.一种计算机可读存储介质,其特征在于:所述计算机可读存储介质存储程序或指令,所述程序或指令使计算机执行如上权利要求1-7任一项所述的基于多阶段特征融合的知识蒸馏方法。
CN202311370731.0A 2023-10-20 2023-10-20 基于多阶段特征融合的知识蒸馏方法、设备及介质 Pending CN117610608A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311370731.0A CN117610608A (zh) 2023-10-20 2023-10-20 基于多阶段特征融合的知识蒸馏方法、设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311370731.0A CN117610608A (zh) 2023-10-20 2023-10-20 基于多阶段特征融合的知识蒸馏方法、设备及介质

Publications (1)

Publication Number Publication Date
CN117610608A true CN117610608A (zh) 2024-02-27

Family

ID=89946919

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311370731.0A Pending CN117610608A (zh) 2023-10-20 2023-10-20 基于多阶段特征融合的知识蒸馏方法、设备及介质

Country Status (1)

Country Link
CN (1) CN117610608A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117831138A (zh) * 2024-03-05 2024-04-05 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117831138A (zh) * 2024-03-05 2024-04-05 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法
CN117831138B (zh) * 2024-03-05 2024-05-24 天津科技大学 基于三阶知识蒸馏的多模态生物特征识别方法

Similar Documents

Publication Publication Date Title
US20220180202A1 (en) Text processing model training method, and text processing method and apparatus
US20220335711A1 (en) Method for generating pre-trained model, electronic device and storage medium
WO2022007823A1 (zh) 一种文本数据处理方法及装置
CN111368993B (zh) 一种数据处理方法及相关设备
CN111062489A (zh) 一种基于知识蒸馏的多语言模型压缩方法、装置
WO2022057776A1 (zh) 一种模型压缩方法及装置
WO2022068627A1 (zh) 一种数据处理方法及相关设备
GB2571825A (en) Semantic class localization digital environment
CN110234018B (zh) 多媒体内容描述生成方法、训练方法、装置、设备及介质
EP4336378A1 (en) Data processing method and related device
CN109271516B (zh) 一种知识图谱中实体类型分类方法及系统
CN108536735B (zh) 基于多通道自编码器的多模态词汇表示方法与系统
CN115221846A (zh) 一种数据处理方法及相关设备
CN113761868B (zh) 文本处理方法、装置、电子设备及可读存储介质
KR102635800B1 (ko) 신경망 모델의 사전 훈련 방법, 장치, 전자 기기 및 매체
CN117610608A (zh) 基于多阶段特征融合的知识蒸馏方法、设备及介质
US20220004721A1 (en) Translation quality detection method and apparatus, machine translation system, and storage medium
US20240152770A1 (en) Neural network search method and related device
CN112257860A (zh) 基于模型压缩的模型生成
WO2023173552A1 (zh) 目标检测模型的建立方法、应用方法、设备、装置及介质
CN115861995A (zh) 一种视觉问答方法、装置及电子设备和存储介质
CN114925320B (zh) 一种数据处理方法及相关装置
CN115795025A (zh) 一种摘要生成方法及其相关设备
Zhang et al. A lightweight multi-dimension dynamic convolutional network for real-time semantic segmentation
CN112784003A (zh) 训练语句复述模型的方法、语句复述方法及其装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination