CN114037876A - 一种模型优化方法和装置 - Google Patents
一种模型优化方法和装置 Download PDFInfo
- Publication number
- CN114037876A CN114037876A CN202111546892.1A CN202111546892A CN114037876A CN 114037876 A CN114037876 A CN 114037876A CN 202111546892 A CN202111546892 A CN 202111546892A CN 114037876 A CN114037876 A CN 114037876A
- Authority
- CN
- China
- Prior art keywords
- pseudo
- samples
- loss
- image classification
- classification model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Probability & Statistics with Applications (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种模型优化方法和装置,用以解决由于半监督学习的监督信息不足而导致图像分类模型训练效果差的问题。包括:获取图像分类模型的训练样本,训练样本包括有标签样本和无标签样本,无标签样本包括有伪标签样本;将有标签样本和有伪标签样本作为伪监督对比学习样本输入图像分类模型,以确定伪监督对比学习损失,伪监督对比学习损失根据图像分类模型的特征提取网络对有标签样本和有伪标签样本提取的特征确定;根据伪监督对比学习损失优化图像分类模型。本方案利用伪标签进行对比表示学习,能从特征提取网络提取的特征维度优化损失函数,使损失函数更准确地表达模型损失,从而有效丰富监督信号,优化模型训练效果。
Description
技术领域
本发明涉及机器学习领域,尤其涉及一种模型优化方法和装置。
背景技术
在机器学习领域,往往需要利用样本对模型进行训练与优化。当样本中包含无标签样本时,可以应用半监督学习、无监督学习等方式进行模型训练。但是,半监督学习算法能获得的信息全部来自有标签数据的监督信息和无标签数据的一致性信息。仅使用少量标签数据的有监督训练和低密度假设所能提供的监督信息是不足的。而在无监督表示学习中,对比学习由于不能充分利用关键且有价值的有标签数据,导致对比学习在半监督任务中表现效果不理想。
如何优化图像分类模型训练效果,是本申请所要解决的技术问题。
发明内容
本申请实施例的目的是提供一种模型优化方法和装置,用以解决由于半监督学习的监督信息不足而导致图像分类模型训练效果差的问题。
第一方面,提供了一种模型优化方法,包括:
获取图像分类模型的训练样本,所述训练样本包括有标签样本和无标签样本,所述无标签样本包括有伪标签样本;
将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,所述伪监督对比学习损失根据所述图像分类模型的特征提取网络对所述有标签样本和所述有伪标签样本提取的特征确定;
根据所述伪监督对比学习损失优化所述图像分类模型
第二方面,提供了一种模型优化装置,包括:
获取模块,获取图像分类模型的训练样本,所述样本包括有标签样本和无标签样本,所述无标签样本包括有伪标签样本;
确定模块,将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,所述伪监督对比学习损失根据所述图像分类模型的特征提取网络对所述有标签样本和所述有伪标签样本提取的特征确定;
优化模块,根据所述伪监督对比学习损失优化所述图像分类模型。
第三方面,提供了一种电子设备,该电子设备包括处理器、存储器及存储在该存储器上并可在该处理器上运行的计算机程序,该计算机程序被该处理器执行时实现如第一方面的方法的步骤。
第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质上存储计算机程序,该计算机程序被处理器执行时实现如第一方面的方法的步骤。
在本申请实施例中,获取图像分类模型的训练样本,训练样本包括有标签样本和无标签样本,无标签样本包括有伪标签样本;将有标签样本和有伪标签样本作为伪监督对比学习样本输入图像分类模型,以确定伪监督对比学习损失,伪监督对比学习损失根据图像分类模型的特征提取网络对有标签样本和有伪标签样本提取的特征确定;根据伪监督对比学习损失优化图像分类模型。本方案利用伪标签进行对比表示学习,能从特征提取网络提取的特征维度优化损失函数,使损失函数更准确地表达模型损失,从而有效丰富监督信号,提高模型优化训练的效果。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本发明的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是本发明的一个实施例一种模型优化方法的流程示意图之一。
图2是本发明的一个实施例一种模型优化方法的流程示意图之二。
图3是本发明的一个实施例一种模型优化方法的流程示意图之三。
图4是本发明的一个实施例一种模型优化方法的流程示意图之四。
图5是本发明的一个实施例一种模型优化方法的流程示意图之五。
图6是本发明的一个实施例一种模型优化方法的流程示意图之六。
图7是本发明的一个实施例一种模型优化方法的流程示意图之七。
图8是本发明的一个实施例一种模型优化方法的流程示意图之八。
图9是本发明的一个实施例一种模型优化方法的流程示意图之九。
图10是本发明的一个实施例一种模型优化方法的流程示意图之十。
图11是本发明的一个实施例一种模型优化方法的流程示意图之十一。
图12是本发明的一个实施例一种模型优化装置的结构示意图之十二。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。本申请中附图编号仅用于区分方案中的各个步骤,不用于限定各个步骤的执行顺序,具体执行顺序以说明书中描述为准。
本申请提供的方案涉及机器学习领域,首先对本领域所包含的名词与概念进行解释:
半监督学习的目的是通过有效利用少量有标签样本和额外的大量无标签样本来提高神经网络的性能。最新的最有效的半监督方法大致可以分为两类:伪标签法(自训练)和基于一致性正则化的方法。其中,基于伪标签的方法将无标签样本输入训练一定程度后的模型得到中间预测结果,对无标签样本生成标签,然后选择置信度足够高的样本视为具有伪标签的样本用于后续训练,由此形成自训练过程。一致性正则方法认为:模型在面对相同输入图像的不同扰动时,应预测出相似的分类结果。扰动方式已从早期的高斯噪声发展到平移、裁剪等弱增强方式和自动增强、随机增强等强增强方式。
对比学习是一类自监督学习方法,它们通过对比损失来衡量正样本对和负样本对的相似性来学习表征。对于自监督学习,已有的方法主要集中在人工设计的辅助任务,例如,预测图像块的相对位置,解决拼图问题,对特征进行聚类以及预测图像旋转程度等。随着计算机视觉的迅速发展,大量使用自监督表示学习范式的作品在对比学习框架下取得了显著的成果。具体地说,对比学习通过激励一个特征表示向量gθ(x)与其对应正样本x+相似并与其它所有负样本x-相异来训练模型表征能力,这一过程可以表述为下式(1-1):
score(gθ(x),gθ(x+))>>score(gθ(x),gθ(x-)) (1-1)
上述模型学习方案中,存在以下缺陷:
对于半监督学习,目前算法能获得的信息全部来自有标签数据的监督信息和无标签数据的一致性信息。仅使用少量标签数据的有监督训练和低密度假设所能提供的监督信息是不足的。
同时,对比学习在无监督表示学习中得到了广泛的应用,但由于不能充分利用关键且有价值的有标签数据,导致对比学习在半监督任务中表现效果不理想。
为了解决现有技术中存在的问题,本申请实施例提供一种模型优化方法,
如图1所示,本申请实施例包括以下步骤:
S11:获取图像分类模型的训练样本,所述训练样本包括有标签样本和无标签样本,所述无标签样本包括有伪标签样本。
在本申请实施例中,训练样本中的有标签样本和有伪标签样本包括可输入至图像分类模型执行训练的图像以及图像对应的标签。上述有标签样本是指样本图像以及对应的标签,该标签可以是人工标记的标签。上述有伪标签样本是指样本图像以及对应的伪标签,该伪标签可以是将无标签样本输入至上述图像分类模型预测输出的伪标签。
本申请的图像分类模型可以用于执行图像分类任务,在本实例中,一个包含可学习参数Θ的图像分类模型在概念上可以分为两部分:特征提取网络和分类器。
S12:将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,所述伪监督对比学习损失根据所述图像分类模型的特征提取网络对所述有标签样本和所述有伪标签样本提取的特征确定。
在本步骤中,将有标签样本和有伪标签样本混合共同输入图像分类模型,用于伪监督对比学习,并将学习目标表示为伪监督对比损失。本步骤中使用图像分类模型的特征提取网络对上述有标签样本和有伪标签样本提取到的特征进行对比学习,得到的伪监督对比学习损失能表征特征维度的损失。
通过伪监督对比学习,能使同类(将伪标签视为大概率同类)样本的特征表示向量在特征空间中应尽可能相似,不同类样本的特征表示向量在特征空间中应尽可能相异。举例而言,正样本对的特征表示向量在特征空间中尽可能相似,负样本对的特征表示向量在特征空间中尽可能相异。
S13:根据所述伪监督对比学习损失优化所述图像分类模型。
基于上述步骤中得到的伪监督对比学习损失能表示模型提取的各类别内与各类别间样本的特征在特征维度的损失,换言之,该伪监督对比学习损失能从特征维度表达上述图像分类模型的损失。本步骤中根据该伪监督对比学习损失对图像分类模型执行优化,具体而言,以降低损失为优化目标调整图像分类模型,能从特征维度的损失优化图像分类模型,使得图像分类模型中的特征提取网络提取的特征在特征空间中更接近期望的特征分布,使优化后的图像分类模型提取的特征更准确,进而提高图像分类模型的预测效果,有效优化图像分类模型。
基于上述实施例提供的方案,可选的,如图2所示,上述步骤S12,包括:
S21:将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以得到所述特征提取网络对所述伪监督对比学习样本提取的特征数据。
在本步骤中,将有标签样本和有伪标签样本混合作为伪监督对比学习样本,将伪监督对比学习样本的各个图像输入图像分类模型后,由图像分类模型的特征提取网络输出图像对应的特征数据。具体的,该特征数据可以包括多个特征值,每个特征值可以表征图像在一种特征维度上的值,这些特征值共同表征了图像在多种特征维度上的特征。
S22:根据所述特征提取网络提取的特征数据确定所述伪监督对比学习样本中每个样本分别对应的多维特征向量。
上述特征向量的维度数量可以与上述特征数据中的特征值的数量一致,多维特征向量表达了图像在多个维度上的特征。
举例而言,如果将有标签样本与有伪标签样本用于本实例中的伪监督对比学习。其中,有伪标签样本属于无标签样本表示为实质上该有伪标签样本中包括无标签样本的图像以及模型对该无标签样本的图像预测输出的伪标签。在伪监督对比学习的步骤中,可以对上述伪监督对比学习样本进行数据增强,以优化学习效果。对于任意的伪监督对比学习样本表示其用于对比学习的特征表示向量,其中,β(xξ)表示对样本xξ采用数据增强方式β执行增强后的结果。应理解的是,也可以不执行数据增强,直接应用样本图像和对应的标签执行伪监督对比学习,那么特征向量可以表示为
S23:根据各个伪监督对比学习样本对应的多维特征向量确定伪监督对比学习损失。
上述伪监督对比学习损失表达了图像分类模型在特征提取维度上区分类内样本与类间样本的能力。本实例中确定的伪监督对比学习损失可以表示为下式(1-2):
其中,表示属于相同类的样本,即正样本。此外,<·>表示内积计算,τ表示一个标量的“温度”超参数(在本实例中固定为0.1)用于调节权重。该公式可以直观解释为,同类(将伪标签视为大概率同类)样本的特征表示向量在特征空间中应尽可能相似,不同类样本的特征表示向量在特征空间中应尽可能相异。举例而言,正样本对的特征表示向量在特征空间中尽可能相似,负样本对的特征表示向量在特征空间中尽可能相异。
基于上述实施例提供的方案,可选的,所述无标签样本还包括无伪标签样本,如图3所示,所述方法还包括:
S31:将所述无伪标签样本作为无监督对比学习样本输入所述图像分类模型,以确定无监督对比学习损失,所述无监督对比学习损失根据所述图像分类模型的特征提取网络对所述无伪标签样本提取的特征确定。
其中,上述无伪标签样本是指包含样本图像但不具有相对应的真实标签也不具有模型预测的伪标签的样本。换言之,这些无伪标签样本是未确定分类结果的样本。
在本步骤中,将无伪标签样本作为无监督对比学习样本进行无监督对比学习,其中应用了图像分类模型的特征提取网络提取到的特征执行对比学习。在无监督对比学习的过程中,使同一样本应用数据增强后得到的不同图像提取出的特征更相似,使不相同样本应用数据增强后得到的图像提取出的特征更相异,从而在特征提取维度上实现优化学习。
其中,上述步骤S13,包括:
S32:根据所述伪监督对比学习损失及所述无监督对比学习损失优化所述图像分类模型。
上述伪监督对比学习损失以及无监督对比学习损失都是基于图像分类模型的特征提取网络提取的特征确定的,都能从特征维度上表达提取到的特征与期望特征分布之间的差距。本步骤中,基于这两种对比学习损失对图像分类模型进行优化,能使图像分类模型提取的特征在特征空间中更接近于期望的特征分布,从而优化图像分类模型的分类结果。
可选的,在优化图像分类模型的步骤中,还可以基于预设权重结合上述两种对比学习损失执行模型优化。比如说,损失函数可以为其中的λpsc和λuc可以是分别对应于伪监督对比学习损失和无监督对比学习损失的权重。在实际应用中,可以根据伪监督对比学习样本数量或质量、无监督对比学习样本数量或质量来确定上述权重,以优化模型训练效果。
可选的,在本步骤中,可以先基于伪监督对比学习损失对图像分类模型执行优化之后,再执行本实例步骤S31和S32。基于伪监督对比学习损失对图像分类模型进行优化,能够在特征维度上对模型实现一定程度的优化。在基于伪监督对比学习损失进行优化后,模型在特征表达上更准确,此时再执行无监督对比学习的步骤S31,就能对无监督对比学习样本提取出更准确的特征,使无监督对比学习的效果更好。
基于上述实施例提供的方案,可选的,如图4所示,上述步骤S21,包括:
S41:将所述无伪标签样本作为无监督对比学习样本输入所述图像分类模型,以得到所述特征提取网络对所述无监督对比学习样本提取的特征数据。
在本实例中,将无伪标签样本用于伪监督对比学习,并将学习目标表示为无监督对比学习损失。本步骤中使用图像分类模型的特征提取网络对上述无伪标签样本提取到的特征数据进行对比表示学习,得到的无监督对比学习损失能表征特征维度的损失。其中,特征数据可以包括样本在多个特征维度上的特征值,特征值表征了样本在相应维度上的特征。
S42:根据所述特征提取网络提取的特征数据确定所述无监督对比学习样本中每个样本分别对应的多维特征向量。
上述特征向量的维度数量可以与上述特征数据中的特征值的数量一致,多维特征向量表达了图像在多个维度上的特征。
具体地,在本步骤中,无伪标签样本实质属于无标签样本,实质上该无伪标签样本中包括样本图像且不具有标签或伪标签。本步骤中的无伪标签样本是指无标签样本中有伪标签样本以外的其他无标签样本,可以表示为 对于任意的样本特征表示向量可以表示为
S43:根据各个无监督对比学习样本对应的多维特征向量确定无监督对比学习损失。
本实施例中的无监督对比学习样本为进行数据增强后的无监督对比学习样本。
本步骤中确定的无监督对比学习损失可以表示为下式(1-3):
其中,i′表示正样本。此外,τ′表示另一个标量的“温度”超参数(本实施例中固定为0.1)用于调节权重。该公式可以直观解释为,同一样本的随机增广结果的特征表示向量(正样本)在特征空间中应尽可能相似,某样本的随机增广结果的特征表示向量与其他样本的随机增广结果的特征表示向量(负样本)在特征空间中应尽可能相异。
基于上述实施例提供的方案,能对图像分类模型实现优化。可选的,还可以将上述无标签样本输入优化后的图像分类模型,进而根据优化后的图像分类模型输出的伪标签确定无标签样本中的有伪标签样本和无伪标签样本。随后,再通过本申请实施例提供的各个步骤再次对上述优化后的图像分类模型执行优化,从而实现迭代优化,使图像分类模型的预测结果更加接近于真实标签,直至符合应用需求。
可选的,基于上述实施例提供的方案,如图5所示,在上述步骤S11之前,还包括:
S51:获取多个无标签样本。
本步骤中的无标签样本是与所述图像分类模型相匹配的无标签样本,该无标签样本中包括样本图像,且该样本图像不具有相对应的标签或伪标签。
S52:将多个所述无标签样本分别输入所述图像分类模型,以得到所述图像分类模型的分类器输出的预测分类结果,所述预测分类结果包括与各个所述无标签样本对应的分类概率。
在本步骤中,将无标签样本输入图像分类模型,可以得到图像分类模型对每个无标签样本预测出的属于各个分类的概率,即上述预测分类结果。预测分类结果中的分类概率表示无标签样本属于对应分类的概率。对于一个无标签样本,图像分类模型输出的预测分类结果中可能包括多个分类以及与各个分类相对应的分类概率。
S53:将所述预测分类结果中概率大于或等于预设概率的无标签样本对应的分类确定为伪标签,以生成所述有伪标签样本。
基于上述步骤得到的预测分类结果,本步骤中将分类概率高的分类确定为无标签样本对应的伪标签。在实际应用中,对于一个无标签样本,图像分类模型可能输出多个分类及其对应的分类概率。
举例而言,预测分类结果中包括无标签样本A和无标签样本B的分类概率。其中,对于无标签样本A输出的分类概率为:a类--2%;b类--86%;c类--12%。对于无标签样本B输出的分类及其对应的分类概率为:a类--40%;b类--35%;c类--25%。假设预设概率为75%,那么对于上述无标签样本A,由于b分类概率86%大于上述预设概率75%,所以将b分类确定为该无标签样本A的伪标签,生成伪标签样本。而对于上述无标签样本B,由于各分类概率都小于上述预设概率75%,所以该无标签样本B在本步骤中未确定出伪标签,属于无伪标签样本。
通过本申请实施例提供的方案,基于图像分类模型的预测结果和预设概率,为一部分无标签样本确定相对应的伪标签,充分利用无标签样本,将无标签样本分为有伪标签样本和无伪标签样本,从而在后续步骤中为确定损失函数提供更多数据支持。
基于上述实施例提供的方案,可选的,所述无伪标签样本包括所述预测分类结果中概率小于所述预设概率的无标签样本。
基于上述实例,如果无标签样本的各分类概率都小于预设概率,则表明现有的图像分类模型不能准确地确定出该无标签样本所属的分类,因此,可以将不能确定分类的无标签样本确定为无伪标签样本,用于在后续步骤中为确定损失函数提供数据支持。
需要说明的是,大部分无标签样本在训练的早期并不能获得足够自信的伪标签。但如果迭代执行本方案,随着迭代次数的增加,伪标签的数量会逐渐增多。因此,随着模型迭代训练,伪标签置信度将发生变化,本发明中伪监督对比学习和无监督对比学习的样本比例也会自动改变,能够实现模型的迭代优化。
基于上述实施例提供的方案,可选的,如图6所示,在上述步骤S52之前,还包括:
S61:对至少部分所述无标签样本进行强增强,以得到强增强样本。
在本步骤中,可以对无标签样本中样本图像分别进行弱增强与强增强,其中进行强增强后的无标签样本为强增强样本。另外,可以将进行弱增强后的无标签样本确定为弱增强样本。
上述弱增强和强增强是指数据增强,基于增强后的样本能够实现对模型提取特征能力的优化。具体而言,弱增强是指对样本图像执行平移、翻转等较小幅度的变换。通常而言,弱增强后的样本虽然与增强前的样本有一定区别,但依然比较容易辨认,即容易提取出特征并确定分类结果。而强增强是指对样本进行遮挡、亮度或明暗度改变、变色等较大幅度的变换。对模型而言,经过强增强后的样本与增强前的样本相比,更难以提取特征,进而更难以确定分类结果。通过本方案中所述的强增强、弱增强等数据增强方式,能够对样本图像进行变更,一定程度上提高模型对样本提取特征和分类的能力,进而使得模型通过训练能更准确地提取出特征并更准确地进行分类。
其中,上述步骤S52,包括:
S62:将所述强增强样本输入所述图像分类模型,以得到所述图像分类模型的分类器输出的强增强样本的预测分类结果。
在本步骤中,可以将上述强增强样本输入图像分类模型,分类器输出的预测分类结果可以表示为其中,表示无标签样本uξ进行强增强后的结果。可选的,上述弱增强样本也可以输入图像分类模型,得到分类器输出的预测分类结果表示为pm(y∣fΘ(α(uξ)))。
本实施例确定的预测分类结果能进一步用于确定无标签样本对应的分类概率,提高确定的分类概率准确性。
基于上述实施例提供的方案,可选的,如图7所示,在上述步骤S53之后,还包括:
S71:确定所述伪标签与所述强增强样本的预测分类结果的一致性损失。
举例而言,令qξ=pm(y∣fΘ(α(uξ))),挑选上述S32确定的预测分类结果中最大类的概率大于或等于预设定阈值γ的样本,保留其最大概率类别作为伪标签。再将伪标签与上述预测分类结果计算一致性损失本步骤具体可以采用交叉熵一致性损失可以表示为下式(1-4):
本步骤中确定的一致性损失用于确定损失函数,其中的一致性是指不经过强增强的样本的预测分类结果应与经过强增强的样本的预测分类结果相一致,一致性损失是指这增强后与增强前的样本的预测分类结果之间的差异。
S72:根据所述伪监督对比学习损失及所述一致性损失优化所述图像分类模型。
由于上述一致性损失能够表征分类器的损失,因此,本方案中用于优化模型的损失函数能表征特征维度上的损失以及分类结果上的损失。其中,特征维度上的损失由上述实施例所述的伪监督对比学习损失表征,分类结果上的损失由本实施例中所述的一致性损失表征。所以,通过本申请实施例提供的方案,能使损失函数从特征维度和分类结果维度更准确地表征损失,从而以降低损失函数为优化目标调整图像分类模型的参数能进一步提高模型优化训练的效果。
基于上述实施例提供的方案,可选的,如图8所示,在上述步骤S11之前,还包括:
S81:获取多个有标签图像样本。
在本步骤中,有标签图像样本包括样本图像和相对应的单分类标签,该单分类标签表征对应的样本图像所属的分类,可以表示为yξ。
S82:对至少部分所述有标签图像样本进行弱增强,以得到弱增强样本。
本步骤中,对有标签图像样本xξ执行弱增强,弱增强样本可以表示为α(xξ)。
S83:根据所述弱增强样本训练所述图像分类模型。
通过本申请实施例提供的方案,能利用具有单分类标签的有标签图像样本对模型进行训练,通过数据增强提升训练图像分类模型的效果。模型通过弱增强后的图像样本进行优化训练,能一定程度上提升模型预测结果的质量,为之后基于多种对比学习损失进一步优化提供基础,使模型阶梯性逐步优化,提升优化效果。
基于上述实施例提供的方案,可选的,如图9所示,上述步骤S83,包括:
S91:根据交叉熵损失函数确定所述弱增强样本的交叉熵损失。
其中,α(xξ)表示有标签样本xξ进行弱增强后的结果,yξ表示单分类标签。
S92:根据所述弱增强样本和所述交叉熵损失训练所述图像分类模型。
通过本申请实施例提供的方案,基于数据增强有效提升训练图像分类模型的效果。
基于上述实施例提供的方案,可选的,如图10所示,上述步骤S13,包括:
S101:根据所述一致性损失、所述伪监督对比学习损失和所述无监督对比学习损失确定损失函数其中,λu、λpsc与λuc分别为所述一致性损失、所述伪监督对比学习损失和所述无监督对比学习损失所占权重,为所述弱增强样本的交叉熵损失,为所述一致性损失,为所述伪监督对比学习损失,为所述无监督对比学习损失;
S102:根据所述损失函数优化训练所述图像分类模型。
本步骤中,以降低所述损失函数为优化目标,调整所述图像分类模型的参数,以得到优化后的图像分类模型。可选的,具体可以采用梯度下降法或其他方法执行优化,本申请实施例提供的方案综合上述一致性损失、所述伪监督对比学习损失和无监督对比学习损失共同确定损失函数,上述各损失所占权重可以根据实际需求预先设定。通过本方案,确定的损失函数能从特征维度以及分类结果维度表征损失,有效丰富监督信息。进而根据本实施例确定的损失优化图像分类模型,能从特征维度和分类结果维度实现模型优化,进一步提升模型优化训练效果。
综合上述多种实施例提供的方案,本申请方案流程示意图如图11所示,本申请实施例提供的方案中,利用了特征提取网络fΘ得到的表示特征进行对比表示学习扩充半监督训练中的特征获取。除了利用样本标签的一致性信息外,还能学习到样本在特征维度的额外辨别性信息,从而丰富监督信息,提高优化训练模型的效果。
本申请实施例提供的上述方案可以用于解决半监督图像分类问题,具体地说,半监督图像数据集包括有标签训练数据、无标签训练数据和测试集数据。应用本方案后,在包含有标签和无标签的训练数据集上对模型进行训练,相比仅在有标签训练数据集上采用传统有监督训练方法得到的结果,本发明得到的结果准确率更高。
之所以本发明得到的模型在测试数据集上准确率更高,这是由于本发明中合理引入了有标签和无标签样本在特征维度上的对比表示信息,这有利于模型在学习一致性信息外,学习额外有辨识力的对比表示特征。
为了解决现有技术中存在的问题,本申请实施例还提供一种模型优化装置120,如图12所示,包括:
获取模块121,获取图像分类模型的训练样本,所述样本包括有标签样本和无标签样本,所述无标签样本包括有伪标签样本;
确定模块122,将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,所述伪监督对比学习损失根据所述图像分类模型的特征提取网络对所述有标签样本和所述有伪标签样本提取的特征确定;
优化模块123,根据所述伪监督对比学习损失优化所述图像分类模型。
通过本申请实施例提供的装置,获取图像分类模型的训练样本,训练样本包括有标签样本和无标签样本,无标签样本包括有伪标签样本;将有标签样本和有伪标签样本作为伪监督对比学习样本输入图像分类模型,以确定伪监督对比学习损失,伪监督对比学习损失根据图像分类模型的特征提取网络对有标签样本和有伪标签样本提取的特征确定;根据伪监督对比学习损失优化图像分类模型。本方案利用伪标签进行对比表示学习,能从特征提取网络提取的特征维度优化损失函数,使损失函数更准确地表达模型损失,从而有效丰富监督信号,提高模型优化训练的效果。
其中,本申请实施例提供的装置中的上述模块还可以实现上述方法实施例提供的方法步骤。或者,本申请实施例提供的装置还可以包括除上述模块以外的其他模块,用以实现上述方法实施例提供的方法步骤。且本申请实施例提供的装置能够实现上述方法实施例所能达到的技术效果。
优选的,本发明实施例还提供一种电子设备,包括处理器,存储器,存储在存储器上并可在所述处理器上运行的计算机程序,该计算机程序被处理器执行时实现上述一种模型优化方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
本发明实施例还提供一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,该计算机程序被处理器执行时实现上述一种模型优化方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。其中,所述的计算机可读存储介质,如只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等。
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明是参照根据本发明实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
在一个典型的配置中,计算设备包括一个或多个处理器(CPU)、输入/输出接口、网络接口和内存。
内存可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM)。内存是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitory media),如调制的数据信号和载波。
还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、商品或者设备中还存在另外的相同要素。
本领域技术人员应明白,本申请的实施例可提供为方法、系统或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
以上所述仅为本申请的实施例而已,并不用于限制本申请。对于本领域技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本申请的权利要求范围之内。
Claims (14)
1.一种模型优化方法,其特征在于,包括:
获取图像分类模型的训练样本,所述训练样本包括有标签样本和无标签样本,所述无标签样本包括有伪标签样本;
将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,所述伪监督对比学习损失根据所述图像分类模型的特征提取网络对所述有标签样本和所述有伪标签样本提取的特征确定;
根据所述伪监督对比学习损失优化所述图像分类模型。
2.如权利要求1所述的方法,其特征在于,所述无标签样本还包括无伪标签样本,所述方法还包括:
将所述无伪标签样本作为无监督对比学习样本输入所述图像分类模型,以确定无监督对比学习损失,所述无监督对比学习损失根据所述图像分类模型的特征提取网络对所述无伪标签样本提取的特征确定;
其中,根据所述伪监督对比学习损失优化所述图像分类模型包括:
根据所述伪监督对比学习损失及所述无监督对比学习损失优化所述图像分类模型。
3.如权利要求1所述的方法,其特征在于,将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,包括:
将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以得到所述特征提取网络对所述伪监督对比学习样本提取的特征数据;
根据所述特征提取网络提取的特征数据确定所述伪监督对比学习样本中每个样本分别对应的多维特征向量;
根据各个伪监督对比学习样本对应的多维特征向量确定伪监督对比学习损失。
4.如权利要求2所述的方法,其特征在于,将所述无伪标签样本作为无监督对比学习样本输入所述图像分类模型,以确定无监督对比学习损失,包括:
将所述无伪标签样本作为无监督对比学习样本输入所述图像分类模型,以得到所述特征提取网络对所述无监督对比学习样本提取的特征数据;
根据所述特征提取网络提取的特征数据确定所述无监督对比学习样本中每个样本分别对应的多维特征向量;
根据各个无监督对比学习样本对应的多维特征向量确定无监督对比学习损失。
5.如权利要求4所述的方法,其特征在于,在获取图像分类模型的训练样本之前,还包括:
获取多个无标签样本;
将多个所述无标签样本输入所述图像分类模型,以得到所述图像分类模型的分类器输出的预测分类结果,所述预测分类结果包括与各个所述无标签样本对应的分类概率;
将所述预测分类结果中概率大于或等于预设概率的无标签样本对应的分类确定为伪标签,以生成所述有伪标签样本。
6.如权利要求5所述的方法,其特征在于,所述无伪标签样本包括所述预测分类结果中概率小于所述预设概率的无标签样本。
7.如权利要求5所述的方法,其特征在于,在将多个所述无标签样本输入所述图像分类模型,以得到所述图像分类模型的分类器输出的预测分类结果之前,还包括:
对至少部分所述无标签样本进行强增强,以得到强增强样本;
其中,将多个所述无标签样本输入所述图像分类模型,以得到所述图像分类模型的分类器输出的预测分类结果,包括:
将所述强增强样本输入所述图像分类模型,以得到所述图像分类模型的分类器输出的强增强样本的预测分类结果。
8.如权利要求7所述的方法,其特征在于,将所述预测分类结果中概率大于或等于预设概率的无标签样本对应的分类确定为伪标签,以生成所述有伪标签样本之后,还包括:
确定所述伪标签与所述强增强样本的预测分类结果的一致性损失;
其中,根据所述伪监督对比学习损失优化所述图像分类模型,包括:
根据所述伪监督对比学习损失及所述一致性损失优化所述图像分类模型。
9.如权利要求8所述的方法,其特征在于,在获取图像分类模型的训练样本之前,还包括:
获取多个有标签图像样本;
对至少部分所述有标签图像样本进行弱增强,以得到弱增强样本;
根据所述弱增强样本训练所述图像分类模型。
10.如权利要求9所述的方法,其特征在于,根据所述弱增强样本训练所述图像分类模型,包括:
根据交叉熵损失函数确定所述弱增强样本的交叉熵损失;
根据所述弱增强样本和所述交叉熵损失训练所述图像分类模型。
12.一种模型优化装置,其特征在于,包括:
获取模块,获取图像分类模型的训练样本,所述样本包括有标签样本和无标签样本,所述无标签样本包括有伪标签样本;
确定模块,将所述有标签样本和所述有伪标签样本作为伪监督对比学习样本输入所述图像分类模型,以确定伪监督对比学习损失,所述伪监督对比学习损失根据所述图像分类模型的特征提取网络对所述有标签样本和所述有伪标签样本提取的特征确定;
优化模块,根据所述伪监督对比学习损失优化所述图像分类模型。
13.一种电子设备,其特征在于,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如权利要求1至11中任一项所述的方法的步骤。
14.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至11中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111546892.1A CN114037876A (zh) | 2021-12-16 | 2021-12-16 | 一种模型优化方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111546892.1A CN114037876A (zh) | 2021-12-16 | 2021-12-16 | 一种模型优化方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114037876A true CN114037876A (zh) | 2022-02-11 |
Family
ID=80147018
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111546892.1A Pending CN114037876A (zh) | 2021-12-16 | 2021-12-16 | 一种模型优化方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114037876A (zh) |
Cited By (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114565972A (zh) * | 2022-02-23 | 2022-05-31 | 中国科学技术大学 | 骨架动作识别方法、系统、设备与存储介质 |
CN115187787A (zh) * | 2022-09-09 | 2022-10-14 | 清华大学 | 用于自监督多视图表征学习的局部流形增强的方法及装置 |
CN115471717A (zh) * | 2022-09-20 | 2022-12-13 | 北京百度网讯科技有限公司 | 模型的半监督训练、分类方法装置、设备、介质及产品 |
CN115482418A (zh) * | 2022-10-09 | 2022-12-16 | 宁波大学 | 基于伪负标签的半监督模型训练方法、系统及应用 |
CN115496955A (zh) * | 2022-11-18 | 2022-12-20 | 之江实验室 | 图像分类模型训练方法、图像分类方法、设备和介质 |
CN116206164A (zh) * | 2023-05-06 | 2023-06-02 | 之江实验室 | 基于半监督对比学习的多相期ct分类系统及构建方法 |
CN116527399A (zh) * | 2023-06-25 | 2023-08-01 | 北京金睛云华科技有限公司 | 基于不可靠伪标签半监督学习的恶意流量分类方法和设备 |
WO2023221634A1 (zh) * | 2022-05-19 | 2023-11-23 | 腾讯科技(深圳)有限公司 | 视频检测方法、装置、设备、存储介质和程序产品 |
CN117154716A (zh) * | 2023-09-08 | 2023-12-01 | 国网河南省电力公司 | 一种分布式电源接入配电网的规划方法及系统 |
WO2024022376A1 (zh) * | 2022-07-29 | 2024-02-01 | 马上消费金融股份有限公司 | 图像处理方法、装置、设备和介质 |
CN117809087A (zh) * | 2023-12-20 | 2024-04-02 | 北京医准医疗科技有限公司 | 一种乳腺图像分类模型的训练方法、装置及电子设备 |
-
2021
- 2021-12-16 CN CN202111546892.1A patent/CN114037876A/zh active Pending
Cited By (18)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114565972B (zh) * | 2022-02-23 | 2024-04-02 | 中国科学技术大学 | 骨架动作识别方法、系统、设备与存储介质 |
CN114565972A (zh) * | 2022-02-23 | 2022-05-31 | 中国科学技术大学 | 骨架动作识别方法、系统、设备与存储介质 |
WO2023221634A1 (zh) * | 2022-05-19 | 2023-11-23 | 腾讯科技(深圳)有限公司 | 视频检测方法、装置、设备、存储介质和程序产品 |
WO2024022376A1 (zh) * | 2022-07-29 | 2024-02-01 | 马上消费金融股份有限公司 | 图像处理方法、装置、设备和介质 |
CN115187787A (zh) * | 2022-09-09 | 2022-10-14 | 清华大学 | 用于自监督多视图表征学习的局部流形增强的方法及装置 |
CN115187787B (zh) * | 2022-09-09 | 2023-01-31 | 清华大学 | 用于自监督多视图表征学习的局部流形增强的方法及装置 |
CN115471717A (zh) * | 2022-09-20 | 2022-12-13 | 北京百度网讯科技有限公司 | 模型的半监督训练、分类方法装置、设备、介质及产品 |
CN115482418B (zh) * | 2022-10-09 | 2024-06-07 | 北京呈创科技股份有限公司 | 基于伪负标签的半监督模型训练方法、系统及应用 |
CN115482418A (zh) * | 2022-10-09 | 2022-12-16 | 宁波大学 | 基于伪负标签的半监督模型训练方法、系统及应用 |
CN115496955B (zh) * | 2022-11-18 | 2023-03-24 | 之江实验室 | 图像分类模型训练方法、图像分类方法、设备和介质 |
CN115496955A (zh) * | 2022-11-18 | 2022-12-20 | 之江实验室 | 图像分类模型训练方法、图像分类方法、设备和介质 |
CN116206164B (zh) * | 2023-05-06 | 2023-08-18 | 之江实验室 | 基于半监督对比学习的多相期ct分类系统及构建方法 |
CN116206164A (zh) * | 2023-05-06 | 2023-06-02 | 之江实验室 | 基于半监督对比学习的多相期ct分类系统及构建方法 |
CN116527399A (zh) * | 2023-06-25 | 2023-08-01 | 北京金睛云华科技有限公司 | 基于不可靠伪标签半监督学习的恶意流量分类方法和设备 |
CN116527399B (zh) * | 2023-06-25 | 2023-09-26 | 北京金睛云华科技有限公司 | 基于不可靠伪标签半监督学习的恶意流量分类方法和设备 |
CN117154716A (zh) * | 2023-09-08 | 2023-12-01 | 国网河南省电力公司 | 一种分布式电源接入配电网的规划方法及系统 |
CN117154716B (zh) * | 2023-09-08 | 2024-04-26 | 国网河南省电力公司 | 一种分布式电源接入配电网的规划方法及系统 |
CN117809087A (zh) * | 2023-12-20 | 2024-04-02 | 北京医准医疗科技有限公司 | 一种乳腺图像分类模型的训练方法、装置及电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114037876A (zh) | 一种模型优化方法和装置 | |
CN110348580B (zh) | 构建gbdt模型的方法、装置及预测方法、装置 | |
CN113837370B (zh) | 用于训练基于对比学习的模型的方法和装置 | |
CN109063743B (zh) | 基于半监督多任务学习的医疗数据分类模型的构建方法 | |
Ma et al. | Lightweight attention convolutional neural network through network slimming for robust facial expression recognition | |
CN108629373B (zh) | 一种图像分类方法、系统、设备及计算机可读存储介质 | |
US11494689B2 (en) | Method and device for improved classification | |
Zhang et al. | Contrastive deep supervision | |
CN108629358B (zh) | 对象类别的预测方法及装置 | |
CN111178533B (zh) | 实现自动半监督机器学习的方法及装置 | |
CN110991500A (zh) | 一种基于嵌套式集成深度支持向量机的小样本多分类方法 | |
Peng et al. | Leaf disease image retrieval with object detection and deep metric learning | |
Cheng et al. | Multi-label few-shot learning for sound event recognition | |
CN115270797A (zh) | 一种基于自训练半监督学习的文本实体抽取方法及系统 | |
Myers-Dean et al. | Generalized few-shot semantic segmentation: All you need is fine-tuning | |
CN116910571B (zh) | 一种基于原型对比学习的开集域适应方法及系统 | |
Guo et al. | Vehicle detection based on superpixel and improved hog in aerial images | |
McNeely-White et al. | Inception and ResNet: Same training, same features | |
CN107895167A (zh) | 一种基于稀疏表示的红外光谱数据分类识别方法 | |
Ouni et al. | A new cbir model using semantic segmentation and fast spatial binary encoding | |
CN116503674B (zh) | 一种基于语义指导的小样本图像分类方法、装置及介质 | |
Yang et al. | Computing object-based saliency via locality-constrained linear coding and conditional random fields | |
Li et al. | Rapid and high-purity seed grading based on pruned deep convolutional neural network | |
Gao et al. | A Chinese dish detector with modified YOLO v3 | |
La Grassa et al. | Learning to Navigate in the Gaussian Mixture Surface |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |