CN113554078B - 一种基于对比类别集中提升连续学习下图分类精度的方法 - Google Patents
一种基于对比类别集中提升连续学习下图分类精度的方法 Download PDFInfo
- Publication number
- CN113554078B CN113554078B CN202110788454.XA CN202110788454A CN113554078B CN 113554078 B CN113554078 B CN 113554078B CN 202110788454 A CN202110788454 A CN 202110788454A CN 113554078 B CN113554078 B CN 113554078B
- Authority
- CN
- China
- Prior art keywords
- model
- sample
- training
- class
- classification
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于对比类别集中提升连续学习下图分类精度的方法。该方法用于对已经过历史数据训练的图分类模型进行类增长学习,具体步骤如下:S1:获取加入新类别的图像分类数据集,图像分类数据集中每个样本均带有其类别标签;S2:获取所述图分类模型在上一轮训练过程中进行参数更新前的旧模型和参数更新后的新模型,然后利用加入新类别的图像分类数据集构建训练数据,进行本轮训练;S3:保存本轮参数更新前的旧模型和参数更新后的新模型,并在进行下一轮训练之前利用本轮参数更新后的新模型进行图分类任务。本发明结合对比学习和知识蒸馏的思想,能够帮助模型学习到更加聚合的数据表征,从而缓解表征覆盖,帮助模型减少灾难性遗忘。
Description
技术领域
本发明涉及计算机视觉处理,尤其涉及一种基于对比类别集中提升连续学习下图分类精度的方法。
背景技术
类增长学习作为连续学习的一种,在工业界以及学界都受到了越来越多的关注。其学习过程不像传统的深度学习范式一样使用全部图片数据来训练模型,而是基于新加入类别数据来连续的更新模型参数且在更新的过程中不再使用过去的数据的学习方式,其更贴近企业数据的更新和淘汰情况。然而,单纯的基于新类别数据微调现有模型会导致模型对于老类别图片的分类精度有明显的下降。我们把这种现象叫做“灾难性遗忘”。
现阶段尝试阻止灾难性遗忘的主要策略是通过基于知识蒸馏的方法来保存过去模型所获得的知识,并放宽限制,设置有限的内存容量,在有限的内存中保存少量过去的类别图片数据来减缓灾难性遗忘。然而,在使用知识蒸馏的过程中,我们发现模型的表征在连续学习的过程中会发生新类别覆盖旧类别的现象。我们称其为“表征覆盖”。而这样的表征覆盖则很大程度上影响了模型在连续学习过程中对新老类别图片的分类精度。我们认为影响类增长学习过程中模型表征覆盖的原因主要有两点。其中一点是受限的内存容量,其次则是模型在连续学习过程中对相同类别表征的聚类能力。当内存容量增加时,更多的过去数据可以被保存,在新类别数据加入时,可以通过这些老类别来缓解表征层面的覆盖。然而由于内存容量受到限制,并且带入内存还会导致新老类别数据不平衡以及老类别数据利用率低等问题。这也使得许多基于内存设置的类增长学习方法不能用在无内存的基础框架上。
发明内容
本发明的目的是为了克服现有技术的不足,提供一种基于对比类别集中提升连续学习下图分类精度的方法。
本发明具体通过以下技术方案实现:
一种基于对比类别集中提升连续学习下图分类精度的方法,用于对已经过历史数据训练的图分类模型进行类增长学习,所述图分类模型由特征提取器和分类器组成,其具体方法如下:
S1:获取加入新类别的图像分类数据集,图像分类数据集中每个样本均带有其类别标签;
S2:获取所述图分类模型在上一轮训练过程中进行参数更新前的旧模型和参数更新后的新模型,然后利用加入新类别的图像分类数据集构建训练数据,按照 S21~S25进行本轮训练;
S21:将训练数据输入所述新模型的特征提取器中提取第一样本表征,在真实类别标签监督下计算对比误差损失,用于聚合相同类别样本的表征同时拉远不同类别样本的表征;
S22:将训练数据输入所述旧模型的特征提取器中提取第二样本表征,对所述第一样本表征和第二样本表征进行表征层面的蒸馏并计算表征蒸馏误差损失,用于从表征层面保存模型知识的同时保证表征空间相对稳定;
S23:基于所述的第一样本表征,通过所述新模型的分类器进行类别预测得到第一预测类别概率,并计算交叉熵损失;
S24:基于所述的第一样本表征,通过所述旧模型的分类器进行类别预测得到第二预测类别概率,并计算第一预测类别概率和第二预测类别概率之间的均方误差损失;
S25:以所述对比误差损失、表征蒸馏误差损失、交叉熵损失和均方误差损失之和作为对比类别集中误差损失,通过最小化该对比类别集中误差对所述新模型进行参数优化更新;
S3:保存本轮参数更新前的旧模型和参数更新后的新模型,并在进行下一轮训练之前利用本轮参数更新后的新模型进行图分类任务。
作为优选,通过内存存储所述图分类模型在训练过程中输入的不同类别的训练样本;其中每一轮训练开始时,将加入新类别的图像分类数据集与内存中存储的所有不同类别的训练样本一起作为训练数据对模型进行训练,每一轮训练结束后,通过随机采样的方式从本轮所输入的训练数据中选择部分属于新类别的样本放入内存中,同时从内存中的每个旧类别中移除固定数量的样本,从而保持内存总容量不变的同时增加类别数。
作为优选,所述训练数据在用于本轮训练之前,预先经过数据增广以扩充样本量。
进一步的,所述数据增广方式包括随机的切割、翻转、对比度和灰度调整。
在上述任一技术方案的基础上,作为进一步优选,所述S21的具体实现步骤如下:
S211:将当前第t轮训练对应的训练数据D(t)输入所述新模型M(t)中,使用内部的特征提取器Eθ (t)获得训练数据D(t)中每一个样本xi (t)的特征ri (t),所有样本的特征构成第一样本表征R(t);
S212:基于第一样本表征R(t)中每个特征所对应的类别标签,计算每个特征与其他特征之间的相似性;所述相似性使用点乘之后取二范数来表示,任意两个特征ri (t)和rj (t)之间的相似性s(xi (t),xj (t))计算公式为:
式中:||·||表示二范数;
S213:将相同类别样本的特征视为正样本,不同类别样本的特征作为负样本,通过计算对比误差损失来聚合相同类别样本的表征同时拉远不同类别样本的表征;所述对比误差损失函数公式为:
其中:B为输入的训练数据D(t)中的样本个数,为训练数据D(t)中除样本xi (t)之外的其余样本的索引集合。
进一步的,所述S22的具体实现步骤包括:
S221:将当前第t轮训练对应的训练数据D(t)输入所述旧模型M(t-1)中,使用内部的特征提取器Eθ (t-1)获得训练数据D(t)中每一个样本xi (t)的特征ri (t-1),所有样本的特征构成第二样本表征R(t-1);
S222:结合所述第一样本表征R(t)和第二样本表征R(t-1)进行表征层面的蒸馏,计算表征蒸馏误差损失计算公式为:
其中:表示训练数据D(t)中所有样本的索引集合。
进一步的,所述S23的具体实现步骤包括:
S231:将前述的第一样本表征R(t)输入前述新模型M(t)中的分类器中进行分类预测,获得当前训练数据中各样本的第一预测类别概率;
S232:结合训练数据D(t)中各样本的第一预测类别概率和真实标签,计算交叉熵损失
进一步的,所述S24的具体实现步骤包括:
S241:将第二样本表征R(t-1)输入所述旧模型M(t-1)中的分类器中进行分类预测,获得当前训练数据的第二预测类别概率;
S242:基于训练数据D(t)中各样本的第一预测类别概率和第二预测类别概率, 计算均方误差损失计算公式为:
式中:和/>分别表示样本xi (t)的第一预测类别概率和第二预测类别概率,n(t)表示引入本轮类增长学习后的图片分类类别总数,MSE表示计算均方误差。
进一步的,所述S25的具体实现步骤包括:
S251:将S21~S24中的四种损失进行加和,得到总的对比类别集中误差损失
S252:以最小化所述对比类别集中误差损失为目标,对所述新模型进行网络参数优化,完成本轮模型训练。
作为优选,所述特征提取器为ResNet32网络。
相对于现有技术而言,本发明具有以下有益效果:
传统的类增长学习都是通过知识蒸馏的方式减缓灾难性遗忘来提高模型在图片上的分类精度,或者通过更好的优化设置内存带来的数据不平衡和数据利用率问题来减缓灾难性遗忘从而提高模型的图片分类精度。而本发明则另辟蹊径,从模型的表征层面着手,通过利用对比学习所特有的类集中方式利用类别信息进行表征聚类从而缓解了表征覆盖提高了模型的分类精度。本发明的方法不受内存数据影响,实验表明不论是在有内存的设置下还是在无内存的设置下,该方法都提高了模型在连续学习下的精度。
附图说明
图1为基于对比类别集中来提高模型在连续学习中的图片分类精度的方法的流程图。
图2为本发明在连续学习过程中的表征分布与之前方法的对比结果。
图3为一种在有内存的设置下结合数据增广的类增长学习框架示意图。
图4为本发明与之前方法在表征层面的类覆盖情况对比结果。
具体实施方式
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图对本发明的具体实施方式做详细的说明。
本发明提供了一种基于对比类别集中提升连续学习下图分类精度的方法,适用于在模型应用过程中依然不断有加入新类别的数据产生状态下的类增长学习。该方法既可以用于有内存的设置下对图分类模型进行类增长学习,也可以用于无内存的设置下对图分类模型进行类增长学习。需说明的是,本发明中的图分类模型基本结构由特征提取器和分类器组成,具体的网络结构不限。在本发明的后续实施例中前述特征提取器采用ResNet32网络。
下面分别对两种方式的具体实现过程进行展开说明。
如图1所示,在本发明的一个较佳实施例中,提供了一种基于对比类别集中提升连续学习下图分类精度的方法,用于无内存的设置下对已经过历史数据训练的图分类模型进行类增长学习,其具体过程如下:
S1:从新产生的数据流中获取加入新类别的图像分类数据集,图像分类数据集是由众多图像样本组成的,每个样本均带有其类别标签,作为模型训练的真值。加入新类别的图像分类数据集中存在原来的图分类模型中没有的图片类别,因此需要重新训练图分类模型。
S2:由于该图分类模型在本轮训练之前已经经过了训练,因此先获取图分类模型在上一轮训练过程中进行参数更新前的旧模型和参数更新后的新模型,旧模型和新模型的网络结构是完全相同的,其区别仅在于新模型中的网络参数在上一轮训练过程中被更新了。当获取新模型和旧模型后,利用加入新类别的图像分类数据集直接作为本轮训练所需的训练数据。
当得到本轮的训练数据后,即可按照S21~S25进行当前轮训练,为了便于描述我们将当前的训练轮数记为第t轮,第t轮训练对应的训练数据记为D(t),新模型记为M(t),旧模型记为M(t-1)。下面对S21~S25的具体训练过程详述如下:
S21:将训练数据输入前述新模型的特征提取器中提取第一样本表征,在真实类别标签监督下计算对比误差损失,用于聚合相同类别样本的表征同时拉远不同类别样本的表征。在本实施例中,该步骤的具体实现过程如S211~S213所示:
S211:将当前第t轮训练对应的训练数据D(t)输入前述新模型M(t)中,使用内部的特征提取器Eθ (t)获得训练数据D(t)中每一个样本xi (t)的特征ri (t),所有样本的特征构成第一样本表征R(t);
S212:基于第一样本表征R(t)中每个特征所对应的类别标签,计算每个特征与其他特征之间的相似性;前述相似性使用点乘之后取二范数来表示,任意两个特征ri (t)和rj (t)之间的相似性s(xi (t),xj (t))计算公式为:
式中:||·||表示二范数;
S213:将相同类别样本的特征视为正样本,不同类别样本的特征作为负样本,通过计算对比误差损失来聚合相同类别样本的表征同时拉远不同类别样本的表征;前述对比误差损失函数公式为:
其中:B为输入的训练数据D(t)中的样本个数,为训练数据D(t)中除样本xi (t)之外的其余样本的索引集合。
S22:将训练数据输入前述旧模型的特征提取器中提取第二样本表征,对前述第一样本表征和第二样本表征进行表征层面的蒸馏并计算表征蒸馏误差损失,用于从表征层面保存模型知识的同时保证表征空间相对稳定。在本实施例中,该步骤的具体实现过程如S221~S222所示:
S221:将当前第t轮训练对应的训练数据D(t)输入前述旧模型M(t-1)中,使用内部的特征提取器Eθ (t-1)获得训练数据D(t)中每一个样本xi (t)的特征ri (t-1),所有样本的特征构成第二样本表征R(t-1);
S222:结合前述第一样本表征R(t)和第二样本表征R(t-1)进行表征层面的蒸馏,计算表征蒸馏误差损失来减少表征的相对空间在连续学习过程中的变化,从而更好的缓解了模型的在表征层面的表征覆盖,计算公式为:
其中:表示训练数据D(t)中所有样本的索引集合。
S23:基于前述的第一样本表征,通过前述新模型的分类器进行类别预测得到第一预测类别概率,并计算交叉熵损失。在本实施例中,该步骤的具体实现过程如S231~S232所示:
S231:将前述的第一样本表征R(t)输入前述新模型M(t)中的分类器中进行分类预测,获得当前训练数据中各样本的第一预测类别概率;
S232:结合训练数据D(t)中各样本的第一预测类别概率和真实标签,计算交叉熵损失
S24:基于前述的第一样本表征,通过前述旧模型的分类器进行类别预测得到第二预测类别概率,并计算第一预测类别概率和第二预测类别概率之间的均方误差损失。在本实施例中,该步骤的具体实现过程如S241~S242所示:
S241:将第二样本表征R(t-1)输入前述旧模型M(t-1)中的分类器中进行分类预测,获得当前训练数据的第二预测类别概率;
S242:基于训练数据D(t)中各样本的第一预测类别概率和第二预测类别概率, 计算均方误差损失来保存新模型的分类器对过去类别的知识,计算公式为:
式中:和/>分别表示样本xi (t)的第一预测类别概率和第二预测类别概率,n(t)表示引入本轮类增长学习后的图片分类类别总数,MSE表示计算均方误差。
S25:以前述对比误差损失、表征蒸馏误差损失、交叉熵损失和均方误差损失之和作为对比类别集中误差损失,通过最小化该对比类别集中误差对前述新模型进行参数优化更新。在本实施例中,该步骤的具体实现过程如S251~S252所示:
S251:将S21~S24中的四种损失进行加和,得到总的对比类别集中误差损失
S252:以最小化前述对比类别集中误差损失为目标,对前述新模型进行网络参数优化,完成本轮模型训练。
S3:保存本轮参数更新前的旧模型和参数更新后的新模型,并在进行下一轮训练之前利用本轮参数更新后的新模型进行图分类任务。
需注意的是,上述S1~S3的类增长学习过程是可以不断循环进行的,当后续的数据流中出现新类别的数据,且样本量达到一定数量后即可再次执行上述 S1~S3的过程,进而进行新一轮的模型训练,实现对新类别的适应。
上述S1~S3的过程是没有内存设置情况下的类增长学习过程,但是加入在有内存设置的情况下,则需要在构建每一轮的训练数据的时候将内存中存储的历史样本一并纳入训练数据范围。
如图2所示,在本发明的另一个较佳实施例中,展示了另一种基于对比类别集中提升连续学习下图分类精度的方法,用于有内存的设置下对已经过历史数据训练的图分类模型进行类增长学习,其每一轮的训练过程S21~S25与前一实施例基本类似,区别仅仅在于S1和S3步骤。在图分类模型的历次训练过程中,需要通过内存存储图分类模型在训练过程中输入的不同类别的训练样本。而在内存的设置下,执行S1步骤过程的区别在于:每一轮训练开始前构建训练数据时,除了将加入新类别的图像分类数据集一起作为训练数据之外,还需要将内存中存储的所有不同类别的训练样本也一并加入训练数据,这两部分一起作为训练数据对模型进行训练;对应的,执行S3步骤过程的区别在于:每一轮训练结束后,还需要通过随机采样的方式从本轮所输入的训练数据中选择部分属于新类别的样本放入内存中,同时从内存中的每个旧类别中移除固定数量的样本,从而保持内存总容量不变的同时增加类别数。
上述图1和图2中所展示的方法,从模型的表征层面着手,通过利用对比学习所特有的类集中方式利用类别信息进行表征聚类,因此均可以提高模型在连续学习下的精度。
另外需注意的是,不论是有内存的设置还是无内存的设置下构建训练数据时,若构建得到的训练数据D(t)的样本量足够多,则可以直接将其作为本轮训练所用的训练数据输入模型即可;但是如果构建得到的训练数据D(t)中样本量不足,那么可以考虑对D(t)中的已有样本数据先预先经过数据增广获得对应的增广数据,然后再将增广数据加入原训练数据D(t)一起作为增广后的训练数据(记为) 用于本轮训练,由此实现样本量的扩充。数据增广方式包括随机的切割、翻转、对比度和灰度调整。
如图3所示,展示了一种在有内存的设置下结合数据增广的类增长学习框架,图中W表示分类器中的权重参数,其上标表示训练轮数。在该框架下,若从数据流中获取的加入新类别的图像分类数据集为D(t),而从内存中获取的样本集合记为Dmem,那么两者结合后记得到数据集D(t*),数据集D(t*)再经过数据增广,最终得到样本扩充后的在该框架下,最终以数据集/>作为本轮训练输入模型的训练数据,即/>
为了验证本发明方法在类别分类上的效果,下面使用cifar-100数据集对上述内存的设置还是无内存的设置下的方法进行图片分类精度的测试。最后使用模型在连续学习过程中对过去数据的分类精度和之前所有数据的平均精度作为评价指标。该测试时由于样本量充足因此未做数据增广。另外,为了对于其他现有技术方法与本发明的区别,同步引入了四种现有技术中的主流类增长学习方法进行相同的图片分类任务。其中各主流方法的具体实现方式可参见以下现有技术论文:
[1]Zhizhong Li and Derek Hoiem.Learning without forgetting.IEEEtransactions on patternanalysis and machine intelligence,40(12):2935–2947,2017.
[2]Francisco M Castro,Manuel Jas Guil,CordeliaSchmid,and Karteek Ala-hari.End-to-end incremental learning.InProceedingsofthe European conference on computer vision(ECCV),pages 233–248,2018.
[3]Yue Wu,Yinpeng Chen,Lijuan Wang,Yuancheng Ye,Zicheng Liu,YandongGuo,and Yun Fu.Large scale incremental learning.InProceedings of theIEEEConference on Computer Vision and Pattern Recog-nition,pages 374–382,2019
[4]Sylvestre-Alvise Rebuffi,AlexanderKolesnikov,Georg Sperl,andChristoph H Lampert.icarl:Incremental classifier and representationlearning.InPro-ceedings of the IEEE conference on Computer Vision andPatternRecognition,pages 2001–2010, 2017
图4为本发明与之前方法在表征层面的类覆盖情况对比,另外测试精度的量化结果如下表所示:
方法 | 2-phase | 5-phase | 10-phase | 20-phase |
LWF[1](无内存设置方法) | 52.80 | 47.15 | 39.90 | 29.64 |
C4IL.NoMem(本发明在无内存设置下) | 56.04 | 51.79 | 44.04 | 34.14 |
ETE[2] | 61.15 | 63.45 | 63.61 | 63.28 |
BiC[3] | 64.31 | 64.50 | 62.99 | 62.22 |
iCaRL[4] | 62.09 | 63.31 | 61.52 | 59.70 |
C4IL.Mem(本发明在有内存设置下) | 65.02 | 66.25 | 66.79 | 66.28 |
从上述结果可见,不论有无内存,本发明基于对比类别集中来提高模型在连续学习中的图片分类精度的方法均明显优于现有主流方法。
以上所述的实施例只是本发明的一种较佳的方案,然其并非用以限制本发明。有关技术领域的普通技术人员,在不脱离本发明的精神和范围的情况下,还可以做出各种变化和变型。因此凡采取等同替换或等效变换的方式所获得的技术方案,均落在本发明的保护范围内。
Claims (10)
1.一种基于对比类别集中提升连续学习下图分类精度的方法,用于对已经过历史数据训练的图分类模型进行类增长学习,所述图分类模型由特征提取器和分类器组成,其特征在于:
S1:获取加入新类别的图像分类数据集,图像分类数据集中每个样本均带有其类别标签;
S2:获取所述图分类模型在上一轮训练过程中进行参数更新前的旧模型和参数更新后的新模型,然后利用加入新类别的图像分类数据集构建训练数据,按照S21~S25进行本轮训练;
S21:将训练数据输入所述新模型的特征提取器中提取第一样本表征,在真实类别标签监督下计算对比误差损失,用于聚合相同类别样本的表征同时拉远不同类别样本的表征;
S22:将训练数据输入所述旧模型的特征提取器中提取第二样本表征,对所述第一样本表征和第二样本表征进行表征层面的蒸馏并计算表征蒸馏误差损失,用于从表征层面保存模型知识的同时保证表征空间相对稳定;
S23:基于所述的第一样本表征,通过所述新模型的分类器进行类别预测得到第一预测类别概率,并计算交叉熵损失;
S24:基于所述的第一样本表征,通过所述旧模型的分类器进行类别预测得到第二预测类别概率,并计算第一预测类别概率和第二预测类别概率之间的均方误差损失;
S25:以所述对比误差损失、表征蒸馏误差损失、交叉熵损失和均方误差损失之和作为对比类别集中误差损失,通过最小化该对比类别集中误差对所述新模型进行参数优化更新;
S3:保存本轮参数更新前的旧模型和参数更新后的新模型,并在进行下一轮训练之前利用本轮参数更新后的新模型进行图分类任务。
2.如权利要求1所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,通过内存存储所述图分类模型在训练过程中输入的不同类别的训练样本;其中每一轮训练开始时,将加入新类别的图像分类数据集与内存中存储的所有不同类别的训练样本一起作为训练数据对模型进行训练,每一轮训练结束后,通过随机采样的方式从本轮所输入的训练数据中选择部分属于新类别的样本放入内存中,同时从内存中的每个旧类别中移除固定数量的样本,从而保持内存总容量不变的同时增加类别数。
3.如权利要求1所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述训练数据在用于本轮训练之前,预先经过数据增广以扩充样本量。
4.如权利要求3所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述数据增广方式包括随机的切割、翻转、对比度和灰度调整。
5.如权利要求1~3任一所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述S21的具体实现步骤如下:
S211:将当前第t轮训练对应的训练数据D(t)输入所述新模型M(t)中,使用内部的特征提取器Eθ (t)获得训练数据D(t)中每一个样本xi (t)的特征ri (t),所有样本的特征构成第一样本表征R(t);
S212:基于第一样本表征R(t)中每个特征所对应的类别标签,计算每个特征与其他特征之间的相似性;所述相似性使用点乘之后取二范数来表示,任意两个特征ri (t)和rj (t)之间的相似性s(xi (t),xj (t))计算公式为:
式中:||·||表示二范数;
S213:将相同类别样本的特征视为正样本,不同类别样本的特征作为负样本,通过计算对比误差损失来聚合相同类别样本的表征同时拉远不同类别样本的表征;所述对比误差损失函数公式为:
其中:B为输入的训练数据D(t)中的样本个数,为训练数据D(t)中除样本xi (t)之外的其余样本的索引集合。
6.如权利要求5所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述S22的具体实现步骤包括:
S221:将当前第t轮训练对应的训练数据D(t)输入所述旧模型M(t-1)中,使用内部的特征提取器Eθ (t-1)获得训练数据D(t)中每一个样本xi (t)的特征ri (t-1),所有样本的特征构成第二样本表征R(t-1);
S222:结合所述第一样本表征R(t)和第二样本表征R(t-1)进行表征层面的蒸馏,计算表征蒸馏误差损失计算公式为:
其中:表示训练数据D(t)中所有样本的索引集合。
7.如权利要求6所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述S23的具体实现步骤包括:
S231:将前述的第一样本表征R(t)输入前述新模型M(t)中的分类器中进行分类预测,获得当前训练数据中各样本的第一预测类别概率;
S232:结合训练数据D(t)中各样本的第一预测类别概率和真实标签,计算交叉熵损失
8.如权利要求7所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述S24的具体实现步骤包括:
S241:将第二样本表征R(t-1)输入所述旧模型M(t-1)中的分类器中进行分类预测,获得当前训练数据的第二预测类别概率;
S242:基于训练数据D(t)中各样本的第一预测类别概率和第二预测类别概率,计算均方误差损失计算公式为:
式中:和/>分别表示样本xi (t)的第一预测类别概率和第二预测类别概率,n(t)表示引入本轮类增长学习后的图片分类类别总数,MSE表示计算均方误差。
9.如权利要求8所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述S25的具体实现步骤包括:
S251:将S21~S24中的四种损失进行加和,得到总的对比类别集中误差损失
S252:以最小化所述对比类别集中误差损失为目标,对所述新模型进行网络参数优化,完成本轮模型训练。
10.如权利要求1所述的基于对比类别集中提升连续学习下图分类精度的方法,其特征在于,所述特征提取器为ResNet32网络。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110788454.XA CN113554078B (zh) | 2021-07-13 | 2021-07-13 | 一种基于对比类别集中提升连续学习下图分类精度的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110788454.XA CN113554078B (zh) | 2021-07-13 | 2021-07-13 | 一种基于对比类别集中提升连续学习下图分类精度的方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113554078A CN113554078A (zh) | 2021-10-26 |
CN113554078B true CN113554078B (zh) | 2023-10-17 |
Family
ID=78131660
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110788454.XA Active CN113554078B (zh) | 2021-07-13 | 2021-07-13 | 一种基于对比类别集中提升连续学习下图分类精度的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113554078B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115984946A (zh) * | 2023-02-01 | 2023-04-18 | 浙江大学 | 一种基于集成学习的人脸识别模型遗忘方法及系统 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN1529506A (zh) * | 2003-09-29 | 2004-09-15 | �Ϻ���ͨ��ѧ | 基于运动检测的视频对象分割方法 |
CN110059896A (zh) * | 2019-05-15 | 2019-07-26 | 浙江科技学院 | 一种基于强化学习的股票预测方法及系统 |
CN111199242A (zh) * | 2019-12-18 | 2020-05-26 | 浙江工业大学 | 一种基于动态修正向量的图像增量学习方法 |
CN112149741A (zh) * | 2020-09-25 | 2020-12-29 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、装置、电子设备及存储介质 |
CN112559784A (zh) * | 2020-11-02 | 2021-03-26 | 浙江智慧视频安防创新中心有限公司 | 基于增量学习的图像分类方法及系统 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110163234B (zh) * | 2018-10-10 | 2023-04-18 | 腾讯科技(深圳)有限公司 | 一种模型训练方法、装置和存储介质 |
-
2021
- 2021-07-13 CN CN202110788454.XA patent/CN113554078B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN1529506A (zh) * | 2003-09-29 | 2004-09-15 | �Ϻ���ͨ��ѧ | 基于运动检测的视频对象分割方法 |
CN110059896A (zh) * | 2019-05-15 | 2019-07-26 | 浙江科技学院 | 一种基于强化学习的股票预测方法及系统 |
CN111199242A (zh) * | 2019-12-18 | 2020-05-26 | 浙江工业大学 | 一种基于动态修正向量的图像增量学习方法 |
CN112149741A (zh) * | 2020-09-25 | 2020-12-29 | 北京百度网讯科技有限公司 | 图像识别模型的训练方法、装置、电子设备及存储介质 |
CN112559784A (zh) * | 2020-11-02 | 2021-03-26 | 浙江智慧视频安防创新中心有限公司 | 基于增量学习的图像分类方法及系统 |
Non-Patent Citations (1)
Title |
---|
基于增量学习的菜品分类研究;陶杨;《中国优秀硕士学位论文全文数据库 工程科技Ⅰ辑》(第6期);第1-52页 * |
Also Published As
Publication number | Publication date |
---|---|
CN113554078A (zh) | 2021-10-26 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Cui et al. | A new hyperparameters optimization method for convolutional neural networks | |
CN110263230B (zh) | 一种基于密度聚类的数据清洗方法及装置 | |
CN110515931B (zh) | 一种基于随机森林算法的电容型设备缺陷预测方法 | |
CN113554078B (zh) | 一种基于对比类别集中提升连续学习下图分类精度的方法 | |
CN111985825A (zh) | 一种用于滚磨机定向仪的晶面质量评估方法 | |
CN114091603A (zh) | 一种空间转录组细胞聚类、分析方法 | |
CN110288026B (zh) | 一种基于度量关系图学习的图像分割方法及装置 | |
CN113377991B (zh) | 一种基于最难正负样本的图像检索方法 | |
CN117112852B (zh) | 一种大语言模型驱动的向量数据库检索方法及系统 | |
CN113159294A (zh) | 基于同伴学习的样本选择算法 | |
Ortelli et al. | Faster estimation of discrete choice models via dataset reduction | |
CN112738724B (zh) | 一种区域目标人群的精准识别方法、装置、设备和介质 | |
CN113610350B (zh) | 复杂工况故障诊断方法、设备、存储介质及装置 | |
CN115457269A (zh) | 一种基于改进DenseNAS的语义分割方法 | |
CN115100694A (zh) | 一种基于自监督神经网络的指纹快速检索方法 | |
CN111340291B (zh) | 一种基于云计算技术的中长期电力负荷组合预测系统及方法 | |
CN108932550B (zh) | 一种基于模糊密集稀疏密集算法进行图像分类的方法 | |
CN113989567A (zh) | 垃圾图片分类方法及装置 | |
CN115688873A (zh) | 图数据处理方法、设备及计算机程序产品 | |
He et al. | Multilevel thresholding based on fuzzy masi entropy | |
CN111079750A (zh) | 一种基于局部区域聚类的电力设备故障区域提取方法 | |
CN116203929B (zh) | 一种面向长尾分布数据的工业过程故障诊断方法 | |
CN116257735B (zh) | 用于智慧城市治理的数据处理方法及系统 | |
CN115618921B (zh) | 知识蒸馏方法、装置、电子设备和存储介质 | |
CN118606756A (zh) | 一种基于二次模糊学习机的标记噪声识别方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |