CN115512156A - 一种用于图像分类模型训练的自蒸馏训练方法 - Google Patents

一种用于图像分类模型训练的自蒸馏训练方法 Download PDF

Info

Publication number
CN115512156A
CN115512156A CN202211173732.1A CN202211173732A CN115512156A CN 115512156 A CN115512156 A CN 115512156A CN 202211173732 A CN202211173732 A CN 202211173732A CN 115512156 A CN115512156 A CN 115512156A
Authority
CN
China
Prior art keywords
module
shallow
layer
characteristic diagram
attention
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211173732.1A
Other languages
English (en)
Inventor
朱明甫
倪水平
马新良
张威
马传琦
洪振东
朱智丹
常月光
李炳伸
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Henan Chuidian Technology Co ltd
Original Assignee
Henan Chuidian Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Henan Chuidian Technology Co ltd filed Critical Henan Chuidian Technology Co ltd
Priority to CN202211173732.1A priority Critical patent/CN115512156A/zh
Publication of CN115512156A publication Critical patent/CN115512156A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明提出的是一种用于图像分类模型训练的自蒸馏训练方法,该方法包括:1、针对图像分类模型完成自蒸馏框架的搭建;2、将深层分类器划分出四个模块:在第一个模块的基础上依次增加第一注意力模块、第一浅层模块、第一全连接层作为第一浅层分类器;在第一个模块与第二个模块的基础上依次增加第二注意力模块、第二浅层模块、第二全连接层作为第二浅层分类器;在第一个模块、第二个模块、第三个模块的基础上依次增加第三注意力模块、第三浅层模块、第三全连接层作为第三浅层分类器;在第一浅层模块的基础上增加第四全连接层,在第二浅层模块的基础上增加第五全连接层;3、使用数据集来进行自蒸馏训练,得到一个深层分类器和三个浅层分类器。

Description

