CN116503676A - 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 - Google Patents
一种基于知识蒸馏小样本增量学习的图片分类方法及系统 Download PDFInfo
- Publication number
- CN116503676A CN116503676A CN202310764468.7A CN202310764468A CN116503676A CN 116503676 A CN116503676 A CN 116503676A CN 202310764468 A CN202310764468 A CN 202310764468A CN 116503676 A CN116503676 A CN 116503676A
- Authority
- CN
- China
- Prior art keywords
- network
- incremental
- distillation
- sample
- category
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 55
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 34
- 238000004821 distillation Methods 0.000 claims abstract description 71
- 238000012549 training Methods 0.000 claims abstract description 51
- 230000006870 function Effects 0.000 claims description 36
- 230000008014 freezing Effects 0.000 claims description 18
- 238000007710 freezing Methods 0.000 claims description 18
- 239000013598 vector Substances 0.000 claims description 8
- 238000004590 computer program Methods 0.000 claims description 5
- 241000764238 Isis Species 0.000 claims description 2
- XLYOFNOQVPJJNP-UHFFFAOYSA-N water Substances O XLYOFNOQVPJJNP-UHFFFAOYSA-N 0.000 claims 1
- 230000000694 effects Effects 0.000 abstract description 2
- 238000012360 testing method Methods 0.000 description 4
- 238000001514 detection method Methods 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 2
- 230000007786 learning performance Effects 0.000 description 2
- 238000005259 measurement Methods 0.000 description 2
- 238000007781 pre-processing Methods 0.000 description 2
- 101100163879 Acremonium egyptiacum ascI gene Proteins 0.000 description 1
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- RTAQQCXQSZGOHL-UHFFFAOYSA-N Titanium Chemical compound [Ti] RTAQQCXQSZGOHL-UHFFFAOYSA-N 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013434 data augmentation Methods 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000010438 heat treatment Methods 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 239000000463 material Substances 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000003014 reinforcing effect Effects 0.000 description 1
- 239000004576 sand Substances 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Vaporization, Distillation, Condensation, Sublimation, And Cold Traps (AREA)
Abstract
本发明公开了一种基于知识蒸馏小样本增量学习的图片分类方法及系统,通过蒸馏网络判断输入图片所属类别,该方法利用预热网络计算类别原型,对于每个episode执行一个小样本分类任务;然后将预热网络的参数作为增量网络的初始值,计算新增类别的类别原型,对每个episode执行一个小样本增量任务;将预热网络和增量网络通过知识蒸馏形成蒸馏网络,将增量网络的参数作为蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;利用所述蒸馏网络计算相似度得到输入图像的所属类别;本发明通过预热、增量学习和知识蒸馏三个阶段减少小样本过拟合问题,缓解了增量学习中的类别遗忘问题,提高了小样本问题下的图片分类效果。
Description
技术领域
本发明涉及一种图片分类方法及系统,尤其是基于知识蒸馏小样本增量学习的图片分类方法及系统。
背景技术
随着人工智能技术的不断发展和应用,增量学习因其强大的适用性逐渐受到了学术界和工业界的关注。增量学习,指的是对于一个已经训练好的模型,在面临新数据时,不需要使用全部数据重新训练整个模型,而是渐进地对模型进行更新。通过不断修正和加强以前的知识,使得模型在新数据上具有泛化性。增量学习降低了模型训练过程中对时间和空间的需求,广泛应用于推荐系统、图片分类等领域中。当前大多数增量学习方法的训练需要大量新类样本,而在现实环境中,受到人力、物力和客观因素的制约,数据获取往往十分困难导致样本量稀少,这严重影响了传统增量学习方法的性能。
知识蒸馏作为一种重要的学习范式,通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息训练这个轻量化模型,以达到更好的性能和精度。其中来自大模型输出的监督信息称之为知识,而小模型学习迁移来自大模型的监督信息称之为蒸馏。然而,传统的知识蒸馏方法往往依赖于大量的训练样本。在小样本场景中,由于缺乏足够多的样本,新旧类别之间的样本数量差异较大,模型在训练或预测过程中往往倾向于更大的旧类训练样本集,容易造成严重的类别不平衡问题导致性能下降,基类与新类样本之间的不平衡也使得模型难以学习新类别。
发明内容
发明目的:本发明的目的是提供一种能够提高小样本学习性能的基于知识蒸馏小样本增量学习的图片分类方法;本发明的第二目的是提供一种能够提高小样本学习性能的基于知识蒸馏小样本增量学习的图片分类系统。
技术方案:本发明所述的基于知识蒸馏小样本增量学习的图片分类方法,通过蒸馏网络判断输入图片所属类别,包括如下步骤:
(1)将随机初始化的ResNet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;
(2)冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;
(3)冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;
(4)利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
进一步地,步骤(1)中利用所述预热网络计算类别原型包括:类别c的类别原型p c 为:
;
其中,S i c 表示小样本分类任务中支持集S i 中类别为c的数据集,|S i c |表示S i c 的大小,x t 为小样本分类任务中样本的特征向量,y t 为对应样本的标签,代表该样本所属的类别;为预热网络。
进一步地,步骤(1)中训练所述预热网络的预热损失函数L H 为:
;
其中,Q i 为小样本分类任务的查询集,x q 为查询集Q i 中的新样本,y q 为对应样本的标签,代表该样本所属的类别;为归一化分类函数,每个类别c的归一化分类分数为/>,/>为softmax函数;/>为权重,每个类别c的权重为/>,d c 为类别c中类别原型与其他同类别样本的距离和,m为S i 中除类别c以外的其他类别。
进一步地,步骤(2)中利用所述增量网络计算新增类别的类别原型包括:
新增类别c' 的类别原型p' c' 为:
;
其中,S_new j c' 表示小样本增量任务中增量支持集S_new j 中类别为c' 的数据集;|S_new j c' |表示S_new j c' 的大小,x t' 为小样本增量任务中样本的特征向量,y t' 为对应样本的标签,代表该样本所属的类别;为增量网络。
进一步地,步骤(2)中训练所述增量网络的增量损失函数L R 为:
;
其中,Q_new j 为小样本增量任务中的增量查询集,x q' 为增量查询集Q_new j 中的新样本,y q' 为对应样本的标签,代表该样本所属的类别;为权重,/>为增量网络;
;
Q_new j c' 为Q_new j 中类别为的数据集,x n 表示Q_new j c' 中类别为c'的其他样本,y n 为对应样本的标签。
进一步地,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛前,利用蒸馏损失函数对所述增量网络进行训练直至收敛,所述蒸馏损失函数的计算方法为:
使用任务无关的数据集D u 进行蒸馏学习,根据D u 在预热网络和增量网络上的输出分布f θ (x u )和g φ (x u )分别计算蒸馏损失项:
;
;
蒸馏损失函数为;
其中为蒸馏网络,T为蒸馏温度系数,x u 为D u 中的样本,λ为参数。
进一步地,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛包括:
利用增量损失函数更新增量网络参数,冻结增量网络;计算蒸馏损失函数并更新蒸馏网络参数和增量网络参数;冻结蒸馏网络,利用更新的增量网络参数重新计算增量损失函数,优化增量网络;
重复上述步骤训练所述增量网络和所述蒸馏网络直至收敛。
进一步地,步骤(4)包括以下内容:计算新增类别c' 的最终类别原型 为:
;
计算样本与每个最终类别原型之间的相似度,/>;利用上述公式计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
本发明所述基于知识蒸馏小样本增量学习的图片分类系统,用于通过蒸馏网络判断输入图片所属类别,包括:
预热网络模块,用于将随机初始化的ResNet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对于每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;
增量网络模块,用于冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;
蒸馏网络及交叉迭代模块,用于冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;
预测模块,用于利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
本发明所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现所述的基于知识蒸馏小样本增量学习的图片分类方法。
有益效果:与现有技术相比,本发明的优点在于:(1)提出了一个三阶段算法,通过预热、增量学习和知识蒸馏三个阶段有效地提升了模型性能;(2)将增量学习任务和小样本学习任务分割开,减少小样本对模型带来的过拟合问题;(3)基于任务无关数据集的知识蒸馏方法,有利于提高模型的可扩展性,有效缓解了增量学习中的类别遗忘问题;(4)交互迭代更新方法能够使目标函数进一步收敛,更适用于小样本情况下的模型训练,同时提高了小样本问题下的图片分类效果。
附图说明
图1为本发明的图片分类方法流程图。
图2为本发明的模型迭代训练阶段的流程示意图。
图3为本发明实施例的分类检测准确率对比图。
具体实施方式
下面结合附图对本发明的技术方案作进一步说明。
对于给定图片判别该图片所属类别的任务,可以看作使用训练集训练出一个高效的检测模型,之后使用该模型对图片进行分类检测。如图1所示,本发明所述的基于知识蒸馏小样本增量学习的图片分类方法,包括以下步骤:数据预处理阶段、模型迭代训练阶段、预测阶段。
(1)数据预处理阶段:
选择miniImageNet作为预热阶段和增量学习阶段的数据集,cifar10作为知识蒸馏阶段中与任务无关的数据集。
对于miniImageNet和cifar10数据集分别按照8:2的比例划分为训练集和测试集。
将测试集中的图片使用中心裁剪方法,以图像中心点为参照,按照224×224像素大小从外向内进行裁剪。对于小于指定尺寸的图片,在原始图像外侧填充0,再进行中心裁剪。
对训练集中的样本进行数据增广。具体地,使用水平翻转方法基于随机的概率水平翻转图片,改变部分图片方向;然后使用图像抖动方法,随机改变图像的亮度、对比度、锐度和饱和度,让特征值在随机因子范围[10,30]内随机变换,增加样本多样性;之后使用图片旋转方法,不改变图片大小、亮度等特征的基础上得到新的数据,使图片分别旋转至、/>和/>,增大样本数量;最后使用随机裁剪方法,将图片随机裁剪为不同大小并改变宽高比,然后缩放至224×224像素大小。
(2)模型迭代训练阶段,优化嵌入网络,如图2所示。
(2.1)将预处理后的训练集作为模型的输入,随机在miniImageNet数据集中选取60个类别作为预热阶段的基础数据集,首先随机初始化的ResNet18作为预热网络,使用预热网络计算类别原型;具体步骤如下:使用基于任务的episode训练策略,对于每个episode,从基础训练集中随机抽取N个类,在每类都分别抽取K个样本组成支持集S,然后再从这N个类中剩余的样本抽取一部分数据作为查询集Q,构成的分类问题被称为N-way K-shot小样本任务,整体的训练任务由若干个小样本任务构成。对于每个episode执行一个小样本任务T i = {S i ,Q i },T i 为预热阶段第i个子任务,S i 为子任务中的支持集,Q i 为子任务中的查询集,计算类别c的原型为:
;
其中,p c 表示在特征空间中样本类别为c的类别原型,S i c 表示支持集S i 中类别为c的数据集,|S i c | 表示S i c 的大小,x t 为样本的特征向量,y t 为对应样本的标签。
(2.2)计算每个类别中原型与其他同类样本的距离之和:
;
根据每个类别中原型与同类样本距离计算类别的权重:
;
其中为softmax函数,m为属于S i 数据集的其他类别。
(2.3)对于来自查询集Q i 中的新样本x q ,利用如下的距离判别得到每个类别c的归一化分类分数:
;
其中为softmax函数。
(2.4)指定预热损失函数L H 为:
;
其中,x q 为属于查询集Q i 的样本,y q 为对应样本的标签,代表该样本所属的类别;使用上述损失函数对预热网络进行迭代训练直至模型收敛。
(2.5)使用miniImageNet数据集中剩余的40个类别作为增量阶段新类数据集,分8次逐步加入,每次增量学习任务新加入5个类别,每个类别随机采样K个样本。构建增量学习网络,首先冻结经过预热训练后的参数θ,并使用其作为增量网络的初始值;对于一个小样本增量任务/>,/>为增量学习阶段的第j个子任务,S_new j 为增量子任务中的支持集,Q_new j 为增量子任务中的查询集,计算新增类别c'的原型为:
;
其中,p' c' 表示在特征空间中样本类别为c' 的类别原型,S_new j c' 表示增量支持集S_new j 中类别为c' 的数据集,|S_new j c' |表示数据集S_new j c' 的大小,x t' 为样本的特征向量,y t' 为对应样本的标签。
(2.6)对于来自增量查询集Q_new j 中的新样本x q' ,根据每个样本到所属类别原型p' c' 的距离计算样本的权重:
;
其中表示标签为c'的样本x q' 的权重值,/>为softmax函数,x n 表示类别为c'的其他样本,y n 为对应样本的标签。
(2.7)构建增量损失函数L R :
;
其中,x q' 为属于增量查询集的样本,y q' 为对应样本的标签;使用上述损失函数对增量网络进行迭代训练直至模型收敛。
(2.8)使用cifar10作为知识蒸馏阶段的任务无关数据集,随机选取10个类别,每个类别随机选择1000张图片。构建蒸馏网络,首先冻结预热网络和增量网络的参数,并拷贝训练后的增量网络的参数作为蒸馏网络/>的初始值;使用任务无关的数据集/>进行蒸馏学习;根据其在预热网络和增量网络上的输出分布f θ (x u )和g φ (x u ),其中x u ∈D u ,分别计算蒸馏损失项:
;
;
其中,为softmax函数,T为蒸馏温度系数。
(2.9)用参数调整新旧类别比例并作累加,计算蒸馏损失函数为L KD :
;
在实验中设置λ = 0.1;使用上述损失函数对增量网络进行迭代训练直至模型收敛。
(2.10)在模型训练过程中,使用交叉迭代网络更新方法,具体包括:
首先使用增量损失函数L R 更新增量网络中的参数φ,之后冻结增量网络,计算得到蒸馏损失函数L KD ,对蒸馏网络/>中的参数σ和增量网络中的φ进行更新,接下来冻结蒸馏网络,根据更新后的参数φ得到新的增量损失函数L R ,进一步优化增量网络,重复上述交叉迭代网络更新步骤直至训练函数收敛。
(3)预测阶段:
(3.1)将预处理后的miniImageNet新类数据集中的测试集数据作为模型输入,使用训练后的蒸馏网络来计算样本的特征向量,计算每个类对应的支持集样本的平均值作为该类的原型:
;
其中,为特征空间中样本类别为c' 的最终类别原型。
(3.2)通过小样本图像分类函数计算测试样本与每个类别原型之间的相似度,最后得到相似度最高的类作为最终检测结果,小样本图像分类函数为:
。
通过仿真实验对本发明所述的基于知识蒸馏小样本增量学习的图片分类方法进行验证,使用python实现所述的模型训练方法与测试方法,并与iCaRL、EEIL、TOPIC等小样本增量学习方法对比,在miniImageNet数据集5-way 5-shot任务下对比结果如图3所示。所有的程序都是在配有Intel Core i7-8700 CPU,3.20GHz,32 GBRAM和NVIDIA TITAN RTX的标准服务器上执行的,采用激活函数为ReLu函数的ResNet18神经网络, 设置优化器为Adam。在预热阶段和增量学习阶段中,使用0.1作为初始学习率,并在训练过程中使之逐步递减为原值的十分之一。在知识蒸馏阶段学习率固定为0.001,迭代20轮后停止。从图3中可以看出,本发明所述的基于知识蒸馏小样本增量学习的图片分类方法的分类识别准确率比其他方法取得了较大程度的领先,相较于TOPIC算法提高了10%左右的最终分类准确率,表现出了更适合小样本学习这一特殊任务的优越性,显著高效地提升了模型性能。
本发明所述基于知识蒸馏小样本增量学习的图片分类系统,用于通过蒸馏网络判断输入图片所属类别,包括:
预热网络模块,用于将随机初始化的ResNet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对于每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;
增量网络模块,用于冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;
蒸馏网络及交叉迭代模块,用于冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;
预测模块,用于利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
本发明所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现所述的基于知识蒸馏小样本增量学习的图片分类方法。
所述计算机可读存储媒体可包括RAM、ROM、EEPROM、CD-ROM 或其它光盘存储装置、磁盘存储装置或其它磁性存储装置、快闪存储器或可用来存储指令或数据结构的形式的所要程序代码并且可由计算机存取的任何其它媒体。
处理器用于执行存储器存储的计算机程序,以实现上述实施例涉及的方法中的各个步骤。
Claims (10)
1.一种基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,通过蒸馏网络判断输入图片所属类别,包括如下步骤:
(1)将随机初始化的ResNet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;
(2)冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;
(3)冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;
(4)利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
2.根据权利要求1所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(1)中利用所述预热网络计算类别原型包括:
类别c的类别原型p c 为:
;
其中,S i c 表示小样本分类任务中支持集S i 中类别为c的数据集,|S i c | 表示S i c 的大小, x t 为小样本分类任务中样本的特征向量,y t 为对应样本的标签,代表该样本所属的类别;为预热网络。
3.根据权利要求2所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(1)中训练所述预热网络的预热损失函数L H 为:
;其中,Q i 为小样本分类任务的查询集,x q 为查询集Q i 中的新样本,y q 为对应样本的标签,代表该样本所属的类别;为归一化分类函数,每个类别c的归一化分类分数为,/>为softmax函数;/>为权重,每个类别c的权重为/>,d c 为类别c中类别原型与其他同类别样本的距离和,m为S i 中除类别c以外的其他类别。
4.根据权利要求1所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(2)中利用所述增量网络计算新增类别的类别原型包括:
新增类别c' 的类别原型p' c' 为:
;
其中,S_new j c' 表示小样本增量任务中增量支持集S_new j 中类别为c'的数据集;表示S_new j c' 的大小,x t' 为小样本增量任务中样本的特征向量,y t' 为对应样本的标签,代表该样本所属的类别;/>为增量网络。
5.根据权利要求4所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(2)中训练所述增量网络的增量损失函数L R 为:
;
其中,Q_new j 为小样本增量任务中的增量查询集,x q' 为增量查询集Q_new j 中的新样本,y q' 为对应样本的标签,代表该样本所属的类别;为权重,/>为增量网络;
;
Q_new j c' 为Q_new j 中类别为的数据集,x n 表示Q_new j c' 中类别为c' 的其他样本, y n 为对应样本的标签;/>为softmax函数。
6.根据权利要求5所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛前,利用蒸馏损失函数对所述增量网络进行训练直至收敛,所述蒸馏损失函数的计算方法为:
使用任务无关的数据集D u 进行蒸馏学习,根据D u 在预热网络和增量网络上的输出分布f θ (x u )和g φ (x u )分别计算蒸馏损失项:
;
;
蒸馏损失函数为;
其中为蒸馏网络,/>为softmax函数,T为蒸馏温度系数,x u 为D u 中的样本,λ为参数。
7.根据权利要求6所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛包括:
利用增量损失函数更新增量网络参数,冻结增量网络;计算蒸馏损失函数并更新蒸馏网络参数和增量网络参数;冻结蒸馏网络,利用更新的增量网络参数重新计算增量损失函数,优化增量网络;
重复上述步骤训练所述增量网络和所述蒸馏网络直至收敛。
8.根据权利要求1所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(4)包括以下内容:
计算新增类别c' 的最终类别原型 为:
;
计算样本与每个最终类别原型之间的相似度,/>;
其中为蒸馏网络,S_new j c' 表示小样本增量任务中增量支持集S_new j 中类别为c' 的数据集,/>表示S_new j c' 的大小,x t' 为小样本增量任务中样本的特征向量,y t' 为对应样本的标签,代表该样本所属的类别;
利用上述公式计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
9.一种基于知识蒸馏小样本增量学习的图片分类系统,其特征在于,用于通过蒸馏网络判断输入图片所属类别,包括:
预热网络模块,用于将随机初始化的ResNet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;
增量网络模块,用于冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;
蒸馏网络及交叉迭代模块,用于冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;
预测模块,用于利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现根据权利要求1-8任一项所述的基于知识蒸馏小样本增量学习的图片分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310764468.7A CN116503676B (zh) | 2023-06-27 | 2023-06-27 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310764468.7A CN116503676B (zh) | 2023-06-27 | 2023-06-27 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116503676A true CN116503676A (zh) | 2023-07-28 |
CN116503676B CN116503676B (zh) | 2023-09-22 |
Family
ID=87328759
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310764468.7A Active CN116503676B (zh) | 2023-06-27 | 2023-06-27 | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116503676B (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116910571A (zh) * | 2023-09-13 | 2023-10-20 | 南京大数据集团有限公司 | 一种基于原型对比学习的开集域适应方法及系统 |
CN117011672A (zh) * | 2023-09-27 | 2023-11-07 | 之江实验室 | 基于类特定元提示学习的小样本类增对象识别方法和装置 |
CN117195951A (zh) * | 2023-09-22 | 2023-12-08 | 东南大学 | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 |
CN117975203A (zh) * | 2024-04-02 | 2024-05-03 | 山东大学 | 基于数据增强的小样本图像类增量学习方法及系统 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114298160A (zh) * | 2021-12-07 | 2022-04-08 | 浙江大学 | 一种基于孪生知识蒸馏与自监督学习的小样本分类方法 |
CN114492745A (zh) * | 2022-01-18 | 2022-05-13 | 天津大学 | 基于知识蒸馏机制的类增量辐射源个体识别方法 |
WO2023040147A1 (zh) * | 2021-09-14 | 2023-03-23 | 上海商汤智能科技有限公司 | 神经网络的训练方法及装置、存储介质和计算机程序 |
-
2023
- 2023-06-27 CN CN202310764468.7A patent/CN116503676B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023040147A1 (zh) * | 2021-09-14 | 2023-03-23 | 上海商汤智能科技有限公司 | 神经网络的训练方法及装置、存储介质和计算机程序 |
CN114298160A (zh) * | 2021-12-07 | 2022-04-08 | 浙江大学 | 一种基于孪生知识蒸馏与自监督学习的小样本分类方法 |
CN114492745A (zh) * | 2022-01-18 | 2022-05-13 | 天津大学 | 基于知识蒸馏机制的类增量辐射源个体识别方法 |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116910571A (zh) * | 2023-09-13 | 2023-10-20 | 南京大数据集团有限公司 | 一种基于原型对比学习的开集域适应方法及系统 |
CN116910571B (zh) * | 2023-09-13 | 2023-12-08 | 南京大数据集团有限公司 | 一种基于原型对比学习的开集域适应方法及系统 |
CN117195951A (zh) * | 2023-09-22 | 2023-12-08 | 东南大学 | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 |
CN117195951B (zh) * | 2023-09-22 | 2024-04-16 | 东南大学 | 一种基于架构搜索和自知识蒸馏的学习基因继承方法 |
CN117011672A (zh) * | 2023-09-27 | 2023-11-07 | 之江实验室 | 基于类特定元提示学习的小样本类增对象识别方法和装置 |
CN117011672B (zh) * | 2023-09-27 | 2024-01-09 | 之江实验室 | 基于类特定元提示学习的小样本类增对象识别方法和装置 |
CN117975203A (zh) * | 2024-04-02 | 2024-05-03 | 山东大学 | 基于数据增强的小样本图像类增量学习方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN116503676B (zh) | 2023-09-22 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116503676B (zh) | 一种基于知识蒸馏小样本增量学习的图片分类方法及系统 | |
CN107122809B (zh) | 基于图像自编码的神经网络特征学习方法 | |
CN108681752B (zh) | 一种基于深度学习的图像场景标注方法 | |
CN107392919B (zh) | 基于自适应遗传算法的灰度阈值获取方法、图像分割方法 | |
CN112699247A (zh) | 一种基于多类交叉熵对比补全编码的知识表示学习框架 | |
CN110929848B (zh) | 基于多挑战感知学习模型的训练、跟踪方法 | |
CN112685504B (zh) | 一种面向生产过程的分布式迁移图学习方法 | |
CN102521656A (zh) | 非平衡样本分类的集成迁移学习方法 | |
CN103116766A (zh) | 一种基于增量神经网络和子图编码的图像分类方法 | |
CN114092742B (zh) | 一种基于多角度的小样本图像分类装置和方法 | |
CN112509017B (zh) | 一种基于可学习差分算法的遥感影像变化检测方法 | |
CN109829414A (zh) | 一种基于标签不确定性和人体组件模型的行人再识别方法 | |
CN117934890B (zh) | 基于局部和全局邻居对齐的原型对比图像聚类方法及系统 | |
CN110399917B (zh) | 一种基于超参数优化cnn的图像分类方法 | |
CN116226689A (zh) | 一种基于高斯混合模型的配电网典型运行场景生成方法 | |
CN114492581A (zh) | 基于迁移学习和注意力机制元学习应用在小样本图片分类的方法 | |
CN117315534A (zh) | 一种基于vgg-16和鲸鱼优化算法的短视频分类方法 | |
CN116894948A (zh) | 基于不确定性引导的半监督图像分割方法 | |
CN111783688A (zh) | 一种基于卷积神经网络的遥感图像场景分类方法 | |
CN116630718A (zh) | 一种基于原型的低扰动的图像类增量学习算法 | |
CN116523877A (zh) | 一种基于卷积神经网络的脑mri图像肿瘤块分割方法 | |
CN112446432B (zh) | 基于量子自学习自训练网络的手写体图片分类方法 | |
CN114037866B (zh) | 一种基于可辨伪特征合成的广义零样本图像分类方法 | |
CN115100694A (zh) | 一种基于自监督神经网络的指纹快速检索方法 | |
CN113989567A (zh) | 垃圾图片分类方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |