CN116563642A - 图像分类模型可信训练及图像分类方法、装置、设备 - Google Patents
图像分类模型可信训练及图像分类方法、装置、设备 Download PDFInfo
- Publication number
- CN116563642A CN116563642A CN202310624459.8A CN202310624459A CN116563642A CN 116563642 A CN116563642 A CN 116563642A CN 202310624459 A CN202310624459 A CN 202310624459A CN 116563642 A CN116563642 A CN 116563642A
- Authority
- CN
- China
- Prior art keywords
- model
- loss
- image
- student
- student model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 172
- 238000013145 classification model Methods 0.000 title claims abstract description 61
- 238000000034 method Methods 0.000 title claims abstract description 54
- 239000013598 vector Substances 0.000 claims abstract description 230
- 230000006870 function Effects 0.000 claims description 22
- 238000004364 calculation method Methods 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 12
- 238000013140 knowledge distillation Methods 0.000 description 8
- 238000000605 extraction Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 5
- 230000000694 effects Effects 0.000 description 2
- 238000002679 ablation Methods 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 210000001747 pupil Anatomy 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000001052 transient effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了图像分类模型可信训练及图像分类方法、装置、设备,包括:将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;基于所述几何结构关系损失更新所述学生模型的参数;当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。这样,能够提升学生模型的图像分类精度。
Description
技术领域
本申请涉及图像分类技术领域,特别涉及图像分类模型可信训练及图像分类方法、装置、设备。
背景技术
在深度学习领域,知识可以理解为输入与输出之间隐含的映射关系。而知识蒸馏就是从大模型(或者模型集合)向小模型迁移知识的过程,其中大模型称为老师模型,而小模型称为学生模型。
在图像多分类任务中知识蒸馏常用做法是将样本同时输入到(已经训练好的)老师模型和学生模型,让学生模型学会老师模型中的知识。但是,存在学生模型进行图像分类时,精度不高的问题。
发明内容
有鉴于此,本申请的目的在于提供图像分类模型可信训练及图像分类方法、装置、设备,能够提升学生模型的图像分类精度。其具体方案如下:
第一方面,本申请公开了一种图像分类模型可信训练方法,包括:
将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
基于所述几何结构关系损失更新所述学生模型的参数;
当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
可选的,还包括:
确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;
若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;
若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
可选的,所述几何结构关系损失包括距离关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
利用公式loss2=leq+lneq计算距离关系损失;
其中,loss2表示距离关系损失,leq=MSELoss(dS i,j,dT i,j),lneq=MSELoss(dS i,j,dT i,j),并且,leq表示同类样本之间的距离损失,lneq表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,dT i,j表示第一向量距离,dS i,j表示第二向量距离,MSELoss为平均平方误差损失函数。
可选的,所述几何结构关系损失包括角度关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;
利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。
可选的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
利用公式loss3=l’ eq+l’neq计算角度关系损失;
其中,loss3表示角度关系损失,l’ eq=MSELoss(aS i,j,k,aT i,j,k),l’neq=MSELoss(aS i,j,k,aT i,j,k),aT i,j,k=cos(embdT i-embdT k,embdT j-embdT k),aS i,j,k=cos(embdS i-embdT k,embdS j-embdT k),并且,l’ eq表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,aS i,j,k表示学生模型对应的角度关系,aT i,j,k表示老师模型对应的角度关系,embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,MSELoss为平均平方误差损失函数。
可选的,所述基于所述几何结构关系损失更新所述学生模型的参数,包括:
基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=MSELoss(embdS,embdT);
基于所述综合训练损失更新所述学生模型的参数。
第二方面,本申请公开了一种图像分类方法,包括:
获取待分类图像;
将所述待分类图像输入目标图像分类模型,得到图像分类结果;
其中,所述目标图像分类模型为基于前述的图像分类模型可信训练方法训练得到。
第三方面,本申请公开了一种图像分类模型可信训练装置,包括:
特征向量获取模块,用于将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
关系损失计算模块,用于若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
模型参数更新模块,用于基于所述几何结构关系损失更新所述学生模型的参数;
分类模型确定模块,用于当参数更新后的所述学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
第四方面,本申请公开了一种电子设备,包括存储器和处理器,其中:
所述存储器,用于保存计算机程序;
所述处理器,用于执行所述计算机程序,以实现前述的图像分类模型可信训练方法,和/或如前述的图像分类方法。
第五方面,本申请公开了一种计算机可读存储介质,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现前述的图像分类模型可信训练方法,和/或如前述的图像分类方法。
可见,本申请将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量,其中,所述老师模型为训练后的图像分类模型,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,之后基于所述几何结构关系损失更新所述学生模型的参数,当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。也即,本申请将图像训练样本同时输入训练得到的老师模型以及学生模型,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本申请公开的一种图像分类模型可信训练方法流程图;
图2为本申请公开的一种具体的模型输入输出示意图;
图3为本申请公开的一种具体的图像分类方法示意图;
图4为本申请公开的一种图像分类模型可信训练装置结构示意图;
图5为本申请公开的一种电子设备结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
目前,在图像多分类任务中知识蒸馏常用做法是将样本同时输入到(已经训练好的)老师模型和学生模型,让学生模型学会老师模型中的知识。但是,存在学生模型进行图像分类时,精度不高的问题。为此,本申请提供了一种图像分类模型可信训练方案,能够提升学生模型的图像分类精度。
参见图1所示,本申请实施例公开了一种图像分类模型可信训练方法,包括:
步骤S11:将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型。
在具体的实施方式中,可以先利用图像训练样本集训练一个用于图像分类的老师模型,具体为多分类,训练后老师模型的参数不再更新,图像训练样本集包括图像训练样本以及图像训练样本对应的标签信息,损失函数为交叉熵,老师模型的参数量大于学生模型。然后训练利用该图像训练样本集训练学生模型。老师模型和学生模型均包括两个模块,在前的模块是特征向量提取模块,在后模块是特征向量分类模块。知识蒸馏仅仅针对特征向量提取模块。本申请实施例中,老师模型特征向量为老师模型中特征向量提取模块输出的特征向量,学生模型特征向量为学生模型中特征向量提取模块输出的特征向量,特征向量分类模块输入的是特征向量,输出的是分类结果。老师和学生模型可以使用相同的特征向量分类模块。
步骤S12:若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失。
在具体的实施方式中,可以确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
进一步的,在一种实施方式中,由于老师模型特征向量约等于学生模型特征向量,并且因为老师模型经过长时间的训练,老师模型特征向量要比学生模型向量更精准,所以在计算第二向量距离时,可以将所述任意两个图像训练样本中的任一样本对应的学生模型特征向量替换为老师模型特征向量,与另一样本对应的学生模型特征向量计算出向量距离,得到第二向量距离。
并且,在一种实施方式中,几何结构关系损失包括距离关系损失,可以利用公式loss2=leq+lneq计算距离关系损失;其中,loss2表示距离关系损失,leq=MSELoss(dS i,j,dT i,j),lneq=MSELoss(dS i,j,dT i,j),并且,leq表示同类样本之间的距离损失,lneq表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,dT i,j表示第一向量距离,dS i,j表示第二向量距离,MSELoss为平均平方误差损失函数。
另外,在一种实施方式中,所述几何结构关系损失包括角度关系损失,确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。需要指出的是,如果锚点样本对应的特征向量和这两个特征向量很接近(属于相同的分类),那么锚点细微的改变都会造成这两个特征向量之间的角度发生巨大的变化;所以本申请实施例限制锚点必须和这两个特征向量属于不同的分类。
进一步的,在具体的实施方式中,可以利用公式loss3=l’ eq+l’neq计算角度关系损失;其中,loss3表示角度关系损失,l’ eq=MSELoss(aS i,j,k,aT i,j,k),l’neq=MSELoss(aS i,j,k,aT i,j,k),aT i,j,k=cos(embdT i-embdT k,embdT j-embdT k),aS i,j,k=cos(embdS i-embdT k,embdS j-embdT k),并且,l’ eq表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,aS i,j,k表示学生模型对应的角度关系,aT i,j,k表示老师模型对应的角度关系,embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,MSELoss为平均平方误差损失函数。
可以理解的是,本申请实施例中,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系可信,则不参与损失计算。例如,参见图2所示,图2为本申请实施例公开的一种具体的模型输入输出示意图。fT表示老师模型,fS表示学生模型,embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,1、2、3分别表示3个样本,1和2属于相同分类,与3为不同分类。1和2样本,学生模型输出的特征向量之间的距离比老师模型更接近,学生模型输出的关系是可信的,所以学生模型不再需要模拟老师模型所输出的特征向量之间的距离和角度关系。
需要指出的是,同类样本的特征向量,距离越近越好;反之,属于不同分类的特征向量,它们之间的距离越远越好。如果相同分类的样本在老师模型中输出的特征向量之间的距离,比在学生模型中的远,说明使用学生模型中的特征向量,比使用老师模型中的能得到更精准的分类结果,即学生模型输出的特征向量之间的关系是可信的,反之是不可信的。如果学生模型输出的关系是可信的,再让学生模型拟合老师模型学习出来的关系,只会降低学生模型的效果。本申请实施例如果相同分类的样本在老师模型中输出的特征向量之间的距离比学生模型中的更远,或者不同分类的样本在老师模型中输出的特征向量之间的距离比学生模型中的更近的时候,使用学生模型的特征向量能得到比老师模型更精准的分类结果。此时给学生模型施加限制,不再让学生模型拟合老师模型学习出来的关系。这样,学生模型在学习老师模型特征向量之间的距离和角度关系时,受到条件的限制。当学生模型中输出的特征向量之间的关系,比老师模型输出的能让最后的分类模块表现更好时,学生模型就不再去模拟老师模型,以提升学生模型的效果。
步骤S13:基于所述几何结构关系损失更新所述学生模型的参数。
在具体的实施方式中,可以基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=MSELoss(embdS,embdT);embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,MSELoss为平均平方误差损失函数。基于所述综合训练损失更新所述学生模型的参数。
在一种实施方式中,可以利用特征向量损失以及特征向量损失对应的超参数、距离关系损失以及距离关系损失对应的超参数、角度关系损失以及角度关系损失对应的超参数计算综合训练损失,具体公式如下:
Loss=αloss1+βloss2+γloss3;其中,Loss表示综合训练损失,α、β、γ为超参数。
步骤S14:当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
可以理解的是,在训练过程中,本申请实施例是选取batch(批量样本)同时输入老师模型和学生模型,利用该批量样本对应的特征向量计算综合训练损失,基于综合训练损失更新学生模型的参数,完成一次迭代,之后重复迭代,直到参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。其中,可以基于综合训练损失更新学生模型的特征向量提取模块的参数,学生模型可以直接使用老师模型的特征向量分类模块。
下面以手写数字识别的数据集为例,采用本申请中的方案进行训练和测试:
步骤1、训练一个大的老师模型去学习一个多分类任务,选择交叉熵作为损失函数。
步骤2、开始训练一个小的学生模型去模拟老师模型中已经学习到的知识,具体做法如下:
步骤2.1、固定老师模型的参数,从数据集中取样一些样本,同时输入给老师模型和学生模型,并得到特征向量。
其中,embdT=fT(x),embdS=fS(x);fT表示老师模型的特向向量提取模块,fS表示学生模型的特征向量提取模块,embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,x表示输入的图像训练样本。
步骤2.2、设计损失函数loss1,让学生模型的特征向量尽可能接近老师模型的特征向量。loss1=MSELoss(embdS,embdT)。
步骤2.3、设计损失函数loss2,在学生模型输出的关系不可信的时候,让学生模型学习老师模型特征向量之间的距离关系。具体方案如下:
dT i,j=distance(embdT i,embdT j);
dS i,j=distance(embdS i,embdS j);
其中,i和j是表示这些样本中的任意两条,i≠j,embdT i是第i条样本输入给老师模型所输出的特征向量,同理embdS j是第j条样本输入给学生模型所输出的特征向量,这两条样本在老师模型和学生模型输出的特征向量之间的距离是dT i,j、dS i,j,distance是距离函数,可以选择余弦距离或者欧式距离等。由于embdT i≈embdS i并且embdT i已经经过了老师模型长时间的训练,要比embdS i更精准,所以在公式中选择用embdT i来代替embdS i,即:dS i,j=distance(embdT i,embdS j);于是得到损失函数loss2:
leq=MSELoss(dS i,j,dT i,j),clsi=clsj且dS i,j>dT i,j;
lneq=MSELoss(dS i,j,dT i,j),clsi≠clsj且dS i,j<dT i,j;
loss2=leq+lneq;
其中,cls表示样本的分类。相同分类的样本,如果学生模型输出的特征向量之间的距离比老师模型的更远,就让学生模型学习老师模型输出的特征向量之间的距离关系,leq=就是对应的损失函数。不同分类的样本,如果学生模型输出的特征向量之间的距离比老师模型的更近,就让学生模型学习老师模型输出的特征向量之间的距离关系,lneq就是对应的损失函数。
步骤2.4、设计损失函数loss3,在学生模型输出的关系不可信的时候,让学生模型学习老师模型输出的特征向量之间的角度关系。要想计算两个特征向量之间的角度关系,还需要选择一个锚点embdT k。需要注意的是,如果锚点和这两个特征向量很接近,属于相同的分类,,那么锚点细微的改变都会造成这两个特征向量之间的角度发生巨大的变化;所以限制锚点必须和这两个特征向量属于不同的分类。
aT i,j,k=cos(embdT i-embdT k,embdT j-embdT k),aS i,j,k=cos(embdS i-embdT k,
embdS j-embdT k),i≠j≠k且clsk≠clsi且clsk≠clsj;
参考距离关系的损失函数,可以得到角度关系的损失函数:
l’ eq=MSELoss(aS i,j,k,aT i,j,k),clsi=clsj且dS i,j>dT i,j;
l’neq=MSELoss(aS i,j,k,aT i,j,k),clsi≠clsj且dS i,j<dT i,j;
loss3=l’ eq+l’neq;
步骤2.5、使用不同的超参数将三个损失函数融合在一起,作为最后的损失函数,就可以实现学生模型的训练。在使用学生模型做最后分类的时候,可以直接使用老师模型的分类模块。
Loss=αloss1+βloss2+γloss3。
这样,借助老师和学生模型输出的特征向量之间的距离关系,来评估学生模型是否需要去学习老师模型输出特征向量之间的距离和角度关系。在计算两个特征向量之间的角度关系时,选择与它们不同分类的特征向量作为锚点,防止因为锚点太过接近这两个特征向量而造成角度计算出现很大的误差。通过消融实验,在手写数字识别的数据集上,通过本申请的方案,将知识从大模型迁移到小模型,使得小模型取得了更高的精度。
可见,本申请实施例将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量,其中,所述老师模型为训练后的图像分类模型,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,之后基于所述几何结构关系损失更新所述学生模型的参数,当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。也即,本申请实施例将图像训练样本同时输入训练得到的老师模型以及学生模型,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
参见图3所示,图3为本申请实施例公开的一种图像分类方法,包括:
步骤S21:获取待分类图像;
步骤S22:将所述待分类图像输入目标图像分类模型,得到图像分类结果;
其中,所述目标图像分类模型为基于前述实施例所述的图像分类模型可信训练方法训练得到。
可见,本申请实施例获取待分类图像,将所述待分类图像输入基于前述实施例所述的图像分类模型可信训练方法训练得到的目标图像分类模型,得到图像分类结果,目标图像分类模型在训练过程中,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
参见图4所示,本申请实施例公开了一种图像分类模型可信训练装置,包括:
特征向量获取模块11,用于将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
关系损失计算模块12,用于若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
模型参数更新模块13,用于基于所述几何结构关系损失更新所述学生模型的参数;
分类模型确定模块14,用于当参数更新后的所述学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
可见,本申请实施例将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量,其中,所述老师模型为训练后的图像分类模型,若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,之后基于所述几何结构关系损失更新所述学生模型的参数,当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。也即,本申请实施例将图像训练样本同时输入训练得到的老师模型以及学生模型,在任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信时,利用这两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,基于损失更新学生模型,以使学生模型拟合老师模型学习出来的几何结构关系,这样,给知识蒸馏添加限制,避免了在学生模型特征向量之间的几何结构关系可信时,学生模型仍拟合老师模型学习出来的几何结构关系,能够提升学生模型的图像分类精度。
进一步的,所述装置还包括可信判断模块,用于:
确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;
若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;
若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
在一种具体的实施方式中,所述几何结构关系损失包括距离关系损失,相应的,关系损失计算模块12,具体用于:
利用公式loss2=leq+lneq计算距离关系损失;
其中,loss2表示距离关系损失,leq=MSELoss(dS i,j,dT i,j),lneq=MSELoss(dS i,j,dT i,j),并且,leq表示同类样本之间的距离损失,lneq表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,dT i,j表示第一向量距离,dS i,j表示第二向量距离,MSELoss为平均平方误差损失函数。
在一种实施方式中,所述几何结构关系损失包括角度关系损失,相应的,关系损失计算模块12,具体用于:确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。
进一步的,关系损失计算模块12,具体用于:
利用公式loss3=l’ eq+l’neq计算角度关系损失;
其中,loss3表示角度关系损失,l’ eq=MSELoss(aS i,j,k,aT i,j,k),l’neq=MSELoss(aS i,j,k,aT i,j,k),aT i,j,k=cos(embdT i-embdT k,embdT j-embdT k),aS i,j,k=cos(embdS i-embdT k,embdS j-embdT k),并且,l’ eq表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,aS i,j,k表示学生模型对应的角度关系,aT i,j,k表示老师模型对应的角度关系,embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,MSELoss为平均平方误差损失函数。
所述装置还包括综合训练损失计算模块,用于:
基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=MSELoss(embdS,embdT);
相应的,模型参数更新模块13,用于基于所述综合训练损失更新所述学生模型的参数。
参见图5所示,本申请实施例公开了一种电子设备20,包括处理器21和存储器22;其中,所述存储器22,用于保存计算机程序;所述处理器21,用于执行所述计算机程序,前述实施例公开的图像分类模型可信训练方法,和/或图像分类方法。
关于上述图像分类模型可信训练方法,和/或图像分类方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。
并且,所述存储器22作为资源存储的载体,可以是只读存储器、随机存储器、磁盘或者光盘等,存储方式可以是短暂存储或者永久存储。
另外,所述电子设备20还包括电源23、通信接口24、输入输出接口25和通信总线26;其中,所述电源23用于为所述电子设备20上的各硬件设备提供工作电压;所述通信接口24能够为所述电子设备20创建与外界设备之间的数据传输通道,其所遵循的通信协议是能够适用于本申请技术方案的任意通信协议,在此不对其进行具体限定;所述输入输出接口25,用于获取外界输入数据或向外界输出数据,其具体的接口类型可以根据具体应用需要进行选取,在此不进行具体限定。
进一步的,本申请实施例还公开了一种计算机可读存储介质,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现前述实施例公开的图像分类模型可信训练方法,和/或图像分类方法。
关于上述图像分类模型可信训练方法,和/或图像分类方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其它实施例的不同之处,各个实施例之间相同或相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。
以上对本申请所提供的图像分类模型可信训练及图像分类方法、装置、设备进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
Claims (10)
1.一种图像分类模型可信训练方法,其特征在于,包括:
将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
基于所述几何结构关系损失更新所述学生模型的参数;
当参数更新后的学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
2.根据权利要求1所述的图像分类模型可信训练方法,其特征在于,还包括:
确定任意两个图像训练样本的老师模型特征向量对应的第一向量距离以及学生模型特征向量对应的第二向量距离;
若所述任意两个图像训练样本为同类样本,且所述第一向量距离小于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信;
若所述任意两个图像训练样本为不同类样本,且所述第一向量距离大于所述第二向量距离,则判定所述任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信。
3.根据权利要求2所述的图像分类模型可信训练方法,其特征在于,所述几何结构关系损失包括距离关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
利用公式loss2=leq+lneq计算距离关系损失;
其中,loss2表示距离关系损失,leq=MSELoss(dS i,j,dT i,j),lneq=MSELoss(dS i,j,dT i,j),并且,leq表示同类样本之间的距离损失,lneq表示不同类样本之间的距离损失,i、j分别表示第i个图像训练样本、第j个图像训练样本,dT i,j表示第一向量距离,dS i,j表示第二向量距离,MSELoss为平均平方误差损失函数。
4.根据权利要求3所述的图像分类模型可信训练方法,其特征在于,所述几何结构关系损失包括角度关系损失,相应的,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
确定与所述任意两个图像训练样本不属于相同分类的图像训练样本作为锚点图像训练样本;
利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量以及所述锚点图像训练样本对应的老师模型特征向量计算角度关系损失。
5.根据权利要求4所述的图像分类模型可信训练方法,其特征在于,所述利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失,包括:
利用公式loss3=l’ eq+l’neq计算角度关系损失;
其中,loss3表示角度关系损失,l’ eq=MSELoss(aS i,j,k,aT i,j,k),l’neq=MSELoss(aS i,j,k,aT i,j,k),aT i,j,k=cos(embdT i-embdT k,embdT j-embdT k),aS i,j,k=cos(embdS i-embdT k,embdS j-embdT k),并且,l’ eq表示同类的第i个图像训练样本与j个图像训练样本之间的角度损失,l’neq表示不同类的第i个图像训练样本与j个图像训练样本之间的角度损失,k表示锚点样本,aS i,j,k表示学生模型对应的角度关系,aT i,j,k表示老师模型对应的角度关系,embdT表示老师模型输出的特征向量,embdS表示学生模型输出的特征向量,MSELoss为平均平方误差损失函数。
6.根据权利要求5所述的图像分类模型可信训练方法,其特征在于,所述基于所述几何结构关系损失更新所述学生模型的参数,包括:
基于特征向量损失和所述几何结构关系损失计算综合训练损失;其中,特征向量损失的计算公式为:loss1=MSELoss(embdS,embdT);
基于所述综合训练损失更新所述学生模型的参数。
7.一种图像分类方法,其特征在于,包括:
获取待分类图像;
将所述待分类图像输入目标图像分类模型,得到图像分类结果;
其中,所述目标图像分类模型为基于权利要求1至6任一项所述的图像分类模型可信训练方法训练得到。
8.一种图像分类模型可信训练装置,其特征在于,包括:
特征向量获取模块,用于将图像训练样本输入老师模型和学生模型,得到所述老师模型输出的老师模型特征向量以及所述学生模型输出的学生模型特征向量;其中,所述老师模型为训练后的图像分类模型;
关系损失计算模块,用于若任意两个图像训练样本对应的学生模型特征向量之间的几何结构关系不可信,则利用所述任意两个图像训练样本对应的老师模型特征向量和学生模型特征向量计算几何结构关系损失;
模型参数更新模块,用于基于所述几何结构关系损失更新所述学生模型的参数;
分类模型确定模块,用于当参数更新后的所述学生模型满足收敛条件,则将参数更新后的学生模型确定为目标图像分类模型。
9.一种电子设备,其特征在于,包括存储器和处理器,其中:
所述存储器,用于保存计算机程序;
所述处理器,用于执行所述计算机程序,以实现如权利要求1至6任一项所述的图像分类模型可信训练方法,和/或如权利要求7所述的图像分类方法。
10.一种计算机可读存储介质,其特征在于,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述的图像分类模型可信训练方法,和/或如权利要求7所述的图像分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310624459.8A CN116563642B (zh) | 2023-05-30 | 2023-05-30 | 图像分类模型可信训练及图像分类方法、装置、设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310624459.8A CN116563642B (zh) | 2023-05-30 | 2023-05-30 | 图像分类模型可信训练及图像分类方法、装置、设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116563642A true CN116563642A (zh) | 2023-08-08 |
CN116563642B CN116563642B (zh) | 2024-02-27 |
Family
ID=87498143
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310624459.8A Active CN116563642B (zh) | 2023-05-30 | 2023-05-30 | 图像分类模型可信训练及图像分类方法、装置、设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116563642B (zh) |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111814717A (zh) * | 2020-07-17 | 2020-10-23 | 腾讯科技(深圳)有限公司 | 人脸识别方法、装置及电子设备 |
KR102232138B1 (ko) * | 2020-11-17 | 2021-03-25 | (주)에이아이매틱스 | 지식 증류 기반 신경망 아키텍처 탐색 방법 |
CN112560631A (zh) * | 2020-12-09 | 2021-03-26 | 昆明理工大学 | 一种基于知识蒸馏的行人重识别方法 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN114067444A (zh) * | 2021-10-12 | 2022-02-18 | 中新国际联合研究院 | 基于元伪标签和光照不变特征的人脸欺骗检测方法和系统 |
CN114494776A (zh) * | 2022-01-24 | 2022-05-13 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置、设备以及存储介质 |
CN114862764A (zh) * | 2022-04-12 | 2022-08-05 | 阿里巴巴达摩院(杭州)科技有限公司 | 瑕疵检测模型训练方法、装置、设备和存储介质 |
CN114973307A (zh) * | 2022-02-08 | 2022-08-30 | 西安交通大学 | 生成对抗和余弦三元损失函数的指静脉识别方法及系统 |
JP2023013293A (ja) * | 2021-07-15 | 2023-01-26 | グローリー株式会社 | 教師データ生成装置、学習モデル生成装置、および教師データの生成方法 |
-
2023
- 2023-05-30 CN CN202310624459.8A patent/CN116563642B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111814717A (zh) * | 2020-07-17 | 2020-10-23 | 腾讯科技(深圳)有限公司 | 人脸识别方法、装置及电子设备 |
KR102232138B1 (ko) * | 2020-11-17 | 2021-03-25 | (주)에이아이매틱스 | 지식 증류 기반 신경망 아키텍처 탐색 방법 |
CN112560631A (zh) * | 2020-12-09 | 2021-03-26 | 昆明理工大学 | 一种基于知识蒸馏的行人重识别方法 |
JP2023013293A (ja) * | 2021-07-15 | 2023-01-26 | グローリー株式会社 | 教師データ生成装置、学習モデル生成装置、および教師データの生成方法 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN114067444A (zh) * | 2021-10-12 | 2022-02-18 | 中新国际联合研究院 | 基于元伪标签和光照不变特征的人脸欺骗检测方法和系统 |
CN114494776A (zh) * | 2022-01-24 | 2022-05-13 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置、设备以及存储介质 |
CN114973307A (zh) * | 2022-02-08 | 2022-08-30 | 西安交通大学 | 生成对抗和余弦三元损失函数的指静脉识别方法及系统 |
CN114862764A (zh) * | 2022-04-12 | 2022-08-05 | 阿里巴巴达摩院(杭州)科技有限公司 | 瑕疵检测模型训练方法、装置、设备和存储介质 |
Non-Patent Citations (2)
Title |
---|
HUGO MASSON ET.AL: "Exploiting prunability for person re-identification", 《EURASIP JOURNAL ON IMAGE AND VIDEO PROCESSING》, pages 1 - 31 * |
李大湘 等: "面向遥感图像场景分类的双知识蒸馏模型", 《电子与信息学报》, vol. 45, no. 10, pages 3558 - 3567 * |
Also Published As
Publication number | Publication date |
---|---|
CN116563642B (zh) | 2024-02-27 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108399428B (zh) | 一种基于迹比准则的三元组损失函数设计方法 | |
US10410362B2 (en) | Method, device, and non-transitory computer readable storage medium for image processing | |
Escanciano et al. | Uniform convergence of weighted sums of non and semiparametric residuals for estimation and testing | |
CN111127364B (zh) | 图像数据增强策略选择方法及人脸识别图像数据增强方法 | |
WO2020232874A1 (zh) | 基于迁移学习的建模方法、装置、计算机设备和存储介质 | |
CN113837205B (zh) | 用于图像特征表示生成的方法、设备、装置和介质 | |
WO2021159815A1 (zh) | 人脸识别模型的训练方法、装置和计算机设备 | |
JP2023042582A (ja) | サンプル分析の方法、電子装置、記憶媒体、及びプログラム製品 | |
CN111161249A (zh) | 一种基于域适应的无监督医学图像分割方法 | |
CN112597124A (zh) | 一种数据字段映射方法、装置及存储介质 | |
CN113326825A (zh) | 伪标签生成方法、装置、电子设备及存储介质 | |
Lonij et al. | Open-world visual recognition using knowledge graphs | |
CN116756536B (zh) | 数据识别方法、模型训练方法、装置、设备及存储介质 | |
CN116563642B (zh) | 图像分类模型可信训练及图像分类方法、装置、设备 | |
WO2021258482A1 (zh) | 基于迁移与弱监督的美丽预测方法、装置及存储介质 | |
CN111062406B (zh) | 一种面向异构领域适应的半监督最优传输方法 | |
CN115795355B (zh) | 一种分类模型训练方法、装置及设备 | |
US20240020531A1 (en) | System and Method for Transforming a Trained Artificial Intelligence Model Into a Trustworthy Artificial Intelligence Model | |
CN111368792B (zh) | 特征点标注模型训练方法、装置、电子设备及存储介质 | |
CN114187470A (zh) | 垃圾分类模型的训练方法、垃圾分类方法及装置 | |
CN112784927A (zh) | 一种基于在线学习的半自动图像标注方法 | |
CN111814949B (zh) | 一种数据标注方法、装置及电子设备 | |
CN117743568B (zh) | 基于资源流量和置信度融合的内容生成方法和系统 | |
CN113837228B (zh) | 基于惩罚感知中心损失函数的用于细粒度物体检索方法 | |
WO2024119901A1 (zh) | 识别模型训练方法、装置、计算机设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant | ||
CP03 | Change of name, title or address |
Address after: No. 205, Building B1, Huigu Science and Technology Industrial Park, No. 336 Bachelor Road, Bachelor Street, Yuelu District, Changsha City, Hunan Province, 410000 Patentee after: Wisdom Eye Technology Co.,Ltd. Country or region after: China Address before: 410205, Changsha high tech Zone, Hunan Province, China Patentee before: Wisdom Eye Technology Co.,Ltd. Country or region before: China |
|
CP03 | Change of name, title or address |