一种用于图像分类模型训练的自蒸馏训练方法
技术领域
本发明涉及一种用于图像分类模型训练的自蒸馏训练方法,属于图像分类模型训练技术领域。
背景技术
随着人工智能发展,深度神经网络算法得到广泛的应用,在各种领域取得显著的成果;以图像分类领域为例,经典的图像分类模型有VGG网络、ResNet网络、ResNext网络等;针对特定的图像分类任务(如垃圾分类、交通标志分类、医学图像分类等),通常选用特定的数据集(垃圾图像数据集、交通标志图像数据集、医学图像数据集等)对图像分类模型进行训练;在现有的训练方式下,分类精度的高低,由图像分类模型自身结构所决定。
为追求更好的图像分类效果,图像分类模型的深度和宽度不断的增加,导致了存储和计算海量增长,图像分类模型难以在资源受限的边缘设备上部署;因此,就需要一系列对图像分类模型进行压缩的方案;其中,知识蒸馏的训练方式能够在不改变图像分类模型结构的条件下提升图像分类模型的分类精度;知识蒸馏的思想是利用复杂度较高的教师网络指导复杂度较低的学生网络进行训练,使学生网络在训练过程中,能够吸收教师网络传递的“知识”,提升学生网络性能,达到图像分类模型压缩的目的。
作为知识蒸馏的改进方法自蒸馏,可加大图像分类模型精度的提升幅度,现有自蒸馏框架可为实际的应用场景提供一个深层分类器以及多个轻量化的浅层分类器,但多个浅层分类器中引入的注意力模块主要由深度可分离卷积层构成,轻量化程度仍有提升的空间。
目前,也存在自蒸馏训练的改进方案,在三个浅层分类器中引入注意力模块,并将三个浅层分类器变得更加轻量,完善了蒸馏框架,提升了蒸馏的效率,但三个浅层分类器仍具有轻量化的提升空间;当深层分类器参数量较大时,整个自蒸馏训练过程计算量依然较大,导致蒸馏训练效率降低。
基于此,本发明提出了一种用于图像分类模型训练的自蒸馏训练方法,设计出一种新型自蒸馏框架,引入轻量的注意力模块到三个浅层分类器中,使其进一步轻量化,在不影响蒸馏效果的前提下,降低部署难度,能有效缓解边缘设备部署性能瓶颈;同时减少三个浅层分类器的计算量,缩短训练时长。
发明内容
本发明提出的是一种用于图像分类模型训练的自蒸馏训练方法,其目的旨在降低自蒸馏框架内浅层分类器的参数量与计算量。
本发明的技术解决方案:一种用于图像分类模型训练的自蒸馏训练方法,该方法包括:
1、将图像分类模型自身作为深层分类器,使用深层分类器作为教师网络;
2、将深层分类器按照网络深度划分出四个模块:第一个模块(Block1)、第二个模块(Block2)、第三个模块(Block3)、第四个模块(Block4);在第一个模块的基础上依次增加第一注意力模块、第一浅层模块、第一全连接层作为第一浅层分类器;在第一个模块与第二个模块的基础上依次增加第二注意力模块、第二浅层模块、第二全连接层作为第二浅层分类器;在第一个模块、第二个模块、第三个模块的基础上依次增加第三注意力模块、第三浅层模块、第三全连接层作为第三浅层分类器;将第一浅层分类器、第二浅层分类器、第三浅层分类器全部作为学生网络;同时,在第一浅层模块的基础上增加第四全连接层,在第二浅层模块的基础上增加第五全连接层;
3、使用数据集来进行自蒸馏训练,得到一个深层分类器和三个浅层分类器。
进一步地,所述深层分类器的第四个模块包含卷积模块与输出模块两个模块;在深层分类器第四个模块内卷积模块的基础上增加自适应平均池化层,用于辅助自蒸馏训练,便于深层分类器第四个模块内卷积模块的“知识”通过第四全连接层传授给第一浅层模块,通过第五全连接层传授给第二浅层模块。
进一步地,所述第一个模块(Block1)的输出特征图output1,作为第二个模块(Block2)的输入特征图,同时也作为第一注意力模块的输入特征图;第二个模块(Block2)的输出特征图output2,作为第三个模块(Block3)的输入特征图,同时也作为第二注意力模块的输入特征图;第三个模块(Block3)的输出特征图output3,作为第四个模块(Block4)的输入特征图,同时也作为第三注意力模块的输入特征图;在第四个模块(Block4)内部,第四个模块卷积模块的输出特征图作为自适应平均池化层的输入特征图,同样也作为第四个模块中输出模块的输入特征图。
进一步地,所述第一注意力模块对输入特征图output1的处理流程具体包括如下步骤:
1)将尺寸为H'×W'×C'的输入特征图output1按照通道数C'分为n组,得到n个尺寸为H'×W'×C'/n的中间特征图bi(i=1,2,…,n);
2)经过全局平均池化层对中间特征图bi(i=1,2,…,n)进行全局平均池化,得到n个尺寸为1×1×C'/n的第一特征图gi(i=1,2,…,n),将第一特征图与中间特征图进行对位点乘,即第一特征图gi与中间特征图bi对位点乘(i=1,2,…,n)得到n个初始注意力掩码ci(i=1,2,…,n),对n个初始注意力掩码ci(i=1,2,…,n)中的每一个初始注意力掩码分别求均值与标准差,将每一个初始注意力掩码进行标准化处理,得到n个H'×W'×1的第二特征图di(i=1,2,…,n);
3)将n个第二特征图di(i=1,2,…,n)中的每一个第二特征图使用Sigmoid函数激活得到最终的n个注意力掩码ei(i=1,2,…,n),n个注意力掩码ei(i=1,2,…,n)分别与相应组别的n个中间特征图bi(i=1,2,…,n)对位点乘,最终得到n个尺寸为H'×W'×C'/n的小组输出特征图fi(i=1,2,…,n);之后将这n个尺寸为H'×W'×C'/n的小组输出特征图fi(i=1,2,…,n)拼接为最终的尺寸为H'×W'×C'的输出特征图J1,输出特征图J1与输入特征图output1尺寸相同;第二注意力模块、第三注意力模块对各自输入特征图output2、输入特征图output3的处理流程与第一注意力模块对输入特征图output1处理流程完全相同,第二注意力模块的输出特征图J2与第二注意力模块的输入特征图output2尺寸相同;第三注意力模块的输出特征图J3与第三注意力模块的输入特征图output3尺寸相同。
进一步地,所述第一浅层模块、第二浅层模块、第三浅层模块分别为模块深度不同的浅层模块;所述第一浅层模块包含了三组模块结构和一个自适应平均池化层;第一注意力模块的输出特征图J1作为第一浅层模块内三组模块结构中的第一组模块结构的输入特征图经过第一组模块的处理之后,得到输出特征图R1_1作为第二组模块结构的输入特征图,特征图R1_1经过第二组模块的处理之后,得到第二组模块结构的输出特征图R1_2作为第三组模块结构的输入特征图,特征图R1_2经过第三组模块的处理之后得到第三组模块结构的输出特征图R1_3,第三组模块结构的输出特征图R1_3作为第一浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R1_4。
进一步地,所述第一浅层模块中的三组模块结构中的每一组模块结构完全相同,每一组模块结构均包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第一浅层注意力模块。
进一步地,所述第二浅层模块包含了两组模块结构和一个自适应平均池化层;第二注意力模块的输出特征图J2作为第二浅层模块内两组模块结构中的第一组模块结构的输入特征图经过第一组模块的处理之后,得到输出特征图R2_1作为第二组模块结构的输入特征图,输入特征图R2_1经过第二组模块结构的处理之后,得到第二组模块结构的输出特征图R2_2;输出特征图R2_2作为第二浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R2_3。
进一步地,所述第二浅层模块中的两组模块结构中的每一组模块结构完全相同,每一组模块结构均包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第二浅层注意力模块。
进一步地,所述第三浅层模块包含了一组模块结构和一个自适应平均池化层;第三注意力模块的输出特征图J3作为第三浅层模块内模块结构的输入特征图经过模块结构的处理之后,得到输出特征图R3_1,输出特征图R3_1作为第三浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R3_2;所述第三浅层模块中的一组模块结构包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第三浅层注意力模块。
进一步地,所述使用数据集来进行自蒸馏训练,具体包括如下步骤:
首先,对CIFAR10数据集的训练集与CIFAR100数据集的训练集进行如下处理:
(1)对图片进行随机裁剪,裁剪后尺寸(size)为32,填充边界的值(padding)设置为4,填充值(fill)设置为128;
(2)对图片进行随机水平翻转;
(3)将图片格式转换为tensor格式;将图片的每一个数值归一化到[0,1];
(4)将图片的每一个数值进行标准化处理,标准化处理的均值为(0.4914,0.4822,0.4465),方差为(0.2023,0.1994,0.2010);
其次,对CIFAR10数据集的测试集与CIFAR100数据集的测试集进行如下处理:
(1)将图片格式转换为tensor格式;将图片的每一个数值归一化到[0,1];
(2)将图片的每一个数值进行标准化处理,标准化处理的均值为(0.4914,0.4822,0.4465),标准差为(0.2023,0.1994,0.2010)。
本发明的有益效果:
1)本发明在自蒸馏框架内部使用了新型参数量和计算量可忽略的第一注意力模块、第二注意力模块、第三注意力模块、浅层注意力模块,相较于原注意力模块,进一步减少了三个浅层分类器的参数总量,降低三个浅层分类器的计算量,缩短整体框架自蒸馏模型蒸馏训练的时间,提升训练效率,降低了三个浅层分类器部署的难度与成本;
2)本发明具有普适性,不仅可应用于垃圾分类、交通标志分类、医学图像分类等领域,还可应用于基于知识蒸馏的模型攻防、基于知识蒸馏的目标检测、基于知识蒸馏的图像语义分割等领域。
附图说明
附图1为图像分类网络自蒸馏框架结构示意图。
附图2为ResNet18网络深度划分示意图。
附图3为ResNet50网络深度划分示意图。
附图4为VGG11(BN)网络深度划分示意图。
附图5为VGG16(BN)网络深度划分示意图。
附图6为Resnext50(32x4d)网络深度划分示意图。
附图7为注意力模块数据处理流程示意图。
附图8为浅层模块结构示意图。
附图9为第三浅层模块卷积流程示意图。
附图10为图像分类应用框架示意图。
附图11为图像分类模型训练部署流程示意图。
附图12为VGG11(BN)自蒸馏框架示意图。
具体实施方式
一种用于图像分类模型训练的自蒸馏训练方法,该方法包括:
1、针对图像分类模型完成自蒸馏框架的搭建;所述自蒸馏框架将图像分类模型自身作为深层分类器,使用深层分类器作为教师网络;
2、将深层分类器按照网络深度划分出四个模块:第一个模块(Block1)、第二个模块(Block2)、第三个模块(Block3)、第四个模块(Block4);在第一个模块的基础上依次增加第一注意力模块、第一浅层模块、第一全连接层作为第一浅层分类器;在第一个模块与第二个模块的基础上依次增加第二注意力模块、第二浅层模块、第二全连接层作为第二浅层分类器;在第一个模块、第二个模块、第三个模块的基础上依次增加第三注意力模块、第三浅层模块、第三全连接层作为第三浅层分类器;将第一浅层分类器、第二浅层分类器、第三浅层分类器全部作为学生网络;同时,还要在第一浅层模块的基础上增加第四全连接层,在第二浅层模块的基础上增加第五全连接层;
3、使用数据集来进行自蒸馏训练,提升深层分类器的性能表现,训练完成后,得到可完成分类任务的一个深层分类器和三个浅层分类器。
所述深层分类器第四个模块共包含卷积模块与输出模块两个模块;在深层分类器第四个模块内卷积模块的基础上增加自适应平均池化层,用于辅助自蒸馏训练,便于深层分类器第四个模块内卷积模块的“知识”通过第四全连接层传授给第一浅层模块,通过第五全连接层传授给第二浅层模块。
所述第一个模块(Block1)的输出特征图output1,作为第二个模块(Block2)的输入特征图,同时也作为第一注意力模块的输入特征图;第二个模块(Block2)的输出特征图output2,作为第三个模块(Block3)的输入特征图,同时也作为第二注意力模块的输入特征图;第三个模块(Block3)的输出特征图output3,作为第四个模块(Block4)的输入特征图,同时也作为第三注意力模块的输入特征图;在第四个模块(Block4)内部,第四个模块卷积模块的输出特征图作为自适应平均池化层的输入特征图,同样也作为第四个模块中输出模块的输入特征图。
如图1所示,附图1中AdaptiveAvgPool表示自适应平均池化层。
一种用于图像分类模型训练的自蒸馏训练方法,该方法还包括:
4、根据应用场景的实际需求挑选出相应的图像分类模型,转换为能够在边缘设备上部署的自蒸馏训练后的图像分类模型,进行部署。
所述深层分类器优选为ResNet18、ResNet50、VGG11(BN)、VGG16(BN)、Resnext50(32x4d)五种分类网络中的任意一种;每种网络深度的划分结构分别如附图2-附图6所示。
附图2中,Conv3×3表示卷积核大小为3的卷积层,AvgPool表示平均池化层,FC表示全连接层;以ResNet18图像分类模型自身为深层分类器,将深层分类器自身划分为四个模块:R18_Block1、R18_Block2、R18_Block3、R18_Block4;R18_Block1由第一个Conv3×3卷积层与Res18Block1构成;R18_Block2由Res18Block2构成;R18_Block3由Res18Block3构成;R18_Block4的卷积模块由Res18Block4构成,R18_Block4的输出模块由一个平均池化层与一个第六全连接层构成;同时,以R18_Block4内部的Res18Block4为基础添加自适应平均池化层。
附图3中,Conv1×1表示卷积核大小为1的卷积层;以ResNet50图像分类模型自身为深层分类器,将深层分类器自身划分为四个模块:R50_Block1、R50_Block2、R50_Block3、R50_Block4;R50_Block1由第一个Conv3×3卷积层与Res50Block1构成;R50_Block2由Res50Block2构成;R50_Block3由Res50Block3构成;R50_Block4的卷积模块由Res50Block4构成,R50_Block4的输出模块由一个平均池化层与一个第六全连接层构成;同时,还要以R50_Block4内部的Res50Block4为基础添加自适应平均池化层。
附图4中,maxpool表示最大池化层;以VGG11图像分类模型自身为深层分类器,将深层分类器自身划分为四个模块:V11_Block1、V11_Block2、V11_Block3、V11_Block4;V11_Block1由V11_layer1构成;V11_Block2由V11_layer2构成;V11_Block3由V11_layer3构成;V11_Block4的卷积模块由V11_layer4构成,V11_Block4的输出模块由一个最大池化层和三个全连接层构成,三个全连接层从左往右依次是第六全连接层、第七全连接层、第八全连接层;同时,以V11_Block4内部的V11_layer4为基础添加自适应平均池化层。
附图5中,以VGG16图像分类模型自身为深层分类器,将深层分类器自身划分为四个模块:V16_Block1、V16_Block2、V16_Block3、V16_Block4;V16_Block1由V16_layer1构成;V16_Block2由V16_layer2构成;V16_Block3由V16_layer3构成;V16_Block4的卷积结构由V16_layer4构成、V16_Block4的输出结构由一个最大池化层与三个全连接层构成,三个全连接层从左往右依次是第六全连接层、第七全连接层、第八全连接层;同时,以V16_Block4内部的V16_layer4为基础添加自适应平均池化层。
附图6中,C=32表示卷积层采用分组数为32的分组卷积方式;以ResNext50图像分类模型自身为深层分类器,将深层分类器自身划分为四个模块:RN50_Block1、RN50_Block2、RN50_Block3、RN50_Block4;RN50_Block1由第一个Conv3×3卷积层与ResnBlock1构成;RN50_Block2由ResnBlock2构成;RN50_Block3由ResnBlock3构成;RN50_Block4的卷积结构由ResnBlock4构成,输出结构由一个平均池化层和一个第六全连接层构成;同时,以RN50_Block4内部的ResnBlock4为基础添加自适应平均池化层。
所述第一注意力模块、第二注意力模块、第三注意力模块为同样的注意力模块;所述第一注意力模块、第二注意力模块、第三注意力模块对各自输入特征图output1、output2、output3的处理流程完全相同;故只给出第一注意力模块对输入特征图output1的处理流程。
所述第一注意力模块对输入特征图output1的处理流程如图7所示,具体包括如下步骤:
1)如附图7中,将尺寸为H'×W'×C'的输入特征图output1按照通道数C'分为n组,得到n个尺寸为H'×W'×C'/n的中间特征图bi(i=1,2,…,n);
2)经过全局平均池化层(Global Average Pooling)对中间特征图bi(i=1,2,…,n)进行全局平均池化,得到n个尺寸为1×1×C'/n的第一特征图gi(i=1,2,…,n),将第一特征图与中间特征图进行对位点乘(Position-wise dot Product),即gi与bi对位点乘(i=1,2,…,n)得到n个初始注意力掩码ci(i=1,2,…,n),对n个初始注意力掩码ci(i=1,2,…,n)中的每一个初始注意力掩码分别求均值与标准差,将每一个初始注意力掩码进行标准化处理(Normalization),得到n个H'×W'×1的第二特征图di(i=1,2,…,n);
3)将n个第二特征图di(i=1,2,…,n)中的每一个第二特征图使用Sigmoid函数激活得到最终的n个注意力掩码ei(i=1,2,…,n),n个注意力掩码ei(i=1,2,…,n)分别与相应组别的n个中间特征图bi(i=1,2,…,n)对位点乘,最终得到n个尺寸为H'×W'×C'/n的小组输出特征图fi(i=1,2,…,n);之后将这n个尺寸为H'×W'×C'/n的小组输出特征图fi(i=1,2,…,n)拼接为最终的尺寸为H'×W'×C'的输出特征图J1,输出特征图J1与输入特征图output1尺寸相同;由于第二注意力模块、第三注意力模块对各自输入特征图output2、output3的处理流程与第一注意力模块对输入特征图output1处理流程完全相同;故第二注意力模块的输出特征图J2与第二注意力模块的输入特征图output2尺寸相同;故第三注意力模块的输出特征图J3与第三注意力模块的输入特征图output3尺寸相同。
在第一注意力模块、第二注意力模块、第三注意力模块中对位点乘(Position-wise dot Product)、标准化处理(Normalization)、Sigmoid函数激活处理均是数学处理方式;故第一注意力模块、第二注意力模块、第三注意力模块主要由全局平均池化层构成,均使用了全局和局部特征的相关性来生成注意力掩码,所以第一注意力模块、第二注意力模块、第三注意力模块的参数量以及运算量可忽略。
所述第一浅层模块、第二浅层模块、第三浅层模块分别为模块深度不同的浅层模块;所述第一浅层模块包含了三组模块结构和一个自适应平均池化层;第一注意力模块的输出特征图J1作为第一浅层模块内三组模块结构中的第一组模块结构的输入特征图经过第一组模块的处理之后,得到输出特征图R1_1作为第二组模块结构的输入特征图,特征图R1_1经过第二组模块的处理之后,得到第二组模块结构的输出特征图R1_2作为第三组模块结构的输入特征图,特征图R1_2经过第三组模块的处理之后得到第三组模块结构的输出特征图R1_3,第三组模块结构的输出特征图R1_3作为第一浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R1_4。
第一浅层模块中的三组模块结构中的每一组模块结构完全相同,每一组模块结构均包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第一浅层注意力模块;故第一浅层模块中:第一组模块结构的输出特征图R1_1就是第一组模块结构的第一浅层注意力模块的输出特征图;第二组模块结构的输出特征图R1_2就是第二组模块结构的第一浅层注意力模块的输出特征图;第三组模块结构的输出特征图R1_3就是第三组模块结构的第一浅层注意力模块的输出特征图。
所述第二浅层模块包含了两组模块结构和一个自适应平均池化层;第二注意力模块的输出特征图J2作为第二浅层模块内两组模块结构中的第一组模块结构的输入特征图经过第一组模块的处理之后,得到输出特征图R2_1作为第二组模块结构的输入特征图,输入特征图R2_1经过第二组模块结构的处理之后,得到第二组模块结构的输出特征图R2_2;输出特征图R2_2作为第二浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R2_3。
第二浅层模块中的两组模块结构中的每一组模块结构完全相同,每一组模块结构均包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第二浅层注意力模块;故第二浅层模块中:第一组模块结构的输出特征图R2_1就是第一组模块结构的第二浅层注意力模块的输出特征图;第二组模块结构的输出特征图R2_2就是第二组模块结构的第二浅层注意力模块的输出特征图。
所述第三浅层模块包含了一组模块结构和一个自适应平均池化层;第三注意力模块的输出特征图J3作为第三浅层模块内模块结构的输入特征图经过模块结构的处理之后,得到输出特征图R3_1,输出特征图R3_1作为第三浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R3_2。
第三浅层模块中的一组模块结构包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第三浅层注意力模块;故第三浅层模块中:模块结构的输出特征图R3_1就是模块结构的第三浅层注意力模块的输出特征图。
其中,所有模块结构组中的第一浅层注意力模块、第二浅层注意力模块、第三浅层注意力模块与第一注意力模块、第二注意力模块、第三注意力模块结构完全相同;故方案中所有的注意力模块结构均相同,所有模块结构组中的第一浅层注意力模块、第二浅层注意力模块、第三浅层注意力模块各自输出特征图的输出流程均可参考上方第一注意力模块输出特征图J1的输出流程。
由于第一浅层模块包含了三组模块结构和一个自适应平均池化层;第二浅层模块包含了两组模块结构和一个自适应平均池化层;第三浅层模块包含了一组模块结构和一个自适应平均池化层;同时,三个浅层模块中所有模块结构中的每一组模块结构,结构均完全相同;三个浅层模块中的每一组模块结构的卷积方式均为深度可分离卷积方式,一次深度可分离卷积方式的卷积过程包括一次深度卷积与一次逐点卷积;因此,三个浅层模块中的每一组模块结构均包含两次深度可分离卷积;三个浅层模块的结构如图8所示;其中,第一浅层模块的模块结构分组数为3,故L=3;第二浅层模块的模块结构分组数为2,故L=2;第三浅层模块的模块结构分组数为1,故L=1。
由于三个浅层模块中所有模块结构中的每一组模块结构,结构均完全相同,故这里给出具有单组模块结构的第三浅层模块对自身输入特征图J3的处理流程;该流程中第三浅层模块内单组模块结构对自身输入特征图J3的处理流程与其余浅层模块中的每一组模块结构对各自相对应的输入特征图处理流程完全相同。
第三浅层模块中的一组模块结构对输入特征图J3的处理流程包括第一次深度可分离卷积(第一个深度卷积与第一个逐点卷积)、第二次深度可分离卷积(第二个深度卷积与第二个逐点卷积)以及第三浅层注意力模块的相关处理,具体过程如图9所示包括如下:
1)对尺寸为Hi In×Wi In×Ci In输入特征图J3进行第一个深度卷积;其中,Hi In表示输入特征图J3的高度,Wi In表示输入特征图J3的宽度,Ci In表示输入特征图J3的通道数,深度卷积的卷积核个数与输入特征图J3的通道数相同,故深度卷积的卷积核个数也为Ci In,每一个深度卷积的卷积核的尺寸为h1×w1×1,h1表示深度卷积卷积核的高度,w1表示深度卷积卷积核的宽度;Ci In个深度卷积卷积核深度卷积后产生尺寸为Hi M0×Wi M0×Ci In的中间特征图B1;其中,Hi M0表示中间特征图B1的高度,Wi M0表示中间特征图B1的宽度,Ci In表示中间特征图B1的通道数;
2)对中间特征图B1进行第一个逐点卷积,共有Ci In组逐点卷积的卷积核,每组逐点卷积的卷积核的尺寸为1×1×Ci In;中间特征图B1经Ci In组逐点卷积的卷积核逐点卷积后产生尺寸为Hi M1×Wi M1×Ci In输出特征图B2,其中,Hi M1表示中间特征图B2的高度,Wi M1表示中间特征图B2的宽度,中间特征图B2的输出通道数也为Ci In
3)对尺寸为Hi M1×Wi M1×Ci In的输出特征图B2进行第二个深度卷积;其中,深度卷积的卷积核个数与特征图B2的通道数相同,故深度卷积的卷积核个数也为Ci In,与第一个深度卷积相同每一个深度卷积的卷积核的尺寸也为h1×w1×1,h1表示深度卷积卷积核的高度,w1表示深度卷积卷积核的宽度;Ci In个深度卷积卷积核深度卷积后产生尺寸为Hi M2×Wi M2×Ci In的中间特征图B3;其中,Hi M2表示中间特征图B3的高度,Wi M2表示中间特征图B3的宽度,中间特征图B3的输出通道数也为Ci In
4)对尺寸为Hi M2×Wi M2×Ci In的中间特征图B3进行第二个逐点卷积,共有Ci out组逐点卷积的卷积核,每组逐点卷积的卷积核的尺寸为1×1×Ci In;中间特征图B3经Ci out组逐点卷积的卷积核逐点卷积后产生尺寸为Hi M3×Wi M3×Ci out的输出特征图B4,其中,Ci out为输出特征图B4的输出通道数,Hi M3表示输出特征图B4的高度,Wi M3表示输出特征图B4的宽度;
5)将输出特征图B4作为第三浅层注意力模块的输入特征图,经过第三浅层注意力模块处理之后,得到输出特征图R3_1。
第三浅层模块中的一组模块结构对输入特征图J3处理完成之后,输出特征图R3_1又经过自适应平均池化层得到尺寸为1×1×Ci out的输出特征图R3_2。
图9中,设第三浅层模块输入特征图J3尺寸为Hi In×Wi In×Ci In,第一个深度卷积层和第二个深度卷积层的卷积核尺寸均为h1×w1×1,经过第一个深度卷积层深度卷积,生成了尺寸为Hi M0×Wi M0×Ci In的中间特征图B1,则第一次深度卷积层深度卷积的参数量、计算量分别为h1×w1×Ci In、h1×w1×Ci In×Hi M0×Wi M0,再经过第一次逐点卷积层逐点卷积,生成了尺寸为Hi M1×Wi M1×Ci In中间特征图B2,则第一次逐点卷积层逐点卷积的参数量、计算量分别为Ci In×Ci In、Ci In×Ci In×Hi M1×Wi M1;经过第二次深度卷积层深度卷积,生成了尺寸为Hi M2×Wi M2×Ci In的中间特征图B3,则第二次深度卷积深度卷积的参数量、计算量分别为h1×w1×Ci In、h1×w1×Ci In×Hi M2×Wi M2,再经过第二次逐点卷积层逐点卷积,生成了尺寸为Hi M3×Wi M3×Ci out的中间特征图B4,则第二次逐点卷积层逐点卷积的参数量、计算量分别为Ci In×Ci out、Ci In×Ci out×Hi M3×Wi M3;第三浅层注意力模块的参数量及相关操作的计算量占比较小,可忽略;故具有单组模块结构的第三浅层模块总参数量为(h1×w1×Ci In+Ci In×Ci In)+(h1×w1×Ci In+Ci In×Ci out),总计算量为h1×w1×Ci In×Hi M0×Wi M0+Ci In×Ci In×Hi M1×Wi M1+h1×w1×Ci In×Hi M2×Wi M2+Ci In×Ci out×Hi M3×Wi M3
所述使用数据集来进行自蒸馏训练,同时使用CIFAR10数据集和CIFAR100数据集来进行自蒸馏训练;CIFAR10数据集由50000张图片的训练集和10000张图片的测试集构成;CIFAR100数据集由50000张图片的训练集和10000张图片的测试集构成。
所述使用数据集来进行自蒸馏训练,具体包括如下步骤:
首先,对CIFAR10数据集与CIFAR100数据集进行预处理,其中,对CIFAR10数据集的训练集与CIFAR100数据集的训练集进行如下处理:
(1)对图片进行随机裁剪,裁剪后尺寸(size)为32,填充边界的值(padding)设置为4,填充值(fill)设置为128;
(2)对图片进行随机水平翻转;
(3)将图片格式转换为tensor格式;将图片的每一个数值归一化到[0,1];
(4)将图片的每一个数值进行标准化处理,标准化处理的均值为(0.4914,0.4822,0.4465),方差为(0.2023,0.1994,0.2010)。
对CIFAR10数据集的测试集与CIFAR100数据集的测试集进行如下处理:
(1)将图片格式转换为tensor格式;将图片的每一个数值归一化到[0,1];
(2)将图片的每一个数值进行标准化处理,标准化处理的均值为(0.4914,0.4822,0.4465),标准差为(0.2023,0.1994,0.2010)。
在CIFAR10数据集上,所有第一注意力模块、第二注意力模块、第三注意力模块、第一浅层模块的第一浅层注意力模块,第二浅层模块的第二浅层注意力模块、第三浅层模块的第三浅层注意力模块将它们各自对应的输入特征图按照通道数分为16组;在CIFAR100数据集上,所有第一注意力模块、第二注意力模块、第三注意力模块、第一浅层模块的第一浅层注意力模块,第二浅层模块的第二浅层注意力模块、第三浅层模块的第三浅层注意力模块将它们各自对应的输入特征图按照通道数分为32组;优选地,在采用CIFAR10数据集与CIFAR100数据集进行自蒸馏训练时,均优选采用SGD优化器进行优化,所有实验均在GPU设备上,pytorch1.9.1环境下进行;具体自蒸馏细节以及参数设置参考下方所给实施例。
本发明中将深层分类器作为教师网络,三个浅层分类器都作为学生网络,对深层分类器和三个浅层分类器进行自蒸馏训练;自蒸馏训练全过程,“知识”仅在蒸馏框架内部流动,进而提升了深层分类器的性能;本发明通过在浅层分类器内浅层模块之前引入参数量和计算量可忽略的由池化层构成的第一注意力模块、第二注意力模块、第三注意力模块减少了浅层分类器的计算量,缩短了蒸馏训练的时长,提升了训练效率,实现了浅层分类器的轻量化,降低了模型部署难度;本发明通过在构建浅层分类器的浅层模块内模块结构中同样添加由引入参数量和计算量可忽略的由池化层构成的浅层注意力模块,确保深层分类器自蒸馏的效果。
本发明所述一种图像分类模型自蒸馏训练的实现方法,可广泛应用于图像分类中,如垃圾分类、交通标志分类、医学图像分类等;图像分类应用框架如图10所示,将经本发明自蒸馏训练后的深层分类器与三个浅层分类器结合具体的应用场景,挑选出满足需求的一个分类器作为最终的图像分类模型部署在边缘设备节点上,部署图像分类模型后的边缘设备接收到图像数据后,会对数据进行预处理,并将处理后的数据输入给该图像分类模型,该图像分类模型对预处理后的数据进行分类,最后通过数据中心和云端为PC端、移动端、API等提供图像分类应用服务,服务质量的好坏取决于图像分类模型的性能;图像分类模型的训练与部署如图11所示。
实施例
本实施例使用VGG11(BN)网络搭建自蒸馏框架,并使用CIFAR10数据集进行自蒸馏训练,如图12所示;首先,以VGG11(BN)自身为深层分类器,在V11_layer1模块的基础上依次增加第一注意力模块、第一浅层模块、具有10个节点的第一全连接层(FC1_10)作为第一浅层分类器;在V11_layer1模块与V11_layer2模块的基础上依次增加第二注意力模块、第二浅层模块、具有10个节点的第二全连接层(FC2_10)作为第二浅层分类器;在V11_layer1模块、V11_layer2模块与V11_layer3模块的基础上依次增加第三注意力模块、第三浅层模块、具有10个节点的第三全连接层(FC3_10)作为第三浅层分类器;同时,还要在第一浅层模块的基础上增加具有512个节点的第四全连接层(FC4_512),在第二浅层模块的基础上增加第五全连接层(FC5_512);最后,使用CIFAR10数据集进行自蒸馏训练;假设有N个样本:
Figure BDA0003863374520000151
样本一共有M个类别,每个类别相应的标签记作:
Figure BDA0003863374520000152
自蒸馏框架中一共有4个分类器,第一浅层分类器1记作θ1、第二浅层分类器2记作θ2、第三浅层分类器3记作θ3、深层分类器记作θ4
图12中,具有4096个节点的第六全连接层、具有4096个节点的第七全连接层与具有10个节点的第八全连接层分别使用(FC6-4096)、(FC7-4096)、(FC8-10)表示。在深层分类器具有10个节点的第八全连接层(FC8-10)上添加Softmax函数记作Softmax;在第一浅层分类器的第一全连接层(FC1_10)上添加Softmax函数记作Softmax1;在第二浅层分类器的第二全连接层(FC2_10)上添加Softmax函数记作Softmax2;在第三浅层分类器的第三全连接层(FC3_10)上添加Softmax函数记作Softmax3;同时,在Softmax函数中引入温度系数记作Softmax_T,可以通过修改温度系数的值来软化输出的标签,Softmax_T函数如公式(1)所示;在深层分类器的第八全连接层(FC8-10)上添加Softmax_T函数记作Soft_T;在第一浅层分类器的第一全连接层(FC1_10)上添加Softmax_T函数记作Soft_T1;在第二浅层分类器的第二全连接层(FC2_10)上添加Softmax_T函数记作Soft_T2;在第三浅层分类器的第三全连接层(FC3_10)上添加Softmax_T函数记作Soft_T3;
Figure BDA0003863374520000161
其中
Figure BDA0003863374520000162
为经过分类器θc(c=1,2,3,4)与Softmax函数相连接的全连接层之后第i类的输出结果,qic(c=1,2,3,4;i=1,2,3,4,5,6,7,8,9,10)为分类器θc(c=1,2,3,4)第i类的输出概率,当T设置为1,式(1)为普通的Softmax函数,T越大,标签就会越软。
在自蒸馏框架中,深层分类器仅仅受到真实标签(label)的监督。而三个浅层分类器中的每一个,在训练的时候要受到三方面来源的监督,三方面分别为:真实标签(label)、深层分类器的输出(FC8-10经过Soft_T函数的输出)以及深层分类器隐藏层(V11_layer4)的输出;基于此,自蒸馏训练时总的损失函数Loss有三部分(Loss1、Loss2、Loss3)构成,再通过添加两个超参数:α和λ,用来平衡三部分监督来源。
Loss1:从真实标签到深层分类器,再到所有浅层分类器的交叉熵损失;通过计算数据集中的真实标签label值与深层分类器和每个浅层分类器的Softmax输出得来的,如式(2)所示;通过这种方式,隐藏在数据集中的知识可直接从标签引入到三个浅层分类器;式(2)中Cr表示交叉熵损失函数,qi(i=4)表示深层分类器θc(c=4)中Softmax层的输出,qi(i=1,2,3)分别表示浅层分类器θc(c=1,2,3)中SoftmaxI(I=1,2,3)层的输出,y表示真实标签label的值;
(1-α)·Cr(qi,y) (2)。
Loss2:从深层分类器到各个浅层分类器的KL散度损失,如式(3)所示;将深层分类器的Soft_T输出结果引入到浅层分类器θc(c=1,2,3)的Soft_TI(I=1,2,3)层,通过这种方式,可以将深层分类器总结的知识转移到各个浅层分类器中;式(3)中KL表示KL散度,qj(j=1,2,3)表示浅层分类器θc(c=1,2,3)中Soft_TI(I=1,2,3)的输出,qC表示深层分类器Soft_T的输出;
α·KL(qj·qC) (3)。
Loss3:通过计算深层分类器与第一浅层分类器、第二浅层分类器的特征图之间的L2损失得到,如式(4)所示;通过这种方式,可以将深层分类器隐藏层的输出引入到第一浅层分类器1与第二浅层分类器2相对应的浅层模块中,式(4)中,Fi(i=1,2)表示浅层分类器θc(c=1,2)相对应的浅层模块对应全连接层(FCy_512,y=4,5)的输出,FC表示深层分类器隐藏层(V11_layer4)经过自适应平均池化层的输出;
Figure BDA0003863374520000171
综上,总的损失函数Loss由以上三部分构成,数学表达式如公式(5)所示:
Figure BDA0003863374520000172
VGG11(BN)网络在CIFAR10数据集上训练200个epoch,采用SGD优化器对神经网络进行优化,初始的学习率为0.1,当训练到66、133和190个epoch的时候,学习率除以10,weight_decay=5e-4,momentum=0.9;超参数α=0.3,λ=0.03;Batchsize为128,所有实验均在GPU设备上,pytorch1.9.1环境下进行。
图12中,Conv3×3,64表示卷积核大小为3,输出通道数为64的卷积层;Conv3×3,128表示卷积核大小为3,输出通道数为128的卷积层;Conv3×3,256表示卷积核大小为3,输出通道数为256的卷积层;Conv3×3,512表示卷积核大小为3,输出通道数为512的卷积层;maxpool表示池化核为2,步距为2的最大池化层;AdaptiveAvgPool表示自适应平均池化层,经过自适应平均池化层的特征图的长度和宽度都会变为1。
Picture表示尺寸为H×W×C的输入特征图_in,首先,经过V11_layer1模块处理:一共经过卷积层_1(Conv3×3,64),池化层_1(maxpool)和卷积层_2(Conv3×3,128);设卷积层_1输出特征图_mid1的高度为Hm1,宽度为Wm1,卷积层_1计算量为:3×3×C×64×Hm1×Wm1;卷积层_2输出特征图即为V11_layer1的输出特征图,记作输出特征图_mid2,输出特征图_mid2尺寸为H1×W1×C1,其中C1=128,卷积层_2计算量为:3×3×64×C1×H1×W1。
故V11_layer1层的计算总量为:3×3×C×64×Hm1×Wm1+3×3×64×C1×H1×W1,参数总量为:3×3×C×64+3×3×64×C1。
V11_layer1的输出特征图_mid2需经过V11_layer2层处理,一共经过池化层_2(maxpool)、卷积层_3(Conv3×3,256)与卷积层_4(Conv3×3,256);设卷积层_3输出特征图_mid3的高度为Hm2,宽度为Wm2,卷积层_3计算量为:3×3×C1×256×Hm2×Wm2;卷积层_4输出特征图即为V11_layer2的输出特征图,记作输出特征图_mid4,输出特征图_mid4尺寸为H2×W2×C2,其中C2=256,卷积层_4计算量为:3×3×256×C2×H2×W2;故V11_layer2层的计算总量为:3×3×C1×256×Hm2×Wm2+3×3×256×C2×H2×W2,参数总量为:3×3×C1×256+3×3×256×C2。
V11_layer2的输出特征图_mid4需经过V11_layer3层处理,一共经过池化层_3(maxpool)、卷积层_5(Conv3×3,512)与卷积层_6(Conv3×3,512);设卷积层_5输出特征图_mid5的高度为Hm3,宽度为Wm3,卷积层_5计算量为:3×3×C2×512×Hm3×Wm3,卷积层_6输出特征图即为V11_layer3的输出特征图,记作输出特征图_mid6,输出特征图_mid6尺寸为H3×W3×C3,其中C3=512,卷积层_6计算量为:3×3×512×C3×H3×W3;故V11_layer3层的计算总量为:3×3×C2×512×Hm3×Wm3+3×3×512×C3×H3×W3,参数总量为:3×3×C2×512+3×3×512×C3。
V11_layer3的输出特征图_mid6需经过V11_layer4层处理,一共经过池化层_4(maxpool)、卷积层_7(Conv3×3,512)与卷积层_8(Conv3×3,512);设卷积层_7输出特征图_mid7的高度为Hm4,宽度为Wm4,计算量为:3×3×C3×512×Hm4×Wm4,卷积层_8输出特征图即为V11_layer4的输出特征图,记作输出特征图_mid8,输出特征图_mid8尺寸为H4×W4×C4,其中C4=512,卷积层_8计算量为:3×3×512×C4×H4×W4;故V11_layer4层的计算总量为:3×3×C3×512×Hm4×Wm4+3×3×512×C4×H4×W4,参数总量为:3×3×C3×512+3×3×512×C4。
V11_layer4的输出特征图需经过池化层_5和第六全连接层、第七全连接层以及第八全连接层处理,具有4096个节点的第六全连接层FC6-4096,参数量与计算量为C4×4096,具有4096个节点的第七全连接层FC7-4096,参数量与计算量为4096×4096,具有10个节点的第八全连接层FC8-10,参数量与计算量为4096×10,故所有全连接层的参数总量与计算总量为:
C4×4096+4096×4096+4096×10。
所述第一浅层分类器主要由V11_layer1、第一注意力模块、第一浅层模块、第一全连接层(FC1_10)以及第四全连接层(FC4_512)构成;其中,V11_layer1层的计算总量为:3×3×C×64×Hm1×Wm1+3×3×64×C1×H1×W1,参数总量为:3×3×C×64+3×3×64×C1;V11_layer1的输出特征图_mid2,输出特征图_mid2尺寸为H1×W1×C1,经过第一注意力模块,尺寸不发生改变,再由第一浅层模块处理,第一浅层模块有3组模块结构,每组模块结构中的第一浅层注意力模块也不会改变所处理的特征图尺寸。设3组模块结构中所有深度卷积卷积核的尺寸均为h'×w'×1。
则第一浅层模块的参数总量为
Figure BDA0003863374520000201
总计算量为
Figure BDA0003863374520000202
Figure BDA0003863374520000203
第一浅层模块第j(j=1,2,3)组模块结构中经过第一次深度卷积,第一次深度卷积中间特征图尺寸变为Hj M0×Wj M0×Cj In,经过第一次逐点卷积,第一次逐点卷积中间特征图尺寸变为Hj M1×Wj M1×Cj In,经过第二次深度卷积,第二次深度卷积中间特征图尺寸变为了Hj M2×Wj M2×Cj In,再经过第二次逐点卷积,第二次逐点卷积中间特征图尺寸变为Hj M3×Wj M3×Cj out;其中,第一浅层模块输入特征图的通道数与V11_layer1层输出特征图的通道数相等,即C1 In=C1;第一浅层模块输出特征图的通道数为C3 out,经过自适应平均池化层,自适应平均池化层输出特征图尺寸变为1×1×C3 out,再与第一全连接层连接,具有10个节点的第一全连接层(FC1_10)参数量与计算量为10×C3 out,具有512个节点的第四全连接层(FC4_512)参数量与计算量为512×C3 out;因此,两个全连接层参数总量与计算总量为522×C3 out;综上,浅层分类器1的参数总量与计算总量分别为:
Figure BDA0003863374520000211
所述第二浅层分类器主要由V11_layer1、V11_layer2、第二注意力模块、第二浅层模块、第二全连接层(FC2_10)以及第五全连接层(FC5_512)构成;
其中,V11_layer2层的计算总量为:
3×3×C1×256×Hm2×Wm2+3×3×256×C2×H2×W2,V11_layer2层的参数总量为:3×3×C1×256+3×3×256×C2;V11_layer2的输出特征图_mid4,输出特征图_mid4尺寸为H2×W2×C2,经过第二注意力模块,尺寸不发生改变,再由第二浅层模块处理,第二浅层模块有2组模块结构,每组模块结构中的浅层注意力模块也不会改变所处理的特征图尺寸;设2组模块结构中所有深度卷积卷积核的尺寸为h'×w'×1,与第一浅层模块的模块结构深度卷积卷积核的尺寸相同;则第二浅层模块的参数总量与计算总量分别为:
Figure BDA0003863374520000212
第二浅层模块第t(t=1,2)组模块结构中经过第一次深度卷积,第一次深度卷积中间特征图尺寸变为Ht M0×Wt M0×Ct In,经过第一次逐点卷积,第一次逐点卷积中间特征图尺寸变为Ht M1×Wt M1×Ct In,经过第二次深度卷积,第二次深度卷积中间特征图尺寸变为了Ht M2×Wt M2×Ct In,再经过第二次逐点卷积,第二次逐点卷积中间特征图尺寸变为Ht M3×Wt M3×Ct out;其中,第二浅层模块输入特征图的通道数与V11_layer2层输出特征图的通道数相等,即C1 In=C2;第二浅层模块输出特征图的通道数为C2 out,经过自适应平均池化层,自适应平均池化层输出特征图尺寸变为1×1×C2 out,再与第二全连接层连接,具有10个节点的第二全连接层(FC2_10)参数量与计算量为10×C2 out,具有512个节点的第五全连接层(FC5_512)参数量与计算量为512×C2 out;因此,两个全连接层参数总量与计算总量为522×C2 out;综上,第二浅层分类器的参数总量与计算总量分别为:
Figure BDA0003863374520000221
第三浅层分类器主要由V11_layer1、V11_layer2、V11_layer3、第三注意力模块、第三浅层模块、第三全连接层(FC3-10)构成;其中,V11_layer3层的计算总量为:3×3×C2×512×Hm3×Wm3+3×3×512×C3×H3×W3,参数总量为:3×3×C2×512+3×3×512×C3;V11_layer3的输出特征图_mid6尺寸为H3×W3×C3,经过第三注意力模块,尺寸不发生改变,再由第三浅层模块处理,第三浅层模块有1组模块结构,该组模块结构中深度卷积卷积核的尺寸为h'×w'×1,深度卷积卷积核的尺寸与第二浅层模块的模块结构深度卷积卷积核的尺寸相同,则第三浅层模块的参数总量与计算总量分别为:
(h'×w'×C1 In+C1 In×C1 In)+(h'×w'×C1 In+C1 In×C1 out)、
h'×w'×C1 In×H1 M0×W1 M0+C1 In×C1 In×H1 M1×W1 M1+h'×w'×C1 In×H1 M2×W1 M2+C1 In×C1 out×H1 M3×W1 M3
第三浅层模块浅层卷积层中经过第一次深度卷积,第一次深度卷积中间特征图尺寸变为H1 M0×W1 M0×C1 In,经过第一次逐点卷积,第一次逐点卷积中间特征图尺寸变为H1 M1×W1 M1×C1 In,经过第二次深度卷积,第二次深度卷积中间特征图尺寸变为H1 M2×W1 M2×C1 In,再经过第二次逐点卷积,第二次逐点卷积中间特征图尺寸变为H1 M3×W1 M3×C1 out;其中,第三浅层模块输入特征图的通道数与V11_layer3层输出特征图的通道数相等,即C1 In=C3;第三浅层模块输出特征图的通道数为C1 out,经过自适应平均池化层,自适应平均池化层输出特征图尺寸变为1×1×C1 out,再与第三全连接层(FC3_10)连接,具有10个节点的第三全连接层(FC3_10),参数量与计算量为10×C1 out;综上,第三浅层分类器的参数总量与计算总量分别为:
Figure BDA0003863374520000231
Figure BDA0003863374520000232
综上:深层分类器VGG11(BN)主要由V11_layer1、V11_layer2、V11_layer3、V11_layer4、一个池化层(maxpool)和三个全连接层(FC6-4096、FC7-4096、FC8-10)构成,参数总量与计算总量分别为:
Figure BDA0003863374520000233
Figure BDA0003863374520000234
综上,本发明所述的技术方案,通过构建更加轻量的浅层分类器,减少计算量,降低了模型自蒸馏所需的时长,并降低浅层分类器部署在嵌入式设备上的难度;同时,在确保原自蒸馏相近的效果基础上,部分浅层分类器的性能甚至还超过了原始的浅层分类器。

Claims (10)

1.一种用于图像分类模型训练的自蒸馏训练方法,其特征是包括:
1、将图像分类模型自身作为深层分类器,使用深层分类器作为教师网络;
2、将深层分类器按照网络深度划分出四个模块:第一个模块(Block1)、第二个模块(Block2)、第三个模块(Block3)、第四个模块(Block4);在第一个模块的基础上依次增加第一注意力模块、第一浅层模块、第一全连接层作为第一浅层分类器;在第一个模块与第二个模块的基础上依次增加第二注意力模块、第二浅层模块、第二全连接层作为第二浅层分类器;在第一个模块、第二个模块、第三个模块的基础上依次增加第三注意力模块、第三浅层模块、第三全连接层作为第三浅层分类器;将第一浅层分类器、第二浅层分类器、第三浅层分类器全部作为学生网络;同时,在第一浅层模块的基础上增加第四全连接层,在第二浅层模块的基础上增加第五全连接层;
3、使用数据集来进行自蒸馏训练,得到一个深层分类器和三个浅层分类器。
2.根据权利要求1所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述深层分类器的第四个模块包含卷积模块与输出模块两个模块;在深层分类器第四个模块内卷积模块的基础上增加自适应平均池化层,用于辅助自蒸馏训练,便于深层分类器第四个模块内卷积模块的“知识”通过第四全连接层传授给第一浅层模块,通过第五全连接层传授给第二浅层模块。
3.根据权利要求1所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第一个模块(Block1)的输出特征图output1,作为第二个模块(Block2)的输入特征图,同时也作为第一注意力模块的输入特征图;第二个模块(Block2)的输出特征图output2,作为第三个模块(Block3)的输入特征图,同时也作为第二注意力模块的输入特征图;第三个模块(Block3)的输出特征图output3,作为第四个模块(Block4)的输入特征图,同时也作为第三注意力模块的输入特征图;在第四个模块(Block4)内部,第四个模块卷积模块的输出特征图作为自适应平均池化层的输入特征图,同样也作为第四个模块中输出模块的输入特征图。
4.根据权利要求3所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第一注意力模块对输入特征图output1的处理流程具体包括如下步骤:
1)将尺寸为H'×W'×C'的输入特征图output1按照通道数C'分为n组,得到n个尺寸为H'×W'×C'/n的中间特征图bi(i=1,2,…,n);
2)经过全局平均池化层对中间特征图bi(i=1,2,…,n)进行全局平均池化,得到n个尺寸为1×1×C'/n的第一特征图gi(i=1,2,…,n),将第一特征图与中间特征图进行对位点乘,即第一特征图gi与中间特征图bi对位点乘(i=1,2,…,n)得到n个初始注意力掩码ci(i=1,2,…,n),对n个初始注意力掩码ci(i=1,2,…,n)中的每一个初始注意力掩码分别求均值与标准差,将每一个初始注意力掩码进行标准化处理,得到n个H'×W'×1的第二特征图di(i=1,2,…,n);
3)将n个第二特征图di(i=1,2,…,n)中的每一个第二特征图使用Sigmoid函数激活得到最终的n个注意力掩码ei(i=1,2,…,n),n个注意力掩码ei(i=1,2,…,n)分别与相应组别的n个中间特征图bi(i=1,2,…,n)对位点乘,最终得到n个尺寸为H'×W'×C'/n的小组输出特征图fi(i=1,2,…,n);之后将这n个尺寸为H'×W'×C'/n的小组输出特征图fi(i=1,2,…,n)拼接为最终的尺寸为H'×W'×C'的输出特征图J1,输出特征图J1与输入特征图output1尺寸相同;第二注意力模块、第三注意力模块对各自输入特征图output2、输入特征图output3的处理流程与第一注意力模块对输入特征图output1处理流程完全相同,第二注意力模块的输出特征图J2与第二注意力模块的输入特征图output2尺寸相同;第三注意力模块的输出特征图J3与第三注意力模块的输入特征图output3尺寸相同。
5.根据权利要求1所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第一浅层模块、第二浅层模块、第三浅层模块分别为模块深度不同的浅层模块;所述第一浅层模块包含了三组模块结构和一个自适应平均池化层;第一注意力模块的输出特征图J1作为第一浅层模块内三组模块结构中的第一组模块结构的输入特征图经过第一组模块的处理之后,得到输出特征图R1_1作为第二组模块结构的输入特征图,特征图R1_1经过第二组模块的处理之后,得到第二组模块结构的输出特征图R1_2作为第三组模块结构的输入特征图,特征图R1_2经过第三组模块的处理之后得到第三组模块结构的输出特征图R1_3,第三组模块结构的输出特征图R1_3作为第一浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R1_4。
6.根据权利要求5所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第一浅层模块中的三组模块结构中的每一组模块结构完全相同,每一组模块结构均包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第一浅层注意力模块。
7.根据权利要求1所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第二浅层模块包含了两组模块结构和一个自适应平均池化层;第二注意力模块的输出特征图J2作为第二浅层模块内两组模块结构中的第一组模块结构的输入特征图经过第一组模块的处理之后,得到输出特征图R2_1作为第二组模块结构的输入特征图,输入特征图R2_1经过第二组模块结构的处理之后,得到第二组模块结构的输出特征图R2_2;输出特征图R2_2作为第二浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R2_3。
8.根据权利要求7所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第二浅层模块中的两组模块结构中的每一组模块结构完全相同,每一组模块结构均包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第二浅层注意力模块。
9.根据权利要求1所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述第三浅层模块包含了一组模块结构和一个自适应平均池化层;第三注意力模块的输出特征图J3作为第三浅层模块内模块结构的输入特征图经过模块结构的处理之后,得到输出特征图R3_1,输出特征图R3_1作为第三浅层模块最后自适应平均池化层的输入特征图,经过自适应平均池化层处理后得到输出特征图R3_2;所述第三浅层模块中的一组模块结构包含步距为2的第一个深度卷积层、步距为1的第一个逐点卷积层、步距为1的第二个深度卷积层、步距为1的第二个逐点卷积层与第三浅层注意力模块。
10.根据权利要求1所述的一种用于图像分类模型训练的自蒸馏训练方法,其特征是所述使用数据集来进行自蒸馏训练,具体包括如下步骤:
首先,对CIFAR10数据集的训练集与CIFAR100数据集的训练集进行如下处理:
(1)对图片进行随机裁剪,裁剪后尺寸(size)为32,填充边界的值(padding)设置为4,填充值(fill)设置为128;
(2)对图片进行随机水平翻转;
(3)将图片格式转换为tensor格式;将图片的每一个数值归一化到[0,1];
(4)将图片的每一个数值进行标准化处理,标准化处理的均值为(0.4914,0.4822,0.4465),方差为(0.2023,0.1994,0.2010);
其次,对CIFAR10数据集的测试集与CIFAR100数据集的测试集进行如下处理:
(1)将图片格式转换为tensor格式;将图片的每一个数值归一化到[0,1];
(2)将图片的每一个数值进行标准化处理,标准化处理的均值为(0.4914,0.4822,0.4465),标准差为(0.2023,0.1994,0.2010)。
CN202211173732.1A 2022-09-26 2022-09-26 一种用于图像分类模型训练的自蒸馏训练方法 Pending CN115512156A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211173732.1A CN115512156A (zh) 2022-09-26 2022-09-26 一种用于图像分类模型训练的自蒸馏训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211173732.1A CN115512156A (zh) 2022-09-26 2022-09-26 一种用于图像分类模型训练的自蒸馏训练方法

Publications (1)

Publication Number Publication Date
CN115512156A true CN115512156A (zh) 2022-12-23

Family

ID=84506902

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211173732.1A Pending CN115512156A (zh) 2022-09-26 2022-09-26 一种用于图像分类模型训练的自蒸馏训练方法

Country Status (1)

Country Link
CN (1) CN115512156A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116310667A (zh) * 2023-05-15 2023-06-23 鹏城实验室 联合对比损失和重建损失的自监督视觉表征学习方法
CN116416456A (zh) * 2023-01-13 2023-07-11 北京数美时代科技有限公司 基于自蒸馏的图像分类方法、系统、存储介质和电子设备

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116416456A (zh) * 2023-01-13 2023-07-11 北京数美时代科技有限公司 基于自蒸馏的图像分类方法、系统、存储介质和电子设备
CN116416456B (zh) * 2023-01-13 2023-10-24 北京数美时代科技有限公司 基于自蒸馏的图像分类方法、系统、存储介质和电子设备
CN116310667A (zh) * 2023-05-15 2023-06-23 鹏城实验室 联合对比损失和重建损失的自监督视觉表征学习方法
CN116310667B (zh) * 2023-05-15 2023-08-22 鹏城实验室 联合对比损失和重建损失的自监督视觉表征学习方法

Similar Documents

Publication Publication Date Title
CN115512156A (zh) 一种用于图像分类模型训练的自蒸馏训练方法
Zoumpourlis et al. Non-linear convolution filters for cnn-based learning
CN109584337B (zh) 一种基于条件胶囊生成对抗网络的图像生成方法
CN112348191B (zh) 一种基于多模态表示学习的知识库补全方法
CN106447626A (zh) 一种基于深度学习的模糊核尺寸估计方法与系统
CN112381211B (zh) 基于异构平台执行深度神经网络的系统及方法
CN106326899A (zh) 一种基于高光谱图像和深度学习算法的烟叶分级方法
Lin et al. Attribute-Aware Convolutional Neural Networks for Facial Beauty Prediction.
CN109753570A (zh) 一种基于Horn逻辑与图神经网络的场景图谱向量化方法
CN111192291A (zh) 一种基于级联回归与孪生网络的目标跟踪方法
CN113627376B (zh) 基于多尺度密集连接深度可分离网络的人脸表情识别方法
CN110197217B (zh) 一种基于深度交错融合分组卷积网络的图像分类方法
CN108596264A (zh) 一种基于深度学习的社区发现方法
CN106411572A (zh) 一种结合节点信息和网络结构的社区发现方法
CN114610897A (zh) 基于图注意力机制的医学知识图谱关系预测方法
CN111401117A (zh) 基于双流卷积神经网络的新生儿疼痛表情识别方法
CN113379655A (zh) 一种基于动态自注意力生成对抗网络的图像合成方法
CN112508181A (zh) 一种基于多通道机制的图池化方法
CN113610163A (zh) 一种基于知识蒸馏的轻量级苹果叶片病害识别方法
CN113610192A (zh) 一种基于连续性剪枝的神经网络轻量化方法及系统
CN115761240A (zh) 一种混沌反向传播图神经网络的图像语义分割方法及装置
CN115620238A (zh) 一种基于多元信息融合的园区行人属性识别方法
CN112613405B (zh) 任意视角动作识别方法
CN114549962A (zh) 一种园林植物叶病分类方法
CN111582202A (zh) 一种智能网课系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination