CN117253123A - 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法 - Google Patents
一种基于中间层特征辅助模块融合匹配的知识蒸馏方法 Download PDFInfo
- Publication number
- CN117253123A CN117253123A CN202311012546.4A CN202311012546A CN117253123A CN 117253123 A CN117253123 A CN 117253123A CN 202311012546 A CN202311012546 A CN 202311012546A CN 117253123 A CN117253123 A CN 117253123A
- Authority
- CN
- China
- Prior art keywords
- network
- student
- teacher
- module
- modules
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 230000004927 fusion Effects 0.000 title claims abstract description 77
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 25
- 238000000034 method Methods 0.000 title claims abstract description 21
- 238000012549 training Methods 0.000 claims abstract description 63
- 238000004821 distillation Methods 0.000 claims abstract description 15
- 230000006870 function Effects 0.000 claims description 41
- 125000000524 functional group Chemical group 0.000 claims description 23
- 238000012360 testing method Methods 0.000 claims description 16
- 238000012545 processing Methods 0.000 claims description 11
- 238000000605 extraction Methods 0.000 claims description 4
- 230000010354 integration Effects 0.000 claims description 3
- 238000010606 normalization Methods 0.000 claims description 3
- 238000007500 overflow downdraw method Methods 0.000 claims description 3
- 239000010410 layer Substances 0.000 claims 8
- 239000011229 interlayer Substances 0.000 claims 3
- 238000012512 characterization method Methods 0.000 abstract description 3
- 238000013508 migration Methods 0.000 abstract description 3
- 230000005012 migration Effects 0.000 abstract description 3
- 238000005516 engineering process Methods 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000010025 steaming Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/86—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using syntactic or structural representations of the image or video pattern, e.g. symbolic string recognition; using graph matching
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/44—Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Abstract
本发明公开了一种基于中间层辅助特征模块融合匹配的知识蒸馏方法,将教师网络和学生网络划分成若干个模块,利用所划分的模块构建分支网络和辅助训练模块,计算其辅助训练损失;再构建特征融合模块并利用注意力机制生成不同的融合权值对辅助训练模块中提取到的特征根据制定的融合策略进行特征融合,计算其特征融合损失;最后将利用总的蒸馏损失促使学生网络和教师网络进行充分地信息交流,并且辅助学生网络更好的分模块矫正参数。本发明解决了知识网络中存在的信息利用不足、信息交流不对等以及信息冗余问题,提升了学生模型对综合信息的学习和表征能力,提高了特征迁移的可靠性,增强了模型的泛化性和鲁棒性。
Description
技术领域
本发明涉及计算机视觉领域,具体涉及一种基于中间层特征辅助模块融合匹配的知识蒸馏方法。
背景技术
对于一般的神经网络模型,复杂模型往往是单个宽而深的复杂模型或若干个基础模型的集合,具有较好的收敛能力和任务处理性能。相反地,简单模型的基本结构单一并且网络模型呈现窄而浅的特点,其表征能力有限。知识蒸馏技术利用复杂模型处理任务能力强和简单模型存储量小的特点,对模型的知识进行迁移来完成模型的压缩处理。知识蒸留技术在处理同样任务时具有提升模型精度、降低模型时延,压缩网络参数的特点。
Seyed Iman Mirzadeh在《Improved Knowledge Distillation via TeacherAssistant》一文中采用引入中等规模的网络(教师助理)的方案来弥合学生模型和教师模型之间的差距,在一定程度上解决了由于教师模型和学生模型差异过大带来的问题,但是该方法的教师助理选择会耗费大量的实验和计算资源,且没有从根本上解决学生网络模型的表达能力有限这一问题。除此之外,按照大多数模型的建模逻辑,学生网络架构和教师网络架构的一致性对知识迁移效果的影响是至关重要的,教师不恰当的表征学习往往会导致知识蒸馏的次优性。
发明内容
本发明的目的在于提供一种基于中间层特征辅助模块融合匹配的知识蒸馏方法,充分挖掘出中间层特征的丰富的信息并加以利用,构建出辅助训练模块和迭代融合模块,解决了以往异构知识蒸馏网络中存在的信息利用不足和信息交流不对等的问题,既保证了不需要学生网络先验知识的便捷性和可以直接用于各种网络的广泛性,又提高了特征迁移的可靠性。
实现本发明目的的技术解决方案为:一种中间层特征辅助模块融合匹配的知识蒸馏方法,包括以下步骤:
步骤S1、在CIFAR-100数据集中随机采集K幅带标签的图像,10000<K≤60000,对上述K幅图像进行归一化处理,将像素大小统一为h×w,其中,h为图像高度,w为图像宽度;将统一尺寸后的图像按照5∶1的比例随机划分为训练数据集和测试数据集,对训练数据集进行数据增强构成教师-学生网络训练数据集,利用教师-学生网络训练数据集对教师网络进行预训练,得到教师主干网络,转入步骤S2。
步骤S2、根据卷积层的深度和特征图的大小,将教师主干网络划分为n个教师模块,学生主干网络划分为n个学生模块,转入步骤S3。
步骤S3、利用教师模块构建学生分支网络,利用学生模块构建教师分支网络,再利用分支网络中包含的子模块构建辅助训练模块,转入步骤S4。
步骤S4、提取步骤S2中各主干网络的输出特征以及步骤S3中辅助训练模块中各分支网络的输出特征,利用教师主干网络的输出特征和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,转入步骤S5。
步骤S5、制定分组融合策略:
利用步骤S3中辅助训练模块中功能相对应的教师分支网络的子模块和学生分支网络的子模块共同构成n-1个功能组,转入步骤S6。
步骤S6、构建特征融合模块,并利用步骤S5中n-1个功能组经过特征融合模块融合后的特征分别与学生主干网络中功能相对应的n-1个学生模块的输出特征计算特征融合损失,转入步骤S7。
步骤S7、将传统蒸馏损失、辅助训练损失以及特征融合损失加权求和,得到总的损失函数,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,转入步骤S8。
步骤S8、将测试数据集输入到训练好的学生网络,输出测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
与现有技术相比,本发明优点在于:
(1)构建了辅助训练模块。该模块使教师网络提供学生网络易于学习的可转移知识,并且辅助学生网络更好的分批矫正模块参数,促进教师网络模块与学生网络模块间的对等信息交流。
(2)构建了特征融合模块并制定了相应的分组融合策略,采用特征迭代融合的方法来整合特征信息、提供可信的特征指导学生网络模型训练。该模块解决了蒸馏信息冗余的问题,具有很强的信息综合能力,能很好地协调多种输入信息关系,进一步优化异构知识蒸馏网络使其实现先进的性能。
(3)利用了注意力机制将特征融合模块的不同通道设置不同的注意力卷积网络生成不同模块的的融合权值,将细节信息聚合,提取更为全面的信息且自适应的突出重要信息。
附图说明
图1为本发明基于中间层特征辅助模块融合匹配的知识蒸馏方法的模型图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面对本发明实施方式作进一步地详细描述。
结合图1,一种基于中间层特征辅助模块融合匹配的知识蒸馏方法,包括以下步骤:
步骤S1、在CIFAR-100数据集中随机采集K幅带标签的图像,10000<K≤60000,对上述K幅图像进行归一化处理,将像素大小统一为h×w,其中,h为图像高度,w为图像宽度;将统一尺寸后的图像按照5∶1的比例随机划分为训练数据集和测试数据集,对训练数据集进行数据增强构成教师-学生网络训练数据集,利用教师-学生网络训练数据集对教师网络进行预训练,得到教师主干网络,转入步骤S2。
步骤S2、根据卷积层的深度和特征图的大小,将教师主干网络划分为n个教师模块,学生主干网络划分为n个学生模块,转入步骤S3。
步骤S3、利用教师模块构建学生分支网络,利用学生模块构建教师分支网络,再利用分支网络中包含的子模块构建辅助训练模块,具体如下:
将步骤S2中教师主干网络和学生主干网络划分的n个模块分别用集合表示;教师模块的集合用表示,T表示教师主干网络,/>表示教师主干网络的第i个教师模块;学生模块的集合用/>表示,S表示学生网络,/>表示学生主干网络的第i个学生模块;然后在教师模块/>后延伸出分支,即依次接入n-i个学生模块/>以构成第v个教师分支网络分支,将这n-i个学生模块称作教师分支网络的子模块,将这n-i个子模块的集合记为/>其中/>表示该教师分支网络的第u个子模块;同理,在学生模块/>后延伸出分支,依次接入n-i个教师模块/> 以构成第v个学生分支网络,将这n-i个教师模块称作学生分支网络的子模块,将这n-i个子模块的集合记为/>其中/>表示该学生分支网络的第u个子模块;最多共有n-1个学生网络分支和n-1个教师网络分支,即1≤v≤nv1,其中,每条教师分支网络中的n-i个学生模块称教师分支网络的子模块,每条学生分支网络中的n-i个教师模块称为学生分支网络的子模块,即1≤u≤n-i;最后将学生分支网络的子模块集合BT1,BT2,…,BTv,…,BTn-1和教师分支网络的子模块集合BS1,BS2,…,BSv,…,BSn-1共同构成辅助训练模块Baux={BT1,BT2,...,BTv,...,BTn-1;BS1,BS2,...,BSv,…,BSn-1},该模块使教师网络提供学生网络易于学习的可转移知识,并且辅助学生网络更好的分批矫正模块参数,促进教师网络模块与学生网络模块间的对等信息交流。
步骤S4、提取步骤S2中各主干网络的输出特征以及步骤S3中辅助训练模块中各分支网络的输出特征,利用教师主干网络的输出特征和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,具体如下:
首先将第v条学生分支网络中的第u个子模块的输出特征和第v条教师分支网络中的第u个子模块的输出特征/>分别表示为:
其中,表示第v条学生分支网络的第u个子模块的特征提取函数,/>表示第v条学生分支网络的第u个子模块;/>表示第v条教师分支网络的第u个子模块的特征提取函数,/>表示第v条教师分支网络的第u个子模块,1≤v≤n-1,1≤u≤n-i。
再将教师主干网络的输出特征经过softmax函数处理后的输出定义为PT,学生主干网络的输出特征/>经过softmax函数处理后的输出定义为PS:
式中t表示温度的超参数。
利用PT、PS计算出教师主干网络和学生主干网络的输出层特征间的知识蒸馏损失即传统的知识蒸馏损失Lcla:
Lcla=KL(PT||PS)
再将第v条教师分支网络的输出特征经softmax函数处理后的类概率定义为将第v条学生分支网络的输出特征经softmax函数处理后的类概率定义为/>
利用PT计算出教师分支网络和教师主干网络的输出特征间的KL损失LTv,利用PS计算出学生分支网络和学生主干网络的KL损失LSv:
最后将辅助训练模块中各分支网络输出特征与主干网络的输出特征之间的辅助训练损失Laux重建为:
Laux=LTv+LSv。
步骤S5、制定分组融合策略,利用步骤S3中辅助训练模块中功能相对应的教师分支网络的子模块和学生分支网络的子模块共同构成n-1个功能组,具体如下:
按照相同位置的模块承担相同功能这个规则,利用第1个教师分支网络的第1个子模块的输出特征/>和第1个学生分支网络的第1个子模块/>的输出特征/>共同建立为第一个功能组/>利用第1个教师分支网络的第2个子模块/>的输出/>第2个教师分支网络的第1个子模块/>的输出特征/>第1个学生分支网络的第2个子模块的输出特征/>以及第2个学生分支网络的第1个子模块的输出特征/>共同建立为第二个功能组/>......;依次取出所有教师分支网络和学生分支网络中所有的子模块,将其中执行相同功能的子模块的输出特征划分为一组,直至建立出第n-1个功能组/>将所有功能组的集合定义为G={G1,G2,...,Gn-1},1≤v≤n-1,1≤u≤n-i。
步骤S6、构建特征融合模块,并利用步骤S5中n-1个功能组经过特征融合模块融合后的特征分别与学生主干网络中功能相对应的n-1个学生模块的输出特征计算特征融合损失,具体如下:
首先由3个大小为1×1、步长为1的卷积层和一次concat操作构成特征融合模块,同时利用注意力机制将特征融合模块的不同通道设置不同的注意力卷积网络生成不同的融合权值将细节信息聚合,提取更为全面的信息且自适应的突出重要信息;在此特征融合模块中采用特征迭代融合的方法,解决了蒸馏信息冗余的问题,具有很强的信息综合能力,能很好地协调多种输入信息关系,进一步优化异构知识蒸馏网络使其实现先进的性能,具体如下:
每两个特征根据不同的融合权值进行一次融合,再将得到的融合特征与下一个特征进行融合,如此逐次进行迭代融合直至遍历功能组中的所有元素;
再将特征融合模块的融合函数定义为fm,将第k个功能组Gk经过特征融合模块的输出特征表示为
其中,1≤k≤n-1。
将学生主干网络划分的学生模块集合的除去第一个学生模块后的n-1个学生模块的输出特征集合定义为/>利用L2归一化损失函数计算功能组经过特征融合模块后的输出特征/>和特征集合FSO中的输出特征/>之间的特征融合损失Lfuse:
步骤S7、将传统蒸馏损失Lcla、辅助训练损失Laux以及特征融合损失Lfuse加权求和,得到总的损失函数Ltotality,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,具体如下:
Ltotality=λ1Lcla+λ2Laux+λ3Lfuse
其中,λ1为传统知识蒸馏损失的权重超参数,λ2为辅助训练损失的权重超参数,λ3为特征融合损失函数的权重超参数。
步骤S8、将测试数据集输入到训练好的学生网络,输出测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
实施例1
本发明所述的一种基于中间层特征辅助模块融合匹配的知识蒸馏方法,步骤如下:
步骤S1、在CIFAR-100数据集中随机采集60000幅带标签的图像,对这60000幅图像进行归一化处理,将像素大小统一为32×32,将统一尺寸后的图像按照5∶1的比例随机划分为训练数据集和测试数据集,对训练数据集进行数据增强构成教师-学生网络训练数据集,利用教师-学生网络训练数据集对教师网络进行预训练,得到教师网络,其中数据增强操作包括图像缩放和随机翻转,图像缩放比例按照原始图像的10%向内缩放和向外缩放,随机翻转的角度在-20°到20°,图像类别数量为100类。
步骤S2、根据卷积层的深度和特征图的大小,将教师主干网络和学生主干网络各自划分为4个模块,转入步骤S3。
步骤S3、利用步骤S2中的教师模块构建3条学生分支网络,利用步骤S2中的学生模块构建3条教师分支网络;再利用这6条分支网络包含的模块共同构成辅助训练模块,具体如下:
将步骤S2中教师主干网络和学生主干网络划分的4个模块分别用集合表示;教师模块的集合用表示,T表示教师主干网络,/>表示教师主干网络的第i个教师模块;学生模块的集合用/>表示,S表示学生网络,/>表示学生主干网络的第i个学生模块;然后在教师模块/>后延伸出分支,即依次接入3个学生模块/>以构成第1个教师分支网络分支,其子模块集合/> 在教师模块/>后延伸出分支,即依次接入2个学生模块/>以构成第2个教师分支网络分支,其子模块集合在教师模块/>后延伸出分支,即依次接入1个学生模块/>以构成第3个教师分支网络分支,其子模块集合/>同理,在学生模块/>后延伸出分支,即依次接入3个教师模块/>以构成第1个学生网络分支,其子模块集合/> 在学生模块/>后延伸出分支,即依次接入2个教师模块/>以构成第2个学生分支网络分支,其子模块集合/>在学生模块/>后延伸出分支,即依次接入2个教师模块/>以构成第3个学生分支网络分支,其子模块集合/>最后将学生分支网络的子模块集合和教师分支网络的子模块集合共同构成辅助训练模块Baux={BT1,BT2,BT3,;BS1,BS2,BS3}。
步骤S4、提取步骤S2中各主干网络的输出特征和步骤S3中辅助训练模块中各分支网络的输出特征,利用预训练教师主干网络和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,转入步骤S5。
步骤S5、制定分组融合策略:
利用步骤S3中辅助训练模块中功能相对应的教师分支网络的子模块和学生分支网络的子模块共同构成3个功能组,具体如下:
按照相同位置的模块承担相同功能这个规则,利用第1个教师分支网络的第1个子模块的输出特征/>和第1个学生分支网络的第1个子模块/>的输出特征/>共同建立为第一个功能组/>利用第1个教师分支网络的第2个子模块/>的输出/>第2个教师分支网络的第1个子模块/>的输出特征/>第1个学生分支网络的第2个子模块的输出特征/>以及第2个学生分支网络的第1个子模块的输出特征/>共同建立为第二个功能组/>以此类推建立出第3个功能组将所有功能组的集合定义为G={G1,G2,G3}。
步骤S6、构建特征融合模块,并利用步骤S5中3个功能组经过特征融合模块融合后的特征分别与学生主干网络功能相对应的3个模块的输出特征计算特征融合损失,转入步骤S7。
步骤S7、将传统蒸馏损失、辅助训练损失以及特征融合损失加权求和,得到总的损失函数,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,转入步骤S8。
步骤S8、将测试数据集输入到训练好的学生网络,输出测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
Claims (8)
1.一种基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤如下:
步骤S1、在CIFAR-100数据集中随机采集K幅带标签的图像,10000<K≤60000,对上述K幅图像进行归一化处理,将像素大小统一为h×w,其中,h为图像高度,w为图像宽度;将统一尺寸后的图像按照5:1的比例随机划分为训练数据集和测试数据集,对训练数据集进行数据增强构成教师—学生网络训练数据集,利用教师—学生网络训练数据集对教师网络进行预训练,得到教师主干网络,转入步骤S2;
步骤S2、根据卷积层的深度和特征图的大小,将教师主干网络划分为n个教师模块,学生主干网络划分为n个学生模块,转入步骤S3;
步骤S3、利用教师模块构建学生分支网络,利用学生模块构建教师分支网络,再利用分支网络中包含的子模块构建辅助训练模块,转入步骤S4;
步骤S4、提取步骤S2中各主干网络的输出特征以及步骤S3中辅助训练模块中各分支网络的输出特征,利用教师主干网络的输出特征和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,转入步骤S5;
步骤S5、制定分组融合策略:
利用步骤S3中辅助训练模块中功能相对应的教师分支网络的子模块和学生分支网络的子模块共同构成n-1个功能组,转入步骤S6;
步骤S6、构建特征融合模块,并利用步骤S5中n-1个功能组经过特征融合模块融合后的特征分别与学生主干网络中功能相对应的n-1个学生模块的输出特征计算特征融合损失,转入步骤S7;
步骤S7、将传统蒸馏损失、辅助训练损失以及特征融合损失加权求和,得到总的损失函数,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,转入步骤S8;
步骤S8、将测试数据集输入到训练好的学生网络,输出测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
2.根据权利要求1所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤S3中,利用教师模块构建学生分支网络,利用学生模块构建教师分支网络,再利用分支网络中包含的子模块构建辅助训练模块,具体如下:
将步骤S2中教师主干网络和学生主干网络划分的n个模块分别用集合表示;教师模块的集合用表示,T表示预训练教师主干网络,/>表示教师主干网络的第i个教师模块;学生模块的集合用/>表示,S表示学生网络,/>表示学生主干网络的第i个学生模块;然后在教师模块/>后延伸出分支,即依次接入n-i个学生模块/>以构成第v条教师分支网络分支,将这n-i个学生模块称作该条教师分支网络的子模块,将第v条教师分支网络分支中的n-i个子模块的集合记为/>其中/>表示该教师分支网络的第u个子模块;同理,在学生模块/>后延伸出分支,依次接入n-i个教师模块/>以构成第v条学生分支网络,将这n-i个教师模块称作该条学生分支网络的子模块,将这第v条学生分支网络中的n-i个子模块的集合记为其中/>表示该学生分支网络的第u个子模块;最多共有n-1个学生网络分支和n-1个教师网络分支,即1≤v≤n-1且1≤u≤n-i;最后将所有学生分支网络的子模块集合BT1,BT2,...,BTv,…,BTn-1和所有教师分支网络的子模块集合BS1,BS2,...,BSv,…,BSn-1共同构成辅助训练模块Baux={BT1,BT2,...,BTv,…,BTn-1;BS1,BS2,...,BSv,…,BSn-1}。
3.根据权利要求2所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤S4中,提取步骤S2中各主干网络的输出特征以及步骤S3中辅助训练模块中各分支网络的输出特征,利用教师主干网络的输出特征和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,具体如下:
首先将第v条学生分支网络中的第u个子模块的输出特征和第v条教师分支网络中的第u个子模块的输出特征/>分别表示为:
其中,表示第v条学生分支网络的第u个子模块的特征提取函数,/>表示第v条学生分支网络的第u个子模块;/>表示第v条教师分支网络的第u个子模块的特征提取函数,/>表示第v条教师分支网络的第u个子模块,1≤v≤n-1,1≤u≤n-i;
再将教师主干网络的输出特征经过softmax函数处理后的输出定义为PT,学生主干网络的输出特征/>经过softmax函数处理后的输出定义为PS:
式中t表示温度的超参数;
利用PT、PS计算出教师主干网络和学生主干网络的输出层特征间的知识蒸馏损失即传统的知识蒸馏损失Lcla:
Lcla=KL(PT||PS)
最后将第v条教师分支网络的输出特征经softmax函数处理后的类概率定义为将第v条学生分支网络的输出特征经softmax函数处理后的类概率定义为/>
利用PT计算出教师分支网络和教师主干网络的输出特征间的KL损失LTv,利用/>PS计算出学生分支网络和学生主干网络的KL损失LSv:
最后将辅助训练模块中各分支网络输出特征与主干网络的输出特征之间的辅助训练损失Laux重建为:
Laux=LTv+LSv。
4.根据权利要求3所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤S5中制定分组融合策略,利用步骤S3中辅助训练模块中功能相对应的教师分支网络的子模块和学生分支网络的子模块共同构成n-1个功能组,具体分组策略如下:
按照相同位置的模块承担相同功能这个规则,利用第1个教师分支网络的第1个子模块的输出特征/>和第1个学生分支网络的第1个子模块/>的输出特征/>共同建立第一个功能组/>利用第1个教师分支网络的第2个子模块/>的输出/>第2个教师分支网络的第1个子模块/>的输出特征/>第1个学生分支网络的第2个子模块/>的输出特征/>以及第2个学生分支网络的第1个子模块的输出特征/>共同建立为第二个功能组……;依次取出所有教师分支网络和学生分支网络中所有的子模块,将其中执行相同功能的子模块的输出特征划分为一组,直至建立出第n-1个功能组将所有功能组的集合定义为G={G1,G2,…,Gn-1},1≤v≤n-1,1≤u≤n-i。
5.根据权利要求4所述的基于中间层辅助特征融合匹配的知识蒸馏方法,其特征在于,步骤S6中构建特征融合模块,并利用步骤S5中n-1个功能组经过特征融合模块融合后的特征分别与学生主干网络功能相对应的n-1个模块的输出特征计算特征融合损失,具体如下:
首先由3个大小为1×1、步长为1的卷积层和一次concat操作构成特征融合模块,同时利用注意力机制将特征融合模块的不同通道设置不同的注意力卷积网络生成不同的融合权值;在此特征融合模块中采用特征迭代融合的方法,即每两个特征根据不同的融合权值进行一次融合,再将得到的融合特征与下一个特征进行融合,如此逐次进行迭代融合直至遍历功能组中的所有元素;
再将特征融合模块的融合函数定义为fm,将第k个功能组Gk经过特征融合模块的输出特征表示为
其中,1≤j≤n-1;
将学生主干网络划分的学生模块集合的除去第一个学生模块后的n-1个学生模块的输出特征集合定义为/>利用L2归一化损失函数计算功能组经过特征融合模块后的输出特征/>和特征集合FSO中的输出特征/>之间的特征融合损失Lfuse:
6.根据权利要求5所述的基于中间层特征辅助融合匹配的知识蒸馏方法,其特征在于,步骤S7中将传统蒸馏损失Lcla、辅助训练损失Laux以及特征融合损失Lfuse加权求和,得到总的损失函数Ltotality,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,具体如下:
Ltotality=λ1Lcla+λ2Laux+λ3Lfuse
其中,λ1为传统知识蒸馏损失的权重超参数,λ2为辅助训练损失的权重超参数,λ3为特征融合损失函数的权重超参数。
7.根据权利要求6所述的基于中间层特征辅助融合匹配的知识蒸馏方法,其特征在于:λ1=0.5,λ2=0.1,λ3=0.1。
8.根据权利要求3所述的基于中间层特征辅助融合匹配的知识蒸馏方法,其特征在于:t=4。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311012546.4A CN117253123B (zh) | 2023-08-11 | 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311012546.4A CN117253123B (zh) | 2023-08-11 | 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117253123A true CN117253123A (zh) | 2023-12-19 |
CN117253123B CN117253123B (zh) | 2024-05-17 |
Family
ID=
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113326941A (zh) * | 2021-06-25 | 2021-08-31 | 江苏大学 | 基于多层多注意力迁移的知识蒸馏方法、装置及设备 |
CN113344206A (zh) * | 2021-06-25 | 2021-09-03 | 江苏大学 | 融合通道与关系特征学习的知识蒸馏方法、装置及设备 |
CN114373133A (zh) * | 2022-01-10 | 2022-04-19 | 中国人民解放军国防科技大学 | 一种基于稠密特征分组蒸馏的缺失模态地物分类方法 |
CN114611670A (zh) * | 2022-03-15 | 2022-06-10 | 重庆理工大学 | 一种基于师生协同的知识蒸馏方法 |
CN114782776A (zh) * | 2022-04-19 | 2022-07-22 | 中国矿业大学 | 基于MoCo模型的多模块知识蒸馏方法 |
US20230154202A1 (en) * | 2020-10-23 | 2023-05-18 | Xi'an Creation Keji Co., Ltd. | Method of road detection based on internet of vehicles |
US20230153943A1 (en) * | 2021-11-16 | 2023-05-18 | Adobe Inc. | Multi-scale distillation for low-resolution detection |
CN116258871A (zh) * | 2023-03-15 | 2023-06-13 | 西南科技大学 | 一种基于融合特征的目标网络模型获取方法及装置 |
CN116486285A (zh) * | 2023-03-15 | 2023-07-25 | 中国矿业大学 | 一种基于类别掩码蒸馏的航拍图像目标检测方法 |
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20230154202A1 (en) * | 2020-10-23 | 2023-05-18 | Xi'an Creation Keji Co., Ltd. | Method of road detection based on internet of vehicles |
CN113326941A (zh) * | 2021-06-25 | 2021-08-31 | 江苏大学 | 基于多层多注意力迁移的知识蒸馏方法、装置及设备 |
CN113344206A (zh) * | 2021-06-25 | 2021-09-03 | 江苏大学 | 融合通道与关系特征学习的知识蒸馏方法、装置及设备 |
US20230153943A1 (en) * | 2021-11-16 | 2023-05-18 | Adobe Inc. | Multi-scale distillation for low-resolution detection |
CN114373133A (zh) * | 2022-01-10 | 2022-04-19 | 中国人民解放军国防科技大学 | 一种基于稠密特征分组蒸馏的缺失模态地物分类方法 |
CN114611670A (zh) * | 2022-03-15 | 2022-06-10 | 重庆理工大学 | 一种基于师生协同的知识蒸馏方法 |
CN114782776A (zh) * | 2022-04-19 | 2022-07-22 | 中国矿业大学 | 基于MoCo模型的多模块知识蒸馏方法 |
CN116258871A (zh) * | 2023-03-15 | 2023-06-13 | 西南科技大学 | 一种基于融合特征的目标网络模型获取方法及装置 |
CN116486285A (zh) * | 2023-03-15 | 2023-07-25 | 中国矿业大学 | 一种基于类别掩码蒸馏的航拍图像目标检测方法 |
Non-Patent Citations (5)
Title |
---|
GAN HU 等: "Layer-fusion for online mutual knowledge distillation", 《MULTIMEDIA SYSTEMS》, vol. 29, 10 November 2022 (2022-11-10), pages 787 * |
PARK D Y 等: "Learning student-friendly teacher networks for knowledge distillation", 《ADVANCES IN NEURAL INFORMATION PROCESSING SYSTEMS》, vol. 34, 31 December 2021 (2021-12-31), pages 13292 - 13303 * |
刘志强: "基于残差通道注意力的图像超分辨网络轻量化的研究", 《中国优秀硕士学位论文全文数据库 信息科技》, vol. 2022, no. 4, 15 April 2022 (2022-04-15), pages 138 - 995 * |
张燕咏 等: "基于多模态融合的自动驾驶感知及计算", 《计算机研究与发展》, vol. 57, no. 9, 31 December 2020 (2020-12-31), pages 1781 - 1799 * |
葛仕明 等: "基于深度特征蒸馏的人脸识别", 《北京交通大学学报》, vol. 41, no. 6, 31 December 2017 (2017-12-31), pages 27 - 33 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113240580B (zh) | 一种基于多维度知识蒸馏的轻量级图像超分辨率重建方法 | |
CN110349185B (zh) | 一种rgbt目标跟踪模型的训练方法及装置 | |
CN113190688B (zh) | 基于逻辑推理和图卷积的复杂网络链接预测方法及系统 | |
CN113516133B (zh) | 一种多模态图像分类方法及系统 | |
CN112348191A (zh) | 一种基于多模态表示学习的知识库补全方法 | |
CN113240683B (zh) | 基于注意力机制的轻量化语义分割模型构建方法 | |
CN110009700B (zh) | 基于rgb图和梯度图的卷积神经网络视觉深度估计方法 | |
CN112381179A (zh) | 一种基于双层注意力机制的异质图分类方法 | |
CN112686376A (zh) | 一种基于时序图神经网络的节点表示方法及增量学习方法 | |
CN113190654A (zh) | 一种基于实体联合嵌入和概率模型的知识图谱补全方法 | |
CN116402133B (zh) | 一种基于结构聚合图卷积网络的知识图谱补全方法及系统 | |
CN113628059A (zh) | 一种基于多层图注意力网络的关联用户识别方法及装置 | |
CN113094593A (zh) | 社交网络事件推荐方法、系统、设备及存储介质 | |
CN115545160A (zh) | 一种多学习行为协同的知识追踪方法及系统 | |
CN115170874A (zh) | 一种基于解耦蒸馏损失的自蒸馏实现方法 | |
CN112905894B (zh) | 一种基于增强图学习的协同过滤推荐方法 | |
CN116030537B (zh) | 基于多分支注意力图卷积的三维人体姿态估计方法 | |
CN117253123B (zh) | 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法 | |
CN116958324A (zh) | 图像生成模型的训练方法、装置、设备及存储介质 | |
CN117253123A (zh) | 一种基于中间层特征辅助模块融合匹配的知识蒸馏方法 | |
CN113962332A (zh) | 基于自优化融合反馈的显著目标识别方法 | |
CN115408505A (zh) | 一种基于双通道超图兴趣建模的对话推荐算法 | |
CN115908600A (zh) | 基于先验正则化的大批量图像重建方法 | |
CN114139674A (zh) | 行为克隆方法、电子设备、存储介质和程序产品 | |
CN113095328A (zh) | 一种基尼指数引导的基于自训练的语义分割方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant |