CN111898735A - 蒸馏学习方法、装置、计算机设备和存储介质 - Google Patents
蒸馏学习方法、装置、计算机设备和存储介质 Download PDFInfo
- Publication number
- CN111898735A CN111898735A CN202010674185.XA CN202010674185A CN111898735A CN 111898735 A CN111898735 A CN 111898735A CN 202010674185 A CN202010674185 A CN 202010674185A CN 111898735 A CN111898735 A CN 111898735A
- Authority
- CN
- China
- Prior art keywords
- student
- teacher
- feature map
- network
- loss value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 46
- 238000004821 distillation Methods 0.000 title claims abstract description 34
- 238000010586 diagram Methods 0.000 claims abstract description 108
- 238000009826 distribution Methods 0.000 claims description 120
- 238000004590 computer program Methods 0.000 claims description 25
- 238000004364 calculation method Methods 0.000 claims description 20
- 230000000694 effects Effects 0.000 abstract description 5
- 230000004044 response Effects 0.000 description 13
- 230000008569 process Effects 0.000 description 9
- 230000011218 segmentation Effects 0.000 description 8
- 238000012549 training Methods 0.000 description 7
- 238000005516 engineering process Methods 0.000 description 4
- 239000000284 extract Substances 0.000 description 3
- 241000282326 Felis catus Species 0.000 description 2
- 230000006870 function Effects 0.000 description 2
- 238000013140 knowledge distillation Methods 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 239000011159 matrix material Substances 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及一种蒸馏学习方法、装置、计算机设备和存储介质。所述方法包括:将目标图像分别输入教师网络和学生网络,得到教师网络输出的教师特征图和学生网络输出的学生特征图;将教师特征图和学生特征图进行通道匹配,根据匹配结果获取教师特征图和学生特征图之间的目标损失值;根据目标损失值调整学生网络中的参数,得到目标学生网络。采用本方法能够提高学生网络对教师网络的蒸馏学习效果,减小学生网络与教师网络的性能差异。
Description
技术领域
本申请涉及机器学习技术领域,特别是涉及一种蒸馏学习方法、装置、计算机设备和存储介质。
背景技术
随着机器学习技术的发展,出现了蒸馏学习技术,蒸馏学习采用的是迁移学习,利用预先训练好的复杂网络模型(教师网络,Teacher model)的输出作为监督信号去训练另外一个简单的网络模型(学生网络,Student model),以获得结构精简且计算复杂度低,同时具有教师网络的知识的学生网络。
传统技术中,大多从像素点级的蒸馏、关系对蒸馏以及基于判别器的全局蒸馏这3个方面进行展开,但其蒸馏学习效果较差,得到的学生网络与教师网络的性能相差很大。
发明内容
基于此,有必要针对上述技术问题,提供一种蒸馏学习方法、装置、计算机设备和存储介质。
一种蒸馏学习方法,所述方法包括:
将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;
将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;
根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
在其中一个实施例中,所述将所述教师特征图和所述学生特征图进行通道匹配,包括:
将所述教师特征图和所述学生特征图的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号;其中,相同通道类型的所述教师通道编号和所述学生通道编号相同;
遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配。
在其中一个实施例中,所述遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配,包括:
按照所述教师通道编号由小到大的顺序,将每一所述教师特征图的通道再以所述学生通道编号由小到大的顺序,依次与每一所述学生特征图的通道进行编号匹配。
在其中一个实施例中,在所述根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值之前,包括:
根据所述教师特征图进行softmax计算,得到第一概率分布;
根据所述学生特征图进行softmax计算,得到第二概率分布;
相应地,所述根据匹配结果计算所述教师特征图和所述学生特征图之间的目标损失值,包括:
根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值。
在其中一个实施例中,所述根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值,包括:
获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;
根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取所述第一KL散度的正值,作为第一损失值;
获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;
根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取所述第二KL散度的负值,作为第二损失值;
根据所述第一损失值和所述第二损失值,得到所述目标损失值。
在其中一个实施例中,所述根据所述第一损失值和所述第二损失值,得到所述目标损失值,包括:
将所述第一损失值加上所述第二损失值,得到所述目标损失值。
在其中一个实施例中,所述根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络,包括:
根据所述目标损失值调整所述学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将所述参数调整后的学生网络作为所述目标学生网络。
一种蒸馏学习装置,所述装置包括:
特征输出模块,用于将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;
通道匹配模块,用于将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;
参数调整模块,用于根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;
将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;
根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;
将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;
根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
上述蒸馏学习方法、装置、计算机设备和存储介质,通过将目标图像分别输入教师网络和学生网络,得到教师网络中不同通道输出的教师特征图和学生网络中不同通道输出的学生特征图。由于教师特征图和学生特征图上的响应分布与输入的目标图像中对应语义的目标分布一致,因此可直接采用通道匹配的教师特征图训练学生网络,以实现教师网络对学生网络的知识迁移,提高学生网络对教师网络的蒸馏学习效果,减小学生网络与教师网络的性能差异。计算机设备具体是将教师特征图和学生特征图进行通道匹配,根据匹配结果获取教师特征图和学生特征图之间的目标损失值;根据目标损失值调整学生网络中的参数,以得到目标学生网络。教师特征图和学生特征图上的响应分布可直观的展现目标图像上的结构化信息,响应高的就对应着当前通道所提取的语义目标区域,即前景区域,响应低的就对应着其他目标(其它语义的目标以及背景),即背景区域。教师网络到学生网络知识蒸馏的成功需确保学生网络学习到更多的前景知识以及更少的背景知识,采用教师网络中通道级的教师特征图对学生特征图进行前后景相应分布的对齐,对齐的过程中不断调整学生网络中的参数,减小学生网络与教师网络的性能差异,以得到学习了教师网络优越性能的目标学习网络,提高了目标学生网络的进行语义分割的准确性。
附图说明
图1为一个实施例中蒸馏学习方法的流程示意图;
图2为一个实施例中通道匹配的流程示意图;
图3为一个实施例中获取目标损失值的流程示意图;
图4为另一个实施例中获取目标损失值的流程示意图;
图5为另一个实施例中蒸馏学习方法的流程示意图;
图6为一个实施例中蒸馏学习方法的应用示意图;
图7为一个实施例中蒸馏学习装置的结构框图;
图8为一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
在一个实施例中,如图1所示,提供了一种蒸馏学习方法,本实施例以该方法应用于终端进行举例说明,可以理解的是,该方法也可以应用于服务器,还可以应用于包括终端和服务器的系统,并通过终端和服务器的交互实现。其中,终端可以但不限于是各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备,服务器可以用独立的服务器或者是多个服务器组成的服务器集群来实现。本实施例中,该方法包括以下步骤:
S110、将目标图像分别输入教师网络和学生网络,得到教师网络输出的教师特征图和学生网络输出的学生特征图。
其中,教师网络是以resnet101-pspnet作为初始网络,学生网络是以resnet18-pspnet作为初始网络,采用主流数据集如cityscape或者Pascal Voc作为训练样本,分别输入上述初始网络进行训练得到的用于进行语义分割的网络模型。
具体地,计算机设备将目标图像输入教师网络得到教师网络输出的用于语义分割的最终教师特征图,将目标图像输入学生网络得到学生网络输出用于语义分割的最终学生特征图。计算机设备在最终教师特征图中提取得到每个通道的二维切片,作为相应通道得到的教师特征图,同样在最终学生特征图中提取得到每个通道的二维切片,作为相应通道得到的学生特征图。
S120、将教师特征图和学生特征图进行通道匹配,根据匹配结果获取教师特征图和学生特征图之间的目标损失值。
其中,目标损失值可用于表征教师网络与学生网络之间各个通道得到的教师特征图与学生特征图之间的整体差异。目标损失值越大,教师特征图与学生特征图之间的整体差异就越大。
具体地,计算机设备将教师特征图和学生特征图进行通道类型的匹配,根据匹配结果获取教师特征图和学生特征图之间的目标损失值。例如,计算通道类型匹配的教师特征图和学生特征图之间的损失值,以及通道类型不匹配的教师特征图和学生特征图之间的损失值,再根据上述两种损失值确定教师特征图和学生特征图之间整体上的目标损失值。
S130、根据目标损失值调整学生网络中的参数,得到目标学生网络。
具体地,计算机设备根据得到的目标损失值调整学生网络中的参数,使得通过参数调整后的学生网络得到的目标损失值不断减小,以得到目标学生网络。
本实施例中,计算机设备将目标图像分别输入教师网络和学生网络,得到教师网络中不同通道输出的教师特征图和学生网络中不同通道输出的学生特征图。由于教师特征图和学生特征图上的响应分布与输入的目标图像中对应语义的目标分布一致,因此可直接采用通道匹配的教师特征图训练学生网络,以实现教师网络对学生网络的知识迁移,提高学生网络对教师网络的蒸馏学习效果,减小学生网络与教师网络的性能差异。
计算机设备具体是将教师特征图和学生特征图进行通道匹配,根据匹配结果获取教师特征图和学生特征图之间的目标损失值;根据目标损失值调整学生网络中的参数,以得到目标学生网络。教师特征图和学生特征图上的响应分布可直观的展现目标图像上的结构化信息,响应高的就对应着当前通道所提取的语义目标区域,即前景区域,响应低的就对应着其他目标(其它语义的目标以及背景),即背景区域。教师网络到学生网络知识蒸馏的成功需确保学生网络学习到更多的前景知识以及更少的背景知识,采用教师网络中通道级的教师特征图对学生特征图进行前后景相应分布的对齐,对齐的过程中不断调整学生网络中的参数,减小学生网络与教师网络的性能差异,以得到学习了教师网络优越性能的目标学习网络,提高了目标学生网络的进行语义分割的准确性。
在一个实施例中,为了提高通道匹配的效率,如图2所示,S120包括:
S210、将教师特征图和学生特征图中的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号。
其中,相同通道类型的教师通道编号和学生通道编号相同。
具体地,计算机设备将教师特征图的通道和学生特征图的通道按照通道类型分别进行编号,并将教师特征图和学生特征图之间相同通道类型的通道编为相同的编号,得到教师通道编号和学生通道编号。例如,教师网络和学生网络是用于语义分割人、猫和狗的网络模型,计算机设备将提取得到人的特征图的通道编为1,将提取得到猫的特征图的通道编为2,将提取得到狗的特征图的通道编为3,则得到教师通道编号c=1,2,3和学生通道编号j=1,2,3。
S220、遍历教师通道编号和学生通道编号进行对应教师特征图的通道和学生特征图的通道的编号匹配。
其中,若通道的编号匹配,则通道匹配;若通道的编号不匹配,则通道不匹配。
具体地,计算机设备按照教师通道编号由小到大的顺序,将每一教师特征图的通道再以学生通道编号由小到大的顺序,依次与每一学生特征图的通道进行编号匹配。例如,计算机设备将教师通道编号c=1的通道依次与学生通道编号j=1,2,3的通道进行编号匹配,再将教师通道编号c=2的通道依次与学生通道编号j=1,2,3的通道进行编号匹配,再将教师通道编号c=3的通道依次与学生通道编号j=1,2,3的通道进行编号匹配。
进一步地,在编号匹配的过程中还包括:
计算机设备将教师通道编号c的通道依次与学生通道编号j的通道进行编号匹配结束后,判断c是否等于预设教师通道数C。若否,再将教师通道编号c+的通道依次与学生通道编号j的通道进行编号匹配;若是,则停止编号匹配。
同时,计算机设备判断与教师通道编号c的通道进行编号匹配的学生通道编号j是否等于预设教师通道数C。若否,再将教师通道编号c的通道与学生通道编号j+的通道进行编号匹配;若是,则将教师通道编号c+的通道依次与学生通道编号j的通道进行编号匹配。
本实施例中,计算机设备对教师特征图和学生特征图中的通道进行编号,并且将相同通道类型的通道编为相同的编号,以此利用编号匹配实现教师特征图和学生特征图之间的通道匹配,编号匹配的过程简单,方便,提高了教师特征图和学生特征图之间通道匹配的效率。
在一个实施例中,可对特征图进行softmax计算,将响应分布转换为概率分布,进而获取特征图之间的损失值。如图3所示,在根据匹配结果获取教师特征图和学生特征图之间的目标损失值之前,包括:
S310、根据教师特征图进行softmax计算,得到第一概率分布。
S320、根据学生特征图进行softmax计算,得到第二概率分布。
其中,softmax计算是通过softmax公式(1)将特征图中响应分布的数字矩阵转换为概率分布。
其中,i为数字矩阵中的第i个的数值,i∈j。
具体地,计算机设备分别对教师特征图和学生特征图中响应分布的数字矩阵采用softmax公式(1)进行softmax计算,对应得到第一概率分布和第二概率分布。
相应地,S130包括:
根据匹配结果、第一概率分布以及第二概率分布计算目标损失值。
具体地,计算机设备根据通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布,计算通道匹配的教师特征图和学生特征图之间的损失值。计算机设备再根据通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布,计算通道不匹配的教师特征图和学生特征图之间的损失值。计算机设备结合上述通道匹配得到的损失值和通道不配的得到的损失值确定教师特征图和学生特征图整体上的目标损失值。
本实施例中,计算机设备根据教师特征图和学生特征图上的响应分布进行softmax计算,对应得到第一概率分布和第二概率分布,以将教师特征图和学生特征图上的响应分布转换为概率分布,为计算教师特征图和学生特征图之间的目标损失值做好数据基础。
在一个实施例中,可采用KL(Kullback-Leibler divergence)散度即相对熵作为教师特征图和学生特征图之间的损失值。如图4所示,上述根据匹配结果、第一概率分布以及第二概率分布计算目标损失值,包括:
S410、获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布。
S420、根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取第一KL散度的正值,作为第一损失值。
S430、获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布。
S440、根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取第二KL散度的负值,作为第二损失值。
S450、根据第一损失值和第二损失值,得到目标损失值。
其中,KL散度DKL可根据真实事件的概率分布p和拟合事件的概率分q布采用KL公式(2)得到。本实施例中,真实事件即为教师网络得到的教师特征图,拟合事件即为学生网络得到的学生特征图。
其中,p(xi)为真实事件的概率值,p(xi)为拟合事件的概率值。
具体地,计算机设备采用KL散度公式(2)根据通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布,计算通道匹配教师特征图和学生特征图之间的第一KL散度,根据通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布,计算通道不匹配的教师特征图和学生特征图之间的第一KL散度。计算机设备进一步取第一KL散度的正值作为第一损失值,取第二KL散度的负值作为第二损失值,并将第一损失值加上第二损失值,得到目标损失值。
本实施例中,教师网络对学生网络的训练目的是通过调整学生网络中的参数使得得到的目标损失函数减小。上述方法中计算机设备取根据通道匹配得到的第一K散度的正值作为第一损失值,取根据通道不匹配得到的第二K散度的正值作为第二损失值,再将第一损失值加上第二损失值作为目标损失值。这样在调整学生网络中的参数使得得到的目标损失函数减小的过程中,可使得第一损失值减小,以减小通道匹配的教师特征图和学生特征图之间的差异,也可以使得第二损失值增大,以增大通道不匹配的教师特征图和学生特征图之间的差异,从而实现学生网络学习到更多的前景至少以及更少的背景知识,使得最终得到的目标学生网络学习到教师网络的优越性能,提高了目标学生网络的性能,以及进行语义分割的准确性。
在一个实施例中,目标损失值达到预设损失值时,即可停止调整参数,得到目标学生网络,则S130包括:
根据目标损失值调整学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将参数调整后的学生网络作为目标学生网络。
具体地,计算机设备根据得到的目标损失值采用梯度更新的方式调整学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,即可停止调整参数,将参数调整后的学生网络作为目标学生网络,训练即结束。训练结束后,还可重新采用数据集对得到的目标学生网络的性能进行验证。
本实施例中,计算机设备通过预设损失值,在确保目标学生网络性能的同时,缩短了训练时间,提高了训练效率。
在一个具体的实施例中,如图5所示,提供了一种蒸馏学习方法,包括:
S501、将目标图像分别输入教师网络和学生网络,得到教师网络输出的教师特征图和学生网络输出的学生特征图。
S502、将教师特征图和学生特征图的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号;其中,相同通道类型的教师通道编号和学生通道编号相同。
S503、按照教师通道编号由小到大的顺序,将每一教师特征图的通道再以学生通道编号由小到大的顺序,依次与每一学生特征图的通道进行编号匹配。
S504、根据教师特征图进行softmax计算,得到第一概率分布。
S505、根据学生特征图进行softmax计算,得到第二概率分布。
S506、获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布。
S507、根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取第一KL散度的正值,作为第一损失值。
S508、获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布。
S509、根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取第二KL散度的负值,作为第二损失值。
S510、将第一损失值加上第二损失值,得到目标损失值。
S511、根据目标损失值调整学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将参数调整后的学生网络作为目标学生网络。
结合图6,计算机设备将目标图像分别输入教师网络和学生网络,将过多次特征提取后,输出用于进行语义分割的最终教师特征图和最终学生特征图。计算机设备在最终教师特征图中提取得到每个通道的二维切片,作为相应通道得到的教师特征图,同理得到学生特征图。计算机设备再对教师特征图和学生特征图的通道进行编号,利用编号进行通道的匹配,取通道匹配的教师特征图和学生特征图之间第一KL散度的正值与通道不匹配的教师特征图和学生特征图之间第二KL散度的负值之和,得到目标损失值。计算机设备再根据目标损失值调整学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,得到目标学生网络,训练结束。
本实施例中,计算机设备通过上述方法实现性能较弱的学生网络中的前后景分布向性能优越的教师网络中的前后景分布对齐,基于通道级的教师特征图和学生特征图的比较,使得学生网络学习到更多的前景至少以及更少的背景知识,提高学生网络对教师网络的蒸馏学习效果,减小学生网络与教师网络的性能差异,得到学习到教师网络的优越性能的目标学生网络,提高了目标学生网络的性能,以及进行语义分割的准确性。
应该理解的是,虽然图1-5的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1-5中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
在一个实施例中,如图7所示,提供了一种蒸馏学习装置,包括:特征输出模块701、通道匹配模块702和参数调整模块703,其中:
特征输出模块701用于将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;通道匹配模块702用于将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;参数调整模块703用于根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
在其中一个实施例中,通道匹配模块702具体用于:
将所述教师特征图和所述学生特征图的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号;其中,相同通道类型的所述教师通道编号和所述学生通道编号相同;遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配。
在其中一个实施例中,通道匹配模块702具体用于:
按照所述教师通道编号由小到大的顺序,将每一所述教师特征图的通道再以所述学生通道编号由小到大的顺序,依次与每一所述学生特征图的通道进行编号匹配。
在其中一个实施例中,所述装置还包括:概率计算模块;概率计算模块用于根据所述教师特征图进行softmax计算,得到第一概率分布;根据所述学生特征图进行softmax计算,得到第二概率分布。
相应地,通道匹配模块702具体用于:根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值。
在其中一个实施例中,通道匹配模块702具体用于:
获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取所述第一KL散度的正值,作为第一损失值;获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取所述第二KL散度的负值,作为第二损失值;根据所述第一损失值和所述第二损失值,得到所述目标损失值。
在其中一个实施例中,通道匹配模块702具体用于:
将所述第一损失值加上所述第二损失值,得到所述目标损失值。
在其中一个实施例中,参数调整模块703具体用于:
根据所述目标损失值调整所述学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将所述参数调整后的学生网络作为所述目标学生网络。
关于蒸馏学习装置的具体限定可以参见上文中对于蒸馏学习方法的限定,在此不再赘述。上述蒸馏学习装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图8所示。该计算机设备包括通过系统总线连接的处理器、存储器和网络接口。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机程序和数据库。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的数据库用于存储蒸馏学习数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种蒸馏学习方法。
本领域技术人员可以理解,图8中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现以下步骤:
将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
将所述教师特征图和所述学生特征图的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号;其中,相同通道类型的所述教师通道编号和所述学生通道编号相同;遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
按照所述教师通道编号由小到大的顺序,将每一所述教师特征图的通道再以所述学生通道编号由小到大的顺序,依次与每一所述学生特征图的通道进行编号匹配。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
根据所述教师特征图进行softmax计算,得到第一概率分布;根据所述学生特征图进行softmax计算,得到第二概率分布;根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取所述第一KL散度的正值,作为第一损失值;获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取所述第二KL散度的负值,作为第二损失值;根据所述第一损失值和所述第二损失值,得到所述目标损失值。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
将所述第一损失值加上所述第二损失值,得到所述目标损失值。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:
根据所述目标损失值调整所述学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将所述参数调整后的学生网络作为所述目标学生网络。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现以下步骤:
将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
将所述教师特征图和所述学生特征图的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号;其中,相同通道类型的所述教师通道编号和所述学生通道编号相同;遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
按照所述教师通道编号由小到大的顺序,将每一所述教师特征图的通道再以所述学生通道编号由小到大的顺序,依次与每一所述学生特征图的通道进行编号匹配。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
根据所述教师特征图进行softmax计算,得到第一概率分布;根据所述学生特征图进行softmax计算,得到第二概率分布;根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取所述第一KL散度的正值,作为第一损失值;获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取所述第二KL散度的负值,作为第二损失值;根据所述第一损失值和所述第二损失值,得到所述目标损失值。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
将所述第一损失值加上所述第二损失值,得到所述目标损失值。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:
根据所述目标损失值调整所述学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将所述参数调整后的学生网络作为所述目标学生网络。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-Only Memory,ROM)、磁带、软盘、闪存或光存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic Random Access Memory,DRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。
Claims (10)
1.一种蒸馏学习方法,其特征在于,所述方法包括:
将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;
将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;
根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
2.根据权利要求1所述的方法,其特征在于,所述将所述教师特征图和所述学生特征图进行通道匹配,包括:
将所述教师特征图和所述学生特征图的通道按照通道类型分别进行编号,得到教师通道编号和学生通道编号;其中,相同通道类型的所述教师通道编号和所述学生通道编号相同;
遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配。
3.根据权利要求2所述的方法,其特征在于,所述遍历所述教师通道编号和所述学生通道编号进行对应所述教师特征图的通道和所述学生特征图的通道的编号匹配,包括:
按照所述教师通道编号由小到大的顺序,将每一所述教师特征图的通道再以所述学生通道编号由小到大的顺序,依次与每一所述学生特征图的通道进行编号匹配。
4.根据权利要求1所述的方法,其特征在于,在所述根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值之前,包括:
根据所述教师特征图进行softmax计算,得到第一概率分布;
根据所述学生特征图进行softmax计算,得到第二概率分布;
相应地,所述根据匹配结果计算所述教师特征图和所述学生特征图之间的目标损失值,包括:
根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值。
5.根据权利要求4所述的方法,其特征在于,所述根据所述匹配结果、所述第一概率分布以及所述第二概率分布计算所述目标损失值,包括:
获取通道匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;
根据通道匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第一KL散度,取所述第一KL散度的正值,作为第一损失值;
获取通道不匹配的教师特征图的第一概率分布,和学生特征图的第二概率分布;
根据通道不匹配的第一概率分布和第二概率分布上的概率值计算教师特征图和学生特征图之间的第二KL散度,取所述第二KL散度的负值,作为第二损失值;
根据所述第一损失值和所述第二损失值,得到所述目标损失值。
6.根据权利要求5所述的方法,其特征在于,所述根据所述第一损失值和所述第二损失值,得到所述目标损失值,包括:
将所述第一损失值加上所述第二损失值,得到所述目标损失值。
7.根据权利要求1所述的方法,其特征在于,所述根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络,包括:
根据所述目标损失值调整所述学生网络中的参数,直至参数调整后的学生网络所得到的目标损失值小于预设损失值,将所述参数调整后的学生网络作为所述目标学生网络。
8.一种蒸馏学习装置,其特征在于,所述装置包括:
特征输出模块,用于将目标图像分别输入教师网络和学生网络,得到所述教师网络输出的教师特征图和所述学生网络输出的学生特征图;
通道匹配模块,用于将所述教师特征图和所述学生特征图进行通道匹配,根据匹配结果获取所述教师特征图和所述学生特征图之间的目标损失值;
参数调整模块,用于根据所述目标损失值调整所述学生网络中的参数,得到目标学生网络。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至7中任一项所述方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至7中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010674185.XA CN111898735A (zh) | 2020-07-14 | 2020-07-14 | 蒸馏学习方法、装置、计算机设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010674185.XA CN111898735A (zh) | 2020-07-14 | 2020-07-14 | 蒸馏学习方法、装置、计算机设备和存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN111898735A true CN111898735A (zh) | 2020-11-06 |
Family
ID=73192641
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010674185.XA Pending CN111898735A (zh) | 2020-07-14 | 2020-07-14 | 蒸馏学习方法、装置、计算机设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111898735A (zh) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112819050A (zh) * | 2021-01-22 | 2021-05-18 | 北京市商汤科技开发有限公司 | 知识蒸馏和图像处理方法、装置、电子设备和存储介质 |
CN112926740A (zh) * | 2021-03-30 | 2021-06-08 | 深圳市商汤科技有限公司 | 神经网络的训练方法、装置、计算机设备及存储介质 |
CN113255915A (zh) * | 2021-05-20 | 2021-08-13 | 深圳思谋信息科技有限公司 | 基于结构化实例图的知识蒸馏方法、装置、设备和介质 |
CN113344213A (zh) * | 2021-05-25 | 2021-09-03 | 北京百度网讯科技有限公司 | 知识蒸馏方法、装置、电子设备及计算机可读存储介质 |
CN113792871A (zh) * | 2021-08-04 | 2021-12-14 | 北京旷视科技有限公司 | 神经网络训练方法、目标识别方法、装置和电子设备 |
CN117576381A (zh) * | 2024-01-16 | 2024-02-20 | 深圳华付技术股份有限公司 | 目标检测训练方法及电子设备、计算机可读存储介质 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
CN110633747A (zh) * | 2019-09-12 | 2019-12-31 | 网易(杭州)网络有限公司 | 目标检测器的压缩方法、装置、介质以及电子设备 |
CN110674880A (zh) * | 2019-09-27 | 2020-01-10 | 北京迈格威科技有限公司 | 用于知识蒸馏的网络训练方法、装置、介质与电子设备 |
CN110909815A (zh) * | 2019-11-29 | 2020-03-24 | 深圳市商汤科技有限公司 | 神经网络训练、图像处理方法、装置及电子设备 |
CN111260056A (zh) * | 2020-01-17 | 2020-06-09 | 北京爱笔科技有限公司 | 一种网络模型蒸馏方法及装置 |
CN111401406A (zh) * | 2020-02-21 | 2020-07-10 | 华为技术有限公司 | 一种神经网络训练方法、视频帧处理方法以及相关设备 |
-
2020
- 2020-07-14 CN CN202010674185.XA patent/CN111898735A/zh active Pending
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989A (zh) * | 2017-06-15 | 2017-10-13 | 北京图森未来科技有限公司 | 一种神经网络训练方法及装置 |
US20180365564A1 (en) * | 2017-06-15 | 2018-12-20 | TuSimple | Method and device for training neural network |
CN110633747A (zh) * | 2019-09-12 | 2019-12-31 | 网易(杭州)网络有限公司 | 目标检测器的压缩方法、装置、介质以及电子设备 |
CN110674880A (zh) * | 2019-09-27 | 2020-01-10 | 北京迈格威科技有限公司 | 用于知识蒸馏的网络训练方法、装置、介质与电子设备 |
CN110909815A (zh) * | 2019-11-29 | 2020-03-24 | 深圳市商汤科技有限公司 | 神经网络训练、图像处理方法、装置及电子设备 |
CN111260056A (zh) * | 2020-01-17 | 2020-06-09 | 北京爱笔科技有限公司 | 一种网络模型蒸馏方法及装置 |
CN111401406A (zh) * | 2020-02-21 | 2020-07-10 | 华为技术有限公司 | 一种神经网络训练方法、视频帧处理方法以及相关设备 |
Cited By (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112819050A (zh) * | 2021-01-22 | 2021-05-18 | 北京市商汤科技开发有限公司 | 知识蒸馏和图像处理方法、装置、电子设备和存储介质 |
WO2022156331A1 (zh) * | 2021-01-22 | 2022-07-28 | 北京市商汤科技开发有限公司 | 知识蒸馏和图像处理方法、装置、电子设备和存储介质 |
CN112819050B (zh) * | 2021-01-22 | 2023-10-27 | 北京市商汤科技开发有限公司 | 知识蒸馏和图像处理方法、装置、电子设备和存储介质 |
CN112926740A (zh) * | 2021-03-30 | 2021-06-08 | 深圳市商汤科技有限公司 | 神经网络的训练方法、装置、计算机设备及存储介质 |
CN113255915A (zh) * | 2021-05-20 | 2021-08-13 | 深圳思谋信息科技有限公司 | 基于结构化实例图的知识蒸馏方法、装置、设备和介质 |
CN113255915B (zh) * | 2021-05-20 | 2022-11-18 | 深圳思谋信息科技有限公司 | 基于结构化实例图的知识蒸馏方法、装置、设备和介质 |
CN113255915B8 (zh) * | 2021-05-20 | 2024-02-06 | 深圳思谋信息科技有限公司 | 基于结构化实例图的知识蒸馏方法、装置、设备和介质 |
CN113344213A (zh) * | 2021-05-25 | 2021-09-03 | 北京百度网讯科技有限公司 | 知识蒸馏方法、装置、电子设备及计算机可读存储介质 |
CN113792871A (zh) * | 2021-08-04 | 2021-12-14 | 北京旷视科技有限公司 | 神经网络训练方法、目标识别方法、装置和电子设备 |
CN113792871B (zh) * | 2021-08-04 | 2024-09-06 | 北京旷视科技有限公司 | 神经网络训练方法、目标识别方法、装置和电子设备 |
CN117576381A (zh) * | 2024-01-16 | 2024-02-20 | 深圳华付技术股份有限公司 | 目标检测训练方法及电子设备、计算机可读存储介质 |
CN117576381B (zh) * | 2024-01-16 | 2024-05-07 | 深圳华付技术股份有限公司 | 目标检测训练方法及电子设备、计算机可读存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111898735A (zh) | 蒸馏学习方法、装置、计算机设备和存储介质 | |
US11348249B2 (en) | Training method for image semantic segmentation model and server | |
CN111192292B (zh) | 基于注意力机制与孪生网络的目标跟踪方法及相关设备 | |
CN109902546B (zh) | 人脸识别方法、装置及计算机可读介质 | |
US11354906B2 (en) | Temporally distributed neural networks for video semantic segmentation | |
EP3937124A1 (en) | Image processing method, device and apparatus, and storage medium | |
WO2019100724A1 (zh) | 训练多标签分类模型的方法和装置 | |
WO2021022521A1 (zh) | 数据处理的方法、训练神经网络模型的方法及设备 | |
CN112613515B (zh) | 语义分割方法、装置、计算机设备和存储介质 | |
CN113255915B (zh) | 基于结构化实例图的知识蒸馏方法、装置、设备和介质 | |
CN111914908B (zh) | 一种图像识别模型训练方法、图像识别方法及相关设备 | |
CN112052837A (zh) | 基于人工智能的目标检测方法以及装置 | |
CN113505797B (zh) | 模型训练方法、装置、计算机设备和存储介质 | |
CN114549913B (zh) | 一种语义分割方法、装置、计算机设备和存储介质 | |
CN114282059A (zh) | 视频检索的方法、装置、设备及存储介质 | |
CN115018039A (zh) | 一种神经网络蒸馏方法、目标检测方法以及装置 | |
CN109101984B (zh) | 一种基于卷积神经网络的图像识别方法及装置 | |
CN114677611B (zh) | 数据识别方法、存储介质及设备 | |
CN108154522B (zh) | 目标追踪系统 | |
CN117132950A (zh) | 一种车辆追踪方法、系统、设备及存储介质 | |
CN111914809A (zh) | 目标对象定位方法、图像处理方法、装置和计算机设备 | |
CN111091198A (zh) | 一种数据处理方法及装置 | |
CN112614197B (zh) | 图像生成方法、装置、计算机设备和存储介质 | |
CN111898620A (zh) | 识别模型的训练方法、字符识别方法、装置、设备和介质 | |
CN114648762A (zh) | 语义分割方法、装置、电子设备和计算机可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |