CN115564987A - 一种基于元学习的图像分类模型的训练方法及应用 - Google Patents
一种基于元学习的图像分类模型的训练方法及应用 Download PDFInfo
- Publication number
- CN115564987A CN115564987A CN202211128442.5A CN202211128442A CN115564987A CN 115564987 A CN115564987 A CN 115564987A CN 202211128442 A CN202211128442 A CN 202211128442A CN 115564987 A CN115564987 A CN 115564987A
- Authority
- CN
- China
- Prior art keywords
- subtask
- model
- image classification
- training
- meta
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 112
- 238000000034 method Methods 0.000 title claims abstract description 76
- 238000013145 classification model Methods 0.000 title claims abstract description 75
- 238000013138 pruning Methods 0.000 claims abstract description 60
- 238000013139 quantization Methods 0.000 claims description 39
- 238000012360 testing method Methods 0.000 claims description 20
- 238000004590 computer program Methods 0.000 claims description 12
- 210000002569 neuron Anatomy 0.000 claims description 12
- 238000009826 distribution Methods 0.000 claims description 9
- 238000003860 storage Methods 0.000 claims description 8
- 238000005070 sampling Methods 0.000 claims description 6
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000013527 convolutional neural network Methods 0.000 claims description 3
- 238000011056 performance test Methods 0.000 claims description 3
- 238000004364 calculation method Methods 0.000 abstract description 20
- 230000008569 process Effects 0.000 description 28
- 238000010586 diagram Methods 0.000 description 8
- 230000006835 compression Effects 0.000 description 7
- 238000007906 compression Methods 0.000 description 7
- 238000004422 calculation algorithm Methods 0.000 description 6
- 238000005457 optimization Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 5
- 238000002474 experimental method Methods 0.000 description 4
- 239000011159 matrix material Substances 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 3
- 230000006870 function Effects 0.000 description 3
- 238000012935 Averaging Methods 0.000 description 2
- 102100030148 Integrator complex subunit 8 Human genes 0.000 description 2
- 101710092891 Integrator complex subunit 8 Proteins 0.000 description 2
- 241001465754 Metazoa Species 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 238000005520 cutting process Methods 0.000 description 2
- 238000004880 explosion Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 239000013589 supplement Substances 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 239000000969 carrier Substances 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 239000000047 product Substances 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 238000011002 quantification Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于元学习的图像分类模型的训练方法及应用,图像分类技术领域,在元学习的预训练阶段的第一轮内循环迭代结束时,对子任务模型进行剪枝操作,在保证模型精度的情况下省去了很多不必要的计算,大大减少了预训练阶段以及微调阶段的计算量,降低了元学习微调阶段在移动终端设备上的门槛;与此同时,本发明通过预训练完成后的微调阶段来弥补剪枝所带来的精度损失,能够在不改变模型精度的情况下降低了模型的大小,从而在移动终端设备上实现了快速高效的训练。
Description
技术领域
本发明属于图像分类技术领域,更具体地,涉及一种基于元学习的图像分类模型的训练方法及应用。
背景技术
图像分类问题在人类日常生活中的各个场景都有着广泛的应用,例如图片审核、物体识别等。利用深度学习来解决图像分类问题是当今研究的一个热点,然而在深度学习里,大多数情况都会使用某个场景的大量数据来训练模型,当场景或者数据集发生改变时,模型就需要重新训练,训练效率较低。为了解决这一问题,研究人员提出了元学习的概念,以各种学习任务为训练数据,在此基础上训练一个模型,模型就可以拥有在新任务上通过少量样本就可以完成学习的能力。元学习的出现为图像分类模型在不同应用场景、不同数据集上都能够训练以及有效运行提供了解决方案。
然而随着边缘智能设备的普及,为了充分利用边缘设备的计算资源,往往将训练任务从云端迁移到边缘上,使得移动终端设备成为了部署、训练模型不可或缺的重要载体,但是由于移动终端设备存储和计算能力有限,而元学习的预训练阶段往往需要在多个任务上进行训练,计算量较大,且元学习的应用特性导致在实际应用过程中要求元学习的Fine-tune微调阶段是快速、高效的;因此,现有的基于元学习的图像分类模型无法在移动终端设备上实现快速高效的训练。
发明内容
针对现有技术的以上缺陷或改进需求,本发明提供了一种基于元学习的图像分类模型的训练方法及应用,用以解决现有的基于元学习的图像分类模型无法在移动终端设备上实现快速高效的训练的技术问题。
为了实现上述问题,第一方面,本发明提供了一种基于元学习的图像分类模型的训练方法,包括以下步骤:
S1、将采集到的带分类标签的图像样本分为图像类别不同的元训练集和元测试集;
S2、从若干个预先准备的子任务中随机采样得到N1个子任务;对元训练集进行划分,得到N1个子任务所对应的子任务集;其中,每个子任务集均包括支持集和查询集;一个子任务对应一个子任务模型;N1为大于1的整数;
S3、分别采用元训练集下的各子任务集中的支撑集,训练对应的子任务模型;判断当前内循环迭代次数是否为1,若是,则对子任务模型进行剪枝操作,转至步骤S4;否则,直接转至步骤S4;
S4、重复步骤S3进行内循环迭代,直至当前内循环迭代次数达到第一预设迭代次数;
S5、基于各子任务模型的分类损失值的平均值对待训练的图像分类模型中的参数进行更新;其中,各子任务模型的分类损失值为将元训练集下的各子任务集中的查询集输入到对应子任务模型中进行性能测试时得到;
S6、重复步骤S2-S5进行外循环迭代,直至当前外循环迭代次数达到第二预设迭代次数;
S7、从预先准备的多个子任务中采样得到N2个子任务,对元测试集进行划分,得到N2个子任务所对应的子任务集,采用N2个子任务所对应的子任务集中的支撑集对图像分类模型进行微调,得到训练好的图像分类模型;其中,N2为大于或等于1的整数;
其中,子任务模型和图像分类模型均为基于神经网络的模型。
进一步优选地,步骤S3中,对子任务模型依次进行通道剪枝和权重剪枝。
进一步优选地,子任务模型为卷积神经网络;对子任务模型进行剪枝的方法,包括:
S31、通过对子任务模型中BN层的缩放因子施加L1范数惩罚,从而对子任务模型基于BN层的缩放因子进行稀疏化训练;将绝对值小于预设通道剪枝阈值的缩放因子所对应的卷积层通道去除,得到通道卷积后的子任务模型;
S32、对通道卷积后的子任务模型中所有相邻两层间的各神经元连接,去除神经元连接权重的L1范数值小于预设权重剪枝阈值的神经元连接,完成对子任务模型剪枝操作。
进一步优选地,上述基于元学习的图像分类模型的训练方法还包括:在步骤S5和步骤S6之间执行的步骤S8,以及在步骤S6和步骤S7之间执行的步骤S9;
步骤S8包括:对图像分类模型进行量化操作;
步骤S9包括:对图像分类模型进行反量化操作。
进一步优选地,每进行完一轮内循环迭代后,采用元训练集下的子任务集中的查询集测试对应子任务模型的性能,得到对应内循环迭代轮次下子任务模型的分类损失值;此时,上述步骤S5包括:为内循环所有迭代轮次分别赋予不同的权重值,并计算各内循环迭代轮次下子任务模型的分类损失值的加权平均值,基于所得加权平均值对待训练的图像分类模型中的参数进行更新。
进一步优选地,在子任务集中,支持集中的图像样本数量小于查询集中的图像样本数量。
进一步优选地,每个子任务集中,支持集和查询集中图像样本数量的比例为1:15。
进一步优选地,每个子任务集的数据分布相同。
第二方面,本发明提供了一种图像分类方法,包括:将待分类图像输入到采用本发明第一方面所提供的基于元学习的图像分类模型的训练方法训练得到的图像分类模型中,得到图像分类结果。
第三方面,本发明提供了一种图像分类系统,包括:存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时执行本发明第二方面所提供的图像分类方法。
第四方面,本发明还提供了一种计算机可读存储介质,所述计算机可读存储介质包括存储的计算机程序,其中,在所述计算机程序被处理器运行时控制所述存储介质所在设备执行本发明第一方面所提供的基于元学习的图像分类模型的训练方法和/或本发明第二方面所提供的图像分类方法。
总体而言,通过本发明所构思的以上技术方案,能够取得以下有益效果:
1、本发明提供了一种基于元学习的图像分类模型的训练方法,在元学习的预训练阶段的第一轮内循环迭代结束时,对子任务模型进行剪枝操作以将表达能力差的通道以及权重剪掉,从而在后续的循环过程中可以直接跳过涉及剪掉的通道的计算,达到精简模型运算的效果,在保证模型精度的情况下省去了很多不必要的计算,大大减少了预训练阶段以及微调阶段的计算量,降低了元学习微调阶段在移动终端设备上的门槛;与此同时,本发明通过预训练完成后的微调阶段来弥补剪枝所带来的精度损失,能够在不改变模型精度的情况下降低了模型的大小,从而在移动终端设备上实现了快速高效的训练。
2、本发明所提供的基于元学习的图像分类模型的训练方法,采用通道剪枝的压缩方法对子任务模型进行结构化的剪枝,将表达能力较弱的通道剪掉从而实现对子任务模型的粗粒度剪枝,并进一步地将修剪后的子训练模型的全连接层中表达能力较差的权重剪掉,以实现全连接层细粒度的剪枝,以大幅降低计算量进而达到加速的效果。
3、本发明所提供的基于元学习的图像分类模型的训练方法,在预训练阶段完成后,对图像分类模型进行量化,将图像分类模型从32-bit精度量化到8-bit精度,使图像分类模型的大小变为之前的四分之一左右,并在微调阶段通过反量化再次训练来弥补模型量化带来的精度损失,通过与微调阶段相配合,能够在不改变模型精度的情况下进一步降低模型的大小。
4、本发明所提供的基于元学习的图像分类模型的训练方法,在预训练阶段对待训练的图像分类模型中的参数进行更新时,分别计算所有内循环更新轮次的损失值,并分别赋予不同的权重,最后进行加权平均,依旧进行一次更新,这样不仅可以使得之前几轮内循环更新的损失参与优化,也降低了运算所需要的时间,减少了反向传播的次数,进一步增加了模型的稳定性以及泛化性能。
附图说明
图1为本发明实施例1提供的基于元学习的图像分类模型的训练方法流程图;
图2为本发明实施例1提供的训练数据分布示意图;
图3为本发明实施例1提供的预训练单个任务的解析示意图;
图4为本发明实施例1提供的微调阶段的流程图;
图5为本发明实施例1提供的通道剪枝过程示意图;
图6为本发明实施例1提供的权重剪枝示意图;
图7为本发明实施例1提供的元学习预训练阶段的模型量化计算方法流程图;
图8为本发明实施例1提供的预训练阶段的流程图;
图9为本发明实施例1提供的剪枝优化示意图。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。此外,下面所描述的本发明各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
实施例1、
一种基于元学习的图像分类模型的训练方法,如图1所示,包括以下步骤:
S1、将采集到的带分类标签的图像样本分为图像类别不同的元训练集和元测试集;
S2、从若干个预先准备的子任务中随机采样得到N1个子任务;对元训练集进行划分,得到N1个子任务所对应的子任务集;其中,每个子任务集均包括支持集和查询集;一个子任务对应一个子任务模型;N1为大于1的整数;
S3、分别采用元训练集下的各子任务集中的支撑集,训练对应的子任务模型;判断当前内循环迭代次数是否为1,若是,则对子任务模型进行剪枝操作,转至步骤S4;否则,直接转至步骤S4;
S4、重复步骤S3进行内循环迭代,直至当前内循环迭代次数达到第一预设迭代次数(本实施例中取值为5);
S5、基于各子任务模型的分类损失值的平均值对待训练的图像分类模型中的参数进行更新;其中,各子任务模型的分类损失值为将元训练集下的各子任务集中的查询集输入到对应子任务模型中进行性能测试时得到;
S6、重复步骤S2-S5进行外循环迭代,直至当前外循环迭代次数达到第二预设迭代次数(本实施例中取值为5);
S7、从预先准备的多个子任务中采样得到N2个子任务,对元测试集进行划分,得到N2个子任务所对应的子任务集,采用N2个子任务所对应的子任务集中的支撑集对图像分类模型进行微调,得到训练好的图像分类模型;其中,N2为大于或等于1的整数;
其中,子任务模型和图像分类模型均为基于神经网络的模型。
需要说明的是,本发明所提供的基于元学习的图像分类模型的训练方法主要分为两个阶段:第一个阶段为元学习的预训练阶段,通过在给定的子任务集上训练网络的初始化参数,具体对应上述的步骤S2-S6;第二个阶段,将初始化参数根据新任务进行微调,具体对应上述的步骤S7。整体思路在于通过迭代训练不同的Task,训练出对于不同Task敏感的参数,模型可以拥有对新任务敏感并能快速适应新任务的能力,可以使模型在新任务上通过少量的样本仅一或多次的梯度迭代中微调参数。
每个子任务集的数据分布相同。在子任务集中,支持集中的图像样本数量小于查询集中的图像样本数量。本实施例中,每个子任务集中,支持集和查询集中图像样本数量的比例为1:15。具体地,在整个流程中,所有的训练数据以及测试数据都是以子任务为单位;模型的训练过程都是围绕子任务展开的。在训练图像分类模型Mfine-tune的过程中,类别包括P1~P5,每类5个已标注图像样本用于训练,另外每类有15个已标注图像样本用来测试。实验的训练数据除了P1~P5中已标注的图像样本外,还包括另外十个类别的图片C1~C10,每类50个已标注样本,用于帮助训练元学习模型Mmeta。N-wayK-shot用来描述在一个Task中,包含N个数据样本分类,每个分类中有K个训练样本。常见的数据样本配置有:5-way1-shot、5way5-shot、20way1-shot、20way5-shot,本实施例中采用5-way5-shot的数据配置。
此时,C1~C10即元训练集,C1~C10包含的样本共计500个,记为Dmeta-train,是用来训练子任务模型Mmeta的数据集。与之相对的,P1~P5即元测试集,P1~P5的样本共计100个,记为Dmeta-test,是用来训练和测试图像分类模型Mfine-tune的数据集。
根据5-way5-shot的设置,在训练子任务模型Mmeta的阶段,从C1~C10中随机取5个类别,每个类别再随机取20个已标注样本,组成一个TaskT,其中再从20个样本中抽取5个已标注样本称为TaskT的SupportSet,另外15个样本称为TaskT的QuerySet。这个TaskT的所有样本,可以类比于普通机器学习中的训练数据样本。
按照这个模式随机抽取若干个Task就组成了训练时所需要的一个Batch,训练数据分布如图2所示,从图中可以看到整个训练集是一个Task的集合,称为Tasks,其中每一行都是一个Task,每个Task的训练集,也就是SupportSet,都是由5类动物的1个样本组成,每个Task的QuerySet也是由5类动物的1个样本组成。图中上半部分是实验的训练集,图中下半部分是实验的测试集,每个Task中数据的分布是相同的。
在第一个阶段,即预训练阶段中,首先通过元训练集下的子任务的支持集对各自子任务的子任务模型进行训练,分别训练出针对各自子任务的子任务模型参数,主要是用来获得每个任务的初始化参数并进行拟合,训练过程中第i个子任务模型参数θ'i的更新公式为:其中,为第i个子任务模型的交叉熵损失函数;α为内循环学习率;预训练单个任务的解析示意图如图3所示;具体地,在这个过程中,先对待训练的图像分类模型进行复制得到多个子任务模型,然后分别对各子任务模型进行训练,进行反向传播和参数更新。其次、用不同子任务中的查询集分别去测试子任务模型的性能,计算损失并进行平均,然后通过梯度下降的方式对待训练的图像分类模型中的参数θ进行更新,更新过程如公式所示;其中,β为外循环学习率;Ti为第i个子任务;p(T)为元训练集中的子任务分布;在这个过程中,并不对各子任务模型的参数进行更新,而是直接对图像分类模型的参数进行更新。
进一步说明的是,在上述过程中,在一个训练轮次会进行两次梯度的更新,分为内循环和外循环两次更新,其中外循环的更新是基于内循环最后一轮更新的损失函数的平均值。经过实验发现上述训练很不稳定,在做外循环的时候,由于权重参数需要多次通过网络产生,就会从外循环梯度回传到内循环,并且网络的每一层都会被回传好几次,这样就容易产生梯度爆炸或者梯度衰减,并且如果网络结构的深度很大时,很可能只传完一个网络就会出现梯度爆炸或者梯度衰减,那么当梯度值会出现较大的问题时,参数的更新自然就会出事,通过这个参数产生的结果也出现不稳定的现象;具体地,由于图像分类模型参数的更新取决于所有查询集上最后一步的损失值,因此通过这个值去做反向传播的时候,之前几轮内循环的参数只能被隐式的优化,他们产生的损失无法直接拿来显示的优化,因此算法的稳定性较差。
为了解决上述问题,一种具体实施方式是将改单步为多步,即将最后一轮内循环完成后再对图像分类模型中的参数进行更新改为每一轮内循环完成后都对图像分类模型中的参数进行更新,原理就是消耗训练时间,以更大的计算量的代价来换取算法的稳定性,因为模型在训练的过程中反向传播的次数会变多,会使算法更加稳定;但是该方法是通过时间来换取稳定性,成本较高,因此,在一种优选方案中,采用单步加权损耗优化方式,将预训练阶段的外循环更新的时机从内循环完成后更新改为了内循环每进行一步就计算损失,利用每一步损失的加权平均值更新,进一步增加了模型的稳定性以及泛化性能;即每进行完一轮内循环迭代,就采用元训练集下的子任务集中的查询集测试对应子任务模型的性能,得到对应内循环迭代轮次下子任务模型的分类损失值;此时,上述步骤S5包括:为内循环所有迭代轮次分别赋予不同的权重值,并计算各内循环迭代轮次下子任务模型的分类损失值的加权平均值,基于所得加权平均值对待训练的图像分类模型中的参数进行更新。具体地,可以看出导致算法稳定性不好的原因在于之前几轮内循环更新的损失无法参与优化,都是以最后一轮更新的损失为准,那么可以分别计算所有内循环更新轮次的损失值,并分别赋予不同的权重,最后进行加权平均,依旧进行一次更新,这样不仅可以使得之前几轮内循环更新的损失参与优化,也能降低运算所需要的时间,减少反向传播的次数。
第二个阶段即微调阶段的流程图如图4所示,这个阶段主要在测试数据集上进行训练,数据集依旧划分为支撑集和查询集,第一个阶段中训练任务的目的是找到一个好的超参设置,第二个阶段利用已经训练好的参数进行微调,让模型能够快速适应新的数据分布,算法大致流程和预训练阶段相同,不同点在于微调阶段不需要再初始化参数,而是利用训练好的图像分类模型的参数。而且微调阶段不需要再形成Batch,只需要从元测试集上抽取一个子任务Task进行学习,利用这个Task的支撑集训练模型,利用查询集测试模型。但是为了避免极端的情况,会从元测试集上随机抽取多个子任务Task,分别对预训练好的模型图像分类模型进行微调,最后对测试结果进行平均,从而避免极端的情况,并且因为测试的子任务Task的查询集是用来测试模型的,标签对模型是未知的,因此微调阶段没有第二次的梯度更新,而是直接利用第一次梯度计算的结果更新参数。
需要说明的是,元学习为小样本学习提供了新的思路和解决方法,旨在通过预训练将模型的参数训练到一个合适的位置,再通过微调阶段少数的训练轮次,使用新的任务和数据分布将参数精调到一个最佳的位置。这个过程中,虽然是小样本学习,每个子任务中的支撑集很少,但是子任务的数目还是需要有一定规模以保证让模型学会适应新任务的能力,导致轮次较多,需要进行大量的计算,为了解决上述问题,本发明对子任务模型依次进行剪枝操作,以避免一些不必要的计算,同时也可以减小预训练阶段结束之后模型的大小以及微调阶段计算的次数和时间,另外还可以减少预训练阶段过拟合的问题,提高模型的泛化能力。
具体地,在一种可选实施方式一下,上述步骤S3中,对子任务模型依次进行通道剪枝和权重剪枝,以实现在不损失过多原算法准确率的同时,减少子任务模型中的参数量和推理过程中的计算量,以此来提高模型推理过程的效率。具体地,以子任务模型为卷积神经网络,主要包括卷积层、BatchNorm、激活层;对子任务模型进行剪枝的方法,包括:
S31、通过对子任务模型中BN层的缩放因子施加L1范数惩罚,从而对子任务模型基于BN层的缩放因子进行稀疏化训练;将绝对值小于预设通道剪枝阈值(本实施例中取值为表达能力降序排序后的30%)的缩放因子所对应的卷积层通道去除,得到通道卷积后的子任务模型;
具体地,本发明通过裁剪每一层表达能力差的卷积层通道,达到模型压缩的目的。本实施方式选择BN层的缩放因子作为裁剪卷积层通道的衡量指标;由于L1范数惩罚会将BN层的缩放因子限制在0的附近,可以方便辨别出模型中表达能力差的通道,且正则化对性能的损失也比较小,甚至某些情况它会使得模型拥有更高的泛化性能,因此,在训练时,对BN层的缩放因子添加L1正则化,达到稀疏化的作用,这样就可以通过BN层的缩放因子趋于0来识别不重要的卷积层通道。具体地,通过公式 推导出BN层的缩放因子(记为γ参数)可以反映出卷积通道表达能力的大小。根据BN层通过L1范数计算后的γ参数的大小,按照排序筛选出绝对值小的γ参数所对应的卷积核或卷积层通道相应去除。通过上述过程将筛选出来的不重要的卷积核或者卷积核的相应通道去除,将留下的卷积核重新排列组合形成新的卷积层权重。当选中的通道被剪掉后会带来一定的精度损失,而在后续的微调阶段会进行精度弥补。
需要说明的是,在整个全局剪枝操作过程中,如果设定的剪枝率足够大时,可能会出现剪枝阈值大于某一层BN中的参数γ的最大值,基于这种考虑,在剪枝之前需要计算各个BN层参数γ最大值,取其中的最小值定为剪枝阈值的上限。当通过计算确定了每一层需要裁剪的通道后,就可以对整个网络进行剪枝,剪枝是通过将卷积层参数和BN层参数重新组合来实现,判断该剪去哪些卷积核通道的依据以及过程如图5所示的通道剪枝过程中卷积核的示意图。
这个过程中需要将输入通道的剪枝掩码Min和本层通道的剪枝掩码Mx所对应的卷积核及通道都剪掉,每一层的剪枝掩码根据当前卷积层BN层参数γ与剪枝阈值的判断组成。如图5所示,其最左边为该层的输入特征图,可以看到第二层通道被判断为对应BN层参数γ较小的通道,也就是表达能力较差的,需要被剪掉,那么根据卷积的计算过程,对应的卷积核张量的第二行也需要被剪掉,因为每个卷积核的第二层通道对应的是输入特征图的第二层;同理,最右边为该层的输出特征图,可以看到第一层通道也是表达能力差的通道,需要被剪枝,对应的卷积核张量的第一列也需要被剪掉。
S32、在全连接层采用权重剪枝,对通道卷积后的子任务模型中所有相邻两层间的各神经元连接,去除神经元连接权重的L1范数值小于预设权重剪枝阈值(本实施例中取值为表达能力降序排列的后20%)的神经元连接,完成对子任务模型剪枝操作。
需要说明的是,步骤S31采用通道剪枝的压缩方法对子任务模型进行结构化的剪枝,这个是最有效的粗粒度剪枝方法,可以大幅降低计算量进而达到加速的效果,步骤S32作为剪枝方法的补充,对修剪后的模型进行全连接层细粒度的剪枝,可以与通道剪枝做到相互补充的效果。其主体的剪枝思想和步骤和通道剪枝相同,都是需要先预训练一个网络,然后评估每一个神经元的权重,接下来根据剪枝率计算剪枝的阈值,然后将阈值以下的权重剪掉,这个过程会损失一些精度,接下来通过微调阶段来将损失的精度弥补回来。
步骤S32中,以获取通道卷积后的子任务模型中第x层和x+1层这相邻两层为例。如图6所示,计算第x层与x+1层间各个神经元连接的权重,将ix与ky的连接权重记录在矩阵f1的第x行y列,接下来计算每个神经元连接的L1范数,根据设定的剪枝率计算权重剪枝阈值,将低于权重剪枝阈值的权重进行标记,具体的标记方法为在矩阵f2中对应的位置置为0,没有标记的权重置为1,矩阵f2即为掩码,那么在最后一步剪枝的过程中,只需要将矩阵f1中的权重与矩阵f2中对应的掩码相乘即可。
进一步地,在深度学习网络的训练过程中,模型参数的乘加计算量是巨大的,通常数以百万记,这就需要一些专业的云计算平台才能满足实时计算的需求,这对边缘智能设备上的产品来说是不可接受的,本发明采用模型量化的方式来降低计算量。量化,即将网络中参数的精度从高精度转化成低精度的操作过程,量化算法根据位数不同有:二值量化、三值量化等量化方法,具体压缩的效果例如32位浮点数转化位8位整型数int8,将模型的大小压缩到了之前的四分之一,量化可以带来更小的模型尺寸,更低的功耗,更快的计算速度。以四位量化为例,也就是2-bit量化,将32-bit精度的权重量化为-1.0、0、1.5、2.0四个值,在这个过程中每个参数所占用的空间从32-bit压缩为了2-bit,缩减为十六分之一,这就是模型量化所带来的好处。
在一种可选实施方式二下,上述基于元学习的图像分类模型的训练方法还包括:在步骤S5和步骤S6之间执行的步骤S8,以及在步骤S6和步骤S7之间执行的步骤S9;
步骤S8包括:对图像分类模型进行量化操作;
步骤S9包括:对图像分类模型进行反量化操作。
具体地,图像分类模型参数经过量化后将浮点数转换为了整数类型,虽然恢复后的数值与压缩前存在一定的误差,但是事实证明,模型精度对这种压缩造成的噪音表现的很健壮,元学习预训练阶段的模型量化计算方法流程图如图7所示,可以看到模型会根据最大值与最小值的范围映射到8bit的数值范围。对于模型量化任务来说,第一步通常先在输入数据中统计出相应的最小值和最大值,这里的输入数据是图像分类模型中的权重或者激活值;第二步为选择合适的量化类型,其中又包含量化位数的选择以及量化方式的选择,比如对称量化或者非对称量化量化等;第三步则根据选择的量化类型、统计出的最小值以及最大值计算出量化参数偏移量(Zero point)和缩放因子(Scale),方便后续的计算;第四步则是根据标定数据对图像分类模型中的参数进行量化操作,比如选取的量化位数为8位,量化方式采取对称量化,则将参数从FP32转为INT8,最后一步则是验证模型量化后的模型性能,如果量化后带来的负面影响过大,即性能下降严重,则需要尝试使用不同的量化方式来计算Zero point和Scale这两个参数,然后再重新执行上面的操作,知道模型的压缩率和性能达到一个较为平衡的状态。
通常来说,神经网络的参数都是使用的32bit的浮点型数表示,在实际的训练过程中发现不需要保留那么高的精度,比如可以用0~255的uint类型的整型来表示原来32个bit所表示的精度,用一小部分的精度差值来换取空间。此外,SGD过程一般6~8bit的精度就够用,因此合理的量化网络也可以保证精度的情况下减小模型的存储体积。
本实施方式下,将将图像分类模型中的权重从32bit量化到8bit。
首先通过计算得到图像分类模型中权重的最小值min和最大值max。
紧接着计算量化后的零点,具体为:Z=Qmax-Rmax÷S;其中,Qmax为量化后的最大值,Rmax为全精度的最大值;
基于量化映射公式r=Round(S(q-Z))对图像分类模型中的权重进行量化;其中,r为量化前的浮点权重;q为量化后的权重;
最后在实际的计算过程中再进行反量化,具体为:R=(Q-Z)*S;其中,R为反量化后的浮点权重值;Q为量化后的权重值。
本实施方式在预训练结束后对图像分类模型进行量化,在保证模型性能的前提下将模型参数的精度从FP32降到INT8,能够在几乎不损失结果精度的情况下,将模型的大小降到了原来的25%。
综合上述实施方式一和实施方式二的预训练阶段的流程图如图8所示,如图8所示,预训练阶段总共分为三个部分:第一个部分为在元训练集上的子任务集的支撑集上进行训练,第二个部分为在固定的训练周期进行模型剪枝,第三个部分为在元训练集上的子任务集的查询集上进行测试,计算损失并进行更新,最后对模型进行量化,采用8-bit量化,将全精度的权重量化到int范围。通过对模型进行剪枝和量化,大大降低了模型在预训练阶段的计算量和预训练产出的模型大小,具体过程如图9所示。
实施例2、
一种图像分类方法,包括:将待分类图像输入到采用本发明实施例1所提供的基于元学习的图像分类模型的训练方法训练得到的图像分类模型中,得到图像分类结果。
相关技术方案同实施例1,这里不做赘述。
实施例3、
一种图像分类系统,包括:存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时执行本发明实施例2所提供的图像分类方法。
相关技术方案同实施例1,这里不做赘述。
实施例4、
一种计算机可读存储介质,所述计算机可读存储介质包括存储的计算机程序,其中,在所述计算机程序被处理器运行时控制所述存储介质所在设备执行本发明实施例1所提供的基于元学习的图像分类模型的训练方法和/或本发明实施例2所提供的图像分类方法。
本领域的技术人员容易理解,以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种基于元学习的图像分类模型的训练方法,其特征在于,包括以下步骤:
S1、将采集到的带分类标签的图像样本分为图像类别不同的元训练集和元测试集;
S2、从若干个预先准备的子任务中随机采样得到N1个子任务;对所述元训练集进行划分,得到N1个子任务所对应的子任务集;每个子任务集均包括支持集和查询集;一个子任务对应一个子任务模型;N1为大于1的整数;
S3、分别采用所述元训练集下的各子任务集中的支撑集,训练对应的子任务模型;判断当前内循环迭代次数是否为1,若是,则对子任务模型进行剪枝操作,转至步骤S4;否则,直接转至步骤S4;
S4、重复步骤S3进行内循环迭代,直至当前内循环迭代次数达到第一预设迭代次数;
S5、基于各子任务模型的分类损失值的平均值对待训练的图像分类模型中的参数进行更新;所述各子任务模型的分类损失值为将所述元训练集下的各子任务集中的查询集输入到对应子任务模型中进行性能测试时得到;
S6、重复步骤S2-S5进行外循环迭代,直至当前外循环迭代次数达到第二预设迭代次数;
S7、从预先准备的多个子任务中采样得到N2个子任务,对所述元测试集进行划分,得到N2个子任务所对应的子任务集,采用N2个子任务所对应的子任务集中的支撑集对图像分类模型进行微调,得到训练好的图像分类模型;N2为大于或等于1的整数;
其中子任务模型和图像分类模型均为基于神经网络的模型。
2.根据权利要求1所述的基于元学习的图像分类模型的训练方法,其特征在于,所述步骤S3中,对子任务模型依次进行通道剪枝和权重剪枝。
3.根据权利要求2所述的基于元学习的图像分类模型的训练方法,其特征在于,子任务模型为卷积神经网络;对子任务模型进行剪枝的方法,包括:
S31、通过对子任务模型中BN层的缩放因子施加L1范数惩罚,从而对子任务模型基于BN层的缩放因子进行稀疏化训练;将绝对值小于预设通道剪枝阈值的缩放因子所对应的卷积层通道去除,得到通道卷积后的子任务模型;
S32、对所述通道卷积后的子任务模型中所有相邻两层间的各神经元连接,去除神经元连接权重的L1范数值小于预设权重剪枝阈值的神经元连接,完成对子任务模型剪枝操作。
4.根据权利要求1所述的基于元学习的图像分类模型的训练方法,其特征在于,还包括:在所述步骤S5和所述步骤S6之间执行的步骤S8,以及在所述步骤S6和所述步骤S7之间执行的步骤S9;
所述步骤S8包括:对图像分类模型进行量化操作;
所述步骤S9包括:对图像分类模型进行反量化操作。
5.根据权利要求1-4任意一项所述的基于元学习的图像分类模型的训练方法,其特征在于,每进行完一轮内循环迭代后,采用所述元训练集下的子任务集中的查询集测试对应子任务模型的性能,得到对应内循环迭代轮次下子任务模型的分类损失值;此时,所述步骤S5包括:为内循环所有迭代轮次分别赋予不同的权重值,并计算各内循环迭代轮次下子任务模型的分类损失值的加权平均值,基于所得加权平均值对待训练的图像分类模型中的参数进行更新。
6.根据权利要求1-4任意一项所述的基于元学习的图像分类模型的训练方法,其特征在于,在子任务集中,支持集中的图像样本数量小于查询集中的图像样本数量。
7.根据权利要求1-4任意一项所述的基于元学习的图像分类模型的训练方法,其特征在于,每个子任务集的数据分布相同。
8.一种图像分类方法,其特征在于,包括:将待分类图像输入到采用权利要求1-7任意一项所述的基于元学习的图像分类模型的训练方法训练得到的图像分类模型中,得到图像分类结果。
9.一种图像分类系统,其特征在于,包括:存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时执行权利要求8所述的图像分类方法。
10.一种计算机可读存储介质,所述计算机可读存储介质包括存储的计算机程序,其中,在所述计算机程序被处理器运行时控制所述存储介质所在设备执行权利要求1-7任意一项所述的基于元学习的图像分类模型的训练方法和/或权利要求8所述的图像分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211128442.5A CN115564987A (zh) | 2022-09-16 | 2022-09-16 | 一种基于元学习的图像分类模型的训练方法及应用 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211128442.5A CN115564987A (zh) | 2022-09-16 | 2022-09-16 | 一种基于元学习的图像分类模型的训练方法及应用 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115564987A true CN115564987A (zh) | 2023-01-03 |
Family
ID=84740579
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211128442.5A Pending CN115564987A (zh) | 2022-09-16 | 2022-09-16 | 一种基于元学习的图像分类模型的训练方法及应用 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115564987A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117422960A (zh) * | 2023-12-14 | 2024-01-19 | 广州华微明天软件技术有限公司 | 一种基于元学习的图像识别持续学习方法 |
-
2022
- 2022-09-16 CN CN202211128442.5A patent/CN115564987A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117422960A (zh) * | 2023-12-14 | 2024-01-19 | 广州华微明天软件技术有限公司 | 一种基于元学习的图像识别持续学习方法 |
CN117422960B (zh) * | 2023-12-14 | 2024-03-26 | 广州华微明天软件技术有限公司 | 一种基于元学习的图像识别持续学习方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108764471B (zh) | 基于特征冗余分析的神经网络跨层剪枝方法 | |
WO2019060670A1 (en) | LOW PROFOUND CONVOLUTIVE NETWORK WEIGHT COMPRESSION | |
WO2020238237A1 (zh) | 一种基于幂指数量化的神经网络压缩方法 | |
CN114677548B (zh) | 基于阻变存储器的神经网络图像分类系统及方法 | |
CN113159276A (zh) | 模型优化部署方法、系统、设备及存储介质 | |
CN112183742A (zh) | 基于渐进式量化和Hessian信息的神经网络混合量化方法 | |
CN111127360A (zh) | 一种基于自动编码器的灰度图像迁移学习方法 | |
CN110647974A (zh) | 深度神经网络中的网络层运算方法及装置 | |
CN112686376A (zh) | 一种基于时序图神经网络的节点表示方法及增量学习方法 | |
CN111160524A (zh) | 一种两阶段的卷积神经网络模型压缩方法 | |
CN115564987A (zh) | 一种基于元学习的图像分类模型的训练方法及应用 | |
CN113392973A (zh) | 一种基于fpga的ai芯片神经网络加速方法 | |
CN115311506A (zh) | 基于阻变存储器的量化因子优化的图像分类方法及装置 | |
CN114511042A (zh) | 一种模型的训练方法、装置、存储介质及电子装置 | |
CN116188878A (zh) | 基于神经网络结构微调的图像分类方法、装置和存储介质 | |
CN117392406A (zh) | 一种单阶段实时目标检测模型低位宽混合精度量化方法 | |
CN115905546B (zh) | 基于阻变存储器的图卷积网络文献识别装置与方法 | |
CN116301914A (zh) | 基于gap8微处理器的卷积神经网络部署方法 | |
CN114372539B (zh) | 基于机器学习框架的分类方法及相关设备 | |
CN116306879A (zh) | 数据处理方法、装置、电子设备以及存储介质 | |
CN116306808A (zh) | 一种联合动态剪枝和条件卷积的卷积神经网络压缩方法及装置 | |
CN113157453B (zh) | 一种基于任务复杂度的高能效目标检测任务动态调度方法 | |
CN115392441A (zh) | 量化神经网络模型的片内适配方法、装置、设备及介质 | |
CN112488291B (zh) | 一种神经网络8比特量化压缩方法 | |
CN114519423A (zh) | 用于压缩神经网络的方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |