CN111242222B - 分类模型的训练方法、图像处理方法及装置 - Google Patents

分类模型的训练方法、图像处理方法及装置 Download PDF

Info

Publication number
CN111242222B
CN111242222B CN202010040821.3A CN202010040821A CN111242222B CN 111242222 B CN111242222 B CN 111242222B CN 202010040821 A CN202010040821 A CN 202010040821A CN 111242222 B CN111242222 B CN 111242222B
Authority
CN
China
Prior art keywords
prediction result
classification model
classification
fully
confidence coefficient
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010040821.3A
Other languages
English (en)
Other versions
CN111242222A (zh
Inventor
张有才
常杰
危夷晨
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Megvii Technology Co Ltd
Original Assignee
Beijing Megvii Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Megvii Technology Co Ltd filed Critical Beijing Megvii Technology Co Ltd
Priority to CN202010040821.3A priority Critical patent/CN111242222B/zh
Publication of CN111242222A publication Critical patent/CN111242222A/zh
Application granted granted Critical
Publication of CN111242222B publication Critical patent/CN111242222B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提供了一种分类模型的训练方法、图像处理方法及装置,分类模型包括特征提取网络和全连接网络,该方法包括:将样本图像输入至分类模型,得到样本图像对应的特征图;将特征图输入至全连接网络,得到分类模型输出的预测结果和所述预测结果对应的置信度;根据所述预测结果和所述置信度,确定目标损失值;根据目标损失值,更新分类模型的参数。本发明可以使训练后的分类模型输出预测结果的置信度。

Description

分类模型的训练方法、图像处理方法及装置
技术领域
本发明涉及神经网络技术领域,尤其是涉及一种分类模型的训练方法、图像处理方法及装置。
背景技术
神经网络是一种具有自学和自适应能力的机器学习模型,通过利用训练集对神经网络进行训练,可以使神经网络完成指定的任务,例如实现人脸识别或目标检测等任务,目前将图像或视频等输入至训练后的神经网络,可以得到训练后的神经网络针对输入的图像或视频输出的预测结果,但是用户无法获知该预测结果的可信程度。
发明内容
有鉴于此,本发明的目的在于提供一种分类模型的训练方法、图像处理方法及装置,可以使训练后的分类模型输出结果的置信度。
为了实现上述目的,本发明实施例采用的技术方案如下:
第一方面,本发明实施例提供了一种分类模型的训练方法,所述分类模型包括特征提取网络和全连接网络,所述方法包括:将样本图像输入至所述分类模型,得到所述样本图像对应的特征图;将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果和所述预测结果对应的置信度;根据所述预测结果和所述置信度,确定目标损失值;根据所述目标损失值更新所述分类模型的参数。
在一种实施方式中,所述全连接网络包括与所述特征提取网络连接的第一全连接子网络和第二全连接子网络;所述第一全连接子网络和第二全连接子网络包括多个全连接层;将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果和所述预测结果对应的置信度的步骤,包括:将所述特征图输入至所述第一全连接子网络,确定所述分类模型对所述样本图像的所述预测结果;将所述特征图输入至所述第二全连接子网络,通过所述第二全连接子网络,确定所述预测结果所对应的置信度。
在一种实施方式中,所述方法还包括:根据分类损失函数和正则化项确定所述目标损失函数,其中,所述正则化项用于表征采样特征与预设分布函数之间的相似度,所述采样特征通过对所述预测结果和所述置信度进行加权求和确定;根据所述预测结果、所述置信度和所述目标损失函数,确定目标损失值。
在一种实施方式中,所述方法还包括:在所述预设分布函数中进行随机采样,得到第一系数;基于所述第一系数,对所述预测结果和所述置信度进行加权求和,得到所述采样特征。
在一种实施方式中,所述正则化项为KL散度函数,根据所述预测结果和所述置信度计算得到。
在一种实施方式中,所述方法还包括:根据所述预测结果确定所述分类损失函数的第一分量;根据所述采样特征确定所述分类损失函数的第二分量;根据所述第一分量和/或所述第二分量,确定所述分类损失函数。
在一种实施方式中,所述特征图对应于所述样本图像的特征分布。
在一种实施方式中,将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果,包括:通过所述全连接网络对所述样本图像的特征分布进行处理,得到所述样本图像的特征分布对应的预测分类分布;将所述预测分类分布的均值确定为所述预测结果。
在一种实施方式中,根据所述预测结果的方差,确定所述置信度。
第二方面,本发明实施例还提供一种图像处理方法,包括:获取待处理图像;通过预设分类模型对待处理图像进行处理,得到所述待处理图像对应的预测结果和所述预测结果对应的置信度,其中,所述预设分类模型为采用如第一方面提供的任一项所述的分类模型的训练方法训练得到。
在一种实施方式中,所述方法还包括:当所述预测结果对应的置信度低于预设置信度阈值,输出所述预测结果对应的提示消息。
第三方面,本发明实施例还提供一种分类模型的训练装置,所述分类模型包括特征提取网络和全连接网络,所述装置包括:特征提取模块,用于将样本图像输入至所述分类模型,得到所述样本图像对应的特征图;输出模块,用于将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果和所述预测结果对应的置信度;损失计算模块,用于根据所述预测结果和所述置信度,确定目标损失值;训练模块,用于根据目标损失值,更新所述分类模型的参数。
第四方面,本发明实施例提供了一种图像处理装置,包括:图像获取模块,用于获取待处理图像;图像处理模块,用于通过预设分类模型对待处理图像进行处理,得到所述待处理图像对应的预测结果和所述预测结果对应的置信度,其中,所述预设分类模型为采用如第一方面提供的任一项所述的分类模型的训练方法训练得到。
第五方面,本发明实施例还提供一种电子设备,包括处理器和存储器;所述存储器上存储有计算机程序,所述计算机程序在被所述处理器运行时执行如第一方面提供的任一项所述的方法,或执行如第二方面提供的任一项所述的方法。
第六方面,本发明实施例提供了一种计算机存储介质,用于储存为第一方面提供的任一项所述方法所用的计算机软件指令,或存储为第二方面提供的任一项所述方法所用的计算机软件指令。
本发明实施例提供了一种分类模型的训练方法及装置,其中,分类模型包括特征提取网络和全连接网络,该方法首先将样本图像输入至分类模型得到特征图,再将特征图输入至全连接网络得到分类模型输出的预测结果和预测结果对应的置信度,然后基于预测结果和置信度计算分类模型的目标损失值,进而可以利用目标损失值对分类模型进行训练。上述方法利用全连接网络得到分类模型输出的预测结果和预测结果对应的置信度,并基于预测结果和置信度计算得到的目标损失值对分类模型进行训练,不仅可以使训练后的分类模型输出准确率较高的预测结果,还可以使训练后的分类模型输出的预测结果携带有置信度,从而通过携带的置信度体现预测结果的可信程度。
本发明实施例提供了一种图像处理方法及装置,首先获取待处理图像,并通过上述分类模型的训练方法训练得到的分类模型对待处理图像进行处理,得到待处理图像对应的预测结果和预测结果对应的置信度。上述方法在对待处理图像进行处理后可以得到预测结果和其对应的置信度,相较于现有技术仅能获取图像处理的预测结果,本发明实施例可以利用分类模型直接得到预测结果的置信度,从而利用置信度对预测结果的可信程度进行评价,有利用用户根据分类模型输出的置信度作出是否采信或应用预测结果的决策。
本发明实施例的其他特征和优点将在随后的说明书中阐述,或者,部分特征和优点可以从说明书推知或毫无疑义地确定,或者通过实施本发明实施例的上述技术即可得知。
为使本发明的上述目的、特征和优点能更明显易懂,下文特举较佳实施例,并配合所附附图,作详细说明如下。
附图说明
为了更清楚地说明本发明具体实施方式或现有技术中的技术方案,下面将对具体实施方式或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了本发明实施例所提供的一种电子设备的结构示意图;
图2示出了本发明实施例所提供的一种分类模型的结构示意图;
图3示出了本发明实施例所提供的一种分类模型的训练方法的流程示意图;
图4示出了本发明实施例所提供的另一种分类模型的结构示意图;
图5示出了本发明实施例所提供的一种图像处理方法的流程示意图;
图6示出了本发明实施例所提供的一种分类模型的训练装置的结构示意图;
图7示出了本发明实施例所提供的一种图像处理装置的结构示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合附图对本发明的技术方案进行描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。
考虑到现有技术中的分类模型仅能输出结果而无法体现结果的可信程度,为改善此问题,本发明实施例提供了一种分类模型的训练方法、图像处理方法及装置,该技术可应用于需要对神经网络进行训练的场景,以下对本发明实施例进行详细介绍。
实施例一
首先,参照图1来描述用于实现本发明实施例的一种分类模型的训练方法、图像处理方法及装置的示例电子设备100。
如图1所示的一种电子设备的结构示意图,电子设备100包括一个或多个处理器102、一个或多个存储装置104、输入装置106、输出装置108以及图像采集装置110,这些组件通过总线系统112和/或其它形式的连接机构(未示出)互连。应当注意,图1所示的电子设备100的组件和结构只是示例性的,而非限制性的,根据需要,所述电子设备可以具有图1示出的部分组件,也可以具有图1未示出其他组件和结构。
所述处理器102可以采用数字信号处理器(DSP)、现场可编程门阵列(FPGA)、可编程逻辑阵列(PLA)中的至少一种硬件形式来实现,所述处理器102可以是中央处理单元(CPU)、图形处理单元(GPU)或者具有数据处理能力和/或指令执行能力的其它形式的处理单元中的一种或几种的组合,并且可以控制所述电子设备100中的其它组件以执行期望的功能。
所述存储装置104可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器102可以运行所述程序指令,以实现下文所述的本发明实施例中(由处理器实现)的客户端功能以及/或者其它期望的功能。在所述计算机可读存储介质中还可以存储各种应用程序和各种数据,例如所述应用程序使用和/或产生的各种数据等。
所述输入装置106可以是用户用来输入指令的装置,并且可以包括键盘、鼠标、麦克风和触摸屏等中的一个或多个。
所述输出装置108可以向外部(例如,用户)输出各种信息(例如,图像或声音),并且可以包括显示器、扬声器等中的一个或多个。
所述图像采集装置110可以拍摄用户期望的图像(例如照片、视频等),并且将所拍摄的图像存储在所述存储装置104中以供其它组件使用。
示例性地,用于实现根据本发明实施例的分类模型的训练方法、图像处理方法及装置的示例电子设备可以被实现为诸如服务器、平板电脑、计算机等智能终端。
实施例二
本发明实施例提供了一种分类模型的训练方法,为便于对本发明实施例提供的分类模型的训练方法进行理解,其中,分类模型可以采用用于执行分类任务的神经网络模型,诸如CNN(Convolutional Neural Network,卷积神经网络)、VGG(Visual Geometry Group,视觉几何组)、ResNet(Residual Neural Network,残差网络)和GoogleNet网络模型等,本发明实施例对分类模型的结构不作限定。本发明实施例提供了一种分类模型的结构示意图,如图2所示,图2中示意出了分类模型包括特征提取网络、全连接网络和输出层,其中,特征提取网络用于提取样本图像的特征,得到样本图像对应的特征图,全连接网络用于基于特征图得到预测结果和预测结果对应的置信度,输出层用于输出预测结果和预测结果对应的置信度。在此基础上,本发明实施例提供了一种分类模型的训练方法,参见图3所示的一种分类模型的训练方法的流程示意图,该方法主要包括以下步骤S302至步骤S308:
步骤S302,将样本图像输入至分类模型,得到样本图像对应的特征图。
其中,样本图像可以为训练集中的任意图像,样本图像标注有标签,在实际应用中,可以根据分类模型需要执行的分类任务选择训练集,并将训练集中的图像作为样本图像,例如,分类模型用于图像分类,则可以将包含多种类别的图像作为样本图像,且样本图像中将标注有该样本图像所属的类别的标签,在一种实施方式中,可利用分类模型中的特征提取网络提取样本图像的特征,得到样本图像对应的特征图。
步骤S304,将特征图输入至全连接网络,得到分类模型输出的预测结果和预测结果对应的置信度。
其中,预测结果是样本图像的均值特征,特征图实质上对应的是样本图像的特征分布,而非确定值。在一种实施方式中,全连接网络可针对输入的特征图得到样本图像对应的预测分类分布,通过计算预测分类分布的均值即可得到的预测结果。置信度用于表征预测结果的可信程度,在一种实施方式中,对预测结果求取方差可以得到置信度。
步骤S306,根据预测结果和置信度,确定目标损失值。其中,目标损失值可以用于估量分类模型执行分类任务的性能。在一种实施方式中,可基于预测结果、置信度和目标损失函数中的正则化项计算正则化损失值,以及基于预测结果、置信度、标签和目标损失函数中的分类损失函数计算分类损失值,通过对正则化损失值和分类损失值进行加权求和,可以得到分类模型的目标损失值。其中,正则化项用于表征预测结果和置信度的加权和值与预设分布函数(诸如标准正态分布)之间的相似程度,分类损失函数用于表征的预测结果与真实值之间的差异程度。
步骤S308,根据目标损失值,更新分类模型的参数。在实际应用中,可通过诸如梯度下降等算法求解目标损失值的偏导数,进而利用偏导数调整分类模型中的参数,例如分类模型中特征提取网络的参数和全连接网络的参数,从而实现训练分类模型,并在分类模型达到收敛时停止训练,得到训练完成的分类模型。
本发明实施例提供的上述分类模型的训练方法,首先将样本图像输入至分类模型得到特征图,将特征图输出至全连接网络得到分类模型输出的预测结果和预测结果对应的置信度,并基于预测结果和置信度计算分类模型的目标损失值,进而可以利用目标损失值对分类模型进行训练。上述方法利用全连接网络得到分类模型输出的预测结果和预测结果对应的置信度,并基于预测结果和置信度计算得到的目标损失值对分类模型进行训练,不仅可以使训练后的分类模型输出准确率较高的预测结果,还可以使训练后的分类模型的输出的预测结果中携带有置信度,从而通过携带的置信度体现预测结果的可信程度。
为便于对前述实施例提供的预测结果和置信度进行理解,本发明实施例提供了一种将特征图输入至全连接网络,得到分类模型输出的预测结果的实施方式,其中,特征图对应于样本图像的特征分布,在实施时可以通过全连接网络对样本图像的特征分布进行处理,得到样本图像的特征分布对应的预测分类分布,并将预测分类分布的均值确定为预测结果。另外,本发明实施例还提供了一种将特征图输入至全连接网络,得到分类模型输出的预测结果对应的置信度的实施方式,可以根据预测结果的方差,确定置信度,也即,计算预测结果的方差并将计算得到的方差作为置信度。在此基础上,结合前述实施例提供的图2,本发明实施例提供了另一种分类模型,参见图4所示的另一种分类模型的结构示意图,图4中进一步示意出了全连接网络包括与特征提取网络连接的第一全连接子网络和第二全连接子网络,其中,第一全连接子网络用于基于特征图计算预测分类分布对应的均值,以得到预测结果,第二全连接子网络可以用于基于特征图计算预测结果的方差,以得到置信度,方差和均值可以用于表征特征图中特征的分布情况,在实际应用中,第一全连接子网络和第二全连接子网络均可以包括多个全连接层,且第一全连接子网络采用的全连接层与第二全连接子网络采用的全连接层的参数不同。基于图4所示的分类模型,本发明实施例提供了一种将特征图输入至全连接网络,得到所述分类模型输出的预测结果和预测结果对应的置信度的具体实施方式,(1)将特征图输入至第一全连接子网络,确定分类模型对样本图像的预测结果,如图4所示,第一全连接子网络采用相连的两个全连接层,其中,第一全连接子网络的输入为第i个样本图像xi的特征图fθ(xi),经两个全连接层利用fμ(fθ(xi))对特征图fθ(xi)进行处理后,输出均值μ,并将输出的均值确定为分类模型对样本图像的预测结果。(2)将特征图输入至第二全连接子网络,确定预测结果所对应的置信度,如图4所示,第二全连接子网络采用相连的两个全连接层,其中,第二全连接子网络的输入为第i个样本图像xi的特征图fθ(xi),经两个全连接层利用f(fθ(xi))对特征图fθ(xi)进行处理后,输出方差∑,并将输出的方差确定为预测结果所对应的置信度。
对于前述步骤S306,本发明实施例提供了一种根据预测结果和置信度确定目标损失值的实施方式,可根据分类损失函数和正则化项确定目标损失函数,然后根据预测结果、置信度和目标损失函数确定目标损失值。其中,正则化项用于表征采样特征与预设分布函数之间的相似度,预设分布函数可以采用诸如标准正太分布等,采样特征通过对预测结果和置信度进行加权求和确定。而在采样特征前,需要确定预测结果或置信度的系数,在一种具体的实施方式中,可以利用重采样技术确定预测结果或置信度的系数,(1)在预设分布函数中进行随机采样,得到第一系数。在具体实现时,可以先对标准正态分布N(0,1)进行随机采样,得到第一系数ε,其中ε∈N(0,1),且每次计算采样特征时均会对标准正态分布N(0,1)进行随机采样,以得到不同的第一系数ε。(2)基于第一系数,对预测结果和置信度进行加权求和,得到采样特征。其中,第一系数可以作为预测结果或置信度的权重,从而对预测结果和置信度进行加权求和,得到样本图像对应的采样特征,例如,将第一系数ε作为置信度的权重,得到采样特征g(μ,∑)=μ+ε*∑。本发明实施例通过利用重采样技巧对预测结果和置信度进行处理,可以将可训练部分(也即,前述全连接网络)与采样部分(也即,重采样过程)分离,从而实现将梯度反传至全连接网络,进而可以实现利用目标损失值对全连接网络进行训练。
为使正则化项能较好地表征采样特征与标准正态分布之间的相似度,本发明实施例提供的正则化项可采用单调递减函数,诸如采用KL(Kullback Leibler,散度约束)散度函数,本发明实施例以KL散度函数为例示意了一种正则化项,具体可如下所示:
其中,Lkl表示正则化损失值,N表示样本图像的总数量,∑i表示第i个样本图像的置信度,μi表示第i个样本图像的预测结果。
另外,本发明实施例示例性列举了确定上述分类损失函数的方式,首先根据预测结果确定分类损失函数的第一分量,在一种实施方式中,第一分量可以采用诸如softmax损失函数、am-softmax损失函数,ArcFace损失函数和triplet损失函数等任意一种损失函数;然后根据采样特征确定分类损失函数的第二分量,在另一种实施方式中,第二分量可以采用诸如softmax损失函数、am-softmax损失函数,ArcFace损失函数和triplet损失函数等任意一种损失函数,其中,第一分量采用的损失函数和第二分量采用的损失函数可以相同也可不同;最后根据第一分量和/或第二分量,确定分类损失函数。
在一种实施方式中,经上述方式确定目标损失函数中的分类损失函数和正则化项,从而根据目标损失函数、预测结果和置信度计算目标损失值,本发明实施例示例性提供了根据预测结果、置信度和目标损失函数确定目标损失值的实现方式,首先根据预测结果、置信度和标签,确定分类损失函数对应的分类损失值;然后根据预测结果和置信度计算正则化损失值;最后将分类损失值和正则化损失值的加权和,确定为目标损失值,具体可参见如下方式:
方式一:根据预测结果确定第一分量,并将第一分量作为第一分类损失函数。在实际应用中,第一分类损失函数可基于实际情况进行选择,诸如softmax损失函数、am-softmax损失函数,ArcFace损失函数和triplet损失函数等任意一种损失函数,以softmax损失函数为例,第一分类损失函数可如下所示:
其中,L1表示根据预测结果确定的第一分类损失值,μ为预测结果,d表示特征总维度,μj表示第j个特征维度对应的预测结果,μk表示第k个特征维度对应的预测结果,yi表示第i个样本图像的标签。
在上述分类损失函数的基础上,本发明实施例提供了一种上述根据预测结果、置信度和目标损失函数,确定目标损失值的实现方式,如下步骤1.1至步骤1.3所示:
步骤1.1,根据预测结果和标签,确定第一分类损失函数对应的第一分类损失值。在具体实现时,可将样本图像的预测结果和标签带入至上述根据预测结果确定的第一分类损失函数中即可得到第一分类损失值L1
步骤1.2,根据正则化项得到正则化损失值,正则化损失值根据预测结果和置信度计算得到。在具体实现时,可将预测结果和置信度带入至上述KL散度函数,即可得到正则化损失值Lkl
步骤1.3,将第一分类损失值和正则化损失值的加权和,确定为目标损失值。在实际应用中,可以分别设置第一分类损失函数和正则化损失函数的权重,例如,将正则化损失值的权重设置为γ,其中,权重γ可采用0至1中任意数值,则目标损失值为:L=L1+γLkl
方式二:根据采样特征确定第二分量,并将第二分量作为第二分类损失函数。在一种实施方式中,可以选择与计算方式一相同的损失函数,以softmax损失函数为例,第二分类损失函数可如下所示:
其中,L2表示根据采样特征确定的第二分类损失值,g(μ,∑)为采样特征,d表示特征总维度,g(μ,∑)j表示第j个特征维度对应的采样特征,g(μ,∑)k表示第k个特征维度对应的采样特征,yi表示第i个样本图像的标签。
在上述第二分类损失函数的基础上,本发明实施例提供了另一种上述根据预测结果、置信度和目标损失函数,确定目标损失值的实现方式,如下步骤2.1至步骤2.3所示:
步骤2.1,根据采样特征和标签,确定第二分类损失函数对应的第二分类损失值。在具体实现时,可将采样特征和标签带入至上述第二分类损失函数中即可得到第二分类损失值L2
步骤2.2,根据正则化项得到正则化损失值,正则化损失值根据预测结果和置信度计算得到。正则化损失值的计算方法可参见前述步骤1.2,此处不再赘述。
步骤2.3,将第二分类损失值和正则化损失值的加权和,确定为目标损失值。在实际应用中,可以分别设置第二分类损失函数和正则化损失函数的权重,例如,可以将正则化损失值的权重设置为γ,得到目标损失值L=L2+γLkl
方式三:根据预测结果确定第三分类损失函数的第一分量,根据采样特征确定第三分类损失函数的第二分量,根据第一分量和第二分量的加权和,确定第三分类损失函数。其中,第一分量可以采用前述方式一提供的L1(μ,y),第二分量可以采用前述方式二提供的L2(g(μ,∑),y),并为第一分量和第二分量配置权重β,其中,β可以基于实际情况进行选择,例如设置β=1,通过计算第一分量和第二分量进行加权求和,得到第三分类损失函数Lrecog=L1(μ,y)+βL2(g(μ,∑),y)。
在上述分类损失函数的基础上,本发明实施例提供了一种上述步骤S306的实现方式,如下步骤3.1至步骤3.4所示:
步骤3.1,根据预测结果和标签,确定第三分类损失函数中第一分量对应的第一分类损失值。在具体实现时,将样本图像的预测结果和标签带入至上述根据预测结果确定的第一分量中即可得到第一分类损失值L1
步骤3.2,根据采样特征和标签,确定第三分类损失函数中第二分量对应的第二分类损失值。在具体实现时,将采样特征和标签带入至上述根据采样特征确定的第二分量中即可得到第二分类损失值L2。在另一种实施方式中,可以第一分量和第二分量可采用不同损失函数,例如第一分量采用softmax损失函数而第二分量采用ArcFace损失函数。
步骤3.3,根据正则化项得到正则化损失值,正则化损失值根据预测结果和置信度计算得到。正则化损失值的计算方法可参见前述步骤1.2,此处不再赘述。
步骤3.4,将第一分类损失值、第二分类损失值和正则化损失值的加权和,确定为目标损失值。在实际应用中,可以分别设置第一分类损失值和第二分类损失值的权重,得到第三分类损失值Lrecog,并设置第三分类损失值Lrecog和正则化损失函数Lkl的权重,例如,将正则化损失值的权重设置为γ,得到目标损失值为:L=Lrecog+γLkl
本发明实施例通过上述方式计算目标损失值,当样本图像的置信度(也即,方差)较大时,可以认为该样本图像为携带有较多噪声的脏数据,例如样本图像模糊、图像标签标注错误、样本图像不包含目标对象或仅包含部分目标对象,此时计算的目标损失值将较小,从而在利用该目标损失值对分类模型进行训练时可以有效减少该样本图像对分类模型的危害,而当样本图像的置信度较小时,可以认为该样本图像为噪声较少的干净数据,此时计算的目标损失值将较大,从而在利用该目标损失值对分类模型进行训练时可以有效学习该图像样本的特征与标签之间的映射关系,相较于现有技术中对不断对脏数据进行清洗的方法,本发明实施例通过上述方法可以有效提高训练后的神经网络的性能。
综上所述,本发明实施例在分类模型中增加第一全连接子网络的分支,并通过第一全连接子网络生成的方差表征预测结果的置信度,另外由于目标损失函数受方差的影响,当样本图像携带有较多噪声时,计算得到的方法较大,并得到较小的目标损失值,从而在利用目标损失值对分类模型进行训练时,可以有效减少携带噪声的样本图像对分类模型的负面影响,进而有效提高分类模型的性能。
实施例三
在前述实施例二提供的分类模型的训练方法的基础上,本发明实施例提供了一种图像处理处理方法,该方法应用经前述实施例训练得到的分类模型对图像进行处理,参见图5所示的一种图像处理方法的流程示意图,该方法主要包括以下步骤S502至步骤S504:
步骤S502,获取待处理图像。训练后的分类模型可用于完成指定任务,如果训练后的分类模型用于人脸识别,则待处理图像中可能包含有人像或人脸等。
步骤S504,通过预设分类模型对待处理图像进行处理,得到待处理图像对应的预测结果和预测结果对应的置信度。其中,预设分类模型是采用如前述实施例二提供的分类模型的训练方法训练得到的。在实际应用中,如果训练后的分类模型用于人脸识别,则将待处理图像输入至训练后的分类模型后,分类模型将通过特征提取网络提取待处理图像中的人脸特征,将提取的特征分别输入至第一全连接子网络和第二全连接子网络,以分别计算人脸特征对应的预测结果和预测结果对应的置信度,当置信度较大时表明该预测结果可信度较低,当置信度较小时表明该预测结果可信度较高。
本发明实施例通过前述实施例提供的分类模型对待处理图像进行处理,输出预测结果和其对应的置信度,相较于现有技术仅能获取图像处理的预测结果,本发明实施例可以利用分类模型直接得到预测结果的置信度,从而利用置信度对预测结果的可信程度进行评价,有利用用户根据分类模型输出的置信度作出是否采信或应用预测结果的决策。
另外,为了进一步便于用于做出是否应该预测结果的决策,本发明实施例可以在预测结果对应的置信度低于预设置信度阈值时,输出预测结果和预测结果对应的提示消息。其中,提醒消息用于提示用户该预测结果的可信程度较低,在实际应用中,可以采用高亮输出预测结果、弹出无法识别图像的提醒框、输出无法识别图像的语音信息等多种提示方式,具体可基于实际情况采用所需的提示方式,从而以更醒目的方式告知用户该预测结果的可信程度。在另一种实施方式中,也可以不输出预测结果,或不输出预测结果但输出预测结果对应的提实消息,以起到提示用户该预测结果不可信的目的。
实施例四
对于实施例二中所提供的分类模型的训练方法,本发明实施例提供了一种分类模型的训练装置,参见图6所示的一种分类模型的训练装置的结构示意图,该装置包括以下模块:
特征提取模块602,用于将样本图像输入至分类模型,得到样本图像对应的特征图。
输出模块604,用于将特征图输入至全连接网络,得到分类模型输出的预测结果和预测结果对应的置信度。
损失计算模块606,用于根据预测结果和置信度,确定目标损失值。
训练模块608,用于根据目标损失值,更新分类模型的参数。
本发明实施例提供的分类模型的训练装置,利用全连接网络得到分类模型输出的预测结果和预测结果对应的置信度,并基于预测结果、置信度、目标损失函数和标签计算得到的目标损失值对分类模型进行训练,不仅可以使训练后的分类模型输出准确率较高的预测结果,还可以使训练后的分类模型输出的预测结果中携带置信度,从而通过携带的置信度体现预测结果的可信程度。
在一种实施方式中,上述全连接网络包括与特征提取网络连接的第一全连接子网络和第二全连接子网络;第一全连接子网络和第二全连接子网络包括多个全连接层;上述输出模块604还用于:将特征图输入至第一全连接子网络,确定分类模型对样本图像的预测结果;将特征图输入至第二全连接子网络,确定预测结果所对应的置信度。
在一种实施方式中,上述损失计算模块606,还用于根据分类损失函数和正则化项确定目标损失函数,其中,正则化项用于表征采样特征与预设分布函数之间的相似度,采样特征通过对预测结果和置信度进行加权求和确定;根据预测结果、置信度和目标损失函数,确定目标损失值。
在一种实施方式中,上述分类模型的训练装置还包括加权计算模块,用于:在预设分布函数中进行随机采样,得到第一系数;基于第一系数,对预测结果和置信度进行加权求和,得到的采样特征。
在一种实施方式中,上述正则化项为KL散度函数,根据预测结果和置信度计算得到。
在一种实施方式中,上述分类模型的训练装置还包括分类函数确定模块,用于根据预测结果确定分类损失函数的第一分量;根据采样特征确定分类损失函数的第二分量;根据第一分量和/或第二分量,确定分类损失函数。
在一种实施方式中,特征图对应于样本图像的特征分布。
在一种实施方式中,上述输出模块604还用于:通过全连接网络对样本图像的特征分布进行处理,得到样本图像的特征分布对应的预测分类分布;将预测分类分布的均值确定为预测结果。
在一种实施方式中,上述输出模块604还用于:根据预测结果的方差,确定置信度。
对于实施例三中所提供的图像处理方法,本发明实施例提供了一种图像处理装置,参见图7所示的一种图像处理装置的结构示意图,该装置包括以下模块:
图像获取模块702,用于获取待处理图像。
图像处理模块704,用于通过预设分类模型对待处理图像进行处理,得到待处理图像对应的预测结果和预测结果对应的置信度,其中,预设分类模型为采用如实施例二提供的分类模型的训练方法训练得到。
本发明实施例提供的上述图像处理装置,在对待处理图像进行处理后可以得到预测结果和其对应的置信度,相较于现有技术仅能获取图像处理的预测结果,本发明实施例可以利用分类模型直接得到预测结果的置信度,从而利用置信度对预测结果的可信程度进行评价,有利用用户根据分类模型输出的置信度作出是否采信或应用预测结果的决策。
在一种实施方式中,上述图像处理装置还包括提示模块,用于当预测结果对应的置信度低于预设置信度阈值,输出预测结果和预测结果对应的提示消息。
本实施例所提供的装置,其实现原理及产生的技术效果和前述实施例相同,为简要描述,装置实施例部分未提及之处,可参考前述方法实施例中相应内容。
实施例五
本发明实施例所提供的分类模型的训练方法、图像处理方法及装置的计算机程序产品,包括存储了程序代码的计算机可读存储介质,所述程序代码包括的指令可用于执行前面方法实施例中所述的方法,具体实现可参见方法实施例,在此不再赘述。
另外,在本发明实施例的描述中,除非另有明确的规定和限定,术语“安装”、“相连”、“连接”应做广义理解,例如,可以是固定连接,也可以是可拆卸连接,或一体地连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连,可以是两个元件内部的连通。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明中的具体含义。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
在本发明的描述中,需要说明的是,术语“中心”、“上”、“下”、“左”、“右”、“竖直”、“水平”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。此外,术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性。
最后应说明的是:以上所述实施例,仅为本发明的具体实施方式,用以说明本发明的技术方案,而非对其限制,本发明的保护范围并不局限于此,尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特征进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的精神和范围,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以所述权利要求的保护范围为准。

Claims (13)

1.一种分类模型的训练方法,其特征在于,所述分类模型包括特征提取网络和全连接网络,所述方法包括:
将样本图像输入至所述分类模型,得到所述样本图像对应的特征图;
将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果和所述预测结果对应的置信度;
根据所述预测结果和所述置信度,确定目标损失值;
根据所述目标损失值,更新所述分类模型的参数;
所述根据所述预测结果和所述置信度,确定目标损失值,包括:
根据分类损失函数和正则化项确定目标损失函数,其中,所述正则化项用于表征采样特征与预设分布函数之间的相似度,所述采样特征通过对所述预测结果和所述置信度进行加权求和确定;
根据所述预测结果、所述置信度和所述目标损失函数,确定目标损失值;
所述方法还包括:
根据所述预测结果确定所述分类损失函数的第一分量;
根据所述采样特征,确定所述分类损失函数的第二分量;
根据所述第一分量和/或所述第二分量,确定所述分类损失函数。
2.根据权利要求1所述的方法,其特征在于,所述全连接网络包括与所述特征提取网络连接的第一全连接子网络和第二全连接子网络,所述第一全连接子网络和第二全连接子网络包括多个全连接层;
将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果和所述预测结果对应的置信度的步骤,包括:
将所述特征图输入至所述第一全连接子网络,确定所述分类模型对所述样本图像的所述预测结果;
将所述特征图输入至所述第二全连接子网络,确定所述预测结果所对应的置信度。
3.根据权利要求1所述的方法,其特征在于,所述方法还包括:
在所述预设分布函数中进行随机采样,得到第一系数;
基于所述第一系数,对所述预测结果和所述置信度进行加权求和,得到所述采样特征。
4.根据权利要求3所述的方法,其特征在于,所述正则化项为KL散度函数,根据所述预测结果和所述置信度计算得到。
5.根据权利要求1-4任一项所述的方法,其特征在于,所述特征图对应于所述样本图像的特征分布。
6.根据权利要求5所述的方法,其特征在于,将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果,包括:
通过所述全连接网络对所述样本图像的特征分布进行处理,得到所述样本图像的特征分布对应的预测分类分布;
将所述预测分类分布的均值确定为所述预测结果。
7.根据权利要求6所述的方法,其特征在于,根据所述预测结果的方差,确定所述置信度。
8.一种图像处理方法,其特征在于,包括:
获取待处理图像;
通过预设分类模型对待处理图像进行处理,得到所述待处理图像对应的预测结果和所述预测结果对应的置信度,其中,所述预设分类模型采用如权利要求1-7任一项所述的分类模型的训练方法训练得到。
9.根据权利要求8所述的方法,其特征在于,所述方法还包括:
当所述预测结果对应的置信度低于预设置信度阈值,输出所述预测结果对应的提示消息。
10.一种分类模型的训练装置,其特征在于,所述分类模型包括特征提取网络和全连接网络,所述装置包括:
特征提取模块,用于将样本图像输入至所述分类模型,得到所述样本图像对应的特征图;
输出模块,用于将所述特征图输入至所述全连接网络,得到所述分类模型输出的预测结果和所述预测结果对应的置信度;
损失计算模块,用于根据所述预测结果和所述置信度,确定目标损失值;
训练模块,用于根据所述目标损失值,更新所述分类模型的参数;
所述损失计算模块还用于:
根据分类损失函数和正则化项确定目标损失函数,其中,所述正则化项用于表征采样特征与预设分布函数之间的相似度,所述采样特征通过对所述预测结果和所述置信度进行加权求和确定;
根据所述预测结果、所述置信度和所述目标损失函数,确定目标损失值;
还包括分类函数确定模块,用于:
根据所述预测结果确定所述分类损失函数的第一分量;
根据所述采样特征,确定所述分类损失函数的第二分量;
根据所述第一分量和/或所述第二分量,确定所述分类损失函数。
11.一种图像处理装置,其特征在于,包括:
图像获取模块,用于获取待处理图像;
图像处理模块,用于通过预设分类模型对待处理图像进行处理,得到所述待处理图像对应的预测结果和所述预测结果对应的置信度,其中,所述预设分类模型为采用如权利要求1-7任一项所述的分类模型的训练方法训练得到。
12.一种电子设备,其特征在于,包括处理器和存储器;
所述存储器上存储有计算机程序,所述计算机程序在被所述处理器运行时执行如权利要求1至7任一项所述的方法,或执行如权利要求8-9任一项所述的方法。
13.一种计算机存储介质,其特征在于,用于储存为权利要求1至7任一项所述方法所用的计算机软件指令,或存储为权利要求8-9任一项所述方法所用的计算机软件指令。
CN202010040821.3A 2020-01-14 2020-01-14 分类模型的训练方法、图像处理方法及装置 Active CN111242222B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010040821.3A CN111242222B (zh) 2020-01-14 2020-01-14 分类模型的训练方法、图像处理方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010040821.3A CN111242222B (zh) 2020-01-14 2020-01-14 分类模型的训练方法、图像处理方法及装置

Publications (2)

Publication Number Publication Date
CN111242222A CN111242222A (zh) 2020-06-05
CN111242222B true CN111242222B (zh) 2023-12-19

Family

ID=70876552

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010040821.3A Active CN111242222B (zh) 2020-01-14 2020-01-14 分类模型的训练方法、图像处理方法及装置

Country Status (1)

Country Link
CN (1) CN111242222B (zh)

Families Citing this family (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11650351B2 (en) * 2020-02-12 2023-05-16 Nec Corporation Semi-supervised deep model for turbulence forecasting
CN112308153B (zh) * 2020-11-02 2023-11-24 创新奇智(广州)科技有限公司 一种烟火检测方法和装置
CN112200173B (zh) * 2020-12-08 2021-03-23 北京沃东天骏信息技术有限公司 多网络模型训练方法、图像标注方法和人脸图像识别方法
CN114612373A (zh) * 2020-12-09 2022-06-10 航天信息股份有限公司 一种图像识别方法及服务器
CN112579587B (zh) * 2020-12-29 2024-07-02 纽扣互联(北京)科技有限公司 数据清洗方法及装置、设备和存储介质
CN112734008A (zh) * 2020-12-31 2021-04-30 平安科技(深圳)有限公司 分类网络构建方法以及基于分类网络的分类方法
CN113610766A (zh) * 2021-07-12 2021-11-05 北京阅视智能技术有限责任公司 显微图像分析方法、装置、存储介质及电子设备
CN113869353A (zh) * 2021-08-16 2021-12-31 深延科技(北京)有限公司 模型训练方法、老虎关键点检测方法及相关装置
CN114255381B (zh) * 2021-12-23 2023-05-12 北京瑞莱智慧科技有限公司 图像识别模型的训练方法、图像识别方法、装置及介质

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108846340A (zh) * 2018-06-05 2018-11-20 腾讯科技(深圳)有限公司 人脸识别方法、装置及分类模型训练方法、装置、存储介质和计算机设备
CN109102024A (zh) * 2018-08-14 2018-12-28 中山大学 一种用于物体精细识别的层次语义嵌入模型及其实现方法
CN109299716A (zh) * 2018-08-07 2019-02-01 北京市商汤科技开发有限公司 神经网络的训练方法、图像分割方法、装置、设备及介质
CN110070067A (zh) * 2019-04-29 2019-07-30 北京金山云网络技术有限公司 视频分类方法及其模型的训练方法、装置和电子设备
CN110321952A (zh) * 2019-07-02 2019-10-11 腾讯医疗健康(深圳)有限公司 一种图像分类模型的训练方法及相关设备
CN110647916A (zh) * 2019-08-23 2020-01-03 苏宁云计算有限公司 基于卷积神经网络的色情图片识别方法及装置

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US9158971B2 (en) * 2014-03-03 2015-10-13 Xerox Corporation Self-learning object detectors for unlabeled videos using multi-task learning

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108846340A (zh) * 2018-06-05 2018-11-20 腾讯科技(深圳)有限公司 人脸识别方法、装置及分类模型训练方法、装置、存储介质和计算机设备
CN109299716A (zh) * 2018-08-07 2019-02-01 北京市商汤科技开发有限公司 神经网络的训练方法、图像分割方法、装置、设备及介质
CN109102024A (zh) * 2018-08-14 2018-12-28 中山大学 一种用于物体精细识别的层次语义嵌入模型及其实现方法
CN110070067A (zh) * 2019-04-29 2019-07-30 北京金山云网络技术有限公司 视频分类方法及其模型的训练方法、装置和电子设备
CN110321952A (zh) * 2019-07-02 2019-10-11 腾讯医疗健康(深圳)有限公司 一种图像分类模型的训练方法及相关设备
CN110647916A (zh) * 2019-08-23 2020-01-03 苏宁云计算有限公司 基于卷积神经网络的色情图片识别方法及装置

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
Learning Cross-Modal Aligned Representation With Graph Embedding;YOUCAI ZHANG, JIAYAN CAO,XIAODONG GU;《IEEE Access》;第77321-77333页 *

Also Published As

Publication number Publication date
CN111242222A (zh) 2020-06-05

Similar Documents

Publication Publication Date Title
CN111242222B (zh) 分类模型的训练方法、图像处理方法及装置
JP2022141931A (ja) 生体検出モデルのトレーニング方法及び装置、生体検出の方法及び装置、電子機器、記憶媒体、並びにコンピュータプログラム
CN111414946B (zh) 基于人工智能的医疗影像的噪声数据识别方法和相关装置
KR20180057096A (ko) 표정 인식과 트레이닝을 수행하는 방법 및 장치
CN112395979B (zh) 基于图像的健康状态识别方法、装置、设备及存储介质
CN112889108A (zh) 使用视听数据进行说话分类
CN111639744A (zh) 学生模型的训练方法、装置及电子设备
CN112052761A (zh) 一种对抗人脸图像的生成方法和装置
CN111401521B (zh) 神经网络模型训练方法及装置、图像识别方法及装置
CN112418195B (zh) 一种人脸关键点检测方法、装置、电子设备及存储介质
CN111160555A (zh) 基于神经网络的处理方法、装置及电子设备
CN111382791B (zh) 深度学习任务处理方法、图像识别任务处理方法和装置
CN111414875A (zh) 基于深度回归森林的三维点云头部姿态估计系统
CN113743426A (zh) 一种训练方法、装置、设备以及计算机可读存储介质
CN109214616B (zh) 一种信息处理装置、系统和方法
CN109101984B (zh) 一种基于卷积神经网络的图像识别方法及装置
CN111523586A (zh) 一种基于噪声可知的全网络监督目标检测方法
US11915419B1 (en) Auto-normalization for machine learning
CN113128526B (zh) 图像识别方法、装置、电子设备和计算机可读存储介质
CN117112766A (zh) 视觉对话方法、装置、电子设备和计算机可读存储介质
WO2023154986A1 (en) Method, system, and device using a generative model for image segmentation
CN111445545A (zh) 一种文本转贴图方法、装置、存储介质及电子设备
CN111898465B (zh) 一种人脸识别模型的获取方法和装置
CN117475187A (zh) 一种训练图像分类模型的方法、装置、设备及存储介质
CN112070022A (zh) 人脸图像识别方法、装置、电子设备和计算机可读介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant