CN111260056B - 一种网络模型蒸馏方法及装置 - Google Patents
一种网络模型蒸馏方法及装置 Download PDFInfo
- Publication number
- CN111260056B CN111260056B CN202010055355.6A CN202010055355A CN111260056B CN 111260056 B CN111260056 B CN 111260056B CN 202010055355 A CN202010055355 A CN 202010055355A CN 111260056 B CN111260056 B CN 111260056B
- Authority
- CN
- China
- Prior art keywords
- channel
- network model
- channel feature
- feature set
- distance
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000004821 distillation Methods 0.000 title claims abstract description 111
- 238000000034 method Methods 0.000 title claims abstract description 53
- 238000004422 calculation algorithm Methods 0.000 claims abstract description 43
- 239000011159 matrix material Substances 0.000 claims description 89
- 230000006870 function Effects 0.000 claims description 49
- 230000004913 activation Effects 0.000 claims description 8
- 238000011176 pooling Methods 0.000 claims description 7
- 238000004590 computer program Methods 0.000 claims description 5
- 238000010606 normalization Methods 0.000 claims description 5
- 230000009469 supplementation Effects 0.000 claims description 5
- 238000004364 calculation method Methods 0.000 description 8
- 238000006243 chemical reaction Methods 0.000 description 6
- 230000008569 process Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 230000009471 action Effects 0.000 description 3
- 239000000047 product Substances 0.000 description 3
- 230000003213 activating effect Effects 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 230000009467 reduction Effects 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 239000013589 supplement Substances 0.000 description 1
- 230000001502 supplementing effect Effects 0.000 description 1
- 230000002194 synthesizing effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例公开了一种网络模型蒸馏方法及装置,具体地,从第一网络模型(老师模型)的蒸馏位点获取第一通道特征集合,包括M个第一通道特征。同时从第二网络模型(学生模型)的蒸馏位点获取第二通道特征集合,包括N个第二通道特征。按照预设规则及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,包括N个通道特征,使得第三通道特征集合与第二通道特征集合匹配。最后,根据第二通道特征集合与第三通道特征集合所匹配的一对通道特征,构建该对通道特征的距离损失函数,利用该距离损失函数对第二网络模型的参数更新,直至构建的距离损失函数满足预设距离阈值,使得第二网络模型学习到第一网络模型的特征表达。
Description
技术领域
本申请涉及自动机器学习技术领域,具体涉及一种网络模型蒸馏方法及装置。
背景技术
卷积神经网络模型蒸馏是一种在广泛使用的小模型训练方法,通常情况况下,小模型具有参数量少、运行速度快、计算资源消耗少的优点,但由于小模型的参数规模较小而存在性能瓶颈、识别准确率不高。模型蒸馏则是使用参数规模较大、性能优异的大模型去引导小模型的训练过程,使后者间接习得前者的特征表达方式,从而达到提升自身性能的目的。
其中,模型蒸馏中最主要的步骤在于训练过程中,在大模型和小模型的特定层级(蒸馏位点)的输出特征之间构建距离损失函数,通过该距离损失函数促使小模型的参数进行迭代更新,进而使得小模型输出的特征表达逼近大模型,以使得小模型的识别准确率提高。
然而,由于大模型和小模型的参数规模不同,导致从大模型选定的特征对应的通道数目与从小模型选定的特征对应的通道数目不对应,因此,在构造距离损失函数时,需要通过额外的转换算子来对大模型的通道数目进行缩减,但这种缩减会引入额外的参数,增加计算开销。
发明内容
有鉴于此,本申请实施例提供一种网络模型蒸馏方法及装置,以实现更为合理有效地使得两个模型之间的通道数据对应,并减小计算开销。
为解决上述问题,本申请实施例提供的技术方案如下:
在本申请实施例第一方面,提供了一种网络模型蒸馏方法,所述方法包括:
从第一网络模型的蒸馏位点获取第一通道特征集合,所述第一网络模型为利用训练样本预先训练生成的老师模型,所述第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数;
从第二网络模型的蒸馏位点获取第二通道特征集合,所述第二网络模型为学生模型,所述第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,M大于N;
根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,所述第三通道特征集合包括N个通道特征;
针对所述第二通道特征集合和所述第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据所述距离损失函数对所述第二网络模型的参数进行更新,直至所述构建的距离损失函数满足预设距离阈值。
在一些可能的实现方式中,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为稀疏匹配时,计算所述第二通道特征集合中每一个所述第二通道特征与所述第一通道特征集合中每个所述第一通道特征之间的距离,构成第一距离矩阵,所述第一距离矩阵大小为N*M;
对所述第一距离矩阵进行补充操作,添加P个距离数值,以使得补充后的第一距离矩阵大小为M*M,所述P等于M*M减去N*M;
针对所述补充后的距离矩阵中的任一行,选择最小距离数值;
将所述最小距离数值对应的第一通道特征确定为目标通道特征;
将各个所述目标通道特征构成第三通道特征集合。
在一些可能的实现方式中,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为随机选择时,对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征;
利用随机函数从目标通道特征匹配对中选择一个所述第一通道特征作为目标通道特征,所述目标通道特征匹配对为任一通道特征匹配对;
将各个所述目标通道特征构成第三通道特征集合。
在一些可能的实现方式中,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为最大特征值池化时,对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征;
将所述通道特征匹配对中各个所述第一通道特征对应的最大特征值进行融合,获得目标通道特征;
将各个所述目标通道特征构成第三通道特征集合。
在一些可能的实现方式中,当所述匹配算法为匈牙利算法时,所述对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,包括:
从所述第一通道特征集合中确定出所述第二参数S个第一通道特征,构成第四通道特征集合,其中,S=R*N,
计算所述第二通道特征集合中每个所述第二通道特征与所述第四通道特征集合中每个所述第一通道特征之间的距离,构成第二距离矩阵,所述第二距离矩阵大小为N*S;
对所述第二距离矩阵进行复制获得R份所述第二距离矩阵,并构成第三距离矩阵,所述第三距离矩阵大小为S*S;利用所述匈牙利算法对所述第三距离矩阵分析匹配,获得通道匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征。
在一些可能的实现方式中,所述将所述通道特征匹配对中各个所述第一通道特征进行融合,获得目标通道特征,包括:
对于所述通道特征匹配对中各个所述第一通道特征,依次比较同一位点各个所述第一通道特征对应的特征值以选择最大特征值;
将选择的各个最大特征值组成目标通道特征。
在一些可能的实现方式中,所述补充后的距离矩阵中的所补充的任一距离数值大于预设距离阈值。
在一些可能的实现方式中,所述方法还包括:
根据所述第一网络模型的类型确定蒸馏位点;和/或,
根据所述第二网络模型的类型确定蒸馏位点。
在一些可能的实现方式中,当所述第一网络模型和/或所述第二网络模型的类型为ResNet网络模型时,将所述ResNet网络模型的残差连接层之后激活层之前的位点确定为蒸馏位点;
当所述第一网络模型和/或所述第二网络模型的类型为MobileNet网络模型时,将所述MobileNet网络模型的归一化层之后激活层之前的位点确定为蒸馏位点;
当所述第一网络模型和/或所述第二网络模型的类型为ShuffleNet网络模型时,将所述ShuffleNet网络模型的shuffle层之后的位点确定为蒸馏位点。
在本申请实施例第二方面,提供了一种网络模型蒸馏装置,所述装置包括:
第一获取单元,用于从第一网络模型的蒸馏位点获取第一通道特征集合,所述第一网络模型为利用训练样本预先训练生成的老师模型,所述第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数;
第二获取单元,用于从第二网络模型的蒸馏位点获取第二通道特征集合,所述第二网络模型为学生模型,所述第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,M大于N;
第一确定单元,用于根据预设规则从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,所述第三通道特征集合包括N个通道特征;
构建单元,用于针对所述第二通道特征集合和所述第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据所述距离损失函数对所述第二网络模型的参数进行更新,直至所述构建的距离损失函数满足第一预设距离阈值。
在本申请实施例第三方面,提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当所述指令在终端设备上运行时,使得所述终端设备执行第一方面所述的网络模型蒸馏方法。
在本申请实施例第四方面,提供了一种网络模型蒸馏设备,包括:存储器,处理器,及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时,实现第一方面所述的网络模型蒸馏方法。
由此可见,本申请实施例具有如下有益效果:
本申请实施例首先从第一网络模型(老师模型)的蒸馏位点获取第一通道特征集合,该第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数,即包括多个第一通道特征。同时从第二网络模型的蒸馏位点获取第二通道特征集合,该第二通道特征集合包括N个第二通道特征,且第二网络模型为学生模型,也就是,第二网络模型的参数规模小于第一网络模型的参数规模。然后,按照预设规则以及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,该第三通道特征集合包括N个通道特征,从而使得第三通道特征集合与所述第二通道特征集合实现完全匹配。最后,根据第二通道特征集合与第三通道特征集合所匹配的一对通道特征,构建该对通道特征的距离损失函数,利用该距离损失函数对第二网络模型的参数进行更新,直至所构建的距离损失函数满足预设距离阈值,进而使得第二网络模型学习到第一网络模型的特征表达,提高识别准确率。
可见,通过本申请实施例提供的网络模型蒸馏方法,可以按照预设规则从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征,无需使用额外的可学习参数缩减第一通道特征集合的通道特征数量,进而减小蒸馏训练的计算开销。
附图说明
图1为本申请实施例提供的一种网络模型蒸馏方法流程图;
图2a为本申请实施例提供的一种ResNet网络模型结构图;
图2b为本申请实施例提供的一种MobileNet网络模型结构图;
图2c为本申请实施例提供的一种ShuffleNet网络模型结构图;
图3为本申请实施例提供的一种网络模型蒸馏装置结构图。
具体实施方式
为使本申请的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本申请实施例作进一步详细的说明。
为便于理解本申请实施例所提供的技术方案,下面将先对本申请涉及的模型蒸馏技术进行说明。
网络模型的蒸馏过程包括:(1)构造并训练一个参数规模较大的网络,作为老师模型。(2)构造符合目标参数规模(通常较小)的网络,作为学生模型。(3)训练过程中,让老师模型和学生模型同时进行前向推理,在两个模型特定层级(蒸馏位点)的输出特征之间构造距离损失函数。(4)网络反向传播时,老师模型的参数保持固定,学生模型的参数依据梯度下降算法迭代更新。通过上述4个步骤,学生模型可以通过与老师模型之间构造的距离损失函数逐渐逼近老师模型的特征表达,从而达到蒸馏的目的。
然而,现有的卷积神经网络模型蒸馏方法存在一个典型的问题是,步骤(3)中所选定的两组输出特征(分别来源于老师模型和学生模型)往往具有不同的通道数目,因此在构造距离损失函数时需要使用额外的转换算子(增加卷积层)来进行通道缩减。该种方式将带来以下两个问题:一是转换算子对学生模型本身的特征会造成干扰,影响模型收敛;二是转换算子会引入额外的可学习参数,可能使蒸馏训练时的计算开销超出现有硬件的承受范围。
基于此,本申请实施例提供了一种网络模型蒸馏方法,该方法基于预设规则从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,无需通过增加额外的可学习参数对第一通道特征集合的通道特征数量进行缩减,消除可学习参数对第二网络模型所输出的通道特征的干扰,降低蒸馏训练对应的计算开销。
基于上述描述,下面将结合附图对本申请实施例提供的网络模型蒸馏方法进行说明。
参见图1,该图为本申请实施例提供的一种网络模型蒸馏方法流程图,如图1所示,该方法可以包括:
S101:从第一网络模型的蒸馏位点获取第一通道特征集合。
本实施例中,对于已利用训练样本预先训练生成的老师模型,即第一网络模型,从该第一网络模型的蒸馏位点获取第一通道特征集合。即,将待处理数据输入第一网络模型,然后在其对应的蒸馏位点获取该第一网络模型输出的关于待处理数据的通道特征集合,即第一通道特征集合。
其中,第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数,该M与第一网络模型蒸馏位点对应的输出层的通道数目相等,每个通道输出第一通道特征,从而形成第一通道特征集合。其中,第一通道特征可以为特征矩阵,该特征矩阵的大小由蒸馏位点对应的输出层的大小决定。例如,第一网络模型的蒸馏位点对应的输出层为3*3*9的卷积层,其中,3*3为卷积核大小,9为通道数目,则每个第一通道特征为3*3的特征矩阵,该特征矩阵共包括9个参数,共存在9个第一通道特征,即第一通道特征集合包括9个第一通道特征。
S102:从第二网络模型的蒸馏位点获取第二通道特征集合。
对于学生模型即第二网络模型,将待处理数据输入第二网络模型,并在该第二网络模型的蒸馏位点获取第二通道特征集合。其中,第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,N与第二网络模型蒸馏位点对应的输出层的通道数目相等,每个通道输出第二通道特征,从而形成第二通道特征集合。
其中,第二通道特征可以为特征矩阵,该特征矩阵的大小由蒸馏位点对应的输出层的大小决定。例如,蒸馏位点对应的输出层为3*3*3的卷积层,其中,3*3为卷积核大小,3为通道数目,则每个第二通道特征为3*3的特征矩阵,该特征矩阵共包括9个参数,共存在3个第二通道特征,即第二通道特征集合包括3个第二通道特征。需要说明的是,第一通道特征的特征尺寸与第二通道特征的特征尺寸相同。
可以理解的是,由于第一网络模型的参数规模大于第二网络模型的参数规模,第二网络模型学习第一网络模型的特征表达,因此M大于N。另外,需要说明的是,在具体实现时,第二网络模型可以为初始网络模型,即未经过训练的网络模型,也可以为预先经过训练的网络模型。当第二网络模型为初始网络模型时,所输入的待处理数据为带有标签的待处理数据,从而训练获得可以进行物体识别或分类的网络模型。
在实时应用中,为保证所获取的第一通道特征集合和第二通道特征集合可以体现各自对应的第一网络模型、第二网络模型的特点,还可以根据网络模型的类型确定蒸馏位点,以使得在该蒸馏位点提取的通道特征更准确。其中,关于确定蒸馏位点的实现将在后续实施例进行说明。
S103:根据预设规则以及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合。
当分别获得第一通道特征集合和第二通道特征集合后,根据预设规则以及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,其中,第三通道特征集合包括N个通道特征。也就是说,按照预设规则以及匹配算法从第一通道特征集合中提取N个通道特征构成第三通道特征集合,使得第三通道特征集合中的每个通道特征与第二通道特征集合中的每个通道特征一一匹配。其中,匹配算法可以为匈牙利算法或其他算法,本实施例在此不做限定。
其中,预设规则可以为稀疏匹配规则,先计算第二通道特征集合中每个第二通道特征与第一通道特征集合中每个第一通道特征之间的距离,并构成距离矩阵(N*M)。然后对该距离矩阵进行补充操作,使得补充的距离矩阵为M*M。再利用通道特征匹配算法进行匹配,从而获得与第二通道特征匹配的第一通道特征,进而生成第三通道特征集合。其中,关于稀疏匹配的具体实现将在后续实施例进行说明。
预设规则还可以为随机选择规则,即利用通道特征匹配算法将第一通道特征集合与第二通道特征集合进行匹配,此时第二通道特征可以匹配至少一个第一通道特征。当存在第二通道特征匹配多个第一通道特征时,利用随机选择函数为该第二通道特征选择一个第二通道特征进行匹配,从而使得每个第二通道特征匹配一个第一通道特征。其中,关于利用随机选择函数确定第三通道特征集合的实现将在后续实施例进行说明。
预设规则还可以为最大特征值池化规则,具体为,首先对第一通道特征集合和第二通道特征集合进行通道特征匹配,当第二通道特征匹配多个第一通道特征时,将所匹配的多个第一通道特征的特征值进行融合,从而获得一个目标通道特征,实现一个第二通道特征匹配一个目标通道特征。其中,关于利用最大特征值池化规则确定第三通道特征集合的实现将在后续实施例进行说明。
S104:针对第二通道特征集合和第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据距离损失函数对第二网络模型的参数进行更新,直至构建的距离损失函数满足预设距离阈值。
当完成第二通道特征集合的匹配时,针对第二通道特征集合和第三通道特征集合所形成的一对通道特征,构建该对通道特征对应的距离损失函数。然后,利用该距离损失函数进行反向传播,以对第二网络模型的参数进行更新,重新进行蒸馏训练,直至所构建的距离损失函数满足预设距离阈值,表明第二网络模型在蒸馏位点所输出的各个通道特征已经逼近第一网络模型在蒸馏位点所输出的各个通道特征,达到蒸馏目的。其中,预设距离阈值可以根据实际应用情况进行设定,本实施例在此不做限定。
需要说明的是,在实际应用中,针对每对匹配的通道特征均构建距离损失函数,利用每对通道特征对应的距离损失函数对第二网络模型的参数进行更新。例如,第二通道特征集合包括3个通道特征,则构建3个距离损失函数,利用每个距离损失函数更新第二网络模型的参数。
基于上述描述可知,首先从第一网络模型(老师模型)的蒸馏位点获取第一通道特征集合,该第一通道特征集合包括M个第一通道特征,即包括多个第一通道特征。同时从第二网络模型的蒸馏位点获取第二通道特征集合,该第二通道特征集合包括N个第二通道特征,且第二网络模型为学生模型,也就是,第二网络模型的参数规模小于第一网络模型的参数规模。然后,按照预设规则及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,该第三通道特征集合包括N个通道特征,从而使得第三通道特征集合与所述第二通道特征集合实现完全匹配。最后,根据第二通道特征集合与第三通道特征集合所匹配的一对通道特征,构建该对通道特征的距离损失函数,利用该距离损失函数对第二网络模型的参数进行更新,直至所构建的距离损失函数满足预设距离阈值,进而使得第二网络模型学习到第一网络模型的特征表达,提高识别准确率。
可见,通过本申请实施例提供的网络模型蒸馏方法,可以按照预设规则从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征,无需使用额外的可学习参数缩减第一通道特征集合的通道特征数量,进而减小蒸馏训练的计算开销。
在实际应用中,本实施例还可以针对网络模型的类型确定对应的蒸馏位点,以便在各自对应的蒸馏位点出获取通道特征集合。具体为,根据第一网络模型的类型确定蒸馏位点,和/或根据第二网络模型的类型确定蒸馏位点。可以理解的是,神经网络模型可以包括多种类型的网络模型,例如,ResNet网络模型、MobileNet网络模型、ShuffleNet网络模型、VggNet网络模型、GoogleNet网络模型等。本实施例针对ResNet网络模型、MobileNet网络模型、ShuffleNet网络模型提供了确定蒸馏位点的方法,具体为:
当第一网络模型和/或第二网络模型的类型为ResNet网络模型时,将ResNet网络模型的残差连接层之后激活层之前的位点确定为蒸馏位点。例如图2a所示,ResNet网络模型包括卷积层Conv、归一化处理层BN、残差连接层以及激活层ReLU,将残差连接层之后激活层之前的位点(圆所在的位置)确定为蒸馏位点,从该位置获取第一通道特征集合或第二通道特征集合。
当第一网络模型和/或第二网络模型的类型为MobileNet网络模型时,将MobileNet网络模型的归一化层之后激活层之前的位点确定为蒸馏位点;例如图2b所示,MobileNet网络模型包括卷积层Conv、归一化层BN+激活层ReLU、卷积层Conv_+归一化层BN、激活层ReLU、卷积层Conv+归一化层BN,将卷积层Conv+归一化层BN之后激活层ReLU之前的位点(圆所在的位置)确定为蒸馏点,从该位置提取第一通道特征集合或第二通道特征集合。
当第一网络模型和/或第二网络模型的类型为ShuffleNet网络模型时,将ShuffleNet网络模型的shuffle层之后的位点确定为蒸馏位点。如图2c所示,ShuffleNet网络模型包括卷积层Conv、归一化层BN+激活层ReLU、3个(卷积层Conv+归一化层BN)、激活层ReLU、全连接层、Shuffle层,将Shuffle层之后的位点(圆所在的位置)确定为蒸馏位点,从该位置获取第一通道特征集合或第二通道特征集合。
需要说明的是,第一网络模型和第二网络模型的类型可以相同,也可以不同,在实际应用时,根据各自的类型确定相应的蒸馏位点,本实施例在此不做任何限定。
基于上述实施例所提及的预设规则,下面将分别对上述三种预设规则进行说明:
一、稀疏匹配
1)计算第二通道特征集合中每个第二通道特征与第一通道特征集合中每个第一通道特征之间的距离,构成第一距离矩阵,该第一距离矩阵大小为N*M。
本实施例,针对第二通道特征集合中的任一个第二通道特征,计算该第二通道特征与第一通道特征集合中的每个第一通道特征之间的距离,从而构成第一距离矩阵。例如,第二通道特征集合包括3个第二通道特征,第一通道特征集合包括5个第一通道特征,则构成的第一距离矩阵为3*5,如(1)所示,该第一距离矩阵的每一行为某一个第二通道特征与各个第一通道特征之间的距离值,如,x11-x15为第一个第二通道特征与各个第一通道特征之间的距离值;x21-x25为第二个第二通道特征与各个第一通道特征之间的距离值;x31-x35为第三个第二通道特征与各个第一通道特征之间的距离值。
需要说明的是,由于在利用匈牙利算法进行通道特征匹配时,必须要求所匹配的两组通道特征数量一致,因此,需要对距离矩阵进行补充。
2)对第一距离矩阵进行补充操作,添加P个距离数值,以使得补充后的第一距离矩阵大小为M*M,其中,P等于M*M减去N*M。
即,对第一距离矩阵中添加额外的距离数值,使得补充后的第一距离矩阵大小为M*M,所添加的P距离数值中每M个距离数值作为新的一行添加至距离矩阵中。例如,第一距离矩阵为3*5,则补充后的第一距离矩阵为5*5,如(2)所示,其中,y41-y45以及y51-y55为补充的距离数值。
需要说明的是,在实际补充时,为避免所补充的距离数值被匹配,则所添加的距离数值大于预设距离数值。具体地,该预设距离阈值可以根据实际情况进行确定,即保证所添加的P个距离数值均为足够大的数值,无法被确定为最小距离数值。
3)针对补充后的第一距离矩阵中的任一行,选择最小距离数值。
4)将最小距离数值对应的第一通道特征确定为目标通道特征。
本实施例中,当完成距离矩阵的补充操作后,针对补充后的第一距离矩阵中的任一行,选择最小距离数值,将该最小距离数值对应的第一通道特征确定为目标通道特征。当确定出每一行最小距离数值对应的目标通道特征后,将各个目标通道特征构成第三通道特征集合。
例如,第一行对应的最小距离数值为x13,则x13对应第一通道特征集合中的第三个第一通道特征,则第三个第一通道特征为目标通道特征;第二行对应的最小距离数值为x22,则x22对应第一通道特征集合中的第二个第一通道特征,则第二个第一通道特征为目标通道特征;第三行对应的最小距离数值为x35,则x35对应第一通道特征集合中的第五个第一通道特征,则第五个第一通道特征为目标通道特征,则各个目标通道特征构成第三通道特征集合。
可见,本实施例通过稀疏匹配的方式从第一通道特征集合中提取与第二通道集合匹配的第三通道特征集合,无需利用额外的可学习参数缩减第一通道特征集合中的通道特征数量,不仅减小蒸馏计算量,还可以消除由于引入额外可学习参数对第二通道特征的干扰。
二、随机选择
1)对第一通道特征集合和第二通道特征集合进行匹配,获得通道特征匹配对。
本实施例中,首先利用通道特征匹配算法对第一通道特征集合和第二通道特征集合进行匹配,获得通道特征匹配对,该通道特征匹配对中每个第二通道特征至少匹配一个第一通道特征。
可以理解的是,由于第一通道特征集合中第一通道特征的数量大于第二通道特征集合中第二通道特征的数量,因此,在进行通道特征匹配时,将出现某一第二通道特征匹配多个第一通道特征。例如,第一通道特征集合为[a1 a2 a3 a4 a5 a6]、第二通道特征集合为[b1 b2 b3],则在进行通道特征匹配时,可以存在每个第二通道特征匹配两个第一通道特征,如b1匹配a4、a5;b2匹配a1、a2;b3匹配a3、a6。再例如,第一通道特征集合为[a1 a2 a3a4 a5],第二通道特征集合为[b1 b2 b3],则在进行通道特征匹配时,出现如b1匹配a4、a5;b2匹配a1、a2;b3匹配a3。
2)利用随机函数从目标通道特征匹配对中选择一个第一通道特征作为目标通道特征。
当完成通道特征匹配后,将每个通道特征匹配对作为目标通道特征匹配对,利用随机函数从目标通道特征匹配对中选择一个第一通道特征作为目标通道特征。可以理解的是,当某一通道特征匹配对中仅存在一个第一通道特征时,将该第一通道特征作为目标通道特征;当某一通道特征匹配对中存在多个第一通道特征时,利用随机函数从多个第一通道特征中选择一个第一通道特征作为目标通道特征。
例如,存在3个通道特征匹配对[b1 a4 a5]、[b2 a1 a2]、[b3 a3 a6],利用随机选择函数从第一个通道特征匹配对中选择第一通道特征a5、从第二个通道特征匹配对中选择第一通道特征a2、从第三个通道特征匹配对中选择a3。
3)将各个目标通道特征构成第三通道特征集合。
当从每个通道特征匹配对中选择出目标通道特征后,将所选择的所有目标通道特征构成第三通道特征集合。例如,从第一个通道特征匹配对中选择第一通道特征a5、从第二个通道特征匹配对中选择第一通道特征a2、从第三个通道特征匹配对中选择a3,则第三通道特征集合为[a5 a2 a3]。
可见,本实施例通过随机选择的方式从第一通道特征集合中提取与第二通道集合匹配的第三通道特征集合,无需利用额外的转换算子缩减第一通道特征集合中的通道特征数量,不仅减小蒸馏计算量,还可以消除由于引入额外转换算子对第二通道特征的干扰。
三、最大特征值池化
1)对第一通道特征集合和第二通道特征集合进行匹配,获得通道特征匹配对。
本实施例利用通道特征匹配算法对第一通道特征集合和第二通道特征集合进行匹配,获得通道特征匹配对,该通道特征匹配对中每个第二通道特征至少匹配一个第一通道特征。
由于第一通道特征集合中第一通道特征的数量大于第二通道特征集合中第二通道特征的数量,因此,在进行通道特征匹配时,将出现某一第二通道特征匹配多个第一通道特征。例如,第一通道特征集合为[a1 a2 a3 a4 a5 a6]、第二通道特征集合为[b1 b2 b3],则在进行通道特征匹配时,可以存在每个第二通道特征匹配两个第一通道特征,如b1匹配a4、a5;b2匹配a1、a2;b3匹配a3、a6。
2)将通道特征匹配对中各个第一通道特征对应的最大特征值进行融合,获得目标通道特征。
当某一通道特征匹配对中包括多个第一通道特征时,从每个第一通道特征中提取最大特征值,并利用提取的每个最大特征值组成新的通道特征作为目标通道特征。具体地,对于通道特征匹配对中各个第一通道特征,依次比较同一位点各个第一通道特征对应的特征值以选择最大特征值;再将选择的各个最大特征值组成目标通道特征。
例如,通道特征匹配对为[b2 a1 a2],其中,a1和a2均为3*3矩阵,共包括9个参数,则将a1中的每个参数a1ij与a2中同一位置对应的参数a2ij进行比较,选出最大值,从而选择9个最大值,构成目标通道特征c3。
3)将各个目标通道特征构成第三通道特征集合。
当确定出每个通道特征对对应的目标通道特征后,将所有的目标通道特征组成第三通道特征集合。
可见,本实施例通过最大特征值池化的方式从第一通道特征集合中提取与第二通道集合匹配的第三通道特征集合,无需利用额外的可学习参数缩减第一通道特征集合中的通道特征数量,不仅减小蒸馏计算量,还可以消除由于引入额外可学习参数对第二通道特征的干扰。
需要说明的是,当匹配算法为匈牙利算法时,由于匈牙利算法要求待匹配的两个通道特征集合中通道特征数量一致,而通常情况下老师模型的通道特征数量大于学生模型的通道特征数量,因此,需要先对老师模型的通道特征数量进行调整,以满足利用匈牙利算法的要求。则当匹配算法为匈牙利算法时,对第一通道特征集合和第二通道特征集合进行匹配,获得通道特征匹配对,具体为:
(1)根据第一通道特征个数M以及第二通道特征个数N确定出第一参数R,其中,
即,首先根据第一通道特征集合中第一通道特征的个数M以及第二通道特征集合中的第二通道特征的个数N确定出第一参数R。其中,R等于M除以N向下取整。例如,M=7,N=3,则R=2;M=14,N=3,则R=4。
(2)根据第一参数R和N的乘积,计算获得第二参数S。
(3)从第一通道特征集合中确定出第二参数S个第一通道特征,构成第四通道特征集合。
当确定出第一参数R后,根据第二通道特征的个数N确定出第四通道特征集合所包括的第一通道特征的个数S,其中,S等于R和N的乘积。具体地,可以从M个第一通道特征中随机选择S个第一通道特征构成第四通道特征,也可以按照预设规则从M个第一通道特征中选择S个第一通道特征构成第四通道特征,例如,选择奇数位的第一通道特征或偶数位的第一通道特征。例如,N=3,R=2,则S=6,即从包括7个第一通道特征的第一通道特征集合中[a1a2 a3 a4 a5 a6 a7],选择奇数位的第一通道特征a1、a3、a5、a7,则从剩余的[a2 a4 a6]中年继续选择奇数位a2、a6,则共选出6个第一通道特征,构成第四通道特征集合。
可以理解的是,当第一通道特征个数M为第二通道特征个数N的正数倍时,S等于M,则直接将第一通道特征集合确定为第四通道特征集合。例如,M=6,N=3时,R=2,则S=6。
(4)计算第二通道特征集合中每个第二通道特征与第四通道特征集合中每个第一通道特征之间的距离,构成第二距离矩阵,该第二距离矩阵大小为N*S。
本实施例,针对第二通道特征集合中的任一个第二通道特征,计算该第二通道特征与第四通道特征集合中的每个第一通道特征之间的距离,从而构成第二距离矩阵。例如,第二通道特征集合包括N=3个第二通道特征分别为[b1 b2 b3],第四通道特征集合为[a1a2 a3 a5 a6 a7],则形成一个3*6的第二距离矩阵,如(3)所示。
(5)对第二距离矩阵进行复制获得R份第二距离矩阵,构成第三距离矩阵,该第三距离矩阵大小为S*S。
当获得第二距离矩阵后,对该第二距离矩阵进行复制操作,共获得R份第二距离矩阵,将所有的第二距离矩阵合成获得第三距离矩阵,该第三距离矩阵大小为S*S,即第三距离矩阵中每个距离对应的第一通道特征与第二通道特征的数量是一致的。例如,R=2,S=6,则对3*6的第二距离矩阵进行复制,获得2个第二距离矩阵,2个第二距离矩阵构成第三距离矩阵6*6,如(4)所示。
(6)利用匈牙利算法对第三距离矩阵分析匹配,获得通道匹配对。
当获得第三距离矩阵后,利用匈牙利算法对第三距离矩阵进行分析匹配,获得通道匹配对。由于第三距离矩阵是通过增加第二距离矩阵的行数获得的,每一行表示一个第二通道特征与每个第一通道特征之间的距离。因此,在利用匈牙利算法对第三距离矩阵中第一通道特征和第二通道特征进行匹配时,每个第二通道特征将匹配至少一个第一通道特征。
例如,对于上述矩阵(4),第一行确定出第二通道特征b1匹配第一通道特征a3;第二行确定出第二通道特征b2匹配第一通道特征a1;第三行确定出第二通道特征b3匹配第一通道特征a5;第四行确定出第二通道特征b1匹配第一通道特征a7,第五行确定出第二通道特征b2匹配第一通道特征a2;第六行确定出第二通道特征b3匹配第一通道特征a6,则b1匹配a3和a7,b2匹配a1和a2,b3匹配a5和a6。
基于上述方法实施例,本申请实施例提供了一种网络模型蒸馏装置结构图,如图3所示,该装置可以包括:
第一获取单元301,用于从第一网络模型的蒸馏位点获取第一通道特征集合,所述第一网络模型为利用训练样本预先训练生成的老师模型,所述第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数;
第二获取单元302,用于从第二网络模型的蒸馏位点获取第二通道特征集合,所述第二网络模型为学生模型,所述第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,M大于N;
确定单元303,用于根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,所述第三通道特征集合包括N个通道特征;
构建单元304,用于针对所述第二通道特征集合和所述第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据所述距离损失函数对所述第二网络模型的参数进行更新,直至所述构建的距离损失函数满足预设距离阈值。
在一种可能的实现方式中,所述第一确定单元,包括:
计算子单元,用于当所述预设规则为稀疏匹配时,计算所述第二通道特征集合中每一个所述第二通道特征与所述第一通道特征集合中每个所述第一通道特征之间的距离,构成第一距离矩阵,所述第一距离矩阵大小为N*M;
补充子单元,用于对所述第一距离矩阵进行补充操作,添加P个距离数值,以使得补充后的第一距离矩阵大小为M*M,所述P等于M*M减去N*M;
第一选择子单元,用于针对所述补充后的第一距离矩阵中的任一行,选择最小距离数值;
确定子单元,用于将所述最小距离数值对应的第一通道特征确定为目标通道特征;
第一构成子单元,用于将各个所述目标通道特征构成第三通道特征集合。
在一种可能的实现方式中,所述第一确定单元,包括:
第一匹配子单元,用于当所述预设规则为随机选择时,对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征;
第二选择子单元,用于利用随机函数从目标通道特征匹配对中选择一个所述第一通道特征作为目标通道特征,所述目标通道特征匹配对为任一通道特征匹配对;
第二构成子单元,用于将各个所述目标通道特征构成第三通道特征集合。
在一种可能的实现方式中,所述第一确定单元,包括:
第二匹配子单元,用于当所述预设规则为最大特征值池化时,对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征;
融合子单元,用于将所述通道特征匹配对中各个所述第一通道特征对应的最大特征值进行融合,获得目标通道特征;
第三构成子单元,用于将各个所述目标通道特征构成第三通道特征集合。
在一种可能的实现方式中,所述第一匹配子单元或所述第二匹配子单元,具体用于,根据所述第一通道特征个数M以及所述第二通道特征个数N确定出第一参数R,所述根据所述第一参数R和所述N的乘积,计算获得第二参数S;从所述第一通道特征集合中确定出所述第二参数S个第一通道特征,构成第四通道特征集合;计算所述第二通道特征集合中每个所述第二通道特征与所述第四通道特征集合中每个所述第一通道特征之间的距离,构成第二距离矩阵,所述第二距离矩阵大小为N*S;对所述第二距离矩阵进行复制获得R份所述第二距离矩阵,并构成第三距离矩阵,所述第三距离矩阵大小为S*S;利用所述匈牙利算法对所述第三距离矩阵分析匹配,获得通道匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征。
在一种可能的实现方式中,所述补充后的距离矩阵中的所补充的任一距离数值大于预设距离阈值。
在一种可能的实现方式中,所述装置还包括:
第二确定单元,用于根据所述第一网络模型的类型确定蒸馏位点;和/或,
第三确定单元,用于根据所述第二网络模型的类型确定蒸馏位点。
在一种可能的实现方式中,当所述第一网络模型和/或所述第二网络模型的类型为ResNet网络模型时,将所述ResNet网络模型的残差连接层之后激活层之前的位点确定为蒸馏位点;
当所述第一网络模型和/或所述第二网络模型的类型为MobileNet网络模型时,将所述MobileNet网络模型的归一化层之后激活层之前的位点确定为蒸馏位点;
当所述第一网络模型和/或所述第二网络模型的类型为ShuffleNet网络模型时,将所述ShuffleNet网络模型的shuffle层之后的位点确定为蒸馏位点。
需要说明的是,本实施例中各个单元的实现可以参见上述方法实施例,本实施例在此不再赘述。
另外,本申请实施例还提供了一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有指令,当所述指令在终端设备上运行时,使得所述终端设备执行所述的网络模型蒸馏方法。
本申请实施例提供了一种网络模型蒸馏设备,包括:存储器,处理器,及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时,实现所述的网络模型蒸馏方法。
基于上述描述可知,首先从第一网络模型(老师模型)的蒸馏位点获取第一通道特征集合,该第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数,即包括多个第一通道特征。同时从第二网络模型的蒸馏位点获取第二通道特征集合,该第二通道特征集合包括N个第二通道特征,且第二网络模型为学生模型,也就是,第二网络模型的参数规模小于第一网络模型的参数规模。然后,按照预设规则及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,该第三通道特征集合包括N个通道特征,从而使得第三通道特征集合与所述第二通道特征集合实现完全匹配。最后,根据第二通道特征集合与第三通道特征集合所匹配的一对通道特征,构建该对通道特征的距离损失函数,利用该距离损失函数对第二网络模型的参数进行更新,直至所构建的距离损失函数满足预设距离阈值,进而使得第二网络模型学习到第一网络模型的特征表达,提高识别准确率。
可见,通过本申请实施例提供的网络模型蒸馏方法,可以按照预设规则从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征,无需使用额外的可学习参数缩减第一通道特征集合的通道特征数量,进而减小蒸馏训练的计算开销。
需要说明的是,本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的系统或装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
应当理解,在本申请中,“至少一个(项)”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:只存在A,只存在B以及同时存在A和B三种情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”,其中a,b,c可以是单个,也可以是多个。
还需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本申请。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本申请的精神或范围的情况下,在其它实施例中实现。因此,本申请将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。
Claims (12)
1.一种网络模型蒸馏方法,其特征在于,所述方法应用于网络模型蒸馏设备,所述方法包括:
从第一网络模型的蒸馏位点获取第一通道特征集合,所述第一网络模型为利用训练样本预先训练生成的老师模型,所述第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数;所述第一通道特征集合是利用所述第一网络模型根据待处理数据得到的;
从第二网络模型的蒸馏位点获取第二通道特征集合,所述第二网络模型为学生模型,所述第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,M大于N;所述第二通道特征是利用所述第二网络模型根据所述待处理数据得到的,所述第二网络模型为初始网络模型;
根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,所述第三通道特征集合包括N个通道特征;
针对所述第二通道特征集合和所述第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据所述距离损失函数对所述第二网络模型的参数进行更新,直至所构建的所述距离损失函数满足预设距离阈值;
利用更新后的第二网络模型进行物体识别或分类。
2.根据权利要求1所述的方法,其特征在于,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为稀疏匹配时,计算所述第二通道特征集合中每一个所述第二通道特征与所述第一通道特征集合中每个所述第一通道特征之间的距离,构成第一距离矩阵,所述第一距离矩阵大小为N*M;
对所述第一距离矩阵进行补充操作,添加P个距离数值,以使得补充后的第一距离矩阵大小为M*M,所述P等于M*M减去N*M;
针对所述补充后的第一距离矩阵中的任一行,选择最小距离数值;
将所述最小距离数值对应的第一通道特征确定为目标通道特征;
将各个所述目标通道特征构成第三通道特征集合。
3.根据权利要求1所述的方法,其特征在于,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为随机选择时,对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征;
利用随机函数从目标通道特征匹配对中选择一个所述第一通道特征作为目标通道特征,所述目标通道特征匹配对为任一通道特征匹配对;
将各个所述目标通道特征构成第三通道特征集合。
4.根据权利要求1所述的方法,其特征在于,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为最大特征值池化时,对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征;
将所述通道特征匹配对中各个所述第一通道特征对应的最大特征值进行融合,获得目标通道特征;
将各个所述目标通道特征构成第三通道特征集合。
5.根据权利要求3或4所述的方法,其特征在于,当所述匹配算法为匈牙利算法时,所述对所述第一通道特征集合和所述第二通道特征集合进行匹配,获得通道特征匹配对,包括:
从所述第一通道特征集合中确定出第二参数S个第一通道特征,构成第四通道特征集合,其中,S=R*N,
计算所述第二通道特征集合中每个所述第二通道特征与所述第四通道特征集合中每个所述第一通道特征之间的距离,构成第二距离矩阵,所述第二距离矩阵大小为N*S;
对所述第二距离矩阵进行复制获得R份所述第二距离矩阵,并构成第三距离矩阵,所述第三距离矩阵大小为S*S;利用所述匈牙利算法对所述第三距离矩阵分析匹配,获得通道匹配对,所述通道特征匹配对中所述第二通道特征至少匹配一个所述第一通道特征。
6.根据权利要求4所述的方法,其特征在于,所述将所述通道特征匹配对中各个所述第一通道特征进行融合,获得目标通道特征,包括:
对于所述通道特征匹配对中各个所述第一通道特征,依次比较同一位点各个所述第一通道特征对应的特征值以选择最大特征值;
将选择的各个最大特征值组成目标通道特征。
7.根据权利要求2所述的方法,其特征在于,所述补充后的距离矩阵中的所补充的任一距离数值大于预设距离阈值。
8.根据权利要求1所述的方法,其特征在于,所述方法还包括:
根据所述第一网络模型的类型确定蒸馏位点;和/或,
根据所述第二网络模型的类型确定蒸馏位点。
9.根据权利要求8所述的方法,其特征在于,当所述第一网络模型和/或所述第二网络模型的类型为ResNet网络模型时,将所述ResNet网络模型的残差连接层之后激活层之前的位点确定为蒸馏位点;
当所述第一网络模型和/或所述第二网络模型的类型为MobileNet网络模型时,将所述MobileNet网络模型的归一化层之后激活层之前的位点确定为蒸馏位点;
当所述第一网络模型和/或所述第二网络模型的类型为ShuffleNet网络模型时,将所述ShuffleNet网络模型的shuffle层之后的位点确定为蒸馏位点。
10.一种网络模型蒸馏装置,其特征在于,所述装置部署在网络模型蒸馏设备,所述装置包括:
第一获取单元,用于从第一网络模型的蒸馏位点获取第一通道特征集合,所述第一网络模型为利用训练样本预先训练生成的老师模型,所述第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数;所述第一通道特征集合是利用所述第一网络模型根据待处理数据得到的;
第二获取单元,用于从第二网络模型的蒸馏位点获取第二通道特征集合,所述第二网络模型为学生模型,所述第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,M大于N;所述第二通道特征是利用所述第二网络模型根据所述待处理数据得到的,所述第二网络模型为初始网络模型;
第一确定单元,用于根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,所述第三通道特征集合包括N个通道特征;
构建单元,用于针对所述第二通道特征集合和所述第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据所述距离损失函数对所述第二网络模型的参数进行更新,直至所构建的所述距离损失函数满足预设距离阈值;
识别分类单元,用于利用更新后的第二网络模型进行物体识别或分类。
11.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有指令,当所述指令在终端设备上运行时,使得所述终端设备执行如权利要求1-9任一项所述的网络模型蒸馏方法。
12.一种网络模型蒸馏设备,其特征在于,包括:存储器,处理器,及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时,实现如权利要求1-9任一项所述的网络模型蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010055355.6A CN111260056B (zh) | 2020-01-17 | 2020-01-17 | 一种网络模型蒸馏方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010055355.6A CN111260056B (zh) | 2020-01-17 | 2020-01-17 | 一种网络模型蒸馏方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111260056A CN111260056A (zh) | 2020-06-09 |
CN111260056B true CN111260056B (zh) | 2024-03-12 |
Family
ID=70954195
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010055355.6A Active CN111260056B (zh) | 2020-01-17 | 2020-01-17 | 一种网络模型蒸馏方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111260056B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111898735A (zh) * | 2020-07-14 | 2020-11-06 | 上海眼控科技股份有限公司 | 蒸馏学习方法、装置、计算机设备和存储介质 |
CN114638238A (zh) * | 2020-12-16 | 2022-06-17 | 北京金山数字娱乐科技有限公司 | 一种神经网络模型的训练方法及装置 |
CN112819050B (zh) * | 2021-01-22 | 2023-10-27 | 北京市商汤科技开发有限公司 | 知识蒸馏和图像处理方法、装置、电子设备和存储介质 |
Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108921294A (zh) * | 2018-07-11 | 2018-11-30 | 浙江大学 | 一种用于神经网络加速的渐进式块知识蒸馏方法 |
CN109409500A (zh) * | 2018-09-21 | 2019-03-01 | 清华大学 | 基于知识蒸馏与非参数卷积的模型加速方法及装置 |
CN109543817A (zh) * | 2018-10-19 | 2019-03-29 | 北京陌上花科技有限公司 | 用于卷积神经网络的模型蒸馏方法及装置 |
CN109740567A (zh) * | 2019-01-18 | 2019-05-10 | 北京旷视科技有限公司 | 关键点定位模型训练方法、定位方法、装置及设备 |
CN110009052A (zh) * | 2019-04-11 | 2019-07-12 | 腾讯科技(深圳)有限公司 | 一种图像识别的方法、图像识别模型训练的方法及装置 |
CN110135562A (zh) * | 2019-04-30 | 2019-08-16 | 中国科学院自动化研究所 | 基于特征空间变化的蒸馏学习方法、系统、装置 |
US10496884B1 (en) * | 2017-09-19 | 2019-12-03 | Deepradiology Inc. | Transformation of textbook information |
CN110674880A (zh) * | 2019-09-27 | 2020-01-10 | 北京迈格威科技有限公司 | 用于知识蒸馏的网络训练方法、装置、介质与电子设备 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107247989B (zh) * | 2017-06-15 | 2020-11-24 | 北京图森智途科技有限公司 | 一种实时的计算机视觉处理方法及装置 |
-
2020
- 2020-01-17 CN CN202010055355.6A patent/CN111260056B/zh active Active
Patent Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10496884B1 (en) * | 2017-09-19 | 2019-12-03 | Deepradiology Inc. | Transformation of textbook information |
CN108921294A (zh) * | 2018-07-11 | 2018-11-30 | 浙江大学 | 一种用于神经网络加速的渐进式块知识蒸馏方法 |
CN109409500A (zh) * | 2018-09-21 | 2019-03-01 | 清华大学 | 基于知识蒸馏与非参数卷积的模型加速方法及装置 |
CN109543817A (zh) * | 2018-10-19 | 2019-03-29 | 北京陌上花科技有限公司 | 用于卷积神经网络的模型蒸馏方法及装置 |
CN109740567A (zh) * | 2019-01-18 | 2019-05-10 | 北京旷视科技有限公司 | 关键点定位模型训练方法、定位方法、装置及设备 |
CN110009052A (zh) * | 2019-04-11 | 2019-07-12 | 腾讯科技(深圳)有限公司 | 一种图像识别的方法、图像识别模型训练的方法及装置 |
CN110135562A (zh) * | 2019-04-30 | 2019-08-16 | 中国科学院自动化研究所 | 基于特征空间变化的蒸馏学习方法、系统、装置 |
CN110674880A (zh) * | 2019-09-27 | 2020-01-10 | 北京迈格威科技有限公司 | 用于知识蒸馏的网络训练方法、装置、介质与电子设备 |
Non-Patent Citations (1)
Title |
---|
卷积神经网络算法模型的压缩与加速算法比较;李思奇;;信息与电脑(理论版);20190615(第11期);27-29 * |
Also Published As
Publication number | Publication date |
---|---|
CN111260056A (zh) | 2020-06-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111260056B (zh) | 一种网络模型蒸馏方法及装置 | |
CN109146076A (zh) | 模型生成方法及装置、数据处理方法及装置 | |
CN112487168B (zh) | 知识图谱的语义问答方法、装置、计算机设备及存储介质 | |
CN109919183B (zh) | 一种基于小样本的图像识别方法、装置、设备及存储介质 | |
WO2021043294A1 (en) | Neural network pruning | |
JP2020038704A (ja) | データ識別器訓練方法、データ識別器訓練装置、プログラム及び訓練方法 | |
CN112288086A (zh) | 一种神经网络的训练方法、装置以及计算机设备 | |
CN110659678B (zh) | 一种用户行为分类方法、系统及存储介质 | |
KR102134472B1 (ko) | 유전 알고리즘을 활용한 콘볼루션 뉴럴 네트워크의 최적 구조 탐색 방법 | |
CN110210558B (zh) | 评估神经网络性能的方法及装置 | |
US20220036189A1 (en) | Methods, systems, and media for random semi-structured row-wise pruning in neural networks | |
CN112508190A (zh) | 结构化稀疏参数的处理方法、装置、设备及存储介质 | |
CN115017178A (zh) | 数据到文本生成模型的训练方法和装置 | |
CN115860100A (zh) | 一种神经网络模型训练方法、装置及计算设备 | |
CN115496144A (zh) | 配电网运行场景确定方法、装置、计算机设备和存储介质 | |
US20230206054A1 (en) | Expedited Assessment and Ranking of Model Quality in Machine Learning | |
CN110222816B (zh) | 深度学习模型的建立方法、图像处理方法及装置 | |
CN112561050A (zh) | 一种神经网络模型训练方法及装置 | |
CN109697511B (zh) | 数据推理方法、装置及计算机设备 | |
US20230004791A1 (en) | Compressed matrix representations of neural network architectures based on synaptic connectivity | |
CN115909441A (zh) | 人脸识别模型建立方法、人脸识别方法和电子设备 | |
KR20190129422A (ko) | 뉴럴 네트워크를 이용한 변분 추론 방법 및 장치 | |
CN115148292A (zh) | 基于人工智能的dna模体预测方法、装置、设备及介质 | |
CN112801203B (zh) | 基于多任务学习的数据分流训练方法及系统 | |
CN111402121A (zh) | 图像风格的转换方法、装置、计算机设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant | ||
TG01 | Patent term adjustment |