CN117195951A - 一种基于架构搜索和自知识蒸馏的学习基因继承方法 - Google Patents
一种基于架构搜索和自知识蒸馏的学习基因继承方法 Download PDFInfo
- Publication number
- CN117195951A CN117195951A CN202311232774.2A CN202311232774A CN117195951A CN 117195951 A CN117195951 A CN 117195951A CN 202311232774 A CN202311232774 A CN 202311232774A CN 117195951 A CN117195951 A CN 117195951A
- Authority
- CN
- China
- Prior art keywords
- network
- learning
- offspring
- operation block
- output
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 79
- 108090000623 proteins and genes Proteins 0.000 title claims abstract description 71
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 19
- 238000010586 diagram Methods 0.000 claims abstract description 23
- 230000002708 enhancing effect Effects 0.000 claims abstract description 3
- 238000012549 training Methods 0.000 claims description 53
- 230000006870 function Effects 0.000 claims description 30
- 238000012360 testing method Methods 0.000 claims description 20
- 238000000605 extraction Methods 0.000 claims description 8
- 238000013461 design Methods 0.000 claims description 4
- 230000008569 process Effects 0.000 claims description 4
- 230000008859 change Effects 0.000 claims description 3
- 238000011176 pooling Methods 0.000 claims description 3
- 238000002372 labelling Methods 0.000 abstract description 3
- 238000011423 initialization method Methods 0.000 abstract description 2
- 238000013135 deep learning Methods 0.000 description 18
- 238000010606 normalization Methods 0.000 description 4
- 230000006835 compression Effects 0.000 description 3
- 238000007906 compression Methods 0.000 description 3
- 230000001537 neural effect Effects 0.000 description 3
- 241000196324 Embryophyta Species 0.000 description 2
- 235000009499 Vanilla fragrans Nutrition 0.000 description 2
- 244000263375 Vanilla tahitensis Species 0.000 description 2
- 235000012036 Vanilla tahitensis Nutrition 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 230000002068 genetic effect Effects 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 230000006978 adaptation Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000013138 pruning Methods 0.000 description 1
- 230000005477 standard model Effects 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 239000002699 waste material Substances 0.000 description 1
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Image Analysis (AREA)
Abstract
本发明提供一种基于架构搜索和自知识蒸馏的学习基因继承方法,为后代网络搭建超网络;随机选取增强后的数据输入超网络和祖先网络;计算超网络与祖先网络特征图的差异来更新超网络的参数;从超网络中搜索相似度最高的后代网络架构。随机选取少量下游任务样本增强后输入后代网络,输出样本类别预测概率;输出后代网络中继承学习基因的层和未继承学习基因的层的特征图的相似度来蒸馏学习基因;利用分类和相似度损失更新后代网络。本发明方法即使在噪声数据上也具有优秀的分类性能;和随机初始化方法相比,在达到相似的分类精度时,需要更少的分类数据;在少数精细标注的数据条件下,本发明方法能快速训练自动生成的后代网络使其具有较高的分类性能。
Description
技术领域
本发明涉及一种基于架构搜索和自知识蒸馏的学习基因继承方法,属于计算机视觉技术领域。
背景技术
深度学习网络在计算机视觉领域中取得了巨大进展,诞生了一系列标准模型,例如CNN、ResNet和Transformer。深度学习网络应用在计算机视觉任务中主要经历了以下的几个阶段:首先收集大量数据,并对这些数据进行精细化的标准。随后由专业的算法研究人员根据任务的特点手工设计深度学习网络。在被标注的数据集上训练深度学习网络,并在验证集上验证当前网络在该数据集上的精度。调整网络的超参数,直至网络在验证集上获得最高的精度,保留此时网络训练的参数。用训练好的参数初始化网络并应用在下游任务中。
然而,上述阶段中存在着较大的不足。收集大量数据非常困难,仅有较大规模的公司才有能力获得大量数据,且对大规模数据进行精细化的标注费时费力。此外,科研人员设计深度学习网络依赖经验和能力,这导致设计的模型未必是最优的,从而需要不断更新网络结构。这个过程同样费时费力。具有大量参数的深度学习网络通常有着更好性能,然而训练大规模网络需要足够的硬件资源,在缺乏硬件资源的场景下,网络往往无法达到理想的效果。因此,研究在少数精细标注的数据条件下,快速训练自动生成的具有少量参数的深度学习网络使其具有较高的性能的问题是迫切的。
深度学习网络的技术现状如下:
①目前自动生成深度学习网络主要依靠单样本神经架构搜索技术,训练一个包含所有候选操作的超网络,并从训练好的超网络中自动搜索出在当前任务上表现最好的子网络。这一过程极大地减少了手工设计深度学习网络的开销。然而由于包含所有候选网络,超网络的规模较大,训练超网络仍然需要较大的开销。
②目前研究使用少量数据使网络达到较好性能的方法主要包括迁移学习和元学习。这些方法能仅用数个样本就使网络达到非常好的性能。然而,这些方法需要复用整个网络,这对计算和存储资源需求较大。
③目前研究减少网络硬件需求的方法主要是模型压缩技术,包括模型减枝和知识蒸馏。然而,这些方法为每个不同的任务重新执行一次压缩。当任务量庞大时,重新开始压缩的时间开销巨大。
发明内容
为解决在少数精细标注的数据条件下,如何快速训练自动生成的小规模深度学习网络使其在分类任务上具有较高的性能的问题,本发明提出了一种基于架构搜索和自知识蒸馏的学习基因继承方法。学习基因是从预先训练的具有大量参数的深度学习网络(祖先网络)中抽取的关键知识,并以继承的方式初始化具有少量参数的深度学习网络(后代网络)。本发明方法能够使得自动生成的小规模深度学习网络在只有少数精细标注的数据上取得良好的分类性能。
为了达到上述目的,本发明提供如下技术方案:
一种基于架构搜索和自知识蒸馏的学习基因继承方法,包括如下步骤:
为后代网络中没有继承学习基因的层设计可供选择的卷积操作,按顺序搭建超网络;随机从训练祖先网络的源数据中选取增强后的样本,作为超网络和祖先网络的输入,超网络输出卷积操作产生的特征图,与祖先网络输出的特征图计算相似度来更新超网络的参数;选择与祖先网络输出的特征图具有最高相似度的卷积操作构建后代网络;随机从下游任务的数据集中选取样本增强后作为后代网络的输入,输出对数据样本类别预测的概率;计算后代网络中继承学习基因的层和没有继承学习基因的层输出的特征图的相似度,用于蒸馏学习基因的知识;利用分类损失函数和计算相似度的损失函数更新后代网络。
进一步的,包括如下具体步骤:
步骤S1:随机从数据集中选取增强后的数据样本,训练具有大量参数的祖先网络按照祖先网络/>中每一操作块梯度的变化情况,将最后3个操作块提取为可以被继承的学习基因层,这之前的层被称为非学习基因层;
步骤S2:根据祖先网络中非学习基因层输出的特征图尺寸的变化情况,将其划分为N个连续的操作块,具体为:/>其中/>为第i个操作块,/>符号表示相邻操作的连接;
步骤S3:根据祖先网络划分的操作块的数量,搭建具有相同数量的超网络/>具体为/>其中/>为第i个操作块,超网络/>中每个操作块和祖先网络/>相同位置的操作块输出的特征图尺寸一致;
步骤S4:随机从数据集中选取数据样本x0进行增强,固定祖先网络参数,输入到训练好的祖先网络/>中的第一个操作块中并输出第一个操作块产生的特征图f1,具体为随后以特征图f1作为第二个操作块的输入并输出产生的特征图f2,具体为以此类推,最终输出祖先网络/>每个操作块生成的特征图f1,…,fN;
步骤S5:选取祖先网络操作块的输入作为超网络/>中对应的操作块的输入并返回超网络/>每个操作块生成的特征图/>
步骤S6:根据步骤S4和步骤S5输出的特征图,以祖先网络相同操作块输出的特征图作为标签,分别计算超网络/>中相同位置的操作块输出的4个特征图和标签的相似度差异/>来计算梯度下降,以此分别更新候选卷积操作的参数,具体为:
其中,i对应祖先网络和超网络/>第i个操作块,/>为超网络/>第二个操作块输出的特征图;
步骤S7:固定由步骤S6训练得到的超网络中每个候选卷积操作块的参数,随机从测试数据集中选取数据样本s0,输入到训练好的祖先网络/>中,采取和步骤S4相同的方式输出祖先网络/>每个操作块生成的特征图f1,…,fN;
步骤S8:采取和步骤S5相同的方式在采样的测试数据s0上,输出超网络每个操作块生成的特征图/>
步骤S9:在测试数据s0上,利用步骤S6中计算特征图相似度差异的损失函数衡量超网络每个操作块下候选卷积操作块和祖先网络/>对应位置操作块生成的特征图f1,…,fN的差异,选择差异最小的候选卷积操作块为后代网络/>的层;
步骤S10:将从祖先网络中提取到的学习基因层继承到后代网络/>中,构建被学习基因层初始化的后代网络层/>
步骤S11:将所有被选择的候选卷积操作块按顺序组合,形成后代网络的非学习基因层,将这些非学习基因层与从祖先网络/>中提取到的学习基因层按前后顺序组合,构成后代网络/>的特征提取层,最后在后代网络/>的特征提取层后组合全连接层FC,形成完整的后代网络/>具体为:/>
步骤S12:随机从下游任务数据集中选取少量训练数据(x,yc),其中yc是输入数据x的所属类别标签,作为后代网络的输入,训练后代网络/>的下游任务数据集和训练祖先网络/>超网络/>的数据集不一致且没有交集,输出对训练数据样本类别预测的概率/>和后代网络/>特征提取层中产生的特征图o1,…,oN,oN+1:
步骤S13:对于训练数据x,将后代网络预测的类别概率/>与标签yc做交叉熵损失,计算分类损失函数,所属的分类损失函数具体为:
其中log表示对数函数;
步骤S14:为后代网络的每个非学习基因层/>设计额外模块B1,…,BN用来输出与olg大小一致的特征图,将由非学习基因层输出的特征图o1,…,oN分别输入到对应的额外模块B1,…,BN中,输出特征图o’1,…,o’N,具体为:o’i=Bi(oi),i∈[1,N];
步骤S15:将由额外模块输出的特征图o’1,…,o’N分别与由学习基因层输出的特征图oN+1计算相似度差异,计算的公式为:
步骤S16:将步骤S13获得的分类损失函数和步骤S14获得的计算相似度差异损失函数整合,计算总体损失函数,表达式如下:
其中,α是超参数,用来调整两种损失的权重大小;
步骤S17:利用总体损失函数计算梯度下降,以此更新后代网络的参数。
进一步的,所述步骤S1中,增强的方式为:在训练环节,对于CIFAR100数据集,采取随机裁剪样本成长宽均为32、随机水平翻转的增强方式,对于ImageNet-Tiny数据集,采取调整样本长宽为224、随机裁剪和随机水平翻转的增强方式;在测试环节,对于CIFAR100数据集样本不进行数据增强的操作;对于ImageNet-Tiny数据集,只将其中样本的长宽调整为224。
进一步的,所述步骤S2中,祖先网络的非学习基因层被划分为4个操作块。
进一步的,所述步骤S3中,超网络的每一个操作块内都包含4个不同的候选卷积操作块。
进一步的,所述步骤S4中增强方式与步骤S1相同。
进一步的,所述步骤S5生成的特征图的过程包括:第一个操作块输出的特征图为其中k表示超网络/>每个操作块下候选卷积操作块中卷积核的大小,由于超网络/>中的每一个操作块包含4个候选卷积操作块,超网络/>中的每一个操作块将输出4个特征图,第二个操作块输出的特征图为/>以此类推,最终输出超网络/>每个操作块生成的特征图/>
进一步的,所述步骤S9具体包括如下过程:
超网络的某个操作块下的4个候选卷积操作块输出的特征图/>依次和祖先网络/>相同位置的操作块输出的特征图fi计算相似度,公式为:
从中选择值最小的,即选择输出了与fi最相似的特征图的候选卷积操作块作为后代网络/>第i层。
进一步的,所述步骤S14中设计的额外模块包括一个卷积操作层和一个池化操作层。
本发明还提供了一种图像分类的方法,包括如下步骤:
步骤S1:收集并整理用于分类的图像数据集;将该数据集划分为训练集和测试集两个部分,其中对于训练集中的图片,人工标注该图片所属的类别;
步骤S2:获取基于架构搜索和自知识蒸馏的学习基因继承方法产生的模型及其对应的参数,利用参数初始化该模型;
步骤S3:用该模型在步骤S1中制作的训练集上训练N轮,并保留训练好的模型参数;
步骤S4:使用训练好的模型参数初始化步骤S2中产生的模型,并在步骤S1中制作的测试集图像上进行预测,以完成识别图片所属的类别的任务。
与现有技术相比,本发明具有如下优点和有益效果:
本发明方法即使在噪声数据上也具有优秀的分类性能;和随机初始化的方法相比,在达到相似的分类精度时,需要更少的分类数据;在少数精细标注的数据条件下,本发明方法能快速训练自动生成的后代网络使其具有较高的分类性能。
附图说明
图1为本发明方法的框架图。
图2为本发明方法在与随机初始化训练方法在CIFAR100和ImageNet-Tiny数据集上基于ResNet和VGG的框架在有不同层级的噪声干扰下性能对比图。
图3为本发明方法在与随机初始化训练方法在CIFAR100和ImageNet-Tiny数据集上基于ResNet和VGG的框架达到相似分类性能所需要的样本数量对比图。
图4为本发明方法在CIFAR100和ImageNet-Tiny数据集上基于ResNet和VGG的框架与现有的应用在少量样本场景下的主要方法,包括MatchingNet、ProtoNet、Baseline++、BOIL、vanilla Learngene方法分类性能对比表。
具体实施方式
以下将结合具体实施例对本发明提供的技术方案进行详细说明,应理解下述具体实施方式仅用于说明本发明而不用于限制本发明的范围。
本发明提供了一种基于架构搜索和自知识蒸馏的学习基因继承方法,其框架图具体如图1所示。对于预先训练好的祖先网络,我们先将其高层级的语义层抽取为学习基因并继承到后代网络中。采用神经架构搜索方法自动地搜索后代网络的非学习基因层。将搜索到的后代网络的非学习基因层与继承的学习基因层组合构成完整的后代网络。在下游任务中训练后代网络,训练的过程中将后代网络的学习基因层的知识蒸馏到非学习基因层中。具体包括如下步骤:
步骤S1:随机从数据集中选取增强后的数据样本训练具有大量参数的祖先网络按照祖先网络/>中每一操作块梯度的变化情况,将最后3个操作块提取为可以被继承的学习基因层,这之前的层被称为非学习基因层。其中,增强的方式为:在训练环节,对于CIFAR100数据集,采取随机裁剪样本成长宽均为32、随机水平翻转的增强方式,对于ImageNet-Tiny数据集,采取调整样本长宽为224、随机裁剪和随机水平翻转的增强方式。在测试环节,对于CIFAR100数据集样本不进行数据增强的操作;对于ImageNet-Tiny数据集,只将其中样本的长宽调整为224。
步骤S2:根据祖先网络中非学习基因层输出的特征图尺寸的变化情况,将其划分为N个连续的操作块,具体为:/>其中/>为第i个操作块,/>符号表示相邻操作的连接。祖先网络/>的非学习基因层由于输出的特征图的尺寸变化了四次,其中每次特征图的尺寸缩小两倍,所以被划分为4个操作块(4个操作块为本实例中的示例,操作块数量可根据需要调整)。
步骤S3:根据祖先网络划分的操作块的数量,搭建具有相同数量的超网络/>具体为/>其中/>为第i个操作块,超网络/>的每一个操作块内都包含4个不同的候选卷积操作块。对于VGG架构,每一个候选卷积操作块包含卷积层、批归一化层和ReLU激活函数层,4个候选卷积操作块不同之处在于卷积层的卷积核大小的差异,分别是1、3、5和7。对于ResNet架构,每一个候选卷积操作块包含由卷积层、批归一化层、ReLU激活函数层、卷积层、批归一化层构成的残差分支和由卷积核大小为1×1的卷积层、批归一化层构成的跳跃连接分支。在残差分支中,第二个卷积操作的输入和输出通道数保持一致,均为该操作块的输出通道数。只有当残差分支第一个卷积层的步长不为1或者操作块的输入和输出通道数不一致时,才进行跳跃连接的操作。超网络/>中每个操作块和祖先网络/>相同位置的操作块输出的特征图尺寸一致。
步骤S4:随机从数据集中选取数据样本x0进行增强,增强方式与步骤S1一致。固定祖先网络参数,输入到训练好的祖先网络/>中的第一个操作块中并输出第一个操作块产生的特征图f1,具体为/>随后以特征图f1作为第二个操作块的输入并输出产生的特征图f2,具体为/>以此类推,最终输出祖先网络/>每个操作块生成的特征图f1,…,fN;
步骤S5:选取祖先网络操作块的输入作为超网络/>中对应的操作块的输入并返回超网络/>每个操作块生成的特征图,第一个操作块输出的特征图为其中k表示超网络/>每个操作块下候选卷积操作块中卷积核的大小,由于超网络/>中的每一个操作块包含4个候选卷积操作块,因此超网络/>中的每一个操作块将输出4个特征图,第二个操作块输出的特征图为/>以此类推,最终输出超网络/>每个操作块生成的特征图/>
步骤S6:根据步骤S4和步骤S5输出的特征图,以祖先网络相同操作块输出的特征图作为标签,分别计算超网络/>中相同位置的操作块输出的4个特征图和标签的相似度差异/>来计算梯度下降,以此分别更新候选卷积操作的参数,具体为:
其中,i对应祖先网络和超网络/>第i个操作块。在训练超网络/>时,从底层到高层的候选操作块的学习率分别为0.005、0.005、0.005、0.002,采用的优化器为Adam优化器,其中eps参数和weightdecay参数分别为1×10-8和1×10-4。
步骤S7:固定由步骤S6训练得到的超网络中每个候选卷积操作块的参数,随机从测试数据集中选取数据样本s0,输入到训练好的祖先网络/>中,采取和步骤S4相同的方式输出祖先网络/>每个操作块生成的特征图f1,…,fN。
步骤S8:采取和步骤S5相同的方式在采样的测试数据s0上,输出超网络每个操作块生成的特征图/>
步骤S9:在测试数据s0上,利用步骤S6中计算特征图相似度差异的损失函数衡量超网络每个操作块下候选卷积操作块和祖先网络/>对应位置操作块生成的特征图f1,…,fN的差异,具体来说,超网络/>的某个操作块下的4个候选卷积操作块输出的特征图依次和祖先网络/>相同位置的操作块输出的特征图fi计算相似度,公式为:
从中选择值最小的,即选择输出了与fi最相似的特征图的候选卷积操作块作为后代网络/>第i层。
步骤S10:将从祖先网络中提取到的学习基因层继承到后代网络/>中,构建被学习基因层初始化的后代网络层/>
步骤S11:将所有被选择的候选卷积操作块按顺序组合,形成后代网络的非学习基因层,将这些非学习基因层与从祖先网络/>中提取到的学习基因层按前后顺序组合,构成后代网络/>的特征提取层,最后在后代网络/>的特征提取层后组合全连接层FC,形成完整的后代网络/>具体为:/>
步骤S12:随机从下游任务数据集中选取少量训练数据(x,yc),其中yc是输入数据x的所属类别标签,作为后代网络的输入,训练后代网络/>的下游任务数据集和训练祖先网络/>超网络/>的数据集不一致且没有交集,输出对训练数据样本类别预测的概率/>和后代网络/>特征提取层中产生的特征图o1,…,oN,oN+1,具体为:
步骤S13:对于训练数据x,将后代网络预测的类别概率/>与标签yc做交叉熵损失,计算分类损失函数,所属的分类损失函数具体为:
其中log表示对数函数。
步骤S14:由于后代网络特征提取层中产生的特征图o1,…,oN,olg尺度不一致,所以为后代网络/>的每个非学习基因层/>设计额外模块B1,…,BN用来输出与olg大小一致的特征图,额外模块包括一个卷积操作层和一个池化操作层,将由非学习基因层输出的特征图o1,…,oN分别输入到对应的额外模块B1,…,BN中,输出特征图o’1,…,o’N,具体为:o’i=Bi(oi),i∈[1,N];
步骤S15:将由额外模块输出的特征图o’1,…,o’N分别与由学习基因层输出的特征图oN+1计算相似度差异,计算的公式为:
步骤S16:将步骤S13获得的分类损失函数和步骤S14获得的计算相似度差异损失函数整合,计算总体损失函数,表达式如下:
其中,α是超参数,用来调整两种损失的权重大小;
步骤S17:利用总体损失函数计算梯度下降,以此更新后代网络的参数。训练后代网络/>时,学习率被设置为0.0001,采用SGD优化器,其中momentum参数和weight decay参数分别为0.9和5×10-4。
测试例:
将本发明方法应用于图像分类任务时,采用以下步骤:
步骤S1:收集并整理用于分类的图像数据集。将该数据集划分为训练集和测试集两个部分,训练集图片的数量应大大小于测试集的部分。其中对于训练集中的图片,人工标注该图片所属的类别。
步骤S2:获取本发明方法产生的模型及其对应的参数,利用参数初始化该模型。
步骤S3:用该模型在步骤S1中制作的训练集上训练50轮,并保留训练好的模型参数。
步骤S4:使用训练好的模型参数初始化步骤S2中产生的模型,并在步骤S1中制作的测试集图像上进行预测,以完成识别图片所属的类别的任务。
在CIFAR100和ImageNet-Tiny数据集上,基于ResNet和VGG的框架的本发明方法在与随机初始化训练方法在有不同层级的噪声干扰下性能对比。噪声一共分为4个层级,分别是10%、20%、30%和40%的数据被噪声污染,产生错误的类别,结果展示在图2中。在所有的噪声层级中,本发明方法的分类精度都最高。具体来说,在图2(c)中,随着噪声程度的加深,本发明方法的训练精度下降了18.72%,然而随机初始化训练方法的训练精度下降了24.4%。这表明本发明方法在噪声数据上的分类性能优于随机初始化训练方式。
在CIFAR100和ImageNet-Tiny数据集上,基于ResNet和VGG的框架的本发明方法与随机初始化训练方法达到相似分类性能所需要的样本数量对比,结果展示在图3中。可以看出,在达到相似分类精度的时候,本发明方法需要更少的样本数量。具体来说,在图3(a)中,本发明方法只需要20个样本就可以超过随机初始化训练方法用220个样本得到的精度,样本数量减少了11倍。
将本发明方法应用在计算资源受限的场景下,在CIFAR100和ImageNet-Tiny数据集上,基于ResNet和VGG两种深度学习网络架构,与现有的应用在少量样本场景下的主要方法分类性能进行对比,主要方法包括MatchingNet、ProtoNet、Baseline++、BOIL和vanillaLearngene,结果展示在图4中。尽管这些方法在它们各自的领域中都取得了令人满意的成绩,然而由于计算资源受限,这些方法无法复用具有大量参数的网络,这导致它们无法取得更好的分类性能。然而,本发明将包含了大规模网络中核心知识的学习基因继承给小规模的网络,这使得本发明不受计算资源约束的同时也充分利用了大规模网络的信息,因此即使在少数样本场景下本发明取得了最好的结果。
综上,本案提出了一个有效的方法,以解决在少数精细标注的数据条件下,快速训练自动生成的小规模深度学习网络使其具有较高的分类性能的问题。为了使得小规模的后代网络也具有较强的学习能力,将来自大规模的祖先网络的关键知识浓缩为学习基因继承给后代网络。针对手工设计深度学习网络费时费力的问题,使用神经架构搜索方法自动搜索后代网络的非学习基因层。为了增加后代网络中非学习基因层和学习基因层的兼容性,在后代网络学习的过程中,将学习基因层的知识蒸馏到非学习基因层中。在两个数据集上,用两种架构进行了实验,去了良好的性能,证明所提方法的有效性。
本发明方案所公开的技术手段不仅限于上述实施方式所公开的技术手段,还包括由以上技术特征任意组合所组成的技术方案。应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也视为本发明的保护范围。
Claims (10)
1.一种基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,包括如下步骤:
为后代网络中没有继承学习基因的层设计可供选择的卷积操作,按顺序搭建超网络;随机从训练祖先网络的源数据中选取增强后的样本,作为超网络和祖先网络的输入,超网络输出卷积操作产生的特征图,与祖先网络输出的特征图计算相似度来更新超网络的参数;选择与祖先网络输出的特征图具有最高相似度的卷积操作构建后代网络;随机从下游任务的数据集中选取样本增强后作为后代网络的输入,输出对数据样本类别预测的概率;计算后代网络中继承学习基因的层和没有继承学习基因的层输出的特征图的相似度,用于蒸馏学习基因的知识;利用分类损失函数和计算相似度的损失函数更新后代网络。
2.根据权利要求1所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,包括如下具体步骤:
步骤S1:随机从数据集中选取增强后的数据样本,训练具有大量参数的祖先网络按照祖先网络/>中每一操作块梯度的变化情况,将最后3个操作块提取为可以被继承的学习基因层,这之前的层被称为非学习基因层;
步骤S2:根据祖先网络中非学习基因层输出的特征图尺寸的变化情况,将其划分为N个连续的操作块,具体为:/>其中/>为第i个操作块,/>符号表示相邻操作的连接;
步骤S3:根据祖先网络划分的操作块的数量,搭建具有相同数量的超网络/>具体为其中/>为第i个操作块,超网络/>中每个操作块和祖先网络/>相同位置的操作块输出的特征图尺寸一致;
步骤S4:随机从数据集中选取数据样本x0进行增强,固定祖先网络参数,输入到训练好的祖先网络/>中的第一个操作块中并输出第一个操作块产生的特征图f1,具体为随后以特征图f1作为第二个操作块的输入并输出产生的特征图f2,具体为以此类推,最终输出祖先网络/>每个操作块生成的特征图f1,…,fN;
步骤S5:选取祖先网络操作块的输入作为超网络/>中对应的操作块的输入并返回超网络/>每个操作块生成的特征图/>
步骤S6:根据步骤S4和步骤S5输出的特征图,以祖先网络相同操作块输出的特征图作为标签,分别计算超网络/>中相同位置的操作块输出的4个特征图和标签的相似度差异来计算梯度下降,以此分别更新候选卷积操作的参数,具体为:
其中,i对应祖先网络和超网络/>第i个操作块,/>为超网络/>第二个操作块输出的特征图;
步骤S7:固定由步骤S6训练得到的超网络中每个候选卷积操作块的参数,随机从测试数据集中选取数据样本s0,输入到训练好的祖先网络/>中,采取和步骤S4相同的方式输出祖先网络/>每个操作块生成的特征图f1,…,fN;
步骤S8:采取和步骤S5相同的方式在采样的测试数据s0上,输出超网络每个操作块生成的特征图/>
步骤S9:在测试数据s0上,利用步骤S6中计算特征图相似度差异的损失函数衡量超网络每个操作块下候选卷积操作块和祖先网络/>对应位置操作块生成的特征图f1,…,fN的差异,选择差异最小的候选卷积操作块为后代网络/>的层;
步骤S10:将从祖先网络中提取到的学习基因层继承到后代网络/>中,构建被学习基因层初始化的后代网络层/>
步骤S11:将所有被选择的候选卷积操作块按顺序组合,形成后代网络的非学习基因层,将这些非学习基因层与从祖先网络/>中提取到的学习基因层按前后顺序组合,构成后代网络/>的特征提取层,最后在后代网络/>的特征提取层后组合全连接层FC,形成完整的后代网络/>具体为:/>
步骤S12:随机从下游任务数据集中选取少量训练数据(x,yc),其中yc是输入数据x的所属类别标签,作为后代网络的输入,训练后代网络/>的下游任务数据集和训练祖先网络超网络/>的数据集不一致且没有交集,输出对训练数据样本类别预测的概率/>和后代网络/>特征提取层中产生的特征图o1,…,oN,oN+1:
步骤S13:对于训练数据x,将后代网络预测的类别概率/>与标签yc做交叉熵损失,计算分类损失函数,所属的分类损失函数具体为:
其中log表示对数函数;
步骤S14:为后代网络的每个非学习基因层/>设计额外模块B1,…,BN用来输出与olg大小一致的特征图,将由非学习基因层输出的特征图o1,…,oN分别输入到对应的额外模块B1,…,BN中,输出特征图o'1,…,o'N,具体为:o'i=Bi(oi),i∈[1,N];
步骤S15:将由额外模块输出的特征图o′1,…,o′N分别与由学习基因层输出的特征图oN+1计算相似度差异,计算的公式为:
步骤S16:将步骤S13获得的分类损失函数和步骤S14获得的计算相似度差异损失函数整合,计算总体损失函数,表达式如下:
其中,α是超参数,用来调整两种损失的权重大小;
步骤S17:利用总体损失函数计算梯度下降,以此更新后代网络的参数。
3.根据权利要求2所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S1中,增强的方式为:在训练环节,对于CIFAR100数据集,采取随机裁剪样本成长宽均为32、随机水平翻转的增强方式,对于ImageNet-Tiny数据集,采取调整样本长宽为224、随机裁剪和随机水平翻转的增强方式;在测试环节,对于CIFAR100数据集样本不进行数据增强的操作;对于ImageNet-Tiny数据集,只将其中样本的长宽调整为224。
4.根据权利要求2所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S2中,祖先网络的非学习基因层被划分为4个操作块。
5.根据权利要求2所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S3中,超网络的每一个操作块内都包含4个不同的候选卷积操作块。
6.根据权利要求2或3所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S4中增强方式与步骤S1相同。
7.根据权利要求2所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S5生成的特征图的过程包括:第一个操作块输出的特征图为其中k表示超网络/>每个操作块下候选卷积操作块中卷积核的大小,由于超网络/>中的每一个操作块包含4个候选卷积操作块,超网络/>中的每一个操作块将输出4个特征图,第二个操作块输出的特征图为/>以此类推,最终输出超网络/>每个操作块生成的特征图/>
8.根据权利要求1所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S9具体包括如下过程:
超网络的某个操作块下的4个候选卷积操作块输出的特征图/>依次和祖先网络/>相同位置的操作块输出的特征图fi计算相似度,公式为:
从中选择值最小的,即选择输出了与fi最相似的特征图的候选卷积操作块作为后代网络/>第i层。
9.根据权利要求1所述的基于架构搜索和自知识蒸馏的学习基因继承方法,其特征在于,所述步骤S14中设计的额外模块包括一个卷积操作层和一个池化操作层。
10.一种图像分类的方法,其特征在于,包括如下步骤:
步骤S1:收集并整理用于分类的图像数据集;将该数据集划分为训练集和测试集两个部分,其中对于训练集中的图片,人工标注该图片所属的类别;
步骤S2:获取权利要求1-9中任意一项所述的基于架构搜索和自知识蒸馏的学习基因继承方法产生的模型及其对应的参数,利用参数初始化该模型;
步骤S3:用该模型在步骤S1中制作的训练集上训练N轮,并保留训练好的模型参数;
步骤S4:使用训练好的模型参数初始化步骤S2中产生的模型,并在步骤S1中制作的测试集图像上进行预测,以完成识别图片所属的类别的任务。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311232774.2A CN117195951B (zh) | 2023-09-22 | 2023-09-22 | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311232774.2A CN117195951B (zh) | 2023-09-22 | 2023-09-22 | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117195951A true CN117195951A (zh) | 2023-12-08 |
CN117195951B CN117195951B (zh) | 2024-04-16 |
Family
ID=88992223
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311232774.2A Active CN117195951B (zh) | 2023-09-22 | 2023-09-22 | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117195951B (zh) |
Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112508104A (zh) * | 2020-12-08 | 2021-03-16 | 浙江工业大学 | 一种基于快速网络架构搜索的跨任务图像分类方法 |
US20220092381A1 (en) * | 2020-09-18 | 2022-03-24 | Baidu Usa Llc | Neural architecture search via similarity-based operator ranking |
US20220156596A1 (en) * | 2020-11-17 | 2022-05-19 | A.I.MATICS Inc. | Neural architecture search method based on knowledge distillation |
CN114821218A (zh) * | 2021-12-14 | 2022-07-29 | 上海悠络客电子科技股份有限公司 | 基于改进的通道注意力机制的目标检测模型搜索方法 |
US20230076457A1 (en) * | 2021-08-27 | 2023-03-09 | Zhejiang Lab | Edge calculation-oriented reparametric neural network architecture search method |
US20230105590A1 (en) * | 2021-05-17 | 2023-04-06 | Tencent Technology (Shenzhen) Company Limited | Data classification and recognition method and apparatus, device, and medium |
US20230153577A1 (en) * | 2021-11-16 | 2023-05-18 | Qualcomm Incorporated | Trust-region aware neural network architecture search for knowledge distillation |
WO2023091428A1 (en) * | 2021-11-16 | 2023-05-25 | Qualcomm Incorporated | Trust-region aware neural network architecture search for knowledge distillation |
US20230196067A1 (en) * | 2021-12-17 | 2023-06-22 | Lemon Inc. | Optimal knowledge distillation scheme |
CN116503676A (zh) * | 2023-06-27 | 2023-07-28 | 南京大数据集团有限公司 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
CN116524282A (zh) * | 2023-06-26 | 2023-08-01 | 贵州大学 | 一种基于特征向量的离散相似度匹配分类方法 |
-
2023
- 2023-09-22 CN CN202311232774.2A patent/CN117195951B/zh active Active
Patent Citations (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20220092381A1 (en) * | 2020-09-18 | 2022-03-24 | Baidu Usa Llc | Neural architecture search via similarity-based operator ranking |
US20220156596A1 (en) * | 2020-11-17 | 2022-05-19 | A.I.MATICS Inc. | Neural architecture search method based on knowledge distillation |
CN112508104A (zh) * | 2020-12-08 | 2021-03-16 | 浙江工业大学 | 一种基于快速网络架构搜索的跨任务图像分类方法 |
US20230105590A1 (en) * | 2021-05-17 | 2023-04-06 | Tencent Technology (Shenzhen) Company Limited | Data classification and recognition method and apparatus, device, and medium |
US20230076457A1 (en) * | 2021-08-27 | 2023-03-09 | Zhejiang Lab | Edge calculation-oriented reparametric neural network architecture search method |
US20230153577A1 (en) * | 2021-11-16 | 2023-05-18 | Qualcomm Incorporated | Trust-region aware neural network architecture search for knowledge distillation |
WO2023091428A1 (en) * | 2021-11-16 | 2023-05-25 | Qualcomm Incorporated | Trust-region aware neural network architecture search for knowledge distillation |
CN114821218A (zh) * | 2021-12-14 | 2022-07-29 | 上海悠络客电子科技股份有限公司 | 基于改进的通道注意力机制的目标检测模型搜索方法 |
US20230196067A1 (en) * | 2021-12-17 | 2023-06-22 | Lemon Inc. | Optimal knowledge distillation scheme |
CN116524282A (zh) * | 2023-06-26 | 2023-08-01 | 贵州大学 | 一种基于特征向量的离散相似度匹配分类方法 |
CN116503676A (zh) * | 2023-06-27 | 2023-07-28 | 南京大数据集团有限公司 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
Non-Patent Citations (2)
Title |
---|
QIUFENG WANG等: "Learngene: Inheriting Condensed Knowledge from the Ancestry Model to Descendant Models", ARXIV:2305.02279V3, 29 June 2023 (2023-06-29), pages 1 - 17 * |
赵立新;侯发东;吕正超;朱慧超;丁筱玲;: "基于迁移学习的棉花叶部病虫害图像识别", 农业工程学报, no. 07, 8 April 2020 (2020-04-08), pages 192 - 199 * |
Also Published As
Publication number | Publication date |
---|---|
CN117195951B (zh) | 2024-04-16 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111488137B (zh) | 一种基于共同注意力表征学习的代码搜索方法 | |
CN110362723A (zh) | 一种题目特征表示方法、装置及存储介质 | |
CN111400494B (zh) | 一种基于GCN-Attention的情感分析方法 | |
CN108763367B (zh) | 一种基于深度对齐矩阵分解模型进行学术论文推荐的方法 | |
JP6738769B2 (ja) | 文ペア分類装置、文ペア分類学習装置、方法、及びプログラム | |
CN107291895B (zh) | 一种快速的层次化文档查询方法 | |
CN114091450B (zh) | 一种基于图卷积网络的司法领域关系抽取方法和系统 | |
CN111222318A (zh) | 基于双通道双向lstm-crf网络的触发词识别方法 | |
CN112651324A (zh) | 视频帧语义信息的提取方法、装置及计算机设备 | |
CN113408581A (zh) | 一种多模态数据匹配方法、装置、设备及存储介质 | |
CN114049527B (zh) | 基于在线协作与融合的自我知识蒸馏方法与系统 | |
CN115456166A (zh) | 一种无源域数据的神经网络分类模型知识蒸馏方法 | |
CN115985520A (zh) | 基于图正则化矩阵分解的药物疾病关联关系的预测方法 | |
CN110310012B (zh) | 数据分析方法、装置、设备及计算机可读存储介质 | |
CN117421595A (zh) | 一种基于深度学习技术的系统日志异常检测方法及系统 | |
CN115995293A (zh) | 一种环状rna和疾病关联预测方法 | |
CN114579794A (zh) | 特征一致性建议的多尺度融合地标图像检索方法及系统 | |
CN116662565A (zh) | 基于对比学习预训练的异质信息网络关键词生成方法 | |
CN113806579A (zh) | 文本图像检索方法和装置 | |
CN113449076A (zh) | 基于全局信息和局部信息的代码搜索嵌入方法及装置 | |
CN117195951B (zh) | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 | |
CN116861022A (zh) | 一种基于深度卷积神经网络和局部敏感哈希算法相结合的图像检索方法 | |
CN111753995A (zh) | 一种基于梯度提升树的局部可解释方法 | |
CN116662566A (zh) | 一种基于对比学习机制的异质信息网络链路预测方法 | |
CN111259176B (zh) | 融合有监督信息的基于矩阵分解的跨模态哈希检索方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |